diff options
-rw-r--r-- | include/net/mctp.h | 5 | ||||
-rw-r--r-- | include/net/netns/mctp.h | 20 | ||||
-rw-r--r-- | net/mctp/af_mctp.c | 148 | ||||
-rw-r--r-- | net/mctp/route.c | 79 | ||||
-rw-r--r-- | net/mctp/test/route-test.c | 194 | ||||
-rw-r--r-- | net/mctp/test/sock-test.c | 167 | ||||
-rw-r--r-- | net/mctp/test/utils.c | 36 | ||||
-rw-r--r-- | net/mctp/test/utils.h | 17 |
8 files changed, 634 insertions, 32 deletions
diff --git a/include/net/mctp.h b/include/net/mctp.h index ac4f4ecdfc24..c3207ce98f07 100644 --- a/include/net/mctp.h +++ b/include/net/mctp.h @@ -69,7 +69,10 @@ struct mctp_sock { /* bind() params */ unsigned int bind_net; - mctp_eid_t bind_addr; + mctp_eid_t bind_local_addr; + mctp_eid_t bind_peer_addr; + unsigned int bind_peer_net; + bool bind_peer_set; __u8 bind_type; /* sendmsg()/recvmsg() uses struct sockaddr_mctp_ext */ diff --git a/include/net/netns/mctp.h b/include/net/netns/mctp.h index 1db8f9aaddb4..89555f90b97b 100644 --- a/include/net/netns/mctp.h +++ b/include/net/netns/mctp.h @@ -6,19 +6,25 @@ #ifndef __NETNS_MCTP_H__ #define __NETNS_MCTP_H__ +#include <linux/hash.h> +#include <linux/hashtable.h> #include <linux/mutex.h> #include <linux/types.h> +#define MCTP_BINDS_BITS 7 + struct netns_mctp { /* Only updated under RTNL, entries freed via RCU */ struct list_head routes; - /* Bound sockets: list of sockets bound by type. - * This list is updated from non-atomic contexts (under bind_lock), - * and read (under rcu) in packet rx + /* Bound sockets: hash table of sockets, keyed by + * (type, src_eid, dest_eid). + * Specific src_eid/dest_eid entries also have an entry for + * MCTP_ADDR_ANY. This list is updated from non-atomic contexts + * (under bind_lock), and read (under rcu) in packet rx. */ struct mutex bind_lock; - struct hlist_head binds; + DECLARE_HASHTABLE(binds, MCTP_BINDS_BITS); /* tag allocations. This list is read and updated from atomic contexts, * but elements are free()ed after a RCU grace-period @@ -34,4 +40,10 @@ struct netns_mctp { struct list_head neighbours; }; +static inline u32 mctp_bind_hash(u8 type, u8 local_addr, u8 peer_addr) +{ + return hash_32(type | (u32)local_addr << 8 | (u32)peer_addr << 16, + MCTP_BINDS_BITS); +} + #endif /* __NETNS_MCTP_H__ */ diff --git a/net/mctp/af_mctp.c b/net/mctp/af_mctp.c index aef74308c18e..df4e8cf33899 100644 --- a/net/mctp/af_mctp.c +++ b/net/mctp/af_mctp.c @@ -53,6 +53,7 @@ static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen) { struct sock *sk = sock->sk; struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); + struct net *net = sock_net(&msk->sk); struct sockaddr_mctp *smctp; int rc; @@ -73,14 +74,48 @@ static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen) lock_sock(sk); - /* TODO: allow rebind */ if (sk_hashed(sk)) { rc = -EADDRINUSE; goto out_release; } - msk->bind_net = smctp->smctp_network; - msk->bind_addr = smctp->smctp_addr.s_addr; - msk->bind_type = smctp->smctp_type & 0x7f; /* ignore the IC bit */ + + msk->bind_local_addr = smctp->smctp_addr.s_addr; + + /* MCTP_NET_ANY with a specific EID is resolved to the default net + * at bind() time. + * For bind_addr=MCTP_ADDR_ANY it is handled specially at route + * lookup time. + */ + if (smctp->smctp_network == MCTP_NET_ANY && + msk->bind_local_addr != MCTP_ADDR_ANY) { + msk->bind_net = mctp_default_net(net); + } else { + msk->bind_net = smctp->smctp_network; + } + + /* ignore the IC bit */ + smctp->smctp_type &= 0x7f; + + if (msk->bind_peer_set) { + if (msk->bind_type != smctp->smctp_type) { + /* Prior connect() had a different type */ + rc = -EINVAL; + goto out_release; + } + + if (msk->bind_net == MCTP_NET_ANY) { + /* Restrict to the network passed to connect() */ + msk->bind_net = msk->bind_peer_net; + } + + if (msk->bind_net != msk->bind_peer_net) { + /* connect() had a different net to bind() */ + rc = -EINVAL; + goto out_release; + } + } else { + msk->bind_type = smctp->smctp_type; + } rc = sk->sk_prot->hash(sk); @@ -90,6 +125,67 @@ out_release: return rc; } +/* Used to set a specific peer prior to bind. Not used for outbound + * connections (Tag Owner set) since MCTP is a datagram protocol. + */ +static int mctp_connect(struct socket *sock, struct sockaddr *addr, + int addrlen, int flags) +{ + struct sock *sk = sock->sk; + struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); + struct net *net = sock_net(&msk->sk); + struct sockaddr_mctp *smctp; + int rc; + + if (addrlen != sizeof(*smctp)) + return -EINVAL; + + if (addr->sa_family != AF_MCTP) + return -EAFNOSUPPORT; + + /* It's a valid sockaddr for MCTP, cast and do protocol checks */ + smctp = (struct sockaddr_mctp *)addr; + + if (!mctp_sockaddr_is_ok(smctp)) + return -EINVAL; + + /* Can't bind by tag */ + if (smctp->smctp_tag) + return -EINVAL; + + /* IC bit must be unset */ + if (smctp->smctp_type & 0x80) + return -EINVAL; + + lock_sock(sk); + + if (sk_hashed(sk)) { + /* bind() already */ + rc = -EADDRINUSE; + goto out_release; + } + + if (msk->bind_peer_set) { + /* connect() already */ + rc = -EADDRINUSE; + goto out_release; + } + + msk->bind_peer_set = true; + msk->bind_peer_addr = smctp->smctp_addr.s_addr; + msk->bind_type = smctp->smctp_type; + if (smctp->smctp_network == MCTP_NET_ANY) + msk->bind_peer_net = mctp_default_net(net); + else + msk->bind_peer_net = smctp->smctp_network; + + rc = 0; + +out_release: + release_sock(sk); + return rc; +} + static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) { DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name); @@ -533,7 +629,7 @@ static const struct proto_ops mctp_dgram_ops = { .family = PF_MCTP, .release = mctp_release, .bind = mctp_bind, - .connect = sock_no_connect, + .connect = mctp_connect, .socketpair = sock_no_socketpair, .accept = sock_no_accept, .getname = sock_no_getname, @@ -600,6 +696,7 @@ static int mctp_sk_init(struct sock *sk) INIT_HLIST_HEAD(&msk->keys); timer_setup(&msk->key_expiry, mctp_sk_expire_keys, 0); + msk->bind_peer_set = false; return 0; } @@ -611,15 +708,48 @@ static void mctp_sk_close(struct sock *sk, long timeout) static int mctp_sk_hash(struct sock *sk) { struct net *net = sock_net(sk); + struct sock *existing; + struct mctp_sock *msk; + mctp_eid_t remote; + u32 hash; + int rc; + + msk = container_of(sk, struct mctp_sock, sk); + + if (msk->bind_peer_set) + remote = msk->bind_peer_addr; + else + remote = MCTP_ADDR_ANY; + hash = mctp_bind_hash(msk->bind_type, msk->bind_local_addr, remote); + + mutex_lock(&net->mctp.bind_lock); + + /* Prevent duplicate binds. */ + sk_for_each(existing, &net->mctp.binds[hash]) { + struct mctp_sock *mex = + container_of(existing, struct mctp_sock, sk); + + bool same_peer = (mex->bind_peer_set && msk->bind_peer_set && + mex->bind_peer_addr == msk->bind_peer_addr) || + (!mex->bind_peer_set && !msk->bind_peer_set); + + if (mex->bind_type == msk->bind_type && + mex->bind_local_addr == msk->bind_local_addr && same_peer && + mex->bind_net == msk->bind_net) { + rc = -EADDRINUSE; + goto out; + } + } /* Bind lookup runs under RCU, remain live during that. */ sock_set_flag(sk, SOCK_RCU_FREE); - mutex_lock(&net->mctp.bind_lock); - sk_add_node_rcu(sk, &net->mctp.binds); - mutex_unlock(&net->mctp.bind_lock); + sk_add_node_rcu(sk, &net->mctp.binds[hash]); + rc = 0; - return 0; +out: + mutex_unlock(&net->mctp.bind_lock); + return rc; } static void mctp_sk_unhash(struct sock *sk) diff --git a/net/mctp/route.c b/net/mctp/route.c index a20d6b11d418..2b2b958ef6a3 100644 --- a/net/mctp/route.c +++ b/net/mctp/route.c @@ -40,33 +40,36 @@ static int mctp_dst_discard(struct mctp_dst *dst, struct sk_buff *skb) return 0; } -static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb) +static struct mctp_sock *mctp_lookup_bind_details(struct net *net, + struct sk_buff *skb, + u8 type, u8 dest, + u8 src, bool allow_net_any) { struct mctp_skb_cb *cb = mctp_cb(skb); - struct mctp_hdr *mh; struct sock *sk; - u8 type; - - WARN_ON(!rcu_read_lock_held()); - - /* TODO: look up in skb->cb? */ - mh = mctp_hdr(skb); + u8 hash; - if (!skb_headlen(skb)) - return NULL; + WARN_ON_ONCE(!rcu_read_lock_held()); - type = (*(u8 *)skb->data) & 0x7f; + hash = mctp_bind_hash(type, dest, src); - sk_for_each_rcu(sk, &net->mctp.binds) { + sk_for_each_rcu(sk, &net->mctp.binds[hash]) { struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); + if (!allow_net_any && msk->bind_net == MCTP_NET_ANY) + continue; + if (msk->bind_net != MCTP_NET_ANY && msk->bind_net != cb->net) continue; if (msk->bind_type != type) continue; - if (!mctp_address_matches(msk->bind_addr, mh->dest)) + if (msk->bind_peer_set && + !mctp_address_matches(msk->bind_peer_addr, src)) + continue; + + if (!mctp_address_matches(msk->bind_local_addr, dest)) continue; return msk; @@ -75,6 +78,54 @@ static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb) return NULL; } +static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb) +{ + struct mctp_sock *msk; + struct mctp_hdr *mh; + u8 type; + + /* TODO: look up in skb->cb? */ + mh = mctp_hdr(skb); + + if (!skb_headlen(skb)) + return NULL; + + type = (*(u8 *)skb->data) & 0x7f; + + /* Look for binds in order of widening scope. A given destination or + * source address also implies matching on a particular network. + * + * - Matching destination and source + * - Matching destination + * - Matching source + * - Matching network, any address + * - Any network or address + */ + + msk = mctp_lookup_bind_details(net, skb, type, mh->dest, mh->src, + false); + if (msk) + return msk; + msk = mctp_lookup_bind_details(net, skb, type, MCTP_ADDR_ANY, mh->src, + false); + if (msk) + return msk; + msk = mctp_lookup_bind_details(net, skb, type, mh->dest, MCTP_ADDR_ANY, + false); + if (msk) + return msk; + msk = mctp_lookup_bind_details(net, skb, type, MCTP_ADDR_ANY, + MCTP_ADDR_ANY, false); + if (msk) + return msk; + msk = mctp_lookup_bind_details(net, skb, type, MCTP_ADDR_ANY, + MCTP_ADDR_ANY, true); + if (msk) + return msk; + + return NULL; +} + /* A note on the key allocations. * * struct net->mctp.keys contains our set of currently-allocated keys for @@ -1671,7 +1722,7 @@ static int __net_init mctp_routes_net_init(struct net *net) struct netns_mctp *ns = &net->mctp; INIT_LIST_HEAD(&ns->routes); - INIT_HLIST_HEAD(&ns->binds); + hash_init(ns->binds); mutex_init(&ns->bind_lock); INIT_HLIST_HEAD(&ns->keys); spin_lock_init(&ns->keys_lock); diff --git a/net/mctp/test/route-test.c b/net/mctp/test/route-test.c index 7a398f41b621..fb6b46a952cb 100644 --- a/net/mctp/test/route-test.c +++ b/net/mctp/test/route-test.c @@ -1164,8 +1164,6 @@ static void mctp_test_route_extaddr_input(struct kunit *test) rc = mctp_dst_input(&dst, skb); KUNIT_ASSERT_EQ(test, rc, 0); - mctp_test_dst_release(&dst, &tpq); - skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc); KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb2); KUNIT_ASSERT_EQ(test, skb2->len, len); @@ -1179,8 +1177,8 @@ static void mctp_test_route_extaddr_input(struct kunit *test) KUNIT_EXPECT_EQ(test, cb2->halen, sizeof(haddr)); KUNIT_EXPECT_MEMEQ(test, cb2->haddr, haddr, sizeof(haddr)); - skb_free_datagram(sock->sk, skb2); - mctp_test_destroy_dev(dev); + kfree_skb(skb2); + __mctp_route_test_fini(test, dev, &dst, &tpq, sock); } static void mctp_test_route_gw_lookup(struct kunit *test) @@ -1410,6 +1408,193 @@ static void mctp_test_route_gw_output(struct kunit *test) kfree_skb(skb); } +struct mctp_bind_lookup_test { + /* header of incoming message */ + struct mctp_hdr hdr; + u8 ty; + /* mctp network of incoming interface (smctp_network) */ + unsigned int net; + + /* expected socket, matches .name in lookup_binds, NULL for dropped */ + const char *expect; +}; + +/* Single-packet TO-set message */ +#define LK(src, dst) RX_HDR(1, (src), (dst), FL_S | FL_E | FL_TO) + +/* Input message test cases for bind lookup tests. + * + * 10 and 11 are local EIDs. + * 20 and 21 are remote EIDs. + */ +static const struct mctp_bind_lookup_test mctp_bind_lookup_tests[] = { + /* both local-eid and remote-eid binds, remote eid is preferenced */ + { .hdr = LK(20, 10), .ty = 1, .net = 1, .expect = "remote20" }, + + { .hdr = LK(20, 255), .ty = 1, .net = 1, .expect = "remote20" }, + { .hdr = LK(20, 0), .ty = 1, .net = 1, .expect = "remote20" }, + { .hdr = LK(0, 255), .ty = 1, .net = 1, .expect = "any" }, + { .hdr = LK(0, 11), .ty = 1, .net = 1, .expect = "any" }, + { .hdr = LK(0, 0), .ty = 1, .net = 1, .expect = "any" }, + { .hdr = LK(0, 10), .ty = 1, .net = 1, .expect = "local10" }, + { .hdr = LK(21, 10), .ty = 1, .net = 1, .expect = "local10" }, + { .hdr = LK(21, 11), .ty = 1, .net = 1, .expect = "remote21local11" }, + + /* both src and dest set to eid=99. unusual, but accepted + * by MCTP stack currently. + */ + { .hdr = LK(99, 99), .ty = 1, .net = 1, .expect = "any" }, + + /* unbound smctp_type */ + { .hdr = LK(20, 10), .ty = 3, .net = 1, .expect = NULL }, + + /* smctp_network tests */ + + { .hdr = LK(0, 0), .ty = 1, .net = 7, .expect = "any" }, + { .hdr = LK(21, 10), .ty = 1, .net = 2, .expect = "any" }, + + /* remote EID 20 matches, but MCTP_NET_ANY in "remote20" resolved + * to net=1, so lookup doesn't match "remote20" + */ + { .hdr = LK(20, 10), .ty = 1, .net = 3, .expect = "any" }, + + { .hdr = LK(21, 10), .ty = 1, .net = 3, .expect = "remote21net3" }, + { .hdr = LK(21, 10), .ty = 1, .net = 4, .expect = "remote21net4" }, + { .hdr = LK(21, 10), .ty = 1, .net = 5, .expect = "remote21net5" }, + + { .hdr = LK(21, 10), .ty = 1, .net = 5, .expect = "remote21net5" }, + + { .hdr = LK(99, 10), .ty = 1, .net = 8, .expect = "local10net8" }, + + { .hdr = LK(99, 10), .ty = 1, .net = 9, .expect = "anynet9" }, + { .hdr = LK(0, 0), .ty = 1, .net = 9, .expect = "anynet9" }, + { .hdr = LK(99, 99), .ty = 1, .net = 9, .expect = "anynet9" }, + { .hdr = LK(20, 10), .ty = 1, .net = 9, .expect = "anynet9" }, +}; + +/* Binds to create during the lookup tests */ +static const struct mctp_test_bind_setup lookup_binds[] = { + /* any address and net, type 1 */ + { .name = "any", .bind_addr = MCTP_ADDR_ANY, + .bind_net = MCTP_NET_ANY, .bind_type = 1, }, + /* local eid 10, net 1 (resolved from MCTP_NET_ANY) */ + { .name = "local10", .bind_addr = 10, + .bind_net = MCTP_NET_ANY, .bind_type = 1, }, + /* local eid 10, net 8 */ + { .name = "local10net8", .bind_addr = 10, + .bind_net = 8, .bind_type = 1, }, + /* any EID, net 9 */ + { .name = "anynet9", .bind_addr = MCTP_ADDR_ANY, + .bind_net = 9, .bind_type = 1, }, + + /* remote eid 20, net 1, any local eid */ + { .name = "remote20", .bind_addr = MCTP_ADDR_ANY, + .bind_net = MCTP_NET_ANY, .bind_type = 1, + .have_peer = true, .peer_addr = 20, .peer_net = MCTP_NET_ANY, }, + + /* remote eid 20, net 1, local eid 11 */ + { .name = "remote21local11", .bind_addr = 11, + .bind_net = MCTP_NET_ANY, .bind_type = 1, + .have_peer = true, .peer_addr = 21, .peer_net = MCTP_NET_ANY, }, + + /* remote eid 21, specific net=3 for connect() */ + { .name = "remote21net3", .bind_addr = MCTP_ADDR_ANY, + .bind_net = MCTP_NET_ANY, .bind_type = 1, + .have_peer = true, .peer_addr = 21, .peer_net = 3, }, + + /* remote eid 21, net 4 for bind, specific net=4 for connect() */ + { .name = "remote21net4", .bind_addr = MCTP_ADDR_ANY, + .bind_net = 4, .bind_type = 1, + .have_peer = true, .peer_addr = 21, .peer_net = 4, }, + + /* remote eid 21, net 5 for bind, specific net=5 for connect() */ + { .name = "remote21net5", .bind_addr = MCTP_ADDR_ANY, + .bind_net = 5, .bind_type = 1, + .have_peer = true, .peer_addr = 21, .peer_net = 5, }, +}; + +static void mctp_bind_lookup_desc(const struct mctp_bind_lookup_test *t, + char *desc) +{ + snprintf(desc, KUNIT_PARAM_DESC_SIZE, + "{src %d dst %d ty %d net %d expect %s}", + t->hdr.src, t->hdr.dest, t->ty, t->net, t->expect); +} + +KUNIT_ARRAY_PARAM(mctp_bind_lookup, mctp_bind_lookup_tests, + mctp_bind_lookup_desc); + +static void mctp_test_bind_lookup(struct kunit *test) +{ + const struct mctp_bind_lookup_test *rx; + struct socket *socks[ARRAY_SIZE(lookup_binds)]; + struct sk_buff *skb_pkt = NULL, *skb_sock = NULL; + struct socket *sock_ty0, *sock_expect = NULL; + struct mctp_test_pktqueue tpq; + struct mctp_test_dev *dev; + struct mctp_dst dst; + int rc; + + rx = test->param_value; + + __mctp_route_test_init(test, &dev, &dst, &tpq, &sock_ty0, rx->net); + /* Create all binds */ + for (size_t i = 0; i < ARRAY_SIZE(lookup_binds); i++) { + mctp_test_bind_run(test, &lookup_binds[i], + &rc, &socks[i]); + KUNIT_ASSERT_EQ(test, rc, 0); + + /* Record the expected receive socket */ + if (rx->expect && + strcmp(rx->expect, lookup_binds[i].name) == 0) { + KUNIT_ASSERT_NULL(test, sock_expect); + sock_expect = socks[i]; + } + } + KUNIT_ASSERT_EQ(test, !!sock_expect, !!rx->expect); + + /* Create test message */ + skb_pkt = mctp_test_create_skb_data(&rx->hdr, &rx->ty); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb_pkt); + mctp_test_skb_set_dev(skb_pkt, dev); + mctp_test_pktqueue_init(&tpq); + + rc = mctp_dst_input(&dst, skb_pkt); + if (rx->expect) { + /* Test the message is received on the expected socket */ + KUNIT_EXPECT_EQ(test, rc, 0); + skb_sock = skb_recv_datagram(sock_expect->sk, + MSG_DONTWAIT, &rc); + if (!skb_sock) { + /* Find which socket received it instead */ + for (size_t i = 0; i < ARRAY_SIZE(lookup_binds); i++) { + skb_sock = skb_recv_datagram(socks[i]->sk, + MSG_DONTWAIT, &rc); + if (skb_sock) { + KUNIT_FAIL(test, + "received on incorrect socket '%s', expect '%s'", + lookup_binds[i].name, + rx->expect); + goto cleanup; + } + } + KUNIT_FAIL(test, "no message received"); + } + } else { + KUNIT_EXPECT_NE(test, rc, 0); + } + +cleanup: + kfree_skb(skb_sock); + kfree_skb(skb_pkt); + + /* Drop all binds */ + for (size_t i = 0; i < ARRAY_SIZE(lookup_binds); i++) + sock_release(socks[i]); + + __mctp_route_test_fini(test, dev, &dst, &tpq, sock_ty0); +} + static struct kunit_case mctp_test_cases[] = { KUNIT_CASE_PARAM(mctp_test_fragment, mctp_frag_gen_params), KUNIT_CASE_PARAM(mctp_test_rx_input, mctp_rx_input_gen_params), @@ -1431,6 +1616,7 @@ static struct kunit_case mctp_test_cases[] = { KUNIT_CASE(mctp_test_route_gw_loop), KUNIT_CASE_PARAM(mctp_test_route_gw_mtu, mctp_route_gw_mtu_gen_params), KUNIT_CASE(mctp_test_route_gw_output), + KUNIT_CASE_PARAM(mctp_test_bind_lookup, mctp_bind_lookup_gen_params), {} }; diff --git a/net/mctp/test/sock-test.c b/net/mctp/test/sock-test.c index 4eb3a724dca3..b0942deb5019 100644 --- a/net/mctp/test/sock-test.c +++ b/net/mctp/test/sock-test.c @@ -215,9 +215,176 @@ static void mctp_test_sock_recvmsg_extaddr(struct kunit *test) __mctp_sock_test_fini(test, dev, rt, sock); } +static const struct mctp_test_bind_setup bind_addrany_netdefault_type1 = { + .bind_addr = MCTP_ADDR_ANY, .bind_net = MCTP_NET_ANY, .bind_type = 1, +}; + +static const struct mctp_test_bind_setup bind_addrany_net2_type1 = { + .bind_addr = MCTP_ADDR_ANY, .bind_net = 2, .bind_type = 1, +}; + +/* 1 is default net */ +static const struct mctp_test_bind_setup bind_addr8_net1_type1 = { + .bind_addr = 8, .bind_net = 1, .bind_type = 1, +}; + +static const struct mctp_test_bind_setup bind_addrany_net1_type1 = { + .bind_addr = MCTP_ADDR_ANY, .bind_net = 1, .bind_type = 1, +}; + +/* 2 is an arbitrary net */ +static const struct mctp_test_bind_setup bind_addr8_net2_type1 = { + .bind_addr = 8, .bind_net = 2, .bind_type = 1, +}; + +static const struct mctp_test_bind_setup bind_addr8_netdefault_type1 = { + .bind_addr = 8, .bind_net = MCTP_NET_ANY, .bind_type = 1, +}; + +static const struct mctp_test_bind_setup bind_addrany_net2_type2 = { + .bind_addr = MCTP_ADDR_ANY, .bind_net = 2, .bind_type = 2, +}; + +static const struct mctp_test_bind_setup bind_addrany_net2_type1_peer9 = { + .bind_addr = MCTP_ADDR_ANY, .bind_net = 2, .bind_type = 1, + .have_peer = true, .peer_addr = 9, .peer_net = 2, +}; + +struct mctp_bind_pair_test { + const struct mctp_test_bind_setup *bind1; + const struct mctp_test_bind_setup *bind2; + int error; +}; + +/* Pairs of binds and whether they will conflict */ +static const struct mctp_bind_pair_test mctp_bind_pair_tests[] = { + /* Both ADDR_ANY, conflict */ + { &bind_addrany_netdefault_type1, &bind_addrany_netdefault_type1, + EADDRINUSE }, + /* Same specific EID, conflict */ + { &bind_addr8_netdefault_type1, &bind_addr8_netdefault_type1, + EADDRINUSE }, + /* ADDR_ANY vs specific EID, OK */ + { &bind_addrany_netdefault_type1, &bind_addr8_netdefault_type1, 0 }, + /* ADDR_ANY different types, OK */ + { &bind_addrany_net2_type2, &bind_addrany_net2_type1, 0 }, + /* ADDR_ANY different nets, OK */ + { &bind_addrany_net2_type1, &bind_addrany_netdefault_type1, 0 }, + + /* specific EID, NET_ANY (resolves to default) + * vs specific EID, explicit default net 1, conflict + */ + { &bind_addr8_netdefault_type1, &bind_addr8_net1_type1, EADDRINUSE }, + + /* specific EID, net 1 vs specific EID, net 2, ok */ + { &bind_addr8_net1_type1, &bind_addr8_net2_type1, 0 }, + + /* ANY_ADDR, NET_ANY (doesn't resolve to default) + * vs ADDR_ANY, explicit default net 1, OK + */ + { &bind_addrany_netdefault_type1, &bind_addrany_net1_type1, 0 }, + + /* specific remote peer doesn't conflict with any-peer bind */ + { &bind_addrany_net2_type1_peer9, &bind_addrany_net2_type1, 0 }, + + /* bind() NET_ANY is allowed with a connect() net */ + { &bind_addrany_net2_type1_peer9, &bind_addrany_netdefault_type1, 0 }, +}; + +static void mctp_bind_pair_desc(const struct mctp_bind_pair_test *t, char *desc) +{ + char peer1[25] = {0}, peer2[25] = {0}; + + if (t->bind1->have_peer) + snprintf(peer1, sizeof(peer1), ", peer %d net %d", + t->bind1->peer_addr, t->bind1->peer_net); + if (t->bind2->have_peer) + snprintf(peer2, sizeof(peer2), ", peer %d net %d", + t->bind2->peer_addr, t->bind2->peer_net); + + snprintf(desc, KUNIT_PARAM_DESC_SIZE, + "{bind(addr %d, type %d, net %d%s)} {bind(addr %d, type %d, net %d%s)} -> error %d", + t->bind1->bind_addr, t->bind1->bind_type, + t->bind1->bind_net, peer1, + t->bind2->bind_addr, t->bind2->bind_type, + t->bind2->bind_net, peer2, t->error); +} + +KUNIT_ARRAY_PARAM(mctp_bind_pair, mctp_bind_pair_tests, mctp_bind_pair_desc); + +static void mctp_test_bind_invalid(struct kunit *test) +{ + struct socket *sock; + int rc; + + /* bind() fails if the bind() vs connect() networks mismatch. */ + const struct mctp_test_bind_setup bind_connect_net_mismatch = { + .bind_addr = MCTP_ADDR_ANY, .bind_net = 1, .bind_type = 1, + .have_peer = true, .peer_addr = 9, .peer_net = 2, + }; + mctp_test_bind_run(test, &bind_connect_net_mismatch, &rc, &sock); + KUNIT_EXPECT_EQ(test, -rc, EINVAL); + sock_release(sock); +} + +static int +mctp_test_bind_conflicts_inner(struct kunit *test, + const struct mctp_test_bind_setup *bind1, + const struct mctp_test_bind_setup *bind2) +{ + struct socket *sock1 = NULL, *sock2 = NULL, *sock3 = NULL; + int bind_errno; + + /* Bind to first address, always succeeds */ + mctp_test_bind_run(test, bind1, &bind_errno, &sock1); + KUNIT_EXPECT_EQ(test, bind_errno, 0); + + /* A second identical bind always fails */ + mctp_test_bind_run(test, bind1, &bind_errno, &sock2); + KUNIT_EXPECT_EQ(test, -bind_errno, EADDRINUSE); + + /* A different bind, result is returned */ + mctp_test_bind_run(test, bind2, &bind_errno, &sock3); + + if (sock1) + sock_release(sock1); + if (sock2) + sock_release(sock2); + if (sock3) + sock_release(sock3); + + return bind_errno; +} + +static void mctp_test_bind_conflicts(struct kunit *test) +{ + const struct mctp_bind_pair_test *pair; + int bind_errno; + + pair = test->param_value; + + bind_errno = + mctp_test_bind_conflicts_inner(test, pair->bind1, pair->bind2); + KUNIT_EXPECT_EQ(test, -bind_errno, pair->error); + + /* swapping the calls, the second bind should still fail */ + bind_errno = + mctp_test_bind_conflicts_inner(test, pair->bind2, pair->bind1); + KUNIT_EXPECT_EQ(test, -bind_errno, pair->error); +} + +static void mctp_test_assumptions(struct kunit *test) +{ + /* check assumption of default net from bind_addr8_net1_type1 */ + KUNIT_ASSERT_EQ(test, mctp_default_net(&init_net), 1); +} + static struct kunit_case mctp_test_cases[] = { + KUNIT_CASE(mctp_test_assumptions), KUNIT_CASE(mctp_test_sock_sendmsg_extaddr), KUNIT_CASE(mctp_test_sock_recvmsg_extaddr), + KUNIT_CASE_PARAM(mctp_test_bind_conflicts, mctp_bind_pair_gen_params), + KUNIT_CASE(mctp_test_bind_invalid), {} }; diff --git a/net/mctp/test/utils.c b/net/mctp/test/utils.c index 01f5af416b81..953d41902771 100644 --- a/net/mctp/test/utils.c +++ b/net/mctp/test/utils.c @@ -258,3 +258,39 @@ struct sk_buff *__mctp_test_create_skb_data(const struct mctp_hdr *hdr, return skb; } + +void mctp_test_bind_run(struct kunit *test, + const struct mctp_test_bind_setup *setup, + int *ret_bind_errno, struct socket **sock) +{ + struct sockaddr_mctp addr; + int rc; + + *ret_bind_errno = -EIO; + + rc = sock_create_kern(&init_net, AF_MCTP, SOCK_DGRAM, 0, sock); + KUNIT_ASSERT_EQ(test, rc, 0); + + /* connect() if requested */ + if (setup->have_peer) { + memset(&addr, 0x0, sizeof(addr)); + addr.smctp_family = AF_MCTP; + addr.smctp_network = setup->peer_net; + addr.smctp_addr.s_addr = setup->peer_addr; + /* connect() type must match bind() type */ + addr.smctp_type = setup->bind_type; + rc = kernel_connect(*sock, (struct sockaddr *)&addr, + sizeof(addr), 0); + KUNIT_EXPECT_EQ(test, rc, 0); + } + + /* bind() */ + memset(&addr, 0x0, sizeof(addr)); + addr.smctp_family = AF_MCTP; + addr.smctp_network = setup->bind_net; + addr.smctp_addr.s_addr = setup->bind_addr; + addr.smctp_type = setup->bind_type; + + *ret_bind_errno = + kernel_bind(*sock, (struct sockaddr *)&addr, sizeof(addr)); +} diff --git a/net/mctp/test/utils.h b/net/mctp/test/utils.h index f10d1d9066cc..06bdb6cb5eff 100644 --- a/net/mctp/test/utils.h +++ b/net/mctp/test/utils.h @@ -31,6 +31,19 @@ struct mctp_test_pktqueue { struct sk_buff_head pkts; }; +struct mctp_test_bind_setup { + mctp_eid_t bind_addr; + int bind_net; + u8 bind_type; + + bool have_peer; + mctp_eid_t peer_addr; + int peer_net; + + /* optional name. Used for comparison in "lookup" tests */ + const char *name; +}; + struct mctp_test_dev *mctp_test_create_dev(void); struct mctp_test_dev *mctp_test_create_dev_lladdr(unsigned short lladdr_len, const unsigned char *lladdr); @@ -61,4 +74,8 @@ struct sk_buff *__mctp_test_create_skb_data(const struct mctp_hdr *hdr, #define mctp_test_create_skb_data(h, d) \ __mctp_test_create_skb_data(h, d, sizeof(*d)) +void mctp_test_bind_run(struct kunit *test, + const struct mctp_test_bind_setup *setup, + int *ret_bind_errno, struct socket **sock); + #endif /* __NET_MCTP_TEST_UTILS_H */ |