summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/net/scm.h4
-rw-r--r--net/core/scm.c32
-rw-r--r--net/unix/af_unix.c57
-rw-r--r--tools/testing/selftests/net/af_unix/scm_pidfd.c217
4 files changed, 247 insertions, 63 deletions
diff --git a/include/net/scm.h b/include/net/scm.h
index 84c4707e78a5..c52519669349 100644
--- a/include/net/scm.h
+++ b/include/net/scm.h
@@ -69,7 +69,7 @@ static __inline__ void unix_get_peersec_dgram(struct socket *sock, struct scm_co
static __inline__ void scm_set_cred(struct scm_cookie *scm,
struct pid *pid, kuid_t uid, kgid_t gid)
{
- scm->pid = get_pid(pid);
+ scm->pid = get_pid(pid);
scm->creds.pid = pid_vnr(pid);
scm->creds.uid = uid;
scm->creds.gid = gid;
@@ -78,7 +78,7 @@ static __inline__ void scm_set_cred(struct scm_cookie *scm,
static __inline__ void scm_destroy_cred(struct scm_cookie *scm)
{
put_pid(scm->pid);
- scm->pid = NULL;
+ scm->pid = NULL;
}
static __inline__ void scm_destroy(struct scm_cookie *scm)
diff --git a/net/core/scm.c b/net/core/scm.c
index 0225bd94170f..072d5742440a 100644
--- a/net/core/scm.c
+++ b/net/core/scm.c
@@ -23,6 +23,8 @@
#include <linux/security.h>
#include <linux/pid_namespace.h>
#include <linux/pid.h>
+#include <uapi/linux/pidfd.h>
+#include <linux/pidfs.h>
#include <linux/nsproxy.h>
#include <linux/slab.h>
#include <linux/errqueue.h>
@@ -145,6 +147,22 @@ void __scm_destroy(struct scm_cookie *scm)
}
EXPORT_SYMBOL(__scm_destroy);
+static inline int scm_replace_pid(struct scm_cookie *scm, struct pid *pid)
+{
+ int err;
+
+ /* drop all previous references */
+ scm_destroy_cred(scm);
+
+ err = pidfs_register_pid(pid);
+ if (unlikely(err))
+ return err;
+
+ scm->pid = pid;
+ scm->creds.pid = pid_vnr(pid);
+ return 0;
+}
+
int __scm_send(struct socket *sock, struct msghdr *msg, struct scm_cookie *p)
{
const struct proto_ops *ops = READ_ONCE(sock->ops);
@@ -189,15 +207,21 @@ int __scm_send(struct socket *sock, struct msghdr *msg, struct scm_cookie *p)
if (err)
goto error;
- p->creds.pid = creds.pid;
if (!p->pid || pid_vnr(p->pid) != creds.pid) {
struct pid *pid;
err = -ESRCH;
pid = find_get_pid(creds.pid);
if (!pid)
goto error;
- put_pid(p->pid);
- p->pid = pid;
+
+ /* pass a struct pid reference from
+ * find_get_pid() to scm_replace_pid().
+ */
+ err = scm_replace_pid(p, pid);
+ if (err) {
+ put_pid(pid);
+ goto error;
+ }
}
err = -EINVAL;
@@ -459,7 +483,7 @@ static void scm_pidfd_recv(struct msghdr *msg, struct scm_cookie *scm)
if (!scm->pid)
return;
- pidfd = pidfd_prepare(scm->pid, 0, &pidfd_file);
+ pidfd = pidfd_prepare(scm->pid, PIDFD_STALE, &pidfd_file);
if (put_cmsg(msg, SOL_SOCKET, SCM_PIDFD, sizeof(int), &pidfd)) {
if (pidfd_file) {
diff --git a/net/unix/af_unix.c b/net/unix/af_unix.c
index 129388c309b0..d52811321fce 100644
--- a/net/unix/af_unix.c
+++ b/net/unix/af_unix.c
@@ -1929,7 +1929,7 @@ static void unix_destruct_scm(struct sk_buff *skb)
struct scm_cookie scm;
memset(&scm, 0, sizeof(scm));
- scm.pid = UNIXCB(skb).pid;
+ scm.pid = UNIXCB(skb).pid;
if (UNIXCB(skb).fp)
unix_detach_fds(&scm, skb);
@@ -1955,21 +1955,45 @@ static int unix_scm_to_skb(struct scm_cookie *scm, struct sk_buff *skb, bool sen
return err;
}
-/*
+static void unix_skb_to_scm(struct sk_buff *skb, struct scm_cookie *scm)
+{
+ scm_set_cred(scm, UNIXCB(skb).pid, UNIXCB(skb).uid, UNIXCB(skb).gid);
+ unix_set_secdata(scm, skb);
+}
+
+/**
+ * unix_maybe_add_creds() - Adds current task uid/gid and struct pid to skb if needed.
+ * @skb: skb to attach creds to.
+ * @sk: Sender sock.
+ * @other: Receiver sock.
+ *
* Some apps rely on write() giving SCM_CREDENTIALS
* We include credentials if source or destination socket
* asserted SOCK_PASSCRED.
+ *
+ * Context: May sleep.
+ * Return: On success zero, on error a negative error code is returned.
*/
-static void unix_maybe_add_creds(struct sk_buff *skb, const struct sock *sk,
- const struct sock *other)
+static int unix_maybe_add_creds(struct sk_buff *skb, const struct sock *sk,
+ const struct sock *other)
{
if (UNIXCB(skb).pid)
- return;
+ return 0;
if (unix_may_passcred(sk) || unix_may_passcred(other)) {
- UNIXCB(skb).pid = get_pid(task_tgid(current));
+ struct pid *pid;
+ int err;
+
+ pid = task_tgid(current);
+ err = pidfs_register_pid(pid);
+ if (unlikely(err))
+ return err;
+
+ UNIXCB(skb).pid = get_pid(pid);
current_uid_gid(&UNIXCB(skb).uid, &UNIXCB(skb).gid);
}
+
+ return 0;
}
static bool unix_skb_scm_eq(struct sk_buff *skb,
@@ -2104,6 +2128,10 @@ lookup:
goto out_sock_put;
}
+ err = unix_maybe_add_creds(skb, sk, other);
+ if (err)
+ goto out_sock_put;
+
restart:
sk_locked = 0;
unix_state_lock(other);
@@ -2212,7 +2240,6 @@ restart_locked:
if (sock_flag(other, SOCK_RCVTSTAMP))
__net_timestamp(skb);
- unix_maybe_add_creds(skb, sk, other);
scm_stat_add(other, skb);
skb_queue_tail(&other->sk_receive_queue, skb);
unix_state_unlock(other);
@@ -2256,6 +2283,10 @@ static int queue_oob(struct sock *sk, struct msghdr *msg, struct sock *other,
if (err < 0)
goto out;
+ err = unix_maybe_add_creds(skb, sk, other);
+ if (err)
+ goto out;
+
skb_put(skb, 1);
err = skb_copy_datagram_from_iter(skb, 0, &msg->msg_iter, 1);
@@ -2275,7 +2306,6 @@ static int queue_oob(struct sock *sk, struct msghdr *msg, struct sock *other,
goto out_unlock;
}
- unix_maybe_add_creds(skb, sk, other);
scm_stat_add(other, skb);
spin_lock(&other->sk_receive_queue.lock);
@@ -2369,6 +2399,10 @@ static int unix_stream_sendmsg(struct socket *sock, struct msghdr *msg,
fds_sent = true;
+ err = unix_maybe_add_creds(skb, sk, other);
+ if (err)
+ goto out_free;
+
if (unlikely(msg->msg_flags & MSG_SPLICE_PAGES)) {
skb->ip_summed = CHECKSUM_UNNECESSARY;
err = skb_splice_from_iter(skb, &msg->msg_iter, size,
@@ -2399,7 +2433,6 @@ static int unix_stream_sendmsg(struct socket *sock, struct msghdr *msg,
goto out_free;
}
- unix_maybe_add_creds(skb, sk, other);
scm_stat_add(other, skb);
skb_queue_tail(&other->sk_receive_queue, skb);
unix_state_unlock(other);
@@ -2547,8 +2580,7 @@ int __unix_dgram_recvmsg(struct sock *sk, struct msghdr *msg, size_t size,
memset(&scm, 0, sizeof(scm));
- scm_set_cred(&scm, UNIXCB(skb).pid, UNIXCB(skb).uid, UNIXCB(skb).gid);
- unix_set_secdata(&scm, skb);
+ unix_skb_to_scm(skb, &scm);
if (!(flags & MSG_PEEK)) {
if (UNIXCB(skb).fp)
@@ -2933,8 +2965,7 @@ unlock:
break;
} else if (unix_may_passcred(sk)) {
/* Copy credentials */
- scm_set_cred(&scm, UNIXCB(skb).pid, UNIXCB(skb).uid, UNIXCB(skb).gid);
- unix_set_secdata(&scm, skb);
+ unix_skb_to_scm(skb, &scm);
check_creds = true;
}
diff --git a/tools/testing/selftests/net/af_unix/scm_pidfd.c b/tools/testing/selftests/net/af_unix/scm_pidfd.c
index 7e534594167e..37e034874034 100644
--- a/tools/testing/selftests/net/af_unix/scm_pidfd.c
+++ b/tools/testing/selftests/net/af_unix/scm_pidfd.c
@@ -15,6 +15,7 @@
#include <sys/types.h>
#include <sys/wait.h>
+#include "../../pidfd/pidfd.h"
#include "../../kselftest_harness.h"
#define clean_errno() (errno == 0 ? "None" : strerror(errno))
@@ -26,6 +27,8 @@
#define SCM_PIDFD 0x04
#endif
+#define CHILD_EXIT_CODE_OK 123
+
static void child_die()
{
exit(1);
@@ -126,16 +129,65 @@ out:
return result;
}
+struct cmsg_data {
+ struct ucred *ucred;
+ int *pidfd;
+};
+
+static int parse_cmsg(struct msghdr *msg, struct cmsg_data *res)
+{
+ struct cmsghdr *cmsg;
+ int data = 0;
+
+ if (msg->msg_flags & (MSG_TRUNC | MSG_CTRUNC)) {
+ log_err("recvmsg: truncated");
+ return 1;
+ }
+
+ for (cmsg = CMSG_FIRSTHDR(msg); cmsg != NULL;
+ cmsg = CMSG_NXTHDR(msg, cmsg)) {
+ if (cmsg->cmsg_level == SOL_SOCKET &&
+ cmsg->cmsg_type == SCM_PIDFD) {
+ if (cmsg->cmsg_len < sizeof(*res->pidfd)) {
+ log_err("CMSG parse: SCM_PIDFD wrong len");
+ return 1;
+ }
+
+ res->pidfd = (void *)CMSG_DATA(cmsg);
+ }
+
+ if (cmsg->cmsg_level == SOL_SOCKET &&
+ cmsg->cmsg_type == SCM_CREDENTIALS) {
+ if (cmsg->cmsg_len < sizeof(*res->ucred)) {
+ log_err("CMSG parse: SCM_CREDENTIALS wrong len");
+ return 1;
+ }
+
+ res->ucred = (void *)CMSG_DATA(cmsg);
+ }
+ }
+
+ if (!res->pidfd) {
+ log_err("CMSG parse: SCM_PIDFD not found");
+ return 1;
+ }
+
+ if (!res->ucred) {
+ log_err("CMSG parse: SCM_CREDENTIALS not found");
+ return 1;
+ }
+
+ return 0;
+}
+
static int cmsg_check(int fd)
{
struct msghdr msg = { 0 };
- struct cmsghdr *cmsg;
+ struct cmsg_data res;
struct iovec iov;
- struct ucred *ucred = NULL;
int data = 0;
char control[CMSG_SPACE(sizeof(struct ucred)) +
CMSG_SPACE(sizeof(int))] = { 0 };
- int *pidfd = NULL;
pid_t parent_pid;
int err;
@@ -158,53 +210,99 @@ static int cmsg_check(int fd)
return 1;
}
- for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL;
- cmsg = CMSG_NXTHDR(&msg, cmsg)) {
- if (cmsg->cmsg_level == SOL_SOCKET &&
- cmsg->cmsg_type == SCM_PIDFD) {
- if (cmsg->cmsg_len < sizeof(*pidfd)) {
- log_err("CMSG parse: SCM_PIDFD wrong len");
- return 1;
- }
+ /* send(pfd, "x", sizeof(char), 0) */
+ if (data != 'x') {
+ log_err("recvmsg: data corruption");
+ return 1;
+ }
- pidfd = (void *)CMSG_DATA(cmsg);
- }
+ if (parse_cmsg(&msg, &res)) {
+ log_err("CMSG parse: parse_cmsg() failed");
+ return 1;
+ }
- if (cmsg->cmsg_level == SOL_SOCKET &&
- cmsg->cmsg_type == SCM_CREDENTIALS) {
- if (cmsg->cmsg_len < sizeof(*ucred)) {
- log_err("CMSG parse: SCM_CREDENTIALS wrong len");
- return 1;
- }
+ /* pidfd from SCM_PIDFD should point to the parent process PID */
+ parent_pid =
+ get_pid_from_fdinfo_file(*res.pidfd, "Pid:", sizeof("Pid:") - 1);
+ if (parent_pid != getppid()) {
+ log_err("wrong SCM_PIDFD %d != %d", parent_pid, getppid());
+ close(*res.pidfd);
+ return 1;
+ }
- ucred = (void *)CMSG_DATA(cmsg);
- }
+ close(*res.pidfd);
+ return 0;
+}
+
+static int cmsg_check_dead(int fd, int expected_pid)
+{
+ int err;
+ struct msghdr msg = { 0 };
+ struct cmsg_data res;
+ struct iovec iov;
+ int data = 0;
+ char control[CMSG_SPACE(sizeof(struct ucred)) +
+ CMSG_SPACE(sizeof(int))] = { 0 };
+ pid_t client_pid;
+ struct pidfd_info info = {
+ .mask = PIDFD_INFO_EXIT,
+ };
+
+ iov.iov_base = &data;
+ iov.iov_len = sizeof(data);
+
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ err = recvmsg(fd, &msg, 0);
+ if (err < 0) {
+ log_err("recvmsg");
+ return 1;
}
- /* send(pfd, "x", sizeof(char), 0) */
- if (data != 'x') {
+ if (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC)) {
+ log_err("recvmsg: truncated");
+ return 1;
+ }
+
+ /* send(cfd, "y", sizeof(char), 0) */
+ if (data != 'y') {
log_err("recvmsg: data corruption");
return 1;
}
- if (!pidfd) {
- log_err("CMSG parse: SCM_PIDFD not found");
+ if (parse_cmsg(&msg, &res)) {
+ log_err("CMSG parse: parse_cmsg() failed");
return 1;
}
- if (!ucred) {
- log_err("CMSG parse: SCM_CREDENTIALS not found");
+ /*
+ * pidfd from SCM_PIDFD should point to the client_pid.
+ * Let's read exit information and check if it's what
+ * we expect to see.
+ */
+ if (ioctl(*res.pidfd, PIDFD_GET_INFO, &info)) {
+ log_err("%s: ioctl(PIDFD_GET_INFO) failed", __func__);
+ close(*res.pidfd);
return 1;
}
- /* pidfd from SCM_PIDFD should point to the parent process PID */
- parent_pid =
- get_pid_from_fdinfo_file(*pidfd, "Pid:", sizeof("Pid:") - 1);
- if (parent_pid != getppid()) {
- log_err("wrong SCM_PIDFD %d != %d", parent_pid, getppid());
+ if (!(info.mask & PIDFD_INFO_EXIT)) {
+ log_err("%s: No exit information from ioctl(PIDFD_GET_INFO)", __func__);
+ close(*res.pidfd);
return 1;
}
+ err = WIFEXITED(info.exit_code) ? WEXITSTATUS(info.exit_code) : 1;
+ if (err != CHILD_EXIT_CODE_OK) {
+ log_err("%s: wrong exit_code %d != %d", __func__, err, CHILD_EXIT_CODE_OK);
+ close(*res.pidfd);
+ return 1;
+ }
+
+ close(*res.pidfd);
return 0;
}
@@ -291,6 +389,24 @@ static void fill_sockaddr(struct sock_addr *addr, bool abstract)
memcpy(sun_path_buf, addr->sock_name, strlen(addr->sock_name));
}
+static int sk_enable_cred_pass(int sk)
+{
+ int on = 0;
+
+ on = 1;
+ if (setsockopt(sk, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on))) {
+ log_err("Failed to set SO_PASSCRED");
+ return 1;
+ }
+
+ if (setsockopt(sk, SOL_SOCKET, SO_PASSPIDFD, &on, sizeof(on))) {
+ log_err("Failed to set SO_PASSPIDFD");
+ return 1;
+ }
+
+ return 0;
+}
+
static void client(FIXTURE_DATA(scm_pidfd) *self,
const FIXTURE_VARIANT(scm_pidfd) *variant)
{
@@ -299,7 +415,6 @@ static void client(FIXTURE_DATA(scm_pidfd) *self,
struct ucred peer_cred;
int peer_pidfd;
pid_t peer_pid;
- int on = 0;
cfd = socket(AF_UNIX, variant->type, 0);
if (cfd < 0) {
@@ -322,14 +437,8 @@ static void client(FIXTURE_DATA(scm_pidfd) *self,
child_die();
}
- on = 1;
- if (setsockopt(cfd, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on))) {
- log_err("Failed to set SO_PASSCRED");
- child_die();
- }
-
- if (setsockopt(cfd, SOL_SOCKET, SO_PASSPIDFD, &on, sizeof(on))) {
- log_err("Failed to set SO_PASSPIDFD");
+ if (sk_enable_cred_pass(cfd)) {
+ log_err("sk_enable_cred_pass() failed");
child_die();
}
@@ -340,6 +449,12 @@ static void client(FIXTURE_DATA(scm_pidfd) *self,
child_die();
}
+ /* send something to the parent so it can receive SCM_PIDFD too and validate it */
+ if (send(cfd, "y", sizeof(char), 0) == -1) {
+ log_err("Failed to send(cfd, \"y\", sizeof(char), 0)");
+ child_die();
+ }
+
/* skip further for SOCK_DGRAM as it's not applicable */
if (variant->type == SOCK_DGRAM)
return;
@@ -398,7 +513,13 @@ TEST_F(scm_pidfd, test)
close(self->server);
close(self->startup_pipe[0]);
client(self, variant);
- exit(0);
+
+ /*
+ * It's a bit unusual, but in case of success we return non-zero
+ * exit code (CHILD_EXIT_CODE_OK) and then we expect to read it
+ * from ioctl(PIDFD_GET_INFO) in cmsg_check_dead().
+ */
+ exit(CHILD_EXIT_CODE_OK);
}
close(self->startup_pipe[1]);
@@ -421,9 +542,17 @@ TEST_F(scm_pidfd, test)
ASSERT_NE(-1, err);
}
- close(pfd);
waitpid(self->client_pid, &child_status, 0);
- ASSERT_EQ(0, WIFEXITED(child_status) ? WEXITSTATUS(child_status) : 1);
+ /* see comment before exit(CHILD_EXIT_CODE_OK) */
+ ASSERT_EQ(CHILD_EXIT_CODE_OK, WIFEXITED(child_status) ? WEXITSTATUS(child_status) : 1);
+
+ err = sk_enable_cred_pass(pfd);
+ ASSERT_EQ(0, err);
+
+ err = cmsg_check_dead(pfd, self->client_pid);
+ ASSERT_EQ(0, err);
+
+ close(pfd);
}
TEST_HARNESS_MAIN