summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaolo Abeni <pabeni@redhat.com>2025-07-15 12:08:41 +0200
committerPaolo Abeni <pabeni@redhat.com>2025-07-15 12:08:41 +0200
commit55e8757c696210292cfda6f1464991d6f5c4300f (patch)
treea9d1f25875eed8f6c219efa2895652cf9911a51e
parenta8594c956cc9dc6799554a554bc422d1ffd4c46b (diff)
parente6d8e7dbc5a363a8e55a65f3bbe7f9f44f0aeb4f (diff)
Merge branch 'net-mctp-improved-bind-handling'
Matt Johnston says: ==================== net: mctp: Improved bind handling This series improves a couple of aspects of MCTP bind() handling. MCTP wasn't checking whether the same MCTP type was bound by multiple sockets. That would result in messages being received by an arbitrary socket, which isn't useful behaviour. Instead it makes more sense to have the duplicate binds fail, the same as other network protocols. An exception is made for more-specific binds to particular MCTP addresses. It is also useful to be able to limit a bind to only receive incoming request messages (MCTP TO bit set) from a specific peer+type, so that individual processes can communicate with separate MCTP peers. One example is a PLDM firmware update requester, which will initiate communication with a device, and then the device will connect back to the requester process. These limited binds are implemented by a connect() call on the socket prior to bind. connect() isn't used in the general case for MCTP, since a plain send() wouldn't provide the required MCTP tag argument for addressing. Signed-off-by: Matt Johnston <matt@codeconstruct.com.au> ==================== Link: https://patch.msgid.link/20250710-mctp-bind-v4-0-8ec2f6460c56@codeconstruct.com.au Signed-off-by: Paolo Abeni <pabeni@redhat.com>
-rw-r--r--include/net/mctp.h5
-rw-r--r--include/net/netns/mctp.h20
-rw-r--r--net/mctp/af_mctp.c148
-rw-r--r--net/mctp/route.c79
-rw-r--r--net/mctp/test/route-test.c194
-rw-r--r--net/mctp/test/sock-test.c167
-rw-r--r--net/mctp/test/utils.c36
-rw-r--r--net/mctp/test/utils.h17
8 files changed, 634 insertions, 32 deletions
diff --git a/include/net/mctp.h b/include/net/mctp.h
index ac4f4ecdfc24..c3207ce98f07 100644
--- a/include/net/mctp.h
+++ b/include/net/mctp.h
@@ -69,7 +69,10 @@ struct mctp_sock {
/* bind() params */
unsigned int bind_net;
- mctp_eid_t bind_addr;
+ mctp_eid_t bind_local_addr;
+ mctp_eid_t bind_peer_addr;
+ unsigned int bind_peer_net;
+ bool bind_peer_set;
__u8 bind_type;
/* sendmsg()/recvmsg() uses struct sockaddr_mctp_ext */
diff --git a/include/net/netns/mctp.h b/include/net/netns/mctp.h
index 1db8f9aaddb4..89555f90b97b 100644
--- a/include/net/netns/mctp.h
+++ b/include/net/netns/mctp.h
@@ -6,19 +6,25 @@
#ifndef __NETNS_MCTP_H__
#define __NETNS_MCTP_H__
+#include <linux/hash.h>
+#include <linux/hashtable.h>
#include <linux/mutex.h>
#include <linux/types.h>
+#define MCTP_BINDS_BITS 7
+
struct netns_mctp {
/* Only updated under RTNL, entries freed via RCU */
struct list_head routes;
- /* Bound sockets: list of sockets bound by type.
- * This list is updated from non-atomic contexts (under bind_lock),
- * and read (under rcu) in packet rx
+ /* Bound sockets: hash table of sockets, keyed by
+ * (type, src_eid, dest_eid).
+ * Specific src_eid/dest_eid entries also have an entry for
+ * MCTP_ADDR_ANY. This list is updated from non-atomic contexts
+ * (under bind_lock), and read (under rcu) in packet rx.
*/
struct mutex bind_lock;
- struct hlist_head binds;
+ DECLARE_HASHTABLE(binds, MCTP_BINDS_BITS);
/* tag allocations. This list is read and updated from atomic contexts,
* but elements are free()ed after a RCU grace-period
@@ -34,4 +40,10 @@ struct netns_mctp {
struct list_head neighbours;
};
+static inline u32 mctp_bind_hash(u8 type, u8 local_addr, u8 peer_addr)
+{
+ return hash_32(type | (u32)local_addr << 8 | (u32)peer_addr << 16,
+ MCTP_BINDS_BITS);
+}
+
#endif /* __NETNS_MCTP_H__ */
diff --git a/net/mctp/af_mctp.c b/net/mctp/af_mctp.c
index aef74308c18e..df4e8cf33899 100644
--- a/net/mctp/af_mctp.c
+++ b/net/mctp/af_mctp.c
@@ -53,6 +53,7 @@ static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen)
{
struct sock *sk = sock->sk;
struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
+ struct net *net = sock_net(&msk->sk);
struct sockaddr_mctp *smctp;
int rc;
@@ -73,14 +74,48 @@ static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen)
lock_sock(sk);
- /* TODO: allow rebind */
if (sk_hashed(sk)) {
rc = -EADDRINUSE;
goto out_release;
}
- msk->bind_net = smctp->smctp_network;
- msk->bind_addr = smctp->smctp_addr.s_addr;
- msk->bind_type = smctp->smctp_type & 0x7f; /* ignore the IC bit */
+
+ msk->bind_local_addr = smctp->smctp_addr.s_addr;
+
+ /* MCTP_NET_ANY with a specific EID is resolved to the default net
+ * at bind() time.
+ * For bind_addr=MCTP_ADDR_ANY it is handled specially at route
+ * lookup time.
+ */
+ if (smctp->smctp_network == MCTP_NET_ANY &&
+ msk->bind_local_addr != MCTP_ADDR_ANY) {
+ msk->bind_net = mctp_default_net(net);
+ } else {
+ msk->bind_net = smctp->smctp_network;
+ }
+
+ /* ignore the IC bit */
+ smctp->smctp_type &= 0x7f;
+
+ if (msk->bind_peer_set) {
+ if (msk->bind_type != smctp->smctp_type) {
+ /* Prior connect() had a different type */
+ rc = -EINVAL;
+ goto out_release;
+ }
+
+ if (msk->bind_net == MCTP_NET_ANY) {
+ /* Restrict to the network passed to connect() */
+ msk->bind_net = msk->bind_peer_net;
+ }
+
+ if (msk->bind_net != msk->bind_peer_net) {
+ /* connect() had a different net to bind() */
+ rc = -EINVAL;
+ goto out_release;
+ }
+ } else {
+ msk->bind_type = smctp->smctp_type;
+ }
rc = sk->sk_prot->hash(sk);
@@ -90,6 +125,67 @@ out_release:
return rc;
}
+/* Used to set a specific peer prior to bind. Not used for outbound
+ * connections (Tag Owner set) since MCTP is a datagram protocol.
+ */
+static int mctp_connect(struct socket *sock, struct sockaddr *addr,
+ int addrlen, int flags)
+{
+ struct sock *sk = sock->sk;
+ struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
+ struct net *net = sock_net(&msk->sk);
+ struct sockaddr_mctp *smctp;
+ int rc;
+
+ if (addrlen != sizeof(*smctp))
+ return -EINVAL;
+
+ if (addr->sa_family != AF_MCTP)
+ return -EAFNOSUPPORT;
+
+ /* It's a valid sockaddr for MCTP, cast and do protocol checks */
+ smctp = (struct sockaddr_mctp *)addr;
+
+ if (!mctp_sockaddr_is_ok(smctp))
+ return -EINVAL;
+
+ /* Can't bind by tag */
+ if (smctp->smctp_tag)
+ return -EINVAL;
+
+ /* IC bit must be unset */
+ if (smctp->smctp_type & 0x80)
+ return -EINVAL;
+
+ lock_sock(sk);
+
+ if (sk_hashed(sk)) {
+ /* bind() already */
+ rc = -EADDRINUSE;
+ goto out_release;
+ }
+
+ if (msk->bind_peer_set) {
+ /* connect() already */
+ rc = -EADDRINUSE;
+ goto out_release;
+ }
+
+ msk->bind_peer_set = true;
+ msk->bind_peer_addr = smctp->smctp_addr.s_addr;
+ msk->bind_type = smctp->smctp_type;
+ if (smctp->smctp_network == MCTP_NET_ANY)
+ msk->bind_peer_net = mctp_default_net(net);
+ else
+ msk->bind_peer_net = smctp->smctp_network;
+
+ rc = 0;
+
+out_release:
+ release_sock(sk);
+ return rc;
+}
+
static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
{
DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
@@ -533,7 +629,7 @@ static const struct proto_ops mctp_dgram_ops = {
.family = PF_MCTP,
.release = mctp_release,
.bind = mctp_bind,
- .connect = sock_no_connect,
+ .connect = mctp_connect,
.socketpair = sock_no_socketpair,
.accept = sock_no_accept,
.getname = sock_no_getname,
@@ -600,6 +696,7 @@ static int mctp_sk_init(struct sock *sk)
INIT_HLIST_HEAD(&msk->keys);
timer_setup(&msk->key_expiry, mctp_sk_expire_keys, 0);
+ msk->bind_peer_set = false;
return 0;
}
@@ -611,15 +708,48 @@ static void mctp_sk_close(struct sock *sk, long timeout)
static int mctp_sk_hash(struct sock *sk)
{
struct net *net = sock_net(sk);
+ struct sock *existing;
+ struct mctp_sock *msk;
+ mctp_eid_t remote;
+ u32 hash;
+ int rc;
+
+ msk = container_of(sk, struct mctp_sock, sk);
+
+ if (msk->bind_peer_set)
+ remote = msk->bind_peer_addr;
+ else
+ remote = MCTP_ADDR_ANY;
+ hash = mctp_bind_hash(msk->bind_type, msk->bind_local_addr, remote);
+
+ mutex_lock(&net->mctp.bind_lock);
+
+ /* Prevent duplicate binds. */
+ sk_for_each(existing, &net->mctp.binds[hash]) {
+ struct mctp_sock *mex =
+ container_of(existing, struct mctp_sock, sk);
+
+ bool same_peer = (mex->bind_peer_set && msk->bind_peer_set &&
+ mex->bind_peer_addr == msk->bind_peer_addr) ||
+ (!mex->bind_peer_set && !msk->bind_peer_set);
+
+ if (mex->bind_type == msk->bind_type &&
+ mex->bind_local_addr == msk->bind_local_addr && same_peer &&
+ mex->bind_net == msk->bind_net) {
+ rc = -EADDRINUSE;
+ goto out;
+ }
+ }
/* Bind lookup runs under RCU, remain live during that. */
sock_set_flag(sk, SOCK_RCU_FREE);
- mutex_lock(&net->mctp.bind_lock);
- sk_add_node_rcu(sk, &net->mctp.binds);
- mutex_unlock(&net->mctp.bind_lock);
+ sk_add_node_rcu(sk, &net->mctp.binds[hash]);
+ rc = 0;
- return 0;
+out:
+ mutex_unlock(&net->mctp.bind_lock);
+ return rc;
}
static void mctp_sk_unhash(struct sock *sk)
diff --git a/net/mctp/route.c b/net/mctp/route.c
index a20d6b11d418..2b2b958ef6a3 100644
--- a/net/mctp/route.c
+++ b/net/mctp/route.c
@@ -40,33 +40,36 @@ static int mctp_dst_discard(struct mctp_dst *dst, struct sk_buff *skb)
return 0;
}
-static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb)
+static struct mctp_sock *mctp_lookup_bind_details(struct net *net,
+ struct sk_buff *skb,
+ u8 type, u8 dest,
+ u8 src, bool allow_net_any)
{
struct mctp_skb_cb *cb = mctp_cb(skb);
- struct mctp_hdr *mh;
struct sock *sk;
- u8 type;
-
- WARN_ON(!rcu_read_lock_held());
-
- /* TODO: look up in skb->cb? */
- mh = mctp_hdr(skb);
+ u8 hash;
- if (!skb_headlen(skb))
- return NULL;
+ WARN_ON_ONCE(!rcu_read_lock_held());
- type = (*(u8 *)skb->data) & 0x7f;
+ hash = mctp_bind_hash(type, dest, src);
- sk_for_each_rcu(sk, &net->mctp.binds) {
+ sk_for_each_rcu(sk, &net->mctp.binds[hash]) {
struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
+ if (!allow_net_any && msk->bind_net == MCTP_NET_ANY)
+ continue;
+
if (msk->bind_net != MCTP_NET_ANY && msk->bind_net != cb->net)
continue;
if (msk->bind_type != type)
continue;
- if (!mctp_address_matches(msk->bind_addr, mh->dest))
+ if (msk->bind_peer_set &&
+ !mctp_address_matches(msk->bind_peer_addr, src))
+ continue;
+
+ if (!mctp_address_matches(msk->bind_local_addr, dest))
continue;
return msk;
@@ -75,6 +78,54 @@ static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb)
return NULL;
}
+static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb)
+{
+ struct mctp_sock *msk;
+ struct mctp_hdr *mh;
+ u8 type;
+
+ /* TODO: look up in skb->cb? */
+ mh = mctp_hdr(skb);
+
+ if (!skb_headlen(skb))
+ return NULL;
+
+ type = (*(u8 *)skb->data) & 0x7f;
+
+ /* Look for binds in order of widening scope. A given destination or
+ * source address also implies matching on a particular network.
+ *
+ * - Matching destination and source
+ * - Matching destination
+ * - Matching source
+ * - Matching network, any address
+ * - Any network or address
+ */
+
+ msk = mctp_lookup_bind_details(net, skb, type, mh->dest, mh->src,
+ false);
+ if (msk)
+ return msk;
+ msk = mctp_lookup_bind_details(net, skb, type, MCTP_ADDR_ANY, mh->src,
+ false);
+ if (msk)
+ return msk;
+ msk = mctp_lookup_bind_details(net, skb, type, mh->dest, MCTP_ADDR_ANY,
+ false);
+ if (msk)
+ return msk;
+ msk = mctp_lookup_bind_details(net, skb, type, MCTP_ADDR_ANY,
+ MCTP_ADDR_ANY, false);
+ if (msk)
+ return msk;
+ msk = mctp_lookup_bind_details(net, skb, type, MCTP_ADDR_ANY,
+ MCTP_ADDR_ANY, true);
+ if (msk)
+ return msk;
+
+ return NULL;
+}
+
/* A note on the key allocations.
*
* struct net->mctp.keys contains our set of currently-allocated keys for
@@ -1671,7 +1722,7 @@ static int __net_init mctp_routes_net_init(struct net *net)
struct netns_mctp *ns = &net->mctp;
INIT_LIST_HEAD(&ns->routes);
- INIT_HLIST_HEAD(&ns->binds);
+ hash_init(ns->binds);
mutex_init(&ns->bind_lock);
INIT_HLIST_HEAD(&ns->keys);
spin_lock_init(&ns->keys_lock);
diff --git a/net/mctp/test/route-test.c b/net/mctp/test/route-test.c
index 7a398f41b621..fb6b46a952cb 100644
--- a/net/mctp/test/route-test.c
+++ b/net/mctp/test/route-test.c
@@ -1164,8 +1164,6 @@ static void mctp_test_route_extaddr_input(struct kunit *test)
rc = mctp_dst_input(&dst, skb);
KUNIT_ASSERT_EQ(test, rc, 0);
- mctp_test_dst_release(&dst, &tpq);
-
skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc);
KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb2);
KUNIT_ASSERT_EQ(test, skb2->len, len);
@@ -1179,8 +1177,8 @@ static void mctp_test_route_extaddr_input(struct kunit *test)
KUNIT_EXPECT_EQ(test, cb2->halen, sizeof(haddr));
KUNIT_EXPECT_MEMEQ(test, cb2->haddr, haddr, sizeof(haddr));
- skb_free_datagram(sock->sk, skb2);
- mctp_test_destroy_dev(dev);
+ kfree_skb(skb2);
+ __mctp_route_test_fini(test, dev, &dst, &tpq, sock);
}
static void mctp_test_route_gw_lookup(struct kunit *test)
@@ -1410,6 +1408,193 @@ static void mctp_test_route_gw_output(struct kunit *test)
kfree_skb(skb);
}
+struct mctp_bind_lookup_test {
+ /* header of incoming message */
+ struct mctp_hdr hdr;
+ u8 ty;
+ /* mctp network of incoming interface (smctp_network) */
+ unsigned int net;
+
+ /* expected socket, matches .name in lookup_binds, NULL for dropped */
+ const char *expect;
+};
+
+/* Single-packet TO-set message */
+#define LK(src, dst) RX_HDR(1, (src), (dst), FL_S | FL_E | FL_TO)
+
+/* Input message test cases for bind lookup tests.
+ *
+ * 10 and 11 are local EIDs.
+ * 20 and 21 are remote EIDs.
+ */
+static const struct mctp_bind_lookup_test mctp_bind_lookup_tests[] = {
+ /* both local-eid and remote-eid binds, remote eid is preferenced */
+ { .hdr = LK(20, 10), .ty = 1, .net = 1, .expect = "remote20" },
+
+ { .hdr = LK(20, 255), .ty = 1, .net = 1, .expect = "remote20" },
+ { .hdr = LK(20, 0), .ty = 1, .net = 1, .expect = "remote20" },
+ { .hdr = LK(0, 255), .ty = 1, .net = 1, .expect = "any" },
+ { .hdr = LK(0, 11), .ty = 1, .net = 1, .expect = "any" },
+ { .hdr = LK(0, 0), .ty = 1, .net = 1, .expect = "any" },
+ { .hdr = LK(0, 10), .ty = 1, .net = 1, .expect = "local10" },
+ { .hdr = LK(21, 10), .ty = 1, .net = 1, .expect = "local10" },
+ { .hdr = LK(21, 11), .ty = 1, .net = 1, .expect = "remote21local11" },
+
+ /* both src and dest set to eid=99. unusual, but accepted
+ * by MCTP stack currently.
+ */
+ { .hdr = LK(99, 99), .ty = 1, .net = 1, .expect = "any" },
+
+ /* unbound smctp_type */
+ { .hdr = LK(20, 10), .ty = 3, .net = 1, .expect = NULL },
+
+ /* smctp_network tests */
+
+ { .hdr = LK(0, 0), .ty = 1, .net = 7, .expect = "any" },
+ { .hdr = LK(21, 10), .ty = 1, .net = 2, .expect = "any" },
+
+ /* remote EID 20 matches, but MCTP_NET_ANY in "remote20" resolved
+ * to net=1, so lookup doesn't match "remote20"
+ */
+ { .hdr = LK(20, 10), .ty = 1, .net = 3, .expect = "any" },
+
+ { .hdr = LK(21, 10), .ty = 1, .net = 3, .expect = "remote21net3" },
+ { .hdr = LK(21, 10), .ty = 1, .net = 4, .expect = "remote21net4" },
+ { .hdr = LK(21, 10), .ty = 1, .net = 5, .expect = "remote21net5" },
+
+ { .hdr = LK(21, 10), .ty = 1, .net = 5, .expect = "remote21net5" },
+
+ { .hdr = LK(99, 10), .ty = 1, .net = 8, .expect = "local10net8" },
+
+ { .hdr = LK(99, 10), .ty = 1, .net = 9, .expect = "anynet9" },
+ { .hdr = LK(0, 0), .ty = 1, .net = 9, .expect = "anynet9" },
+ { .hdr = LK(99, 99), .ty = 1, .net = 9, .expect = "anynet9" },
+ { .hdr = LK(20, 10), .ty = 1, .net = 9, .expect = "anynet9" },
+};
+
+/* Binds to create during the lookup tests */
+static const struct mctp_test_bind_setup lookup_binds[] = {
+ /* any address and net, type 1 */
+ { .name = "any", .bind_addr = MCTP_ADDR_ANY,
+ .bind_net = MCTP_NET_ANY, .bind_type = 1, },
+ /* local eid 10, net 1 (resolved from MCTP_NET_ANY) */
+ { .name = "local10", .bind_addr = 10,
+ .bind_net = MCTP_NET_ANY, .bind_type = 1, },
+ /* local eid 10, net 8 */
+ { .name = "local10net8", .bind_addr = 10,
+ .bind_net = 8, .bind_type = 1, },
+ /* any EID, net 9 */
+ { .name = "anynet9", .bind_addr = MCTP_ADDR_ANY,
+ .bind_net = 9, .bind_type = 1, },
+
+ /* remote eid 20, net 1, any local eid */
+ { .name = "remote20", .bind_addr = MCTP_ADDR_ANY,
+ .bind_net = MCTP_NET_ANY, .bind_type = 1,
+ .have_peer = true, .peer_addr = 20, .peer_net = MCTP_NET_ANY, },
+
+ /* remote eid 20, net 1, local eid 11 */
+ { .name = "remote21local11", .bind_addr = 11,
+ .bind_net = MCTP_NET_ANY, .bind_type = 1,
+ .have_peer = true, .peer_addr = 21, .peer_net = MCTP_NET_ANY, },
+
+ /* remote eid 21, specific net=3 for connect() */
+ { .name = "remote21net3", .bind_addr = MCTP_ADDR_ANY,
+ .bind_net = MCTP_NET_ANY, .bind_type = 1,
+ .have_peer = true, .peer_addr = 21, .peer_net = 3, },
+
+ /* remote eid 21, net 4 for bind, specific net=4 for connect() */
+ { .name = "remote21net4", .bind_addr = MCTP_ADDR_ANY,
+ .bind_net = 4, .bind_type = 1,
+ .have_peer = true, .peer_addr = 21, .peer_net = 4, },
+
+ /* remote eid 21, net 5 for bind, specific net=5 for connect() */
+ { .name = "remote21net5", .bind_addr = MCTP_ADDR_ANY,
+ .bind_net = 5, .bind_type = 1,
+ .have_peer = true, .peer_addr = 21, .peer_net = 5, },
+};
+
+static void mctp_bind_lookup_desc(const struct mctp_bind_lookup_test *t,
+ char *desc)
+{
+ snprintf(desc, KUNIT_PARAM_DESC_SIZE,
+ "{src %d dst %d ty %d net %d expect %s}",
+ t->hdr.src, t->hdr.dest, t->ty, t->net, t->expect);
+}
+
+KUNIT_ARRAY_PARAM(mctp_bind_lookup, mctp_bind_lookup_tests,
+ mctp_bind_lookup_desc);
+
+static void mctp_test_bind_lookup(struct kunit *test)
+{
+ const struct mctp_bind_lookup_test *rx;
+ struct socket *socks[ARRAY_SIZE(lookup_binds)];
+ struct sk_buff *skb_pkt = NULL, *skb_sock = NULL;
+ struct socket *sock_ty0, *sock_expect = NULL;
+ struct mctp_test_pktqueue tpq;
+ struct mctp_test_dev *dev;
+ struct mctp_dst dst;
+ int rc;
+
+ rx = test->param_value;
+
+ __mctp_route_test_init(test, &dev, &dst, &tpq, &sock_ty0, rx->net);
+ /* Create all binds */
+ for (size_t i = 0; i < ARRAY_SIZE(lookup_binds); i++) {
+ mctp_test_bind_run(test, &lookup_binds[i],
+ &rc, &socks[i]);
+ KUNIT_ASSERT_EQ(test, rc, 0);
+
+ /* Record the expected receive socket */
+ if (rx->expect &&
+ strcmp(rx->expect, lookup_binds[i].name) == 0) {
+ KUNIT_ASSERT_NULL(test, sock_expect);
+ sock_expect = socks[i];
+ }
+ }
+ KUNIT_ASSERT_EQ(test, !!sock_expect, !!rx->expect);
+
+ /* Create test message */
+ skb_pkt = mctp_test_create_skb_data(&rx->hdr, &rx->ty);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb_pkt);
+ mctp_test_skb_set_dev(skb_pkt, dev);
+ mctp_test_pktqueue_init(&tpq);
+
+ rc = mctp_dst_input(&dst, skb_pkt);
+ if (rx->expect) {
+ /* Test the message is received on the expected socket */
+ KUNIT_EXPECT_EQ(test, rc, 0);
+ skb_sock = skb_recv_datagram(sock_expect->sk,
+ MSG_DONTWAIT, &rc);
+ if (!skb_sock) {
+ /* Find which socket received it instead */
+ for (size_t i = 0; i < ARRAY_SIZE(lookup_binds); i++) {
+ skb_sock = skb_recv_datagram(socks[i]->sk,
+ MSG_DONTWAIT, &rc);
+ if (skb_sock) {
+ KUNIT_FAIL(test,
+ "received on incorrect socket '%s', expect '%s'",
+ lookup_binds[i].name,
+ rx->expect);
+ goto cleanup;
+ }
+ }
+ KUNIT_FAIL(test, "no message received");
+ }
+ } else {
+ KUNIT_EXPECT_NE(test, rc, 0);
+ }
+
+cleanup:
+ kfree_skb(skb_sock);
+ kfree_skb(skb_pkt);
+
+ /* Drop all binds */
+ for (size_t i = 0; i < ARRAY_SIZE(lookup_binds); i++)
+ sock_release(socks[i]);
+
+ __mctp_route_test_fini(test, dev, &dst, &tpq, sock_ty0);
+}
+
static struct kunit_case mctp_test_cases[] = {
KUNIT_CASE_PARAM(mctp_test_fragment, mctp_frag_gen_params),
KUNIT_CASE_PARAM(mctp_test_rx_input, mctp_rx_input_gen_params),
@@ -1431,6 +1616,7 @@ static struct kunit_case mctp_test_cases[] = {
KUNIT_CASE(mctp_test_route_gw_loop),
KUNIT_CASE_PARAM(mctp_test_route_gw_mtu, mctp_route_gw_mtu_gen_params),
KUNIT_CASE(mctp_test_route_gw_output),
+ KUNIT_CASE_PARAM(mctp_test_bind_lookup, mctp_bind_lookup_gen_params),
{}
};
diff --git a/net/mctp/test/sock-test.c b/net/mctp/test/sock-test.c
index 4eb3a724dca3..b0942deb5019 100644
--- a/net/mctp/test/sock-test.c
+++ b/net/mctp/test/sock-test.c
@@ -215,9 +215,176 @@ static void mctp_test_sock_recvmsg_extaddr(struct kunit *test)
__mctp_sock_test_fini(test, dev, rt, sock);
}
+static const struct mctp_test_bind_setup bind_addrany_netdefault_type1 = {
+ .bind_addr = MCTP_ADDR_ANY, .bind_net = MCTP_NET_ANY, .bind_type = 1,
+};
+
+static const struct mctp_test_bind_setup bind_addrany_net2_type1 = {
+ .bind_addr = MCTP_ADDR_ANY, .bind_net = 2, .bind_type = 1,
+};
+
+/* 1 is default net */
+static const struct mctp_test_bind_setup bind_addr8_net1_type1 = {
+ .bind_addr = 8, .bind_net = 1, .bind_type = 1,
+};
+
+static const struct mctp_test_bind_setup bind_addrany_net1_type1 = {
+ .bind_addr = MCTP_ADDR_ANY, .bind_net = 1, .bind_type = 1,
+};
+
+/* 2 is an arbitrary net */
+static const struct mctp_test_bind_setup bind_addr8_net2_type1 = {
+ .bind_addr = 8, .bind_net = 2, .bind_type = 1,
+};
+
+static const struct mctp_test_bind_setup bind_addr8_netdefault_type1 = {
+ .bind_addr = 8, .bind_net = MCTP_NET_ANY, .bind_type = 1,
+};
+
+static const struct mctp_test_bind_setup bind_addrany_net2_type2 = {
+ .bind_addr = MCTP_ADDR_ANY, .bind_net = 2, .bind_type = 2,
+};
+
+static const struct mctp_test_bind_setup bind_addrany_net2_type1_peer9 = {
+ .bind_addr = MCTP_ADDR_ANY, .bind_net = 2, .bind_type = 1,
+ .have_peer = true, .peer_addr = 9, .peer_net = 2,
+};
+
+struct mctp_bind_pair_test {
+ const struct mctp_test_bind_setup *bind1;
+ const struct mctp_test_bind_setup *bind2;
+ int error;
+};
+
+/* Pairs of binds and whether they will conflict */
+static const struct mctp_bind_pair_test mctp_bind_pair_tests[] = {
+ /* Both ADDR_ANY, conflict */
+ { &bind_addrany_netdefault_type1, &bind_addrany_netdefault_type1,
+ EADDRINUSE },
+ /* Same specific EID, conflict */
+ { &bind_addr8_netdefault_type1, &bind_addr8_netdefault_type1,
+ EADDRINUSE },
+ /* ADDR_ANY vs specific EID, OK */
+ { &bind_addrany_netdefault_type1, &bind_addr8_netdefault_type1, 0 },
+ /* ADDR_ANY different types, OK */
+ { &bind_addrany_net2_type2, &bind_addrany_net2_type1, 0 },
+ /* ADDR_ANY different nets, OK */
+ { &bind_addrany_net2_type1, &bind_addrany_netdefault_type1, 0 },
+
+ /* specific EID, NET_ANY (resolves to default)
+ * vs specific EID, explicit default net 1, conflict
+ */
+ { &bind_addr8_netdefault_type1, &bind_addr8_net1_type1, EADDRINUSE },
+
+ /* specific EID, net 1 vs specific EID, net 2, ok */
+ { &bind_addr8_net1_type1, &bind_addr8_net2_type1, 0 },
+
+ /* ANY_ADDR, NET_ANY (doesn't resolve to default)
+ * vs ADDR_ANY, explicit default net 1, OK
+ */
+ { &bind_addrany_netdefault_type1, &bind_addrany_net1_type1, 0 },
+
+ /* specific remote peer doesn't conflict with any-peer bind */
+ { &bind_addrany_net2_type1_peer9, &bind_addrany_net2_type1, 0 },
+
+ /* bind() NET_ANY is allowed with a connect() net */
+ { &bind_addrany_net2_type1_peer9, &bind_addrany_netdefault_type1, 0 },
+};
+
+static void mctp_bind_pair_desc(const struct mctp_bind_pair_test *t, char *desc)
+{
+ char peer1[25] = {0}, peer2[25] = {0};
+
+ if (t->bind1->have_peer)
+ snprintf(peer1, sizeof(peer1), ", peer %d net %d",
+ t->bind1->peer_addr, t->bind1->peer_net);
+ if (t->bind2->have_peer)
+ snprintf(peer2, sizeof(peer2), ", peer %d net %d",
+ t->bind2->peer_addr, t->bind2->peer_net);
+
+ snprintf(desc, KUNIT_PARAM_DESC_SIZE,
+ "{bind(addr %d, type %d, net %d%s)} {bind(addr %d, type %d, net %d%s)} -> error %d",
+ t->bind1->bind_addr, t->bind1->bind_type,
+ t->bind1->bind_net, peer1,
+ t->bind2->bind_addr, t->bind2->bind_type,
+ t->bind2->bind_net, peer2, t->error);
+}
+
+KUNIT_ARRAY_PARAM(mctp_bind_pair, mctp_bind_pair_tests, mctp_bind_pair_desc);
+
+static void mctp_test_bind_invalid(struct kunit *test)
+{
+ struct socket *sock;
+ int rc;
+
+ /* bind() fails if the bind() vs connect() networks mismatch. */
+ const struct mctp_test_bind_setup bind_connect_net_mismatch = {
+ .bind_addr = MCTP_ADDR_ANY, .bind_net = 1, .bind_type = 1,
+ .have_peer = true, .peer_addr = 9, .peer_net = 2,
+ };
+ mctp_test_bind_run(test, &bind_connect_net_mismatch, &rc, &sock);
+ KUNIT_EXPECT_EQ(test, -rc, EINVAL);
+ sock_release(sock);
+}
+
+static int
+mctp_test_bind_conflicts_inner(struct kunit *test,
+ const struct mctp_test_bind_setup *bind1,
+ const struct mctp_test_bind_setup *bind2)
+{
+ struct socket *sock1 = NULL, *sock2 = NULL, *sock3 = NULL;
+ int bind_errno;
+
+ /* Bind to first address, always succeeds */
+ mctp_test_bind_run(test, bind1, &bind_errno, &sock1);
+ KUNIT_EXPECT_EQ(test, bind_errno, 0);
+
+ /* A second identical bind always fails */
+ mctp_test_bind_run(test, bind1, &bind_errno, &sock2);
+ KUNIT_EXPECT_EQ(test, -bind_errno, EADDRINUSE);
+
+ /* A different bind, result is returned */
+ mctp_test_bind_run(test, bind2, &bind_errno, &sock3);
+
+ if (sock1)
+ sock_release(sock1);
+ if (sock2)
+ sock_release(sock2);
+ if (sock3)
+ sock_release(sock3);
+
+ return bind_errno;
+}
+
+static void mctp_test_bind_conflicts(struct kunit *test)
+{
+ const struct mctp_bind_pair_test *pair;
+ int bind_errno;
+
+ pair = test->param_value;
+
+ bind_errno =
+ mctp_test_bind_conflicts_inner(test, pair->bind1, pair->bind2);
+ KUNIT_EXPECT_EQ(test, -bind_errno, pair->error);
+
+ /* swapping the calls, the second bind should still fail */
+ bind_errno =
+ mctp_test_bind_conflicts_inner(test, pair->bind2, pair->bind1);
+ KUNIT_EXPECT_EQ(test, -bind_errno, pair->error);
+}
+
+static void mctp_test_assumptions(struct kunit *test)
+{
+ /* check assumption of default net from bind_addr8_net1_type1 */
+ KUNIT_ASSERT_EQ(test, mctp_default_net(&init_net), 1);
+}
+
static struct kunit_case mctp_test_cases[] = {
+ KUNIT_CASE(mctp_test_assumptions),
KUNIT_CASE(mctp_test_sock_sendmsg_extaddr),
KUNIT_CASE(mctp_test_sock_recvmsg_extaddr),
+ KUNIT_CASE_PARAM(mctp_test_bind_conflicts, mctp_bind_pair_gen_params),
+ KUNIT_CASE(mctp_test_bind_invalid),
{}
};
diff --git a/net/mctp/test/utils.c b/net/mctp/test/utils.c
index 01f5af416b81..953d41902771 100644
--- a/net/mctp/test/utils.c
+++ b/net/mctp/test/utils.c
@@ -258,3 +258,39 @@ struct sk_buff *__mctp_test_create_skb_data(const struct mctp_hdr *hdr,
return skb;
}
+
+void mctp_test_bind_run(struct kunit *test,
+ const struct mctp_test_bind_setup *setup,
+ int *ret_bind_errno, struct socket **sock)
+{
+ struct sockaddr_mctp addr;
+ int rc;
+
+ *ret_bind_errno = -EIO;
+
+ rc = sock_create_kern(&init_net, AF_MCTP, SOCK_DGRAM, 0, sock);
+ KUNIT_ASSERT_EQ(test, rc, 0);
+
+ /* connect() if requested */
+ if (setup->have_peer) {
+ memset(&addr, 0x0, sizeof(addr));
+ addr.smctp_family = AF_MCTP;
+ addr.smctp_network = setup->peer_net;
+ addr.smctp_addr.s_addr = setup->peer_addr;
+ /* connect() type must match bind() type */
+ addr.smctp_type = setup->bind_type;
+ rc = kernel_connect(*sock, (struct sockaddr *)&addr,
+ sizeof(addr), 0);
+ KUNIT_EXPECT_EQ(test, rc, 0);
+ }
+
+ /* bind() */
+ memset(&addr, 0x0, sizeof(addr));
+ addr.smctp_family = AF_MCTP;
+ addr.smctp_network = setup->bind_net;
+ addr.smctp_addr.s_addr = setup->bind_addr;
+ addr.smctp_type = setup->bind_type;
+
+ *ret_bind_errno =
+ kernel_bind(*sock, (struct sockaddr *)&addr, sizeof(addr));
+}
diff --git a/net/mctp/test/utils.h b/net/mctp/test/utils.h
index f10d1d9066cc..06bdb6cb5eff 100644
--- a/net/mctp/test/utils.h
+++ b/net/mctp/test/utils.h
@@ -31,6 +31,19 @@ struct mctp_test_pktqueue {
struct sk_buff_head pkts;
};
+struct mctp_test_bind_setup {
+ mctp_eid_t bind_addr;
+ int bind_net;
+ u8 bind_type;
+
+ bool have_peer;
+ mctp_eid_t peer_addr;
+ int peer_net;
+
+ /* optional name. Used for comparison in "lookup" tests */
+ const char *name;
+};
+
struct mctp_test_dev *mctp_test_create_dev(void);
struct mctp_test_dev *mctp_test_create_dev_lladdr(unsigned short lladdr_len,
const unsigned char *lladdr);
@@ -61,4 +74,8 @@ struct sk_buff *__mctp_test_create_skb_data(const struct mctp_hdr *hdr,
#define mctp_test_create_skb_data(h, d) \
__mctp_test_create_skb_data(h, d, sizeof(*d))
+void mctp_test_bind_run(struct kunit *test,
+ const struct mctp_test_bind_setup *setup,
+ int *ret_bind_errno, struct socket **sock);
+
#endif /* __NET_MCTP_TEST_UTILS_H */