diff options
Diffstat (limited to 'net/tls/tls_main.c')
| -rw-r--r-- | net/tls/tls_main.c | 422 |
1 files changed, 237 insertions, 185 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index df921a2904b9..b3da6c5ab999 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -39,8 +39,11 @@ #include <linux/netdevice.h> #include <linux/sched/signal.h> #include <linux/inetdevice.h> +#include <linux/inet_diag.h> +#include <net/snmp.h> #include <net/tls.h> +#include <net/tls_toe.h> MODULE_AUTHOR("Mellanox Technologies"); MODULE_DESCRIPTION("Transport Layer Security Support"); @@ -57,14 +60,12 @@ static struct proto *saved_tcpv6_prot; static DEFINE_MUTEX(tcpv6_prot_mutex); static struct proto *saved_tcpv4_prot; static DEFINE_MUTEX(tcpv4_prot_mutex); -static LIST_HEAD(device_list); -static DEFINE_SPINLOCK(device_spinlock); static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; static struct proto_ops tls_sw_proto_ops; static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], struct proto *base); -static void update_sk_prot(struct sock *sk, struct tls_context *ctx) +void update_sk_prot(struct sock *sk, struct tls_context *ctx) { int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; @@ -208,6 +209,17 @@ int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, return tls_push_sg(sk, ctx, sg, offset, flags); } +void tls_free_partial_record(struct sock *sk, struct tls_context *ctx) +{ + struct scatterlist *sg; + + for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) { + put_page(sg_page(sg)); + sk_mem_uncharge(sk, sg->length); + } + ctx->partially_sent_record = NULL; +} + static void tls_write_space(struct sock *sk) { struct tls_context *ctx = tls_get_ctx(sk); @@ -231,70 +243,90 @@ static void tls_write_space(struct sock *sk) ctx->sk_write_space(sk); } -static void tls_ctx_free(struct tls_context *ctx) +/** + * tls_ctx_free() - free TLS ULP context + * @sk: socket to with @ctx is attached + * @ctx: TLS context structure + * + * Free TLS context. If @sk is %NULL caller guarantees that the socket + * to which @ctx was attached has no outstanding references. + */ +void tls_ctx_free(struct sock *sk, struct tls_context *ctx) { if (!ctx) return; memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send)); memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv)); - kfree(ctx); + mutex_destroy(&ctx->tx_lock); + + if (sk) + kfree_rcu(ctx, rcu); + else + kfree(ctx); } -static void tls_sk_proto_close(struct sock *sk, long timeout) +static void tls_sk_proto_cleanup(struct sock *sk, + struct tls_context *ctx, long timeo) { - struct tls_context *ctx = tls_get_ctx(sk); - long timeo = sock_sndtimeo(sk, 0); - void (*sk_proto_close)(struct sock *sk, long timeout); - bool free_ctx = false; - - lock_sock(sk); - sk_proto_close = ctx->sk_proto_close; - - if (ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD) - goto skip_tx_cleanup; - - if (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE) { - free_ctx = true; - goto skip_tx_cleanup; - } - - if (!tls_complete_pending_work(sk, ctx, 0, &timeo)) + if (unlikely(sk->sk_write_pending) && + !wait_on_pending_writer(sk, &timeo)) tls_handle_open_record(sk, 0); /* We need these for tls_sw_fallback handling of other packets */ if (ctx->tx_conf == TLS_SW) { kfree(ctx->tx.rec_seq); kfree(ctx->tx.iv); - tls_sw_free_resources_tx(sk); + tls_sw_release_resources_tx(sk); + TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); + } else if (ctx->tx_conf == TLS_HW) { + tls_device_free_resources_tx(sk); + TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); } if (ctx->rx_conf == TLS_SW) { - kfree(ctx->rx.rec_seq); - kfree(ctx->rx.iv); - tls_sw_free_resources_rx(sk); + tls_sw_release_resources_rx(sk); + TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); + } else if (ctx->rx_conf == TLS_HW) { + tls_device_offload_cleanup_rx(sk); + TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); } +} -#ifdef CONFIG_TLS_DEVICE - if (ctx->rx_conf == TLS_HW) - tls_device_offload_cleanup_rx(sk); +static void tls_sk_proto_close(struct sock *sk, long timeout) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tls_context *ctx = tls_get_ctx(sk); + long timeo = sock_sndtimeo(sk, 0); + bool free_ctx; - if (ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW) { -#else - { -#endif - tls_ctx_free(ctx); - ctx = NULL; - } + if (ctx->tx_conf == TLS_SW) + tls_sw_cancel_work_tx(ctx); -skip_tx_cleanup: + lock_sock(sk); + free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW; + + if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE) + tls_sk_proto_cleanup(sk, ctx, timeo); + + write_lock_bh(&sk->sk_callback_lock); + if (free_ctx) + rcu_assign_pointer(icsk->icsk_ulp_data, NULL); + sk->sk_prot = ctx->sk_proto; + if (sk->sk_write_space == tls_write_space) + sk->sk_write_space = ctx->sk_write_space; + write_unlock_bh(&sk->sk_callback_lock); release_sock(sk); - sk_proto_close(sk, timeout); - /* free ctx for TLS_HW_RECORD, used by tcp_set_state - * for sk->sk_prot->unhash [tls_hw_unhash] - */ + if (ctx->tx_conf == TLS_SW) + tls_sw_free_ctx_tx(ctx); + if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) + tls_sw_strparser_done(ctx); + if (ctx->rx_conf == TLS_SW) + tls_sw_free_ctx_rx(ctx); + ctx->sk_proto->close(sk, timeout); + if (free_ctx) - tls_ctx_free(ctx); + tls_ctx_free(sk, ctx); } static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval, @@ -411,7 +443,8 @@ static int tls_getsockopt(struct sock *sk, int level, int optname, struct tls_context *ctx = tls_get_ctx(sk); if (level != SOL_TLS) - return ctx->getsockopt(sk, level, optname, optval, optlen); + return ctx->sk_proto->getsockopt(sk, level, + optname, optval, optlen); return do_tls_getsockopt(sk, optname, optval, optlen); } @@ -469,54 +502,63 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, switch (crypto_info->cipher_type) { case TLS_CIPHER_AES_GCM_128: + optsize = sizeof(struct tls12_crypto_info_aes_gcm_128); + break; case TLS_CIPHER_AES_GCM_256: { - optsize = crypto_info->cipher_type == TLS_CIPHER_AES_GCM_128 ? - sizeof(struct tls12_crypto_info_aes_gcm_128) : - sizeof(struct tls12_crypto_info_aes_gcm_256); - if (optlen != optsize) { - rc = -EINVAL; - goto err_crypto_info; - } - rc = copy_from_user(crypto_info + 1, optval + sizeof(*crypto_info), - optlen - sizeof(*crypto_info)); - if (rc) { - rc = -EFAULT; - goto err_crypto_info; - } + optsize = sizeof(struct tls12_crypto_info_aes_gcm_256); break; } + case TLS_CIPHER_AES_CCM_128: + optsize = sizeof(struct tls12_crypto_info_aes_ccm_128); + break; default: rc = -EINVAL; goto err_crypto_info; } + if (optlen != optsize) { + rc = -EINVAL; + goto err_crypto_info; + } + + rc = copy_from_user(crypto_info + 1, optval + sizeof(*crypto_info), + optlen - sizeof(*crypto_info)); + if (rc) { + rc = -EFAULT; + goto err_crypto_info; + } + if (tx) { -#ifdef CONFIG_TLS_DEVICE rc = tls_set_device_offload(sk, ctx); conf = TLS_HW; - if (rc) { -#else - { -#endif + if (!rc) { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); + } else { rc = tls_set_sw_offload(sk, ctx, 1); + if (rc) + goto err_crypto_info; + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); conf = TLS_SW; } } else { -#ifdef CONFIG_TLS_DEVICE rc = tls_set_device_offload_rx(sk, ctx); conf = TLS_HW; - if (rc) { -#else - { -#endif + if (!rc) { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); + } else { rc = tls_set_sw_offload(sk, ctx, 0); + if (rc) + goto err_crypto_info; + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); conf = TLS_SW; } + tls_sw_strparser_arm(sk, ctx); } - if (rc) - goto err_crypto_info; - if (tx) ctx->tx_conf = conf; else @@ -562,12 +604,13 @@ static int tls_setsockopt(struct sock *sk, int level, int optname, struct tls_context *ctx = tls_get_ctx(sk); if (level != SOL_TLS) - return ctx->setsockopt(sk, level, optname, optval, optlen); + return ctx->sk_proto->setsockopt(sk, level, optname, optval, + optlen); return do_tls_setsockopt(sk, optname, optval, optlen); } -static struct tls_context *create_ctx(struct sock *sk) +struct tls_context *tls_ctx_create(struct sock *sk) { struct inet_connection_sock *icsk = inet_csk(sk); struct tls_context *ctx; @@ -576,10 +619,9 @@ static struct tls_context *create_ctx(struct sock *sk) if (!ctx) return NULL; - icsk->icsk_ulp_data = ctx; - ctx->setsockopt = sk->sk_prot->setsockopt; - ctx->getsockopt = sk->sk_prot->getsockopt; - ctx->sk_proto_close = sk->sk_prot->close; + mutex_init(&ctx->tx_lock); + rcu_assign_pointer(icsk->icsk_ulp_data, ctx); + ctx->sk_proto = sk->sk_prot; return ctx; } @@ -609,93 +651,6 @@ static void tls_build_proto(struct sock *sk) } } -static void tls_hw_sk_destruct(struct sock *sk) -{ - struct tls_context *ctx = tls_get_ctx(sk); - struct inet_connection_sock *icsk = inet_csk(sk); - - ctx->sk_destruct(sk); - /* Free ctx */ - kfree(ctx); - icsk->icsk_ulp_data = NULL; -} - -static int tls_hw_prot(struct sock *sk) -{ - struct tls_context *ctx; - struct tls_device *dev; - int rc = 0; - - spin_lock_bh(&device_spinlock); - list_for_each_entry(dev, &device_list, dev_list) { - if (dev->feature && dev->feature(dev)) { - ctx = create_ctx(sk); - if (!ctx) - goto out; - - spin_unlock_bh(&device_spinlock); - tls_build_proto(sk); - ctx->hash = sk->sk_prot->hash; - ctx->unhash = sk->sk_prot->unhash; - ctx->sk_proto_close = sk->sk_prot->close; - ctx->sk_destruct = sk->sk_destruct; - sk->sk_destruct = tls_hw_sk_destruct; - ctx->rx_conf = TLS_HW_RECORD; - ctx->tx_conf = TLS_HW_RECORD; - update_sk_prot(sk, ctx); - spin_lock_bh(&device_spinlock); - rc = 1; - break; - } - } -out: - spin_unlock_bh(&device_spinlock); - return rc; -} - -static void tls_hw_unhash(struct sock *sk) -{ - struct tls_context *ctx = tls_get_ctx(sk); - struct tls_device *dev; - - spin_lock_bh(&device_spinlock); - list_for_each_entry(dev, &device_list, dev_list) { - if (dev->unhash) { - kref_get(&dev->kref); - spin_unlock_bh(&device_spinlock); - dev->unhash(dev, sk); - kref_put(&dev->kref, dev->release); - spin_lock_bh(&device_spinlock); - } - } - spin_unlock_bh(&device_spinlock); - ctx->unhash(sk); -} - -static int tls_hw_hash(struct sock *sk) -{ - struct tls_context *ctx = tls_get_ctx(sk); - struct tls_device *dev; - int err; - - err = ctx->hash(sk); - spin_lock_bh(&device_spinlock); - list_for_each_entry(dev, &device_list, dev_list) { - if (dev->hash) { - kref_get(&dev->kref); - spin_unlock_bh(&device_spinlock); - err |= dev->hash(dev, sk); - kref_put(&dev->kref, dev->release); - spin_lock_bh(&device_spinlock); - } - } - spin_unlock_bh(&device_spinlock); - - if (err) - tls_hw_unhash(sk); - return err; -} - static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], struct proto *base) { @@ -733,11 +688,11 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW]; #endif - +#ifdef CONFIG_TLS_TOE prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base; - prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_hw_hash; - prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_hw_unhash; - prot[TLS_HW_RECORD][TLS_HW_RECORD].close = tls_sk_proto_close; + prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_toe_hash; + prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_toe_unhash; +#endif } static int tls_init(struct sock *sk) @@ -745,8 +700,12 @@ static int tls_init(struct sock *sk) struct tls_context *ctx; int rc = 0; - if (tls_hw_prot(sk)) - goto out; + tls_build_proto(sk); + +#ifdef CONFIG_TLS_TOE + if (tls_toe_bypass(sk)) + return 0; +#endif /* The TLS ulp is currently supported only for TCP sockets * in ESTABLISHED state. @@ -758,50 +717,144 @@ static int tls_init(struct sock *sk) return -ENOTSUPP; /* allocate tls context */ - ctx = create_ctx(sk); + write_lock_bh(&sk->sk_callback_lock); + ctx = tls_ctx_create(sk); if (!ctx) { rc = -ENOMEM; goto out; } - tls_build_proto(sk); ctx->tx_conf = TLS_BASE; ctx->rx_conf = TLS_BASE; update_sk_prot(sk, ctx); out: + write_unlock_bh(&sk->sk_callback_lock); return rc; } -void tls_register_device(struct tls_device *device) +static void tls_update(struct sock *sk, struct proto *p) { - spin_lock_bh(&device_spinlock); - list_add_tail(&device->dev_list, &device_list); - spin_unlock_bh(&device_spinlock); + struct tls_context *ctx; + + ctx = tls_get_ctx(sk); + if (likely(ctx)) + ctx->sk_proto = p; + else + sk->sk_prot = p; } -EXPORT_SYMBOL(tls_register_device); -void tls_unregister_device(struct tls_device *device) +static int tls_get_info(const struct sock *sk, struct sk_buff *skb) { - spin_lock_bh(&device_spinlock); - list_del(&device->dev_list); - spin_unlock_bh(&device_spinlock); + u16 version, cipher_type; + struct tls_context *ctx; + struct nlattr *start; + int err; + + start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS); + if (!start) + return -EMSGSIZE; + + rcu_read_lock(); + ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data); + if (!ctx) { + err = 0; + goto nla_failure; + } + version = ctx->prot_info.version; + if (version) { + err = nla_put_u16(skb, TLS_INFO_VERSION, version); + if (err) + goto nla_failure; + } + cipher_type = ctx->prot_info.cipher_type; + if (cipher_type) { + err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type); + if (err) + goto nla_failure; + } + err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true)); + if (err) + goto nla_failure; + + err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false)); + if (err) + goto nla_failure; + + rcu_read_unlock(); + nla_nest_end(skb, start); + return 0; + +nla_failure: + rcu_read_unlock(); + nla_nest_cancel(skb, start); + return err; +} + +static size_t tls_get_info_size(const struct sock *sk) +{ + size_t size = 0; + + size += nla_total_size(0) + /* INET_ULP_INFO_TLS */ + nla_total_size(sizeof(u16)) + /* TLS_INFO_VERSION */ + nla_total_size(sizeof(u16)) + /* TLS_INFO_CIPHER */ + nla_total_size(sizeof(u16)) + /* TLS_INFO_RXCONF */ + nla_total_size(sizeof(u16)) + /* TLS_INFO_TXCONF */ + 0; + + return size; +} + +static int __net_init tls_init_net(struct net *net) +{ + int err; + + net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib); + if (!net->mib.tls_statistics) + return -ENOMEM; + + err = tls_proc_init(net); + if (err) + goto err_free_stats; + + return 0; +err_free_stats: + free_percpu(net->mib.tls_statistics); + return err; } -EXPORT_SYMBOL(tls_unregister_device); + +static void __net_exit tls_exit_net(struct net *net) +{ + tls_proc_fini(net); + free_percpu(net->mib.tls_statistics); +} + +static struct pernet_operations tls_proc_ops = { + .init = tls_init_net, + .exit = tls_exit_net, +}; static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { .name = "tls", .owner = THIS_MODULE, .init = tls_init, + .update = tls_update, + .get_info = tls_get_info, + .get_info_size = tls_get_info_size, }; static int __init tls_register(void) { + int err; + + err = register_pernet_subsys(&tls_proc_ops); + if (err) + return err; + tls_sw_proto_ops = inet_stream_ops; tls_sw_proto_ops.splice_read = tls_sw_splice_read; + tls_sw_proto_ops.sendpage_locked = tls_sw_sendpage_locked, -#ifdef CONFIG_TLS_DEVICE tls_device_init(); -#endif tcp_register_ulp(&tcp_tls_ulp_ops); return 0; @@ -810,9 +863,8 @@ static int __init tls_register(void) static void __exit tls_unregister(void) { tcp_unregister_ulp(&tcp_tls_ulp_ops); -#ifdef CONFIG_TLS_DEVICE tls_device_cleanup(); -#endif + unregister_pernet_subsys(&tls_proc_ops); } module_init(tls_register); |
