diff options
Diffstat (limited to 'drivers/net/wireguard/netlink.c')
-rw-r--r-- | drivers/net/wireguard/netlink.c | 47 |
1 files changed, 27 insertions, 20 deletions
diff --git a/drivers/net/wireguard/netlink.c b/drivers/net/wireguard/netlink.c index f7055180ba4a..67f962eb8b46 100644 --- a/drivers/net/wireguard/netlink.c +++ b/drivers/net/wireguard/netlink.c @@ -24,7 +24,7 @@ static const struct nla_policy device_policy[WGDEVICE_A_MAX + 1] = { [WGDEVICE_A_IFNAME] = { .type = NLA_NUL_STRING, .len = IFNAMSIZ - 1 }, [WGDEVICE_A_PRIVATE_KEY] = NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN), [WGDEVICE_A_PUBLIC_KEY] = NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN), - [WGDEVICE_A_FLAGS] = { .type = NLA_U32 }, + [WGDEVICE_A_FLAGS] = NLA_POLICY_MASK(NLA_U32, __WGDEVICE_F_ALL), [WGDEVICE_A_LISTEN_PORT] = { .type = NLA_U16 }, [WGDEVICE_A_FWMARK] = { .type = NLA_U32 }, [WGDEVICE_A_PEERS] = { .type = NLA_NESTED } @@ -33,7 +33,7 @@ static const struct nla_policy device_policy[WGDEVICE_A_MAX + 1] = { static const struct nla_policy peer_policy[WGPEER_A_MAX + 1] = { [WGPEER_A_PUBLIC_KEY] = NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN), [WGPEER_A_PRESHARED_KEY] = NLA_POLICY_EXACT_LEN(NOISE_SYMMETRIC_KEY_LEN), - [WGPEER_A_FLAGS] = { .type = NLA_U32 }, + [WGPEER_A_FLAGS] = NLA_POLICY_MASK(NLA_U32, __WGPEER_F_ALL), [WGPEER_A_ENDPOINT] = NLA_POLICY_MIN_LEN(sizeof(struct sockaddr)), [WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL] = { .type = NLA_U16 }, [WGPEER_A_LAST_HANDSHAKE_TIME] = NLA_POLICY_EXACT_LEN(sizeof(struct __kernel_timespec)), @@ -46,7 +46,8 @@ static const struct nla_policy peer_policy[WGPEER_A_MAX + 1] = { static const struct nla_policy allowedip_policy[WGALLOWEDIP_A_MAX + 1] = { [WGALLOWEDIP_A_FAMILY] = { .type = NLA_U16 }, [WGALLOWEDIP_A_IPADDR] = NLA_POLICY_MIN_LEN(sizeof(struct in_addr)), - [WGALLOWEDIP_A_CIDR_MASK] = { .type = NLA_U8 } + [WGALLOWEDIP_A_CIDR_MASK] = { .type = NLA_U8 }, + [WGALLOWEDIP_A_FLAGS] = NLA_POLICY_MASK(NLA_U32, __WGALLOWEDIP_F_ALL), }; static struct wg_device *lookup_interface(struct nlattr **attrs, @@ -329,6 +330,7 @@ static int set_port(struct wg_device *wg, u16 port) static int set_allowedip(struct wg_peer *peer, struct nlattr **attrs) { int ret = -EINVAL; + u32 flags = 0; u16 family; u8 cidr; @@ -337,19 +339,30 @@ static int set_allowedip(struct wg_peer *peer, struct nlattr **attrs) return ret; family = nla_get_u16(attrs[WGALLOWEDIP_A_FAMILY]); cidr = nla_get_u8(attrs[WGALLOWEDIP_A_CIDR_MASK]); + if (attrs[WGALLOWEDIP_A_FLAGS]) + flags = nla_get_u32(attrs[WGALLOWEDIP_A_FLAGS]); if (family == AF_INET && cidr <= 32 && - nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr)) - ret = wg_allowedips_insert_v4( - &peer->device->peer_allowedips, - nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer, - &peer->device->device_update_lock); - else if (family == AF_INET6 && cidr <= 128 && - nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in6_addr)) - ret = wg_allowedips_insert_v6( - &peer->device->peer_allowedips, - nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer, - &peer->device->device_update_lock); + nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr)) { + if (flags & WGALLOWEDIP_F_REMOVE_ME) + ret = wg_allowedips_remove_v4(&peer->device->peer_allowedips, + nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, + peer, &peer->device->device_update_lock); + else + ret = wg_allowedips_insert_v4(&peer->device->peer_allowedips, + nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, + peer, &peer->device->device_update_lock); + } else if (family == AF_INET6 && cidr <= 128 && + nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in6_addr)) { + if (flags & WGALLOWEDIP_F_REMOVE_ME) + ret = wg_allowedips_remove_v6(&peer->device->peer_allowedips, + nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, + peer, &peer->device->device_update_lock); + else + ret = wg_allowedips_insert_v6(&peer->device->peer_allowedips, + nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, + peer, &peer->device->device_update_lock); + } return ret; } @@ -373,9 +386,6 @@ static int set_peer(struct wg_device *wg, struct nlattr **attrs) if (attrs[WGPEER_A_FLAGS]) flags = nla_get_u32(attrs[WGPEER_A_FLAGS]); - ret = -EOPNOTSUPP; - if (flags & ~__WGPEER_F_ALL) - goto out; ret = -EPFNOSUPPORT; if (attrs[WGPEER_A_PROTOCOL_VERSION]) { @@ -506,9 +516,6 @@ static int wg_set_device(struct sk_buff *skb, struct genl_info *info) if (info->attrs[WGDEVICE_A_FLAGS]) flags = nla_get_u32(info->attrs[WGDEVICE_A_FLAGS]); - ret = -EOPNOTSUPP; - if (flags & ~__WGDEVICE_F_ALL) - goto out; if (info->attrs[WGDEVICE_A_LISTEN_PORT] || info->attrs[WGDEVICE_A_FWMARK]) { struct net *net; |