diff options
Diffstat (limited to 'drivers/net/wireguard')
-rw-r--r-- | drivers/net/wireguard/allowedips.c | 102 | ||||
-rw-r--r-- | drivers/net/wireguard/allowedips.h | 4 | ||||
-rw-r--r-- | drivers/net/wireguard/cookie.c | 4 | ||||
-rw-r--r-- | drivers/net/wireguard/device.c | 9 | ||||
-rw-r--r-- | drivers/net/wireguard/netlink.c | 47 | ||||
-rw-r--r-- | drivers/net/wireguard/noise.c | 4 | ||||
-rw-r--r-- | drivers/net/wireguard/selftest/allowedips.c | 48 | ||||
-rw-r--r-- | drivers/net/wireguard/timers.c | 8 |
8 files changed, 163 insertions, 63 deletions
diff --git a/drivers/net/wireguard/allowedips.c b/drivers/net/wireguard/allowedips.c index 4b8528206cc8..09f7fcd7da78 100644 --- a/drivers/net/wireguard/allowedips.c +++ b/drivers/net/wireguard/allowedips.c @@ -249,6 +249,52 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, return 0; } +static void remove_node(struct allowedips_node *node, struct mutex *lock) +{ + struct allowedips_node *child, **parent_bit, *parent; + bool free_parent; + + list_del_init(&node->peer_list); + RCU_INIT_POINTER(node->peer, NULL); + if (node->bit[0] && node->bit[1]) + return; + child = rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])], + lockdep_is_held(lock)); + if (child) + child->parent_bit_packed = node->parent_bit_packed; + parent_bit = (struct allowedips_node **)(node->parent_bit_packed & ~3UL); + *parent_bit = child; + parent = (void *)parent_bit - + offsetof(struct allowedips_node, bit[node->parent_bit_packed & 1]); + free_parent = !rcu_access_pointer(node->bit[0]) && !rcu_access_pointer(node->bit[1]) && + (node->parent_bit_packed & 3) <= 1 && !rcu_access_pointer(parent->peer); + if (free_parent) + child = rcu_dereference_protected(parent->bit[!(node->parent_bit_packed & 1)], + lockdep_is_held(lock)); + call_rcu(&node->rcu, node_free_rcu); + if (!free_parent) + return; + if (child) + child->parent_bit_packed = parent->parent_bit_packed; + *(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) = child; + call_rcu(&parent->rcu, node_free_rcu); +} + +static int remove(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, + u8 cidr, struct wg_peer *peer, struct mutex *lock) +{ + struct allowedips_node *node; + + if (unlikely(cidr > bits)) + return -EINVAL; + if (!rcu_access_pointer(*trie) || !node_placement(*trie, key, cidr, bits, &node, lock) || + peer != rcu_access_pointer(node->peer)) + return 0; + + remove_node(node, lock); + return 0; +} + void wg_allowedips_init(struct allowedips *table) { table->root4 = table->root6 = NULL; @@ -300,44 +346,38 @@ int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, return add(&table->root6, 128, key, cidr, peer, lock); } +int wg_allowedips_remove_v4(struct allowedips *table, const struct in_addr *ip, + u8 cidr, struct wg_peer *peer, struct mutex *lock) +{ + /* Aligned so it can be passed to fls */ + u8 key[4] __aligned(__alignof(u32)); + + ++table->seq; + swap_endian(key, (const u8 *)ip, 32); + return remove(&table->root4, 32, key, cidr, peer, lock); +} + +int wg_allowedips_remove_v6(struct allowedips *table, const struct in6_addr *ip, + u8 cidr, struct wg_peer *peer, struct mutex *lock) +{ + /* Aligned so it can be passed to fls64 */ + u8 key[16] __aligned(__alignof(u64)); + + ++table->seq; + swap_endian(key, (const u8 *)ip, 128); + return remove(&table->root6, 128, key, cidr, peer, lock); +} + void wg_allowedips_remove_by_peer(struct allowedips *table, struct wg_peer *peer, struct mutex *lock) { - struct allowedips_node *node, *child, **parent_bit, *parent, *tmp; - bool free_parent; + struct allowedips_node *node, *tmp; if (list_empty(&peer->allowedips_list)) return; ++table->seq; - list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) { - list_del_init(&node->peer_list); - RCU_INIT_POINTER(node->peer, NULL); - if (node->bit[0] && node->bit[1]) - continue; - child = rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])], - lockdep_is_held(lock)); - if (child) - child->parent_bit_packed = node->parent_bit_packed; - parent_bit = (struct allowedips_node **)(node->parent_bit_packed & ~3UL); - *parent_bit = child; - parent = (void *)parent_bit - - offsetof(struct allowedips_node, bit[node->parent_bit_packed & 1]); - free_parent = !rcu_access_pointer(node->bit[0]) && - !rcu_access_pointer(node->bit[1]) && - (node->parent_bit_packed & 3) <= 1 && - !rcu_access_pointer(parent->peer); - if (free_parent) - child = rcu_dereference_protected( - parent->bit[!(node->parent_bit_packed & 1)], - lockdep_is_held(lock)); - call_rcu(&node->rcu, node_free_rcu); - if (!free_parent) - continue; - if (child) - child->parent_bit_packed = parent->parent_bit_packed; - *(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) = child; - call_rcu(&parent->rcu, node_free_rcu); - } + list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) + remove_node(node, lock); } int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr) diff --git a/drivers/net/wireguard/allowedips.h b/drivers/net/wireguard/allowedips.h index 2346c797eb4d..931958cb6e10 100644 --- a/drivers/net/wireguard/allowedips.h +++ b/drivers/net/wireguard/allowedips.h @@ -38,6 +38,10 @@ int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, u8 cidr, struct wg_peer *peer, struct mutex *lock); int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, u8 cidr, struct wg_peer *peer, struct mutex *lock); +int wg_allowedips_remove_v4(struct allowedips *table, const struct in_addr *ip, + u8 cidr, struct wg_peer *peer, struct mutex *lock); +int wg_allowedips_remove_v6(struct allowedips *table, const struct in6_addr *ip, + u8 cidr, struct wg_peer *peer, struct mutex *lock); void wg_allowedips_remove_by_peer(struct allowedips *table, struct wg_peer *peer, struct mutex *lock); /* The ip input pointer should be __aligned(__alignof(u64))) */ diff --git a/drivers/net/wireguard/cookie.c b/drivers/net/wireguard/cookie.c index f89581b5e8cb..94d0a7206084 100644 --- a/drivers/net/wireguard/cookie.c +++ b/drivers/net/wireguard/cookie.c @@ -26,8 +26,8 @@ void wg_cookie_checker_init(struct cookie_checker *checker, } enum { COOKIE_KEY_LABEL_LEN = 8 }; -static const u8 mac1_key_label[COOKIE_KEY_LABEL_LEN] = "mac1----"; -static const u8 cookie_key_label[COOKIE_KEY_LABEL_LEN] = "cookie--"; +static const u8 mac1_key_label[COOKIE_KEY_LABEL_LEN] __nonstring = "mac1----"; +static const u8 cookie_key_label[COOKIE_KEY_LABEL_LEN] __nonstring = "cookie--"; static void precompute_key(u8 key[NOISE_SYMMETRIC_KEY_LEN], const u8 pubkey[NOISE_PUBLIC_KEY_LEN], diff --git a/drivers/net/wireguard/device.c b/drivers/net/wireguard/device.c index 6cf173a008e7..3ffeeba5dccf 100644 --- a/drivers/net/wireguard/device.c +++ b/drivers/net/wireguard/device.c @@ -81,7 +81,7 @@ static int wg_pm_notification(struct notifier_block *nb, unsigned long action, v list_for_each_entry(wg, &device_list, device_list) { mutex_lock(&wg->device_update_lock); list_for_each_entry(peer, &wg->peer_list, peer_list) { - del_timer(&peer->timer_zero_key_material); + timer_delete(&peer->timer_zero_key_material); wg_noise_handshake_clear(&peer->handshake); wg_noise_keypairs_clear(&peer->keypairs); } @@ -307,14 +307,15 @@ static void wg_setup(struct net_device *dev) wg->dev = dev; } -static int wg_newlink(struct net *src_net, struct net_device *dev, - struct nlattr *tb[], struct nlattr *data[], +static int wg_newlink(struct net_device *dev, + struct rtnl_newlink_params *params, struct netlink_ext_ack *extack) { + struct net *link_net = rtnl_newlink_link_net(params); struct wg_device *wg = netdev_priv(dev); int ret = -ENOMEM; - rcu_assign_pointer(wg->creating_net, src_net); + rcu_assign_pointer(wg->creating_net, link_net); init_rwsem(&wg->static_identity.lock); mutex_init(&wg->socket_update_lock); mutex_init(&wg->device_update_lock); diff --git a/drivers/net/wireguard/netlink.c b/drivers/net/wireguard/netlink.c index f7055180ba4a..67f962eb8b46 100644 --- a/drivers/net/wireguard/netlink.c +++ b/drivers/net/wireguard/netlink.c @@ -24,7 +24,7 @@ static const struct nla_policy device_policy[WGDEVICE_A_MAX + 1] = { [WGDEVICE_A_IFNAME] = { .type = NLA_NUL_STRING, .len = IFNAMSIZ - 1 }, [WGDEVICE_A_PRIVATE_KEY] = NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN), [WGDEVICE_A_PUBLIC_KEY] = NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN), - [WGDEVICE_A_FLAGS] = { .type = NLA_U32 }, + [WGDEVICE_A_FLAGS] = NLA_POLICY_MASK(NLA_U32, __WGDEVICE_F_ALL), [WGDEVICE_A_LISTEN_PORT] = { .type = NLA_U16 }, [WGDEVICE_A_FWMARK] = { .type = NLA_U32 }, [WGDEVICE_A_PEERS] = { .type = NLA_NESTED } @@ -33,7 +33,7 @@ static const struct nla_policy device_policy[WGDEVICE_A_MAX + 1] = { static const struct nla_policy peer_policy[WGPEER_A_MAX + 1] = { [WGPEER_A_PUBLIC_KEY] = NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN), [WGPEER_A_PRESHARED_KEY] = NLA_POLICY_EXACT_LEN(NOISE_SYMMETRIC_KEY_LEN), - [WGPEER_A_FLAGS] = { .type = NLA_U32 }, + [WGPEER_A_FLAGS] = NLA_POLICY_MASK(NLA_U32, __WGPEER_F_ALL), [WGPEER_A_ENDPOINT] = NLA_POLICY_MIN_LEN(sizeof(struct sockaddr)), [WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL] = { .type = NLA_U16 }, [WGPEER_A_LAST_HANDSHAKE_TIME] = NLA_POLICY_EXACT_LEN(sizeof(struct __kernel_timespec)), @@ -46,7 +46,8 @@ static const struct nla_policy peer_policy[WGPEER_A_MAX + 1] = { static const struct nla_policy allowedip_policy[WGALLOWEDIP_A_MAX + 1] = { [WGALLOWEDIP_A_FAMILY] = { .type = NLA_U16 }, [WGALLOWEDIP_A_IPADDR] = NLA_POLICY_MIN_LEN(sizeof(struct in_addr)), - [WGALLOWEDIP_A_CIDR_MASK] = { .type = NLA_U8 } + [WGALLOWEDIP_A_CIDR_MASK] = { .type = NLA_U8 }, + [WGALLOWEDIP_A_FLAGS] = NLA_POLICY_MASK(NLA_U32, __WGALLOWEDIP_F_ALL), }; static struct wg_device *lookup_interface(struct nlattr **attrs, @@ -329,6 +330,7 @@ static int set_port(struct wg_device *wg, u16 port) static int set_allowedip(struct wg_peer *peer, struct nlattr **attrs) { int ret = -EINVAL; + u32 flags = 0; u16 family; u8 cidr; @@ -337,19 +339,30 @@ static int set_allowedip(struct wg_peer *peer, struct nlattr **attrs) return ret; family = nla_get_u16(attrs[WGALLOWEDIP_A_FAMILY]); cidr = nla_get_u8(attrs[WGALLOWEDIP_A_CIDR_MASK]); + if (attrs[WGALLOWEDIP_A_FLAGS]) + flags = nla_get_u32(attrs[WGALLOWEDIP_A_FLAGS]); if (family == AF_INET && cidr <= 32 && - nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr)) - ret = wg_allowedips_insert_v4( - &peer->device->peer_allowedips, - nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer, - &peer->device->device_update_lock); - else if (family == AF_INET6 && cidr <= 128 && - nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in6_addr)) - ret = wg_allowedips_insert_v6( - &peer->device->peer_allowedips, - nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer, - &peer->device->device_update_lock); + nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr)) { + if (flags & WGALLOWEDIP_F_REMOVE_ME) + ret = wg_allowedips_remove_v4(&peer->device->peer_allowedips, + nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, + peer, &peer->device->device_update_lock); + else + ret = wg_allowedips_insert_v4(&peer->device->peer_allowedips, + nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, + peer, &peer->device->device_update_lock); + } else if (family == AF_INET6 && cidr <= 128 && + nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in6_addr)) { + if (flags & WGALLOWEDIP_F_REMOVE_ME) + ret = wg_allowedips_remove_v6(&peer->device->peer_allowedips, + nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, + peer, &peer->device->device_update_lock); + else + ret = wg_allowedips_insert_v6(&peer->device->peer_allowedips, + nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, + peer, &peer->device->device_update_lock); + } return ret; } @@ -373,9 +386,6 @@ static int set_peer(struct wg_device *wg, struct nlattr **attrs) if (attrs[WGPEER_A_FLAGS]) flags = nla_get_u32(attrs[WGPEER_A_FLAGS]); - ret = -EOPNOTSUPP; - if (flags & ~__WGPEER_F_ALL) - goto out; ret = -EPFNOSUPPORT; if (attrs[WGPEER_A_PROTOCOL_VERSION]) { @@ -506,9 +516,6 @@ static int wg_set_device(struct sk_buff *skb, struct genl_info *info) if (info->attrs[WGDEVICE_A_FLAGS]) flags = nla_get_u32(info->attrs[WGDEVICE_A_FLAGS]); - ret = -EOPNOTSUPP; - if (flags & ~__WGDEVICE_F_ALL) - goto out; if (info->attrs[WGDEVICE_A_LISTEN_PORT] || info->attrs[WGDEVICE_A_FWMARK]) { struct net *net; diff --git a/drivers/net/wireguard/noise.c b/drivers/net/wireguard/noise.c index 202a33af5a72..7eb9a23a3d4d 100644 --- a/drivers/net/wireguard/noise.c +++ b/drivers/net/wireguard/noise.c @@ -25,8 +25,8 @@ * <- e, ee, se, psk, {} */ -static const u8 handshake_name[37] = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"; -static const u8 identifier_name[34] = "WireGuard v1 zx2c4 Jason@zx2c4.com"; +static const u8 handshake_name[37] __nonstring = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"; +static const u8 identifier_name[34] __nonstring = "WireGuard v1 zx2c4 Jason@zx2c4.com"; static u8 handshake_init_hash[NOISE_HASH_LEN] __ro_after_init; static u8 handshake_init_chaining_key[NOISE_HASH_LEN] __ro_after_init; static atomic64_t keypair_counter = ATOMIC64_INIT(0); diff --git a/drivers/net/wireguard/selftest/allowedips.c b/drivers/net/wireguard/selftest/allowedips.c index 25de7058701a..41837efa70cb 100644 --- a/drivers/net/wireguard/selftest/allowedips.c +++ b/drivers/net/wireguard/selftest/allowedips.c @@ -460,6 +460,10 @@ static __init struct wg_peer *init_peer(void) wg_allowedips_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), \ cidr, mem, &mutex) +#define remove(version, mem, ipa, ipb, ipc, ipd, cidr) \ + wg_allowedips_remove_v##version(&t, ip##version(ipa, ipb, ipc, ipd), \ + cidr, mem, &mutex) + #define maybe_fail() do { \ ++i; \ if (!_s) { \ @@ -585,6 +589,50 @@ bool __init wg_allowedips_selftest(void) test_negative(4, a, 192, 0, 0, 0); test_negative(4, a, 255, 0, 0, 0); + insert(4, a, 1, 0, 0, 0, 32); + insert(4, a, 192, 0, 0, 0, 24); + insert(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128); + insert(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98); + test(4, a, 1, 0, 0, 0); + test(4, a, 192, 0, 0, 1); + test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef); + test(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010); + /* Must be an exact match to remove */ + remove(4, a, 192, 0, 0, 0, 32); + test(4, a, 192, 0, 0, 1); + /* NULL peer should have no effect and return 0 */ + test_boolean(!remove(4, NULL, 192, 0, 0, 0, 24)); + test(4, a, 192, 0, 0, 1); + /* different peer should have no effect and return 0 */ + test_boolean(!remove(4, b, 192, 0, 0, 0, 24)); + test(4, a, 192, 0, 0, 1); + /* invalid CIDR should have no effect and return -EINVAL */ + test_boolean(remove(4, b, 192, 0, 0, 0, 33) == -EINVAL); + test(4, a, 192, 0, 0, 1); + remove(4, a, 192, 0, 0, 0, 24); + test_negative(4, a, 192, 0, 0, 1); + remove(4, a, 1, 0, 0, 0, 32); + test_negative(4, a, 1, 0, 0, 0); + /* Must be an exact match to remove */ + remove(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 96); + test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef); + /* NULL peer should have no effect and return 0 */ + test_boolean(!remove(6, NULL, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)); + test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef); + /* different peer should have no effect and return 0 */ + test_boolean(!remove(6, b, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)); + test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef); + /* invalid CIDR should have no effect and return -EINVAL */ + test_boolean(remove(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 129) == -EINVAL); + test(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef); + remove(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128); + test_negative(6, a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef); + /* Must match the peer to remove */ + remove(6, b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98); + test(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010); + remove(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98); + test_negative(6, a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010); + wg_allowedips_free(&t, &mutex); wg_allowedips_init(&t); insert(4, a, 192, 168, 0, 0, 16); diff --git a/drivers/net/wireguard/timers.c b/drivers/net/wireguard/timers.c index 968bdb4df0b3..a9e0890c2f77 100644 --- a/drivers/net/wireguard/timers.c +++ b/drivers/net/wireguard/timers.c @@ -48,7 +48,7 @@ static void wg_expired_retransmit_handshake(struct timer_list *timer) peer->device->dev->name, peer->internal_id, &peer->endpoint.addr, (int)MAX_TIMER_HANDSHAKES + 2); - del_timer(&peer->timer_send_keepalive); + timer_delete(&peer->timer_send_keepalive); /* We drop all packets without a keypair and don't try again, * if we try unsuccessfully for too long to make a handshake. */ @@ -167,7 +167,7 @@ void wg_timers_data_received(struct wg_peer *peer) */ void wg_timers_any_authenticated_packet_sent(struct wg_peer *peer) { - del_timer(&peer->timer_send_keepalive); + timer_delete(&peer->timer_send_keepalive); } /* Should be called after any type of authenticated packet is received, whether @@ -175,7 +175,7 @@ void wg_timers_any_authenticated_packet_sent(struct wg_peer *peer) */ void wg_timers_any_authenticated_packet_received(struct wg_peer *peer) { - del_timer(&peer->timer_new_handshake); + timer_delete(&peer->timer_new_handshake); } /* Should be called after a handshake initiation message is sent. */ @@ -191,7 +191,7 @@ void wg_timers_handshake_initiated(struct wg_peer *peer) */ void wg_timers_handshake_complete(struct wg_peer *peer) { - del_timer(&peer->timer_retransmit_handshake); + timer_delete(&peer->timer_retransmit_handshake); peer->timer_handshake_attempts = 0; peer->sent_lastminute_handshake = false; ktime_get_real_ts64(&peer->walltime_last_handshake); |