diff options
| -rw-r--r-- | include/net/scm.h | 4 | ||||
| -rw-r--r-- | net/core/scm.c | 32 | ||||
| -rw-r--r-- | net/unix/af_unix.c | 57 | ||||
| -rw-r--r-- | tools/testing/selftests/net/af_unix/scm_pidfd.c | 217 |
4 files changed, 247 insertions, 63 deletions
diff --git a/include/net/scm.h b/include/net/scm.h index 84c4707e78a5..c52519669349 100644 --- a/include/net/scm.h +++ b/include/net/scm.h @@ -69,7 +69,7 @@ static __inline__ void unix_get_peersec_dgram(struct socket *sock, struct scm_co static __inline__ void scm_set_cred(struct scm_cookie *scm, struct pid *pid, kuid_t uid, kgid_t gid) { - scm->pid = get_pid(pid); + scm->pid = get_pid(pid); scm->creds.pid = pid_vnr(pid); scm->creds.uid = uid; scm->creds.gid = gid; @@ -78,7 +78,7 @@ static __inline__ void scm_set_cred(struct scm_cookie *scm, static __inline__ void scm_destroy_cred(struct scm_cookie *scm) { put_pid(scm->pid); - scm->pid = NULL; + scm->pid = NULL; } static __inline__ void scm_destroy(struct scm_cookie *scm) diff --git a/net/core/scm.c b/net/core/scm.c index 0225bd94170f..072d5742440a 100644 --- a/net/core/scm.c +++ b/net/core/scm.c @@ -23,6 +23,8 @@ #include <linux/security.h> #include <linux/pid_namespace.h> #include <linux/pid.h> +#include <uapi/linux/pidfd.h> +#include <linux/pidfs.h> #include <linux/nsproxy.h> #include <linux/slab.h> #include <linux/errqueue.h> @@ -145,6 +147,22 @@ void __scm_destroy(struct scm_cookie *scm) } EXPORT_SYMBOL(__scm_destroy); +static inline int scm_replace_pid(struct scm_cookie *scm, struct pid *pid) +{ + int err; + + /* drop all previous references */ + scm_destroy_cred(scm); + + err = pidfs_register_pid(pid); + if (unlikely(err)) + return err; + + scm->pid = pid; + scm->creds.pid = pid_vnr(pid); + return 0; +} + int __scm_send(struct socket *sock, struct msghdr *msg, struct scm_cookie *p) { const struct proto_ops *ops = READ_ONCE(sock->ops); @@ -189,15 +207,21 @@ int __scm_send(struct socket *sock, struct msghdr *msg, struct scm_cookie *p) if (err) goto error; - p->creds.pid = creds.pid; if (!p->pid || pid_vnr(p->pid) != creds.pid) { struct pid *pid; err = -ESRCH; pid = find_get_pid(creds.pid); if (!pid) goto error; - put_pid(p->pid); - p->pid = pid; + + /* pass a struct pid reference from + * find_get_pid() to scm_replace_pid(). + */ + err = scm_replace_pid(p, pid); + if (err) { + put_pid(pid); + goto error; + } } err = -EINVAL; @@ -459,7 +483,7 @@ static void scm_pidfd_recv(struct msghdr *msg, struct scm_cookie *scm) if (!scm->pid) return; - pidfd = pidfd_prepare(scm->pid, 0, &pidfd_file); + pidfd = pidfd_prepare(scm->pid, PIDFD_STALE, &pidfd_file); if (put_cmsg(msg, SOL_SOCKET, SCM_PIDFD, sizeof(int), &pidfd)) { if (pidfd_file) { diff --git a/net/unix/af_unix.c b/net/unix/af_unix.c index 129388c309b0..d52811321fce 100644 --- a/net/unix/af_unix.c +++ b/net/unix/af_unix.c @@ -1929,7 +1929,7 @@ static void unix_destruct_scm(struct sk_buff *skb) struct scm_cookie scm; memset(&scm, 0, sizeof(scm)); - scm.pid = UNIXCB(skb).pid; + scm.pid = UNIXCB(skb).pid; if (UNIXCB(skb).fp) unix_detach_fds(&scm, skb); @@ -1955,21 +1955,45 @@ static int unix_scm_to_skb(struct scm_cookie *scm, struct sk_buff *skb, bool sen return err; } -/* +static void unix_skb_to_scm(struct sk_buff *skb, struct scm_cookie *scm) +{ + scm_set_cred(scm, UNIXCB(skb).pid, UNIXCB(skb).uid, UNIXCB(skb).gid); + unix_set_secdata(scm, skb); +} + +/** + * unix_maybe_add_creds() - Adds current task uid/gid and struct pid to skb if needed. + * @skb: skb to attach creds to. + * @sk: Sender sock. + * @other: Receiver sock. + * * Some apps rely on write() giving SCM_CREDENTIALS * We include credentials if source or destination socket * asserted SOCK_PASSCRED. + * + * Context: May sleep. + * Return: On success zero, on error a negative error code is returned. */ -static void unix_maybe_add_creds(struct sk_buff *skb, const struct sock *sk, - const struct sock *other) +static int unix_maybe_add_creds(struct sk_buff *skb, const struct sock *sk, + const struct sock *other) { if (UNIXCB(skb).pid) - return; + return 0; if (unix_may_passcred(sk) || unix_may_passcred(other)) { - UNIXCB(skb).pid = get_pid(task_tgid(current)); + struct pid *pid; + int err; + + pid = task_tgid(current); + err = pidfs_register_pid(pid); + if (unlikely(err)) + return err; + + UNIXCB(skb).pid = get_pid(pid); current_uid_gid(&UNIXCB(skb).uid, &UNIXCB(skb).gid); } + + return 0; } static bool unix_skb_scm_eq(struct sk_buff *skb, @@ -2104,6 +2128,10 @@ lookup: goto out_sock_put; } + err = unix_maybe_add_creds(skb, sk, other); + if (err) + goto out_sock_put; + restart: sk_locked = 0; unix_state_lock(other); @@ -2212,7 +2240,6 @@ restart_locked: if (sock_flag(other, SOCK_RCVTSTAMP)) __net_timestamp(skb); - unix_maybe_add_creds(skb, sk, other); scm_stat_add(other, skb); skb_queue_tail(&other->sk_receive_queue, skb); unix_state_unlock(other); @@ -2256,6 +2283,10 @@ static int queue_oob(struct sock *sk, struct msghdr *msg, struct sock *other, if (err < 0) goto out; + err = unix_maybe_add_creds(skb, sk, other); + if (err) + goto out; + skb_put(skb, 1); err = skb_copy_datagram_from_iter(skb, 0, &msg->msg_iter, 1); @@ -2275,7 +2306,6 @@ static int queue_oob(struct sock *sk, struct msghdr *msg, struct sock *other, goto out_unlock; } - unix_maybe_add_creds(skb, sk, other); scm_stat_add(other, skb); spin_lock(&other->sk_receive_queue.lock); @@ -2369,6 +2399,10 @@ static int unix_stream_sendmsg(struct socket *sock, struct msghdr *msg, fds_sent = true; + err = unix_maybe_add_creds(skb, sk, other); + if (err) + goto out_free; + if (unlikely(msg->msg_flags & MSG_SPLICE_PAGES)) { skb->ip_summed = CHECKSUM_UNNECESSARY; err = skb_splice_from_iter(skb, &msg->msg_iter, size, @@ -2399,7 +2433,6 @@ static int unix_stream_sendmsg(struct socket *sock, struct msghdr *msg, goto out_free; } - unix_maybe_add_creds(skb, sk, other); scm_stat_add(other, skb); skb_queue_tail(&other->sk_receive_queue, skb); unix_state_unlock(other); @@ -2547,8 +2580,7 @@ int __unix_dgram_recvmsg(struct sock *sk, struct msghdr *msg, size_t size, memset(&scm, 0, sizeof(scm)); - scm_set_cred(&scm, UNIXCB(skb).pid, UNIXCB(skb).uid, UNIXCB(skb).gid); - unix_set_secdata(&scm, skb); + unix_skb_to_scm(skb, &scm); if (!(flags & MSG_PEEK)) { if (UNIXCB(skb).fp) @@ -2933,8 +2965,7 @@ unlock: break; } else if (unix_may_passcred(sk)) { /* Copy credentials */ - scm_set_cred(&scm, UNIXCB(skb).pid, UNIXCB(skb).uid, UNIXCB(skb).gid); - unix_set_secdata(&scm, skb); + unix_skb_to_scm(skb, &scm); check_creds = true; } diff --git a/tools/testing/selftests/net/af_unix/scm_pidfd.c b/tools/testing/selftests/net/af_unix/scm_pidfd.c index 7e534594167e..37e034874034 100644 --- a/tools/testing/selftests/net/af_unix/scm_pidfd.c +++ b/tools/testing/selftests/net/af_unix/scm_pidfd.c @@ -15,6 +15,7 @@ #include <sys/types.h> #include <sys/wait.h> +#include "../../pidfd/pidfd.h" #include "../../kselftest_harness.h" #define clean_errno() (errno == 0 ? "None" : strerror(errno)) @@ -26,6 +27,8 @@ #define SCM_PIDFD 0x04 #endif +#define CHILD_EXIT_CODE_OK 123 + static void child_die() { exit(1); @@ -126,16 +129,65 @@ out: return result; } +struct cmsg_data { + struct ucred *ucred; + int *pidfd; +}; + +static int parse_cmsg(struct msghdr *msg, struct cmsg_data *res) +{ + struct cmsghdr *cmsg; + int data = 0; + + if (msg->msg_flags & (MSG_TRUNC | MSG_CTRUNC)) { + log_err("recvmsg: truncated"); + return 1; + } + + for (cmsg = CMSG_FIRSTHDR(msg); cmsg != NULL; + cmsg = CMSG_NXTHDR(msg, cmsg)) { + if (cmsg->cmsg_level == SOL_SOCKET && + cmsg->cmsg_type == SCM_PIDFD) { + if (cmsg->cmsg_len < sizeof(*res->pidfd)) { + log_err("CMSG parse: SCM_PIDFD wrong len"); + return 1; + } + + res->pidfd = (void *)CMSG_DATA(cmsg); + } + + if (cmsg->cmsg_level == SOL_SOCKET && + cmsg->cmsg_type == SCM_CREDENTIALS) { + if (cmsg->cmsg_len < sizeof(*res->ucred)) { + log_err("CMSG parse: SCM_CREDENTIALS wrong len"); + return 1; + } + + res->ucred = (void *)CMSG_DATA(cmsg); + } + } + + if (!res->pidfd) { + log_err("CMSG parse: SCM_PIDFD not found"); + return 1; + } + + if (!res->ucred) { + log_err("CMSG parse: SCM_CREDENTIALS not found"); + return 1; + } + + return 0; +} + static int cmsg_check(int fd) { struct msghdr msg = { 0 }; - struct cmsghdr *cmsg; + struct cmsg_data res; struct iovec iov; - struct ucred *ucred = NULL; int data = 0; char control[CMSG_SPACE(sizeof(struct ucred)) + CMSG_SPACE(sizeof(int))] = { 0 }; - int *pidfd = NULL; pid_t parent_pid; int err; @@ -158,53 +210,99 @@ static int cmsg_check(int fd) return 1; } - for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL; - cmsg = CMSG_NXTHDR(&msg, cmsg)) { - if (cmsg->cmsg_level == SOL_SOCKET && - cmsg->cmsg_type == SCM_PIDFD) { - if (cmsg->cmsg_len < sizeof(*pidfd)) { - log_err("CMSG parse: SCM_PIDFD wrong len"); - return 1; - } + /* send(pfd, "x", sizeof(char), 0) */ + if (data != 'x') { + log_err("recvmsg: data corruption"); + return 1; + } - pidfd = (void *)CMSG_DATA(cmsg); - } + if (parse_cmsg(&msg, &res)) { + log_err("CMSG parse: parse_cmsg() failed"); + return 1; + } - if (cmsg->cmsg_level == SOL_SOCKET && - cmsg->cmsg_type == SCM_CREDENTIALS) { - if (cmsg->cmsg_len < sizeof(*ucred)) { - log_err("CMSG parse: SCM_CREDENTIALS wrong len"); - return 1; - } + /* pidfd from SCM_PIDFD should point to the parent process PID */ + parent_pid = + get_pid_from_fdinfo_file(*res.pidfd, "Pid:", sizeof("Pid:") - 1); + if (parent_pid != getppid()) { + log_err("wrong SCM_PIDFD %d != %d", parent_pid, getppid()); + close(*res.pidfd); + return 1; + } - ucred = (void *)CMSG_DATA(cmsg); - } + close(*res.pidfd); + return 0; +} + +static int cmsg_check_dead(int fd, int expected_pid) +{ + int err; + struct msghdr msg = { 0 }; + struct cmsg_data res; + struct iovec iov; + int data = 0; + char control[CMSG_SPACE(sizeof(struct ucred)) + + CMSG_SPACE(sizeof(int))] = { 0 }; + pid_t client_pid; + struct pidfd_info info = { + .mask = PIDFD_INFO_EXIT, + }; + + iov.iov_base = &data; + iov.iov_len = sizeof(data); + + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + err = recvmsg(fd, &msg, 0); + if (err < 0) { + log_err("recvmsg"); + return 1; } - /* send(pfd, "x", sizeof(char), 0) */ - if (data != 'x') { + if (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC)) { + log_err("recvmsg: truncated"); + return 1; + } + + /* send(cfd, "y", sizeof(char), 0) */ + if (data != 'y') { log_err("recvmsg: data corruption"); return 1; } - if (!pidfd) { - log_err("CMSG parse: SCM_PIDFD not found"); + if (parse_cmsg(&msg, &res)) { + log_err("CMSG parse: parse_cmsg() failed"); return 1; } - if (!ucred) { - log_err("CMSG parse: SCM_CREDENTIALS not found"); + /* + * pidfd from SCM_PIDFD should point to the client_pid. + * Let's read exit information and check if it's what + * we expect to see. + */ + if (ioctl(*res.pidfd, PIDFD_GET_INFO, &info)) { + log_err("%s: ioctl(PIDFD_GET_INFO) failed", __func__); + close(*res.pidfd); return 1; } - /* pidfd from SCM_PIDFD should point to the parent process PID */ - parent_pid = - get_pid_from_fdinfo_file(*pidfd, "Pid:", sizeof("Pid:") - 1); - if (parent_pid != getppid()) { - log_err("wrong SCM_PIDFD %d != %d", parent_pid, getppid()); + if (!(info.mask & PIDFD_INFO_EXIT)) { + log_err("%s: No exit information from ioctl(PIDFD_GET_INFO)", __func__); + close(*res.pidfd); return 1; } + err = WIFEXITED(info.exit_code) ? WEXITSTATUS(info.exit_code) : 1; + if (err != CHILD_EXIT_CODE_OK) { + log_err("%s: wrong exit_code %d != %d", __func__, err, CHILD_EXIT_CODE_OK); + close(*res.pidfd); + return 1; + } + + close(*res.pidfd); return 0; } @@ -291,6 +389,24 @@ static void fill_sockaddr(struct sock_addr *addr, bool abstract) memcpy(sun_path_buf, addr->sock_name, strlen(addr->sock_name)); } +static int sk_enable_cred_pass(int sk) +{ + int on = 0; + + on = 1; + if (setsockopt(sk, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on))) { + log_err("Failed to set SO_PASSCRED"); + return 1; + } + + if (setsockopt(sk, SOL_SOCKET, SO_PASSPIDFD, &on, sizeof(on))) { + log_err("Failed to set SO_PASSPIDFD"); + return 1; + } + + return 0; +} + static void client(FIXTURE_DATA(scm_pidfd) *self, const FIXTURE_VARIANT(scm_pidfd) *variant) { @@ -299,7 +415,6 @@ static void client(FIXTURE_DATA(scm_pidfd) *self, struct ucred peer_cred; int peer_pidfd; pid_t peer_pid; - int on = 0; cfd = socket(AF_UNIX, variant->type, 0); if (cfd < 0) { @@ -322,14 +437,8 @@ static void client(FIXTURE_DATA(scm_pidfd) *self, child_die(); } - on = 1; - if (setsockopt(cfd, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on))) { - log_err("Failed to set SO_PASSCRED"); - child_die(); - } - - if (setsockopt(cfd, SOL_SOCKET, SO_PASSPIDFD, &on, sizeof(on))) { - log_err("Failed to set SO_PASSPIDFD"); + if (sk_enable_cred_pass(cfd)) { + log_err("sk_enable_cred_pass() failed"); child_die(); } @@ -340,6 +449,12 @@ static void client(FIXTURE_DATA(scm_pidfd) *self, child_die(); } + /* send something to the parent so it can receive SCM_PIDFD too and validate it */ + if (send(cfd, "y", sizeof(char), 0) == -1) { + log_err("Failed to send(cfd, \"y\", sizeof(char), 0)"); + child_die(); + } + /* skip further for SOCK_DGRAM as it's not applicable */ if (variant->type == SOCK_DGRAM) return; @@ -398,7 +513,13 @@ TEST_F(scm_pidfd, test) close(self->server); close(self->startup_pipe[0]); client(self, variant); - exit(0); + + /* + * It's a bit unusual, but in case of success we return non-zero + * exit code (CHILD_EXIT_CODE_OK) and then we expect to read it + * from ioctl(PIDFD_GET_INFO) in cmsg_check_dead(). + */ + exit(CHILD_EXIT_CODE_OK); } close(self->startup_pipe[1]); @@ -421,9 +542,17 @@ TEST_F(scm_pidfd, test) ASSERT_NE(-1, err); } - close(pfd); waitpid(self->client_pid, &child_status, 0); - ASSERT_EQ(0, WIFEXITED(child_status) ? WEXITSTATUS(child_status) : 1); + /* see comment before exit(CHILD_EXIT_CODE_OK) */ + ASSERT_EQ(CHILD_EXIT_CODE_OK, WIFEXITED(child_status) ? WEXITSTATUS(child_status) : 1); + + err = sk_enable_cred_pass(pfd); + ASSERT_EQ(0, err); + + err = cmsg_check_dead(pfd, self->client_pid); + ASSERT_EQ(0, err); + + close(pfd); } TEST_HARNESS_MAIN |
