diff options
Diffstat (limited to 'net')
340 files changed, 14440 insertions, 7067 deletions
diff --git a/net/8021q/vlan_dev.c b/net/8021q/vlan_dev.c index 546af0e73ac3..ff720f1ebf73 100644 --- a/net/8021q/vlan_dev.c +++ b/net/8021q/vlan_dev.c @@ -756,8 +756,7 @@ static void vlan_dev_netpoll_cleanup(struct net_device *dev) return; vlan->netpoll = NULL; - - __netpoll_free_async(netpoll); + __netpoll_free(netpoll); } #endif /* CONFIG_NET_POLL_CONTROLLER */ diff --git a/net/9p/Makefile b/net/9p/Makefile index c0486cfc85d9..aa0a5641e5d0 100644 --- a/net/9p/Makefile +++ b/net/9p/Makefile @@ -8,7 +8,6 @@ obj-$(CONFIG_NET_9P_RDMA) += 9pnet_rdma.o mod.o \ client.o \ error.o \ - util.o \ protocol.o \ trans_fd.o \ trans_common.o \ diff --git a/net/9p/client.c b/net/9p/client.c index deae53a7dffc..2c9a17b9b46b 100644 --- a/net/9p/client.c +++ b/net/9p/client.c @@ -231,144 +231,170 @@ free_and_return: return ret; } -static struct p9_fcall *p9_fcall_alloc(int alloc_msize) +static int p9_fcall_init(struct p9_client *c, struct p9_fcall *fc, + int alloc_msize) { - struct p9_fcall *fc; - fc = kmalloc(sizeof(struct p9_fcall) + alloc_msize, GFP_NOFS); - if (!fc) - return NULL; + if (likely(c->fcall_cache) && alloc_msize == c->msize) { + fc->sdata = kmem_cache_alloc(c->fcall_cache, GFP_NOFS); + fc->cache = c->fcall_cache; + } else { + fc->sdata = kmalloc(alloc_msize, GFP_NOFS); + fc->cache = NULL; + } + if (!fc->sdata) + return -ENOMEM; fc->capacity = alloc_msize; - fc->sdata = (char *) fc + sizeof(struct p9_fcall); - return fc; + return 0; +} + +void p9_fcall_fini(struct p9_fcall *fc) +{ + /* sdata can be NULL for interrupted requests in trans_rdma, + * and kmem_cache_free does not do NULL-check for us + */ + if (unlikely(!fc->sdata)) + return; + + if (fc->cache) + kmem_cache_free(fc->cache, fc->sdata); + else + kfree(fc->sdata); } +EXPORT_SYMBOL(p9_fcall_fini); + +static struct kmem_cache *p9_req_cache; /** - * p9_tag_alloc - lookup/allocate a request by tag - * @c: client session to lookup tag within - * @tag: numeric id for transaction - * - * this is a simple array lookup, but will grow the - * request_slots as necessary to accommodate transaction - * ids which did not previously have a slot. - * - * this code relies on the client spinlock to manage locks, its - * possible we should switch to something else, but I'd rather - * stick with something low-overhead for the common case. + * p9_req_alloc - Allocate a new request. + * @c: Client session. + * @type: Transaction type. + * @max_size: Maximum packet size for this request. * + * Context: Process context. + * Return: Pointer to new request. */ - static struct p9_req_t * -p9_tag_alloc(struct p9_client *c, u16 tag, unsigned int max_size) +p9_tag_alloc(struct p9_client *c, int8_t type, unsigned int max_size) { - unsigned long flags; - int row, col; - struct p9_req_t *req; + struct p9_req_t *req = kmem_cache_alloc(p9_req_cache, GFP_NOFS); int alloc_msize = min(c->msize, max_size); + int tag; - /* This looks up the original request by tag so we know which - * buffer to read the data into */ - tag++; - - if (tag >= c->max_tag) { - spin_lock_irqsave(&c->lock, flags); - /* check again since original check was outside of lock */ - while (tag >= c->max_tag) { - row = (tag / P9_ROW_MAXTAG); - c->reqs[row] = kcalloc(P9_ROW_MAXTAG, - sizeof(struct p9_req_t), GFP_ATOMIC); - - if (!c->reqs[row]) { - pr_err("Couldn't grow tag array\n"); - spin_unlock_irqrestore(&c->lock, flags); - return ERR_PTR(-ENOMEM); - } - for (col = 0; col < P9_ROW_MAXTAG; col++) { - req = &c->reqs[row][col]; - req->status = REQ_STATUS_IDLE; - init_waitqueue_head(&req->wq); - } - c->max_tag += P9_ROW_MAXTAG; - } - spin_unlock_irqrestore(&c->lock, flags); - } - row = tag / P9_ROW_MAXTAG; - col = tag % P9_ROW_MAXTAG; - - req = &c->reqs[row][col]; - if (!req->tc) - req->tc = p9_fcall_alloc(alloc_msize); - if (!req->rc) - req->rc = p9_fcall_alloc(alloc_msize); - if (!req->tc || !req->rc) - goto grow_failed; + if (!req) + return ERR_PTR(-ENOMEM); - p9pdu_reset(req->tc); - p9pdu_reset(req->rc); + if (p9_fcall_init(c, &req->tc, alloc_msize)) + goto free_req; + if (p9_fcall_init(c, &req->rc, alloc_msize)) + goto free; - req->tc->tag = tag-1; + p9pdu_reset(&req->tc); + p9pdu_reset(&req->rc); req->status = REQ_STATUS_ALLOC; + init_waitqueue_head(&req->wq); + INIT_LIST_HEAD(&req->req_list); + + idr_preload(GFP_NOFS); + spin_lock_irq(&c->lock); + if (type == P9_TVERSION) + tag = idr_alloc(&c->reqs, req, P9_NOTAG, P9_NOTAG + 1, + GFP_NOWAIT); + else + tag = idr_alloc(&c->reqs, req, 0, P9_NOTAG, GFP_NOWAIT); + req->tc.tag = tag; + spin_unlock_irq(&c->lock); + idr_preload_end(); + if (tag < 0) + goto free; + + /* Init ref to two because in the general case there is one ref + * that is put asynchronously by a writer thread, one ref + * temporarily given by p9_tag_lookup and put by p9_client_cb + * in the recv thread, and one ref put by p9_tag_remove in the + * main thread. The only exception is virtio that does not use + * p9_tag_lookup but does not have a writer thread either + * (the write happens synchronously in the request/zc_request + * callback), so p9_client_cb eats the second ref there + * as the pointer is duplicated directly by virtqueue_add_sgs() + */ + refcount_set(&req->refcount.refcount, 2); return req; -grow_failed: - pr_err("Couldn't grow tag array\n"); - kfree(req->tc); - kfree(req->rc); - req->tc = req->rc = NULL; +free: + p9_fcall_fini(&req->tc); + p9_fcall_fini(&req->rc); +free_req: + kmem_cache_free(p9_req_cache, req); return ERR_PTR(-ENOMEM); } /** - * p9_tag_lookup - lookup a request by tag - * @c: client session to lookup tag within - * @tag: numeric id for transaction + * p9_tag_lookup - Look up a request by tag. + * @c: Client session. + * @tag: Transaction ID. * + * Context: Any context. + * Return: A request, or %NULL if there is no request with that tag. */ - struct p9_req_t *p9_tag_lookup(struct p9_client *c, u16 tag) { - int row, col; - - /* This looks up the original request by tag so we know which - * buffer to read the data into */ - tag++; - - if (tag >= c->max_tag) - return NULL; + struct p9_req_t *req; - row = tag / P9_ROW_MAXTAG; - col = tag % P9_ROW_MAXTAG; + rcu_read_lock(); +again: + req = idr_find(&c->reqs, tag); + if (req) { + /* We have to be careful with the req found under rcu_read_lock + * Thanks to SLAB_TYPESAFE_BY_RCU we can safely try to get the + * ref again without corrupting other data, then check again + * that the tag matches once we have the ref + */ + if (!p9_req_try_get(req)) + goto again; + if (req->tc.tag != tag) { + p9_req_put(req); + goto again; + } + } + rcu_read_unlock(); - return &c->reqs[row][col]; + return req; } EXPORT_SYMBOL(p9_tag_lookup); /** - * p9_tag_init - setup tags structure and contents - * @c: v9fs client struct - * - * This initializes the tags structure for each client instance. + * p9_tag_remove - Remove a tag. + * @c: Client session. + * @r: Request of reference. * + * Context: Any context. */ +static int p9_tag_remove(struct p9_client *c, struct p9_req_t *r) +{ + unsigned long flags; + u16 tag = r->tc.tag; + + p9_debug(P9_DEBUG_MUX, "clnt %p req %p tag: %d\n", c, r, tag); + spin_lock_irqsave(&c->lock, flags); + idr_remove(&c->reqs, tag); + spin_unlock_irqrestore(&c->lock, flags); + return p9_req_put(r); +} -static int p9_tag_init(struct p9_client *c) +static void p9_req_free(struct kref *ref) { - int err = 0; + struct p9_req_t *r = container_of(ref, struct p9_req_t, refcount); + p9_fcall_fini(&r->tc); + p9_fcall_fini(&r->rc); + kmem_cache_free(p9_req_cache, r); +} - c->tagpool = p9_idpool_create(); - if (IS_ERR(c->tagpool)) { - err = PTR_ERR(c->tagpool); - goto error; - } - err = p9_idpool_get(c->tagpool); /* reserve tag 0 */ - if (err < 0) { - p9_idpool_destroy(c->tagpool); - goto error; - } - c->max_tag = 0; -error: - return err; +int p9_req_put(struct p9_req_t *r) +{ + return kref_put(&r->refcount, p9_req_free); } +EXPORT_SYMBOL(p9_req_put); /** * p9_tag_cleanup - cleans up tags structure and reclaims resources @@ -379,52 +405,17 @@ error: */ static void p9_tag_cleanup(struct p9_client *c) { - int row, col; - - /* check to insure all requests are idle */ - for (row = 0; row < (c->max_tag/P9_ROW_MAXTAG); row++) { - for (col = 0; col < P9_ROW_MAXTAG; col++) { - if (c->reqs[row][col].status != REQ_STATUS_IDLE) { - p9_debug(P9_DEBUG_MUX, - "Attempting to cleanup non-free tag %d,%d\n", - row, col); - /* TODO: delay execution of cleanup */ - return; - } - } - } - - if (c->tagpool) { - p9_idpool_put(0, c->tagpool); /* free reserved tag 0 */ - p9_idpool_destroy(c->tagpool); - } + struct p9_req_t *req; + int id; - /* free requests associated with tags */ - for (row = 0; row < (c->max_tag/P9_ROW_MAXTAG); row++) { - for (col = 0; col < P9_ROW_MAXTAG; col++) { - kfree(c->reqs[row][col].tc); - kfree(c->reqs[row][col].rc); - } - kfree(c->reqs[row]); + rcu_read_lock(); + idr_for_each_entry(&c->reqs, req, id) { + pr_info("Tag %d still in use\n", id); + if (p9_tag_remove(c, req) == 0) + pr_warn("Packet with tag %d has still references", + req->tc.tag); } - c->max_tag = 0; -} - -/** - * p9_free_req - free a request and clean-up as necessary - * c: client state - * r: request to release - * - */ - -static void p9_free_req(struct p9_client *c, struct p9_req_t *r) -{ - int tag = r->tc->tag; - p9_debug(P9_DEBUG_MUX, "clnt %p req %p tag: %d\n", c, r, tag); - - r->status = REQ_STATUS_IDLE; - if (tag != P9_NOTAG && p9_idpool_check(tag, c->tagpool)) - p9_idpool_put(tag, c->tagpool); + rcu_read_unlock(); } /** @@ -435,7 +426,7 @@ static void p9_free_req(struct p9_client *c, struct p9_req_t *r) */ void p9_client_cb(struct p9_client *c, struct p9_req_t *req, int status) { - p9_debug(P9_DEBUG_MUX, " tag %d\n", req->tc->tag); + p9_debug(P9_DEBUG_MUX, " tag %d\n", req->tc.tag); /* * This barrier is needed to make sure any change made to req before @@ -445,7 +436,8 @@ void p9_client_cb(struct p9_client *c, struct p9_req_t *req, int status) req->status = status; wake_up(&req->wq); - p9_debug(P9_DEBUG_MUX, "wakeup: %d\n", req->tc->tag); + p9_debug(P9_DEBUG_MUX, "wakeup: %d\n", req->tc.tag); + p9_req_put(req); } EXPORT_SYMBOL(p9_client_cb); @@ -516,18 +508,18 @@ static int p9_check_errors(struct p9_client *c, struct p9_req_t *req) int err; int ecode; - err = p9_parse_header(req->rc, NULL, &type, NULL, 0); - if (req->rc->size >= c->msize) { + err = p9_parse_header(&req->rc, NULL, &type, NULL, 0); + if (req->rc.size >= c->msize) { p9_debug(P9_DEBUG_ERROR, "requested packet size too big: %d\n", - req->rc->size); + req->rc.size); return -EIO; } /* * dump the response from server * This should be after check errors which poplulate pdu_fcall. */ - trace_9p_protocol_dump(c, req->rc); + trace_9p_protocol_dump(c, &req->rc); if (err) { p9_debug(P9_DEBUG_ERROR, "couldn't parse header %d\n", err); return err; @@ -537,7 +529,7 @@ static int p9_check_errors(struct p9_client *c, struct p9_req_t *req) if (!p9_is_proto_dotl(c)) { char *ename; - err = p9pdu_readf(req->rc, c->proto_version, "s?d", + err = p9pdu_readf(&req->rc, c->proto_version, "s?d", &ename, &ecode); if (err) goto out_err; @@ -553,7 +545,7 @@ static int p9_check_errors(struct p9_client *c, struct p9_req_t *req) } kfree(ename); } else { - err = p9pdu_readf(req->rc, c->proto_version, "d", &ecode); + err = p9pdu_readf(&req->rc, c->proto_version, "d", &ecode); err = -ecode; p9_debug(P9_DEBUG_9P, "<<< RLERROR (%d)\n", -ecode); @@ -587,12 +579,12 @@ static int p9_check_zc_errors(struct p9_client *c, struct p9_req_t *req, int8_t type; char *ename = NULL; - err = p9_parse_header(req->rc, NULL, &type, NULL, 0); + err = p9_parse_header(&req->rc, NULL, &type, NULL, 0); /* * dump the response from server * This should be after parse_header which poplulate pdu_fcall. */ - trace_9p_protocol_dump(c, req->rc); + trace_9p_protocol_dump(c, &req->rc); if (err) { p9_debug(P9_DEBUG_ERROR, "couldn't parse header %d\n", err); return err; @@ -607,13 +599,13 @@ static int p9_check_zc_errors(struct p9_client *c, struct p9_req_t *req, /* 7 = header size for RERROR; */ int inline_len = in_hdrlen - 7; - len = req->rc->size - req->rc->offset; + len = req->rc.size - req->rc.offset; if (len > (P9_ZC_HDR_SZ - 7)) { err = -EFAULT; goto out_err; } - ename = &req->rc->sdata[req->rc->offset]; + ename = &req->rc.sdata[req->rc.offset]; if (len > inline_len) { /* We have error in external buffer */ if (!copy_from_iter_full(ename + inline_len, @@ -623,7 +615,7 @@ static int p9_check_zc_errors(struct p9_client *c, struct p9_req_t *req, } } ename = NULL; - err = p9pdu_readf(req->rc, c->proto_version, "s?d", + err = p9pdu_readf(&req->rc, c->proto_version, "s?d", &ename, &ecode); if (err) goto out_err; @@ -639,7 +631,7 @@ static int p9_check_zc_errors(struct p9_client *c, struct p9_req_t *req, } kfree(ename); } else { - err = p9pdu_readf(req->rc, c->proto_version, "d", &ecode); + err = p9pdu_readf(&req->rc, c->proto_version, "d", &ecode); err = -ecode; p9_debug(P9_DEBUG_9P, "<<< RLERROR (%d)\n", -ecode); @@ -672,7 +664,7 @@ static int p9_client_flush(struct p9_client *c, struct p9_req_t *oldreq) int16_t oldtag; int err; - err = p9_parse_header(oldreq->tc, NULL, NULL, &oldtag, 1); + err = p9_parse_header(&oldreq->tc, NULL, NULL, &oldtag, 1); if (err) return err; @@ -686,11 +678,12 @@ static int p9_client_flush(struct p9_client *c, struct p9_req_t *oldreq) * if we haven't received a response for oldreq, * remove it from the list */ - if (oldreq->status == REQ_STATUS_SENT) + if (oldreq->status == REQ_STATUS_SENT) { if (c->trans_mod->cancelled) c->trans_mod->cancelled(c, oldreq); + } - p9_free_req(c, req); + p9_tag_remove(c, req); return 0; } @@ -698,7 +691,7 @@ static struct p9_req_t *p9_client_prepare_req(struct p9_client *c, int8_t type, int req_size, const char *fmt, va_list ap) { - int tag, err; + int err; struct p9_req_t *req; p9_debug(P9_DEBUG_MUX, "client %p op %d\n", c, type); @@ -711,27 +704,22 @@ static struct p9_req_t *p9_client_prepare_req(struct p9_client *c, if ((c->status == BeginDisconnect) && (type != P9_TCLUNK)) return ERR_PTR(-EIO); - tag = P9_NOTAG; - if (type != P9_TVERSION) { - tag = p9_idpool_get(c->tagpool); - if (tag < 0) - return ERR_PTR(-ENOMEM); - } - - req = p9_tag_alloc(c, tag, req_size); + req = p9_tag_alloc(c, type, req_size); if (IS_ERR(req)) return req; /* marshall the data */ - p9pdu_prepare(req->tc, tag, type); - err = p9pdu_vwritef(req->tc, c->proto_version, fmt, ap); + p9pdu_prepare(&req->tc, req->tc.tag, type); + err = p9pdu_vwritef(&req->tc, c->proto_version, fmt, ap); if (err) goto reterr; - p9pdu_finalize(c, req->tc); - trace_9p_client_req(c, type, tag); + p9pdu_finalize(c, &req->tc); + trace_9p_client_req(c, type, req->tc.tag); return req; reterr: - p9_free_req(c, req); + p9_tag_remove(c, req); + /* We have to put also the 2nd reference as it won't be used */ + p9_req_put(req); return ERR_PTR(err); } @@ -741,7 +729,7 @@ reterr: * @type: type of request * @fmt: protocol format string (see protocol.c) * - * Returns request structure (which client must free using p9_free_req) + * Returns request structure (which client must free using p9_tag_remove) */ static struct p9_req_t * @@ -766,6 +754,8 @@ p9_client_rpc(struct p9_client *c, int8_t type, const char *fmt, ...) err = c->trans_mod->request(c, req); if (err < 0) { + /* write won't happen */ + p9_req_put(req); if (err != -ERESTARTSYS && err != -EFAULT) c->status = Disconnected; goto recalc_sigpending; @@ -813,11 +803,11 @@ recalc_sigpending: goto reterr; err = p9_check_errors(c, req); - trace_9p_client_res(c, type, req->rc->tag, err); + trace_9p_client_res(c, type, req->rc.tag, err); if (!err) return req; reterr: - p9_free_req(c, req); + p9_tag_remove(c, req); return ERR_PTR(safe_errno(err)); } @@ -832,7 +822,7 @@ reterr: * @hdrlen: reader header size, This is the size of response protocol data * @fmt: protocol format string (see protocol.c) * - * Returns request structure (which client must free using p9_free_req) + * Returns request structure (which client must free using p9_tag_remove) */ static struct p9_req_t *p9_client_zc_rpc(struct p9_client *c, int8_t type, struct iov_iter *uidata, @@ -895,11 +885,11 @@ recalc_sigpending: goto reterr; err = p9_check_zc_errors(c, req, uidata, in_hdrlen); - trace_9p_client_res(c, type, req->rc->tag, err); + trace_9p_client_res(c, type, req->rc.tag, err); if (!err) return req; reterr: - p9_free_req(c, req); + p9_tag_remove(c, req); return ERR_PTR(safe_errno(err)); } @@ -978,10 +968,10 @@ static int p9_client_version(struct p9_client *c) if (IS_ERR(req)) return PTR_ERR(req); - err = p9pdu_readf(req->rc, c->proto_version, "ds", &msize, &version); + err = p9pdu_readf(&req->rc, c->proto_version, "ds", &msize, &version); if (err) { p9_debug(P9_DEBUG_9P, "version error %d\n", err); - trace_9p_protocol_dump(c, req->rc); + trace_9p_protocol_dump(c, &req->rc); goto error; } @@ -1002,7 +992,7 @@ static int p9_client_version(struct p9_client *c) error: kfree(version); - p9_free_req(c, req); + p9_tag_remove(c, req); return err; } @@ -1020,20 +1010,18 @@ struct p9_client *p9_client_create(const char *dev_name, char *options) clnt->trans_mod = NULL; clnt->trans = NULL; + clnt->fcall_cache = NULL; client_id = utsname()->nodename; memcpy(clnt->name, client_id, strlen(client_id) + 1); spin_lock_init(&clnt->lock); idr_init(&clnt->fids); - - err = p9_tag_init(clnt); - if (err < 0) - goto free_client; + idr_init(&clnt->reqs); err = parse_opts(options, clnt); if (err < 0) - goto destroy_tagpool; + goto free_client; if (!clnt->trans_mod) clnt->trans_mod = v9fs_get_default_trans(); @@ -1042,7 +1030,7 @@ struct p9_client *p9_client_create(const char *dev_name, char *options) err = -EPROTONOSUPPORT; p9_debug(P9_DEBUG_ERROR, "No transport defined or default transport\n"); - goto destroy_tagpool; + goto free_client; } p9_debug(P9_DEBUG_MUX, "clnt %p trans %p msize %d protocol %d\n", @@ -1059,14 +1047,21 @@ struct p9_client *p9_client_create(const char *dev_name, char *options) if (err) goto close_trans; + /* P9_HDRSZ + 4 is the smallest packet header we can have that is + * followed by data accessed from userspace by read + */ + clnt->fcall_cache = + kmem_cache_create_usercopy("9p-fcall-cache", clnt->msize, + 0, 0, P9_HDRSZ + 4, + clnt->msize - (P9_HDRSZ + 4), + NULL); + return clnt; close_trans: clnt->trans_mod->close(clnt); put_trans: v9fs_put_trans(clnt->trans_mod); -destroy_tagpool: - p9_idpool_destroy(clnt->tagpool); free_client: kfree(clnt); return ERR_PTR(err); @@ -1092,6 +1087,7 @@ void p9_client_destroy(struct p9_client *clnt) p9_tag_cleanup(clnt); + kmem_cache_destroy(clnt->fcall_cache); kfree(clnt); } EXPORT_SYMBOL(p9_client_destroy); @@ -1135,10 +1131,10 @@ struct p9_fid *p9_client_attach(struct p9_client *clnt, struct p9_fid *afid, goto error; } - err = p9pdu_readf(req->rc, clnt->proto_version, "Q", &qid); + err = p9pdu_readf(&req->rc, clnt->proto_version, "Q", &qid); if (err) { - trace_9p_protocol_dump(clnt, req->rc); - p9_free_req(clnt, req); + trace_9p_protocol_dump(clnt, &req->rc); + p9_tag_remove(clnt, req); goto error; } @@ -1147,7 +1143,7 @@ struct p9_fid *p9_client_attach(struct p9_client *clnt, struct p9_fid *afid, memmove(&fid->qid, &qid, sizeof(struct p9_qid)); - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); return fid; error: @@ -1192,13 +1188,13 @@ struct p9_fid *p9_client_walk(struct p9_fid *oldfid, uint16_t nwname, goto error; } - err = p9pdu_readf(req->rc, clnt->proto_version, "R", &nwqids, &wqids); + err = p9pdu_readf(&req->rc, clnt->proto_version, "R", &nwqids, &wqids); if (err) { - trace_9p_protocol_dump(clnt, req->rc); - p9_free_req(clnt, req); + trace_9p_protocol_dump(clnt, &req->rc); + p9_tag_remove(clnt, req); goto clunk_fid; } - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); p9_debug(P9_DEBUG_9P, "<<< RWALK nwqid %d:\n", nwqids); @@ -1259,9 +1255,9 @@ int p9_client_open(struct p9_fid *fid, int mode) goto error; } - err = p9pdu_readf(req->rc, clnt->proto_version, "Qd", &qid, &iounit); + err = p9pdu_readf(&req->rc, clnt->proto_version, "Qd", &qid, &iounit); if (err) { - trace_9p_protocol_dump(clnt, req->rc); + trace_9p_protocol_dump(clnt, &req->rc); goto free_and_error; } @@ -1273,7 +1269,7 @@ int p9_client_open(struct p9_fid *fid, int mode) fid->iounit = iounit; free_and_error: - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); error: return err; } @@ -1303,9 +1299,9 @@ int p9_client_create_dotl(struct p9_fid *ofid, const char *name, u32 flags, u32 goto error; } - err = p9pdu_readf(req->rc, clnt->proto_version, "Qd", qid, &iounit); + err = p9pdu_readf(&req->rc, clnt->proto_version, "Qd", qid, &iounit); if (err) { - trace_9p_protocol_dump(clnt, req->rc); + trace_9p_protocol_dump(clnt, &req->rc); goto free_and_error; } @@ -1318,7 +1314,7 @@ int p9_client_create_dotl(struct p9_fid *ofid, const char *name, u32 flags, u32 ofid->iounit = iounit; free_and_error: - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); error: return err; } @@ -1348,9 +1344,9 @@ int p9_client_fcreate(struct p9_fid *fid, const char *name, u32 perm, int mode, goto error; } - err = p9pdu_readf(req->rc, clnt->proto_version, "Qd", &qid, &iounit); + err = p9pdu_readf(&req->rc, clnt->proto_version, "Qd", &qid, &iounit); if (err) { - trace_9p_protocol_dump(clnt, req->rc); + trace_9p_protocol_dump(clnt, &req->rc); goto free_and_error; } @@ -1363,7 +1359,7 @@ int p9_client_fcreate(struct p9_fid *fid, const char *name, u32 perm, int mode, fid->iounit = iounit; free_and_error: - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); error: return err; } @@ -1387,9 +1383,9 @@ int p9_client_symlink(struct p9_fid *dfid, const char *name, goto error; } - err = p9pdu_readf(req->rc, clnt->proto_version, "Q", qid); + err = p9pdu_readf(&req->rc, clnt->proto_version, "Q", qid); if (err) { - trace_9p_protocol_dump(clnt, req->rc); + trace_9p_protocol_dump(clnt, &req->rc); goto free_and_error; } @@ -1397,7 +1393,7 @@ int p9_client_symlink(struct p9_fid *dfid, const char *name, qid->type, (unsigned long long)qid->path, qid->version); free_and_error: - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); error: return err; } @@ -1417,7 +1413,7 @@ int p9_client_link(struct p9_fid *dfid, struct p9_fid *oldfid, const char *newna return PTR_ERR(req); p9_debug(P9_DEBUG_9P, "<<< RLINK\n"); - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); return 0; } EXPORT_SYMBOL(p9_client_link); @@ -1441,7 +1437,7 @@ int p9_client_fsync(struct p9_fid *fid, int datasync) p9_debug(P9_DEBUG_9P, "<<< RFSYNC fid %d\n", fid->fid); - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); error: return err; @@ -1476,7 +1472,7 @@ again: p9_debug(P9_DEBUG_9P, "<<< RCLUNK fid %d\n", fid->fid); - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); error: /* * Fid is not valid even after a failed clunk @@ -1510,7 +1506,7 @@ int p9_client_remove(struct p9_fid *fid) p9_debug(P9_DEBUG_9P, "<<< RREMOVE fid %d\n", fid->fid); - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); error: if (err == -ERESTARTSYS) p9_client_clunk(fid); @@ -1537,7 +1533,7 @@ int p9_client_unlinkat(struct p9_fid *dfid, const char *name, int flags) } p9_debug(P9_DEBUG_9P, "<<< RUNLINKAT fid %d %s\n", dfid->fid, name); - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); error: return err; } @@ -1585,11 +1581,11 @@ p9_client_read(struct p9_fid *fid, u64 offset, struct iov_iter *to, int *err) break; } - *err = p9pdu_readf(req->rc, clnt->proto_version, + *err = p9pdu_readf(&req->rc, clnt->proto_version, "D", &count, &dataptr); if (*err) { - trace_9p_protocol_dump(clnt, req->rc); - p9_free_req(clnt, req); + trace_9p_protocol_dump(clnt, &req->rc); + p9_tag_remove(clnt, req); break; } if (rsize < count) { @@ -1599,7 +1595,7 @@ p9_client_read(struct p9_fid *fid, u64 offset, struct iov_iter *to, int *err) p9_debug(P9_DEBUG_9P, "<<< RREAD count %d\n", count); if (!count) { - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); break; } @@ -1609,7 +1605,7 @@ p9_client_read(struct p9_fid *fid, u64 offset, struct iov_iter *to, int *err) offset += n; if (n != count) { *err = -EFAULT; - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); break; } } else { @@ -1617,7 +1613,7 @@ p9_client_read(struct p9_fid *fid, u64 offset, struct iov_iter *to, int *err) total += count; offset += count; } - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); } return total; } @@ -1658,10 +1654,10 @@ p9_client_write(struct p9_fid *fid, u64 offset, struct iov_iter *from, int *err) break; } - *err = p9pdu_readf(req->rc, clnt->proto_version, "d", &count); + *err = p9pdu_readf(&req->rc, clnt->proto_version, "d", &count); if (*err) { - trace_9p_protocol_dump(clnt, req->rc); - p9_free_req(clnt, req); + trace_9p_protocol_dump(clnt, &req->rc); + p9_tag_remove(clnt, req); break; } if (rsize < count) { @@ -1671,7 +1667,7 @@ p9_client_write(struct p9_fid *fid, u64 offset, struct iov_iter *from, int *err) p9_debug(P9_DEBUG_9P, "<<< RWRITE count %d\n", count); - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); iov_iter_advance(from, count); total += count; offset += count; @@ -1702,10 +1698,10 @@ struct p9_wstat *p9_client_stat(struct p9_fid *fid) goto error; } - err = p9pdu_readf(req->rc, clnt->proto_version, "wS", &ignored, ret); + err = p9pdu_readf(&req->rc, clnt->proto_version, "wS", &ignored, ret); if (err) { - trace_9p_protocol_dump(clnt, req->rc); - p9_free_req(clnt, req); + trace_9p_protocol_dump(clnt, &req->rc); + p9_tag_remove(clnt, req); goto error; } @@ -1722,7 +1718,7 @@ struct p9_wstat *p9_client_stat(struct p9_fid *fid) from_kgid(&init_user_ns, ret->n_gid), from_kuid(&init_user_ns, ret->n_muid)); - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); return ret; error: @@ -1755,10 +1751,10 @@ struct p9_stat_dotl *p9_client_getattr_dotl(struct p9_fid *fid, goto error; } - err = p9pdu_readf(req->rc, clnt->proto_version, "A", ret); + err = p9pdu_readf(&req->rc, clnt->proto_version, "A", ret); if (err) { - trace_9p_protocol_dump(clnt, req->rc); - p9_free_req(clnt, req); + trace_9p_protocol_dump(clnt, &req->rc); + p9_tag_remove(clnt, req); goto error; } @@ -1783,7 +1779,7 @@ struct p9_stat_dotl *p9_client_getattr_dotl(struct p9_fid *fid, ret->st_ctime_nsec, ret->st_btime_sec, ret->st_btime_nsec, ret->st_gen, ret->st_data_version); - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); return ret; error: @@ -1852,7 +1848,7 @@ int p9_client_wstat(struct p9_fid *fid, struct p9_wstat *wst) p9_debug(P9_DEBUG_9P, "<<< RWSTAT fid %d\n", fid->fid); - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); error: return err; } @@ -1884,7 +1880,7 @@ int p9_client_setattr(struct p9_fid *fid, struct p9_iattr_dotl *p9attr) goto error; } p9_debug(P9_DEBUG_9P, "<<< RSETATTR fid %d\n", fid->fid); - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); error: return err; } @@ -1907,12 +1903,12 @@ int p9_client_statfs(struct p9_fid *fid, struct p9_rstatfs *sb) goto error; } - err = p9pdu_readf(req->rc, clnt->proto_version, "ddqqqqqqd", &sb->type, - &sb->bsize, &sb->blocks, &sb->bfree, &sb->bavail, - &sb->files, &sb->ffree, &sb->fsid, &sb->namelen); + err = p9pdu_readf(&req->rc, clnt->proto_version, "ddqqqqqqd", &sb->type, + &sb->bsize, &sb->blocks, &sb->bfree, &sb->bavail, + &sb->files, &sb->ffree, &sb->fsid, &sb->namelen); if (err) { - trace_9p_protocol_dump(clnt, req->rc); - p9_free_req(clnt, req); + trace_9p_protocol_dump(clnt, &req->rc); + p9_tag_remove(clnt, req); goto error; } @@ -1923,7 +1919,7 @@ int p9_client_statfs(struct p9_fid *fid, struct p9_rstatfs *sb) sb->blocks, sb->bfree, sb->bavail, sb->files, sb->ffree, sb->fsid, (long int)sb->namelen); - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); error: return err; } @@ -1951,7 +1947,7 @@ int p9_client_rename(struct p9_fid *fid, p9_debug(P9_DEBUG_9P, "<<< RRENAME fid %d\n", fid->fid); - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); error: return err; } @@ -1981,7 +1977,7 @@ int p9_client_renameat(struct p9_fid *olddirfid, const char *old_name, p9_debug(P9_DEBUG_9P, "<<< RRENAMEAT newdirfid %d new name %s\n", newdirfid->fid, new_name); - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); error: return err; } @@ -2015,13 +2011,13 @@ struct p9_fid *p9_client_xattrwalk(struct p9_fid *file_fid, err = PTR_ERR(req); goto error; } - err = p9pdu_readf(req->rc, clnt->proto_version, "q", attr_size); + err = p9pdu_readf(&req->rc, clnt->proto_version, "q", attr_size); if (err) { - trace_9p_protocol_dump(clnt, req->rc); - p9_free_req(clnt, req); + trace_9p_protocol_dump(clnt, &req->rc); + p9_tag_remove(clnt, req); goto clunk_fid; } - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); p9_debug(P9_DEBUG_9P, "<<< RXATTRWALK fid %d size %llu\n", attr_fid->fid, *attr_size); return attr_fid; @@ -2055,7 +2051,7 @@ int p9_client_xattrcreate(struct p9_fid *fid, const char *name, goto error; } p9_debug(P9_DEBUG_9P, "<<< RXATTRCREATE fid %d\n", fid->fid); - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); error: return err; } @@ -2070,7 +2066,7 @@ int p9_client_readdir(struct p9_fid *fid, char *data, u32 count, u64 offset) struct kvec kv = {.iov_base = data, .iov_len = count}; struct iov_iter to; - iov_iter_kvec(&to, READ | ITER_KVEC, &kv, 1, count); + iov_iter_kvec(&to, READ, &kv, 1, count); p9_debug(P9_DEBUG_9P, ">>> TREADDIR fid %d offset %llu count %d\n", fid->fid, (unsigned long long) offset, count); @@ -2103,9 +2099,9 @@ int p9_client_readdir(struct p9_fid *fid, char *data, u32 count, u64 offset) goto error; } - err = p9pdu_readf(req->rc, clnt->proto_version, "D", &count, &dataptr); + err = p9pdu_readf(&req->rc, clnt->proto_version, "D", &count, &dataptr); if (err) { - trace_9p_protocol_dump(clnt, req->rc); + trace_9p_protocol_dump(clnt, &req->rc); goto free_and_error; } if (rsize < count) { @@ -2118,11 +2114,11 @@ int p9_client_readdir(struct p9_fid *fid, char *data, u32 count, u64 offset) if (non_zc) memmove(data, dataptr, count); - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); return count; free_and_error: - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); error: return err; } @@ -2144,16 +2140,16 @@ int p9_client_mknod_dotl(struct p9_fid *fid, const char *name, int mode, if (IS_ERR(req)) return PTR_ERR(req); - err = p9pdu_readf(req->rc, clnt->proto_version, "Q", qid); + err = p9pdu_readf(&req->rc, clnt->proto_version, "Q", qid); if (err) { - trace_9p_protocol_dump(clnt, req->rc); + trace_9p_protocol_dump(clnt, &req->rc); goto error; } p9_debug(P9_DEBUG_9P, "<<< RMKNOD qid %x.%llx.%x\n", qid->type, (unsigned long long)qid->path, qid->version); error: - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); return err; } @@ -2175,16 +2171,16 @@ int p9_client_mkdir_dotl(struct p9_fid *fid, const char *name, int mode, if (IS_ERR(req)) return PTR_ERR(req); - err = p9pdu_readf(req->rc, clnt->proto_version, "Q", qid); + err = p9pdu_readf(&req->rc, clnt->proto_version, "Q", qid); if (err) { - trace_9p_protocol_dump(clnt, req->rc); + trace_9p_protocol_dump(clnt, &req->rc); goto error; } p9_debug(P9_DEBUG_9P, "<<< RMKDIR qid %x.%llx.%x\n", qid->type, (unsigned long long)qid->path, qid->version); error: - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); return err; } @@ -2210,14 +2206,14 @@ int p9_client_lock_dotl(struct p9_fid *fid, struct p9_flock *flock, u8 *status) if (IS_ERR(req)) return PTR_ERR(req); - err = p9pdu_readf(req->rc, clnt->proto_version, "b", status); + err = p9pdu_readf(&req->rc, clnt->proto_version, "b", status); if (err) { - trace_9p_protocol_dump(clnt, req->rc); + trace_9p_protocol_dump(clnt, &req->rc); goto error; } p9_debug(P9_DEBUG_9P, "<<< RLOCK status %i\n", *status); error: - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); return err; } @@ -2241,18 +2237,18 @@ int p9_client_getlock_dotl(struct p9_fid *fid, struct p9_getlock *glock) if (IS_ERR(req)) return PTR_ERR(req); - err = p9pdu_readf(req->rc, clnt->proto_version, "bqqds", &glock->type, - &glock->start, &glock->length, &glock->proc_id, - &glock->client_id); + err = p9pdu_readf(&req->rc, clnt->proto_version, "bqqds", &glock->type, + &glock->start, &glock->length, &glock->proc_id, + &glock->client_id); if (err) { - trace_9p_protocol_dump(clnt, req->rc); + trace_9p_protocol_dump(clnt, &req->rc); goto error; } p9_debug(P9_DEBUG_9P, "<<< RGETLOCK type %i start %lld length %lld " "proc_id %d client_id %s\n", glock->type, glock->start, glock->length, glock->proc_id, glock->client_id); error: - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); return err; } EXPORT_SYMBOL(p9_client_getlock_dotl); @@ -2271,14 +2267,25 @@ int p9_client_readlink(struct p9_fid *fid, char **target) if (IS_ERR(req)) return PTR_ERR(req); - err = p9pdu_readf(req->rc, clnt->proto_version, "s", target); + err = p9pdu_readf(&req->rc, clnt->proto_version, "s", target); if (err) { - trace_9p_protocol_dump(clnt, req->rc); + trace_9p_protocol_dump(clnt, &req->rc); goto error; } p9_debug(P9_DEBUG_9P, "<<< RREADLINK target %s\n", *target); error: - p9_free_req(clnt, req); + p9_tag_remove(clnt, req); return err; } EXPORT_SYMBOL(p9_client_readlink); + +int __init p9_client_init(void) +{ + p9_req_cache = KMEM_CACHE(p9_req_t, SLAB_TYPESAFE_BY_RCU); + return p9_req_cache ? 0 : -ENOMEM; +} + +void __exit p9_client_exit(void) +{ + kmem_cache_destroy(p9_req_cache); +} diff --git a/net/9p/mod.c b/net/9p/mod.c index 253ba824a325..0da56d6af73b 100644 --- a/net/9p/mod.c +++ b/net/9p/mod.c @@ -171,11 +171,17 @@ void v9fs_put_trans(struct p9_trans_module *m) */ static int __init init_p9(void) { + int ret; + + ret = p9_client_init(); + if (ret) + return ret; + p9_error_init(); pr_info("Installing 9P2000 support\n"); p9_trans_fd_init(); - return 0; + return ret; } /** @@ -188,6 +194,7 @@ static void __exit exit_p9(void) pr_info("Unloading 9P2000 support\n"); p9_trans_fd_exit(); + p9_client_exit(); } module_init(init_p9) diff --git a/net/9p/protocol.c b/net/9p/protocol.c index 4a1e1dd30b52..462ba144cb39 100644 --- a/net/9p/protocol.c +++ b/net/9p/protocol.c @@ -46,10 +46,15 @@ p9pdu_writef(struct p9_fcall *pdu, int proto_version, const char *fmt, ...); void p9stat_free(struct p9_wstat *stbuf) { kfree(stbuf->name); + stbuf->name = NULL; kfree(stbuf->uid); + stbuf->uid = NULL; kfree(stbuf->gid); + stbuf->gid = NULL; kfree(stbuf->muid); + stbuf->muid = NULL; kfree(stbuf->extension); + stbuf->extension = NULL; } EXPORT_SYMBOL(p9stat_free); @@ -566,9 +571,10 @@ int p9stat_read(struct p9_client *clnt, char *buf, int len, struct p9_wstat *st) if (ret) { p9_debug(P9_DEBUG_9P, "<<< p9stat_read failed: %d\n", ret); trace_9p_protocol_dump(clnt, &fake_pdu); + return ret; } - return ret; + return fake_pdu.offset; } EXPORT_SYMBOL(p9stat_read); @@ -617,13 +623,19 @@ int p9dirent_read(struct p9_client *clnt, char *buf, int len, if (ret) { p9_debug(P9_DEBUG_9P, "<<< p9dirent_read failed: %d\n", ret); trace_9p_protocol_dump(clnt, &fake_pdu); - goto out; + return ret; } - strcpy(dirent->d_name, nameptr); + ret = strscpy(dirent->d_name, nameptr, sizeof(dirent->d_name)); + if (ret < 0) { + p9_debug(P9_DEBUG_ERROR, + "On the wire dirent name too long: %s\n", + nameptr); + kfree(nameptr); + return ret; + } kfree(nameptr); -out: return fake_pdu.offset; } EXPORT_SYMBOL(p9dirent_read); diff --git a/net/9p/trans_fd.c b/net/9p/trans_fd.c index e2ef3c782c53..f868cf6fba79 100644 --- a/net/9p/trans_fd.c +++ b/net/9p/trans_fd.c @@ -131,7 +131,8 @@ struct p9_conn { int err; struct list_head req_list; struct list_head unsent_req_list; - struct p9_req_t *req; + struct p9_req_t *rreq; + struct p9_req_t *wreq; char tmp_buf[7]; struct p9_fcall rc; int wpos; @@ -291,7 +292,6 @@ static void p9_read_work(struct work_struct *work) __poll_t n; int err; struct p9_conn *m; - int status = REQ_STATUS_ERROR; m = container_of(work, struct p9_conn, rq); @@ -322,7 +322,7 @@ static void p9_read_work(struct work_struct *work) m->rc.offset += err; /* header read in */ - if ((!m->req) && (m->rc.offset == m->rc.capacity)) { + if ((!m->rreq) && (m->rc.offset == m->rc.capacity)) { p9_debug(P9_DEBUG_TRANS, "got new header\n"); /* Header size */ @@ -346,23 +346,23 @@ static void p9_read_work(struct work_struct *work) "mux %p pkt: size: %d bytes tag: %d\n", m, m->rc.size, m->rc.tag); - m->req = p9_tag_lookup(m->client, m->rc.tag); - if (!m->req || (m->req->status != REQ_STATUS_SENT)) { + m->rreq = p9_tag_lookup(m->client, m->rc.tag); + if (!m->rreq || (m->rreq->status != REQ_STATUS_SENT)) { p9_debug(P9_DEBUG_ERROR, "Unexpected packet tag %d\n", m->rc.tag); err = -EIO; goto error; } - if (m->req->rc == NULL) { + if (!m->rreq->rc.sdata) { p9_debug(P9_DEBUG_ERROR, "No recv fcall for tag %d (req %p), disconnecting!\n", - m->rc.tag, m->req); - m->req = NULL; + m->rc.tag, m->rreq); + m->rreq = NULL; err = -EIO; goto error; } - m->rc.sdata = (char *)m->req->rc + sizeof(struct p9_fcall); + m->rc.sdata = m->rreq->rc.sdata; memcpy(m->rc.sdata, m->tmp_buf, m->rc.capacity); m->rc.capacity = m->rc.size; } @@ -370,20 +370,27 @@ static void p9_read_work(struct work_struct *work) /* packet is read in * not an else because some packets (like clunk) have no payload */ - if ((m->req) && (m->rc.offset == m->rc.capacity)) { + if ((m->rreq) && (m->rc.offset == m->rc.capacity)) { p9_debug(P9_DEBUG_TRANS, "got new packet\n"); - m->req->rc->size = m->rc.offset; + m->rreq->rc.size = m->rc.offset; spin_lock(&m->client->lock); - if (m->req->status != REQ_STATUS_ERROR) - status = REQ_STATUS_RCVD; - list_del(&m->req->req_list); - /* update req->status while holding client->lock */ - p9_client_cb(m->client, m->req, status); + if (m->rreq->status == REQ_STATUS_SENT) { + list_del(&m->rreq->req_list); + p9_client_cb(m->client, m->rreq, REQ_STATUS_RCVD); + } else { + spin_unlock(&m->client->lock); + p9_debug(P9_DEBUG_ERROR, + "Request tag %d errored out while we were reading the reply\n", + m->rc.tag); + err = -EIO; + goto error; + } spin_unlock(&m->client->lock); m->rc.sdata = NULL; m->rc.offset = 0; m->rc.capacity = 0; - m->req = NULL; + p9_req_put(m->rreq); + m->rreq = NULL; } end_clear: @@ -469,9 +476,11 @@ static void p9_write_work(struct work_struct *work) p9_debug(P9_DEBUG_TRANS, "move req %p\n", req); list_move_tail(&req->req_list, &m->req_list); - m->wbuf = req->tc->sdata; - m->wsize = req->tc->size; + m->wbuf = req->tc.sdata; + m->wsize = req->tc.size; m->wpos = 0; + p9_req_get(req); + m->wreq = req; spin_unlock(&m->client->lock); } @@ -492,8 +501,11 @@ static void p9_write_work(struct work_struct *work) } m->wpos += err; - if (m->wpos == m->wsize) + if (m->wpos == m->wsize) { m->wpos = m->wsize = 0; + p9_req_put(m->wreq); + m->wreq = NULL; + } end_clear: clear_bit(Wworksched, &m->wsched); @@ -663,7 +675,7 @@ static int p9_fd_request(struct p9_client *client, struct p9_req_t *req) struct p9_conn *m = &ts->conn; p9_debug(P9_DEBUG_TRANS, "mux %p task %p tcall %p id %d\n", - m, current, req->tc, req->tc->id); + m, current, &req->tc, req->tc.id); if (m->err < 0) return m->err; @@ -694,6 +706,7 @@ static int p9_fd_cancel(struct p9_client *client, struct p9_req_t *req) if (req->status == REQ_STATUS_UNSENT) { list_del(&req->req_list); req->status = REQ_STATUS_FLSHD; + p9_req_put(req); ret = 0; } spin_unlock(&client->lock); @@ -711,6 +724,7 @@ static int p9_fd_cancelled(struct p9_client *client, struct p9_req_t *req) spin_lock(&client->lock); list_del(&req->req_list); spin_unlock(&client->lock); + p9_req_put(req); return 0; } @@ -862,7 +876,15 @@ static void p9_conn_destroy(struct p9_conn *m) p9_mux_poll_stop(m); cancel_work_sync(&m->rq); + if (m->rreq) { + p9_req_put(m->rreq); + m->rreq = NULL; + } cancel_work_sync(&m->wq); + if (m->wreq) { + p9_req_put(m->wreq); + m->wreq = NULL; + } p9_conn_cancel(m, -ECONNRESET); diff --git a/net/9p/trans_rdma.c b/net/9p/trans_rdma.c index b513cffeeb3c..119103bfa82e 100644 --- a/net/9p/trans_rdma.c +++ b/net/9p/trans_rdma.c @@ -122,7 +122,7 @@ struct p9_rdma_context { dma_addr_t busa; union { struct p9_req_t *req; - struct p9_fcall *rc; + struct p9_fcall rc; }; }; @@ -274,8 +274,7 @@ p9_cm_event_handler(struct rdma_cm_id *id, struct rdma_cm_event *event) case RDMA_CM_EVENT_DISCONNECTED: if (rdma) rdma->state = P9_RDMA_CLOSED; - if (c) - c->status = Disconnected; + c->status = Disconnected; break; case RDMA_CM_EVENT_TIMEWAIT_EXIT: @@ -320,8 +319,8 @@ recv_done(struct ib_cq *cq, struct ib_wc *wc) if (wc->status != IB_WC_SUCCESS) goto err_out; - c->rc->size = wc->byte_len; - err = p9_parse_header(c->rc, NULL, NULL, &tag, 1); + c->rc.size = wc->byte_len; + err = p9_parse_header(&c->rc, NULL, NULL, &tag, 1); if (err) goto err_out; @@ -331,12 +330,13 @@ recv_done(struct ib_cq *cq, struct ib_wc *wc) /* Check that we have not yet received a reply for this request. */ - if (unlikely(req->rc)) { + if (unlikely(req->rc.sdata)) { pr_err("Duplicate reply for request %d", tag); goto err_out; } - req->rc = c->rc; + req->rc.size = c->rc.size; + req->rc.sdata = c->rc.sdata; p9_client_cb(client, req, REQ_STATUS_RCVD); out: @@ -361,9 +361,10 @@ send_done(struct ib_cq *cq, struct ib_wc *wc) container_of(wc->wr_cqe, struct p9_rdma_context, cqe); ib_dma_unmap_single(rdma->cm_id->device, - c->busa, c->req->tc->size, + c->busa, c->req->tc.size, DMA_TO_DEVICE); up(&rdma->sq_sem); + p9_req_put(c->req); kfree(c); } @@ -401,7 +402,7 @@ post_recv(struct p9_client *client, struct p9_rdma_context *c) struct ib_sge sge; c->busa = ib_dma_map_single(rdma->cm_id->device, - c->rc->sdata, client->msize, + c->rc.sdata, client->msize, DMA_FROM_DEVICE); if (ib_dma_mapping_error(rdma->cm_id->device, c->busa)) goto error; @@ -443,9 +444,9 @@ static int rdma_request(struct p9_client *client, struct p9_req_t *req) **/ if (unlikely(atomic_read(&rdma->excess_rc) > 0)) { if ((atomic_sub_return(1, &rdma->excess_rc) >= 0)) { - /* Got one ! */ - kfree(req->rc); - req->rc = NULL; + /* Got one! */ + p9_fcall_fini(&req->rc); + req->rc.sdata = NULL; goto dont_need_post_recv; } else { /* We raced and lost. */ @@ -459,7 +460,7 @@ static int rdma_request(struct p9_client *client, struct p9_req_t *req) err = -ENOMEM; goto recv_error; } - rpl_context->rc = req->rc; + rpl_context->rc.sdata = req->rc.sdata; /* * Post a receive buffer for this request. We need to ensure @@ -475,11 +476,11 @@ static int rdma_request(struct p9_client *client, struct p9_req_t *req) err = post_recv(client, rpl_context); if (err) { - p9_debug(P9_DEBUG_FCALL, "POST RECV failed\n"); + p9_debug(P9_DEBUG_ERROR, "POST RECV failed: %d\n", err); goto recv_error; } /* remove posted receive buffer from request structure */ - req->rc = NULL; + req->rc.sdata = NULL; dont_need_post_recv: /* Post the request */ @@ -491,7 +492,7 @@ dont_need_post_recv: c->req = req; c->busa = ib_dma_map_single(rdma->cm_id->device, - c->req->tc->sdata, c->req->tc->size, + c->req->tc.sdata, c->req->tc.size, DMA_TO_DEVICE); if (ib_dma_mapping_error(rdma->cm_id->device, c->busa)) { err = -EIO; @@ -501,7 +502,7 @@ dont_need_post_recv: c->cqe.done = send_done; sge.addr = c->busa; - sge.length = c->req->tc->size; + sge.length = c->req->tc.size; sge.lkey = rdma->pd->local_dma_lkey; wr.next = NULL; @@ -544,7 +545,7 @@ dont_need_post_recv: recv_error: kfree(rpl_context); spin_lock_irqsave(&rdma->req_lock, flags); - if (rdma->state < P9_RDMA_CLOSING) { + if (err != -EINTR && rdma->state < P9_RDMA_CLOSING) { rdma->state = P9_RDMA_CLOSING; spin_unlock_irqrestore(&rdma->req_lock, flags); rdma_disconnect(rdma->cm_id); diff --git a/net/9p/trans_virtio.c b/net/9p/trans_virtio.c index 7728b0acde09..b1d39cabf125 100644 --- a/net/9p/trans_virtio.c +++ b/net/9p/trans_virtio.c @@ -155,7 +155,7 @@ static void req_done(struct virtqueue *vq) } if (len) { - req->rc->size = len; + req->rc.size = len; p9_client_cb(chan->client, req, REQ_STATUS_RCVD); } } @@ -207,6 +207,13 @@ static int p9_virtio_cancel(struct p9_client *client, struct p9_req_t *req) return 1; } +/* Reply won't come, so drop req ref */ +static int p9_virtio_cancelled(struct p9_client *client, struct p9_req_t *req) +{ + p9_req_put(req); + return 0; +} + /** * pack_sg_list_p - Just like pack_sg_list. Instead of taking a buffer, * this takes a list of pages. @@ -273,12 +280,12 @@ req_retry: out_sgs = in_sgs = 0; /* Handle out VirtIO ring buffers */ out = pack_sg_list(chan->sg, 0, - VIRTQUEUE_NUM, req->tc->sdata, req->tc->size); + VIRTQUEUE_NUM, req->tc.sdata, req->tc.size); if (out) sgs[out_sgs++] = chan->sg; in = pack_sg_list(chan->sg, out, - VIRTQUEUE_NUM, req->rc->sdata, req->rc->capacity); + VIRTQUEUE_NUM, req->rc.sdata, req->rc.capacity); if (in) sgs[out_sgs + in_sgs++] = chan->sg + out; @@ -322,7 +329,7 @@ static int p9_get_mapped_pages(struct virtio_chan *chan, if (!iov_iter_count(data)) return 0; - if (!(data->type & ITER_KVEC)) { + if (!iov_iter_is_kvec(data)) { int n; /* * We allow only p9_max_pages pinned. We wait for the @@ -404,6 +411,7 @@ p9_virtio_zc_request(struct p9_client *client, struct p9_req_t *req, struct scatterlist *sgs[4]; size_t offs; int need_drop = 0; + int kicked = 0; p9_debug(P9_DEBUG_TRANS, "virtio request\n"); @@ -411,29 +419,33 @@ p9_virtio_zc_request(struct p9_client *client, struct p9_req_t *req, __le32 sz; int n = p9_get_mapped_pages(chan, &out_pages, uodata, outlen, &offs, &need_drop); - if (n < 0) - return n; + if (n < 0) { + err = n; + goto err_out; + } out_nr_pages = DIV_ROUND_UP(n + offs, PAGE_SIZE); if (n != outlen) { __le32 v = cpu_to_le32(n); - memcpy(&req->tc->sdata[req->tc->size - 4], &v, 4); + memcpy(&req->tc.sdata[req->tc.size - 4], &v, 4); outlen = n; } /* The size field of the message must include the length of the * header and the length of the data. We didn't actually know * the length of the data until this point so add it in now. */ - sz = cpu_to_le32(req->tc->size + outlen); - memcpy(&req->tc->sdata[0], &sz, sizeof(sz)); + sz = cpu_to_le32(req->tc.size + outlen); + memcpy(&req->tc.sdata[0], &sz, sizeof(sz)); } else if (uidata) { int n = p9_get_mapped_pages(chan, &in_pages, uidata, inlen, &offs, &need_drop); - if (n < 0) - return n; + if (n < 0) { + err = n; + goto err_out; + } in_nr_pages = DIV_ROUND_UP(n + offs, PAGE_SIZE); if (n != inlen) { __le32 v = cpu_to_le32(n); - memcpy(&req->tc->sdata[req->tc->size - 4], &v, 4); + memcpy(&req->tc.sdata[req->tc.size - 4], &v, 4); inlen = n; } } @@ -445,7 +457,7 @@ req_retry_pinned: /* out data */ out = pack_sg_list(chan->sg, 0, - VIRTQUEUE_NUM, req->tc->sdata, req->tc->size); + VIRTQUEUE_NUM, req->tc.sdata, req->tc.size); if (out) sgs[out_sgs++] = chan->sg; @@ -464,7 +476,7 @@ req_retry_pinned: * alloced memory and payload onto the user buffer. */ in = pack_sg_list(chan->sg, out, - VIRTQUEUE_NUM, req->rc->sdata, in_hdr_len); + VIRTQUEUE_NUM, req->rc.sdata, in_hdr_len); if (in) sgs[out_sgs + in_sgs++] = chan->sg + out; @@ -498,6 +510,7 @@ req_retry_pinned: } virtqueue_kick(chan->vq); spin_unlock_irqrestore(&chan->lock, flags); + kicked = 1; p9_debug(P9_DEBUG_TRANS, "virtio request kicked\n"); err = wait_event_killable(req->wq, req->status >= REQ_STATUS_RCVD); /* @@ -518,6 +531,10 @@ err_out: } kvfree(in_pages); kvfree(out_pages); + if (!kicked) { + /* reply won't come */ + p9_req_put(req); + } return err; } @@ -750,6 +767,7 @@ static struct p9_trans_module p9_virtio_trans = { .request = p9_virtio_request, .zc_request = p9_virtio_zc_request, .cancel = p9_virtio_cancel, + .cancelled = p9_virtio_cancelled, /* * We leave one entry for input and one entry for response * headers. We also skip one more entry to accomodate, address diff --git a/net/9p/trans_xen.c b/net/9p/trans_xen.c index c2d54ac76bfd..e2fbf3677b9b 100644 --- a/net/9p/trans_xen.c +++ b/net/9p/trans_xen.c @@ -141,7 +141,7 @@ static int p9_xen_request(struct p9_client *client, struct p9_req_t *p9_req) struct xen_9pfs_front_priv *priv = NULL; RING_IDX cons, prod, masked_cons, masked_prod; unsigned long flags; - u32 size = p9_req->tc->size; + u32 size = p9_req->tc.size; struct xen_9pfs_dataring *ring; int num; @@ -154,7 +154,7 @@ static int p9_xen_request(struct p9_client *client, struct p9_req_t *p9_req) if (!priv || priv->client != client) return -EINVAL; - num = p9_req->tc->tag % priv->num_rings; + num = p9_req->tc.tag % priv->num_rings; ring = &priv->rings[num]; again: @@ -176,7 +176,7 @@ again: masked_prod = xen_9pfs_mask(prod, XEN_9PFS_RING_SIZE); masked_cons = xen_9pfs_mask(cons, XEN_9PFS_RING_SIZE); - xen_9pfs_write_packet(ring->data.out, p9_req->tc->sdata, size, + xen_9pfs_write_packet(ring->data.out, p9_req->tc.sdata, size, &masked_prod, masked_cons, XEN_9PFS_RING_SIZE); p9_req->status = REQ_STATUS_SENT; @@ -185,6 +185,7 @@ again: ring->intf->out_prod = prod; spin_unlock_irqrestore(&ring->lock, flags); notify_remote_via_irq(ring->irq); + p9_req_put(p9_req); return 0; } @@ -229,12 +230,12 @@ static void p9_xen_response(struct work_struct *work) continue; } - memcpy(req->rc, &h, sizeof(h)); - req->rc->offset = 0; + memcpy(&req->rc, &h, sizeof(h)); + req->rc.offset = 0; masked_cons = xen_9pfs_mask(cons, XEN_9PFS_RING_SIZE); /* Then, read the whole packet (including the header) */ - xen_9pfs_read_packet(req->rc->sdata, ring->data.in, h.size, + xen_9pfs_read_packet(req->rc.sdata, ring->data.in, h.size, masked_prod, &masked_cons, XEN_9PFS_RING_SIZE); @@ -391,8 +392,8 @@ static int xen_9pfs_front_probe(struct xenbus_device *dev, unsigned int max_rings, max_ring_order, len = 0; versions = xenbus_read(XBT_NIL, dev->otherend, "versions", &len); - if (!len) - return -EINVAL; + if (IS_ERR(versions)) + return PTR_ERR(versions); if (strcmp(versions, "1")) { kfree(versions); return -EINVAL; diff --git a/net/9p/util.c b/net/9p/util.c deleted file mode 100644 index 55ad98277e85..000000000000 --- a/net/9p/util.c +++ /dev/null @@ -1,140 +0,0 @@ -/* - * net/9p/util.c - * - * This file contains some helper functions - * - * Copyright (C) 2007 by Latchesar Ionkov <lucho@ionkov.net> - * Copyright (C) 2004 by Eric Van Hensbergen <ericvh@gmail.com> - * Copyright (C) 2002 by Ron Minnich <rminnich@lanl.gov> - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License version 2 - * as published by the Free Software Foundation. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program; if not, write to: - * Free Software Foundation - * 51 Franklin Street, Fifth Floor - * Boston, MA 02111-1301 USA - * - */ - -#include <linux/module.h> -#include <linux/errno.h> -#include <linux/fs.h> -#include <linux/sched.h> -#include <linux/parser.h> -#include <linux/idr.h> -#include <linux/slab.h> -#include <net/9p/9p.h> - -/** - * struct p9_idpool - per-connection accounting for tag idpool - * @lock: protects the pool - * @pool: idr to allocate tag id from - * - */ - -struct p9_idpool { - spinlock_t lock; - struct idr pool; -}; - -/** - * p9_idpool_create - create a new per-connection id pool - * - */ - -struct p9_idpool *p9_idpool_create(void) -{ - struct p9_idpool *p; - - p = kmalloc(sizeof(struct p9_idpool), GFP_KERNEL); - if (!p) - return ERR_PTR(-ENOMEM); - - spin_lock_init(&p->lock); - idr_init(&p->pool); - - return p; -} -EXPORT_SYMBOL(p9_idpool_create); - -/** - * p9_idpool_destroy - create a new per-connection id pool - * @p: idpool to destroy - */ - -void p9_idpool_destroy(struct p9_idpool *p) -{ - idr_destroy(&p->pool); - kfree(p); -} -EXPORT_SYMBOL(p9_idpool_destroy); - -/** - * p9_idpool_get - allocate numeric id from pool - * @p: pool to allocate from - * - * Bugs: This seems to be an awful generic function, should it be in idr.c with - * the lock included in struct idr? - */ - -int p9_idpool_get(struct p9_idpool *p) -{ - int i; - unsigned long flags; - - idr_preload(GFP_NOFS); - spin_lock_irqsave(&p->lock, flags); - - /* no need to store exactly p, we just need something non-null */ - i = idr_alloc(&p->pool, p, 0, 0, GFP_NOWAIT); - - spin_unlock_irqrestore(&p->lock, flags); - idr_preload_end(); - if (i < 0) - return -1; - - p9_debug(P9_DEBUG_MUX, " id %d pool %p\n", i, p); - return i; -} -EXPORT_SYMBOL(p9_idpool_get); - -/** - * p9_idpool_put - release numeric id from pool - * @id: numeric id which is being released - * @p: pool to release id into - * - * Bugs: This seems to be an awful generic function, should it be in idr.c with - * the lock included in struct idr? - */ - -void p9_idpool_put(int id, struct p9_idpool *p) -{ - unsigned long flags; - - p9_debug(P9_DEBUG_MUX, " id %d pool %p\n", id, p); - - spin_lock_irqsave(&p->lock, flags); - idr_remove(&p->pool, id); - spin_unlock_irqrestore(&p->lock, flags); -} -EXPORT_SYMBOL(p9_idpool_put); - -/** - * p9_idpool_check - check if the specified id is available - * @id: id to check - * @p: pool to check - */ - -int p9_idpool_check(int id, struct p9_idpool *p) -{ - return idr_find(&p->pool, id) != NULL; -} -EXPORT_SYMBOL(p9_idpool_check); diff --git a/net/Kconfig b/net/Kconfig index 228dfa382eec..f235edb593ba 100644 --- a/net/Kconfig +++ b/net/Kconfig @@ -300,8 +300,11 @@ config BPF_JIT config BPF_STREAM_PARSER bool "enable BPF STREAM_PARSER" + depends on INET depends on BPF_SYSCALL + depends on CGROUP_BPF select STREAM_PARSER + select NET_SOCK_MSG ---help--- Enabling this allows a stream parser to be used with BPF_MAP_TYPE_SOCKMAP. @@ -413,6 +416,14 @@ config GRO_CELLS config SOCK_VALIDATE_XMIT bool +config NET_SOCK_MSG + bool + default n + help + The NET_SOCK_MSG provides a framework for plain sockets (e.g. TCP) or + ULPs (upper layer modules, e.g. TLS) to process L7 application data + with the help of BPF programs. + config NET_DEVLINK tristate "Network physical/parent device Netlink interface" help diff --git a/net/atm/common.c b/net/atm/common.c index 9f8cb0d2e71e..a38c174fc766 100644 --- a/net/atm/common.c +++ b/net/atm/common.c @@ -653,7 +653,7 @@ __poll_t vcc_poll(struct file *file, struct socket *sock, poll_table *wait) struct atm_vcc *vcc; __poll_t mask; - sock_poll_wait(file, wait); + sock_poll_wait(file, sock, wait); mask = 0; vcc = ATM_SD(sock); diff --git a/net/bluetooth/6lowpan.c b/net/bluetooth/6lowpan.c index 4e2576fc0c59..828e87fe8027 100644 --- a/net/bluetooth/6lowpan.c +++ b/net/bluetooth/6lowpan.c @@ -467,7 +467,7 @@ static int send_pkt(struct l2cap_chan *chan, struct sk_buff *skb, iv.iov_len = skb->len; memset(&msg, 0, sizeof(msg)); - iov_iter_kvec(&msg.msg_iter, WRITE | ITER_KVEC, &iv, 1, skb->len); + iov_iter_kvec(&msg.msg_iter, WRITE, &iv, 1, skb->len); err = l2cap_chan_send(chan, &msg, skb->len); if (err > 0) { diff --git a/net/bluetooth/a2mp.c b/net/bluetooth/a2mp.c index 51c2cf2d8923..58fc6333d412 100644 --- a/net/bluetooth/a2mp.c +++ b/net/bluetooth/a2mp.c @@ -63,7 +63,7 @@ static void a2mp_send(struct amp_mgr *mgr, u8 code, u8 ident, u16 len, void *dat memset(&msg, 0, sizeof(msg)); - iov_iter_kvec(&msg.msg_iter, WRITE | ITER_KVEC, &iv, 1, total_len); + iov_iter_kvec(&msg.msg_iter, WRITE, &iv, 1, total_len); l2cap_chan_send(chan, &msg, total_len); diff --git a/net/bluetooth/bnep/sock.c b/net/bluetooth/bnep/sock.c index 00deacdcb51c..cfd83c5521ae 100644 --- a/net/bluetooth/bnep/sock.c +++ b/net/bluetooth/bnep/sock.c @@ -49,18 +49,17 @@ static int bnep_sock_release(struct socket *sock) return 0; } -static int bnep_sock_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +static int do_bnep_sock_ioctl(struct socket *sock, unsigned int cmd, void __user *argp) { struct bnep_connlist_req cl; struct bnep_connadd_req ca; struct bnep_conndel_req cd; struct bnep_conninfo ci; struct socket *nsock; - void __user *argp = (void __user *)arg; __u32 supp_feat = BIT(BNEP_SETUP_RESPONSE); int err; - BT_DBG("cmd %x arg %lx", cmd, arg); + BT_DBG("cmd %x arg %p", cmd, argp); switch (cmd) { case BNEPCONNADD: @@ -134,16 +133,22 @@ static int bnep_sock_ioctl(struct socket *sock, unsigned int cmd, unsigned long return 0; } +static int bnep_sock_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + return do_bnep_sock_ioctl(sock, cmd, (void __user *)arg); +} + #ifdef CONFIG_COMPAT static int bnep_sock_compat_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) { + void __user *argp = compat_ptr(arg); if (cmd == BNEPGETCONNLIST) { struct bnep_connlist_req cl; + unsigned __user *p = argp; u32 uci; int err; - if (get_user(cl.cnum, (u32 __user *) arg) || - get_user(uci, (u32 __user *) (arg + 4))) + if (get_user(cl.cnum, p) || get_user(uci, p + 1)) return -EFAULT; cl.ci = compat_ptr(uci); @@ -153,13 +158,13 @@ static int bnep_sock_compat_ioctl(struct socket *sock, unsigned int cmd, unsigne err = bnep_get_connlist(&cl); - if (!err && put_user(cl.cnum, (u32 __user *) arg)) + if (!err && put_user(cl.cnum, p)) err = -EFAULT; return err; } - return bnep_sock_ioctl(sock, cmd, arg); + return do_bnep_sock_ioctl(sock, cmd, argp); } #endif diff --git a/net/bluetooth/cmtp/sock.c b/net/bluetooth/cmtp/sock.c index e08f28fadd65..defdd4871919 100644 --- a/net/bluetooth/cmtp/sock.c +++ b/net/bluetooth/cmtp/sock.c @@ -63,17 +63,16 @@ static int cmtp_sock_release(struct socket *sock) return 0; } -static int cmtp_sock_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +static int do_cmtp_sock_ioctl(struct socket *sock, unsigned int cmd, void __user *argp) { struct cmtp_connadd_req ca; struct cmtp_conndel_req cd; struct cmtp_connlist_req cl; struct cmtp_conninfo ci; struct socket *nsock; - void __user *argp = (void __user *)arg; int err; - BT_DBG("cmd %x arg %lx", cmd, arg); + BT_DBG("cmd %x arg %p", cmd, argp); switch (cmd) { case CMTPCONNADD: @@ -137,16 +136,22 @@ static int cmtp_sock_ioctl(struct socket *sock, unsigned int cmd, unsigned long return -EINVAL; } +static int cmtp_sock_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + return do_cmtp_sock_ioctl(sock, cmd, (void __user *)arg); +} + #ifdef CONFIG_COMPAT static int cmtp_sock_compat_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) { + void __user *argp = compat_ptr(arg); if (cmd == CMTPGETCONNLIST) { struct cmtp_connlist_req cl; + u32 __user *p = argp; u32 uci; int err; - if (get_user(cl.cnum, (u32 __user *) arg) || - get_user(uci, (u32 __user *) (arg + 4))) + if (get_user(cl.cnum, p) || get_user(uci, p + 1)) return -EFAULT; cl.ci = compat_ptr(uci); @@ -156,13 +161,13 @@ static int cmtp_sock_compat_ioctl(struct socket *sock, unsigned int cmd, unsigne err = cmtp_get_connlist(&cl); - if (!err && put_user(cl.cnum, (u32 __user *) arg)) + if (!err && put_user(cl.cnum, p)) err = -EFAULT; return err; } - return cmtp_sock_ioctl(sock, cmd, arg); + return do_cmtp_sock_ioctl(sock, cmd, argp); } #endif diff --git a/net/bluetooth/hci_event.c b/net/bluetooth/hci_event.c index f47f8fad757a..ef9928d7b4fb 100644 --- a/net/bluetooth/hci_event.c +++ b/net/bluetooth/hci_event.c @@ -4937,31 +4937,27 @@ static void le_conn_complete_evt(struct hci_dev *hdev, u8 status, hci_debugfs_create_conn(conn); hci_conn_add_sysfs(conn); - if (!status) { - /* The remote features procedure is defined for master - * role only. So only in case of an initiated connection - * request the remote features. - * - * If the local controller supports slave-initiated features - * exchange, then requesting the remote features in slave - * role is possible. Otherwise just transition into the - * connected state without requesting the remote features. - */ - if (conn->out || - (hdev->le_features[0] & HCI_LE_SLAVE_FEATURES)) { - struct hci_cp_le_read_remote_features cp; + /* The remote features procedure is defined for master + * role only. So only in case of an initiated connection + * request the remote features. + * + * If the local controller supports slave-initiated features + * exchange, then requesting the remote features in slave + * role is possible. Otherwise just transition into the + * connected state without requesting the remote features. + */ + if (conn->out || + (hdev->le_features[0] & HCI_LE_SLAVE_FEATURES)) { + struct hci_cp_le_read_remote_features cp; - cp.handle = __cpu_to_le16(conn->handle); + cp.handle = __cpu_to_le16(conn->handle); - hci_send_cmd(hdev, HCI_OP_LE_READ_REMOTE_FEATURES, - sizeof(cp), &cp); + hci_send_cmd(hdev, HCI_OP_LE_READ_REMOTE_FEATURES, + sizeof(cp), &cp); - hci_conn_hold(conn); - } else { - conn->state = BT_CONNECTED; - hci_connect_cfm(conn, status); - } + hci_conn_hold(conn); } else { + conn->state = BT_CONNECTED; hci_connect_cfm(conn, status); } diff --git a/net/bluetooth/hidp/core.c b/net/bluetooth/hidp/core.c index 3734dc1788b4..a442e21f3894 100644 --- a/net/bluetooth/hidp/core.c +++ b/net/bluetooth/hidp/core.c @@ -649,7 +649,7 @@ static void hidp_process_transmit(struct hidp_session *session, } static int hidp_setup_input(struct hidp_session *session, - struct hidp_connadd_req *req) + const struct hidp_connadd_req *req) { struct input_dev *input; int i; @@ -748,7 +748,7 @@ EXPORT_SYMBOL_GPL(hidp_hid_driver); /* This function sets up the hid device. It does not add it to the HID system. That is done in hidp_add_connection(). */ static int hidp_setup_hid(struct hidp_session *session, - struct hidp_connadd_req *req) + const struct hidp_connadd_req *req) { struct hid_device *hid; int err; @@ -807,7 +807,7 @@ fault: /* initialize session devices */ static int hidp_session_dev_init(struct hidp_session *session, - struct hidp_connadd_req *req) + const struct hidp_connadd_req *req) { int ret; @@ -906,7 +906,7 @@ static void hidp_session_dev_work(struct work_struct *work) static int hidp_session_new(struct hidp_session **out, const bdaddr_t *bdaddr, struct socket *ctrl_sock, struct socket *intr_sock, - struct hidp_connadd_req *req, + const struct hidp_connadd_req *req, struct l2cap_conn *conn) { struct hidp_session *session; @@ -1338,7 +1338,7 @@ static int hidp_verify_sockets(struct socket *ctrl_sock, return 0; } -int hidp_connection_add(struct hidp_connadd_req *req, +int hidp_connection_add(const struct hidp_connadd_req *req, struct socket *ctrl_sock, struct socket *intr_sock) { diff --git a/net/bluetooth/hidp/hidp.h b/net/bluetooth/hidp/hidp.h index 8798492a6e99..6ef88d0a1919 100644 --- a/net/bluetooth/hidp/hidp.h +++ b/net/bluetooth/hidp/hidp.h @@ -122,7 +122,7 @@ struct hidp_connlist_req { struct hidp_conninfo __user *ci; }; -int hidp_connection_add(struct hidp_connadd_req *req, struct socket *ctrl_sock, struct socket *intr_sock); +int hidp_connection_add(const struct hidp_connadd_req *req, struct socket *ctrl_sock, struct socket *intr_sock); int hidp_connection_del(struct hidp_conndel_req *req); int hidp_get_connlist(struct hidp_connlist_req *req); int hidp_get_conninfo(struct hidp_conninfo *ci); diff --git a/net/bluetooth/hidp/sock.c b/net/bluetooth/hidp/sock.c index 1eaac01f85de..9f85a1943be9 100644 --- a/net/bluetooth/hidp/sock.c +++ b/net/bluetooth/hidp/sock.c @@ -46,9 +46,8 @@ static int hidp_sock_release(struct socket *sock) return 0; } -static int hidp_sock_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +static int do_hidp_sock_ioctl(struct socket *sock, unsigned int cmd, void __user *argp) { - void __user *argp = (void __user *) arg; struct hidp_connadd_req ca; struct hidp_conndel_req cd; struct hidp_connlist_req cl; @@ -57,7 +56,7 @@ static int hidp_sock_ioctl(struct socket *sock, unsigned int cmd, unsigned long struct socket *isock; int err; - BT_DBG("cmd %x arg %lx", cmd, arg); + BT_DBG("cmd %x arg %p", cmd, argp); switch (cmd) { case HIDPCONNADD: @@ -122,6 +121,11 @@ static int hidp_sock_ioctl(struct socket *sock, unsigned int cmd, unsigned long return -EINVAL; } +static int hidp_sock_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + return do_hidp_sock_ioctl(sock, cmd, (void __user *)arg); +} + #ifdef CONFIG_COMPAT struct compat_hidp_connadd_req { int ctrl_sock; /* Connected control socket */ @@ -141,13 +145,15 @@ struct compat_hidp_connadd_req { static int hidp_sock_compat_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) { + void __user *argp = compat_ptr(arg); + int err; + if (cmd == HIDPGETCONNLIST) { struct hidp_connlist_req cl; + u32 __user *p = argp; u32 uci; - int err; - if (get_user(cl.cnum, (u32 __user *) arg) || - get_user(uci, (u32 __user *) (arg + 4))) + if (get_user(cl.cnum, p) || get_user(uci, p + 1)) return -EFAULT; cl.ci = compat_ptr(uci); @@ -157,39 +163,54 @@ static int hidp_sock_compat_ioctl(struct socket *sock, unsigned int cmd, unsigne err = hidp_get_connlist(&cl); - if (!err && put_user(cl.cnum, (u32 __user *) arg)) + if (!err && put_user(cl.cnum, p)) err = -EFAULT; return err; } else if (cmd == HIDPCONNADD) { - struct compat_hidp_connadd_req ca; - struct hidp_connadd_req __user *uca; + struct compat_hidp_connadd_req ca32; + struct hidp_connadd_req ca; + struct socket *csock; + struct socket *isock; - uca = compat_alloc_user_space(sizeof(*uca)); + if (!capable(CAP_NET_ADMIN)) + return -EPERM; - if (copy_from_user(&ca, (void __user *) arg, sizeof(ca))) + if (copy_from_user(&ca32, (void __user *) arg, sizeof(ca32))) return -EFAULT; - if (put_user(ca.ctrl_sock, &uca->ctrl_sock) || - put_user(ca.intr_sock, &uca->intr_sock) || - put_user(ca.parser, &uca->parser) || - put_user(ca.rd_size, &uca->rd_size) || - put_user(compat_ptr(ca.rd_data), &uca->rd_data) || - put_user(ca.country, &uca->country) || - put_user(ca.subclass, &uca->subclass) || - put_user(ca.vendor, &uca->vendor) || - put_user(ca.product, &uca->product) || - put_user(ca.version, &uca->version) || - put_user(ca.flags, &uca->flags) || - put_user(ca.idle_to, &uca->idle_to) || - copy_to_user(&uca->name[0], &ca.name[0], 128)) - return -EFAULT; + ca.ctrl_sock = ca32.ctrl_sock; + ca.intr_sock = ca32.intr_sock; + ca.parser = ca32.parser; + ca.rd_size = ca32.rd_size; + ca.rd_data = compat_ptr(ca32.rd_data); + ca.country = ca32.country; + ca.subclass = ca32.subclass; + ca.vendor = ca32.vendor; + ca.product = ca32.product; + ca.version = ca32.version; + ca.flags = ca32.flags; + ca.idle_to = ca32.idle_to; + memcpy(ca.name, ca32.name, 128); + + csock = sockfd_lookup(ca.ctrl_sock, &err); + if (!csock) + return err; - arg = (unsigned long) uca; + isock = sockfd_lookup(ca.intr_sock, &err); + if (!isock) { + sockfd_put(csock); + return err; + } - /* Fall through. We don't actually write back any _changes_ - to the structure anyway, so there's no need to copy back - into the original compat version */ + err = hidp_connection_add(&ca, csock, isock); + if (!err && copy_to_user(argp, &ca32, sizeof(ca32))) + err = -EFAULT; + + sockfd_put(csock); + sockfd_put(isock); + + return err; } return hidp_sock_ioctl(sock, cmd, arg); diff --git a/net/bluetooth/l2cap_core.c b/net/bluetooth/l2cap_core.c index 514899f7f0d4..2146e0f3b6f8 100644 --- a/net/bluetooth/l2cap_core.c +++ b/net/bluetooth/l2cap_core.c @@ -680,9 +680,9 @@ static void l2cap_chan_le_connect_reject(struct l2cap_chan *chan) u16 result; if (test_bit(FLAG_DEFER_SETUP, &chan->flags)) - result = L2CAP_CR_AUTHORIZATION; + result = L2CAP_CR_LE_AUTHORIZATION; else - result = L2CAP_CR_BAD_PSM; + result = L2CAP_CR_LE_BAD_PSM; l2cap_state_change(chan, BT_DISCONN); @@ -3670,7 +3670,7 @@ void __l2cap_le_connect_rsp_defer(struct l2cap_chan *chan) rsp.mtu = cpu_to_le16(chan->imtu); rsp.mps = cpu_to_le16(chan->mps); rsp.credits = cpu_to_le16(chan->rx_credits); - rsp.result = cpu_to_le16(L2CAP_CR_SUCCESS); + rsp.result = cpu_to_le16(L2CAP_CR_LE_SUCCESS); l2cap_send_cmd(conn, chan->ident, L2CAP_LE_CONN_RSP, sizeof(rsp), &rsp); @@ -3816,9 +3816,17 @@ static struct l2cap_chan *l2cap_connect(struct l2cap_conn *conn, result = L2CAP_CR_NO_MEM; + /* Check for valid dynamic CID range (as per Erratum 3253) */ + if (scid < L2CAP_CID_DYN_START || scid > L2CAP_CID_DYN_END) { + result = L2CAP_CR_INVALID_SCID; + goto response; + } + /* Check if we already have channel with that dcid */ - if (__l2cap_get_chan_by_dcid(conn, scid)) + if (__l2cap_get_chan_by_dcid(conn, scid)) { + result = L2CAP_CR_SCID_IN_USE; goto response; + } chan = pchan->ops->new_connection(pchan); if (!chan) @@ -5280,7 +5288,7 @@ static int l2cap_le_connect_rsp(struct l2cap_conn *conn, credits = __le16_to_cpu(rsp->credits); result = __le16_to_cpu(rsp->result); - if (result == L2CAP_CR_SUCCESS && (mtu < 23 || mps < 23 || + if (result == L2CAP_CR_LE_SUCCESS && (mtu < 23 || mps < 23 || dcid < L2CAP_CID_DYN_START || dcid > L2CAP_CID_LE_DYN_END)) return -EPROTO; @@ -5301,7 +5309,7 @@ static int l2cap_le_connect_rsp(struct l2cap_conn *conn, l2cap_chan_lock(chan); switch (result) { - case L2CAP_CR_SUCCESS: + case L2CAP_CR_LE_SUCCESS: if (__l2cap_get_chan_by_dcid(conn, dcid)) { err = -EBADSLT; break; @@ -5315,8 +5323,8 @@ static int l2cap_le_connect_rsp(struct l2cap_conn *conn, l2cap_chan_ready(chan); break; - case L2CAP_CR_AUTHENTICATION: - case L2CAP_CR_ENCRYPTION: + case L2CAP_CR_LE_AUTHENTICATION: + case L2CAP_CR_LE_ENCRYPTION: /* If we already have MITM protection we can't do * anything. */ @@ -5459,7 +5467,7 @@ static int l2cap_le_connect_req(struct l2cap_conn *conn, pchan = l2cap_global_chan_by_psm(BT_LISTEN, psm, &conn->hcon->src, &conn->hcon->dst, LE_LINK); if (!pchan) { - result = L2CAP_CR_BAD_PSM; + result = L2CAP_CR_LE_BAD_PSM; chan = NULL; goto response; } @@ -5469,28 +5477,28 @@ static int l2cap_le_connect_req(struct l2cap_conn *conn, if (!smp_sufficient_security(conn->hcon, pchan->sec_level, SMP_ALLOW_STK)) { - result = L2CAP_CR_AUTHENTICATION; + result = L2CAP_CR_LE_AUTHENTICATION; chan = NULL; goto response_unlock; } /* Check for valid dynamic CID range */ if (scid < L2CAP_CID_DYN_START || scid > L2CAP_CID_LE_DYN_END) { - result = L2CAP_CR_INVALID_SCID; + result = L2CAP_CR_LE_INVALID_SCID; chan = NULL; goto response_unlock; } /* Check if we already have channel with that dcid */ if (__l2cap_get_chan_by_dcid(conn, scid)) { - result = L2CAP_CR_SCID_IN_USE; + result = L2CAP_CR_LE_SCID_IN_USE; chan = NULL; goto response_unlock; } chan = pchan->ops->new_connection(pchan); if (!chan) { - result = L2CAP_CR_NO_MEM; + result = L2CAP_CR_LE_NO_MEM; goto response_unlock; } @@ -5526,7 +5534,7 @@ static int l2cap_le_connect_req(struct l2cap_conn *conn, chan->ops->defer(chan); } else { l2cap_chan_ready(chan); - result = L2CAP_CR_SUCCESS; + result = L2CAP_CR_LE_SUCCESS; } response_unlock: diff --git a/net/bluetooth/mgmt.c b/net/bluetooth/mgmt.c index 3bdc8f3ca259..ccce954f8146 100644 --- a/net/bluetooth/mgmt.c +++ b/net/bluetooth/mgmt.c @@ -2434,9 +2434,8 @@ static int unpair_device(struct sock *sk, struct hci_dev *hdev, void *data, /* LE address type */ addr_type = le_addr_type(cp->addr.type); - hci_remove_irk(hdev, &cp->addr.bdaddr, addr_type); - - err = hci_remove_ltk(hdev, &cp->addr.bdaddr, addr_type); + /* Abort any ongoing SMP pairing. Removes ltk and irk if they exist. */ + err = smp_cancel_and_remove_pairing(hdev, &cp->addr.bdaddr, addr_type); if (err < 0) { err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_UNPAIR_DEVICE, MGMT_STATUS_NOT_PAIRED, &rp, @@ -2450,8 +2449,6 @@ static int unpair_device(struct sock *sk, struct hci_dev *hdev, void *data, goto done; } - /* Abort any ongoing SMP pairing */ - smp_cancel_pairing(conn); /* Defer clearing up the connection parameters until closing to * give a chance of keeping them if a repairing happens. diff --git a/net/bluetooth/rfcomm/tty.c b/net/bluetooth/rfcomm/tty.c index 5e44d842cc5d..0c7d31c6c18c 100644 --- a/net/bluetooth/rfcomm/tty.c +++ b/net/bluetooth/rfcomm/tty.c @@ -839,18 +839,6 @@ static int rfcomm_tty_ioctl(struct tty_struct *tty, unsigned int cmd, unsigned l BT_DBG("TIOCMIWAIT"); break; - case TIOCGSERIAL: - BT_ERR("TIOCGSERIAL is not supported"); - return -ENOIOCTLCMD; - - case TIOCSSERIAL: - BT_ERR("TIOCSSERIAL is not supported"); - return -ENOIOCTLCMD; - - case TIOCSERGSTRUCT: - BT_ERR("TIOCSERGSTRUCT is not supported"); - return -ENOIOCTLCMD; - case TIOCSERGETLSR: BT_ERR("TIOCSERGETLSR is not supported"); return -ENOIOCTLCMD; diff --git a/net/bluetooth/smp.c b/net/bluetooth/smp.c index 090670fe385f..c822e626761b 100644 --- a/net/bluetooth/smp.c +++ b/net/bluetooth/smp.c @@ -622,7 +622,7 @@ static void smp_send_cmd(struct l2cap_conn *conn, u8 code, u16 len, void *data) memset(&msg, 0, sizeof(msg)); - iov_iter_kvec(&msg.msg_iter, WRITE | ITER_KVEC, iv, 2, 1 + len); + iov_iter_kvec(&msg.msg_iter, WRITE, iv, 2, 1 + len); l2cap_chan_send(chan, &msg, 1 + len); @@ -2419,30 +2419,51 @@ unlock: return ret; } -void smp_cancel_pairing(struct hci_conn *hcon) +int smp_cancel_and_remove_pairing(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 addr_type) { - struct l2cap_conn *conn = hcon->l2cap_data; + struct hci_conn *hcon; + struct l2cap_conn *conn; struct l2cap_chan *chan; struct smp_chan *smp; + int err; + + err = hci_remove_ltk(hdev, bdaddr, addr_type); + hci_remove_irk(hdev, bdaddr, addr_type); + + hcon = hci_conn_hash_lookup_le(hdev, bdaddr, addr_type); + if (!hcon) + goto done; + conn = hcon->l2cap_data; if (!conn) - return; + goto done; chan = conn->smp; if (!chan) - return; + goto done; l2cap_chan_lock(chan); smp = chan->data; if (smp) { + /* Set keys to NULL to make sure smp_failure() does not try to + * remove and free already invalidated rcu list entries. */ + smp->ltk = NULL; + smp->slave_ltk = NULL; + smp->remote_irk = NULL; + if (test_bit(SMP_FLAG_COMPLETE, &smp->flags)) smp_failure(conn, 0); else smp_failure(conn, SMP_UNSPECIFIED); + err = 0; } l2cap_chan_unlock(chan); + +done: + return err; } static int smp_cmd_encrypt_info(struct l2cap_conn *conn, struct sk_buff *skb) diff --git a/net/bluetooth/smp.h b/net/bluetooth/smp.h index 0ff6247eaa6c..121edadd5f8d 100644 --- a/net/bluetooth/smp.h +++ b/net/bluetooth/smp.h @@ -181,7 +181,8 @@ enum smp_key_pref { }; /* SMP Commands */ -void smp_cancel_pairing(struct hci_conn *hcon); +int smp_cancel_and_remove_pairing(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 addr_type); bool smp_sufficient_security(struct hci_conn *hcon, u8 sec_level, enum smp_key_pref key_pref); int smp_conn_security(struct hci_conn *hcon, __u8 sec_level); diff --git a/net/bpf/test_run.c b/net/bpf/test_run.c index f4078830ea50..c89c22c49015 100644 --- a/net/bpf/test_run.c +++ b/net/bpf/test_run.c @@ -10,9 +10,11 @@ #include <linux/etherdevice.h> #include <linux/filter.h> #include <linux/sched/signal.h> +#include <net/sock.h> +#include <net/tcp.h> static __always_inline u32 bpf_test_run_one(struct bpf_prog *prog, void *ctx, - struct bpf_cgroup_storage *storage) + struct bpf_cgroup_storage *storage[MAX_BPF_CGROUP_STORAGE_TYPE]) { u32 ret; @@ -28,13 +30,20 @@ static __always_inline u32 bpf_test_run_one(struct bpf_prog *prog, void *ctx, static u32 bpf_test_run(struct bpf_prog *prog, void *ctx, u32 repeat, u32 *time) { - struct bpf_cgroup_storage *storage = NULL; + struct bpf_cgroup_storage *storage[MAX_BPF_CGROUP_STORAGE_TYPE] = { 0 }; + enum bpf_cgroup_storage_type stype; u64 time_start, time_spent = 0; u32 ret = 0, i; - storage = bpf_cgroup_storage_alloc(prog); - if (IS_ERR(storage)) - return PTR_ERR(storage); + for_each_cgroup_storage_type(stype) { + storage[stype] = bpf_cgroup_storage_alloc(prog, stype); + if (IS_ERR(storage[stype])) { + storage[stype] = NULL; + for_each_cgroup_storage_type(stype) + bpf_cgroup_storage_free(storage[stype]); + return -ENOMEM; + } + } if (!repeat) repeat = 1; @@ -53,7 +62,8 @@ static u32 bpf_test_run(struct bpf_prog *prog, void *ctx, u32 repeat, u32 *time) do_div(time_spent, repeat); *time = time_spent > U32_MAX ? U32_MAX : (u32)time_spent; - bpf_cgroup_storage_free(storage); + for_each_cgroup_storage_type(stype) + bpf_cgroup_storage_free(storage[stype]); return ret; } @@ -107,6 +117,7 @@ int bpf_prog_test_run_skb(struct bpf_prog *prog, const union bpf_attr *kattr, u32 retval, duration; int hh_len = ETH_HLEN; struct sk_buff *skb; + struct sock *sk; void *data; int ret; @@ -129,11 +140,21 @@ int bpf_prog_test_run_skb(struct bpf_prog *prog, const union bpf_attr *kattr, break; } + sk = kzalloc(sizeof(struct sock), GFP_USER); + if (!sk) { + kfree(data); + return -ENOMEM; + } + sock_net_set(sk, current->nsproxy->net_ns); + sock_init_data(NULL, sk); + skb = build_skb(data, 0); if (!skb) { kfree(data); + kfree(sk); return -ENOMEM; } + skb->sk = sk; skb_reserve(skb, NET_SKB_PAD + NET_IP_ALIGN); __skb_put(skb, size); @@ -151,6 +172,7 @@ int bpf_prog_test_run_skb(struct bpf_prog *prog, const union bpf_attr *kattr, if (pskb_expand_head(skb, nhead, 0, GFP_USER)) { kfree_skb(skb); + kfree(sk); return -ENOMEM; } } @@ -163,6 +185,7 @@ int bpf_prog_test_run_skb(struct bpf_prog *prog, const union bpf_attr *kattr, size = skb_headlen(skb); ret = bpf_test_finish(kattr, uattr, skb->data, size, retval, duration); kfree_skb(skb); + kfree(sk); return ret; } diff --git a/net/bpfilter/bpfilter_kern.c b/net/bpfilter/bpfilter_kern.c index f0fc182d3db7..7acfc83087d5 100644 --- a/net/bpfilter/bpfilter_kern.c +++ b/net/bpfilter/bpfilter_kern.c @@ -23,9 +23,11 @@ static void shutdown_umh(struct umh_info *info) if (!info->pid) return; - tsk = pid_task(find_vpid(info->pid), PIDTYPE_PID); - if (tsk) + tsk = get_pid_task(find_vpid(info->pid), PIDTYPE_PID); + if (tsk) { force_sig(SIGKILL, tsk); + put_task_struct(tsk); + } fput(info->pipe_to_umh); fput(info->pipe_from_umh); info->pid = 0; @@ -59,7 +61,7 @@ static int __bpfilter_process_sockopt(struct sock *sk, int optname, req.is_set = is_set; req.pid = current->pid; req.cmd = optname; - req.addr = (long)optval; + req.addr = (long __force __user)optval; req.len = optlen; mutex_lock(&bpfilter_lock); if (!info.pid) @@ -90,6 +92,7 @@ static int __init load_umh(void) int err; /* fork usermode process */ + info.cmdline = "bpfilter_umh"; err = fork_usermode_blob(&bpfilter_umh_start, &bpfilter_umh_end - &bpfilter_umh_start, &info); @@ -98,7 +101,7 @@ static int __init load_umh(void) pr_info("Loaded bpfilter_umh pid %d\n", info.pid); /* health check that usermode process started correctly */ - if (__bpfilter_process_sockopt(NULL, 0, 0, 0, 0) != 0) { + if (__bpfilter_process_sockopt(NULL, 0, NULL, 0, 0) != 0) { stop_umh(); return -EFAULT; } diff --git a/net/bridge/Kconfig b/net/bridge/Kconfig index aa0d3b2f1bb7..3625d6ade45c 100644 --- a/net/bridge/Kconfig +++ b/net/bridge/Kconfig @@ -17,7 +17,7 @@ config BRIDGE other third party bridge products. In order to use the Ethernet bridge, you'll need the bridge - configuration tools; see <file:Documentation/networking/bridge.txt> + configuration tools; see <file:Documentation/networking/bridge.rst> for location. Please read the Bridge mini-HOWTO for more information. diff --git a/net/bridge/br.c b/net/bridge/br.c index e411e40333e2..360ad66c21e9 100644 --- a/net/bridge/br.c +++ b/net/bridge/br.c @@ -151,7 +151,7 @@ static int br_switchdev_event(struct notifier_block *unused, break; } br_fdb_offloaded_set(br, p, fdb_info->addr, - fdb_info->vid); + fdb_info->vid, true); break; case SWITCHDEV_FDB_DEL_TO_BRIDGE: fdb_info = ptr; @@ -163,7 +163,7 @@ static int br_switchdev_event(struct notifier_block *unused, case SWITCHDEV_FDB_OFFLOADED: fdb_info = ptr; br_fdb_offloaded_set(br, p, fdb_info->addr, - fdb_info->vid); + fdb_info->vid, fdb_info->offloaded); break; } diff --git a/net/bridge/br_device.c b/net/bridge/br_device.c index e053a4e43758..c6abf927f0c9 100644 --- a/net/bridge/br_device.c +++ b/net/bridge/br_device.c @@ -344,7 +344,7 @@ void br_netpoll_disable(struct net_bridge_port *p) p->np = NULL; - __netpoll_free_async(np); + __netpoll_free(np); } #endif diff --git a/net/bridge/br_fdb.c b/net/bridge/br_fdb.c index 74331690a390..e56ba3912a90 100644 --- a/net/bridge/br_fdb.c +++ b/net/bridge/br_fdb.c @@ -1152,7 +1152,7 @@ int br_fdb_external_learn_del(struct net_bridge *br, struct net_bridge_port *p, } void br_fdb_offloaded_set(struct net_bridge *br, struct net_bridge_port *p, - const unsigned char *addr, u16 vid) + const unsigned char *addr, u16 vid, bool offloaded) { struct net_bridge_fdb_entry *fdb; @@ -1160,7 +1160,7 @@ void br_fdb_offloaded_set(struct net_bridge *br, struct net_bridge_port *p, fdb = br_fdb_find(br, addr, vid); if (fdb) - fdb->offloaded = 1; + fdb->offloaded = offloaded; spin_unlock_bh(&br->hash_lock); } diff --git a/net/bridge/br_mdb.c b/net/bridge/br_mdb.c index a4a848bf827b..a7ea2d431714 100644 --- a/net/bridge/br_mdb.c +++ b/net/bridge/br_mdb.c @@ -162,6 +162,29 @@ out: return err; } +static int br_mdb_valid_dump_req(const struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct br_port_msg *bpm; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*bpm))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid header for mdb dump request"); + return -EINVAL; + } + + bpm = nlmsg_data(nlh); + if (bpm->ifindex) { + NL_SET_ERR_MSG_MOD(extack, "Filtering by device index is not supported for mdb dump request"); + return -EINVAL; + } + if (nlmsg_attrlen(nlh, sizeof(*bpm))) { + NL_SET_ERR_MSG(extack, "Invalid data after header in mdb dump request"); + return -EINVAL; + } + + return 0; +} + static int br_mdb_dump(struct sk_buff *skb, struct netlink_callback *cb) { struct net_device *dev; @@ -169,6 +192,13 @@ static int br_mdb_dump(struct sk_buff *skb, struct netlink_callback *cb) struct nlmsghdr *nlh = NULL; int idx = 0, s_idx; + if (cb->strict_check) { + int err = br_mdb_valid_dump_req(cb->nlh, cb->extack); + + if (err < 0) + return err; + } + s_idx = cb->args[0]; rcu_read_lock(); diff --git a/net/bridge/br_multicast.c b/net/bridge/br_multicast.c index 024139b51d3a..6bac0d6b7b94 100644 --- a/net/bridge/br_multicast.c +++ b/net/bridge/br_multicast.c @@ -1422,7 +1422,14 @@ static void br_multicast_query_received(struct net_bridge *br, return; br_multicast_update_query_timer(br, query, max_delay); - br_multicast_mark_router(br, port); + + /* Based on RFC4541, section 2.1.1 IGMP Forwarding Rules, + * the arrival port for IGMP Queries where the source address + * is 0.0.0.0 should not be added to router port list. + */ + if ((saddr->proto == htons(ETH_P_IP) && saddr->u.ip4) || + saddr->proto == htons(ETH_P_IPV6)) + br_multicast_mark_router(br, port); } static void br_ip4_multicast_query(struct net_bridge *br, diff --git a/net/bridge/br_netfilter_hooks.c b/net/bridge/br_netfilter_hooks.c index e0a3b038d052..b1b5e8516724 100644 --- a/net/bridge/br_netfilter_hooks.c +++ b/net/bridge/br_netfilter_hooks.c @@ -836,7 +836,8 @@ static unsigned int ip_sabotage_in(void *priv, struct sk_buff *skb, const struct nf_hook_state *state) { - if (skb->nf_bridge && !skb->nf_bridge->in_prerouting) { + if (skb->nf_bridge && !skb->nf_bridge->in_prerouting && + !netif_is_l3_master(skb->dev)) { state->okfn(state->net, state->sk, skb); return NF_STOLEN; } diff --git a/net/bridge/br_netlink.c b/net/bridge/br_netlink.c index e5a5bc5d5232..3345f1984542 100644 --- a/net/bridge/br_netlink.c +++ b/net/bridge/br_netlink.c @@ -1034,6 +1034,7 @@ static const struct nla_policy br_policy[IFLA_BR_MAX + 1] = { [IFLA_BR_MCAST_STATS_ENABLED] = { .type = NLA_U8 }, [IFLA_BR_MCAST_IGMP_VERSION] = { .type = NLA_U8 }, [IFLA_BR_MCAST_MLD_VERSION] = { .type = NLA_U8 }, + [IFLA_BR_VLAN_STATS_PER_PORT] = { .type = NLA_U8 }, }; static int br_changelink(struct net_device *brdev, struct nlattr *tb[], @@ -1114,6 +1115,14 @@ static int br_changelink(struct net_device *brdev, struct nlattr *tb[], if (err) return err; } + + if (data[IFLA_BR_VLAN_STATS_PER_PORT]) { + __u8 per_port = nla_get_u8(data[IFLA_BR_VLAN_STATS_PER_PORT]); + + err = br_vlan_set_stats_per_port(br, per_port); + if (err) + return err; + } #endif if (data[IFLA_BR_GROUP_FWD_MASK]) { @@ -1327,6 +1336,7 @@ static size_t br_get_size(const struct net_device *brdev) nla_total_size(sizeof(__be16)) + /* IFLA_BR_VLAN_PROTOCOL */ nla_total_size(sizeof(u16)) + /* IFLA_BR_VLAN_DEFAULT_PVID */ nla_total_size(sizeof(u8)) + /* IFLA_BR_VLAN_STATS_ENABLED */ + nla_total_size(sizeof(u8)) + /* IFLA_BR_VLAN_STATS_PER_PORT */ #endif nla_total_size(sizeof(u16)) + /* IFLA_BR_GROUP_FWD_MASK */ nla_total_size(sizeof(struct ifla_bridge_id)) + /* IFLA_BR_ROOT_ID */ @@ -1417,7 +1427,9 @@ static int br_fill_info(struct sk_buff *skb, const struct net_device *brdev) if (nla_put_be16(skb, IFLA_BR_VLAN_PROTOCOL, br->vlan_proto) || nla_put_u16(skb, IFLA_BR_VLAN_DEFAULT_PVID, br->default_pvid) || nla_put_u8(skb, IFLA_BR_VLAN_STATS_ENABLED, - br_opt_get(br, BROPT_VLAN_STATS_ENABLED))) + br_opt_get(br, BROPT_VLAN_STATS_ENABLED)) || + nla_put_u8(skb, IFLA_BR_VLAN_STATS_PER_PORT, + br_opt_get(br, IFLA_BR_VLAN_STATS_PER_PORT))) return -EMSGSIZE; #endif #ifdef CONFIG_BRIDGE_IGMP_SNOOPING diff --git a/net/bridge/br_private.h b/net/bridge/br_private.h index 57229b9d800f..2920e06a5403 100644 --- a/net/bridge/br_private.h +++ b/net/bridge/br_private.h @@ -320,6 +320,7 @@ enum net_bridge_opts { BROPT_HAS_IPV6_ADDR, BROPT_NEIGH_SUPPRESS_ENABLED, BROPT_MTU_SET_BY_USER, + BROPT_VLAN_STATS_PER_PORT, }; struct net_bridge { @@ -573,7 +574,7 @@ int br_fdb_external_learn_del(struct net_bridge *br, struct net_bridge_port *p, const unsigned char *addr, u16 vid, bool swdev_notify); void br_fdb_offloaded_set(struct net_bridge *br, struct net_bridge_port *p, - const unsigned char *addr, u16 vid); + const unsigned char *addr, u16 vid, bool offloaded); /* br_forward.c */ enum br_pkt_type { @@ -859,6 +860,7 @@ int br_vlan_filter_toggle(struct net_bridge *br, unsigned long val); int __br_vlan_set_proto(struct net_bridge *br, __be16 proto); int br_vlan_set_proto(struct net_bridge *br, unsigned long val); int br_vlan_set_stats(struct net_bridge *br, unsigned long val); +int br_vlan_set_stats_per_port(struct net_bridge *br, unsigned long val); int br_vlan_init(struct net_bridge *br); int br_vlan_set_default_pvid(struct net_bridge *br, unsigned long val); int __br_vlan_set_default_pvid(struct net_bridge *br, u16 pvid); diff --git a/net/bridge/br_switchdev.c b/net/bridge/br_switchdev.c index d77f807420c4..b993df770675 100644 --- a/net/bridge/br_switchdev.c +++ b/net/bridge/br_switchdev.c @@ -103,7 +103,7 @@ int br_switchdev_set_port_flag(struct net_bridge_port *p, static void br_switchdev_fdb_call_notifiers(bool adding, const unsigned char *mac, u16 vid, struct net_device *dev, - bool added_by_user) + bool added_by_user, bool offloaded) { struct switchdev_notifier_fdb_info info; unsigned long notifier_type; @@ -111,6 +111,7 @@ br_switchdev_fdb_call_notifiers(bool adding, const unsigned char *mac, info.addr = mac; info.vid = vid; info.added_by_user = added_by_user; + info.offloaded = offloaded; notifier_type = adding ? SWITCHDEV_FDB_ADD_TO_DEVICE : SWITCHDEV_FDB_DEL_TO_DEVICE; call_switchdev_notifiers(notifier_type, dev, &info.info); } @@ -126,13 +127,15 @@ br_switchdev_fdb_notify(const struct net_bridge_fdb_entry *fdb, int type) br_switchdev_fdb_call_notifiers(false, fdb->key.addr.addr, fdb->key.vlan_id, fdb->dst->dev, - fdb->added_by_user); + fdb->added_by_user, + fdb->offloaded); break; case RTM_NEWNEIGH: br_switchdev_fdb_call_notifiers(true, fdb->key.addr.addr, fdb->key.vlan_id, fdb->dst->dev, - fdb->added_by_user); + fdb->added_by_user, + fdb->offloaded); break; } } diff --git a/net/bridge/br_sysfs_br.c b/net/bridge/br_sysfs_br.c index c93c5724609e..60182bef6341 100644 --- a/net/bridge/br_sysfs_br.c +++ b/net/bridge/br_sysfs_br.c @@ -803,6 +803,22 @@ static ssize_t vlan_stats_enabled_store(struct device *d, return store_bridge_parm(d, buf, len, br_vlan_set_stats); } static DEVICE_ATTR_RW(vlan_stats_enabled); + +static ssize_t vlan_stats_per_port_show(struct device *d, + struct device_attribute *attr, + char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%u\n", br_opt_get(br, BROPT_VLAN_STATS_PER_PORT)); +} + +static ssize_t vlan_stats_per_port_store(struct device *d, + struct device_attribute *attr, + const char *buf, size_t len) +{ + return store_bridge_parm(d, buf, len, br_vlan_set_stats_per_port); +} +static DEVICE_ATTR_RW(vlan_stats_per_port); #endif static struct attribute *bridge_attrs[] = { @@ -856,6 +872,7 @@ static struct attribute *bridge_attrs[] = { &dev_attr_vlan_protocol.attr, &dev_attr_default_pvid.attr, &dev_attr_vlan_stats_enabled.attr, + &dev_attr_vlan_stats_per_port.attr, #endif NULL }; diff --git a/net/bridge/br_vlan.c b/net/bridge/br_vlan.c index 5942e03dd845..8c9297a01947 100644 --- a/net/bridge/br_vlan.c +++ b/net/bridge/br_vlan.c @@ -190,6 +190,19 @@ static void br_vlan_put_master(struct net_bridge_vlan *masterv) } } +static void nbp_vlan_rcu_free(struct rcu_head *rcu) +{ + struct net_bridge_vlan *v; + + v = container_of(rcu, struct net_bridge_vlan, rcu); + WARN_ON(br_vlan_is_master(v)); + /* if we had per-port stats configured then free them here */ + if (v->brvlan->stats != v->stats) + free_percpu(v->stats); + v->stats = NULL; + kfree(v); +} + /* This is the shared VLAN add function which works for both ports and bridge * devices. There are four possible calls to this function in terms of the * vlan entry type: @@ -245,7 +258,15 @@ static int __vlan_add(struct net_bridge_vlan *v, u16 flags) if (!masterv) goto out_filt; v->brvlan = masterv; - v->stats = masterv->stats; + if (br_opt_get(br, BROPT_VLAN_STATS_PER_PORT)) { + v->stats = netdev_alloc_pcpu_stats(struct br_vlan_stats); + if (!v->stats) { + err = -ENOMEM; + goto out_filt; + } + } else { + v->stats = masterv->stats; + } } else { err = br_switchdev_port_vlan_add(dev, v->vid, flags); if (err && err != -EOPNOTSUPP) @@ -282,6 +303,10 @@ out_filt: if (p) { __vlan_vid_del(dev, br, v->vid); if (masterv) { + if (v->stats && masterv->stats != v->stats) + free_percpu(v->stats); + v->stats = NULL; + br_vlan_put_master(masterv); v->brvlan = NULL; } @@ -329,7 +354,7 @@ static int __vlan_del(struct net_bridge_vlan *v) rhashtable_remove_fast(&vg->vlan_hash, &v->vnode, br_vlan_rht_params); __vlan_del_list(v); - kfree_rcu(v, rcu); + call_rcu(&v->rcu, nbp_vlan_rcu_free); } br_vlan_put_master(masterv); @@ -830,6 +855,30 @@ int br_vlan_set_stats(struct net_bridge *br, unsigned long val) return 0; } +int br_vlan_set_stats_per_port(struct net_bridge *br, unsigned long val) +{ + struct net_bridge_port *p; + + /* allow to change the option if there are no port vlans configured */ + list_for_each_entry(p, &br->port_list, list) { + struct net_bridge_vlan_group *vg = nbp_vlan_group(p); + + if (vg->num_vlans) + return -EBUSY; + } + + switch (val) { + case 0: + case 1: + br_opt_toggle(br, BROPT_VLAN_STATS_PER_PORT, !!val); + break; + default: + return -EINVAL; + } + + return 0; +} + static bool vlan_default_pvid(struct net_bridge_vlan_group *vg, u16 vid) { struct net_bridge_vlan *v; diff --git a/net/caif/caif_socket.c b/net/caif/caif_socket.c index d18965f3291f..416717c57cd1 100644 --- a/net/caif/caif_socket.c +++ b/net/caif/caif_socket.c @@ -941,7 +941,7 @@ static __poll_t caif_poll(struct file *file, __poll_t mask; struct caifsock *cf_sk = container_of(sk, struct caifsock, sk); - sock_poll_wait(file, wait); + sock_poll_wait(file, sock, wait); mask = 0; /* exceptional events? */ diff --git a/net/ceph/crypto.c b/net/ceph/crypto.c index 02172c408ff2..5d6724cee38f 100644 --- a/net/ceph/crypto.c +++ b/net/ceph/crypto.c @@ -46,9 +46,9 @@ static int set_secret(struct ceph_crypto_key *key, void *buf) goto fail; } - /* crypto_alloc_skcipher() allocates with GFP_KERNEL */ + /* crypto_alloc_sync_skcipher() allocates with GFP_KERNEL */ noio_flag = memalloc_noio_save(); - key->tfm = crypto_alloc_skcipher("cbc(aes)", 0, CRYPTO_ALG_ASYNC); + key->tfm = crypto_alloc_sync_skcipher("cbc(aes)", 0, 0); memalloc_noio_restore(noio_flag); if (IS_ERR(key->tfm)) { ret = PTR_ERR(key->tfm); @@ -56,7 +56,7 @@ static int set_secret(struct ceph_crypto_key *key, void *buf) goto fail; } - ret = crypto_skcipher_setkey(key->tfm, key->key, key->len); + ret = crypto_sync_skcipher_setkey(key->tfm, key->key, key->len); if (ret) goto fail; @@ -136,7 +136,7 @@ void ceph_crypto_key_destroy(struct ceph_crypto_key *key) if (key) { kfree(key->key); key->key = NULL; - crypto_free_skcipher(key->tfm); + crypto_free_sync_skcipher(key->tfm); key->tfm = NULL; } } @@ -216,7 +216,7 @@ static void teardown_sgtable(struct sg_table *sgt) static int ceph_aes_crypt(const struct ceph_crypto_key *key, bool encrypt, void *buf, int buf_len, int in_len, int *pout_len) { - SKCIPHER_REQUEST_ON_STACK(req, key->tfm); + SYNC_SKCIPHER_REQUEST_ON_STACK(req, key->tfm); struct sg_table sgt; struct scatterlist prealloc_sg; char iv[AES_BLOCK_SIZE] __aligned(8); @@ -232,7 +232,7 @@ static int ceph_aes_crypt(const struct ceph_crypto_key *key, bool encrypt, return ret; memcpy(iv, aes_iv, AES_BLOCK_SIZE); - skcipher_request_set_tfm(req, key->tfm); + skcipher_request_set_sync_tfm(req, key->tfm); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, sgt.sgl, sgt.sgl, crypt_len, iv); diff --git a/net/ceph/crypto.h b/net/ceph/crypto.h index bb45c7d43739..96ef4d860bc9 100644 --- a/net/ceph/crypto.h +++ b/net/ceph/crypto.h @@ -13,7 +13,7 @@ struct ceph_crypto_key { struct ceph_timespec created; int len; void *key; - struct crypto_skcipher *tfm; + struct crypto_sync_skcipher *tfm; }; int ceph_crypto_key_clone(struct ceph_crypto_key *dst, diff --git a/net/ceph/messenger.c b/net/ceph/messenger.c index 0a187196aeed..57fcc6b4bf6e 100644 --- a/net/ceph/messenger.c +++ b/net/ceph/messenger.c @@ -156,7 +156,6 @@ static bool con_flag_test_and_set(struct ceph_connection *con, /* Slab caches for frequently-allocated structures */ static struct kmem_cache *ceph_msg_cache; -static struct kmem_cache *ceph_msg_data_cache; /* static tag bytes (protocol control messages) */ static char tag_msg = CEPH_MSGR_TAG_MSG; @@ -235,23 +234,11 @@ static int ceph_msgr_slab_init(void) if (!ceph_msg_cache) return -ENOMEM; - BUG_ON(ceph_msg_data_cache); - ceph_msg_data_cache = KMEM_CACHE(ceph_msg_data, 0); - if (ceph_msg_data_cache) - return 0; - - kmem_cache_destroy(ceph_msg_cache); - ceph_msg_cache = NULL; - - return -ENOMEM; + return 0; } static void ceph_msgr_slab_exit(void) { - BUG_ON(!ceph_msg_data_cache); - kmem_cache_destroy(ceph_msg_data_cache); - ceph_msg_data_cache = NULL; - BUG_ON(!ceph_msg_cache); kmem_cache_destroy(ceph_msg_cache); ceph_msg_cache = NULL; @@ -526,7 +513,7 @@ static int ceph_tcp_recvmsg(struct socket *sock, void *buf, size_t len) if (!buf) msg.msg_flags |= MSG_TRUNC; - iov_iter_kvec(&msg.msg_iter, READ | ITER_KVEC, &iov, 1, len); + iov_iter_kvec(&msg.msg_iter, READ, &iov, 1, len); r = sock_recvmsg(sock, &msg, msg.msg_flags); if (r == -EAGAIN) r = 0; @@ -545,7 +532,7 @@ static int ceph_tcp_recvpage(struct socket *sock, struct page *page, int r; BUG_ON(page_offset + length > PAGE_SIZE); - iov_iter_bvec(&msg.msg_iter, READ | ITER_BVEC, &bvec, 1, length); + iov_iter_bvec(&msg.msg_iter, READ, &bvec, 1, length); r = sock_recvmsg(sock, &msg, msg.msg_flags); if (r == -EAGAIN) r = 0; @@ -607,7 +594,7 @@ static int ceph_tcp_sendpage(struct socket *sock, struct page *page, else msg.msg_flags |= MSG_EOR; /* superfluous, but what the hell */ - iov_iter_bvec(&msg.msg_iter, WRITE | ITER_BVEC, &bvec, 1, size); + iov_iter_bvec(&msg.msg_iter, WRITE, &bvec, 1, size); ret = sock_sendmsg(sock, &msg); if (ret == -EAGAIN) ret = 0; @@ -1141,16 +1128,13 @@ static void __ceph_msg_data_cursor_init(struct ceph_msg_data_cursor *cursor) static void ceph_msg_data_cursor_init(struct ceph_msg *msg, size_t length) { struct ceph_msg_data_cursor *cursor = &msg->cursor; - struct ceph_msg_data *data; BUG_ON(!length); BUG_ON(length > msg->data_length); - BUG_ON(list_empty(&msg->data)); + BUG_ON(!msg->num_data_items); - cursor->data_head = &msg->data; cursor->total_resid = length; - data = list_first_entry(&msg->data, struct ceph_msg_data, links); - cursor->data = data; + cursor->data = msg->data; __ceph_msg_data_cursor_init(cursor); } @@ -1231,8 +1215,7 @@ static void ceph_msg_data_advance(struct ceph_msg_data_cursor *cursor, if (!cursor->resid && cursor->total_resid) { WARN_ON(!cursor->last_piece); - BUG_ON(list_is_last(&cursor->data->links, cursor->data_head)); - cursor->data = list_next_entry(cursor->data, links); + cursor->data++; __ceph_msg_data_cursor_init(cursor); new_piece = true; } @@ -1248,9 +1231,6 @@ static size_t sizeof_footer(struct ceph_connection *con) static void prepare_message_data(struct ceph_msg *msg, u32 data_len) { - BUG_ON(!msg); - BUG_ON(!data_len); - /* Initialize data cursor */ ceph_msg_data_cursor_init(msg, (size_t)data_len); @@ -1590,7 +1570,7 @@ static int write_partial_message_data(struct ceph_connection *con) dout("%s %p msg %p\n", __func__, con, msg); - if (list_empty(&msg->data)) + if (!msg->num_data_items) return -EINVAL; /* @@ -2347,8 +2327,7 @@ static int read_partial_msg_data(struct ceph_connection *con) u32 crc = 0; int ret; - BUG_ON(!msg); - if (list_empty(&msg->data)) + if (!msg->num_data_items) return -EIO; if (do_datacrc) @@ -3256,32 +3235,16 @@ bool ceph_con_keepalive_expired(struct ceph_connection *con, return false; } -static struct ceph_msg_data *ceph_msg_data_create(enum ceph_msg_data_type type) +static struct ceph_msg_data *ceph_msg_data_add(struct ceph_msg *msg) { - struct ceph_msg_data *data; - - if (WARN_ON(!ceph_msg_data_type_valid(type))) - return NULL; - - data = kmem_cache_zalloc(ceph_msg_data_cache, GFP_NOFS); - if (!data) - return NULL; - - data->type = type; - INIT_LIST_HEAD(&data->links); - - return data; + BUG_ON(msg->num_data_items >= msg->max_data_items); + return &msg->data[msg->num_data_items++]; } static void ceph_msg_data_destroy(struct ceph_msg_data *data) { - if (!data) - return; - - WARN_ON(!list_empty(&data->links)); if (data->type == CEPH_MSG_DATA_PAGELIST) ceph_pagelist_release(data->pagelist); - kmem_cache_free(ceph_msg_data_cache, data); } void ceph_msg_data_add_pages(struct ceph_msg *msg, struct page **pages, @@ -3292,13 +3255,12 @@ void ceph_msg_data_add_pages(struct ceph_msg *msg, struct page **pages, BUG_ON(!pages); BUG_ON(!length); - data = ceph_msg_data_create(CEPH_MSG_DATA_PAGES); - BUG_ON(!data); + data = ceph_msg_data_add(msg); + data->type = CEPH_MSG_DATA_PAGES; data->pages = pages; data->length = length; data->alignment = alignment & ~PAGE_MASK; - list_add_tail(&data->links, &msg->data); msg->data_length += length; } EXPORT_SYMBOL(ceph_msg_data_add_pages); @@ -3311,11 +3273,11 @@ void ceph_msg_data_add_pagelist(struct ceph_msg *msg, BUG_ON(!pagelist); BUG_ON(!pagelist->length); - data = ceph_msg_data_create(CEPH_MSG_DATA_PAGELIST); - BUG_ON(!data); + data = ceph_msg_data_add(msg); + data->type = CEPH_MSG_DATA_PAGELIST; + refcount_inc(&pagelist->refcnt); data->pagelist = pagelist; - list_add_tail(&data->links, &msg->data); msg->data_length += pagelist->length; } EXPORT_SYMBOL(ceph_msg_data_add_pagelist); @@ -3326,12 +3288,11 @@ void ceph_msg_data_add_bio(struct ceph_msg *msg, struct ceph_bio_iter *bio_pos, { struct ceph_msg_data *data; - data = ceph_msg_data_create(CEPH_MSG_DATA_BIO); - BUG_ON(!data); + data = ceph_msg_data_add(msg); + data->type = CEPH_MSG_DATA_BIO; data->bio_pos = *bio_pos; data->bio_length = length; - list_add_tail(&data->links, &msg->data); msg->data_length += length; } EXPORT_SYMBOL(ceph_msg_data_add_bio); @@ -3342,11 +3303,10 @@ void ceph_msg_data_add_bvecs(struct ceph_msg *msg, { struct ceph_msg_data *data; - data = ceph_msg_data_create(CEPH_MSG_DATA_BVECS); - BUG_ON(!data); + data = ceph_msg_data_add(msg); + data->type = CEPH_MSG_DATA_BVECS; data->bvec_pos = *bvec_pos; - list_add_tail(&data->links, &msg->data); msg->data_length += bvec_pos->iter.bi_size; } EXPORT_SYMBOL(ceph_msg_data_add_bvecs); @@ -3355,8 +3315,8 @@ EXPORT_SYMBOL(ceph_msg_data_add_bvecs); * construct a new message with given type, size * the new msg has a ref count of 1. */ -struct ceph_msg *ceph_msg_new(int type, int front_len, gfp_t flags, - bool can_fail) +struct ceph_msg *ceph_msg_new2(int type, int front_len, int max_data_items, + gfp_t flags, bool can_fail) { struct ceph_msg *m; @@ -3370,7 +3330,6 @@ struct ceph_msg *ceph_msg_new(int type, int front_len, gfp_t flags, INIT_LIST_HEAD(&m->list_head); kref_init(&m->kref); - INIT_LIST_HEAD(&m->data); /* front */ if (front_len) { @@ -3385,6 +3344,15 @@ struct ceph_msg *ceph_msg_new(int type, int front_len, gfp_t flags, } m->front_alloc_len = m->front.iov_len = front_len; + if (max_data_items) { + m->data = kmalloc_array(max_data_items, sizeof(*m->data), + flags); + if (!m->data) + goto out2; + + m->max_data_items = max_data_items; + } + dout("ceph_msg_new %p front %d\n", m, front_len); return m; @@ -3401,6 +3369,13 @@ out: } return NULL; } +EXPORT_SYMBOL(ceph_msg_new2); + +struct ceph_msg *ceph_msg_new(int type, int front_len, gfp_t flags, + bool can_fail) +{ + return ceph_msg_new2(type, front_len, 0, flags, can_fail); +} EXPORT_SYMBOL(ceph_msg_new); /* @@ -3496,13 +3471,14 @@ static void ceph_msg_free(struct ceph_msg *m) { dout("%s %p\n", __func__, m); kvfree(m->front.iov_base); + kfree(m->data); kmem_cache_free(ceph_msg_cache, m); } static void ceph_msg_release(struct kref *kref) { struct ceph_msg *m = container_of(kref, struct ceph_msg, kref); - struct ceph_msg_data *data, *next; + int i; dout("%s %p\n", __func__, m); WARN_ON(!list_empty(&m->list_head)); @@ -3515,11 +3491,8 @@ static void ceph_msg_release(struct kref *kref) m->middle = NULL; } - list_for_each_entry_safe(data, next, &m->data, links) { - list_del_init(&data->links); - ceph_msg_data_destroy(data); - } - m->data_length = 0; + for (i = 0; i < m->num_data_items; i++) + ceph_msg_data_destroy(&m->data[i]); if (m->pool) ceph_msgpool_put(m->pool, m); diff --git a/net/ceph/msgpool.c b/net/ceph/msgpool.c index 72571535883f..e3ecb80cd182 100644 --- a/net/ceph/msgpool.c +++ b/net/ceph/msgpool.c @@ -14,7 +14,8 @@ static void *msgpool_alloc(gfp_t gfp_mask, void *arg) struct ceph_msgpool *pool = arg; struct ceph_msg *msg; - msg = ceph_msg_new(pool->type, pool->front_len, gfp_mask, true); + msg = ceph_msg_new2(pool->type, pool->front_len, pool->max_data_items, + gfp_mask, true); if (!msg) { dout("msgpool_alloc %s failed\n", pool->name); } else { @@ -35,11 +36,13 @@ static void msgpool_free(void *element, void *arg) } int ceph_msgpool_init(struct ceph_msgpool *pool, int type, - int front_len, int size, bool blocking, const char *name) + int front_len, int max_data_items, int size, + const char *name) { dout("msgpool %s init\n", name); pool->type = type; pool->front_len = front_len; + pool->max_data_items = max_data_items; pool->pool = mempool_create(size, msgpool_alloc, msgpool_free, pool); if (!pool->pool) return -ENOMEM; @@ -53,18 +56,21 @@ void ceph_msgpool_destroy(struct ceph_msgpool *pool) mempool_destroy(pool->pool); } -struct ceph_msg *ceph_msgpool_get(struct ceph_msgpool *pool, - int front_len) +struct ceph_msg *ceph_msgpool_get(struct ceph_msgpool *pool, int front_len, + int max_data_items) { struct ceph_msg *msg; - if (front_len > pool->front_len) { - dout("msgpool_get %s need front %d, pool size is %d\n", - pool->name, front_len, pool->front_len); - WARN_ON(1); + if (front_len > pool->front_len || + max_data_items > pool->max_data_items) { + pr_warn_ratelimited("%s need %d/%d, pool %s has %d/%d\n", + __func__, front_len, max_data_items, pool->name, + pool->front_len, pool->max_data_items); + WARN_ON_ONCE(1); /* try to alloc a fresh message */ - return ceph_msg_new(pool->type, front_len, GFP_NOFS, false); + return ceph_msg_new2(pool->type, front_len, max_data_items, + GFP_NOFS, false); } msg = mempool_alloc(pool->pool, GFP_NOFS); @@ -80,6 +86,9 @@ void ceph_msgpool_put(struct ceph_msgpool *pool, struct ceph_msg *msg) msg->front.iov_len = pool->front_len; msg->hdr.front_len = cpu_to_le32(pool->front_len); + msg->data_length = 0; + msg->num_data_items = 0; + kref_init(&msg->kref); /* retake single ref */ mempool_free(msg, pool->pool); } diff --git a/net/ceph/osd_client.c b/net/ceph/osd_client.c index 60934bd8796c..d23a9f81f3d7 100644 --- a/net/ceph/osd_client.c +++ b/net/ceph/osd_client.c @@ -126,6 +126,9 @@ static void ceph_osd_data_init(struct ceph_osd_data *osd_data) osd_data->type = CEPH_OSD_DATA_TYPE_NONE; } +/* + * Consumes @pages if @own_pages is true. + */ static void ceph_osd_data_pages_init(struct ceph_osd_data *osd_data, struct page **pages, u64 length, u32 alignment, bool pages_from_pool, bool own_pages) @@ -138,6 +141,9 @@ static void ceph_osd_data_pages_init(struct ceph_osd_data *osd_data, osd_data->own_pages = own_pages; } +/* + * Consumes a ref on @pagelist. + */ static void ceph_osd_data_pagelist_init(struct ceph_osd_data *osd_data, struct ceph_pagelist *pagelist) { @@ -362,6 +368,8 @@ static void ceph_osd_data_release(struct ceph_osd_data *osd_data) num_pages = calc_pages_for((u64)osd_data->alignment, (u64)osd_data->length); ceph_release_page_vector(osd_data->pages, num_pages); + } else if (osd_data->type == CEPH_OSD_DATA_TYPE_PAGELIST) { + ceph_pagelist_release(osd_data->pagelist); } ceph_osd_data_init(osd_data); } @@ -402,6 +410,9 @@ static void osd_req_op_data_release(struct ceph_osd_request *osd_req, case CEPH_OSD_OP_LIST_WATCHERS: ceph_osd_data_release(&op->list_watchers.response_data); break; + case CEPH_OSD_OP_COPY_FROM: + ceph_osd_data_release(&op->copy_from.osd_data); + break; default: break; } @@ -606,12 +617,15 @@ static int ceph_oloc_encoding_size(const struct ceph_object_locator *oloc) return 8 + 4 + 4 + 4 + (oloc->pool_ns ? oloc->pool_ns->len : 0); } -int ceph_osdc_alloc_messages(struct ceph_osd_request *req, gfp_t gfp) +static int __ceph_osdc_alloc_messages(struct ceph_osd_request *req, gfp_t gfp, + int num_request_data_items, + int num_reply_data_items) { struct ceph_osd_client *osdc = req->r_osdc; struct ceph_msg *msg; int msg_size; + WARN_ON(req->r_request || req->r_reply); WARN_ON(ceph_oid_empty(&req->r_base_oid)); WARN_ON(ceph_oloc_empty(&req->r_base_oloc)); @@ -633,9 +647,11 @@ int ceph_osdc_alloc_messages(struct ceph_osd_request *req, gfp_t gfp) msg_size += 4 + 8; /* retry_attempt, features */ if (req->r_mempool) - msg = ceph_msgpool_get(&osdc->msgpool_op, 0); + msg = ceph_msgpool_get(&osdc->msgpool_op, msg_size, + num_request_data_items); else - msg = ceph_msg_new(CEPH_MSG_OSD_OP, msg_size, gfp, true); + msg = ceph_msg_new2(CEPH_MSG_OSD_OP, msg_size, + num_request_data_items, gfp, true); if (!msg) return -ENOMEM; @@ -648,9 +664,11 @@ int ceph_osdc_alloc_messages(struct ceph_osd_request *req, gfp_t gfp) msg_size += req->r_num_ops * sizeof(struct ceph_osd_op); if (req->r_mempool) - msg = ceph_msgpool_get(&osdc->msgpool_op_reply, 0); + msg = ceph_msgpool_get(&osdc->msgpool_op_reply, msg_size, + num_reply_data_items); else - msg = ceph_msg_new(CEPH_MSG_OSD_OPREPLY, msg_size, gfp, true); + msg = ceph_msg_new2(CEPH_MSG_OSD_OPREPLY, msg_size, + num_reply_data_items, gfp, true); if (!msg) return -ENOMEM; @@ -658,7 +676,6 @@ int ceph_osdc_alloc_messages(struct ceph_osd_request *req, gfp_t gfp) return 0; } -EXPORT_SYMBOL(ceph_osdc_alloc_messages); static bool osd_req_opcode_valid(u16 opcode) { @@ -671,6 +688,65 @@ __CEPH_FORALL_OSD_OPS(GENERATE_CASE) } } +static void get_num_data_items(struct ceph_osd_request *req, + int *num_request_data_items, + int *num_reply_data_items) +{ + struct ceph_osd_req_op *op; + + *num_request_data_items = 0; + *num_reply_data_items = 0; + + for (op = req->r_ops; op != &req->r_ops[req->r_num_ops]; op++) { + switch (op->op) { + /* request */ + case CEPH_OSD_OP_WRITE: + case CEPH_OSD_OP_WRITEFULL: + case CEPH_OSD_OP_SETXATTR: + case CEPH_OSD_OP_CMPXATTR: + case CEPH_OSD_OP_NOTIFY_ACK: + case CEPH_OSD_OP_COPY_FROM: + *num_request_data_items += 1; + break; + + /* reply */ + case CEPH_OSD_OP_STAT: + case CEPH_OSD_OP_READ: + case CEPH_OSD_OP_LIST_WATCHERS: + *num_reply_data_items += 1; + break; + + /* both */ + case CEPH_OSD_OP_NOTIFY: + *num_request_data_items += 1; + *num_reply_data_items += 1; + break; + case CEPH_OSD_OP_CALL: + *num_request_data_items += 2; + *num_reply_data_items += 1; + break; + + default: + WARN_ON(!osd_req_opcode_valid(op->op)); + break; + } + } +} + +/* + * oid, oloc and OSD op opcode(s) must be filled in before this function + * is called. + */ +int ceph_osdc_alloc_messages(struct ceph_osd_request *req, gfp_t gfp) +{ + int num_request_data_items, num_reply_data_items; + + get_num_data_items(req, &num_request_data_items, &num_reply_data_items); + return __ceph_osdc_alloc_messages(req, gfp, num_request_data_items, + num_reply_data_items); +} +EXPORT_SYMBOL(ceph_osdc_alloc_messages); + /* * This is an osd op init function for opcodes that have no data or * other information associated with them. It also serves as a @@ -767,22 +843,19 @@ void osd_req_op_extent_dup_last(struct ceph_osd_request *osd_req, EXPORT_SYMBOL(osd_req_op_extent_dup_last); int osd_req_op_cls_init(struct ceph_osd_request *osd_req, unsigned int which, - u16 opcode, const char *class, const char *method) + const char *class, const char *method) { - struct ceph_osd_req_op *op = _osd_req_op_init(osd_req, which, - opcode, 0); + struct ceph_osd_req_op *op; struct ceph_pagelist *pagelist; size_t payload_len = 0; size_t size; - BUG_ON(opcode != CEPH_OSD_OP_CALL); + op = _osd_req_op_init(osd_req, which, CEPH_OSD_OP_CALL, 0); - pagelist = kmalloc(sizeof (*pagelist), GFP_NOFS); + pagelist = ceph_pagelist_alloc(GFP_NOFS); if (!pagelist) return -ENOMEM; - ceph_pagelist_init(pagelist); - op->cls.class_name = class; size = strlen(class); BUG_ON(size > (size_t) U8_MAX); @@ -815,12 +888,10 @@ int osd_req_op_xattr_init(struct ceph_osd_request *osd_req, unsigned int which, BUG_ON(opcode != CEPH_OSD_OP_SETXATTR && opcode != CEPH_OSD_OP_CMPXATTR); - pagelist = kmalloc(sizeof(*pagelist), GFP_NOFS); + pagelist = ceph_pagelist_alloc(GFP_NOFS); if (!pagelist) return -ENOMEM; - ceph_pagelist_init(pagelist); - payload_len = strlen(name); op->xattr.name_len = payload_len; ceph_pagelist_append(pagelist, name, payload_len); @@ -900,12 +971,6 @@ static void ceph_osdc_msg_data_add(struct ceph_msg *msg, static u32 osd_req_encode_op(struct ceph_osd_op *dst, const struct ceph_osd_req_op *src) { - if (WARN_ON(!osd_req_opcode_valid(src->op))) { - pr_err("unrecognized osd opcode %d\n", src->op); - - return 0; - } - switch (src->op) { case CEPH_OSD_OP_STAT: break; @@ -955,6 +1020,14 @@ static u32 osd_req_encode_op(struct ceph_osd_op *dst, case CEPH_OSD_OP_CREATE: case CEPH_OSD_OP_DELETE: break; + case CEPH_OSD_OP_COPY_FROM: + dst->copy_from.snapid = cpu_to_le64(src->copy_from.snapid); + dst->copy_from.src_version = + cpu_to_le64(src->copy_from.src_version); + dst->copy_from.flags = src->copy_from.flags; + dst->copy_from.src_fadvise_flags = + cpu_to_le32(src->copy_from.src_fadvise_flags); + break; default: pr_err("unsupported osd opcode %s\n", ceph_osd_op_name(src->op)); @@ -1038,7 +1111,15 @@ struct ceph_osd_request *ceph_osdc_new_request(struct ceph_osd_client *osdc, if (flags & CEPH_OSD_FLAG_WRITE) req->r_data_offset = off; - r = ceph_osdc_alloc_messages(req, GFP_NOFS); + if (num_ops > 1) + /* + * This is a special case for ceph_writepages_start(), but it + * also covers ceph_uninline_data(). If more multi-op request + * use cases emerge, we will need a separate helper. + */ + r = __ceph_osdc_alloc_messages(req, GFP_NOFS, num_ops, 0); + else + r = ceph_osdc_alloc_messages(req, GFP_NOFS); if (r) goto fail; @@ -1845,48 +1926,55 @@ static bool should_plug_request(struct ceph_osd_request *req) return true; } -static void setup_request_data(struct ceph_osd_request *req, - struct ceph_msg *msg) +/* + * Keep get_num_data_items() in sync with this function. + */ +static void setup_request_data(struct ceph_osd_request *req) { - u32 data_len = 0; - int i; + struct ceph_msg *request_msg = req->r_request; + struct ceph_msg *reply_msg = req->r_reply; + struct ceph_osd_req_op *op; - if (!list_empty(&msg->data)) + if (req->r_request->num_data_items || req->r_reply->num_data_items) return; - WARN_ON(msg->data_length); - for (i = 0; i < req->r_num_ops; i++) { - struct ceph_osd_req_op *op = &req->r_ops[i]; - + WARN_ON(request_msg->data_length || reply_msg->data_length); + for (op = req->r_ops; op != &req->r_ops[req->r_num_ops]; op++) { switch (op->op) { /* request */ case CEPH_OSD_OP_WRITE: case CEPH_OSD_OP_WRITEFULL: WARN_ON(op->indata_len != op->extent.length); - ceph_osdc_msg_data_add(msg, &op->extent.osd_data); + ceph_osdc_msg_data_add(request_msg, + &op->extent.osd_data); break; case CEPH_OSD_OP_SETXATTR: case CEPH_OSD_OP_CMPXATTR: WARN_ON(op->indata_len != op->xattr.name_len + op->xattr.value_len); - ceph_osdc_msg_data_add(msg, &op->xattr.osd_data); + ceph_osdc_msg_data_add(request_msg, + &op->xattr.osd_data); break; case CEPH_OSD_OP_NOTIFY_ACK: - ceph_osdc_msg_data_add(msg, + ceph_osdc_msg_data_add(request_msg, &op->notify_ack.request_data); break; + case CEPH_OSD_OP_COPY_FROM: + ceph_osdc_msg_data_add(request_msg, + &op->copy_from.osd_data); + break; /* reply */ case CEPH_OSD_OP_STAT: - ceph_osdc_msg_data_add(req->r_reply, + ceph_osdc_msg_data_add(reply_msg, &op->raw_data_in); break; case CEPH_OSD_OP_READ: - ceph_osdc_msg_data_add(req->r_reply, + ceph_osdc_msg_data_add(reply_msg, &op->extent.osd_data); break; case CEPH_OSD_OP_LIST_WATCHERS: - ceph_osdc_msg_data_add(req->r_reply, + ceph_osdc_msg_data_add(reply_msg, &op->list_watchers.response_data); break; @@ -1895,25 +1983,23 @@ static void setup_request_data(struct ceph_osd_request *req, WARN_ON(op->indata_len != op->cls.class_len + op->cls.method_len + op->cls.indata_len); - ceph_osdc_msg_data_add(msg, &op->cls.request_info); + ceph_osdc_msg_data_add(request_msg, + &op->cls.request_info); /* optional, can be NONE */ - ceph_osdc_msg_data_add(msg, &op->cls.request_data); + ceph_osdc_msg_data_add(request_msg, + &op->cls.request_data); /* optional, can be NONE */ - ceph_osdc_msg_data_add(req->r_reply, + ceph_osdc_msg_data_add(reply_msg, &op->cls.response_data); break; case CEPH_OSD_OP_NOTIFY: - ceph_osdc_msg_data_add(msg, + ceph_osdc_msg_data_add(request_msg, &op->notify.request_data); - ceph_osdc_msg_data_add(req->r_reply, + ceph_osdc_msg_data_add(reply_msg, &op->notify.response_data); break; } - - data_len += op->indata_len; } - - WARN_ON(data_len != msg->data_length); } static void encode_pgid(void **p, const struct ceph_pg *pgid) @@ -1961,7 +2047,7 @@ static void encode_request_partial(struct ceph_osd_request *req, req->r_data_offset || req->r_snapc); } - setup_request_data(req, msg); + setup_request_data(req); encode_spgid(&p, &req->r_t.spgid); /* actual spg */ ceph_encode_32(&p, req->r_t.pgid.seed); /* raw hash */ @@ -3001,11 +3087,21 @@ static void linger_submit(struct ceph_osd_linger_request *lreq) struct ceph_osd_client *osdc = lreq->osdc; struct ceph_osd *osd; + down_write(&osdc->lock); + linger_register(lreq); + if (lreq->is_watch) { + lreq->reg_req->r_ops[0].watch.cookie = lreq->linger_id; + lreq->ping_req->r_ops[0].watch.cookie = lreq->linger_id; + } else { + lreq->reg_req->r_ops[0].notify.cookie = lreq->linger_id; + } + calc_target(osdc, &lreq->t, NULL, false); osd = lookup_create_osd(osdc, lreq->t.osd, true); link_linger(osd, lreq); send_linger(lreq); + up_write(&osdc->lock); } static void cancel_linger_map_check(struct ceph_osd_linger_request *lreq) @@ -4318,9 +4414,7 @@ static void handle_watch_notify(struct ceph_osd_client *osdc, lreq->notify_id, notify_id); } else if (!completion_done(&lreq->notify_finish_wait)) { struct ceph_msg_data *data = - list_first_entry_or_null(&msg->data, - struct ceph_msg_data, - links); + msg->num_data_items ? &msg->data[0] : NULL; if (data) { if (lreq->preply_pages) { @@ -4476,6 +4570,23 @@ alloc_linger_request(struct ceph_osd_linger_request *lreq) ceph_oid_copy(&req->r_base_oid, &lreq->t.base_oid); ceph_oloc_copy(&req->r_base_oloc, &lreq->t.base_oloc); + return req; +} + +static struct ceph_osd_request * +alloc_watch_request(struct ceph_osd_linger_request *lreq, u8 watch_opcode) +{ + struct ceph_osd_request *req; + + req = alloc_linger_request(lreq); + if (!req) + return NULL; + + /* + * Pass 0 for cookie because we don't know it yet, it will be + * filled in by linger_submit(). + */ + osd_req_op_watch_init(req, 0, 0, watch_opcode); if (ceph_osdc_alloc_messages(req, GFP_NOIO)) { ceph_osdc_put_request(req); @@ -4514,27 +4625,19 @@ ceph_osdc_watch(struct ceph_osd_client *osdc, lreq->t.flags = CEPH_OSD_FLAG_WRITE; ktime_get_real_ts64(&lreq->mtime); - lreq->reg_req = alloc_linger_request(lreq); + lreq->reg_req = alloc_watch_request(lreq, CEPH_OSD_WATCH_OP_WATCH); if (!lreq->reg_req) { ret = -ENOMEM; goto err_put_lreq; } - lreq->ping_req = alloc_linger_request(lreq); + lreq->ping_req = alloc_watch_request(lreq, CEPH_OSD_WATCH_OP_PING); if (!lreq->ping_req) { ret = -ENOMEM; goto err_put_lreq; } - down_write(&osdc->lock); - linger_register(lreq); /* before osd_req_op_* */ - osd_req_op_watch_init(lreq->reg_req, 0, lreq->linger_id, - CEPH_OSD_WATCH_OP_WATCH); - osd_req_op_watch_init(lreq->ping_req, 0, lreq->linger_id, - CEPH_OSD_WATCH_OP_PING); linger_submit(lreq); - up_write(&osdc->lock); - ret = linger_reg_commit_wait(lreq); if (ret) { linger_cancel(lreq); @@ -4599,11 +4702,10 @@ static int osd_req_op_notify_ack_init(struct ceph_osd_request *req, int which, op = _osd_req_op_init(req, which, CEPH_OSD_OP_NOTIFY_ACK, 0); - pl = kmalloc(sizeof(*pl), GFP_NOIO); + pl = ceph_pagelist_alloc(GFP_NOIO); if (!pl) return -ENOMEM; - ceph_pagelist_init(pl); ret = ceph_pagelist_encode_64(pl, notify_id); ret |= ceph_pagelist_encode_64(pl, cookie); if (payload) { @@ -4641,12 +4743,12 @@ int ceph_osdc_notify_ack(struct ceph_osd_client *osdc, ceph_oloc_copy(&req->r_base_oloc, oloc); req->r_flags = CEPH_OSD_FLAG_READ; - ret = ceph_osdc_alloc_messages(req, GFP_NOIO); + ret = osd_req_op_notify_ack_init(req, 0, notify_id, cookie, payload, + payload_len); if (ret) goto out_put_req; - ret = osd_req_op_notify_ack_init(req, 0, notify_id, cookie, payload, - payload_len); + ret = ceph_osdc_alloc_messages(req, GFP_NOIO); if (ret) goto out_put_req; @@ -4670,11 +4772,10 @@ static int osd_req_op_notify_init(struct ceph_osd_request *req, int which, op = _osd_req_op_init(req, which, CEPH_OSD_OP_NOTIFY, 0); op->notify.cookie = cookie; - pl = kmalloc(sizeof(*pl), GFP_NOIO); + pl = ceph_pagelist_alloc(GFP_NOIO); if (!pl) return -ENOMEM; - ceph_pagelist_init(pl); ret = ceph_pagelist_encode_32(pl, 1); /* prot_ver */ ret |= ceph_pagelist_encode_32(pl, timeout); ret |= ceph_pagelist_encode_32(pl, payload_len); @@ -4733,29 +4834,30 @@ int ceph_osdc_notify(struct ceph_osd_client *osdc, goto out_put_lreq; } + /* + * Pass 0 for cookie because we don't know it yet, it will be + * filled in by linger_submit(). + */ + ret = osd_req_op_notify_init(lreq->reg_req, 0, 0, 1, timeout, + payload, payload_len); + if (ret) + goto out_put_lreq; + /* for notify_id */ pages = ceph_alloc_page_vector(1, GFP_NOIO); if (IS_ERR(pages)) { ret = PTR_ERR(pages); goto out_put_lreq; } - - down_write(&osdc->lock); - linger_register(lreq); /* before osd_req_op_* */ - ret = osd_req_op_notify_init(lreq->reg_req, 0, lreq->linger_id, 1, - timeout, payload, payload_len); - if (ret) { - linger_unregister(lreq); - up_write(&osdc->lock); - ceph_release_page_vector(pages, 1); - goto out_put_lreq; - } ceph_osd_data_pages_init(osd_req_op_data(lreq->reg_req, 0, notify, response_data), pages, PAGE_SIZE, 0, false, true); - linger_submit(lreq); - up_write(&osdc->lock); + ret = ceph_osdc_alloc_messages(lreq->reg_req, GFP_NOIO); + if (ret) + goto out_put_lreq; + + linger_submit(lreq); ret = linger_reg_commit_wait(lreq); if (!ret) ret = linger_notify_finish_wait(lreq); @@ -4881,10 +4983,6 @@ int ceph_osdc_list_watchers(struct ceph_osd_client *osdc, ceph_oloc_copy(&req->r_base_oloc, oloc); req->r_flags = CEPH_OSD_FLAG_READ; - ret = ceph_osdc_alloc_messages(req, GFP_NOIO); - if (ret) - goto out_put_req; - pages = ceph_alloc_page_vector(1, GFP_NOIO); if (IS_ERR(pages)) { ret = PTR_ERR(pages); @@ -4896,6 +4994,10 @@ int ceph_osdc_list_watchers(struct ceph_osd_client *osdc, response_data), pages, PAGE_SIZE, 0, false, true); + ret = ceph_osdc_alloc_messages(req, GFP_NOIO); + if (ret) + goto out_put_req; + ceph_osdc_start_request(osdc, req, false); ret = ceph_osdc_wait_request(osdc, req); if (ret >= 0) { @@ -4958,11 +5060,7 @@ int ceph_osdc_call(struct ceph_osd_client *osdc, ceph_oloc_copy(&req->r_base_oloc, oloc); req->r_flags = flags; - ret = ceph_osdc_alloc_messages(req, GFP_NOIO); - if (ret) - goto out_put_req; - - ret = osd_req_op_cls_init(req, 0, CEPH_OSD_OP_CALL, class, method); + ret = osd_req_op_cls_init(req, 0, class, method); if (ret) goto out_put_req; @@ -4973,6 +5071,10 @@ int ceph_osdc_call(struct ceph_osd_client *osdc, osd_req_op_cls_response_data_pages(req, 0, &resp_page, *resp_len, 0, false, false); + ret = ceph_osdc_alloc_messages(req, GFP_NOIO); + if (ret) + goto out_put_req; + ceph_osdc_start_request(osdc, req, false); ret = ceph_osdc_wait_request(osdc, req); if (ret >= 0) { @@ -5021,11 +5123,12 @@ int ceph_osdc_init(struct ceph_osd_client *osdc, struct ceph_client *client) goto out_map; err = ceph_msgpool_init(&osdc->msgpool_op, CEPH_MSG_OSD_OP, - PAGE_SIZE, 10, true, "osd_op"); + PAGE_SIZE, CEPH_OSD_SLAB_OPS, 10, "osd_op"); if (err < 0) goto out_mempool; err = ceph_msgpool_init(&osdc->msgpool_op_reply, CEPH_MSG_OSD_OPREPLY, - PAGE_SIZE, 10, true, "osd_op_reply"); + PAGE_SIZE, CEPH_OSD_SLAB_OPS, 10, + "osd_op_reply"); if (err < 0) goto out_msgpool; @@ -5168,6 +5271,80 @@ int ceph_osdc_writepages(struct ceph_osd_client *osdc, struct ceph_vino vino, } EXPORT_SYMBOL(ceph_osdc_writepages); +static int osd_req_op_copy_from_init(struct ceph_osd_request *req, + u64 src_snapid, u64 src_version, + struct ceph_object_id *src_oid, + struct ceph_object_locator *src_oloc, + u32 src_fadvise_flags, + u32 dst_fadvise_flags, + u8 copy_from_flags) +{ + struct ceph_osd_req_op *op; + struct page **pages; + void *p, *end; + + pages = ceph_alloc_page_vector(1, GFP_KERNEL); + if (IS_ERR(pages)) + return PTR_ERR(pages); + + op = _osd_req_op_init(req, 0, CEPH_OSD_OP_COPY_FROM, dst_fadvise_flags); + op->copy_from.snapid = src_snapid; + op->copy_from.src_version = src_version; + op->copy_from.flags = copy_from_flags; + op->copy_from.src_fadvise_flags = src_fadvise_flags; + + p = page_address(pages[0]); + end = p + PAGE_SIZE; + ceph_encode_string(&p, end, src_oid->name, src_oid->name_len); + encode_oloc(&p, end, src_oloc); + op->indata_len = PAGE_SIZE - (end - p); + + ceph_osd_data_pages_init(&op->copy_from.osd_data, pages, + op->indata_len, 0, false, true); + return 0; +} + +int ceph_osdc_copy_from(struct ceph_osd_client *osdc, + u64 src_snapid, u64 src_version, + struct ceph_object_id *src_oid, + struct ceph_object_locator *src_oloc, + u32 src_fadvise_flags, + struct ceph_object_id *dst_oid, + struct ceph_object_locator *dst_oloc, + u32 dst_fadvise_flags, + u8 copy_from_flags) +{ + struct ceph_osd_request *req; + int ret; + + req = ceph_osdc_alloc_request(osdc, NULL, 1, false, GFP_KERNEL); + if (!req) + return -ENOMEM; + + req->r_flags = CEPH_OSD_FLAG_WRITE; + + ceph_oloc_copy(&req->r_t.base_oloc, dst_oloc); + ceph_oid_copy(&req->r_t.base_oid, dst_oid); + + ret = osd_req_op_copy_from_init(req, src_snapid, src_version, src_oid, + src_oloc, src_fadvise_flags, + dst_fadvise_flags, copy_from_flags); + if (ret) + goto out; + + ret = ceph_osdc_alloc_messages(req, GFP_KERNEL); + if (ret) + goto out; + + ceph_osdc_start_request(osdc, req, false); + ret = ceph_osdc_wait_request(osdc, req); + +out: + ceph_osdc_put_request(req); + return ret; +} +EXPORT_SYMBOL(ceph_osdc_copy_from); + int __init ceph_osdc_setup(void) { size_t size = sizeof(struct ceph_osd_request) + @@ -5295,7 +5472,7 @@ static struct ceph_msg *alloc_msg_with_page_vector(struct ceph_msg_header *hdr) u32 front_len = le32_to_cpu(hdr->front_len); u32 data_len = le32_to_cpu(hdr->data_len); - m = ceph_msg_new(type, front_len, GFP_NOIO, false); + m = ceph_msg_new2(type, front_len, 1, GFP_NOIO, false); if (!m) return NULL; diff --git a/net/ceph/pagelist.c b/net/ceph/pagelist.c index 2ea0564771d2..65e34f78b05d 100644 --- a/net/ceph/pagelist.c +++ b/net/ceph/pagelist.c @@ -6,6 +6,26 @@ #include <linux/highmem.h> #include <linux/ceph/pagelist.h> +struct ceph_pagelist *ceph_pagelist_alloc(gfp_t gfp_flags) +{ + struct ceph_pagelist *pl; + + pl = kmalloc(sizeof(*pl), gfp_flags); + if (!pl) + return NULL; + + INIT_LIST_HEAD(&pl->head); + pl->mapped_tail = NULL; + pl->length = 0; + pl->room = 0; + INIT_LIST_HEAD(&pl->free_list); + pl->num_pages_free = 0; + refcount_set(&pl->refcnt, 1); + + return pl; +} +EXPORT_SYMBOL(ceph_pagelist_alloc); + static void ceph_pagelist_unmap_tail(struct ceph_pagelist *pl) { if (pl->mapped_tail) { diff --git a/net/compat.c b/net/compat.c index 3b2105f6549d..47a614b370cd 100644 --- a/net/compat.c +++ b/net/compat.c @@ -812,21 +812,21 @@ COMPAT_SYSCALL_DEFINE6(recvfrom, int, fd, void __user *, buf, compat_size_t, len static int __compat_sys_recvmmsg(int fd, struct compat_mmsghdr __user *mmsg, unsigned int vlen, unsigned int flags, - struct compat_timespec __user *timeout) + struct old_timespec32 __user *timeout) { int datagrams; - struct timespec ktspec; + struct timespec64 ktspec; if (timeout == NULL) return __sys_recvmmsg(fd, (struct mmsghdr __user *)mmsg, vlen, flags | MSG_CMSG_COMPAT, NULL); - if (compat_get_timespec(&ktspec, timeout)) + if (compat_get_timespec64(&ktspec, timeout)) return -EFAULT; datagrams = __sys_recvmmsg(fd, (struct mmsghdr __user *)mmsg, vlen, flags | MSG_CMSG_COMPAT, &ktspec); - if (datagrams > 0 && compat_put_timespec(&ktspec, timeout)) + if (datagrams > 0 && compat_put_timespec64(&ktspec, timeout)) datagrams = -EFAULT; return datagrams; @@ -834,7 +834,7 @@ static int __compat_sys_recvmmsg(int fd, struct compat_mmsghdr __user *mmsg, COMPAT_SYSCALL_DEFINE5(recvmmsg, int, fd, struct compat_mmsghdr __user *, mmsg, unsigned int, vlen, unsigned int, flags, - struct compat_timespec __user *, timeout) + struct old_timespec32 __user *, timeout) { return __compat_sys_recvmmsg(fd, mmsg, vlen, flags, timeout); } diff --git a/net/core/Makefile b/net/core/Makefile index 80175e6a2eb8..fccd31e0e7f7 100644 --- a/net/core/Makefile +++ b/net/core/Makefile @@ -16,6 +16,7 @@ obj-y += dev.o ethtool.o dev_addr_lists.o dst.o netevent.o \ obj-y += net-sysfs.o obj-$(CONFIG_PAGE_POOL) += page_pool.o obj-$(CONFIG_PROC_FS) += net-procfs.o +obj-$(CONFIG_NET_SOCK_MSG) += skmsg.o obj-$(CONFIG_NET_PKTGEN) += pktgen.o obj-$(CONFIG_NETPOLL) += netpoll.o obj-$(CONFIG_FIB_RULES) += fib_rules.o @@ -27,6 +28,7 @@ obj-$(CONFIG_CGROUP_NET_PRIO) += netprio_cgroup.o obj-$(CONFIG_CGROUP_NET_CLASSID) += netclassid_cgroup.o obj-$(CONFIG_LWTUNNEL) += lwtunnel.o obj-$(CONFIG_LWTUNNEL_BPF) += lwt_bpf.o +obj-$(CONFIG_BPF_STREAM_PARSER) += sock_map.o obj-$(CONFIG_DST_CACHE) += dst_cache.o obj-$(CONFIG_HWBM) += hwbm.o obj-$(CONFIG_NET_DEVLINK) += devlink.o diff --git a/net/core/datagram.c b/net/core/datagram.c index 9aac0d63d53e..57f3a6fcfc1e 100644 --- a/net/core/datagram.c +++ b/net/core/datagram.c @@ -808,8 +808,9 @@ int skb_copy_and_csum_datagram_msg(struct sk_buff *skb, return -EINVAL; } - if (unlikely(skb->ip_summed == CHECKSUM_COMPLETE)) - netdev_rx_csum_fault(skb->dev); + if (unlikely(skb->ip_summed == CHECKSUM_COMPLETE) && + !skb->csum_complete_sw) + netdev_rx_csum_fault(NULL); } return 0; fault: @@ -837,7 +838,7 @@ __poll_t datagram_poll(struct file *file, struct socket *sock, struct sock *sk = sock->sk; __poll_t mask; - sock_poll_wait(file, wait); + sock_poll_wait(file, sock, wait); mask = 0; /* exceptional events? */ diff --git a/net/core/dev.c b/net/core/dev.c index 0b2d777e5b9e..77d43ae2a7bb 100644 --- a/net/core/dev.c +++ b/net/core/dev.c @@ -1752,6 +1752,28 @@ int call_netdevice_notifiers(unsigned long val, struct net_device *dev) } EXPORT_SYMBOL(call_netdevice_notifiers); +/** + * call_netdevice_notifiers_mtu - call all network notifier blocks + * @val: value passed unmodified to notifier function + * @dev: net_device pointer passed unmodified to notifier function + * @arg: additional u32 argument passed to the notifier function + * + * Call all network notifier blocks. Parameters and return value + * are as for raw_notifier_call_chain(). + */ +static int call_netdevice_notifiers_mtu(unsigned long val, + struct net_device *dev, u32 arg) +{ + struct netdev_notifier_info_ext info = { + .info.dev = dev, + .ext.mtu = arg, + }; + + BUILD_BUG_ON(offsetof(struct netdev_notifier_info_ext, info) != 0); + + return call_netdevice_notifiers_info(val, &info.info); +} + #ifdef CONFIG_NET_INGRESS static DEFINE_STATIC_KEY_FALSE(ingress_needed_key); @@ -1954,6 +1976,17 @@ static inline bool skb_loop_sk(struct packet_type *ptype, struct sk_buff *skb) return false; } +/** + * dev_nit_active - return true if any network interface taps are in use + * + * @dev: network device to check for the presence of taps + */ +bool dev_nit_active(struct net_device *dev) +{ + return !list_empty(&ptype_all) || !list_empty(&dev->ptype_all); +} +EXPORT_SYMBOL_GPL(dev_nit_active); + /* * Support routine. Sends outgoing frames to any network * taps currently in use. @@ -3211,7 +3244,7 @@ static int xmit_one(struct sk_buff *skb, struct net_device *dev, unsigned int len; int rc; - if (!list_empty(&ptype_all) || !list_empty(&dev->ptype_all)) + if (dev_nit_active(dev)) dev_queue_xmit_nit(skb, dev); len = skb->len; @@ -4258,6 +4291,9 @@ static u32 netif_receive_generic_xdp(struct sk_buff *skb, struct netdev_rx_queue *rxqueue; void *orig_data, *orig_data_end; u32 metalen, act = XDP_DROP; + __be16 orig_eth_type; + struct ethhdr *eth; + bool orig_bcast; int hlen, off; u32 mac_len; @@ -4298,6 +4334,9 @@ static u32 netif_receive_generic_xdp(struct sk_buff *skb, xdp->data_hard_start = skb->data - skb_headroom(skb); orig_data_end = xdp->data_end; orig_data = xdp->data; + eth = (struct ethhdr *)xdp->data; + orig_bcast = is_multicast_ether_addr_64bits(eth->h_dest); + orig_eth_type = eth->h_proto; rxqueue = netif_get_rxqueue(skb); xdp->rxq = &rxqueue->xdp_rxq; @@ -4321,6 +4360,14 @@ static u32 netif_receive_generic_xdp(struct sk_buff *skb, } + /* check if XDP changed eth hdr such SKB needs update */ + eth = (struct ethhdr *)xdp->data; + if ((orig_eth_type != eth->h_proto) || + (orig_bcast != is_multicast_ether_addr_64bits(eth->h_dest))) { + __skb_push(skb, ETH_HLEN); + skb->protocol = eth_type_trans(skb, skb->dev); + } + switch (act) { case XDP_REDIRECT: case XDP_TX: @@ -5410,7 +5457,7 @@ static void gro_flush_oldest(struct list_head *head) /* Do not adjust napi->gro_hash[].count, caller is adding a new * SKB to the chain. */ - list_del(&oldest->list); + skb_list_del_init(oldest); napi_gro_complete(oldest); } @@ -7575,14 +7622,16 @@ int dev_set_mtu_ext(struct net_device *dev, int new_mtu, err = __dev_set_mtu(dev, new_mtu); if (!err) { - err = call_netdevice_notifiers(NETDEV_CHANGEMTU, dev); + err = call_netdevice_notifiers_mtu(NETDEV_CHANGEMTU, dev, + orig_mtu); err = notifier_to_errno(err); if (err) { /* setting mtu back and notifying everyone again, * so that they have a chance to revert changes. */ __dev_set_mtu(dev, orig_mtu); - call_netdevice_notifiers(NETDEV_CHANGEMTU, dev); + call_netdevice_notifiers_mtu(NETDEV_CHANGEMTU, dev, + new_mtu); } } return err; diff --git a/net/core/devlink.c b/net/core/devlink.c index 8c0ed225e280..3a4b29a13d31 100644 --- a/net/core/devlink.c +++ b/net/core/devlink.c @@ -1626,7 +1626,7 @@ static int devlink_nl_cmd_eswitch_set_doit(struct sk_buff *skb, if (!ops->eswitch_mode_set) return -EOPNOTSUPP; mode = nla_get_u16(info->attrs[DEVLINK_ATTR_ESWITCH_MODE]); - err = ops->eswitch_mode_set(devlink, mode); + err = ops->eswitch_mode_set(devlink, mode, info->extack); if (err) return err; } @@ -1636,7 +1636,8 @@ static int devlink_nl_cmd_eswitch_set_doit(struct sk_buff *skb, return -EOPNOTSUPP; inline_mode = nla_get_u8( info->attrs[DEVLINK_ATTR_ESWITCH_INLINE_MODE]); - err = ops->eswitch_inline_mode_set(devlink, inline_mode); + err = ops->eswitch_inline_mode_set(devlink, inline_mode, + info->extack); if (err) return err; } @@ -1645,7 +1646,8 @@ static int devlink_nl_cmd_eswitch_set_doit(struct sk_buff *skb, if (!ops->eswitch_encap_mode_set) return -EOPNOTSUPP; encap_mode = nla_get_u8(info->attrs[DEVLINK_ATTR_ESWITCH_ENCAP_MODE]); - err = ops->eswitch_encap_mode_set(devlink, encap_mode); + err = ops->eswitch_encap_mode_set(devlink, encap_mode, + info->extack); if (err) return err; } @@ -2675,6 +2677,21 @@ static const struct devlink_param devlink_param_generic[] = { .name = DEVLINK_PARAM_GENERIC_REGION_SNAPSHOT_NAME, .type = DEVLINK_PARAM_GENERIC_REGION_SNAPSHOT_TYPE, }, + { + .id = DEVLINK_PARAM_GENERIC_ID_IGNORE_ARI, + .name = DEVLINK_PARAM_GENERIC_IGNORE_ARI_NAME, + .type = DEVLINK_PARAM_GENERIC_IGNORE_ARI_TYPE, + }, + { + .id = DEVLINK_PARAM_GENERIC_ID_MSIX_VEC_PER_PF_MAX, + .name = DEVLINK_PARAM_GENERIC_MSIX_VEC_PER_PF_MAX_NAME, + .type = DEVLINK_PARAM_GENERIC_MSIX_VEC_PER_PF_MAX_TYPE, + }, + { + .id = DEVLINK_PARAM_GENERIC_ID_MSIX_VEC_PER_PF_MIN, + .name = DEVLINK_PARAM_GENERIC_MSIX_VEC_PER_PF_MIN_NAME, + .type = DEVLINK_PARAM_GENERIC_MSIX_VEC_PER_PF_MIN_TYPE, + }, }; static int devlink_param_generic_verify(const struct devlink_param *param) @@ -2995,6 +3012,8 @@ devlink_param_value_get_from_info(const struct devlink_param *param, struct genl_info *info, union devlink_param_value *value) { + int len; + if (param->type != DEVLINK_PARAM_TYPE_BOOL && !info->attrs[DEVLINK_ATTR_PARAM_VALUE_DATA]) return -EINVAL; @@ -3010,10 +3029,13 @@ devlink_param_value_get_from_info(const struct devlink_param *param, value->vu32 = nla_get_u32(info->attrs[DEVLINK_ATTR_PARAM_VALUE_DATA]); break; case DEVLINK_PARAM_TYPE_STRING: - if (nla_len(info->attrs[DEVLINK_ATTR_PARAM_VALUE_DATA]) > - DEVLINK_PARAM_MAX_STRING_VALUE) + len = strnlen(nla_data(info->attrs[DEVLINK_ATTR_PARAM_VALUE_DATA]), + nla_len(info->attrs[DEVLINK_ATTR_PARAM_VALUE_DATA])); + if (len == nla_len(info->attrs[DEVLINK_ATTR_PARAM_VALUE_DATA]) || + len >= __DEVLINK_PARAM_MAX_STRING_VALUE) return -EINVAL; - value->vstr = nla_data(info->attrs[DEVLINK_ATTR_PARAM_VALUE_DATA]); + strcpy(value->vstr, + nla_data(info->attrs[DEVLINK_ATTR_PARAM_VALUE_DATA])); break; case DEVLINK_PARAM_TYPE_BOOL: value->vbool = info->attrs[DEVLINK_ATTR_PARAM_VALUE_DATA] ? @@ -3100,7 +3122,10 @@ static int devlink_nl_cmd_param_set_doit(struct sk_buff *skb, return -EOPNOTSUPP; if (cmode == DEVLINK_PARAM_CMODE_DRIVERINIT) { - param_item->driverinit_value = value; + if (param->type == DEVLINK_PARAM_TYPE_STRING) + strcpy(param_item->driverinit_value.vstr, value.vstr); + else + param_item->driverinit_value = value; param_item->driverinit_value_valid = true; } else { if (!param->set) @@ -3487,7 +3512,7 @@ static int devlink_nl_cmd_region_read_dumpit(struct sk_buff *skb, start_offset = *((u64 *)&cb->args[0]); err = nlmsg_parse(cb->nlh, GENL_HDRLEN + devlink_nl_family.hdrsize, - attrs, DEVLINK_ATTR_MAX, ops->policy, NULL); + attrs, DEVLINK_ATTR_MAX, ops->policy, cb->extack); if (err) goto out; @@ -4540,7 +4565,10 @@ int devlink_param_driverinit_value_get(struct devlink *devlink, u32 param_id, DEVLINK_PARAM_CMODE_DRIVERINIT)) return -EOPNOTSUPP; - *init_val = param_item->driverinit_value; + if (param_item->param->type == DEVLINK_PARAM_TYPE_STRING) + strcpy(init_val->vstr, param_item->driverinit_value.vstr); + else + *init_val = param_item->driverinit_value; return 0; } @@ -4571,7 +4599,10 @@ int devlink_param_driverinit_value_set(struct devlink *devlink, u32 param_id, DEVLINK_PARAM_CMODE_DRIVERINIT)) return -EOPNOTSUPP; - param_item->driverinit_value = init_val; + if (param_item->param->type == DEVLINK_PARAM_TYPE_STRING) + strcpy(param_item->driverinit_value.vstr, init_val.vstr); + else + param_item->driverinit_value = init_val; param_item->driverinit_value_valid = true; devlink_param_notify(devlink, param_item, DEVLINK_CMD_PARAM_NEW); @@ -4604,6 +4635,23 @@ void devlink_param_value_changed(struct devlink *devlink, u32 param_id) EXPORT_SYMBOL_GPL(devlink_param_value_changed); /** + * devlink_param_value_str_fill - Safely fill-up the string preventing + * from overflow of the preallocated buffer + * + * @dst_val: destination devlink_param_value + * @src: source buffer + */ +void devlink_param_value_str_fill(union devlink_param_value *dst_val, + const char *src) +{ + size_t len; + + len = strlcpy(dst_val->vstr, src, __DEVLINK_PARAM_MAX_STRING_VALUE); + WARN_ON(len >= __DEVLINK_PARAM_MAX_STRING_VALUE); +} +EXPORT_SYMBOL_GPL(devlink_param_value_str_fill); + +/** * devlink_region_create - create a new address region * * @devlink: devlink diff --git a/net/core/ethtool.c b/net/core/ethtool.c index 96afc55aa61e..d05402868575 100644 --- a/net/core/ethtool.c +++ b/net/core/ethtool.c @@ -27,6 +27,7 @@ #include <linux/rtnetlink.h> #include <linux/sched/signal.h> #include <linux/net.h> +#include <net/xdp_sock.h> /* * Some useful ethtool_ops methods that're device independent. @@ -927,6 +928,9 @@ static noinline_for_stack int ethtool_get_rxnfc(struct net_device *dev, return -EINVAL; } + if (info.cmd != cmd) + return -EINVAL; + if (info.cmd == ETHTOOL_GRXCLSRLALL) { if (info.rule_cnt > 0) { if (info.rule_cnt <= KMALLOC_MAX_SIZE / sizeof(u32)) @@ -1395,6 +1399,7 @@ static int ethtool_get_wol(struct net_device *dev, char __user *useraddr) static int ethtool_set_wol(struct net_device *dev, char __user *useraddr) { struct ethtool_wolinfo wol; + int ret; if (!dev->ethtool_ops->set_wol) return -EOPNOTSUPP; @@ -1402,7 +1407,13 @@ static int ethtool_set_wol(struct net_device *dev, char __user *useraddr) if (copy_from_user(&wol, useraddr, sizeof(wol))) return -EFAULT; - return dev->ethtool_ops->set_wol(dev, &wol); + ret = dev->ethtool_ops->set_wol(dev, &wol); + if (ret) + return ret; + + dev->wol_enabled = !!wol.wolopts; + + return 0; } static int ethtool_get_eee(struct net_device *dev, char __user *useraddr) @@ -1655,8 +1666,10 @@ static noinline_for_stack int ethtool_get_channels(struct net_device *dev, static noinline_for_stack int ethtool_set_channels(struct net_device *dev, void __user *useraddr) { - struct ethtool_channels channels, max = { .cmd = ETHTOOL_GCHANNELS }; + struct ethtool_channels channels, curr = { .cmd = ETHTOOL_GCHANNELS }; + u16 from_channel, to_channel; u32 max_rx_in_use = 0; + unsigned int i; if (!dev->ethtool_ops->set_channels || !dev->ethtool_ops->get_channels) return -EOPNOTSUPP; @@ -1664,13 +1677,13 @@ static noinline_for_stack int ethtool_set_channels(struct net_device *dev, if (copy_from_user(&channels, useraddr, sizeof(channels))) return -EFAULT; - dev->ethtool_ops->get_channels(dev, &max); + dev->ethtool_ops->get_channels(dev, &curr); /* ensure new counts are within the maximums */ - if ((channels.rx_count > max.max_rx) || - (channels.tx_count > max.max_tx) || - (channels.combined_count > max.max_combined) || - (channels.other_count > max.max_other)) + if (channels.rx_count > curr.max_rx || + channels.tx_count > curr.max_tx || + channels.combined_count > curr.max_combined || + channels.other_count > curr.max_other) return -EINVAL; /* ensure the new Rx count fits within the configured Rx flow @@ -1680,6 +1693,14 @@ static noinline_for_stack int ethtool_set_channels(struct net_device *dev, (channels.combined_count + channels.rx_count) <= max_rx_in_use) return -EINVAL; + /* Disabling channels, query zero-copy AF_XDP sockets */ + from_channel = channels.combined_count + + min(channels.rx_count, channels.tx_count); + to_channel = curr.combined_count + max(curr.rx_count, curr.tx_count); + for (i = from_channel; i < to_channel; i++) + if (xdp_get_umem_from_qid(dev, i)) + return -EINVAL; + return dev->ethtool_ops->set_channels(dev, &channels); } @@ -2374,13 +2395,17 @@ roll_back: return ret; } -static int ethtool_set_per_queue(struct net_device *dev, void __user *useraddr) +static int ethtool_set_per_queue(struct net_device *dev, + void __user *useraddr, u32 sub_cmd) { struct ethtool_per_queue_op per_queue_opt; if (copy_from_user(&per_queue_opt, useraddr, sizeof(per_queue_opt))) return -EFAULT; + if (per_queue_opt.sub_command != sub_cmd) + return -EINVAL; + switch (per_queue_opt.sub_command) { case ETHTOOL_GCOALESCE: return ethtool_get_per_queue_coalesce(dev, useraddr, &per_queue_opt); @@ -2751,7 +2776,7 @@ int dev_ethtool(struct net *net, struct ifreq *ifr) rc = ethtool_get_phy_stats(dev, useraddr); break; case ETHTOOL_PERQUEUE: - rc = ethtool_set_per_queue(dev, useraddr); + rc = ethtool_set_per_queue(dev, useraddr, sub_cmd); break; case ETHTOOL_GLINKSETTINGS: rc = ethtool_get_link_ksettings(dev, useraddr); diff --git a/net/core/fib_rules.c b/net/core/fib_rules.c index 0ff3953f64aa..ffbb827723a2 100644 --- a/net/core/fib_rules.c +++ b/net/core/fib_rules.c @@ -1063,13 +1063,47 @@ skip: return err; } +static int fib_valid_dumprule_req(const struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct fib_rule_hdr *frh; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*frh))) { + NL_SET_ERR_MSG(extack, "Invalid header for fib rule dump request"); + return -EINVAL; + } + + frh = nlmsg_data(nlh); + if (frh->dst_len || frh->src_len || frh->tos || frh->table || + frh->res1 || frh->res2 || frh->action || frh->flags) { + NL_SET_ERR_MSG(extack, + "Invalid values in header for fib rule dump request"); + return -EINVAL; + } + + if (nlmsg_attrlen(nlh, sizeof(*frh))) { + NL_SET_ERR_MSG(extack, "Invalid data after header in fib rule dump request"); + return -EINVAL; + } + + return 0; +} + static int fib_nl_dumprule(struct sk_buff *skb, struct netlink_callback *cb) { + const struct nlmsghdr *nlh = cb->nlh; struct net *net = sock_net(skb->sk); struct fib_rules_ops *ops; int idx = 0, family; - family = rtnl_msg_family(cb->nlh); + if (cb->strict_check) { + int err = fib_valid_dumprule_req(nlh, cb->extack); + + if (err < 0) + return err; + } + + family = rtnl_msg_family(nlh); if (family != AF_UNSPEC) { /* Protocol specific dump request */ ops = lookup_rules_ops(net, family); diff --git a/net/core/filter.c b/net/core/filter.c index 72db8afb7cb6..e521c5ebc7d1 100644 --- a/net/core/filter.c +++ b/net/core/filter.c @@ -38,6 +38,7 @@ #include <net/protocol.h> #include <net/netlink.h> #include <linux/skbuff.h> +#include <linux/skmsg.h> #include <net/sock.h> #include <net/flow_dissector.h> #include <linux/errno.h> @@ -58,13 +59,17 @@ #include <net/busy_poll.h> #include <net/tcp.h> #include <net/xfrm.h> +#include <net/udp.h> #include <linux/bpf_trace.h> #include <net/xdp_sock.h> #include <linux/inetdevice.h> +#include <net/inet_hashtables.h> +#include <net/inet6_hashtables.h> #include <net/ip_fib.h> #include <net/flow.h> #include <net/arp.h> #include <net/ipv6.h> +#include <net/net_namespace.h> #include <linux/seg6_local.h> #include <net/seg6.h> #include <net/seg6_local.h> @@ -2138,123 +2143,7 @@ static const struct bpf_func_proto bpf_redirect_proto = { .arg2_type = ARG_ANYTHING, }; -BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb, - struct bpf_map *, map, void *, key, u64, flags) -{ - struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); - - /* If user passes invalid input drop the packet. */ - if (unlikely(flags & ~(BPF_F_INGRESS))) - return SK_DROP; - - tcb->bpf.flags = flags; - tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key); - if (!tcb->bpf.sk_redir) - return SK_DROP; - - return SK_PASS; -} - -static const struct bpf_func_proto bpf_sk_redirect_hash_proto = { - .func = bpf_sk_redirect_hash, - .gpl_only = false, - .ret_type = RET_INTEGER, - .arg1_type = ARG_PTR_TO_CTX, - .arg2_type = ARG_CONST_MAP_PTR, - .arg3_type = ARG_PTR_TO_MAP_KEY, - .arg4_type = ARG_ANYTHING, -}; - -BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb, - struct bpf_map *, map, u32, key, u64, flags) -{ - struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); - - /* If user passes invalid input drop the packet. */ - if (unlikely(flags & ~(BPF_F_INGRESS))) - return SK_DROP; - - tcb->bpf.flags = flags; - tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key); - if (!tcb->bpf.sk_redir) - return SK_DROP; - - return SK_PASS; -} - -struct sock *do_sk_redirect_map(struct sk_buff *skb) -{ - struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); - - return tcb->bpf.sk_redir; -} - -static const struct bpf_func_proto bpf_sk_redirect_map_proto = { - .func = bpf_sk_redirect_map, - .gpl_only = false, - .ret_type = RET_INTEGER, - .arg1_type = ARG_PTR_TO_CTX, - .arg2_type = ARG_CONST_MAP_PTR, - .arg3_type = ARG_ANYTHING, - .arg4_type = ARG_ANYTHING, -}; - -BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg_buff *, msg, - struct bpf_map *, map, void *, key, u64, flags) -{ - /* If user passes invalid input drop the packet. */ - if (unlikely(flags & ~(BPF_F_INGRESS))) - return SK_DROP; - - msg->flags = flags; - msg->sk_redir = __sock_hash_lookup_elem(map, key); - if (!msg->sk_redir) - return SK_DROP; - - return SK_PASS; -} - -static const struct bpf_func_proto bpf_msg_redirect_hash_proto = { - .func = bpf_msg_redirect_hash, - .gpl_only = false, - .ret_type = RET_INTEGER, - .arg1_type = ARG_PTR_TO_CTX, - .arg2_type = ARG_CONST_MAP_PTR, - .arg3_type = ARG_PTR_TO_MAP_KEY, - .arg4_type = ARG_ANYTHING, -}; - -BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg_buff *, msg, - struct bpf_map *, map, u32, key, u64, flags) -{ - /* If user passes invalid input drop the packet. */ - if (unlikely(flags & ~(BPF_F_INGRESS))) - return SK_DROP; - - msg->flags = flags; - msg->sk_redir = __sock_map_lookup_elem(map, key); - if (!msg->sk_redir) - return SK_DROP; - - return SK_PASS; -} - -struct sock *do_msg_redirect_map(struct sk_msg_buff *msg) -{ - return msg->sk_redir; -} - -static const struct bpf_func_proto bpf_msg_redirect_map_proto = { - .func = bpf_msg_redirect_map, - .gpl_only = false, - .ret_type = RET_INTEGER, - .arg1_type = ARG_PTR_TO_CTX, - .arg2_type = ARG_CONST_MAP_PTR, - .arg3_type = ARG_ANYTHING, - .arg4_type = ARG_ANYTHING, -}; - -BPF_CALL_2(bpf_msg_apply_bytes, struct sk_msg_buff *, msg, u32, bytes) +BPF_CALL_2(bpf_msg_apply_bytes, struct sk_msg *, msg, u32, bytes) { msg->apply_bytes = bytes; return 0; @@ -2268,7 +2157,7 @@ static const struct bpf_func_proto bpf_msg_apply_bytes_proto = { .arg2_type = ARG_ANYTHING, }; -BPF_CALL_2(bpf_msg_cork_bytes, struct sk_msg_buff *, msg, u32, bytes) +BPF_CALL_2(bpf_msg_cork_bytes, struct sk_msg *, msg, u32, bytes) { msg->cork_bytes = bytes; return 0; @@ -2282,45 +2171,37 @@ static const struct bpf_func_proto bpf_msg_cork_bytes_proto = { .arg2_type = ARG_ANYTHING, }; -#define sk_msg_iter_var(var) \ - do { \ - var++; \ - if (var == MAX_SKB_FRAGS) \ - var = 0; \ - } while (0) - -BPF_CALL_4(bpf_msg_pull_data, - struct sk_msg_buff *, msg, u32, start, u32, end, u64, flags) +BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start, + u32, end, u64, flags) { - unsigned int len = 0, offset = 0, copy = 0, poffset = 0; - int bytes = end - start, bytes_sg_total; - struct scatterlist *sg = msg->sg_data; - int first_sg, last_sg, i, shift; - unsigned char *p, *to, *from; + u32 len = 0, offset = 0, copy = 0, poffset = 0, bytes = end - start; + u32 first_sge, last_sge, i, shift, bytes_sg_total; + struct scatterlist *sge; + u8 *raw, *to, *from; struct page *page; if (unlikely(flags || end <= start)) return -EINVAL; /* First find the starting scatterlist element */ - i = msg->sg_start; + i = msg->sg.start; do { - len = sg[i].length; + len = sk_msg_elem(msg, i)->length; if (start < offset + len) break; offset += len; - sk_msg_iter_var(i); - } while (i != msg->sg_end); + sk_msg_iter_var_next(i); + } while (i != msg->sg.end); if (unlikely(start >= offset + len)) return -EINVAL; - first_sg = i; + first_sge = i; /* The start may point into the sg element so we need to also * account for the headroom. */ bytes_sg_total = start - offset + bytes; - if (!msg->sg_copy[i] && bytes_sg_total <= len) + if (!msg->sg.copy[i] && bytes_sg_total <= len) goto out; /* At this point we need to linearize multiple scatterlist @@ -2334,12 +2215,12 @@ BPF_CALL_4(bpf_msg_pull_data, * will copy the entire sg entry. */ do { - copy += sg[i].length; - sk_msg_iter_var(i); + copy += sk_msg_elem(msg, i)->length; + sk_msg_iter_var_next(i); if (bytes_sg_total <= copy) break; - } while (i != msg->sg_end); - last_sg = i; + } while (i != msg->sg.end); + last_sge = i; if (unlikely(bytes_sg_total > copy)) return -EINVAL; @@ -2348,63 +2229,61 @@ BPF_CALL_4(bpf_msg_pull_data, get_order(copy)); if (unlikely(!page)) return -ENOMEM; - p = page_address(page); - i = first_sg; + raw = page_address(page); + i = first_sge; do { - from = sg_virt(&sg[i]); - len = sg[i].length; - to = p + poffset; + sge = sk_msg_elem(msg, i); + from = sg_virt(sge); + len = sge->length; + to = raw + poffset; memcpy(to, from, len); poffset += len; - sg[i].length = 0; - put_page(sg_page(&sg[i])); + sge->length = 0; + put_page(sg_page(sge)); - sk_msg_iter_var(i); - } while (i != last_sg); + sk_msg_iter_var_next(i); + } while (i != last_sge); - sg[first_sg].length = copy; - sg_set_page(&sg[first_sg], page, copy, 0); + sg_set_page(&msg->sg.data[first_sge], page, copy, 0); /* To repair sg ring we need to shift entries. If we only * had a single entry though we can just replace it and * be done. Otherwise walk the ring and shift the entries. */ - WARN_ON_ONCE(last_sg == first_sg); - shift = last_sg > first_sg ? - last_sg - first_sg - 1 : - MAX_SKB_FRAGS - first_sg + last_sg - 1; + WARN_ON_ONCE(last_sge == first_sge); + shift = last_sge > first_sge ? + last_sge - first_sge - 1 : + MAX_SKB_FRAGS - first_sge + last_sge - 1; if (!shift) goto out; - i = first_sg; - sk_msg_iter_var(i); + i = first_sge; + sk_msg_iter_var_next(i); do { - int move_from; + u32 move_from; - if (i + shift >= MAX_SKB_FRAGS) - move_from = i + shift - MAX_SKB_FRAGS; + if (i + shift >= MAX_MSG_FRAGS) + move_from = i + shift - MAX_MSG_FRAGS; else move_from = i + shift; - - if (move_from == msg->sg_end) + if (move_from == msg->sg.end) break; - sg[i] = sg[move_from]; - sg[move_from].length = 0; - sg[move_from].page_link = 0; - sg[move_from].offset = 0; - - sk_msg_iter_var(i); + msg->sg.data[i] = msg->sg.data[move_from]; + msg->sg.data[move_from].length = 0; + msg->sg.data[move_from].page_link = 0; + msg->sg.data[move_from].offset = 0; + sk_msg_iter_var_next(i); } while (1); - msg->sg_end -= shift; - if (msg->sg_end < 0) - msg->sg_end += MAX_SKB_FRAGS; + + msg->sg.end = msg->sg.end - shift > msg->sg.end ? + msg->sg.end - shift + MAX_MSG_FRAGS : + msg->sg.end - shift; out: - msg->data = sg_virt(&sg[first_sg]) + start - offset; + msg->data = sg_virt(&msg->sg.data[first_sge]) + start - offset; msg->data_end = msg->data + bytes; - return 0; } @@ -2418,6 +2297,137 @@ static const struct bpf_func_proto bpf_msg_pull_data_proto = { .arg4_type = ARG_ANYTHING, }; +BPF_CALL_4(bpf_msg_push_data, struct sk_msg *, msg, u32, start, + u32, len, u64, flags) +{ + struct scatterlist sge, nsge, nnsge, rsge = {0}, *psge; + u32 new, i = 0, l, space, copy = 0, offset = 0; + u8 *raw, *to, *from; + struct page *page; + + if (unlikely(flags)) + return -EINVAL; + + /* First find the starting scatterlist element */ + i = msg->sg.start; + do { + l = sk_msg_elem(msg, i)->length; + + if (start < offset + l) + break; + offset += l; + sk_msg_iter_var_next(i); + } while (i != msg->sg.end); + + if (start >= offset + l) + return -EINVAL; + + space = MAX_MSG_FRAGS - sk_msg_elem_used(msg); + + /* If no space available will fallback to copy, we need at + * least one scatterlist elem available to push data into + * when start aligns to the beginning of an element or two + * when it falls inside an element. We handle the start equals + * offset case because its the common case for inserting a + * header. + */ + if (!space || (space == 1 && start != offset)) + copy = msg->sg.data[i].length; + + page = alloc_pages(__GFP_NOWARN | GFP_ATOMIC | __GFP_COMP, + get_order(copy + len)); + if (unlikely(!page)) + return -ENOMEM; + + if (copy) { + int front, back; + + raw = page_address(page); + + psge = sk_msg_elem(msg, i); + front = start - offset; + back = psge->length - front; + from = sg_virt(psge); + + if (front) + memcpy(raw, from, front); + + if (back) { + from += front; + to = raw + front + len; + + memcpy(to, from, back); + } + + put_page(sg_page(psge)); + } else if (start - offset) { + psge = sk_msg_elem(msg, i); + rsge = sk_msg_elem_cpy(msg, i); + + psge->length = start - offset; + rsge.length -= psge->length; + rsge.offset += start; + + sk_msg_iter_var_next(i); + sg_unmark_end(psge); + sk_msg_iter_next(msg, end); + } + + /* Slot(s) to place newly allocated data */ + new = i; + + /* Shift one or two slots as needed */ + if (!copy) { + sge = sk_msg_elem_cpy(msg, i); + + sk_msg_iter_var_next(i); + sg_unmark_end(&sge); + sk_msg_iter_next(msg, end); + + nsge = sk_msg_elem_cpy(msg, i); + if (rsge.length) { + sk_msg_iter_var_next(i); + nnsge = sk_msg_elem_cpy(msg, i); + } + + while (i != msg->sg.end) { + msg->sg.data[i] = sge; + sge = nsge; + sk_msg_iter_var_next(i); + if (rsge.length) { + nsge = nnsge; + nnsge = sk_msg_elem_cpy(msg, i); + } else { + nsge = sk_msg_elem_cpy(msg, i); + } + } + } + + /* Place newly allocated data buffer */ + sk_mem_charge(msg->sk, len); + msg->sg.size += len; + msg->sg.copy[new] = false; + sg_set_page(&msg->sg.data[new], page, len + copy, 0); + if (rsge.length) { + get_page(sg_page(&rsge)); + sk_msg_iter_var_next(new); + msg->sg.data[new] = rsge; + } + + sk_msg_compute_data_pointers(msg); + return 0; +} + +static const struct bpf_func_proto bpf_msg_push_data_proto = { + .func = bpf_msg_push_data, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_ANYTHING, +}; + BPF_CALL_1(bpf_get_cgroup_classid, const struct sk_buff *, skb) { return task_get_classid(skb); @@ -3923,8 +3933,8 @@ BPF_CALL_5(bpf_setsockopt, struct bpf_sock_ops_kern *, bpf_sock, sk->sk_userlocks |= SOCK_SNDBUF_LOCK; sk->sk_sndbuf = max_t(int, val * 2, SOCK_MIN_SNDBUF); break; - case SO_MAX_PACING_RATE: - sk->sk_max_pacing_rate = val; + case SO_MAX_PACING_RATE: /* 32bit version */ + sk->sk_max_pacing_rate = (val == ~0U) ? ~0UL : val; sk->sk_pacing_rate = min(sk->sk_pacing_rate, sk->sk_max_pacing_rate); break; @@ -4813,6 +4823,149 @@ static const struct bpf_func_proto bpf_lwt_seg6_adjust_srh_proto = { }; #endif /* CONFIG_IPV6_SEG6_BPF */ +#ifdef CONFIG_INET +static struct sock *sk_lookup(struct net *net, struct bpf_sock_tuple *tuple, + struct sk_buff *skb, u8 family, u8 proto) +{ + bool refcounted = false; + struct sock *sk = NULL; + int dif = 0; + + if (skb->dev) + dif = skb->dev->ifindex; + + if (family == AF_INET) { + __be32 src4 = tuple->ipv4.saddr; + __be32 dst4 = tuple->ipv4.daddr; + int sdif = inet_sdif(skb); + + if (proto == IPPROTO_TCP) + sk = __inet_lookup(net, &tcp_hashinfo, skb, 0, + src4, tuple->ipv4.sport, + dst4, tuple->ipv4.dport, + dif, sdif, &refcounted); + else + sk = __udp4_lib_lookup(net, src4, tuple->ipv4.sport, + dst4, tuple->ipv4.dport, + dif, sdif, &udp_table, skb); +#if IS_ENABLED(CONFIG_IPV6) + } else { + struct in6_addr *src6 = (struct in6_addr *)&tuple->ipv6.saddr; + struct in6_addr *dst6 = (struct in6_addr *)&tuple->ipv6.daddr; + u16 hnum = ntohs(tuple->ipv6.dport); + int sdif = inet6_sdif(skb); + + if (proto == IPPROTO_TCP) + sk = __inet6_lookup(net, &tcp_hashinfo, skb, 0, + src6, tuple->ipv6.sport, + dst6, hnum, + dif, sdif, &refcounted); + else if (likely(ipv6_bpf_stub)) + sk = ipv6_bpf_stub->udp6_lib_lookup(net, + src6, tuple->ipv6.sport, + dst6, hnum, + dif, sdif, + &udp_table, skb); +#endif + } + + if (unlikely(sk && !refcounted && !sock_flag(sk, SOCK_RCU_FREE))) { + WARN_ONCE(1, "Found non-RCU, unreferenced socket!"); + sk = NULL; + } + return sk; +} + +/* bpf_sk_lookup performs the core lookup for different types of sockets, + * taking a reference on the socket if it doesn't have the flag SOCK_RCU_FREE. + * Returns the socket as an 'unsigned long' to simplify the casting in the + * callers to satisfy BPF_CALL declarations. + */ +static unsigned long +bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len, + u8 proto, u64 netns_id, u64 flags) +{ + struct net *caller_net; + struct sock *sk = NULL; + u8 family = AF_UNSPEC; + struct net *net; + + family = len == sizeof(tuple->ipv4) ? AF_INET : AF_INET6; + if (unlikely(family == AF_UNSPEC || netns_id > U32_MAX || flags)) + goto out; + + if (skb->dev) + caller_net = dev_net(skb->dev); + else + caller_net = sock_net(skb->sk); + if (netns_id) { + net = get_net_ns_by_id(caller_net, netns_id); + if (unlikely(!net)) + goto out; + sk = sk_lookup(net, tuple, skb, family, proto); + put_net(net); + } else { + net = caller_net; + sk = sk_lookup(net, tuple, skb, family, proto); + } + + if (sk) + sk = sk_to_full_sk(sk); +out: + return (unsigned long) sk; +} + +BPF_CALL_5(bpf_sk_lookup_tcp, struct sk_buff *, skb, + struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags) +{ + return bpf_sk_lookup(skb, tuple, len, IPPROTO_TCP, netns_id, flags); +} + +static const struct bpf_func_proto bpf_sk_lookup_tcp_proto = { + .func = bpf_sk_lookup_tcp, + .gpl_only = false, + .pkt_access = true, + .ret_type = RET_PTR_TO_SOCKET_OR_NULL, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_MEM, + .arg3_type = ARG_CONST_SIZE, + .arg4_type = ARG_ANYTHING, + .arg5_type = ARG_ANYTHING, +}; + +BPF_CALL_5(bpf_sk_lookup_udp, struct sk_buff *, skb, + struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags) +{ + return bpf_sk_lookup(skb, tuple, len, IPPROTO_UDP, netns_id, flags); +} + +static const struct bpf_func_proto bpf_sk_lookup_udp_proto = { + .func = bpf_sk_lookup_udp, + .gpl_only = false, + .pkt_access = true, + .ret_type = RET_PTR_TO_SOCKET_OR_NULL, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_MEM, + .arg3_type = ARG_CONST_SIZE, + .arg4_type = ARG_ANYTHING, + .arg5_type = ARG_ANYTHING, +}; + +BPF_CALL_1(bpf_sk_release, struct sock *, sk) +{ + if (!sock_flag(sk, SOCK_RCU_FREE)) + sock_gen_put(sk); + return 0; +} + +static const struct bpf_func_proto bpf_sk_release_proto = { + .func = bpf_sk_release, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_SOCKET, +}; +#endif /* CONFIG_INET */ + bool bpf_helper_changes_pkt_data(void *func) { if (func == bpf_skb_vlan_push || @@ -4832,6 +4985,7 @@ bool bpf_helper_changes_pkt_data(void *func) func == bpf_xdp_adjust_head || func == bpf_xdp_adjust_meta || func == bpf_msg_pull_data || + func == bpf_msg_push_data || func == bpf_xdp_adjust_tail || #if IS_ENABLED(CONFIG_IPV6_SEG6_BPF) func == bpf_lwt_seg6_store_bytes || @@ -4854,6 +5008,12 @@ bpf_base_func_proto(enum bpf_func_id func_id) return &bpf_map_update_elem_proto; case BPF_FUNC_map_delete_elem: return &bpf_map_delete_elem_proto; + case BPF_FUNC_map_push_elem: + return &bpf_map_push_elem_proto; + case BPF_FUNC_map_pop_elem: + return &bpf_map_pop_elem_proto; + case BPF_FUNC_map_peek_elem: + return &bpf_map_peek_elem_proto; case BPF_FUNC_get_prandom_u32: return &bpf_get_prandom_u32_proto; case BPF_FUNC_get_smp_processor_id: @@ -5019,6 +5179,14 @@ tc_cls_act_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) case BPF_FUNC_skb_ancestor_cgroup_id: return &bpf_skb_ancestor_cgroup_id_proto; #endif +#ifdef CONFIG_INET + case BPF_FUNC_sk_lookup_tcp: + return &bpf_sk_lookup_tcp_proto; + case BPF_FUNC_sk_lookup_udp: + return &bpf_sk_lookup_udp_proto; + case BPF_FUNC_sk_release: + return &bpf_sk_release_proto; +#endif default: return bpf_base_func_proto(func_id); } @@ -5051,6 +5219,9 @@ xdp_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) } } +const struct bpf_func_proto bpf_sock_map_update_proto __weak; +const struct bpf_func_proto bpf_sock_hash_update_proto __weak; + static const struct bpf_func_proto * sock_ops_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) { @@ -5074,6 +5245,9 @@ sock_ops_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) } } +const struct bpf_func_proto bpf_msg_redirect_map_proto __weak; +const struct bpf_func_proto bpf_msg_redirect_hash_proto __weak; + static const struct bpf_func_proto * sk_msg_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) { @@ -5088,13 +5262,16 @@ sk_msg_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) return &bpf_msg_cork_bytes_proto; case BPF_FUNC_msg_pull_data: return &bpf_msg_pull_data_proto; - case BPF_FUNC_get_local_storage: - return &bpf_get_local_storage_proto; + case BPF_FUNC_msg_push_data: + return &bpf_msg_push_data_proto; default: return bpf_base_func_proto(func_id); } } +const struct bpf_func_proto bpf_sk_redirect_map_proto __weak; +const struct bpf_func_proto bpf_sk_redirect_hash_proto __weak; + static const struct bpf_func_proto * sk_skb_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) { @@ -5117,8 +5294,14 @@ sk_skb_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) return &bpf_sk_redirect_map_proto; case BPF_FUNC_sk_redirect_hash: return &bpf_sk_redirect_hash_proto; - case BPF_FUNC_get_local_storage: - return &bpf_get_local_storage_proto; +#ifdef CONFIG_INET + case BPF_FUNC_sk_lookup_tcp: + return &bpf_sk_lookup_tcp_proto; + case BPF_FUNC_sk_lookup_udp: + return &bpf_sk_lookup_udp_proto; + case BPF_FUNC_sk_release: + return &bpf_sk_release_proto; +#endif default: return bpf_base_func_proto(func_id); } @@ -5299,6 +5482,46 @@ static bool sk_filter_is_valid_access(int off, int size, return bpf_skb_is_valid_access(off, size, type, prog, info); } +static bool cg_skb_is_valid_access(int off, int size, + enum bpf_access_type type, + const struct bpf_prog *prog, + struct bpf_insn_access_aux *info) +{ + switch (off) { + case bpf_ctx_range(struct __sk_buff, tc_classid): + case bpf_ctx_range(struct __sk_buff, data_meta): + case bpf_ctx_range(struct __sk_buff, flow_keys): + return false; + case bpf_ctx_range(struct __sk_buff, data): + case bpf_ctx_range(struct __sk_buff, data_end): + if (!capable(CAP_SYS_ADMIN)) + return false; + break; + } + + if (type == BPF_WRITE) { + switch (off) { + case bpf_ctx_range(struct __sk_buff, mark): + case bpf_ctx_range(struct __sk_buff, priority): + case bpf_ctx_range_till(struct __sk_buff, cb[0], cb[4]): + break; + default: + return false; + } + } + + switch (off) { + case bpf_ctx_range(struct __sk_buff, data): + info->reg_type = PTR_TO_PACKET; + break; + case bpf_ctx_range(struct __sk_buff, data_end): + info->reg_type = PTR_TO_PACKET_END; + break; + } + + return bpf_skb_is_valid_access(off, size, type, prog, info); +} + static bool lwt_is_valid_access(int off, int size, enum bpf_access_type type, const struct bpf_prog *prog, @@ -5394,23 +5617,38 @@ static bool __sock_filter_check_size(int off, int size, return size == size_default; } -static bool sock_filter_is_valid_access(int off, int size, - enum bpf_access_type type, - const struct bpf_prog *prog, - struct bpf_insn_access_aux *info) +bool bpf_sock_is_valid_access(int off, int size, enum bpf_access_type type, + struct bpf_insn_access_aux *info) { if (off < 0 || off >= sizeof(struct bpf_sock)) return false; if (off % size != 0) return false; - if (!__sock_filter_check_attach_type(off, type, - prog->expected_attach_type)) - return false; if (!__sock_filter_check_size(off, size, info)) return false; return true; } +static bool sock_filter_is_valid_access(int off, int size, + enum bpf_access_type type, + const struct bpf_prog *prog, + struct bpf_insn_access_aux *info) +{ + if (!bpf_sock_is_valid_access(off, size, type, info)) + return false; + return __sock_filter_check_attach_type(off, type, + prog->expected_attach_type); +} + +static int bpf_noop_prologue(struct bpf_insn *insn_buf, bool direct_write, + const struct bpf_prog *prog) +{ + /* Neither direct read nor direct write requires any preliminary + * action. + */ + return 0; +} + static int bpf_unclone_prologue(struct bpf_insn *insn_buf, bool direct_write, const struct bpf_prog *prog, int drop_verdict) { @@ -6122,10 +6360,10 @@ static u32 bpf_convert_ctx_access(enum bpf_access_type type, return insn - insn_buf; } -static u32 sock_filter_convert_ctx_access(enum bpf_access_type type, - const struct bpf_insn *si, - struct bpf_insn *insn_buf, - struct bpf_prog *prog, u32 *target_size) +u32 bpf_sock_convert_ctx_access(enum bpf_access_type type, + const struct bpf_insn *si, + struct bpf_insn *insn_buf, + struct bpf_prog *prog, u32 *target_size) { struct bpf_insn *insn = insn_buf; int off; @@ -6835,22 +7073,22 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type, switch (si->off) { case offsetof(struct sk_msg_md, data): - *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg_buff, data), + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg, data), si->dst_reg, si->src_reg, - offsetof(struct sk_msg_buff, data)); + offsetof(struct sk_msg, data)); break; case offsetof(struct sk_msg_md, data_end): - *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg_buff, data_end), + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg, data_end), si->dst_reg, si->src_reg, - offsetof(struct sk_msg_buff, data_end)); + offsetof(struct sk_msg, data_end)); break; case offsetof(struct sk_msg_md, family): BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_family) != 2); *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( - struct sk_msg_buff, sk), + struct sk_msg, sk), si->dst_reg, si->src_reg, - offsetof(struct sk_msg_buff, sk)); + offsetof(struct sk_msg, sk)); *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg, offsetof(struct sock_common, skc_family)); break; @@ -6859,9 +7097,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type, BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_daddr) != 4); *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( - struct sk_msg_buff, sk), + struct sk_msg, sk), si->dst_reg, si->src_reg, - offsetof(struct sk_msg_buff, sk)); + offsetof(struct sk_msg, sk)); *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, offsetof(struct sock_common, skc_daddr)); break; @@ -6871,9 +7109,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type, skc_rcv_saddr) != 4); *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( - struct sk_msg_buff, sk), + struct sk_msg, sk), si->dst_reg, si->src_reg, - offsetof(struct sk_msg_buff, sk)); + offsetof(struct sk_msg, sk)); *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, offsetof(struct sock_common, skc_rcv_saddr)); @@ -6888,9 +7126,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type, off = si->off; off -= offsetof(struct sk_msg_md, remote_ip6[0]); *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( - struct sk_msg_buff, sk), + struct sk_msg, sk), si->dst_reg, si->src_reg, - offsetof(struct sk_msg_buff, sk)); + offsetof(struct sk_msg, sk)); *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, offsetof(struct sock_common, skc_v6_daddr.s6_addr32[0]) + @@ -6909,9 +7147,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type, off = si->off; off -= offsetof(struct sk_msg_md, local_ip6[0]); *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( - struct sk_msg_buff, sk), + struct sk_msg, sk), si->dst_reg, si->src_reg, - offsetof(struct sk_msg_buff, sk)); + offsetof(struct sk_msg, sk)); *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, offsetof(struct sock_common, skc_v6_rcv_saddr.s6_addr32[0]) + @@ -6925,9 +7163,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type, BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_dport) != 2); *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( - struct sk_msg_buff, sk), + struct sk_msg, sk), si->dst_reg, si->src_reg, - offsetof(struct sk_msg_buff, sk)); + offsetof(struct sk_msg, sk)); *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg, offsetof(struct sock_common, skc_dport)); #ifndef __BIG_ENDIAN_BITFIELD @@ -6939,9 +7177,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type, BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_num) != 2); *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( - struct sk_msg_buff, sk), + struct sk_msg, sk), si->dst_reg, si->src_reg, - offsetof(struct sk_msg_buff, sk)); + offsetof(struct sk_msg, sk)); *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg, offsetof(struct sock_common, skc_num)); break; @@ -6977,6 +7215,7 @@ const struct bpf_verifier_ops xdp_verifier_ops = { .get_func_proto = xdp_func_proto, .is_valid_access = xdp_is_valid_access, .convert_ctx_access = xdp_convert_ctx_access, + .gen_prologue = bpf_noop_prologue, }; const struct bpf_prog_ops xdp_prog_ops = { @@ -6985,7 +7224,7 @@ const struct bpf_prog_ops xdp_prog_ops = { const struct bpf_verifier_ops cg_skb_verifier_ops = { .get_func_proto = cg_skb_func_proto, - .is_valid_access = sk_filter_is_valid_access, + .is_valid_access = cg_skb_is_valid_access, .convert_ctx_access = bpf_convert_ctx_access, }; @@ -7037,7 +7276,7 @@ const struct bpf_prog_ops lwt_seg6local_prog_ops = { const struct bpf_verifier_ops cg_sock_verifier_ops = { .get_func_proto = sock_filter_func_proto, .is_valid_access = sock_filter_is_valid_access, - .convert_ctx_access = sock_filter_convert_ctx_access, + .convert_ctx_access = bpf_sock_convert_ctx_access, }; const struct bpf_prog_ops cg_sock_prog_ops = { @@ -7075,6 +7314,7 @@ const struct bpf_verifier_ops sk_msg_verifier_ops = { .get_func_proto = sk_msg_func_proto, .is_valid_access = sk_msg_is_valid_access, .convert_ctx_access = sk_msg_convert_ctx_access, + .gen_prologue = bpf_noop_prologue, }; const struct bpf_prog_ops sk_msg_prog_ops = { diff --git a/net/core/neighbour.c b/net/core/neighbour.c index 20e0d3308148..41954e42a2de 100644 --- a/net/core/neighbour.c +++ b/net/core/neighbour.c @@ -232,7 +232,8 @@ static void pneigh_queue_purge(struct sk_buff_head *list) } } -static void neigh_flush_dev(struct neigh_table *tbl, struct net_device *dev) +static void neigh_flush_dev(struct neigh_table *tbl, struct net_device *dev, + bool skip_perm) { int i; struct neigh_hash_table *nht; @@ -250,6 +251,10 @@ static void neigh_flush_dev(struct neigh_table *tbl, struct net_device *dev) np = &n->next; continue; } + if (skip_perm && n->nud_state & NUD_PERMANENT) { + np = &n->next; + continue; + } rcu_assign_pointer(*np, rcu_dereference_protected(n->next, lockdep_is_held(&tbl->lock))); @@ -285,21 +290,35 @@ static void neigh_flush_dev(struct neigh_table *tbl, struct net_device *dev) void neigh_changeaddr(struct neigh_table *tbl, struct net_device *dev) { write_lock_bh(&tbl->lock); - neigh_flush_dev(tbl, dev); + neigh_flush_dev(tbl, dev, false); write_unlock_bh(&tbl->lock); } EXPORT_SYMBOL(neigh_changeaddr); -int neigh_ifdown(struct neigh_table *tbl, struct net_device *dev) +static int __neigh_ifdown(struct neigh_table *tbl, struct net_device *dev, + bool skip_perm) { write_lock_bh(&tbl->lock); - neigh_flush_dev(tbl, dev); + neigh_flush_dev(tbl, dev, skip_perm); pneigh_ifdown_and_unlock(tbl, dev); del_timer_sync(&tbl->proxy_timer); pneigh_queue_purge(&tbl->proxy_queue); return 0; } + +int neigh_carrier_down(struct neigh_table *tbl, struct net_device *dev) +{ + __neigh_ifdown(tbl, dev, true); + return 0; +} +EXPORT_SYMBOL(neigh_carrier_down); + +int neigh_ifdown(struct neigh_table *tbl, struct net_device *dev) +{ + __neigh_ifdown(tbl, dev, false); + return 0; +} EXPORT_SYMBOL(neigh_ifdown); static struct neighbour *neigh_alloc(struct neigh_table *tbl, struct net_device *dev) @@ -1148,8 +1167,7 @@ int neigh_update(struct neighbour *neigh, const u8 *lladdr, u8 new, neigh->nud_state = new; err = 0; notify = old & NUD_VALID; - if (((old & (NUD_INCOMPLETE | NUD_PROBE)) || - (flags & NEIGH_UPDATE_F_ADMIN)) && + if ((old & (NUD_INCOMPLETE | NUD_PROBE)) && (new & NUD_FAILED)) { neigh_invalidate(neigh); notify = 1; @@ -2164,15 +2182,47 @@ errout: return err; } +static int neightbl_valid_dump_info(const struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct ndtmsg *ndtm; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ndtm))) { + NL_SET_ERR_MSG(extack, "Invalid header for neighbor table dump request"); + return -EINVAL; + } + + ndtm = nlmsg_data(nlh); + if (ndtm->ndtm_pad1 || ndtm->ndtm_pad2) { + NL_SET_ERR_MSG(extack, "Invalid values in header for neighbor table dump request"); + return -EINVAL; + } + + if (nlmsg_attrlen(nlh, sizeof(*ndtm))) { + NL_SET_ERR_MSG(extack, "Invalid data after header in neighbor table dump request"); + return -EINVAL; + } + + return 0; +} + static int neightbl_dump_info(struct sk_buff *skb, struct netlink_callback *cb) { + const struct nlmsghdr *nlh = cb->nlh; struct net *net = sock_net(skb->sk); int family, tidx, nidx = 0; int tbl_skip = cb->args[0]; int neigh_skip = cb->args[1]; struct neigh_table *tbl; - family = ((struct rtgenmsg *) nlmsg_data(cb->nlh))->rtgen_family; + if (cb->strict_check) { + int err = neightbl_valid_dump_info(nlh, cb->extack); + + if (err < 0) + return err; + } + + family = ((struct rtgenmsg *)nlmsg_data(nlh))->rtgen_family; for (tidx = 0; tidx < NEIGH_NR_TABLES; tidx++) { struct neigh_parms *p; @@ -2185,7 +2235,7 @@ static int neightbl_dump_info(struct sk_buff *skb, struct netlink_callback *cb) continue; if (neightbl_fill_info(skb, tbl, NETLINK_CB(cb->skb).portid, - cb->nlh->nlmsg_seq, RTM_NEWNEIGHTBL, + nlh->nlmsg_seq, RTM_NEWNEIGHTBL, NLM_F_MULTI) < 0) break; @@ -2200,7 +2250,7 @@ static int neightbl_dump_info(struct sk_buff *skb, struct netlink_callback *cb) if (neightbl_fill_param_info(skb, tbl, p, NETLINK_CB(cb->skb).portid, - cb->nlh->nlmsg_seq, + nlh->nlmsg_seq, RTM_NEWNEIGHTBL, NLM_F_MULTI) < 0) goto out; @@ -2314,7 +2364,7 @@ static bool neigh_master_filtered(struct net_device *dev, int master_idx) if (!master_idx) return false; - master = netdev_master_upper_dev_get(dev); + master = dev ? netdev_master_upper_dev_get(dev) : NULL; if (!master || master->ifindex != master_idx) return true; @@ -2323,41 +2373,30 @@ static bool neigh_master_filtered(struct net_device *dev, int master_idx) static bool neigh_ifindex_filtered(struct net_device *dev, int filter_idx) { - if (filter_idx && dev->ifindex != filter_idx) + if (filter_idx && (!dev || dev->ifindex != filter_idx)) return true; return false; } +struct neigh_dump_filter { + int master_idx; + int dev_idx; +}; + static int neigh_dump_table(struct neigh_table *tbl, struct sk_buff *skb, - struct netlink_callback *cb) + struct netlink_callback *cb, + struct neigh_dump_filter *filter) { struct net *net = sock_net(skb->sk); - const struct nlmsghdr *nlh = cb->nlh; - struct nlattr *tb[NDA_MAX + 1]; struct neighbour *n; int rc, h, s_h = cb->args[1]; int idx, s_idx = idx = cb->args[2]; struct neigh_hash_table *nht; - int filter_master_idx = 0, filter_idx = 0; unsigned int flags = NLM_F_MULTI; - int err; - err = nlmsg_parse(nlh, sizeof(struct ndmsg), tb, NDA_MAX, NULL, NULL); - if (!err) { - if (tb[NDA_IFINDEX]) { - if (nla_len(tb[NDA_IFINDEX]) != sizeof(u32)) - return -EINVAL; - filter_idx = nla_get_u32(tb[NDA_IFINDEX]); - } - if (tb[NDA_MASTER]) { - if (nla_len(tb[NDA_MASTER]) != sizeof(u32)) - return -EINVAL; - filter_master_idx = nla_get_u32(tb[NDA_MASTER]); - } - if (filter_idx || filter_master_idx) - flags |= NLM_F_DUMP_FILTERED; - } + if (filter->dev_idx || filter->master_idx) + flags |= NLM_F_DUMP_FILTERED; rcu_read_lock_bh(); nht = rcu_dereference_bh(tbl->nht); @@ -2370,8 +2409,8 @@ static int neigh_dump_table(struct neigh_table *tbl, struct sk_buff *skb, n = rcu_dereference_bh(n->next)) { if (idx < s_idx || !net_eq(dev_net(n->dev), net)) goto next; - if (neigh_ifindex_filtered(n->dev, filter_idx) || - neigh_master_filtered(n->dev, filter_master_idx)) + if (neigh_ifindex_filtered(n->dev, filter->dev_idx) || + neigh_master_filtered(n->dev, filter->master_idx)) goto next; if (neigh_fill_info(skb, n, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, @@ -2393,12 +2432,17 @@ out: } static int pneigh_dump_table(struct neigh_table *tbl, struct sk_buff *skb, - struct netlink_callback *cb) + struct netlink_callback *cb, + struct neigh_dump_filter *filter) { struct pneigh_entry *n; struct net *net = sock_net(skb->sk); int rc, h, s_h = cb->args[3]; int idx, s_idx = idx = cb->args[4]; + unsigned int flags = NLM_F_MULTI; + + if (filter->dev_idx || filter->master_idx) + flags |= NLM_F_DUMP_FILTERED; read_lock_bh(&tbl->lock); @@ -2408,10 +2452,12 @@ static int pneigh_dump_table(struct neigh_table *tbl, struct sk_buff *skb, for (n = tbl->phash_buckets[h], idx = 0; n; n = n->next) { if (idx < s_idx || pneigh_net(n) != net) goto next; + if (neigh_ifindex_filtered(n->dev, filter->dev_idx) || + neigh_master_filtered(n->dev, filter->master_idx)) + goto next; if (pneigh_fill_info(skb, n, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, - RTM_NEWNEIGH, - NLM_F_MULTI, tbl) < 0) { + RTM_NEWNEIGH, flags, tbl) < 0) { read_unlock_bh(&tbl->lock); rc = -1; goto out; @@ -2430,22 +2476,91 @@ out: } +static int neigh_valid_dump_req(const struct nlmsghdr *nlh, + bool strict_check, + struct neigh_dump_filter *filter, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[NDA_MAX + 1]; + int err, i; + + if (strict_check) { + struct ndmsg *ndm; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ndm))) { + NL_SET_ERR_MSG(extack, "Invalid header for neighbor dump request"); + return -EINVAL; + } + + ndm = nlmsg_data(nlh); + if (ndm->ndm_pad1 || ndm->ndm_pad2 || ndm->ndm_ifindex || + ndm->ndm_state || ndm->ndm_flags || ndm->ndm_type) { + NL_SET_ERR_MSG(extack, "Invalid values in header for neighbor dump request"); + return -EINVAL; + } + + err = nlmsg_parse_strict(nlh, sizeof(struct ndmsg), tb, NDA_MAX, + NULL, extack); + } else { + err = nlmsg_parse(nlh, sizeof(struct ndmsg), tb, NDA_MAX, + NULL, extack); + } + if (err < 0) + return err; + + for (i = 0; i <= NDA_MAX; ++i) { + if (!tb[i]) + continue; + + /* all new attributes should require strict_check */ + switch (i) { + case NDA_IFINDEX: + if (nla_len(tb[i]) != sizeof(u32)) { + NL_SET_ERR_MSG(extack, "Invalid IFINDEX attribute in neighbor dump request"); + return -EINVAL; + } + filter->dev_idx = nla_get_u32(tb[i]); + break; + case NDA_MASTER: + if (nla_len(tb[i]) != sizeof(u32)) { + NL_SET_ERR_MSG(extack, "Invalid MASTER attribute in neighbor dump request"); + return -EINVAL; + } + filter->master_idx = nla_get_u32(tb[i]); + break; + default: + if (strict_check) { + NL_SET_ERR_MSG(extack, "Unsupported attribute in neighbor dump request"); + return -EINVAL; + } + } + } + + return 0; +} + static int neigh_dump_info(struct sk_buff *skb, struct netlink_callback *cb) { + const struct nlmsghdr *nlh = cb->nlh; + struct neigh_dump_filter filter = {}; struct neigh_table *tbl; int t, family, s_t; int proxy = 0; int err; - family = ((struct rtgenmsg *) nlmsg_data(cb->nlh))->rtgen_family; + family = ((struct rtgenmsg *)nlmsg_data(nlh))->rtgen_family; /* check for full ndmsg structure presence, family member is * the same for both structures */ - if (nlmsg_len(cb->nlh) >= sizeof(struct ndmsg) && - ((struct ndmsg *) nlmsg_data(cb->nlh))->ndm_flags == NTF_PROXY) + if (nlmsg_len(nlh) >= sizeof(struct ndmsg) && + ((struct ndmsg *)nlmsg_data(nlh))->ndm_flags == NTF_PROXY) proxy = 1; + err = neigh_valid_dump_req(nlh, cb->strict_check, &filter, cb->extack); + if (err < 0 && cb->strict_check) + return err; + s_t = cb->args[0]; for (t = 0; t < NEIGH_NR_TABLES; t++) { @@ -2459,9 +2574,9 @@ static int neigh_dump_info(struct sk_buff *skb, struct netlink_callback *cb) memset(&cb->args[1], 0, sizeof(cb->args) - sizeof(cb->args[0])); if (proxy) - err = pneigh_dump_table(tbl, skb, cb); + err = pneigh_dump_table(tbl, skb, cb, &filter); else - err = neigh_dump_table(tbl, skb, cb); + err = neigh_dump_table(tbl, skb, cb, &filter); if (err < 0) break; } diff --git a/net/core/net_namespace.c b/net/core/net_namespace.c index 670c84b1bfc2..fefe72774aeb 100644 --- a/net/core/net_namespace.c +++ b/net/core/net_namespace.c @@ -853,6 +853,12 @@ static int rtnl_net_dumpid(struct sk_buff *skb, struct netlink_callback *cb) .s_idx = cb->args[0], }; + if (cb->strict_check && + nlmsg_attrlen(cb->nlh, sizeof(struct rtgenmsg))) { + NL_SET_ERR_MSG(cb->extack, "Unknown data in network namespace id dump request"); + return -EINVAL; + } + spin_lock_bh(&net->nsid_lock); idr_for_each(&net->netns_ids, rtnl_net_dumpid_one, &net_cb); spin_unlock_bh(&net->nsid_lock); diff --git a/net/core/netclassid_cgroup.c b/net/core/netclassid_cgroup.c index 5e4f04004a49..7bf833598615 100644 --- a/net/core/netclassid_cgroup.c +++ b/net/core/netclassid_cgroup.c @@ -106,6 +106,7 @@ static int write_classid(struct cgroup_subsys_state *css, struct cftype *cft, iterate_fd(p->files, 0, update_classid_sock, (void *)(unsigned long)cs->classid); task_unlock(p); + cond_resched(); } css_task_iter_end(&it); diff --git a/net/core/netpoll.c b/net/core/netpoll.c index 3219a2932463..5da9552b186b 100644 --- a/net/core/netpoll.c +++ b/net/core/netpoll.c @@ -57,7 +57,6 @@ DEFINE_STATIC_SRCU(netpoll_srcu); MAX_UDP_CHUNK) static void zap_completion_queue(void); -static void netpoll_async_cleanup(struct work_struct *work); static unsigned int carrier_timeout = 4; module_param(carrier_timeout, uint, 0644); @@ -135,27 +134,9 @@ static void queue_process(struct work_struct *work) } } -/* - * Check whether delayed processing was scheduled for our NIC. If so, - * we attempt to grab the poll lock and use ->poll() to pump the card. - * If this fails, either we've recursed in ->poll() or it's already - * running on another CPU. - * - * Note: we don't mask interrupts with this lock because we're using - * trylock here and interrupts are already disabled in the softirq - * case. Further, we test the poll_owner to avoid recursion on UP - * systems where the lock doesn't exist. - */ static void poll_one_napi(struct napi_struct *napi) { - int work = 0; - - /* net_rx_action's ->poll() invocations and our's are - * synchronized by this test which is only made while - * holding the napi->poll_lock. - */ - if (!test_bit(NAPI_STATE_SCHED, &napi->state)) - return; + int work; /* If we set this bit but see that it has already been set, * that indicates that napi has been disabled and we need @@ -607,7 +588,6 @@ int __netpoll_setup(struct netpoll *np, struct net_device *ndev) np->dev = ndev; strlcpy(np->dev_name, ndev->name, IFNAMSIZ); - INIT_WORK(&np->cleanup_work, netpoll_async_cleanup); if (ndev->priv_flags & IFF_DISABLE_NETPOLL) { np_err(np, "%s doesn't support polling, aborting\n", @@ -806,10 +786,6 @@ void __netpoll_cleanup(struct netpoll *np) { struct netpoll_info *npinfo; - /* rtnl_dereference would be preferable here but - * rcu_cleanup_netpoll path can put us in here safely without - * holding the rtnl, so plain rcu_dereference it is - */ npinfo = rtnl_dereference(np->dev->npinfo); if (!npinfo) return; @@ -830,21 +806,16 @@ void __netpoll_cleanup(struct netpoll *np) } EXPORT_SYMBOL_GPL(__netpoll_cleanup); -static void netpoll_async_cleanup(struct work_struct *work) +void __netpoll_free(struct netpoll *np) { - struct netpoll *np = container_of(work, struct netpoll, cleanup_work); + ASSERT_RTNL(); - rtnl_lock(); + /* Wait for transmitting packets to finish before freeing. */ + synchronize_rcu_bh(); __netpoll_cleanup(np); - rtnl_unlock(); kfree(np); } - -void __netpoll_free_async(struct netpoll *np) -{ - schedule_work(&np->cleanup_work); -} -EXPORT_SYMBOL_GPL(__netpoll_free_async); +EXPORT_SYMBOL_GPL(__netpoll_free); void netpoll_cleanup(struct netpoll *np) { diff --git a/net/core/rtnetlink.c b/net/core/rtnetlink.c index 35162e1b06ad..e01274bd5e3e 100644 --- a/net/core/rtnetlink.c +++ b/net/core/rtnetlink.c @@ -59,7 +59,7 @@ #include <net/rtnetlink.h> #include <net/net_namespace.h> -#define RTNL_MAX_TYPE 48 +#define RTNL_MAX_TYPE 49 #define RTNL_SLAVE_MAX_TYPE 36 struct rtnl_link { @@ -1878,8 +1878,52 @@ struct net *rtnl_get_net_ns_capable(struct sock *sk, int netnsid) } EXPORT_SYMBOL_GPL(rtnl_get_net_ns_capable); +static int rtnl_valid_dump_ifinfo_req(const struct nlmsghdr *nlh, + bool strict_check, struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + int hdrlen; + + if (strict_check) { + struct ifinfomsg *ifm; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ifm))) { + NL_SET_ERR_MSG(extack, "Invalid header for link dump"); + return -EINVAL; + } + + ifm = nlmsg_data(nlh); + if (ifm->__ifi_pad || ifm->ifi_type || ifm->ifi_flags || + ifm->ifi_change) { + NL_SET_ERR_MSG(extack, "Invalid values in header for link dump request"); + return -EINVAL; + } + if (ifm->ifi_index) { + NL_SET_ERR_MSG(extack, "Filter by device index not supported for link dumps"); + return -EINVAL; + } + + return nlmsg_parse_strict(nlh, sizeof(*ifm), tb, IFLA_MAX, + ifla_policy, extack); + } + + /* A hack to preserve kernel<->userspace interface. + * The correct header is ifinfomsg. It is consistent with rtnl_getlink. + * However, before Linux v3.9 the code here assumed rtgenmsg and that's + * what iproute2 < v3.9.0 used. + * We can detect the old iproute2. Even including the IFLA_EXT_MASK + * attribute, its netlink message is shorter than struct ifinfomsg. + */ + hdrlen = nlmsg_len(nlh) < sizeof(struct ifinfomsg) ? + sizeof(struct rtgenmsg) : sizeof(struct ifinfomsg); + + return nlmsg_parse(nlh, hdrlen, tb, IFLA_MAX, ifla_policy, extack); +} + static int rtnl_dump_ifinfo(struct sk_buff *skb, struct netlink_callback *cb) { + struct netlink_ext_ack *extack = cb->extack; + const struct nlmsghdr *nlh = cb->nlh; struct net *net = sock_net(skb->sk); struct net *tgt_net = net; int h, s_h; @@ -1892,46 +1936,54 @@ static int rtnl_dump_ifinfo(struct sk_buff *skb, struct netlink_callback *cb) unsigned int flags = NLM_F_MULTI; int master_idx = 0; int netnsid = -1; - int err; - int hdrlen; + int err, i; s_h = cb->args[0]; s_idx = cb->args[1]; - /* A hack to preserve kernel<->userspace interface. - * The correct header is ifinfomsg. It is consistent with rtnl_getlink. - * However, before Linux v3.9 the code here assumed rtgenmsg and that's - * what iproute2 < v3.9.0 used. - * We can detect the old iproute2. Even including the IFLA_EXT_MASK - * attribute, its netlink message is shorter than struct ifinfomsg. - */ - hdrlen = nlmsg_len(cb->nlh) < sizeof(struct ifinfomsg) ? - sizeof(struct rtgenmsg) : sizeof(struct ifinfomsg); + err = rtnl_valid_dump_ifinfo_req(nlh, cb->strict_check, tb, extack); + if (err < 0) { + if (cb->strict_check) + return err; + + goto walk_entries; + } - if (nlmsg_parse(cb->nlh, hdrlen, tb, IFLA_MAX, - ifla_policy, NULL) >= 0) { - if (tb[IFLA_TARGET_NETNSID]) { - netnsid = nla_get_s32(tb[IFLA_TARGET_NETNSID]); + for (i = 0; i <= IFLA_MAX; ++i) { + if (!tb[i]) + continue; + + /* new attributes should only be added with strict checking */ + switch (i) { + case IFLA_TARGET_NETNSID: + netnsid = nla_get_s32(tb[i]); tgt_net = rtnl_get_net_ns_capable(skb->sk, netnsid); if (IS_ERR(tgt_net)) { - tgt_net = net; - netnsid = -1; + NL_SET_ERR_MSG(extack, "Invalid target network namespace id"); + return PTR_ERR(tgt_net); + } + break; + case IFLA_EXT_MASK: + ext_filter_mask = nla_get_u32(tb[i]); + break; + case IFLA_MASTER: + master_idx = nla_get_u32(tb[i]); + break; + case IFLA_LINKINFO: + kind_ops = linkinfo_to_kind_ops(tb[i]); + break; + default: + if (cb->strict_check) { + NL_SET_ERR_MSG(extack, "Unsupported attribute in link dump request"); + return -EINVAL; } } - - if (tb[IFLA_EXT_MASK]) - ext_filter_mask = nla_get_u32(tb[IFLA_EXT_MASK]); - - if (tb[IFLA_MASTER]) - master_idx = nla_get_u32(tb[IFLA_MASTER]); - - if (tb[IFLA_LINKINFO]) - kind_ops = linkinfo_to_kind_ops(tb[IFLA_LINKINFO]); - - if (master_idx || kind_ops) - flags |= NLM_F_DUMP_FILTERED; } + if (master_idx || kind_ops) + flags |= NLM_F_DUMP_FILTERED; + +walk_entries: for (h = s_h; h < NETDEV_HASHENTRIES; h++, s_idx = 0) { idx = 0; head = &tgt_net->dev_index_head[h]; @@ -1943,8 +1995,7 @@ static int rtnl_dump_ifinfo(struct sk_buff *skb, struct netlink_callback *cb) err = rtnl_fill_ifinfo(skb, dev, net, RTM_NEWLINK, NETLINK_CB(cb->skb).portid, - cb->nlh->nlmsg_seq, 0, - flags, + nlh->nlmsg_seq, 0, flags, ext_filter_mask, 0, NULL, 0, netnsid); @@ -2852,6 +2903,12 @@ struct net_device *rtnl_create_link(struct net *net, else if (ops->get_num_rx_queues) num_rx_queues = ops->get_num_rx_queues(); + if (num_tx_queues < 1 || num_tx_queues > 4096) + return ERR_PTR(-EINVAL); + + if (num_rx_queues < 1 || num_rx_queues > 4096) + return ERR_PTR(-EINVAL); + dev = alloc_netdev_mqs(ops->priv_size, ifname, name_assign_type, ops->setup, num_tx_queues, num_rx_queues); if (!dev) @@ -3276,6 +3333,7 @@ static int rtnl_dump_all(struct sk_buff *skb, struct netlink_callback *cb) int idx; int s_idx = cb->family; int type = cb->nlh->nlmsg_type - RTM_BASE; + int ret = 0; if (s_idx == 0) s_idx = 1; @@ -3308,12 +3366,13 @@ static int rtnl_dump_all(struct sk_buff *skb, struct netlink_callback *cb) cb->prev_seq = 0; cb->seq = 0; } - if (dumpit(skb, cb)) + ret = dumpit(skb, cb); + if (ret < 0) break; } cb->family = idx; - return skb->len; + return skb->len ? : ret; } struct sk_buff *rtmsg_ifinfo_build_skb(int type, struct net_device *dev, @@ -3541,6 +3600,11 @@ static int rtnl_fdb_add(struct sk_buff *skb, struct nlmsghdr *nlh, return -EINVAL; } + if (dev->type != ARPHRD_ETHER) { + NL_SET_ERR_MSG(extack, "FDB add only supported for Ethernet devices"); + return -EINVAL; + } + addr = nla_data(tb[NDA_LLADDR]); err = fdb_vid_parse(tb[NDA_VLAN], &vid, extack); @@ -3645,6 +3709,11 @@ static int rtnl_fdb_del(struct sk_buff *skb, struct nlmsghdr *nlh, return -EINVAL; } + if (dev->type != ARPHRD_ETHER) { + NL_SET_ERR_MSG(extack, "FDB delete only supported for Ethernet devices"); + return -EINVAL; + } + addr = nla_data(tb[NDA_LLADDR]); err = fdb_vid_parse(tb[NDA_VLAN], &vid, extack); @@ -3742,14 +3811,100 @@ out: } EXPORT_SYMBOL(ndo_dflt_fdb_dump); +static int valid_fdb_dump_strict(const struct nlmsghdr *nlh, + int *br_idx, int *brport_idx, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[NDA_MAX + 1]; + struct ndmsg *ndm; + int err, i; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ndm))) { + NL_SET_ERR_MSG(extack, "Invalid header for fdb dump request"); + return -EINVAL; + } + + ndm = nlmsg_data(nlh); + if (ndm->ndm_pad1 || ndm->ndm_pad2 || ndm->ndm_state || + ndm->ndm_flags || ndm->ndm_type) { + NL_SET_ERR_MSG(extack, "Invalid values in header for fbd dump request"); + return -EINVAL; + } + + err = nlmsg_parse_strict(nlh, sizeof(struct ndmsg), tb, NDA_MAX, + NULL, extack); + if (err < 0) + return err; + + *brport_idx = ndm->ndm_ifindex; + for (i = 0; i <= NDA_MAX; ++i) { + if (!tb[i]) + continue; + + switch (i) { + case NDA_IFINDEX: + if (nla_len(tb[i]) != sizeof(u32)) { + NL_SET_ERR_MSG(extack, "Invalid IFINDEX attribute in fdb dump request"); + return -EINVAL; + } + *brport_idx = nla_get_u32(tb[NDA_IFINDEX]); + break; + case NDA_MASTER: + if (nla_len(tb[i]) != sizeof(u32)) { + NL_SET_ERR_MSG(extack, "Invalid MASTER attribute in fdb dump request"); + return -EINVAL; + } + *br_idx = nla_get_u32(tb[NDA_MASTER]); + break; + default: + NL_SET_ERR_MSG(extack, "Unsupported attribute in fdb dump request"); + return -EINVAL; + } + } + + return 0; +} + +static int valid_fdb_dump_legacy(const struct nlmsghdr *nlh, + int *br_idx, int *brport_idx, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_MAX+1]; + int err; + + /* A hack to preserve kernel<->userspace interface. + * Before Linux v4.12 this code accepted ndmsg since iproute2 v3.3.0. + * However, ndmsg is shorter than ifinfomsg thus nlmsg_parse() bails. + * So, check for ndmsg with an optional u32 attribute (not used here). + * Fortunately these sizes don't conflict with the size of ifinfomsg + * with an optional attribute. + */ + if (nlmsg_len(nlh) != sizeof(struct ndmsg) && + (nlmsg_len(nlh) != sizeof(struct ndmsg) + + nla_attr_size(sizeof(u32)))) { + struct ifinfomsg *ifm; + + err = nlmsg_parse(nlh, sizeof(struct ifinfomsg), tb, IFLA_MAX, + ifla_policy, extack); + if (err < 0) { + return -EINVAL; + } else if (err == 0) { + if (tb[IFLA_MASTER]) + *br_idx = nla_get_u32(tb[IFLA_MASTER]); + } + + ifm = nlmsg_data(nlh); + *brport_idx = ifm->ifi_index; + } + return 0; +} + static int rtnl_fdb_dump(struct sk_buff *skb, struct netlink_callback *cb) { struct net_device *dev; - struct nlattr *tb[IFLA_MAX+1]; struct net_device *br_dev = NULL; const struct net_device_ops *ops = NULL; const struct net_device_ops *cops = NULL; - struct ifinfomsg *ifm = nlmsg_data(cb->nlh); struct net *net = sock_net(skb->sk); struct hlist_head *head; int brport_idx = 0; @@ -3759,16 +3914,14 @@ static int rtnl_fdb_dump(struct sk_buff *skb, struct netlink_callback *cb) int err = 0; int fidx = 0; - err = nlmsg_parse(cb->nlh, sizeof(struct ifinfomsg), tb, - IFLA_MAX, ifla_policy, NULL); - if (err < 0) { - return -EINVAL; - } else if (err == 0) { - if (tb[IFLA_MASTER]) - br_idx = nla_get_u32(tb[IFLA_MASTER]); - } - - brport_idx = ifm->ifi_index; + if (cb->strict_check) + err = valid_fdb_dump_strict(cb->nlh, &br_idx, &brport_idx, + cb->extack); + else + err = valid_fdb_dump_legacy(cb->nlh, &br_idx, &brport_idx, + cb->extack); + if (err < 0) + return err; if (br_idx) { br_dev = __dev_get_by_index(net, br_idx); @@ -3953,28 +4106,72 @@ nla_put_failure: } EXPORT_SYMBOL_GPL(ndo_dflt_bridge_getlink); +static int valid_bridge_getlink_req(const struct nlmsghdr *nlh, + bool strict_check, u32 *filter_mask, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_MAX+1]; + int err, i; + + if (strict_check) { + struct ifinfomsg *ifm; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ifm))) { + NL_SET_ERR_MSG(extack, "Invalid header for bridge link dump"); + return -EINVAL; + } + + ifm = nlmsg_data(nlh); + if (ifm->__ifi_pad || ifm->ifi_type || ifm->ifi_flags || + ifm->ifi_change || ifm->ifi_index) { + NL_SET_ERR_MSG(extack, "Invalid values in header for bridge link dump request"); + return -EINVAL; + } + + err = nlmsg_parse_strict(nlh, sizeof(struct ifinfomsg), tb, + IFLA_MAX, ifla_policy, extack); + } else { + err = nlmsg_parse(nlh, sizeof(struct ifinfomsg), tb, + IFLA_MAX, ifla_policy, extack); + } + if (err < 0) + return err; + + /* new attributes should only be added with strict checking */ + for (i = 0; i <= IFLA_MAX; ++i) { + if (!tb[i]) + continue; + + switch (i) { + case IFLA_EXT_MASK: + *filter_mask = nla_get_u32(tb[i]); + break; + default: + if (strict_check) { + NL_SET_ERR_MSG(extack, "Unsupported attribute in bridge link dump request"); + return -EINVAL; + } + } + } + + return 0; +} + static int rtnl_bridge_getlink(struct sk_buff *skb, struct netlink_callback *cb) { + const struct nlmsghdr *nlh = cb->nlh; struct net *net = sock_net(skb->sk); struct net_device *dev; int idx = 0; u32 portid = NETLINK_CB(cb->skb).portid; - u32 seq = cb->nlh->nlmsg_seq; + u32 seq = nlh->nlmsg_seq; u32 filter_mask = 0; int err; - if (nlmsg_len(cb->nlh) > sizeof(struct ifinfomsg)) { - struct nlattr *extfilt; - - extfilt = nlmsg_find_attr(cb->nlh, sizeof(struct ifinfomsg), - IFLA_EXT_MASK); - if (extfilt) { - if (nla_len(extfilt) < sizeof(filter_mask)) - return -EINVAL; - - filter_mask = nla_get_u32(extfilt); - } - } + err = valid_bridge_getlink_req(nlh, cb->strict_check, &filter_mask, + cb->extack); + if (err < 0 && cb->strict_check) + return err; rcu_read_lock(); for_each_netdev_rcu(net, dev) { @@ -4568,6 +4765,7 @@ static int rtnl_stats_get(struct sk_buff *skb, struct nlmsghdr *nlh, static int rtnl_stats_dump(struct sk_buff *skb, struct netlink_callback *cb) { + struct netlink_ext_ack *extack = cb->extack; int h, s_h, err, s_idx, s_idxattr, s_prividx; struct net *net = sock_net(skb->sk); unsigned int flags = NLM_F_MULTI; @@ -4584,13 +4782,32 @@ static int rtnl_stats_dump(struct sk_buff *skb, struct netlink_callback *cb) cb->seq = net->dev_base_seq; - if (nlmsg_len(cb->nlh) < sizeof(*ifsm)) + if (nlmsg_len(cb->nlh) < sizeof(*ifsm)) { + NL_SET_ERR_MSG(extack, "Invalid header for stats dump"); return -EINVAL; + } ifsm = nlmsg_data(cb->nlh); + + /* only requests using strict checks can pass data to influence + * the dump. The legacy exception is filter_mask. + */ + if (cb->strict_check) { + if (ifsm->pad1 || ifsm->pad2 || ifsm->ifindex) { + NL_SET_ERR_MSG(extack, "Invalid values in header for stats dump request"); + return -EINVAL; + } + if (nlmsg_attrlen(cb->nlh, sizeof(*ifsm))) { + NL_SET_ERR_MSG(extack, "Invalid attributes after stats header"); + return -EINVAL; + } + } + filter_mask = ifsm->filter_mask; - if (!filter_mask) + if (!filter_mask) { + NL_SET_ERR_MSG(extack, "Filter mask must be set for stats dump"); return -EINVAL; + } for (h = s_h; h < NETDEV_HASHENTRIES; h++, s_idx = 0) { idx = 0; diff --git a/net/core/skbuff.c b/net/core/skbuff.c index 0e937d3d85b5..946de0e24c87 100644 --- a/net/core/skbuff.c +++ b/net/core/skbuff.c @@ -1846,8 +1846,9 @@ int pskb_trim_rcsum_slow(struct sk_buff *skb, unsigned int len) if (skb->ip_summed == CHECKSUM_COMPLETE) { int delta = skb->len - len; - skb->csum = csum_sub(skb->csum, - skb_checksum(skb, len, delta, 0)); + skb->csum = csum_block_sub(skb->csum, + skb_checksum(skb, len, delta, 0), + len); } return __pskb_trim(skb, len); } @@ -4394,14 +4395,16 @@ EXPORT_SYMBOL_GPL(skb_complete_wifi_ack); */ bool skb_partial_csum_set(struct sk_buff *skb, u16 start, u16 off) { - if (unlikely(start > skb_headlen(skb)) || - unlikely((int)start + off > skb_headlen(skb) - 2)) { - net_warn_ratelimited("bad partial csum: csum=%u/%u len=%u\n", - start, off, skb_headlen(skb)); + u32 csum_end = (u32)start + (u32)off + sizeof(__sum16); + u32 csum_start = skb_headroom(skb) + (u32)start; + + if (unlikely(csum_start > U16_MAX || csum_end > skb_headlen(skb))) { + net_warn_ratelimited("bad partial csum: csum=%u/%u headroom=%u headlen=%u\n", + start, off, skb_headroom(skb), skb_headlen(skb)); return false; } skb->ip_summed = CHECKSUM_PARTIAL; - skb->csum_start = skb_headroom(skb) + start; + skb->csum_start = csum_start; skb->csum_offset = off; skb_set_transport_header(skb, start); return true; diff --git a/net/core/skmsg.c b/net/core/skmsg.c new file mode 100644 index 000000000000..56a99d0c9aa0 --- /dev/null +++ b/net/core/skmsg.c @@ -0,0 +1,802 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ + +#include <linux/skmsg.h> +#include <linux/skbuff.h> +#include <linux/scatterlist.h> + +#include <net/sock.h> +#include <net/tcp.h> + +static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce) +{ + if (msg->sg.end > msg->sg.start && + elem_first_coalesce < msg->sg.end) + return true; + + if (msg->sg.end < msg->sg.start && + (elem_first_coalesce > msg->sg.start || + elem_first_coalesce < msg->sg.end)) + return true; + + return false; +} + +int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len, + int elem_first_coalesce) +{ + struct page_frag *pfrag = sk_page_frag(sk); + int ret = 0; + + len -= msg->sg.size; + while (len > 0) { + struct scatterlist *sge; + u32 orig_offset; + int use, i; + + if (!sk_page_frag_refill(sk, pfrag)) + return -ENOMEM; + + orig_offset = pfrag->offset; + use = min_t(int, len, pfrag->size - orig_offset); + if (!sk_wmem_schedule(sk, use)) + return -ENOMEM; + + i = msg->sg.end; + sk_msg_iter_var_prev(i); + sge = &msg->sg.data[i]; + + if (sk_msg_try_coalesce_ok(msg, elem_first_coalesce) && + sg_page(sge) == pfrag->page && + sge->offset + sge->length == orig_offset) { + sge->length += use; + } else { + if (sk_msg_full(msg)) { + ret = -ENOSPC; + break; + } + + sge = &msg->sg.data[msg->sg.end]; + sg_unmark_end(sge); + sg_set_page(sge, pfrag->page, use, orig_offset); + get_page(pfrag->page); + sk_msg_iter_next(msg, end); + } + + sk_mem_charge(sk, use); + msg->sg.size += use; + pfrag->offset += use; + len -= use; + } + + return ret; +} +EXPORT_SYMBOL_GPL(sk_msg_alloc); + +int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src, + u32 off, u32 len) +{ + int i = src->sg.start; + struct scatterlist *sge = sk_msg_elem(src, i); + u32 sge_len, sge_off; + + if (sk_msg_full(dst)) + return -ENOSPC; + + while (off) { + if (sge->length > off) + break; + off -= sge->length; + sk_msg_iter_var_next(i); + if (i == src->sg.end && off) + return -ENOSPC; + sge = sk_msg_elem(src, i); + } + + while (len) { + sge_len = sge->length - off; + sge_off = sge->offset + off; + if (sge_len > len) + sge_len = len; + off = 0; + len -= sge_len; + sk_msg_page_add(dst, sg_page(sge), sge_len, sge_off); + sk_mem_charge(sk, sge_len); + sk_msg_iter_var_next(i); + if (i == src->sg.end && len) + return -ENOSPC; + sge = sk_msg_elem(src, i); + } + + return 0; +} +EXPORT_SYMBOL_GPL(sk_msg_clone); + +void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes) +{ + int i = msg->sg.start; + + do { + struct scatterlist *sge = sk_msg_elem(msg, i); + + if (bytes < sge->length) { + sge->length -= bytes; + sge->offset += bytes; + sk_mem_uncharge(sk, bytes); + break; + } + + sk_mem_uncharge(sk, sge->length); + bytes -= sge->length; + sge->length = 0; + sge->offset = 0; + sk_msg_iter_var_next(i); + } while (bytes && i != msg->sg.end); + msg->sg.start = i; +} +EXPORT_SYMBOL_GPL(sk_msg_return_zero); + +void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes) +{ + int i = msg->sg.start; + + do { + struct scatterlist *sge = &msg->sg.data[i]; + int uncharge = (bytes < sge->length) ? bytes : sge->length; + + sk_mem_uncharge(sk, uncharge); + bytes -= uncharge; + sk_msg_iter_var_next(i); + } while (i != msg->sg.end); +} +EXPORT_SYMBOL_GPL(sk_msg_return); + +static int sk_msg_free_elem(struct sock *sk, struct sk_msg *msg, u32 i, + bool charge) +{ + struct scatterlist *sge = sk_msg_elem(msg, i); + u32 len = sge->length; + + if (charge) + sk_mem_uncharge(sk, len); + if (!msg->skb) + put_page(sg_page(sge)); + memset(sge, 0, sizeof(*sge)); + return len; +} + +static int __sk_msg_free(struct sock *sk, struct sk_msg *msg, u32 i, + bool charge) +{ + struct scatterlist *sge = sk_msg_elem(msg, i); + int freed = 0; + + while (msg->sg.size) { + msg->sg.size -= sge->length; + freed += sk_msg_free_elem(sk, msg, i, charge); + sk_msg_iter_var_next(i); + sk_msg_check_to_free(msg, i, msg->sg.size); + sge = sk_msg_elem(msg, i); + } + if (msg->skb) + consume_skb(msg->skb); + sk_msg_init(msg); + return freed; +} + +int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg) +{ + return __sk_msg_free(sk, msg, msg->sg.start, false); +} +EXPORT_SYMBOL_GPL(sk_msg_free_nocharge); + +int sk_msg_free(struct sock *sk, struct sk_msg *msg) +{ + return __sk_msg_free(sk, msg, msg->sg.start, true); +} +EXPORT_SYMBOL_GPL(sk_msg_free); + +static void __sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, + u32 bytes, bool charge) +{ + struct scatterlist *sge; + u32 i = msg->sg.start; + + while (bytes) { + sge = sk_msg_elem(msg, i); + if (!sge->length) + break; + if (bytes < sge->length) { + if (charge) + sk_mem_uncharge(sk, bytes); + sge->length -= bytes; + sge->offset += bytes; + msg->sg.size -= bytes; + break; + } + + msg->sg.size -= sge->length; + bytes -= sge->length; + sk_msg_free_elem(sk, msg, i, charge); + sk_msg_iter_var_next(i); + sk_msg_check_to_free(msg, i, bytes); + } + msg->sg.start = i; +} + +void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes) +{ + __sk_msg_free_partial(sk, msg, bytes, true); +} +EXPORT_SYMBOL_GPL(sk_msg_free_partial); + +void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg, + u32 bytes) +{ + __sk_msg_free_partial(sk, msg, bytes, false); +} + +void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len) +{ + int trim = msg->sg.size - len; + u32 i = msg->sg.end; + + if (trim <= 0) { + WARN_ON(trim < 0); + return; + } + + sk_msg_iter_var_prev(i); + msg->sg.size = len; + while (msg->sg.data[i].length && + trim >= msg->sg.data[i].length) { + trim -= msg->sg.data[i].length; + sk_msg_free_elem(sk, msg, i, true); + sk_msg_iter_var_prev(i); + if (!trim) + goto out; + } + + msg->sg.data[i].length -= trim; + sk_mem_uncharge(sk, trim); +out: + /* If we trim data before curr pointer update copybreak and current + * so that any future copy operations start at new copy location. + * However trimed data that has not yet been used in a copy op + * does not require an update. + */ + if (msg->sg.curr >= i) { + msg->sg.curr = i; + msg->sg.copybreak = msg->sg.data[i].length; + } + sk_msg_iter_var_next(i); + msg->sg.end = i; +} +EXPORT_SYMBOL_GPL(sk_msg_trim); + +int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from, + struct sk_msg *msg, u32 bytes) +{ + int i, maxpages, ret = 0, num_elems = sk_msg_elem_used(msg); + const int to_max_pages = MAX_MSG_FRAGS; + struct page *pages[MAX_MSG_FRAGS]; + ssize_t orig, copied, use, offset; + + orig = msg->sg.size; + while (bytes > 0) { + i = 0; + maxpages = to_max_pages - num_elems; + if (maxpages == 0) { + ret = -EFAULT; + goto out; + } + + copied = iov_iter_get_pages(from, pages, bytes, maxpages, + &offset); + if (copied <= 0) { + ret = -EFAULT; + goto out; + } + + iov_iter_advance(from, copied); + bytes -= copied; + msg->sg.size += copied; + + while (copied) { + use = min_t(int, copied, PAGE_SIZE - offset); + sg_set_page(&msg->sg.data[msg->sg.end], + pages[i], use, offset); + sg_unmark_end(&msg->sg.data[msg->sg.end]); + sk_mem_charge(sk, use); + + offset = 0; + copied -= use; + sk_msg_iter_next(msg, end); + num_elems++; + i++; + } + /* When zerocopy is mixed with sk_msg_*copy* operations we + * may have a copybreak set in this case clear and prefer + * zerocopy remainder when possible. + */ + msg->sg.copybreak = 0; + msg->sg.curr = msg->sg.end; + } +out: + /* Revert iov_iter updates, msg will need to use 'trim' later if it + * also needs to be cleared. + */ + if (ret) + iov_iter_revert(from, msg->sg.size - orig); + return ret; +} +EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter); + +int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from, + struct sk_msg *msg, u32 bytes) +{ + int ret = -ENOSPC, i = msg->sg.curr; + struct scatterlist *sge; + u32 copy, buf_size; + void *to; + + do { + sge = sk_msg_elem(msg, i); + /* This is possible if a trim operation shrunk the buffer */ + if (msg->sg.copybreak >= sge->length) { + msg->sg.copybreak = 0; + sk_msg_iter_var_next(i); + if (i == msg->sg.end) + break; + sge = sk_msg_elem(msg, i); + } + + buf_size = sge->length - msg->sg.copybreak; + copy = (buf_size > bytes) ? bytes : buf_size; + to = sg_virt(sge) + msg->sg.copybreak; + msg->sg.copybreak += copy; + if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY) + ret = copy_from_iter_nocache(to, copy, from); + else + ret = copy_from_iter(to, copy, from); + if (ret != copy) { + ret = -EFAULT; + goto out; + } + bytes -= copy; + if (!bytes) + break; + msg->sg.copybreak = 0; + sk_msg_iter_var_next(i); + } while (i != msg->sg.end); +out: + msg->sg.curr = i; + return ret; +} +EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter); + +static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb) +{ + struct sock *sk = psock->sk; + int copied = 0, num_sge; + struct sk_msg *msg; + + msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_ATOMIC); + if (unlikely(!msg)) + return -EAGAIN; + if (!sk_rmem_schedule(sk, skb, skb->len)) { + kfree(msg); + return -EAGAIN; + } + + sk_msg_init(msg); + num_sge = skb_to_sgvec(skb, msg->sg.data, 0, skb->len); + if (unlikely(num_sge < 0)) { + kfree(msg); + return num_sge; + } + + sk_mem_charge(sk, skb->len); + copied = skb->len; + msg->sg.start = 0; + msg->sg.end = num_sge == MAX_MSG_FRAGS ? 0 : num_sge; + msg->skb = skb; + + sk_psock_queue_msg(psock, msg); + sk->sk_data_ready(sk); + return copied; +} + +static int sk_psock_handle_skb(struct sk_psock *psock, struct sk_buff *skb, + u32 off, u32 len, bool ingress) +{ + if (ingress) + return sk_psock_skb_ingress(psock, skb); + else + return skb_send_sock_locked(psock->sk, skb, off, len); +} + +static void sk_psock_backlog(struct work_struct *work) +{ + struct sk_psock *psock = container_of(work, struct sk_psock, work); + struct sk_psock_work_state *state = &psock->work_state; + struct sk_buff *skb; + bool ingress; + u32 len, off; + int ret; + + /* Lock sock to avoid losing sk_socket during loop. */ + lock_sock(psock->sk); + if (state->skb) { + skb = state->skb; + len = state->len; + off = state->off; + state->skb = NULL; + goto start; + } + + while ((skb = skb_dequeue(&psock->ingress_skb))) { + len = skb->len; + off = 0; +start: + ingress = tcp_skb_bpf_ingress(skb); + do { + ret = -EIO; + if (likely(psock->sk->sk_socket)) + ret = sk_psock_handle_skb(psock, skb, off, + len, ingress); + if (ret <= 0) { + if (ret == -EAGAIN) { + state->skb = skb; + state->len = len; + state->off = off; + goto end; + } + /* Hard errors break pipe and stop xmit. */ + sk_psock_report_error(psock, ret ? -ret : EPIPE); + sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED); + kfree_skb(skb); + goto end; + } + off += ret; + len -= ret; + } while (len); + + if (!ingress) + kfree_skb(skb); + } +end: + release_sock(psock->sk); +} + +struct sk_psock *sk_psock_init(struct sock *sk, int node) +{ + struct sk_psock *psock = kzalloc_node(sizeof(*psock), + GFP_ATOMIC | __GFP_NOWARN, + node); + if (!psock) + return NULL; + + psock->sk = sk; + psock->eval = __SK_NONE; + + INIT_LIST_HEAD(&psock->link); + spin_lock_init(&psock->link_lock); + + INIT_WORK(&psock->work, sk_psock_backlog); + INIT_LIST_HEAD(&psock->ingress_msg); + skb_queue_head_init(&psock->ingress_skb); + + sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED); + refcount_set(&psock->refcnt, 1); + + rcu_assign_sk_user_data(sk, psock); + sock_hold(sk); + + return psock; +} +EXPORT_SYMBOL_GPL(sk_psock_init); + +struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock) +{ + struct sk_psock_link *link; + + spin_lock_bh(&psock->link_lock); + link = list_first_entry_or_null(&psock->link, struct sk_psock_link, + list); + if (link) + list_del(&link->list); + spin_unlock_bh(&psock->link_lock); + return link; +} + +void __sk_psock_purge_ingress_msg(struct sk_psock *psock) +{ + struct sk_msg *msg, *tmp; + + list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) { + list_del(&msg->list); + sk_msg_free(psock->sk, msg); + kfree(msg); + } +} + +static void sk_psock_zap_ingress(struct sk_psock *psock) +{ + __skb_queue_purge(&psock->ingress_skb); + __sk_psock_purge_ingress_msg(psock); +} + +static void sk_psock_link_destroy(struct sk_psock *psock) +{ + struct sk_psock_link *link, *tmp; + + list_for_each_entry_safe(link, tmp, &psock->link, list) { + list_del(&link->list); + sk_psock_free_link(link); + } +} + +static void sk_psock_destroy_deferred(struct work_struct *gc) +{ + struct sk_psock *psock = container_of(gc, struct sk_psock, gc); + + /* No sk_callback_lock since already detached. */ + if (psock->parser.enabled) + strp_done(&psock->parser.strp); + + cancel_work_sync(&psock->work); + + psock_progs_drop(&psock->progs); + + sk_psock_link_destroy(psock); + sk_psock_cork_free(psock); + sk_psock_zap_ingress(psock); + + if (psock->sk_redir) + sock_put(psock->sk_redir); + sock_put(psock->sk); + kfree(psock); +} + +void sk_psock_destroy(struct rcu_head *rcu) +{ + struct sk_psock *psock = container_of(rcu, struct sk_psock, rcu); + + INIT_WORK(&psock->gc, sk_psock_destroy_deferred); + schedule_work(&psock->gc); +} +EXPORT_SYMBOL_GPL(sk_psock_destroy); + +void sk_psock_drop(struct sock *sk, struct sk_psock *psock) +{ + rcu_assign_sk_user_data(sk, NULL); + sk_psock_cork_free(psock); + sk_psock_restore_proto(sk, psock); + + write_lock_bh(&sk->sk_callback_lock); + if (psock->progs.skb_parser) + sk_psock_stop_strp(sk, psock); + write_unlock_bh(&sk->sk_callback_lock); + sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED); + + call_rcu_sched(&psock->rcu, sk_psock_destroy); +} +EXPORT_SYMBOL_GPL(sk_psock_drop); + +static int sk_psock_map_verd(int verdict, bool redir) +{ + switch (verdict) { + case SK_PASS: + return redir ? __SK_REDIRECT : __SK_PASS; + case SK_DROP: + default: + break; + } + + return __SK_DROP; +} + +int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock, + struct sk_msg *msg) +{ + struct bpf_prog *prog; + int ret; + + preempt_disable(); + rcu_read_lock(); + prog = READ_ONCE(psock->progs.msg_parser); + if (unlikely(!prog)) { + ret = __SK_PASS; + goto out; + } + + sk_msg_compute_data_pointers(msg); + msg->sk = sk; + ret = BPF_PROG_RUN(prog, msg); + ret = sk_psock_map_verd(ret, msg->sk_redir); + psock->apply_bytes = msg->apply_bytes; + if (ret == __SK_REDIRECT) { + if (psock->sk_redir) + sock_put(psock->sk_redir); + psock->sk_redir = msg->sk_redir; + if (!psock->sk_redir) { + ret = __SK_DROP; + goto out; + } + sock_hold(psock->sk_redir); + } +out: + rcu_read_unlock(); + preempt_enable(); + return ret; +} +EXPORT_SYMBOL_GPL(sk_psock_msg_verdict); + +static int sk_psock_bpf_run(struct sk_psock *psock, struct bpf_prog *prog, + struct sk_buff *skb) +{ + int ret; + + skb->sk = psock->sk; + bpf_compute_data_end_sk_skb(skb); + preempt_disable(); + ret = BPF_PROG_RUN(prog, skb); + preempt_enable(); + /* strparser clones the skb before handing it to a upper layer, + * meaning skb_orphan has been called. We NULL sk on the way out + * to ensure we don't trigger a BUG_ON() in skb/sk operations + * later and because we are not charging the memory of this skb + * to any socket yet. + */ + skb->sk = NULL; + return ret; +} + +static struct sk_psock *sk_psock_from_strp(struct strparser *strp) +{ + struct sk_psock_parser *parser; + + parser = container_of(strp, struct sk_psock_parser, strp); + return container_of(parser, struct sk_psock, parser); +} + +static void sk_psock_verdict_apply(struct sk_psock *psock, + struct sk_buff *skb, int verdict) +{ + struct sk_psock *psock_other; + struct sock *sk_other; + bool ingress; + + switch (verdict) { + case __SK_REDIRECT: + sk_other = tcp_skb_bpf_redirect_fetch(skb); + if (unlikely(!sk_other)) + goto out_free; + psock_other = sk_psock(sk_other); + if (!psock_other || sock_flag(sk_other, SOCK_DEAD) || + !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED)) + goto out_free; + ingress = tcp_skb_bpf_ingress(skb); + if ((!ingress && sock_writeable(sk_other)) || + (ingress && + atomic_read(&sk_other->sk_rmem_alloc) <= + sk_other->sk_rcvbuf)) { + if (!ingress) + skb_set_owner_w(skb, sk_other); + skb_queue_tail(&psock_other->ingress_skb, skb); + schedule_work(&psock_other->work); + break; + } + /* fall-through */ + case __SK_DROP: + /* fall-through */ + default: +out_free: + kfree_skb(skb); + } +} + +static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb) +{ + struct sk_psock *psock = sk_psock_from_strp(strp); + struct bpf_prog *prog; + int ret = __SK_DROP; + + rcu_read_lock(); + prog = READ_ONCE(psock->progs.skb_verdict); + if (likely(prog)) { + skb_orphan(skb); + tcp_skb_bpf_redirect_clear(skb); + ret = sk_psock_bpf_run(psock, prog, skb); + ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb)); + } + rcu_read_unlock(); + sk_psock_verdict_apply(psock, skb, ret); +} + +static int sk_psock_strp_read_done(struct strparser *strp, int err) +{ + return err; +} + +static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb) +{ + struct sk_psock *psock = sk_psock_from_strp(strp); + struct bpf_prog *prog; + int ret = skb->len; + + rcu_read_lock(); + prog = READ_ONCE(psock->progs.skb_parser); + if (likely(prog)) + ret = sk_psock_bpf_run(psock, prog, skb); + rcu_read_unlock(); + return ret; +} + +/* Called with socket lock held. */ +static void sk_psock_data_ready(struct sock *sk) +{ + struct sk_psock *psock; + + rcu_read_lock(); + psock = sk_psock(sk); + if (likely(psock)) { + write_lock_bh(&sk->sk_callback_lock); + strp_data_ready(&psock->parser.strp); + write_unlock_bh(&sk->sk_callback_lock); + } + rcu_read_unlock(); +} + +static void sk_psock_write_space(struct sock *sk) +{ + struct sk_psock *psock; + void (*write_space)(struct sock *sk); + + rcu_read_lock(); + psock = sk_psock(sk); + if (likely(psock && sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED))) + schedule_work(&psock->work); + write_space = psock->saved_write_space; + rcu_read_unlock(); + write_space(sk); +} + +int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock) +{ + static const struct strp_callbacks cb = { + .rcv_msg = sk_psock_strp_read, + .read_sock_done = sk_psock_strp_read_done, + .parse_msg = sk_psock_strp_parse, + }; + + psock->parser.enabled = false; + return strp_init(&psock->parser.strp, sk, &cb); +} + +void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock) +{ + struct sk_psock_parser *parser = &psock->parser; + + if (parser->enabled) + return; + + parser->saved_data_ready = sk->sk_data_ready; + sk->sk_data_ready = sk_psock_data_ready; + sk->sk_write_space = sk_psock_write_space; + parser->enabled = true; +} + +void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock) +{ + struct sk_psock_parser *parser = &psock->parser; + + if (!parser->enabled) + return; + + sk->sk_data_ready = parser->saved_data_ready; + parser->saved_data_ready = NULL; + strp_stop(&parser->strp); + parser->enabled = false; +} diff --git a/net/core/sock.c b/net/core/sock.c index 7e8796a6a089..6fcc4bc07d19 100644 --- a/net/core/sock.c +++ b/net/core/sock.c @@ -998,7 +998,7 @@ set_rcvbuf: cmpxchg(&sk->sk_pacing_status, SK_PACING_NONE, SK_PACING_NEEDED); - sk->sk_max_pacing_rate = val; + sk->sk_max_pacing_rate = (val == ~0U) ? ~0UL : val; sk->sk_pacing_rate = min(sk->sk_pacing_rate, sk->sk_max_pacing_rate); break; @@ -1336,7 +1336,8 @@ int sock_getsockopt(struct socket *sock, int level, int optname, #endif case SO_MAX_PACING_RATE: - v.val = sk->sk_max_pacing_rate; + /* 32bit version */ + v.val = min_t(unsigned long, sk->sk_max_pacing_rate, ~0U); break; case SO_INCOMING_CPU: @@ -2238,67 +2239,6 @@ bool sk_page_frag_refill(struct sock *sk, struct page_frag *pfrag) } EXPORT_SYMBOL(sk_page_frag_refill); -int sk_alloc_sg(struct sock *sk, int len, struct scatterlist *sg, - int sg_start, int *sg_curr_index, unsigned int *sg_curr_size, - int first_coalesce) -{ - int sg_curr = *sg_curr_index, use = 0, rc = 0; - unsigned int size = *sg_curr_size; - struct page_frag *pfrag; - struct scatterlist *sge; - - len -= size; - pfrag = sk_page_frag(sk); - - while (len > 0) { - unsigned int orig_offset; - - if (!sk_page_frag_refill(sk, pfrag)) { - rc = -ENOMEM; - goto out; - } - - use = min_t(int, len, pfrag->size - pfrag->offset); - - if (!sk_wmem_schedule(sk, use)) { - rc = -ENOMEM; - goto out; - } - - sk_mem_charge(sk, use); - size += use; - orig_offset = pfrag->offset; - pfrag->offset += use; - - sge = sg + sg_curr - 1; - if (sg_curr > first_coalesce && sg_page(sge) == pfrag->page && - sge->offset + sge->length == orig_offset) { - sge->length += use; - } else { - sge = sg + sg_curr; - sg_unmark_end(sge); - sg_set_page(sge, pfrag->page, use, orig_offset); - get_page(pfrag->page); - sg_curr++; - - if (sg_curr == MAX_SKB_FRAGS) - sg_curr = 0; - - if (sg_curr == sg_start) { - rc = -ENOSPC; - break; - } - } - - len -= use; - } -out: - *sg_curr_size = size; - *sg_curr_index = sg_curr; - return rc; -} -EXPORT_SYMBOL(sk_alloc_sg); - static void __lock_sock(struct sock *sk) __releases(&sk->sk_lock.slock) __acquires(&sk->sk_lock.slock) @@ -2810,8 +2750,8 @@ void sock_init_data(struct socket *sock, struct sock *sk) sk->sk_ll_usec = sysctl_net_busy_read; #endif - sk->sk_max_pacing_rate = ~0U; - sk->sk_pacing_rate = ~0U; + sk->sk_max_pacing_rate = ~0UL; + sk->sk_pacing_rate = ~0UL; sk->sk_pacing_shift = 10; sk->sk_incoming_cpu = -1; diff --git a/net/core/sock_map.c b/net/core/sock_map.c new file mode 100644 index 000000000000..be6092ac69f8 --- /dev/null +++ b/net/core/sock_map.c @@ -0,0 +1,1003 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ + +#include <linux/bpf.h> +#include <linux/filter.h> +#include <linux/errno.h> +#include <linux/file.h> +#include <linux/net.h> +#include <linux/workqueue.h> +#include <linux/skmsg.h> +#include <linux/list.h> +#include <linux/jhash.h> + +struct bpf_stab { + struct bpf_map map; + struct sock **sks; + struct sk_psock_progs progs; + raw_spinlock_t lock; +}; + +#define SOCK_CREATE_FLAG_MASK \ + (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY) + +static struct bpf_map *sock_map_alloc(union bpf_attr *attr) +{ + struct bpf_stab *stab; + u64 cost; + int err; + + if (!capable(CAP_NET_ADMIN)) + return ERR_PTR(-EPERM); + if (attr->max_entries == 0 || + attr->key_size != 4 || + attr->value_size != 4 || + attr->map_flags & ~SOCK_CREATE_FLAG_MASK) + return ERR_PTR(-EINVAL); + + stab = kzalloc(sizeof(*stab), GFP_USER); + if (!stab) + return ERR_PTR(-ENOMEM); + + bpf_map_init_from_attr(&stab->map, attr); + raw_spin_lock_init(&stab->lock); + + /* Make sure page count doesn't overflow. */ + cost = (u64) stab->map.max_entries * sizeof(struct sock *); + if (cost >= U32_MAX - PAGE_SIZE) { + err = -EINVAL; + goto free_stab; + } + + stab->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT; + err = bpf_map_precharge_memlock(stab->map.pages); + if (err) + goto free_stab; + + stab->sks = bpf_map_area_alloc(stab->map.max_entries * + sizeof(struct sock *), + stab->map.numa_node); + if (stab->sks) + return &stab->map; + err = -ENOMEM; +free_stab: + kfree(stab); + return ERR_PTR(err); +} + +int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog) +{ + u32 ufd = attr->target_fd; + struct bpf_map *map; + struct fd f; + int ret; + + f = fdget(ufd); + map = __bpf_map_get(f); + if (IS_ERR(map)) + return PTR_ERR(map); + ret = sock_map_prog_update(map, prog, attr->attach_type); + fdput(f); + return ret; +} + +static void sock_map_sk_acquire(struct sock *sk) + __acquires(&sk->sk_lock.slock) +{ + lock_sock(sk); + preempt_disable(); + rcu_read_lock(); +} + +static void sock_map_sk_release(struct sock *sk) + __releases(&sk->sk_lock.slock) +{ + rcu_read_unlock(); + preempt_enable(); + release_sock(sk); +} + +static void sock_map_add_link(struct sk_psock *psock, + struct sk_psock_link *link, + struct bpf_map *map, void *link_raw) +{ + link->link_raw = link_raw; + link->map = map; + spin_lock_bh(&psock->link_lock); + list_add_tail(&link->list, &psock->link); + spin_unlock_bh(&psock->link_lock); +} + +static void sock_map_del_link(struct sock *sk, + struct sk_psock *psock, void *link_raw) +{ + struct sk_psock_link *link, *tmp; + bool strp_stop = false; + + spin_lock_bh(&psock->link_lock); + list_for_each_entry_safe(link, tmp, &psock->link, list) { + if (link->link_raw == link_raw) { + struct bpf_map *map = link->map; + struct bpf_stab *stab = container_of(map, struct bpf_stab, + map); + if (psock->parser.enabled && stab->progs.skb_parser) + strp_stop = true; + list_del(&link->list); + sk_psock_free_link(link); + } + } + spin_unlock_bh(&psock->link_lock); + if (strp_stop) { + write_lock_bh(&sk->sk_callback_lock); + sk_psock_stop_strp(sk, psock); + write_unlock_bh(&sk->sk_callback_lock); + } +} + +static void sock_map_unref(struct sock *sk, void *link_raw) +{ + struct sk_psock *psock = sk_psock(sk); + + if (likely(psock)) { + sock_map_del_link(sk, psock, link_raw); + sk_psock_put(sk, psock); + } +} + +static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, + struct sock *sk) +{ + struct bpf_prog *msg_parser, *skb_parser, *skb_verdict; + bool skb_progs, sk_psock_is_new = false; + struct sk_psock *psock; + int ret; + + skb_verdict = READ_ONCE(progs->skb_verdict); + skb_parser = READ_ONCE(progs->skb_parser); + skb_progs = skb_parser && skb_verdict; + if (skb_progs) { + skb_verdict = bpf_prog_inc_not_zero(skb_verdict); + if (IS_ERR(skb_verdict)) + return PTR_ERR(skb_verdict); + skb_parser = bpf_prog_inc_not_zero(skb_parser); + if (IS_ERR(skb_parser)) { + bpf_prog_put(skb_verdict); + return PTR_ERR(skb_parser); + } + } + + msg_parser = READ_ONCE(progs->msg_parser); + if (msg_parser) { + msg_parser = bpf_prog_inc_not_zero(msg_parser); + if (IS_ERR(msg_parser)) { + ret = PTR_ERR(msg_parser); + goto out; + } + } + + psock = sk_psock_get_checked(sk); + if (IS_ERR(psock)) { + ret = PTR_ERR(psock); + goto out_progs; + } + + if (psock) { + if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) || + (skb_progs && READ_ONCE(psock->progs.skb_parser))) { + sk_psock_put(sk, psock); + ret = -EBUSY; + goto out_progs; + } + } else { + psock = sk_psock_init(sk, map->numa_node); + if (!psock) { + ret = -ENOMEM; + goto out_progs; + } + sk_psock_is_new = true; + } + + if (msg_parser) + psock_set_prog(&psock->progs.msg_parser, msg_parser); + if (sk_psock_is_new) { + ret = tcp_bpf_init(sk); + if (ret < 0) + goto out_drop; + } else { + tcp_bpf_reinit(sk); + } + + write_lock_bh(&sk->sk_callback_lock); + if (skb_progs && !psock->parser.enabled) { + ret = sk_psock_init_strp(sk, psock); + if (ret) { + write_unlock_bh(&sk->sk_callback_lock); + goto out_drop; + } + psock_set_prog(&psock->progs.skb_verdict, skb_verdict); + psock_set_prog(&psock->progs.skb_parser, skb_parser); + sk_psock_start_strp(sk, psock); + } + write_unlock_bh(&sk->sk_callback_lock); + return 0; +out_drop: + sk_psock_put(sk, psock); +out_progs: + if (msg_parser) + bpf_prog_put(msg_parser); +out: + if (skb_progs) { + bpf_prog_put(skb_verdict); + bpf_prog_put(skb_parser); + } + return ret; +} + +static void sock_map_free(struct bpf_map *map) +{ + struct bpf_stab *stab = container_of(map, struct bpf_stab, map); + int i; + + synchronize_rcu(); + rcu_read_lock(); + raw_spin_lock_bh(&stab->lock); + for (i = 0; i < stab->map.max_entries; i++) { + struct sock **psk = &stab->sks[i]; + struct sock *sk; + + sk = xchg(psk, NULL); + if (sk) + sock_map_unref(sk, psk); + } + raw_spin_unlock_bh(&stab->lock); + rcu_read_unlock(); + + bpf_map_area_free(stab->sks); + kfree(stab); +} + +static void sock_map_release_progs(struct bpf_map *map) +{ + psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs); +} + +static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key) +{ + struct bpf_stab *stab = container_of(map, struct bpf_stab, map); + + WARN_ON_ONCE(!rcu_read_lock_held()); + + if (unlikely(key >= map->max_entries)) + return NULL; + return READ_ONCE(stab->sks[key]); +} + +static void *sock_map_lookup(struct bpf_map *map, void *key) +{ + return ERR_PTR(-EOPNOTSUPP); +} + +static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test, + struct sock **psk) +{ + struct sock *sk; + + raw_spin_lock_bh(&stab->lock); + sk = *psk; + if (!sk_test || sk_test == sk) + *psk = NULL; + raw_spin_unlock_bh(&stab->lock); + if (unlikely(!sk)) + return -EINVAL; + sock_map_unref(sk, psk); + return 0; +} + +static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk, + void *link_raw) +{ + struct bpf_stab *stab = container_of(map, struct bpf_stab, map); + + __sock_map_delete(stab, sk, link_raw); +} + +static int sock_map_delete_elem(struct bpf_map *map, void *key) +{ + struct bpf_stab *stab = container_of(map, struct bpf_stab, map); + u32 i = *(u32 *)key; + struct sock **psk; + + if (unlikely(i >= map->max_entries)) + return -EINVAL; + + psk = &stab->sks[i]; + return __sock_map_delete(stab, NULL, psk); +} + +static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next) +{ + struct bpf_stab *stab = container_of(map, struct bpf_stab, map); + u32 i = key ? *(u32 *)key : U32_MAX; + u32 *key_next = next; + + if (i == stab->map.max_entries - 1) + return -ENOENT; + if (i >= stab->map.max_entries) + *key_next = 0; + else + *key_next = i + 1; + return 0; +} + +static int sock_map_update_common(struct bpf_map *map, u32 idx, + struct sock *sk, u64 flags) +{ + struct bpf_stab *stab = container_of(map, struct bpf_stab, map); + struct sk_psock_link *link; + struct sk_psock *psock; + struct sock *osk; + int ret; + + WARN_ON_ONCE(!rcu_read_lock_held()); + if (unlikely(flags > BPF_EXIST)) + return -EINVAL; + if (unlikely(idx >= map->max_entries)) + return -E2BIG; + + link = sk_psock_init_link(); + if (!link) + return -ENOMEM; + + ret = sock_map_link(map, &stab->progs, sk); + if (ret < 0) + goto out_free; + + psock = sk_psock(sk); + WARN_ON_ONCE(!psock); + + raw_spin_lock_bh(&stab->lock); + osk = stab->sks[idx]; + if (osk && flags == BPF_NOEXIST) { + ret = -EEXIST; + goto out_unlock; + } else if (!osk && flags == BPF_EXIST) { + ret = -ENOENT; + goto out_unlock; + } + + sock_map_add_link(psock, link, map, &stab->sks[idx]); + stab->sks[idx] = sk; + if (osk) + sock_map_unref(osk, &stab->sks[idx]); + raw_spin_unlock_bh(&stab->lock); + return 0; +out_unlock: + raw_spin_unlock_bh(&stab->lock); + if (psock) + sk_psock_put(sk, psock); +out_free: + sk_psock_free_link(link); + return ret; +} + +static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops) +{ + return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB || + ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB; +} + +static bool sock_map_sk_is_suitable(const struct sock *sk) +{ + return sk->sk_type == SOCK_STREAM && + sk->sk_protocol == IPPROTO_TCP; +} + +static int sock_map_update_elem(struct bpf_map *map, void *key, + void *value, u64 flags) +{ + u32 ufd = *(u32 *)value; + u32 idx = *(u32 *)key; + struct socket *sock; + struct sock *sk; + int ret; + + sock = sockfd_lookup(ufd, &ret); + if (!sock) + return ret; + sk = sock->sk; + if (!sk) { + ret = -EINVAL; + goto out; + } + if (!sock_map_sk_is_suitable(sk) || + sk->sk_state != TCP_ESTABLISHED) { + ret = -EOPNOTSUPP; + goto out; + } + + sock_map_sk_acquire(sk); + ret = sock_map_update_common(map, idx, sk, flags); + sock_map_sk_release(sk); +out: + fput(sock->file); + return ret; +} + +BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops, + struct bpf_map *, map, void *, key, u64, flags) +{ + WARN_ON_ONCE(!rcu_read_lock_held()); + + if (likely(sock_map_sk_is_suitable(sops->sk) && + sock_map_op_okay(sops))) + return sock_map_update_common(map, *(u32 *)key, sops->sk, + flags); + return -EOPNOTSUPP; +} + +const struct bpf_func_proto bpf_sock_map_update_proto = { + .func = bpf_sock_map_update, + .gpl_only = false, + .pkt_access = true, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_CONST_MAP_PTR, + .arg3_type = ARG_PTR_TO_MAP_KEY, + .arg4_type = ARG_ANYTHING, +}; + +BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb, + struct bpf_map *, map, u32, key, u64, flags) +{ + struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); + + if (unlikely(flags & ~(BPF_F_INGRESS))) + return SK_DROP; + tcb->bpf.flags = flags; + tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key); + if (!tcb->bpf.sk_redir) + return SK_DROP; + return SK_PASS; +} + +const struct bpf_func_proto bpf_sk_redirect_map_proto = { + .func = bpf_sk_redirect_map, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_CONST_MAP_PTR, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_ANYTHING, +}; + +BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg, + struct bpf_map *, map, u32, key, u64, flags) +{ + if (unlikely(flags & ~(BPF_F_INGRESS))) + return SK_DROP; + msg->flags = flags; + msg->sk_redir = __sock_map_lookup_elem(map, key); + if (!msg->sk_redir) + return SK_DROP; + return SK_PASS; +} + +const struct bpf_func_proto bpf_msg_redirect_map_proto = { + .func = bpf_msg_redirect_map, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_CONST_MAP_PTR, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_ANYTHING, +}; + +const struct bpf_map_ops sock_map_ops = { + .map_alloc = sock_map_alloc, + .map_free = sock_map_free, + .map_get_next_key = sock_map_get_next_key, + .map_update_elem = sock_map_update_elem, + .map_delete_elem = sock_map_delete_elem, + .map_lookup_elem = sock_map_lookup, + .map_release_uref = sock_map_release_progs, + .map_check_btf = map_check_no_btf, +}; + +struct bpf_htab_elem { + struct rcu_head rcu; + u32 hash; + struct sock *sk; + struct hlist_node node; + u8 key[0]; +}; + +struct bpf_htab_bucket { + struct hlist_head head; + raw_spinlock_t lock; +}; + +struct bpf_htab { + struct bpf_map map; + struct bpf_htab_bucket *buckets; + u32 buckets_num; + u32 elem_size; + struct sk_psock_progs progs; + atomic_t count; +}; + +static inline u32 sock_hash_bucket_hash(const void *key, u32 len) +{ + return jhash(key, len, 0); +} + +static struct bpf_htab_bucket *sock_hash_select_bucket(struct bpf_htab *htab, + u32 hash) +{ + return &htab->buckets[hash & (htab->buckets_num - 1)]; +} + +static struct bpf_htab_elem * +sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key, + u32 key_size) +{ + struct bpf_htab_elem *elem; + + hlist_for_each_entry_rcu(elem, head, node) { + if (elem->hash == hash && + !memcmp(&elem->key, key, key_size)) + return elem; + } + + return NULL; +} + +static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key) +{ + struct bpf_htab *htab = container_of(map, struct bpf_htab, map); + u32 key_size = map->key_size, hash; + struct bpf_htab_bucket *bucket; + struct bpf_htab_elem *elem; + + WARN_ON_ONCE(!rcu_read_lock_held()); + + hash = sock_hash_bucket_hash(key, key_size); + bucket = sock_hash_select_bucket(htab, hash); + elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); + + return elem ? elem->sk : NULL; +} + +static void sock_hash_free_elem(struct bpf_htab *htab, + struct bpf_htab_elem *elem) +{ + atomic_dec(&htab->count); + kfree_rcu(elem, rcu); +} + +static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, + void *link_raw) +{ + struct bpf_htab *htab = container_of(map, struct bpf_htab, map); + struct bpf_htab_elem *elem_probe, *elem = link_raw; + struct bpf_htab_bucket *bucket; + + WARN_ON_ONCE(!rcu_read_lock_held()); + bucket = sock_hash_select_bucket(htab, elem->hash); + + /* elem may be deleted in parallel from the map, but access here + * is okay since it's going away only after RCU grace period. + * However, we need to check whether it's still present. + */ + raw_spin_lock_bh(&bucket->lock); + elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash, + elem->key, map->key_size); + if (elem_probe && elem_probe == elem) { + hlist_del_rcu(&elem->node); + sock_map_unref(elem->sk, elem); + sock_hash_free_elem(htab, elem); + } + raw_spin_unlock_bh(&bucket->lock); +} + +static int sock_hash_delete_elem(struct bpf_map *map, void *key) +{ + struct bpf_htab *htab = container_of(map, struct bpf_htab, map); + u32 hash, key_size = map->key_size; + struct bpf_htab_bucket *bucket; + struct bpf_htab_elem *elem; + int ret = -ENOENT; + + hash = sock_hash_bucket_hash(key, key_size); + bucket = sock_hash_select_bucket(htab, hash); + + raw_spin_lock_bh(&bucket->lock); + elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); + if (elem) { + hlist_del_rcu(&elem->node); + sock_map_unref(elem->sk, elem); + sock_hash_free_elem(htab, elem); + ret = 0; + } + raw_spin_unlock_bh(&bucket->lock); + return ret; +} + +static struct bpf_htab_elem *sock_hash_alloc_elem(struct bpf_htab *htab, + void *key, u32 key_size, + u32 hash, struct sock *sk, + struct bpf_htab_elem *old) +{ + struct bpf_htab_elem *new; + + if (atomic_inc_return(&htab->count) > htab->map.max_entries) { + if (!old) { + atomic_dec(&htab->count); + return ERR_PTR(-E2BIG); + } + } + + new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN, + htab->map.numa_node); + if (!new) { + atomic_dec(&htab->count); + return ERR_PTR(-ENOMEM); + } + memcpy(new->key, key, key_size); + new->sk = sk; + new->hash = hash; + return new; +} + +static int sock_hash_update_common(struct bpf_map *map, void *key, + struct sock *sk, u64 flags) +{ + struct bpf_htab *htab = container_of(map, struct bpf_htab, map); + u32 key_size = map->key_size, hash; + struct bpf_htab_elem *elem, *elem_new; + struct bpf_htab_bucket *bucket; + struct sk_psock_link *link; + struct sk_psock *psock; + int ret; + + WARN_ON_ONCE(!rcu_read_lock_held()); + if (unlikely(flags > BPF_EXIST)) + return -EINVAL; + + link = sk_psock_init_link(); + if (!link) + return -ENOMEM; + + ret = sock_map_link(map, &htab->progs, sk); + if (ret < 0) + goto out_free; + + psock = sk_psock(sk); + WARN_ON_ONCE(!psock); + + hash = sock_hash_bucket_hash(key, key_size); + bucket = sock_hash_select_bucket(htab, hash); + + raw_spin_lock_bh(&bucket->lock); + elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); + if (elem && flags == BPF_NOEXIST) { + ret = -EEXIST; + goto out_unlock; + } else if (!elem && flags == BPF_EXIST) { + ret = -ENOENT; + goto out_unlock; + } + + elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem); + if (IS_ERR(elem_new)) { + ret = PTR_ERR(elem_new); + goto out_unlock; + } + + sock_map_add_link(psock, link, map, elem_new); + /* Add new element to the head of the list, so that + * concurrent search will find it before old elem. + */ + hlist_add_head_rcu(&elem_new->node, &bucket->head); + if (elem) { + hlist_del_rcu(&elem->node); + sock_map_unref(elem->sk, elem); + sock_hash_free_elem(htab, elem); + } + raw_spin_unlock_bh(&bucket->lock); + return 0; +out_unlock: + raw_spin_unlock_bh(&bucket->lock); + sk_psock_put(sk, psock); +out_free: + sk_psock_free_link(link); + return ret; +} + +static int sock_hash_update_elem(struct bpf_map *map, void *key, + void *value, u64 flags) +{ + u32 ufd = *(u32 *)value; + struct socket *sock; + struct sock *sk; + int ret; + + sock = sockfd_lookup(ufd, &ret); + if (!sock) + return ret; + sk = sock->sk; + if (!sk) { + ret = -EINVAL; + goto out; + } + if (!sock_map_sk_is_suitable(sk) || + sk->sk_state != TCP_ESTABLISHED) { + ret = -EOPNOTSUPP; + goto out; + } + + sock_map_sk_acquire(sk); + ret = sock_hash_update_common(map, key, sk, flags); + sock_map_sk_release(sk); +out: + fput(sock->file); + return ret; +} + +static int sock_hash_get_next_key(struct bpf_map *map, void *key, + void *key_next) +{ + struct bpf_htab *htab = container_of(map, struct bpf_htab, map); + struct bpf_htab_elem *elem, *elem_next; + u32 hash, key_size = map->key_size; + struct hlist_head *head; + int i = 0; + + if (!key) + goto find_first_elem; + hash = sock_hash_bucket_hash(key, key_size); + head = &sock_hash_select_bucket(htab, hash)->head; + elem = sock_hash_lookup_elem_raw(head, hash, key, key_size); + if (!elem) + goto find_first_elem; + + elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem->node)), + struct bpf_htab_elem, node); + if (elem_next) { + memcpy(key_next, elem_next->key, key_size); + return 0; + } + + i = hash & (htab->buckets_num - 1); + i++; +find_first_elem: + for (; i < htab->buckets_num; i++) { + head = &sock_hash_select_bucket(htab, i)->head; + elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)), + struct bpf_htab_elem, node); + if (elem_next) { + memcpy(key_next, elem_next->key, key_size); + return 0; + } + } + + return -ENOENT; +} + +static struct bpf_map *sock_hash_alloc(union bpf_attr *attr) +{ + struct bpf_htab *htab; + int i, err; + u64 cost; + + if (!capable(CAP_NET_ADMIN)) + return ERR_PTR(-EPERM); + if (attr->max_entries == 0 || + attr->key_size == 0 || + attr->value_size != 4 || + attr->map_flags & ~SOCK_CREATE_FLAG_MASK) + return ERR_PTR(-EINVAL); + if (attr->key_size > MAX_BPF_STACK) + return ERR_PTR(-E2BIG); + + htab = kzalloc(sizeof(*htab), GFP_USER); + if (!htab) + return ERR_PTR(-ENOMEM); + + bpf_map_init_from_attr(&htab->map, attr); + + htab->buckets_num = roundup_pow_of_two(htab->map.max_entries); + htab->elem_size = sizeof(struct bpf_htab_elem) + + round_up(htab->map.key_size, 8); + if (htab->buckets_num == 0 || + htab->buckets_num > U32_MAX / sizeof(struct bpf_htab_bucket)) { + err = -EINVAL; + goto free_htab; + } + + cost = (u64) htab->buckets_num * sizeof(struct bpf_htab_bucket) + + (u64) htab->elem_size * htab->map.max_entries; + if (cost >= U32_MAX - PAGE_SIZE) { + err = -EINVAL; + goto free_htab; + } + + htab->buckets = bpf_map_area_alloc(htab->buckets_num * + sizeof(struct bpf_htab_bucket), + htab->map.numa_node); + if (!htab->buckets) { + err = -ENOMEM; + goto free_htab; + } + + for (i = 0; i < htab->buckets_num; i++) { + INIT_HLIST_HEAD(&htab->buckets[i].head); + raw_spin_lock_init(&htab->buckets[i].lock); + } + + return &htab->map; +free_htab: + kfree(htab); + return ERR_PTR(err); +} + +static void sock_hash_free(struct bpf_map *map) +{ + struct bpf_htab *htab = container_of(map, struct bpf_htab, map); + struct bpf_htab_bucket *bucket; + struct bpf_htab_elem *elem; + struct hlist_node *node; + int i; + + synchronize_rcu(); + rcu_read_lock(); + for (i = 0; i < htab->buckets_num; i++) { + bucket = sock_hash_select_bucket(htab, i); + raw_spin_lock_bh(&bucket->lock); + hlist_for_each_entry_safe(elem, node, &bucket->head, node) { + hlist_del_rcu(&elem->node); + sock_map_unref(elem->sk, elem); + } + raw_spin_unlock_bh(&bucket->lock); + } + rcu_read_unlock(); + + bpf_map_area_free(htab->buckets); + kfree(htab); +} + +static void sock_hash_release_progs(struct bpf_map *map) +{ + psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs); +} + +BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops, + struct bpf_map *, map, void *, key, u64, flags) +{ + WARN_ON_ONCE(!rcu_read_lock_held()); + + if (likely(sock_map_sk_is_suitable(sops->sk) && + sock_map_op_okay(sops))) + return sock_hash_update_common(map, key, sops->sk, flags); + return -EOPNOTSUPP; +} + +const struct bpf_func_proto bpf_sock_hash_update_proto = { + .func = bpf_sock_hash_update, + .gpl_only = false, + .pkt_access = true, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_CONST_MAP_PTR, + .arg3_type = ARG_PTR_TO_MAP_KEY, + .arg4_type = ARG_ANYTHING, +}; + +BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb, + struct bpf_map *, map, void *, key, u64, flags) +{ + struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); + + if (unlikely(flags & ~(BPF_F_INGRESS))) + return SK_DROP; + tcb->bpf.flags = flags; + tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key); + if (!tcb->bpf.sk_redir) + return SK_DROP; + return SK_PASS; +} + +const struct bpf_func_proto bpf_sk_redirect_hash_proto = { + .func = bpf_sk_redirect_hash, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_CONST_MAP_PTR, + .arg3_type = ARG_PTR_TO_MAP_KEY, + .arg4_type = ARG_ANYTHING, +}; + +BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg, + struct bpf_map *, map, void *, key, u64, flags) +{ + if (unlikely(flags & ~(BPF_F_INGRESS))) + return SK_DROP; + msg->flags = flags; + msg->sk_redir = __sock_hash_lookup_elem(map, key); + if (!msg->sk_redir) + return SK_DROP; + return SK_PASS; +} + +const struct bpf_func_proto bpf_msg_redirect_hash_proto = { + .func = bpf_msg_redirect_hash, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_CONST_MAP_PTR, + .arg3_type = ARG_PTR_TO_MAP_KEY, + .arg4_type = ARG_ANYTHING, +}; + +const struct bpf_map_ops sock_hash_ops = { + .map_alloc = sock_hash_alloc, + .map_free = sock_hash_free, + .map_get_next_key = sock_hash_get_next_key, + .map_update_elem = sock_hash_update_elem, + .map_delete_elem = sock_hash_delete_elem, + .map_lookup_elem = sock_map_lookup, + .map_release_uref = sock_hash_release_progs, + .map_check_btf = map_check_no_btf, +}; + +static struct sk_psock_progs *sock_map_progs(struct bpf_map *map) +{ + switch (map->map_type) { + case BPF_MAP_TYPE_SOCKMAP: + return &container_of(map, struct bpf_stab, map)->progs; + case BPF_MAP_TYPE_SOCKHASH: + return &container_of(map, struct bpf_htab, map)->progs; + default: + break; + } + + return NULL; +} + +int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, + u32 which) +{ + struct sk_psock_progs *progs = sock_map_progs(map); + + if (!progs) + return -EOPNOTSUPP; + + switch (which) { + case BPF_SK_MSG_VERDICT: + psock_set_prog(&progs->msg_parser, prog); + break; + case BPF_SK_SKB_STREAM_PARSER: + psock_set_prog(&progs->skb_parser, prog); + break; + case BPF_SK_SKB_STREAM_VERDICT: + psock_set_prog(&progs->skb_verdict, prog); + break; + default: + return -EOPNOTSUPP; + } + + return 0; +} + +void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link) +{ + switch (link->map->map_type) { + case BPF_MAP_TYPE_SOCKMAP: + return sock_map_delete_from_link(link->map, sk, + link->link_raw); + case BPF_MAP_TYPE_SOCKHASH: + return sock_hash_delete_from_link(link->map, sk, + link->link_raw); + default: + break; + } +} diff --git a/net/core/sysctl_net_core.c b/net/core/sysctl_net_core.c index b1a2c5e38530..37b4667128a3 100644 --- a/net/core/sysctl_net_core.c +++ b/net/core/sysctl_net_core.c @@ -279,7 +279,6 @@ static int proc_dointvec_minmax_bpf_enable(struct ctl_table *table, int write, return ret; } -# ifdef CONFIG_HAVE_EBPF_JIT static int proc_dointvec_minmax_bpf_restricted(struct ctl_table *table, int write, void __user *buffer, size_t *lenp, @@ -290,7 +289,6 @@ proc_dointvec_minmax_bpf_restricted(struct ctl_table *table, int write, return proc_dointvec_minmax(table, write, buffer, lenp, ppos); } -# endif #endif static struct ctl_table net_core_table[] = { @@ -397,6 +395,14 @@ static struct ctl_table net_core_table[] = { .extra2 = &one, }, # endif + { + .procname = "bpf_jit_limit", + .data = &bpf_jit_limit, + .maxlen = sizeof(int), + .mode = 0600, + .proc_handler = proc_dointvec_minmax_bpf_restricted, + .extra1 = &one, + }, #endif { .procname = "netdev_tstamp_prequeue", diff --git a/net/dccp/input.c b/net/dccp/input.c index d28d46bff6ab..85d6c879383d 100644 --- a/net/dccp/input.c +++ b/net/dccp/input.c @@ -606,11 +606,13 @@ int dccp_rcv_state_process(struct sock *sk, struct sk_buff *skb, if (sk->sk_state == DCCP_LISTEN) { if (dh->dccph_type == DCCP_PKT_REQUEST) { /* It is possible that we process SYN packets from backlog, - * so we need to make sure to disable BH right there. + * so we need to make sure to disable BH and RCU right there. */ + rcu_read_lock(); local_bh_disable(); acceptable = inet_csk(sk)->icsk_af_ops->conn_request(sk, skb) >= 0; local_bh_enable(); + rcu_read_unlock(); if (!acceptable) return 1; consume_skb(skb); diff --git a/net/dccp/ipv4.c b/net/dccp/ipv4.c index b08feb219b44..8e08cea6f178 100644 --- a/net/dccp/ipv4.c +++ b/net/dccp/ipv4.c @@ -493,9 +493,11 @@ static int dccp_v4_send_response(const struct sock *sk, struct request_sock *req dh->dccph_checksum = dccp_v4_csum_finish(skb, ireq->ir_loc_addr, ireq->ir_rmt_addr); + rcu_read_lock(); err = ip_build_and_send_pkt(skb, sk, ireq->ir_loc_addr, ireq->ir_rmt_addr, - ireq_opt_deref(ireq)); + rcu_dereference(ireq->ireq_opt)); + rcu_read_unlock(); err = net_xmit_eval(err); } diff --git a/net/dccp/proto.c b/net/dccp/proto.c index 875858c8b059..43733accf58e 100644 --- a/net/dccp/proto.c +++ b/net/dccp/proto.c @@ -325,7 +325,7 @@ __poll_t dccp_poll(struct file *file, struct socket *sock, __poll_t mask; struct sock *sk = sock->sk; - sock_poll_wait(file, wait); + sock_poll_wait(file, sock, wait); if (sk->sk_state == DCCP_LISTEN) return inet_csk_listen_poll(sk); diff --git a/net/dns_resolver/dns_key.c b/net/dns_resolver/dns_key.c index 7f4534828f6c..a65d553e730d 100644 --- a/net/dns_resolver/dns_key.c +++ b/net/dns_resolver/dns_key.c @@ -29,6 +29,7 @@ #include <linux/keyctl.h> #include <linux/err.h> #include <linux/seq_file.h> +#include <linux/dns_resolver.h> #include <keys/dns_resolver-type.h> #include <keys/user-type.h> #include "internal.h" @@ -48,27 +49,86 @@ const struct cred *dns_resolver_cache; /* * Preparse instantiation data for a dns_resolver key. * - * The data must be a NUL-terminated string, with the NUL char accounted in - * datalen. + * For normal hostname lookups, the data must be a NUL-terminated string, with + * the NUL char accounted in datalen. * * If the data contains a '#' characters, then we take the clause after each * one to be an option of the form 'key=value'. The actual data of interest is * the string leading up to the first '#'. For instance: * * "ip1,ip2,...#foo=bar" + * + * For server list requests, the data must begin with a NUL char and be + * followed by a byte indicating the version of the data format. Version 1 + * looks something like (note this is packed): + * + * u8 Non-string marker (ie. 0) + * u8 Content (DNS_PAYLOAD_IS_*) + * u8 Version (e.g. 1) + * u8 Source of server list + * u8 Lookup status of server list + * u8 Number of servers + * foreach-server { + * __le16 Name length + * __le16 Priority (as per SRV record, low first) + * __le16 Weight (as per SRV record, higher first) + * __le16 Port + * u8 Source of address list + * u8 Lookup status of address list + * u8 Protocol (DNS_SERVER_PROTOCOL_*) + * u8 Number of addresses + * char[] Name (not NUL-terminated) + * foreach-address { + * u8 Family (DNS_ADDRESS_IS_*) + * union { + * u8[4] ipv4_addr + * u8[16] ipv6_addr + * } + * } + * } + * */ static int dns_resolver_preparse(struct key_preparsed_payload *prep) { + const struct dns_payload_header *bin; struct user_key_payload *upayload; unsigned long derrno; int ret; int datalen = prep->datalen, result_len = 0; const char *data = prep->data, *end, *opt; + if (datalen <= 1 || !data) + return -EINVAL; + + if (data[0] == 0) { + /* It may be a server list. */ + if (datalen <= sizeof(*bin)) + return -EINVAL; + + bin = (const struct dns_payload_header *)data; + kenter("[%u,%u],%u", bin->content, bin->version, datalen); + if (bin->content != DNS_PAYLOAD_IS_SERVER_LIST) { + pr_warn_ratelimited( + "dns_resolver: Unsupported content type (%u)\n", + bin->content); + return -EINVAL; + } + + if (bin->version != 1) { + pr_warn_ratelimited( + "dns_resolver: Unsupported server list version (%u)\n", + bin->version); + return -EINVAL; + } + + result_len = datalen; + goto store_result; + } + kenter("'%*.*s',%u", datalen, datalen, data, datalen); - if (datalen <= 1 || !data || data[datalen - 1] != '\0') + if (!data || data[datalen - 1] != '\0') return -EINVAL; datalen--; @@ -144,6 +204,7 @@ dns_resolver_preparse(struct key_preparsed_payload *prep) return 0; } +store_result: kdebug("store result"); prep->quotalen = result_len; diff --git a/net/dns_resolver/dns_query.c b/net/dns_resolver/dns_query.c index 49da67034f29..76338c38738a 100644 --- a/net/dns_resolver/dns_query.c +++ b/net/dns_resolver/dns_query.c @@ -148,12 +148,9 @@ int dns_query(const char *type, const char *name, size_t namelen, if (_result) { ret = -ENOMEM; - *_result = kmalloc(len + 1, GFP_KERNEL); + *_result = kmemdup_nul(upayload->data, len, GFP_KERNEL); if (!*_result) goto put; - - memcpy(*_result, upayload->data, len); - (*_result)[len] = '\0'; } if (_expiry) diff --git a/net/dsa/legacy.c b/net/dsa/legacy.c index 8aa92b09db76..cb42939db776 100644 --- a/net/dsa/legacy.c +++ b/net/dsa/legacy.c @@ -686,8 +686,7 @@ static void dsa_shutdown(struct platform_device *pdev) #ifdef CONFIG_PM_SLEEP static int dsa_suspend(struct device *d) { - struct platform_device *pdev = to_platform_device(d); - struct dsa_switch_tree *dst = platform_get_drvdata(pdev); + struct dsa_switch_tree *dst = dev_get_drvdata(d); int i, ret = 0; for (i = 0; i < dst->pd->nr_chips; i++) { @@ -702,8 +701,7 @@ static int dsa_suspend(struct device *d) static int dsa_resume(struct device *d) { - struct platform_device *pdev = to_platform_device(d); - struct dsa_switch_tree *dst = platform_get_drvdata(pdev); + struct dsa_switch_tree *dst = dev_get_drvdata(d); int i, ret = 0; for (i = 0; i < dst->pd->nr_chips; i++) { diff --git a/net/dsa/slave.c b/net/dsa/slave.c index 3f840b6eea69..7d0c19e7edcf 100644 --- a/net/dsa/slave.c +++ b/net/dsa/slave.c @@ -722,7 +722,7 @@ static void dsa_slave_netpoll_cleanup(struct net_device *dev) p->netpoll = NULL; - __netpoll_free_async(netpoll); + __netpoll_free(netpoll); } static void dsa_slave_poll_controller(struct net_device *dev) @@ -1478,6 +1478,7 @@ static void dsa_slave_switchdev_event_work(struct work_struct *work) netdev_dbg(dev, "fdb add failed err=%d\n", err); break; } + fdb_info->offloaded = true; call_switchdev_notifiers(SWITCHDEV_FDB_OFFLOADED, dev, &fdb_info->info); break; diff --git a/net/ipv4/Makefile b/net/ipv4/Makefile index 7446b98661d8..58629314eae9 100644 --- a/net/ipv4/Makefile +++ b/net/ipv4/Makefile @@ -63,6 +63,7 @@ obj-$(CONFIG_TCP_CONG_SCALABLE) += tcp_scalable.o obj-$(CONFIG_TCP_CONG_LP) += tcp_lp.o obj-$(CONFIG_TCP_CONG_YEAH) += tcp_yeah.o obj-$(CONFIG_TCP_CONG_ILLINOIS) += tcp_illinois.o +obj-$(CONFIG_NET_SOCK_MSG) += tcp_bpf.o obj-$(CONFIG_NETLABEL) += cipso_ipv4.o obj-$(CONFIG_XFRM) += xfrm4_policy.o xfrm4_state.o xfrm4_input.o \ diff --git a/net/ipv4/arp.c b/net/ipv4/arp.c index e90c89ef8c08..850a6f13a082 100644 --- a/net/ipv4/arp.c +++ b/net/ipv4/arp.c @@ -1255,6 +1255,8 @@ static int arp_netdev_event(struct notifier_block *this, unsigned long event, change_info = ptr; if (change_info->flags_changed & IFF_NOARP) neigh_changeaddr(&arp_tbl, dev); + if (!netif_carrier_ok(dev)) + neigh_carrier_down(&arp_tbl, dev); break; default: break; diff --git a/net/ipv4/devinet.c b/net/ipv4/devinet.c index 44d931a3cd50..a34602ae27de 100644 --- a/net/ipv4/devinet.c +++ b/net/ipv4/devinet.c @@ -109,6 +109,7 @@ struct inet_fill_args { int event; unsigned int flags; int netnsid; + int ifindex; }; #define IN4_ADDR_HSIZE_SHIFT 8 @@ -782,7 +783,8 @@ static void set_ifa_lifetime(struct in_ifaddr *ifa, __u32 valid_lft, } static struct in_ifaddr *rtm_to_ifaddr(struct net *net, struct nlmsghdr *nlh, - __u32 *pvalid_lft, __u32 *pprefered_lft) + __u32 *pvalid_lft, __u32 *pprefered_lft, + struct netlink_ext_ack *extack) { struct nlattr *tb[IFA_MAX+1]; struct in_ifaddr *ifa; @@ -792,7 +794,7 @@ static struct in_ifaddr *rtm_to_ifaddr(struct net *net, struct nlmsghdr *nlh, int err; err = nlmsg_parse(nlh, sizeof(*ifm), tb, IFA_MAX, ifa_ipv4_policy, - NULL); + extack); if (err < 0) goto errout; @@ -897,7 +899,7 @@ static int inet_rtm_newaddr(struct sk_buff *skb, struct nlmsghdr *nlh, ASSERT_RTNL(); - ifa = rtm_to_ifaddr(net, nlh, &valid_lft, &prefered_lft); + ifa = rtm_to_ifaddr(net, nlh, &valid_lft, &prefered_lft, extack); if (IS_ERR(ifa)) return PTR_ERR(ifa); @@ -1659,39 +1661,133 @@ nla_put_failure: return -EMSGSIZE; } +static int inet_valid_dump_ifaddr_req(const struct nlmsghdr *nlh, + struct inet_fill_args *fillargs, + struct net **tgt_net, struct sock *sk, + struct netlink_callback *cb) +{ + struct netlink_ext_ack *extack = cb->extack; + struct nlattr *tb[IFA_MAX+1]; + struct ifaddrmsg *ifm; + int err, i; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ifm))) { + NL_SET_ERR_MSG(extack, "ipv4: Invalid header for address dump request"); + return -EINVAL; + } + + ifm = nlmsg_data(nlh); + if (ifm->ifa_prefixlen || ifm->ifa_flags || ifm->ifa_scope) { + NL_SET_ERR_MSG(extack, "ipv4: Invalid values in header for address dump request"); + return -EINVAL; + } + + fillargs->ifindex = ifm->ifa_index; + if (fillargs->ifindex) { + cb->answer_flags |= NLM_F_DUMP_FILTERED; + fillargs->flags |= NLM_F_DUMP_FILTERED; + } + + err = nlmsg_parse_strict(nlh, sizeof(*ifm), tb, IFA_MAX, + ifa_ipv4_policy, extack); + if (err < 0) + return err; + + for (i = 0; i <= IFA_MAX; ++i) { + if (!tb[i]) + continue; + + if (i == IFA_TARGET_NETNSID) { + struct net *net; + + fillargs->netnsid = nla_get_s32(tb[i]); + + net = rtnl_get_net_ns_capable(sk, fillargs->netnsid); + if (IS_ERR(net)) { + fillargs->netnsid = -1; + NL_SET_ERR_MSG(extack, "ipv4: Invalid target network namespace id"); + return PTR_ERR(net); + } + *tgt_net = net; + } else { + NL_SET_ERR_MSG(extack, "ipv4: Unsupported attribute in dump request"); + return -EINVAL; + } + } + + return 0; +} + +static int in_dev_dump_addr(struct in_device *in_dev, struct sk_buff *skb, + struct netlink_callback *cb, int s_ip_idx, + struct inet_fill_args *fillargs) +{ + struct in_ifaddr *ifa; + int ip_idx = 0; + int err; + + for (ifa = in_dev->ifa_list; ifa; ifa = ifa->ifa_next, ip_idx++) { + if (ip_idx < s_ip_idx) + continue; + + err = inet_fill_ifaddr(skb, ifa, fillargs); + if (err < 0) + goto done; + + nl_dump_check_consistent(cb, nlmsg_hdr(skb)); + } + err = 0; + +done: + cb->args[2] = ip_idx; + + return err; +} + static int inet_dump_ifaddr(struct sk_buff *skb, struct netlink_callback *cb) { + const struct nlmsghdr *nlh = cb->nlh; struct inet_fill_args fillargs = { .portid = NETLINK_CB(cb->skb).portid, - .seq = cb->nlh->nlmsg_seq, + .seq = nlh->nlmsg_seq, .event = RTM_NEWADDR, .flags = NLM_F_MULTI, .netnsid = -1, }; struct net *net = sock_net(skb->sk); - struct nlattr *tb[IFA_MAX+1]; struct net *tgt_net = net; int h, s_h; int idx, s_idx; - int ip_idx, s_ip_idx; + int s_ip_idx; struct net_device *dev; struct in_device *in_dev; - struct in_ifaddr *ifa; struct hlist_head *head; + int err = 0; s_h = cb->args[0]; s_idx = idx = cb->args[1]; - s_ip_idx = ip_idx = cb->args[2]; - - if (nlmsg_parse(cb->nlh, sizeof(struct ifaddrmsg), tb, IFA_MAX, - ifa_ipv4_policy, NULL) >= 0) { - if (tb[IFA_TARGET_NETNSID]) { - fillargs.netnsid = nla_get_s32(tb[IFA_TARGET_NETNSID]); + s_ip_idx = cb->args[2]; + + if (cb->strict_check) { + err = inet_valid_dump_ifaddr_req(nlh, &fillargs, &tgt_net, + skb->sk, cb); + if (err < 0) + goto put_tgt_net; + + err = 0; + if (fillargs.ifindex) { + dev = __dev_get_by_index(tgt_net, fillargs.ifindex); + if (!dev) { + err = -ENODEV; + goto put_tgt_net; + } - tgt_net = rtnl_get_net_ns_capable(skb->sk, - fillargs.netnsid); - if (IS_ERR(tgt_net)) - return PTR_ERR(tgt_net); + in_dev = __in_dev_get_rtnl(dev); + if (in_dev) { + err = in_dev_dump_addr(in_dev, skb, cb, s_ip_idx, + &fillargs); + } + goto put_tgt_net; } } @@ -1710,15 +1806,11 @@ static int inet_dump_ifaddr(struct sk_buff *skb, struct netlink_callback *cb) if (!in_dev) goto cont; - for (ifa = in_dev->ifa_list, ip_idx = 0; ifa; - ifa = ifa->ifa_next, ip_idx++) { - if (ip_idx < s_ip_idx) - continue; - if (inet_fill_ifaddr(skb, ifa, &fillargs) < 0) { - rcu_read_unlock(); - goto done; - } - nl_dump_check_consistent(cb, nlmsg_hdr(skb)); + err = in_dev_dump_addr(in_dev, skb, cb, s_ip_idx, + &fillargs); + if (err < 0) { + rcu_read_unlock(); + goto done; } cont: idx++; @@ -1729,11 +1821,11 @@ cont: done: cb->args[0] = h; cb->args[1] = idx; - cb->args[2] = ip_idx; +put_tgt_net: if (fillargs.netnsid >= 0) put_net(tgt_net); - return skb->len; + return err < 0 ? err : skb->len; } static void rtmsg_ifa(int event, struct in_ifaddr *ifa, struct nlmsghdr *nlh, @@ -2035,6 +2127,7 @@ errout: static int inet_netconf_dump_devconf(struct sk_buff *skb, struct netlink_callback *cb) { + const struct nlmsghdr *nlh = cb->nlh; struct net *net = sock_net(skb->sk); int h, s_h; int idx, s_idx; @@ -2042,6 +2135,21 @@ static int inet_netconf_dump_devconf(struct sk_buff *skb, struct in_device *in_dev; struct hlist_head *head; + if (cb->strict_check) { + struct netlink_ext_ack *extack = cb->extack; + struct netconfmsg *ncm; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ncm))) { + NL_SET_ERR_MSG(extack, "ipv4: Invalid header for netconf dump request"); + return -EINVAL; + } + + if (nlmsg_attrlen(nlh, sizeof(*ncm))) { + NL_SET_ERR_MSG(extack, "ipv4: Invalid data after header in netconf dump request"); + return -EINVAL; + } + } + s_h = cb->args[0]; s_idx = idx = cb->args[1]; @@ -2061,7 +2169,7 @@ static int inet_netconf_dump_devconf(struct sk_buff *skb, if (inet_netconf_fill_devconf(skb, dev->ifindex, &in_dev->cnf, NETLINK_CB(cb->skb).portid, - cb->nlh->nlmsg_seq, + nlh->nlmsg_seq, RTM_NEWNETCONF, NLM_F_MULTI, NETCONFA_ALL) < 0) { @@ -2078,7 +2186,7 @@ cont: if (inet_netconf_fill_devconf(skb, NETCONFA_IFINDEX_ALL, net->ipv4.devconf_all, NETLINK_CB(cb->skb).portid, - cb->nlh->nlmsg_seq, + nlh->nlmsg_seq, RTM_NEWNETCONF, NLM_F_MULTI, NETCONFA_ALL) < 0) goto done; @@ -2089,7 +2197,7 @@ cont: if (inet_netconf_fill_devconf(skb, NETCONFA_IFINDEX_DEFAULT, net->ipv4.devconf_dflt, NETLINK_CB(cb->skb).portid, - cb->nlh->nlmsg_seq, + nlh->nlmsg_seq, RTM_NEWNETCONF, NLM_F_MULTI, NETCONFA_ALL) < 0) goto done; diff --git a/net/ipv4/fib_frontend.c b/net/ipv4/fib_frontend.c index 30e2bcc3ef2a..6df95be96311 100644 --- a/net/ipv4/fib_frontend.c +++ b/net/ipv4/fib_frontend.c @@ -802,19 +802,115 @@ errout: return err; } +int ip_valid_fib_dump_req(struct net *net, const struct nlmsghdr *nlh, + struct fib_dump_filter *filter, + struct netlink_callback *cb) +{ + struct netlink_ext_ack *extack = cb->extack; + struct nlattr *tb[RTA_MAX + 1]; + struct rtmsg *rtm; + int err, i; + + ASSERT_RTNL(); + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*rtm))) { + NL_SET_ERR_MSG(extack, "Invalid header for FIB dump request"); + return -EINVAL; + } + + rtm = nlmsg_data(nlh); + if (rtm->rtm_dst_len || rtm->rtm_src_len || rtm->rtm_tos || + rtm->rtm_scope) { + NL_SET_ERR_MSG(extack, "Invalid values in header for FIB dump request"); + return -EINVAL; + } + if (rtm->rtm_flags & ~(RTM_F_CLONED | RTM_F_PREFIX)) { + NL_SET_ERR_MSG(extack, "Invalid flags for FIB dump request"); + return -EINVAL; + } + + filter->dump_all_families = (rtm->rtm_family == AF_UNSPEC); + filter->flags = rtm->rtm_flags; + filter->protocol = rtm->rtm_protocol; + filter->rt_type = rtm->rtm_type; + filter->table_id = rtm->rtm_table; + + err = nlmsg_parse_strict(nlh, sizeof(*rtm), tb, RTA_MAX, + rtm_ipv4_policy, extack); + if (err < 0) + return err; + + for (i = 0; i <= RTA_MAX; ++i) { + int ifindex; + + if (!tb[i]) + continue; + + switch (i) { + case RTA_TABLE: + filter->table_id = nla_get_u32(tb[i]); + break; + case RTA_OIF: + ifindex = nla_get_u32(tb[i]); + filter->dev = __dev_get_by_index(net, ifindex); + if (!filter->dev) + return -ENODEV; + break; + default: + NL_SET_ERR_MSG(extack, "Unsupported attribute in dump request"); + return -EINVAL; + } + } + + if (filter->flags || filter->protocol || filter->rt_type || + filter->table_id || filter->dev) { + filter->filter_set = 1; + cb->answer_flags = NLM_F_DUMP_FILTERED; + } + + return 0; +} +EXPORT_SYMBOL_GPL(ip_valid_fib_dump_req); + static int inet_dump_fib(struct sk_buff *skb, struct netlink_callback *cb) { + const struct nlmsghdr *nlh = cb->nlh; struct net *net = sock_net(skb->sk); + struct fib_dump_filter filter = {}; unsigned int h, s_h; unsigned int e = 0, s_e; struct fib_table *tb; struct hlist_head *head; int dumped = 0, err; - if (nlmsg_len(cb->nlh) >= sizeof(struct rtmsg) && - ((struct rtmsg *) nlmsg_data(cb->nlh))->rtm_flags & RTM_F_CLONED) + if (cb->strict_check) { + err = ip_valid_fib_dump_req(net, nlh, &filter, cb); + if (err < 0) + return err; + } else if (nlmsg_len(nlh) >= sizeof(struct rtmsg)) { + struct rtmsg *rtm = nlmsg_data(nlh); + + filter.flags = rtm->rtm_flags & (RTM_F_PREFIX | RTM_F_CLONED); + } + + /* fib entries are never clones and ipv4 does not use prefix flag */ + if (filter.flags & (RTM_F_PREFIX | RTM_F_CLONED)) return skb->len; + if (filter.table_id) { + tb = fib_get_table(net, filter.table_id); + if (!tb) { + if (filter.dump_all_families) + return skb->len; + + NL_SET_ERR_MSG(cb->extack, "ipv4: FIB table does not exist"); + return -ENOENT; + } + + err = fib_table_dump(tb, skb, cb, &filter); + return skb->len ? : err; + } + s_h = cb->args[0]; s_e = cb->args[1]; @@ -829,7 +925,7 @@ static int inet_dump_fib(struct sk_buff *skb, struct netlink_callback *cb) if (dumped) memset(&cb->args[2], 0, sizeof(cb->args) - 2 * sizeof(cb->args[0])); - err = fib_table_dump(tb, skb, cb); + err = fib_table_dump(tb, skb, cb, &filter); if (err < 0) { if (likely(skb->len)) goto out; @@ -1253,7 +1349,8 @@ static int fib_inetaddr_event(struct notifier_block *this, unsigned long event, static int fib_netdev_event(struct notifier_block *this, unsigned long event, void *ptr) { struct net_device *dev = netdev_notifier_info_to_dev(ptr); - struct netdev_notifier_changeupper_info *info; + struct netdev_notifier_changeupper_info *upper_info = ptr; + struct netdev_notifier_info_ext *info_ext = ptr; struct in_device *in_dev; struct net *net = dev_net(dev); unsigned int flags; @@ -1288,16 +1385,19 @@ static int fib_netdev_event(struct notifier_block *this, unsigned long event, vo fib_sync_up(dev, RTNH_F_LINKDOWN); else fib_sync_down_dev(dev, event, false); - /* fall through */ + rt_cache_flush(net); + break; case NETDEV_CHANGEMTU: + fib_sync_mtu(dev, info_ext->ext.mtu); rt_cache_flush(net); break; case NETDEV_CHANGEUPPER: - info = ptr; + upper_info = ptr; /* flush all routes if dev is linked to or unlinked from * an L3 master device (e.g., VRF) */ - if (info->upper_dev && netif_is_l3_master(info->upper_dev)) + if (upper_info->upper_dev && + netif_is_l3_master(upper_info->upper_dev)) fib_disable_ip(dev, NETDEV_DOWN, true); break; } diff --git a/net/ipv4/fib_semantics.c b/net/ipv4/fib_semantics.c index bee8db979195..b5c3937ca6ec 100644 --- a/net/ipv4/fib_semantics.c +++ b/net/ipv4/fib_semantics.c @@ -208,7 +208,6 @@ static void rt_fibinfo_free_cpus(struct rtable __rcu * __percpu *rtp) static void free_fib_info_rcu(struct rcu_head *head) { struct fib_info *fi = container_of(head, struct fib_info, rcu); - struct dst_metrics *m; change_nexthops(fi) { if (nexthop_nh->nh_dev) @@ -219,9 +218,8 @@ static void free_fib_info_rcu(struct rcu_head *head) rt_fibinfo_free(&nexthop_nh->nh_rth_input); } endfor_nexthops(fi); - m = fi->fib_metrics; - if (m != &dst_default_metrics && refcount_dec_and_test(&m->refcnt)) - kfree(m); + ip_fib_metrics_put(fi->fib_metrics); + kfree(fi); } @@ -1020,13 +1018,6 @@ static bool fib_valid_prefsrc(struct fib_config *cfg, __be32 fib_prefsrc) return true; } -static int -fib_convert_metrics(struct fib_info *fi, const struct fib_config *cfg) -{ - return ip_metrics_convert(fi->fib_net, cfg->fc_mx, cfg->fc_mx_len, - fi->fib_metrics->metrics); -} - struct fib_info *fib_create_info(struct fib_config *cfg, struct netlink_ext_ack *extack) { @@ -1084,16 +1075,14 @@ struct fib_info *fib_create_info(struct fib_config *cfg, fi = kzalloc(sizeof(*fi)+nhs*sizeof(struct fib_nh), GFP_KERNEL); if (!fi) goto failure; - if (cfg->fc_mx) { - fi->fib_metrics = kzalloc(sizeof(*fi->fib_metrics), GFP_KERNEL); - if (unlikely(!fi->fib_metrics)) { - kfree(fi); - return ERR_PTR(err); - } - refcount_set(&fi->fib_metrics->refcnt, 1); - } else { - fi->fib_metrics = (struct dst_metrics *)&dst_default_metrics; + fi->fib_metrics = ip_fib_metrics_init(fi->fib_net, cfg->fc_mx, + cfg->fc_mx_len); + if (unlikely(IS_ERR(fi->fib_metrics))) { + err = PTR_ERR(fi->fib_metrics); + kfree(fi); + return ERR_PTR(err); } + fib_info_cnt++; fi->fib_net = net; fi->fib_protocol = cfg->fc_protocol; @@ -1112,10 +1101,6 @@ struct fib_info *fib_create_info(struct fib_config *cfg, goto failure; } endfor_nexthops(fi) - err = fib_convert_metrics(fi, cfg); - if (err) - goto failure; - if (cfg->fc_mp) { #ifdef CONFIG_IP_ROUTE_MULTIPATH err = fib_get_nhs(fi, cfg->fc_mp, cfg->fc_mp_len, cfg, extack); @@ -1472,6 +1457,56 @@ static int call_fib_nh_notifiers(struct fib_nh *fib_nh, return NOTIFY_DONE; } +/* Update the PMTU of exceptions when: + * - the new MTU of the first hop becomes smaller than the PMTU + * - the old MTU was the same as the PMTU, and it limited discovery of + * larger MTUs on the path. With that limit raised, we can now + * discover larger MTUs + * A special case is locked exceptions, for which the PMTU is smaller + * than the minimal accepted PMTU: + * - if the new MTU is greater than the PMTU, don't make any change + * - otherwise, unlock and set PMTU + */ +static void nh_update_mtu(struct fib_nh *nh, u32 new, u32 orig) +{ + struct fnhe_hash_bucket *bucket; + int i; + + bucket = rcu_dereference_protected(nh->nh_exceptions, 1); + if (!bucket) + return; + + for (i = 0; i < FNHE_HASH_SIZE; i++) { + struct fib_nh_exception *fnhe; + + for (fnhe = rcu_dereference_protected(bucket[i].chain, 1); + fnhe; + fnhe = rcu_dereference_protected(fnhe->fnhe_next, 1)) { + if (fnhe->fnhe_mtu_locked) { + if (new <= fnhe->fnhe_pmtu) { + fnhe->fnhe_pmtu = new; + fnhe->fnhe_mtu_locked = false; + } + } else if (new < fnhe->fnhe_pmtu || + orig == fnhe->fnhe_pmtu) { + fnhe->fnhe_pmtu = new; + } + } + } +} + +void fib_sync_mtu(struct net_device *dev, u32 orig_mtu) +{ + unsigned int hash = fib_devindex_hashfn(dev->ifindex); + struct hlist_head *head = &fib_info_devhash[hash]; + struct fib_nh *nh; + + hlist_for_each_entry(nh, head, nh_hash) { + if (nh->nh_dev == dev) + nh_update_mtu(nh, dev->mtu, orig_mtu); + } +} + /* Event force Flags Description * NETDEV_CHANGE 0 LINKDOWN Carrier OFF, not for scope host * NETDEV_DOWN 0 LINKDOWN|DEAD Link down, not for scope host diff --git a/net/ipv4/fib_trie.c b/net/ipv4/fib_trie.c index 5bc0c89e81e4..237c9f72b265 100644 --- a/net/ipv4/fib_trie.c +++ b/net/ipv4/fib_trie.c @@ -2003,12 +2003,17 @@ void fib_free_table(struct fib_table *tb) } static int fn_trie_dump_leaf(struct key_vector *l, struct fib_table *tb, - struct sk_buff *skb, struct netlink_callback *cb) + struct sk_buff *skb, struct netlink_callback *cb, + struct fib_dump_filter *filter) { + unsigned int flags = NLM_F_MULTI; __be32 xkey = htonl(l->key); struct fib_alias *fa; int i, s_i; + if (filter->filter_set) + flags |= NLM_F_DUMP_FILTERED; + s_i = cb->args[4]; i = 0; @@ -2016,25 +2021,35 @@ static int fn_trie_dump_leaf(struct key_vector *l, struct fib_table *tb, hlist_for_each_entry_rcu(fa, &l->leaf, fa_list) { int err; - if (i < s_i) { - i++; - continue; - } + if (i < s_i) + goto next; - if (tb->tb_id != fa->tb_id) { - i++; - continue; + if (tb->tb_id != fa->tb_id) + goto next; + + if (filter->filter_set) { + if (filter->rt_type && fa->fa_type != filter->rt_type) + goto next; + + if ((filter->protocol && + fa->fa_info->fib_protocol != filter->protocol)) + goto next; + + if (filter->dev && + !fib_info_nh_uses_dev(fa->fa_info, filter->dev)) + goto next; } err = fib_dump_info(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, RTM_NEWROUTE, tb->tb_id, fa->fa_type, xkey, KEYLENGTH - fa->fa_slen, - fa->fa_tos, fa->fa_info, NLM_F_MULTI); + fa->fa_tos, fa->fa_info, flags); if (err < 0) { cb->args[4] = i; return err; } +next: i++; } @@ -2044,7 +2059,7 @@ static int fn_trie_dump_leaf(struct key_vector *l, struct fib_table *tb, /* rcu_read_lock needs to be hold by caller from readside */ int fib_table_dump(struct fib_table *tb, struct sk_buff *skb, - struct netlink_callback *cb) + struct netlink_callback *cb, struct fib_dump_filter *filter) { struct trie *t = (struct trie *)tb->tb_data; struct key_vector *l, *tp = t->kv; @@ -2057,7 +2072,7 @@ int fib_table_dump(struct fib_table *tb, struct sk_buff *skb, while ((l = leaf_walk_rcu(&tp, key)) != NULL) { int err; - err = fn_trie_dump_leaf(l, tb, skb, cb); + err = fn_trie_dump_leaf(l, tb, skb, cb, filter); if (err < 0) { cb->args[3] = key; cb->args[2] = count; diff --git a/net/ipv4/igmp.c b/net/ipv4/igmp.c index 4da39446da2d..765b2b32c4a4 100644 --- a/net/ipv4/igmp.c +++ b/net/ipv4/igmp.c @@ -111,13 +111,10 @@ #ifdef CONFIG_IP_MULTICAST /* Parameter names and values are taken from igmp-v2-06 draft */ -#define IGMP_V1_ROUTER_PRESENT_TIMEOUT (400*HZ) -#define IGMP_V2_ROUTER_PRESENT_TIMEOUT (400*HZ) #define IGMP_V2_UNSOLICITED_REPORT_INTERVAL (10*HZ) #define IGMP_V3_UNSOLICITED_REPORT_INTERVAL (1*HZ) +#define IGMP_QUERY_INTERVAL (125*HZ) #define IGMP_QUERY_RESPONSE_INTERVAL (10*HZ) -#define IGMP_QUERY_ROBUSTNESS_VARIABLE 2 - #define IGMP_INITIAL_REPORT_DELAY (1) @@ -935,13 +932,15 @@ static bool igmp_heard_query(struct in_device *in_dev, struct sk_buff *skb, max_delay = IGMP_QUERY_RESPONSE_INTERVAL; in_dev->mr_v1_seen = jiffies + - IGMP_V1_ROUTER_PRESENT_TIMEOUT; + (in_dev->mr_qrv * in_dev->mr_qi) + + in_dev->mr_qri; group = 0; } else { /* v2 router present */ max_delay = ih->code*(HZ/IGMP_TIMER_SCALE); in_dev->mr_v2_seen = jiffies + - IGMP_V2_ROUTER_PRESENT_TIMEOUT; + (in_dev->mr_qrv * in_dev->mr_qi) + + in_dev->mr_qri; } /* cancel the interface change timer */ in_dev->mr_ifc_count = 0; @@ -981,8 +980,21 @@ static bool igmp_heard_query(struct in_device *in_dev, struct sk_buff *skb, if (!max_delay) max_delay = 1; /* can't mod w/ 0 */ in_dev->mr_maxdelay = max_delay; - if (ih3->qrv) - in_dev->mr_qrv = ih3->qrv; + + /* RFC3376, 4.1.6. QRV and 4.1.7. QQIC, when the most recently + * received value was zero, use the default or statically + * configured value. + */ + in_dev->mr_qrv = ih3->qrv ?: net->ipv4.sysctl_igmp_qrv; + in_dev->mr_qi = IGMPV3_QQIC(ih3->qqic)*HZ ?: IGMP_QUERY_INTERVAL; + + /* RFC3376, 8.3. Query Response Interval: + * The number of seconds represented by the [Query Response + * Interval] must be less than the [Query Interval]. + */ + if (in_dev->mr_qri >= in_dev->mr_qi) + in_dev->mr_qri = (in_dev->mr_qi/HZ - 1)*HZ; + if (!group) { /* general query */ if (ih3->nsrcs) return true; /* no sources allowed */ @@ -1723,18 +1735,30 @@ void ip_mc_down(struct in_device *in_dev) ip_mc_dec_group(in_dev, IGMP_ALL_HOSTS); } -void ip_mc_init_dev(struct in_device *in_dev) -{ #ifdef CONFIG_IP_MULTICAST +static void ip_mc_reset(struct in_device *in_dev) +{ struct net *net = dev_net(in_dev->dev); + + in_dev->mr_qi = IGMP_QUERY_INTERVAL; + in_dev->mr_qri = IGMP_QUERY_RESPONSE_INTERVAL; + in_dev->mr_qrv = net->ipv4.sysctl_igmp_qrv; +} +#else +static void ip_mc_reset(struct in_device *in_dev) +{ +} #endif + +void ip_mc_init_dev(struct in_device *in_dev) +{ ASSERT_RTNL(); #ifdef CONFIG_IP_MULTICAST timer_setup(&in_dev->mr_gq_timer, igmp_gq_timer_expire, 0); timer_setup(&in_dev->mr_ifc_timer, igmp_ifc_timer_expire, 0); - in_dev->mr_qrv = net->ipv4.sysctl_igmp_qrv; #endif + ip_mc_reset(in_dev); spin_lock_init(&in_dev->mc_tomb_lock); } @@ -1744,15 +1768,10 @@ void ip_mc_init_dev(struct in_device *in_dev) void ip_mc_up(struct in_device *in_dev) { struct ip_mc_list *pmc; -#ifdef CONFIG_IP_MULTICAST - struct net *net = dev_net(in_dev->dev); -#endif ASSERT_RTNL(); -#ifdef CONFIG_IP_MULTICAST - in_dev->mr_qrv = net->ipv4.sysctl_igmp_qrv; -#endif + ip_mc_reset(in_dev); ip_mc_inc_group(in_dev, IGMP_ALL_HOSTS); for_each_pmc_rtnl(in_dev, pmc) { diff --git a/net/ipv4/inet_connection_sock.c b/net/ipv4/inet_connection_sock.c index dfd5009f96ef..15e7f7915a21 100644 --- a/net/ipv4/inet_connection_sock.c +++ b/net/ipv4/inet_connection_sock.c @@ -544,7 +544,8 @@ struct dst_entry *inet_csk_route_req(const struct sock *sk, struct ip_options_rcu *opt; struct rtable *rt; - opt = ireq_opt_deref(ireq); + rcu_read_lock(); + opt = rcu_dereference(ireq->ireq_opt); flowi4_init_output(fl4, ireq->ir_iif, ireq->ir_mark, RT_CONN_FLAGS(sk), RT_SCOPE_UNIVERSE, @@ -558,11 +559,13 @@ struct dst_entry *inet_csk_route_req(const struct sock *sk, goto no_route; if (opt && opt->opt.is_strictroute && rt->rt_uses_gateway) goto route_err; + rcu_read_unlock(); return &rt->dst; route_err: ip_rt_put(rt); no_route: + rcu_read_unlock(); __IP_INC_STATS(net, IPSTATS_MIB_OUTNOROUTES); return NULL; } diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c index f5c9ef2586de..411dd7a90046 100644 --- a/net/ipv4/inet_hashtables.c +++ b/net/ipv4/inet_hashtables.c @@ -19,7 +19,7 @@ #include <linux/slab.h> #include <linux/wait.h> #include <linux/vmalloc.h> -#include <linux/bootmem.h> +#include <linux/memblock.h> #include <net/addrconf.h> #include <net/inet_connection_sock.h> diff --git a/net/ipv4/ip_sockglue.c b/net/ipv4/ip_sockglue.c index c0fe5ad996f2..26c36cccabdc 100644 --- a/net/ipv4/ip_sockglue.c +++ b/net/ipv4/ip_sockglue.c @@ -149,7 +149,6 @@ static void ip_cmsg_recv_security(struct msghdr *msg, struct sk_buff *skb) static void ip_cmsg_recv_dstaddr(struct msghdr *msg, struct sk_buff *skb) { struct sockaddr_in sin; - const struct iphdr *iph = ip_hdr(skb); __be16 *ports; int end; @@ -164,7 +163,7 @@ static void ip_cmsg_recv_dstaddr(struct msghdr *msg, struct sk_buff *skb) ports = (__be16 *)skb_transport_header(skb); sin.sin_family = AF_INET; - sin.sin_addr.s_addr = iph->daddr; + sin.sin_addr.s_addr = ip_hdr(skb)->daddr; sin.sin_port = ports[1]; memset(sin.sin_zero, 0, sizeof(sin.sin_zero)); diff --git a/net/ipv4/ipmr.c b/net/ipv4/ipmr.c index 5660adcf7a04..a6defbec4f1b 100644 --- a/net/ipv4/ipmr.c +++ b/net/ipv4/ipmr.c @@ -2527,8 +2527,34 @@ errout_free: static int ipmr_rtm_dumproute(struct sk_buff *skb, struct netlink_callback *cb) { + struct fib_dump_filter filter = {}; + int err; + + if (cb->strict_check) { + err = ip_valid_fib_dump_req(sock_net(skb->sk), cb->nlh, + &filter, cb); + if (err < 0) + return err; + } + + if (filter.table_id) { + struct mr_table *mrt; + + mrt = ipmr_get_table(sock_net(skb->sk), filter.table_id); + if (!mrt) { + if (filter.dump_all_families) + return skb->len; + + NL_SET_ERR_MSG(cb->extack, "ipv4: MR table does not exist"); + return -ENOENT; + } + err = mr_table_dump(mrt, skb, cb, _ipmr_fill_mroute, + &mfc_unres_lock, &filter); + return skb->len ? : err; + } + return mr_rtm_dumproute(skb, cb, ipmr_mr_table_iter, - _ipmr_fill_mroute, &mfc_unres_lock); + _ipmr_fill_mroute, &mfc_unres_lock, &filter); } static const struct nla_policy rtm_ipmr_policy[RTA_MAX + 1] = { @@ -2710,6 +2736,31 @@ static bool ipmr_fill_vif(struct mr_table *mrt, u32 vifid, struct sk_buff *skb) return true; } +static int ipmr_valid_dumplink(const struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct ifinfomsg *ifm; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ifm))) { + NL_SET_ERR_MSG(extack, "ipv4: Invalid header for ipmr link dump"); + return -EINVAL; + } + + if (nlmsg_attrlen(nlh, sizeof(*ifm))) { + NL_SET_ERR_MSG(extack, "Invalid data after header in ipmr link dump"); + return -EINVAL; + } + + ifm = nlmsg_data(nlh); + if (ifm->__ifi_pad || ifm->ifi_type || ifm->ifi_flags || + ifm->ifi_change || ifm->ifi_index) { + NL_SET_ERR_MSG(extack, "Invalid values in header for ipmr link dump request"); + return -EINVAL; + } + + return 0; +} + static int ipmr_rtm_dumplink(struct sk_buff *skb, struct netlink_callback *cb) { struct net *net = sock_net(skb->sk); @@ -2718,6 +2769,13 @@ static int ipmr_rtm_dumplink(struct sk_buff *skb, struct netlink_callback *cb) unsigned int e = 0, s_e; struct mr_table *mrt; + if (cb->strict_check) { + int err = ipmr_valid_dumplink(cb->nlh, cb->extack); + + if (err < 0) + return err; + } + s_t = cb->args[0]; s_e = cb->args[1]; diff --git a/net/ipv4/ipmr_base.c b/net/ipv4/ipmr_base.c index 1ad9aa62a97b..3e614cc824f7 100644 --- a/net/ipv4/ipmr_base.c +++ b/net/ipv4/ipmr_base.c @@ -268,6 +268,81 @@ int mr_fill_mroute(struct mr_table *mrt, struct sk_buff *skb, } EXPORT_SYMBOL(mr_fill_mroute); +static bool mr_mfc_uses_dev(const struct mr_table *mrt, + const struct mr_mfc *c, + const struct net_device *dev) +{ + int ct; + + for (ct = c->mfc_un.res.minvif; ct < c->mfc_un.res.maxvif; ct++) { + if (VIF_EXISTS(mrt, ct) && c->mfc_un.res.ttls[ct] < 255) { + const struct vif_device *vif; + + vif = &mrt->vif_table[ct]; + if (vif->dev == dev) + return true; + } + } + return false; +} + +int mr_table_dump(struct mr_table *mrt, struct sk_buff *skb, + struct netlink_callback *cb, + int (*fill)(struct mr_table *mrt, struct sk_buff *skb, + u32 portid, u32 seq, struct mr_mfc *c, + int cmd, int flags), + spinlock_t *lock, struct fib_dump_filter *filter) +{ + unsigned int e = 0, s_e = cb->args[1]; + unsigned int flags = NLM_F_MULTI; + struct mr_mfc *mfc; + int err; + + if (filter->filter_set) + flags |= NLM_F_DUMP_FILTERED; + + list_for_each_entry_rcu(mfc, &mrt->mfc_cache_list, list) { + if (e < s_e) + goto next_entry; + if (filter->dev && + !mr_mfc_uses_dev(mrt, mfc, filter->dev)) + goto next_entry; + + err = fill(mrt, skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, mfc, RTM_NEWROUTE, flags); + if (err < 0) + goto out; +next_entry: + e++; + } + + spin_lock_bh(lock); + list_for_each_entry(mfc, &mrt->mfc_unres_queue, list) { + if (e < s_e) + goto next_entry2; + if (filter->dev && + !mr_mfc_uses_dev(mrt, mfc, filter->dev)) + goto next_entry2; + + err = fill(mrt, skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, mfc, RTM_NEWROUTE, flags); + if (err < 0) { + spin_unlock_bh(lock); + goto out; + } +next_entry2: + e++; + } + spin_unlock_bh(lock); + err = 0; + e = 0; + +out: + cb->args[1] = e; + return err; +} +EXPORT_SYMBOL(mr_table_dump); + int mr_rtm_dumproute(struct sk_buff *skb, struct netlink_callback *cb, struct mr_table *(*iter)(struct net *net, struct mr_table *mrt), @@ -275,53 +350,35 @@ int mr_rtm_dumproute(struct sk_buff *skb, struct netlink_callback *cb, struct sk_buff *skb, u32 portid, u32 seq, struct mr_mfc *c, int cmd, int flags), - spinlock_t *lock) + spinlock_t *lock, struct fib_dump_filter *filter) { - unsigned int t = 0, e = 0, s_t = cb->args[0], s_e = cb->args[1]; + unsigned int t = 0, s_t = cb->args[0]; struct net *net = sock_net(skb->sk); struct mr_table *mrt; - struct mr_mfc *mfc; + int err; + + /* multicast does not track protocol or have route type other + * than RTN_MULTICAST + */ + if (filter->filter_set) { + if (filter->protocol || filter->flags || + (filter->rt_type && filter->rt_type != RTN_MULTICAST)) + return skb->len; + } rcu_read_lock(); for (mrt = iter(net, NULL); mrt; mrt = iter(net, mrt)) { if (t < s_t) goto next_table; - list_for_each_entry_rcu(mfc, &mrt->mfc_cache_list, list) { - if (e < s_e) - goto next_entry; - if (fill(mrt, skb, NETLINK_CB(cb->skb).portid, - cb->nlh->nlmsg_seq, mfc, - RTM_NEWROUTE, NLM_F_MULTI) < 0) - goto done; -next_entry: - e++; - } - e = 0; - s_e = 0; - - spin_lock_bh(lock); - list_for_each_entry(mfc, &mrt->mfc_unres_queue, list) { - if (e < s_e) - goto next_entry2; - if (fill(mrt, skb, NETLINK_CB(cb->skb).portid, - cb->nlh->nlmsg_seq, mfc, - RTM_NEWROUTE, NLM_F_MULTI) < 0) { - spin_unlock_bh(lock); - goto done; - } -next_entry2: - e++; - } - spin_unlock_bh(lock); - e = 0; - s_e = 0; + + err = mr_table_dump(mrt, skb, cb, fill, lock, filter); + if (err < 0) + break; next_table: t++; } -done: rcu_read_unlock(); - cb->args[1] = e; cb->args[0] = t; return skb->len; diff --git a/net/ipv4/metrics.c b/net/ipv4/metrics.c index 04311f7067e2..6d218f5a2e71 100644 --- a/net/ipv4/metrics.c +++ b/net/ipv4/metrics.c @@ -5,8 +5,8 @@ #include <net/net_namespace.h> #include <net/tcp.h> -int ip_metrics_convert(struct net *net, struct nlattr *fc_mx, int fc_mx_len, - u32 *metrics) +static int ip_metrics_convert(struct net *net, struct nlattr *fc_mx, + int fc_mx_len, u32 *metrics) { bool ecn_ca = false; struct nlattr *nla; @@ -52,4 +52,28 @@ int ip_metrics_convert(struct net *net, struct nlattr *fc_mx, int fc_mx_len, return 0; } -EXPORT_SYMBOL_GPL(ip_metrics_convert); + +struct dst_metrics *ip_fib_metrics_init(struct net *net, struct nlattr *fc_mx, + int fc_mx_len) +{ + struct dst_metrics *fib_metrics; + int err; + + if (!fc_mx) + return (struct dst_metrics *)&dst_default_metrics; + + fib_metrics = kzalloc(sizeof(*fib_metrics), GFP_KERNEL); + if (unlikely(!fib_metrics)) + return ERR_PTR(-ENOMEM); + + err = ip_metrics_convert(net, fc_mx, fc_mx_len, fib_metrics->metrics); + if (!err) { + refcount_set(&fib_metrics->refcnt, 1); + } else { + kfree(fib_metrics); + fib_metrics = ERR_PTR(err); + } + + return fib_metrics; +} +EXPORT_SYMBOL_GPL(ip_fib_metrics_init); diff --git a/net/ipv4/netfilter/nf_nat_l3proto_ipv4.c b/net/ipv4/netfilter/nf_nat_l3proto_ipv4.c index 6115bf1ff6f0..78a67f961d86 100644 --- a/net/ipv4/netfilter/nf_nat_l3proto_ipv4.c +++ b/net/ipv4/netfilter/nf_nat_l3proto_ipv4.c @@ -264,7 +264,6 @@ nf_nat_ipv4_fn(void *priv, struct sk_buff *skb, return nf_nat_inet_fn(priv, skb, state); } -EXPORT_SYMBOL_GPL(nf_nat_ipv4_fn); static unsigned int nf_nat_ipv4_in(void *priv, struct sk_buff *skb, diff --git a/net/ipv4/netfilter/nf_nat_masquerade_ipv4.c b/net/ipv4/netfilter/nf_nat_masquerade_ipv4.c index ad3aeff152ed..a9d5e013e555 100644 --- a/net/ipv4/netfilter/nf_nat_masquerade_ipv4.c +++ b/net/ipv4/netfilter/nf_nat_masquerade_ipv4.c @@ -104,12 +104,26 @@ static int masq_device_event(struct notifier_block *this, return NOTIFY_DONE; } +static int inet_cmp(struct nf_conn *ct, void *ptr) +{ + struct in_ifaddr *ifa = (struct in_ifaddr *)ptr; + struct net_device *dev = ifa->ifa_dev->dev; + struct nf_conntrack_tuple *tuple; + + if (!device_cmp(ct, (void *)(long)dev->ifindex)) + return 0; + + tuple = &ct->tuplehash[IP_CT_DIR_REPLY].tuple; + + return ifa->ifa_address == tuple->dst.u3.ip; +} + static int masq_inet_event(struct notifier_block *this, unsigned long event, void *ptr) { struct in_device *idev = ((struct in_ifaddr *)ptr)->ifa_dev; - struct netdev_notifier_info info; + struct net *net = dev_net(idev->dev); /* The masq_dev_notifier will catch the case of the device going * down. So if the inetdev is dead and being destroyed we have @@ -119,8 +133,10 @@ static int masq_inet_event(struct notifier_block *this, if (idev->dead) return NOTIFY_DONE; - netdev_notifier_info_init(&info, idev->dev); - return masq_device_event(this, event, &info); + if (event == NETDEV_DOWN) + nf_ct_iterate_cleanup_net(net, inet_cmp, ptr, 0, 0); + + return NOTIFY_DONE; } static struct notifier_block masq_dev_notifier = { diff --git a/net/ipv4/netfilter/nf_nat_snmp_basic_main.c b/net/ipv4/netfilter/nf_nat_snmp_basic_main.c index ac110c1d55b5..a0aa13bcabda 100644 --- a/net/ipv4/netfilter/nf_nat_snmp_basic_main.c +++ b/net/ipv4/netfilter/nf_nat_snmp_basic_main.c @@ -60,6 +60,7 @@ MODULE_LICENSE("GPL"); MODULE_AUTHOR("James Morris <jmorris@intercode.com.au>"); MODULE_DESCRIPTION("Basic SNMP Application Layer Gateway"); MODULE_ALIAS("ip_nat_snmp_basic"); +MODULE_ALIAS_NFCT_HELPER("snmp_trap"); #define SNMP_PORT 161 #define SNMP_TRAP_PORT 162 diff --git a/net/ipv4/route.c b/net/ipv4/route.c index 048919713f4e..c0a9d26c06ce 100644 --- a/net/ipv4/route.c +++ b/net/ipv4/route.c @@ -1001,21 +1001,22 @@ out: kfree_skb(skb); static void __ip_rt_update_pmtu(struct rtable *rt, struct flowi4 *fl4, u32 mtu) { struct dst_entry *dst = &rt->dst; + u32 old_mtu = ipv4_mtu(dst); struct fib_result res; bool lock = false; if (ip_mtu_locked(dst)) return; - if (ipv4_mtu(dst) < mtu) + if (old_mtu < mtu) return; if (mtu < ip_rt_min_pmtu) { lock = true; - mtu = ip_rt_min_pmtu; + mtu = min(old_mtu, ip_rt_min_pmtu); } - if (rt->rt_pmtu == mtu && + if (rt->rt_pmtu == mtu && !lock && time_before(jiffies, dst->expires - ip_rt_mtu_expires / 2)) return; @@ -1476,12 +1477,9 @@ void rt_del_uncached_list(struct rtable *rt) static void ipv4_dst_destroy(struct dst_entry *dst) { - struct dst_metrics *p = (struct dst_metrics *)DST_METRICS_PTR(dst); struct rtable *rt = (struct rtable *)dst; - if (p != &dst_default_metrics && refcount_dec_and_test(&p->refcnt)) - kfree(p); - + ip_dst_metrics_put(dst); rt_del_uncached_list(rt); } @@ -1528,11 +1526,8 @@ static void rt_set_nexthop(struct rtable *rt, __be32 daddr, rt->rt_gateway = nh->nh_gw; rt->rt_uses_gateway = 1; } - dst_init_metrics(&rt->dst, fi->fib_metrics->metrics, true); - if (fi->fib_metrics != &dst_default_metrics) { - rt->dst._metrics |= DST_METRICS_REFCOUNTED; - refcount_inc(&fi->fib_metrics->refcnt); - } + ip_dst_init_metrics(&rt->dst, fi->fib_metrics); + #ifdef CONFIG_IP_ROUTE_CLASSID rt->dst.tclassid = nh->nh_tclassid; #endif diff --git a/net/ipv4/sysctl_net_ipv4.c b/net/ipv4/sysctl_net_ipv4.c index b92f422f2fa8..891ed2f91467 100644 --- a/net/ipv4/sysctl_net_ipv4.c +++ b/net/ipv4/sysctl_net_ipv4.c @@ -48,6 +48,7 @@ static int tcp_syn_retries_max = MAX_TCP_SYNCNT; static int ip_ping_group_range_min[] = { 0, 0 }; static int ip_ping_group_range_max[] = { GID_T_MAX, GID_T_MAX }; static int comp_sack_nr_max = 255; +static u32 u32_max_div_HZ = UINT_MAX / HZ; /* obsolete */ static int sysctl_tcp_low_latency __read_mostly; @@ -745,9 +746,10 @@ static struct ctl_table ipv4_net_table[] = { { .procname = "tcp_probe_interval", .data = &init_net.ipv4.sysctl_tcp_probe_interval, - .maxlen = sizeof(int), + .maxlen = sizeof(u32), .mode = 0644, - .proc_handler = proc_dointvec, + .proc_handler = proc_douintvec_minmax, + .extra2 = &u32_max_div_HZ, }, { .procname = "igmp_link_local_mcast_reports", diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index 43ef83b2330e..9e6bc4d6daa7 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -262,7 +262,7 @@ #include <linux/net.h> #include <linux/socket.h> #include <linux/random.h> -#include <linux/bootmem.h> +#include <linux/memblock.h> #include <linux/highmem.h> #include <linux/swap.h> #include <linux/cache.h> @@ -507,7 +507,7 @@ __poll_t tcp_poll(struct file *file, struct socket *sock, poll_table *wait) const struct tcp_sock *tp = tcp_sk(sk); int state; - sock_poll_wait(file, wait); + sock_poll_wait(file, sock, wait); state = inet_sk_state_load(sk); if (state == TCP_LISTEN) @@ -3111,10 +3111,10 @@ void tcp_get_info(struct sock *sk, struct tcp_info *info) { const struct tcp_sock *tp = tcp_sk(sk); /* iff sk_type == SOCK_STREAM */ const struct inet_connection_sock *icsk = inet_csk(sk); + unsigned long rate; u32 now; u64 rate64; bool slow; - u32 rate; memset(info, 0, sizeof(*info)); if (sk->sk_type != SOCK_STREAM) @@ -3124,11 +3124,11 @@ void tcp_get_info(struct sock *sk, struct tcp_info *info) /* Report meaningful fields for all TCP states, including listeners */ rate = READ_ONCE(sk->sk_pacing_rate); - rate64 = rate != ~0U ? rate : ~0ULL; + rate64 = (rate != ~0UL) ? rate : ~0ULL; info->tcpi_pacing_rate = rate64; rate = READ_ONCE(sk->sk_max_pacing_rate); - rate64 = rate != ~0U ? rate : ~0ULL; + rate64 = (rate != ~0UL) ? rate : ~0ULL; info->tcpi_max_pacing_rate = rate64; info->tcpi_reordering = tp->reordering; @@ -3254,8 +3254,8 @@ struct sk_buff *tcp_get_timestamping_opt_stats(const struct sock *sk) const struct tcp_sock *tp = tcp_sk(sk); struct sk_buff *stats; struct tcp_info info; + unsigned long rate; u64 rate64; - u32 rate; stats = alloc_skb(tcp_opt_stats_get_size(), GFP_ATOMIC); if (!stats) @@ -3274,7 +3274,7 @@ struct sk_buff *tcp_get_timestamping_opt_stats(const struct sock *sk) tp->total_retrans, TCP_NLA_PAD); rate = READ_ONCE(sk->sk_pacing_rate); - rate64 = rate != ~0U ? rate : ~0ULL; + rate64 = (rate != ~0UL) ? rate : ~0ULL; nla_put_u64_64bit(stats, TCP_NLA_PACING_RATE, rate64, TCP_NLA_PAD); rate64 = tcp_compute_delivery_rate(tp); diff --git a/net/ipv4/tcp_bbr.c b/net/ipv4/tcp_bbr.c index a5786e3e2c16..9277abdd822a 100644 --- a/net/ipv4/tcp_bbr.c +++ b/net/ipv4/tcp_bbr.c @@ -129,7 +129,7 @@ static const u32 bbr_probe_rtt_mode_ms = 200; static const int bbr_min_tso_rate = 1200000; /* Pace at ~1% below estimated bw, on average, to reduce queue at bottleneck. */ -static const int bbr_pacing_marging_percent = 1; +static const int bbr_pacing_margin_percent = 1; /* We use a high_gain value of 2/ln(2) because it's the smallest pacing gain * that will allow a smoothly increasing pacing rate that will double each RTT @@ -214,12 +214,12 @@ static u64 bbr_rate_bytes_per_sec(struct sock *sk, u64 rate, int gain) rate *= mss; rate *= gain; rate >>= BBR_SCALE; - rate *= USEC_PER_SEC / 100 * (100 - bbr_pacing_marging_percent); + rate *= USEC_PER_SEC / 100 * (100 - bbr_pacing_margin_percent); return rate >> BW_SCALE; } /* Convert a BBR bw and gain factor to a pacing rate in bytes per second. */ -static u32 bbr_bw_to_pacing_rate(struct sock *sk, u32 bw, int gain) +static unsigned long bbr_bw_to_pacing_rate(struct sock *sk, u32 bw, int gain) { u64 rate = bw; @@ -258,7 +258,7 @@ static void bbr_set_pacing_rate(struct sock *sk, u32 bw, int gain) { struct tcp_sock *tp = tcp_sk(sk); struct bbr *bbr = inet_csk_ca(sk); - u32 rate = bbr_bw_to_pacing_rate(sk, bw, gain); + unsigned long rate = bbr_bw_to_pacing_rate(sk, bw, gain); if (unlikely(!bbr->has_seen_rtt && tp->srtt_us)) bbr_init_pacing_rate_from_rtt(sk); @@ -280,7 +280,7 @@ static u32 bbr_tso_segs_goal(struct sock *sk) /* Sort of tcp_tso_autosize() but ignoring * driver provided sk_gso_max_size. */ - bytes = min_t(u32, sk->sk_pacing_rate >> sk->sk_pacing_shift, + bytes = min_t(unsigned long, sk->sk_pacing_rate >> sk->sk_pacing_shift, GSO_MAX_SIZE - 1 - MAX_TCP_HEADER); segs = max_t(u32, bytes / tp->mss_cache, bbr_min_tso_segs(sk)); @@ -369,6 +369,39 @@ static u32 bbr_target_cwnd(struct sock *sk, u32 bw, int gain) return cwnd; } +/* With pacing at lower layers, there's often less data "in the network" than + * "in flight". With TSQ and departure time pacing at lower layers (e.g. fq), + * we often have several skbs queued in the pacing layer with a pre-scheduled + * earliest departure time (EDT). BBR adapts its pacing rate based on the + * inflight level that it estimates has already been "baked in" by previous + * departure time decisions. We calculate a rough estimate of the number of our + * packets that might be in the network at the earliest departure time for the + * next skb scheduled: + * in_network_at_edt = inflight_at_edt - (EDT - now) * bw + * If we're increasing inflight, then we want to know if the transmit of the + * EDT skb will push inflight above the target, so inflight_at_edt includes + * bbr_tso_segs_goal() from the skb departing at EDT. If decreasing inflight, + * then estimate if inflight will sink too low just before the EDT transmit. + */ +static u32 bbr_packets_in_net_at_edt(struct sock *sk, u32 inflight_now) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + u64 now_ns, edt_ns, interval_us; + u32 interval_delivered, inflight_at_edt; + + now_ns = tp->tcp_clock_cache; + edt_ns = max(tp->tcp_wstamp_ns, now_ns); + interval_us = div_u64(edt_ns - now_ns, NSEC_PER_USEC); + interval_delivered = (u64)bbr_bw(sk) * interval_us >> BW_SCALE; + inflight_at_edt = inflight_now; + if (bbr->pacing_gain > BBR_UNIT) /* increasing inflight */ + inflight_at_edt += bbr_tso_segs_goal(sk); /* include EDT skb */ + if (interval_delivered >= inflight_at_edt) + return 0; + return inflight_at_edt - interval_delivered; +} + /* An optimization in BBR to reduce losses: On the first round of recovery, we * follow the packet conservation principle: send P packets per P packets acked. * After that, we slow-start and send at most 2*P packets per P packets acked. @@ -460,7 +493,7 @@ static bool bbr_is_next_cycle_phase(struct sock *sk, if (bbr->pacing_gain == BBR_UNIT) return is_full_length; /* just use wall clock time */ - inflight = rs->prior_in_flight; /* what was in-flight before ACK? */ + inflight = bbr_packets_in_net_at_edt(sk, rs->prior_in_flight); bw = bbr_max_bw(sk); /* A pacing_gain > 1.0 probes for bw by trying to raise inflight to at @@ -488,8 +521,6 @@ static void bbr_advance_cycle_phase(struct sock *sk) bbr->cycle_idx = (bbr->cycle_idx + 1) & (CYCLE_LEN - 1); bbr->cycle_mstamp = tp->delivered_mstamp; - bbr->pacing_gain = bbr->lt_use_bw ? BBR_UNIT : - bbr_pacing_gain[bbr->cycle_idx]; } /* Gain cycling: cycle pacing gain to converge to fair share of available bw. */ @@ -507,8 +538,6 @@ static void bbr_reset_startup_mode(struct sock *sk) struct bbr *bbr = inet_csk_ca(sk); bbr->mode = BBR_STARTUP; - bbr->pacing_gain = bbr_high_gain; - bbr->cwnd_gain = bbr_high_gain; } static void bbr_reset_probe_bw_mode(struct sock *sk) @@ -516,8 +545,6 @@ static void bbr_reset_probe_bw_mode(struct sock *sk) struct bbr *bbr = inet_csk_ca(sk); bbr->mode = BBR_PROBE_BW; - bbr->pacing_gain = BBR_UNIT; - bbr->cwnd_gain = bbr_cwnd_gain; bbr->cycle_idx = CYCLE_LEN - 1 - prandom_u32_max(bbr_cycle_rand); bbr_advance_cycle_phase(sk); /* flip to next phase of gain cycle */ } @@ -735,13 +762,11 @@ static void bbr_check_drain(struct sock *sk, const struct rate_sample *rs) if (bbr->mode == BBR_STARTUP && bbr_full_bw_reached(sk)) { bbr->mode = BBR_DRAIN; /* drain queue we created */ - bbr->pacing_gain = bbr_drain_gain; /* pace slow to drain */ - bbr->cwnd_gain = bbr_high_gain; /* maintain cwnd */ tcp_sk(sk)->snd_ssthresh = bbr_target_cwnd(sk, bbr_max_bw(sk), BBR_UNIT); } /* fall through to check if in-flight is already small: */ if (bbr->mode == BBR_DRAIN && - tcp_packets_in_flight(tcp_sk(sk)) <= + bbr_packets_in_net_at_edt(sk, tcp_packets_in_flight(tcp_sk(sk))) <= bbr_target_cwnd(sk, bbr_max_bw(sk), BBR_UNIT)) bbr_reset_probe_bw_mode(sk); /* we estimate queue is drained */ } @@ -798,8 +823,6 @@ static void bbr_update_min_rtt(struct sock *sk, const struct rate_sample *rs) if (bbr_probe_rtt_mode_ms > 0 && filter_expired && !bbr->idle_restart && bbr->mode != BBR_PROBE_RTT) { bbr->mode = BBR_PROBE_RTT; /* dip, drain queue */ - bbr->pacing_gain = BBR_UNIT; - bbr->cwnd_gain = BBR_UNIT; bbr_save_cwnd(sk); /* note cwnd so we can restore it */ bbr->probe_rtt_done_stamp = 0; } @@ -827,6 +850,35 @@ static void bbr_update_min_rtt(struct sock *sk, const struct rate_sample *rs) bbr->idle_restart = 0; } +static void bbr_update_gains(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + switch (bbr->mode) { + case BBR_STARTUP: + bbr->pacing_gain = bbr_high_gain; + bbr->cwnd_gain = bbr_high_gain; + break; + case BBR_DRAIN: + bbr->pacing_gain = bbr_drain_gain; /* slow, to drain */ + bbr->cwnd_gain = bbr_high_gain; /* keep cwnd */ + break; + case BBR_PROBE_BW: + bbr->pacing_gain = (bbr->lt_use_bw ? + BBR_UNIT : + bbr_pacing_gain[bbr->cycle_idx]); + bbr->cwnd_gain = bbr_cwnd_gain; + break; + case BBR_PROBE_RTT: + bbr->pacing_gain = BBR_UNIT; + bbr->cwnd_gain = BBR_UNIT; + break; + default: + WARN_ONCE(1, "BBR bad mode: %u\n", bbr->mode); + break; + } +} + static void bbr_update_model(struct sock *sk, const struct rate_sample *rs) { bbr_update_bw(sk, rs); @@ -834,6 +886,7 @@ static void bbr_update_model(struct sock *sk, const struct rate_sample *rs) bbr_check_full_bw_reached(sk, rs); bbr_check_drain(sk, rs); bbr_update_min_rtt(sk, rs); + bbr_update_gains(sk); } static void bbr_main(struct sock *sk, const struct rate_sample *rs) diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c new file mode 100644 index 000000000000..3b45fe530f91 --- /dev/null +++ b/net/ipv4/tcp_bpf.c @@ -0,0 +1,669 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ + +#include <linux/skmsg.h> +#include <linux/filter.h> +#include <linux/bpf.h> +#include <linux/init.h> +#include <linux/wait.h> + +#include <net/inet_common.h> + +static bool tcp_bpf_stream_read(const struct sock *sk) +{ + struct sk_psock *psock; + bool empty = true; + + rcu_read_lock(); + psock = sk_psock(sk); + if (likely(psock)) + empty = list_empty(&psock->ingress_msg); + rcu_read_unlock(); + return !empty; +} + +static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock, + int flags, long timeo, int *err) +{ + DEFINE_WAIT_FUNC(wait, woken_wake_function); + int ret; + + add_wait_queue(sk_sleep(sk), &wait); + sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); + ret = sk_wait_event(sk, &timeo, + !list_empty(&psock->ingress_msg) || + !skb_queue_empty(&sk->sk_receive_queue), &wait); + sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); + remove_wait_queue(sk_sleep(sk), &wait); + return ret; +} + +int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock, + struct msghdr *msg, int len, int flags) +{ + struct iov_iter *iter = &msg->msg_iter; + int peek = flags & MSG_PEEK; + int i, ret, copied = 0; + struct sk_msg *msg_rx; + + msg_rx = list_first_entry_or_null(&psock->ingress_msg, + struct sk_msg, list); + + while (copied != len) { + struct scatterlist *sge; + + if (unlikely(!msg_rx)) + break; + + i = msg_rx->sg.start; + do { + struct page *page; + int copy; + + sge = sk_msg_elem(msg_rx, i); + copy = sge->length; + page = sg_page(sge); + if (copied + copy > len) + copy = len - copied; + ret = copy_page_to_iter(page, sge->offset, copy, iter); + if (ret != copy) { + msg_rx->sg.start = i; + return -EFAULT; + } + + copied += copy; + if (likely(!peek)) { + sge->offset += copy; + sge->length -= copy; + sk_mem_uncharge(sk, copy); + msg_rx->sg.size -= copy; + + if (!sge->length) { + sk_msg_iter_var_next(i); + if (!msg_rx->skb) + put_page(page); + } + } else { + sk_msg_iter_var_next(i); + } + + if (copied == len) + break; + } while (i != msg_rx->sg.end); + + if (unlikely(peek)) { + msg_rx = list_next_entry(msg_rx, list); + continue; + } + + msg_rx->sg.start = i; + if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) { + list_del(&msg_rx->list); + if (msg_rx->skb) + consume_skb(msg_rx->skb); + kfree(msg_rx); + } + msg_rx = list_first_entry_or_null(&psock->ingress_msg, + struct sk_msg, list); + } + + return copied; +} +EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg); + +int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, + int nonblock, int flags, int *addr_len) +{ + struct sk_psock *psock; + int copied, ret; + + if (unlikely(flags & MSG_ERRQUEUE)) + return inet_recv_error(sk, msg, len, addr_len); + if (!skb_queue_empty(&sk->sk_receive_queue)) + return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); + + psock = sk_psock_get(sk); + if (unlikely(!psock)) + return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); + lock_sock(sk); +msg_bytes_ready: + copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags); + if (!copied) { + int data, err = 0; + long timeo; + + timeo = sock_rcvtimeo(sk, nonblock); + data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err); + if (data) { + if (skb_queue_empty(&sk->sk_receive_queue)) + goto msg_bytes_ready; + release_sock(sk); + sk_psock_put(sk, psock); + return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); + } + if (err) { + ret = err; + goto out; + } + copied = -EAGAIN; + } + ret = copied; +out: + release_sock(sk); + sk_psock_put(sk, psock); + return ret; +} + +static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock, + struct sk_msg *msg, u32 apply_bytes, int flags) +{ + bool apply = apply_bytes; + struct scatterlist *sge; + u32 size, copied = 0; + struct sk_msg *tmp; + int i, ret = 0; + + tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL); + if (unlikely(!tmp)) + return -ENOMEM; + + lock_sock(sk); + tmp->sg.start = msg->sg.start; + i = msg->sg.start; + do { + sge = sk_msg_elem(msg, i); + size = (apply && apply_bytes < sge->length) ? + apply_bytes : sge->length; + if (!sk_wmem_schedule(sk, size)) { + if (!copied) + ret = -ENOMEM; + break; + } + + sk_mem_charge(sk, size); + sk_msg_xfer(tmp, msg, i, size); + copied += size; + if (sge->length) + get_page(sk_msg_page(tmp, i)); + sk_msg_iter_var_next(i); + tmp->sg.end = i; + if (apply) { + apply_bytes -= size; + if (!apply_bytes) + break; + } + } while (i != msg->sg.end); + + if (!ret) { + msg->sg.start = i; + msg->sg.size -= apply_bytes; + sk_psock_queue_msg(psock, tmp); + sk->sk_data_ready(sk); + } else { + sk_msg_free(sk, tmp); + kfree(tmp); + } + + release_sock(sk); + return ret; +} + +static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes, + int flags, bool uncharge) +{ + bool apply = apply_bytes; + struct scatterlist *sge; + struct page *page; + int size, ret = 0; + u32 off; + + while (1) { + sge = sk_msg_elem(msg, msg->sg.start); + size = (apply && apply_bytes < sge->length) ? + apply_bytes : sge->length; + off = sge->offset; + page = sg_page(sge); + + tcp_rate_check_app_limited(sk); +retry: + ret = do_tcp_sendpages(sk, page, off, size, flags); + if (ret <= 0) + return ret; + if (apply) + apply_bytes -= ret; + msg->sg.size -= ret; + sge->offset += ret; + sge->length -= ret; + if (uncharge) + sk_mem_uncharge(sk, ret); + if (ret != size) { + size -= ret; + off += ret; + goto retry; + } + if (!sge->length) { + put_page(page); + sk_msg_iter_next(msg, start); + sg_init_table(sge, 1); + if (msg->sg.start == msg->sg.end) + break; + } + if (apply && !apply_bytes) + break; + } + + return 0; +} + +static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg, + u32 apply_bytes, int flags, bool uncharge) +{ + int ret; + + lock_sock(sk); + ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge); + release_sock(sk); + return ret; +} + +int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, + u32 bytes, int flags) +{ + bool ingress = sk_msg_to_ingress(msg); + struct sk_psock *psock = sk_psock_get(sk); + int ret; + + if (unlikely(!psock)) { + sk_msg_free(sk, msg); + return 0; + } + ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) : + tcp_bpf_push_locked(sk, msg, bytes, flags, false); + sk_psock_put(sk, psock); + return ret; +} +EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir); + +static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, + struct sk_msg *msg, int *copied, int flags) +{ + bool cork = false, enospc = msg->sg.start == msg->sg.end; + struct sock *sk_redir; + u32 tosend; + int ret; + +more_data: + if (psock->eval == __SK_NONE) + psock->eval = sk_psock_msg_verdict(sk, psock, msg); + + if (msg->cork_bytes && + msg->cork_bytes > msg->sg.size && !enospc) { + psock->cork_bytes = msg->cork_bytes - msg->sg.size; + if (!psock->cork) { + psock->cork = kzalloc(sizeof(*psock->cork), + GFP_ATOMIC | __GFP_NOWARN); + if (!psock->cork) + return -ENOMEM; + } + memcpy(psock->cork, msg, sizeof(*msg)); + return 0; + } + + tosend = msg->sg.size; + if (psock->apply_bytes && psock->apply_bytes < tosend) + tosend = psock->apply_bytes; + + switch (psock->eval) { + case __SK_PASS: + ret = tcp_bpf_push(sk, msg, tosend, flags, true); + if (unlikely(ret)) { + *copied -= sk_msg_free(sk, msg); + break; + } + sk_msg_apply_bytes(psock, tosend); + break; + case __SK_REDIRECT: + sk_redir = psock->sk_redir; + sk_msg_apply_bytes(psock, tosend); + if (psock->cork) { + cork = true; + psock->cork = NULL; + } + sk_msg_return(sk, msg, tosend); + release_sock(sk); + ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags); + lock_sock(sk); + if (unlikely(ret < 0)) { + int free = sk_msg_free_nocharge(sk, msg); + + if (!cork) + *copied -= free; + } + if (cork) { + sk_msg_free(sk, msg); + kfree(msg); + msg = NULL; + ret = 0; + } + break; + case __SK_DROP: + default: + sk_msg_free_partial(sk, msg, tosend); + sk_msg_apply_bytes(psock, tosend); + *copied -= tosend; + return -EACCES; + } + + if (likely(!ret)) { + if (!psock->apply_bytes) { + psock->eval = __SK_NONE; + if (psock->sk_redir) { + sock_put(psock->sk_redir); + psock->sk_redir = NULL; + } + } + if (msg && + msg->sg.data[msg->sg.start].page_link && + msg->sg.data[msg->sg.start].length) + goto more_data; + } + return ret; +} + +static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) +{ + struct sk_msg tmp, *msg_tx = NULL; + int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS; + int copied = 0, err = 0; + struct sk_psock *psock; + long timeo; + + psock = sk_psock_get(sk); + if (unlikely(!psock)) + return tcp_sendmsg(sk, msg, size); + + lock_sock(sk); + timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); + while (msg_data_left(msg)) { + bool enospc = false; + u32 copy, osize; + + if (sk->sk_err) { + err = -sk->sk_err; + goto out_err; + } + + copy = msg_data_left(msg); + if (!sk_stream_memory_free(sk)) + goto wait_for_sndbuf; + if (psock->cork) { + msg_tx = psock->cork; + } else { + msg_tx = &tmp; + sk_msg_init(msg_tx); + } + + osize = msg_tx->sg.size; + err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1); + if (err) { + if (err != -ENOSPC) + goto wait_for_memory; + enospc = true; + copy = msg_tx->sg.size - osize; + } + + err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx, + copy); + if (err < 0) { + sk_msg_trim(sk, msg_tx, osize); + goto out_err; + } + + copied += copy; + if (psock->cork_bytes) { + if (size > psock->cork_bytes) + psock->cork_bytes = 0; + else + psock->cork_bytes -= size; + if (psock->cork_bytes && !enospc) + goto out_err; + /* All cork bytes are accounted, rerun the prog. */ + psock->eval = __SK_NONE; + psock->cork_bytes = 0; + } + + err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags); + if (unlikely(err < 0)) + goto out_err; + continue; +wait_for_sndbuf: + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); +wait_for_memory: + err = sk_stream_wait_memory(sk, &timeo); + if (err) { + if (msg_tx && msg_tx != psock->cork) + sk_msg_free(sk, msg_tx); + goto out_err; + } + } +out_err: + if (err < 0) + err = sk_stream_error(sk, msg->msg_flags, err); + release_sock(sk); + sk_psock_put(sk, psock); + return copied ? copied : err; +} + +static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset, + size_t size, int flags) +{ + struct sk_msg tmp, *msg = NULL; + int err = 0, copied = 0; + struct sk_psock *psock; + bool enospc = false; + + psock = sk_psock_get(sk); + if (unlikely(!psock)) + return tcp_sendpage(sk, page, offset, size, flags); + + lock_sock(sk); + if (psock->cork) { + msg = psock->cork; + } else { + msg = &tmp; + sk_msg_init(msg); + } + + /* Catch case where ring is full and sendpage is stalled. */ + if (unlikely(sk_msg_full(msg))) + goto out_err; + + sk_msg_page_add(msg, page, size, offset); + sk_mem_charge(sk, size); + copied = size; + if (sk_msg_full(msg)) + enospc = true; + if (psock->cork_bytes) { + if (size > psock->cork_bytes) + psock->cork_bytes = 0; + else + psock->cork_bytes -= size; + if (psock->cork_bytes && !enospc) + goto out_err; + /* All cork bytes are accounted, rerun the prog. */ + psock->eval = __SK_NONE; + psock->cork_bytes = 0; + } + + err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags); +out_err: + release_sock(sk); + sk_psock_put(sk, psock); + return copied ? copied : err; +} + +static void tcp_bpf_remove(struct sock *sk, struct sk_psock *psock) +{ + struct sk_psock_link *link; + + sk_psock_cork_free(psock); + __sk_psock_purge_ingress_msg(psock); + while ((link = sk_psock_link_pop(psock))) { + sk_psock_unlink(sk, link); + sk_psock_free_link(link); + } +} + +static void tcp_bpf_unhash(struct sock *sk) +{ + void (*saved_unhash)(struct sock *sk); + struct sk_psock *psock; + + rcu_read_lock(); + psock = sk_psock(sk); + if (unlikely(!psock)) { + rcu_read_unlock(); + if (sk->sk_prot->unhash) + sk->sk_prot->unhash(sk); + return; + } + + saved_unhash = psock->saved_unhash; + tcp_bpf_remove(sk, psock); + rcu_read_unlock(); + saved_unhash(sk); +} + +static void tcp_bpf_close(struct sock *sk, long timeout) +{ + void (*saved_close)(struct sock *sk, long timeout); + struct sk_psock *psock; + + lock_sock(sk); + rcu_read_lock(); + psock = sk_psock(sk); + if (unlikely(!psock)) { + rcu_read_unlock(); + release_sock(sk); + return sk->sk_prot->close(sk, timeout); + } + + saved_close = psock->saved_close; + tcp_bpf_remove(sk, psock); + rcu_read_unlock(); + release_sock(sk); + saved_close(sk, timeout); +} + +enum { + TCP_BPF_IPV4, + TCP_BPF_IPV6, + TCP_BPF_NUM_PROTS, +}; + +enum { + TCP_BPF_BASE, + TCP_BPF_TX, + TCP_BPF_NUM_CFGS, +}; + +static struct proto *tcpv6_prot_saved __read_mostly; +static DEFINE_SPINLOCK(tcpv6_prot_lock); +static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS]; + +static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], + struct proto *base) +{ + prot[TCP_BPF_BASE] = *base; + prot[TCP_BPF_BASE].unhash = tcp_bpf_unhash; + prot[TCP_BPF_BASE].close = tcp_bpf_close; + prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg; + prot[TCP_BPF_BASE].stream_memory_read = tcp_bpf_stream_read; + + prot[TCP_BPF_TX] = prot[TCP_BPF_BASE]; + prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg; + prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage; +} + +static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops) +{ + if (sk->sk_family == AF_INET6 && + unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) { + spin_lock_bh(&tcpv6_prot_lock); + if (likely(ops != tcpv6_prot_saved)) { + tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops); + smp_store_release(&tcpv6_prot_saved, ops); + } + spin_unlock_bh(&tcpv6_prot_lock); + } +} + +static int __init tcp_bpf_v4_build_proto(void) +{ + tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot); + return 0; +} +core_initcall(tcp_bpf_v4_build_proto); + +static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock) +{ + int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; + int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; + + sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]); +} + +static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock) +{ + int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; + int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; + + /* Reinit occurs when program types change e.g. TCP_BPF_TX is removed + * or added requiring sk_prot hook updates. We keep original saved + * hooks in this case. + */ + sk->sk_prot = &tcp_bpf_prots[family][config]; +} + +static int tcp_bpf_assert_proto_ops(struct proto *ops) +{ + /* In order to avoid retpoline, we make assumptions when we call + * into ops if e.g. a psock is not present. Make sure they are + * indeed valid assumptions. + */ + return ops->recvmsg == tcp_recvmsg && + ops->sendmsg == tcp_sendmsg && + ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP; +} + +void tcp_bpf_reinit(struct sock *sk) +{ + struct sk_psock *psock; + + sock_owned_by_me(sk); + + rcu_read_lock(); + psock = sk_psock(sk); + tcp_bpf_reinit_sk_prot(sk, psock); + rcu_read_unlock(); +} + +int tcp_bpf_init(struct sock *sk) +{ + struct proto *ops = READ_ONCE(sk->sk_prot); + struct sk_psock *psock; + + sock_owned_by_me(sk); + + rcu_read_lock(); + psock = sk_psock(sk); + if (unlikely(!psock || psock->sk_proto || + tcp_bpf_assert_proto_ops(ops))) { + rcu_read_unlock(); + return -EINVAL; + } + tcp_bpf_check_v6_needs_rebuild(sk, ops); + tcp_bpf_update_sk_prot(sk, psock); + rcu_read_unlock(); + return 0; +} diff --git a/net/ipv4/tcp_cdg.c b/net/ipv4/tcp_cdg.c index 06fbe102a425..37eebd910396 100644 --- a/net/ipv4/tcp_cdg.c +++ b/net/ipv4/tcp_cdg.c @@ -146,7 +146,7 @@ static void tcp_cdg_hystart_update(struct sock *sk) return; if (hystart_detect & HYSTART_ACK_TRAIN) { - u32 now_us = div_u64(local_clock(), NSEC_PER_USEC); + u32 now_us = tp->tcp_mstamp; if (ca->last_ack == 0 || !tcp_is_cwnd_limited(sk)) { ca->last_ack = now_us; diff --git a/net/ipv4/tcp_dctcp.c b/net/ipv4/tcp_dctcp.c index ca61e2a659e7..cd4814f7e962 100644 --- a/net/ipv4/tcp_dctcp.c +++ b/net/ipv4/tcp_dctcp.c @@ -44,6 +44,7 @@ #include <linux/mm.h> #include <net/tcp.h> #include <linux/inet_diag.h> +#include "tcp_dctcp.h" #define DCTCP_MAX_ALPHA 1024U @@ -118,54 +119,6 @@ static u32 dctcp_ssthresh(struct sock *sk) return max(tp->snd_cwnd - ((tp->snd_cwnd * ca->dctcp_alpha) >> 11U), 2U); } -/* Minimal DCTP CE state machine: - * - * S: 0 <- last pkt was non-CE - * 1 <- last pkt was CE - */ - -static void dctcp_ce_state_0_to_1(struct sock *sk) -{ - struct dctcp *ca = inet_csk_ca(sk); - struct tcp_sock *tp = tcp_sk(sk); - - if (!ca->ce_state) { - /* State has changed from CE=0 to CE=1, force an immediate - * ACK to reflect the new CE state. If an ACK was delayed, - * send that first to reflect the prior CE state. - */ - if (inet_csk(sk)->icsk_ack.pending & ICSK_ACK_TIMER) - __tcp_send_ack(sk, ca->prior_rcv_nxt); - inet_csk(sk)->icsk_ack.pending |= ICSK_ACK_NOW; - } - - ca->prior_rcv_nxt = tp->rcv_nxt; - ca->ce_state = 1; - - tp->ecn_flags |= TCP_ECN_DEMAND_CWR; -} - -static void dctcp_ce_state_1_to_0(struct sock *sk) -{ - struct dctcp *ca = inet_csk_ca(sk); - struct tcp_sock *tp = tcp_sk(sk); - - if (ca->ce_state) { - /* State has changed from CE=1 to CE=0, force an immediate - * ACK to reflect the new CE state. If an ACK was delayed, - * send that first to reflect the prior CE state. - */ - if (inet_csk(sk)->icsk_ack.pending & ICSK_ACK_TIMER) - __tcp_send_ack(sk, ca->prior_rcv_nxt); - inet_csk(sk)->icsk_ack.pending |= ICSK_ACK_NOW; - } - - ca->prior_rcv_nxt = tp->rcv_nxt; - ca->ce_state = 0; - - tp->ecn_flags &= ~TCP_ECN_DEMAND_CWR; -} - static void dctcp_update_alpha(struct sock *sk, u32 flags) { const struct tcp_sock *tp = tcp_sk(sk); @@ -230,12 +183,12 @@ static void dctcp_state(struct sock *sk, u8 new_state) static void dctcp_cwnd_event(struct sock *sk, enum tcp_ca_event ev) { + struct dctcp *ca = inet_csk_ca(sk); + switch (ev) { case CA_EVENT_ECN_IS_CE: - dctcp_ce_state_0_to_1(sk); - break; case CA_EVENT_ECN_NO_CE: - dctcp_ce_state_1_to_0(sk); + dctcp_ece_ack_update(sk, ev, &ca->prior_rcv_nxt, &ca->ce_state); break; default: /* Don't care for the rest. */ diff --git a/net/ipv4/tcp_dctcp.h b/net/ipv4/tcp_dctcp.h new file mode 100644 index 000000000000..d69a77cbd0c7 --- /dev/null +++ b/net/ipv4/tcp_dctcp.h @@ -0,0 +1,40 @@ +#ifndef _TCP_DCTCP_H +#define _TCP_DCTCP_H + +static inline void dctcp_ece_ack_cwr(struct sock *sk, u32 ce_state) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (ce_state == 1) + tp->ecn_flags |= TCP_ECN_DEMAND_CWR; + else + tp->ecn_flags &= ~TCP_ECN_DEMAND_CWR; +} + +/* Minimal DCTP CE state machine: + * + * S: 0 <- last pkt was non-CE + * 1 <- last pkt was CE + */ +static inline void dctcp_ece_ack_update(struct sock *sk, enum tcp_ca_event evt, + u32 *prior_rcv_nxt, u32 *ce_state) +{ + u32 new_ce_state = (evt == CA_EVENT_ECN_IS_CE) ? 1 : 0; + + if (*ce_state != new_ce_state) { + /* CE state has changed, force an immediate ACK to + * reflect the new CE state. If an ACK was delayed, + * send that first to reflect the prior CE state. + */ + if (inet_csk(sk)->icsk_ack.pending & ICSK_ACK_TIMER) { + dctcp_ece_ack_cwr(sk, *ce_state); + __tcp_send_ack(sk, *prior_rcv_nxt); + } + inet_csk(sk)->icsk_ack.pending |= ICSK_ACK_NOW; + } + *prior_rcv_nxt = tcp_sk(sk)->rcv_nxt; + *ce_state = new_ce_state; + dctcp_ece_ack_cwr(sk, new_ce_state); +} + +#endif diff --git a/net/ipv4/tcp_input.c b/net/ipv4/tcp_input.c index bf1aac315490..2868ef28ce52 100644 --- a/net/ipv4/tcp_input.c +++ b/net/ipv4/tcp_input.c @@ -2979,8 +2979,8 @@ void tcp_rearm_rto(struct sock *sk) */ rto = usecs_to_jiffies(max_t(int, delta_us, 1)); } - inet_csk_reset_xmit_timer(sk, ICSK_TIME_RETRANS, rto, - TCP_RTO_MAX); + tcp_reset_xmit_timer(sk, ICSK_TIME_RETRANS, rto, + TCP_RTO_MAX, tcp_rtx_queue_head(sk)); } } @@ -3255,8 +3255,8 @@ static void tcp_ack_probe(struct sock *sk) } else { unsigned long when = tcp_probe0_when(sk, TCP_RTO_MAX); - inet_csk_reset_xmit_timer(sk, ICSK_TIME_PROBE0, - when, TCP_RTO_MAX); + tcp_reset_xmit_timer(sk, ICSK_TIME_PROBE0, + when, TCP_RTO_MAX, NULL); } } @@ -6002,11 +6002,13 @@ int tcp_rcv_state_process(struct sock *sk, struct sk_buff *skb) if (th->fin) goto discard; /* It is possible that we process SYN packets from backlog, - * so we need to make sure to disable BH right there. + * so we need to make sure to disable BH and RCU right there. */ + rcu_read_lock(); local_bh_disable(); acceptable = icsk->icsk_af_ops->conn_request(sk, skb) >= 0; local_bh_enable(); + rcu_read_unlock(); if (!acceptable) return 1; diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c index 1f2496e8620d..de47038afdf0 100644 --- a/net/ipv4/tcp_ipv4.c +++ b/net/ipv4/tcp_ipv4.c @@ -943,9 +943,11 @@ static int tcp_v4_send_synack(const struct sock *sk, struct dst_entry *dst, if (skb) { __tcp_v4_send_check(skb, ireq->ir_loc_addr, ireq->ir_rmt_addr); + rcu_read_lock(); err = ip_build_and_send_pkt(skb, sk, ireq->ir_loc_addr, ireq->ir_rmt_addr, - ireq_opt_deref(ireq)); + rcu_dereference(ireq->ireq_opt)); + rcu_read_unlock(); err = net_xmit_eval(err); } diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c index 059b67af28b1..9c34b97d365d 100644 --- a/net/ipv4/tcp_output.c +++ b/net/ipv4/tcp_output.c @@ -52,9 +52,8 @@ void tcp_mstamp_refresh(struct tcp_sock *tp) { u64 val = tcp_clock_ns(); - /* departure time for next data packet */ - if (val > tp->tcp_wstamp_ns) - tp->tcp_wstamp_ns = val; + if (val > tp->tcp_clock_cache) + tp->tcp_clock_cache = val; val = div_u64(val, NSEC_PER_USEC); if (val > tp->tcp_mstamp) @@ -976,32 +975,26 @@ enum hrtimer_restart tcp_pace_kick(struct hrtimer *timer) return HRTIMER_NORESTART; } -static void tcp_internal_pacing(struct sock *sk) -{ - if (!tcp_needs_internal_pacing(sk)) - return; - hrtimer_start(&tcp_sk(sk)->pacing_timer, - ns_to_ktime(tcp_sk(sk)->tcp_wstamp_ns), - HRTIMER_MODE_ABS_PINNED_SOFT); - sock_hold(sk); -} - -static void tcp_update_skb_after_send(struct sock *sk, struct sk_buff *skb) +static void tcp_update_skb_after_send(struct sock *sk, struct sk_buff *skb, + u64 prior_wstamp) { struct tcp_sock *tp = tcp_sk(sk); skb->skb_mstamp_ns = tp->tcp_wstamp_ns; if (sk->sk_pacing_status != SK_PACING_NONE) { - u32 rate = sk->sk_pacing_rate; + unsigned long rate = sk->sk_pacing_rate; /* Original sch_fq does not pace first 10 MSS * Note that tp->data_segs_out overflows after 2^32 packets, * this is a minor annoyance. */ - if (rate != ~0U && rate && tp->data_segs_out >= 10) { - tp->tcp_wstamp_ns += div_u64((u64)skb->len * NSEC_PER_SEC, rate); + if (rate != ~0UL && rate && tp->data_segs_out >= 10) { + u64 len_ns = div64_ul((u64)skb->len * NSEC_PER_SEC, rate); + u64 credit = tp->tcp_wstamp_ns - prior_wstamp; - tcp_internal_pacing(sk); + /* take into account OS jitter */ + len_ns -= min_t(u64, len_ns / 2, credit); + tp->tcp_wstamp_ns += len_ns; } } list_move_tail(&skb->tcp_tsorted_anchor, &tp->tsorted_sent_queue); @@ -1030,6 +1023,7 @@ static int __tcp_transmit_skb(struct sock *sk, struct sk_buff *skb, struct sk_buff *oskb = NULL; struct tcp_md5sig_key *md5; struct tcphdr *th; + u64 prior_wstamp; int err; BUG_ON(!skb || !tcp_skb_pcount(skb)); @@ -1050,6 +1044,10 @@ static int __tcp_transmit_skb(struct sock *sk, struct sk_buff *skb, if (unlikely(!skb)) return -ENOBUFS; } + + prior_wstamp = tp->tcp_wstamp_ns; + tp->tcp_wstamp_ns = max(tp->tcp_wstamp_ns, tp->tcp_clock_cache); + skb->skb_mstamp_ns = tp->tcp_wstamp_ns; inet = inet_sk(sk); @@ -1166,7 +1164,7 @@ static int __tcp_transmit_skb(struct sock *sk, struct sk_buff *skb, err = net_xmit_eval(err); } if (!err && oskb) { - tcp_update_skb_after_send(sk, oskb); + tcp_update_skb_after_send(sk, oskb, prior_wstamp); tcp_rate_skb_sent(sk, oskb); } return err; @@ -1701,8 +1699,9 @@ static u32 tcp_tso_autosize(const struct sock *sk, unsigned int mss_now, { u32 bytes, segs; - bytes = min(sk->sk_pacing_rate >> sk->sk_pacing_shift, - sk->sk_gso_max_size - 1 - MAX_TCP_HEADER); + bytes = min_t(unsigned long, + sk->sk_pacing_rate >> sk->sk_pacing_shift, + sk->sk_gso_max_size - 1 - MAX_TCP_HEADER); /* Goal is to send at least one packet per ms, * not one big TSO packet every 100 ms. @@ -2175,10 +2174,23 @@ static int tcp_mtu_probe(struct sock *sk) return -1; } -static bool tcp_pacing_check(const struct sock *sk) +static bool tcp_pacing_check(struct sock *sk) { - return tcp_needs_internal_pacing(sk) && - hrtimer_is_queued(&tcp_sk(sk)->pacing_timer); + struct tcp_sock *tp = tcp_sk(sk); + + if (!tcp_needs_internal_pacing(sk)) + return false; + + if (tp->tcp_wstamp_ns <= tp->tcp_clock_cache) + return false; + + if (!hrtimer_is_queued(&tp->pacing_timer)) { + hrtimer_start(&tp->pacing_timer, + ns_to_ktime(tp->tcp_wstamp_ns), + HRTIMER_MODE_ABS_PINNED_SOFT); + sock_hold(sk); + } + return true; } /* TCP Small Queues : @@ -2195,10 +2207,12 @@ static bool tcp_pacing_check(const struct sock *sk) static bool tcp_small_queue_check(struct sock *sk, const struct sk_buff *skb, unsigned int factor) { - unsigned int limit; + unsigned long limit; - limit = max(2 * skb->truesize, sk->sk_pacing_rate >> sk->sk_pacing_shift); - limit = min_t(u32, limit, + limit = max_t(unsigned long, + 2 * skb->truesize, + sk->sk_pacing_rate >> sk->sk_pacing_shift); + limit = min_t(unsigned long, limit, sock_net(sk)->ipv4.sysctl_tcp_limit_output_bytes); limit <<= factor; @@ -2307,18 +2321,19 @@ static bool tcp_write_xmit(struct sock *sk, unsigned int mss_now, int nonagle, while ((skb = tcp_send_head(sk))) { unsigned int limit; + if (unlikely(tp->repair) && tp->repair_queue == TCP_SEND_QUEUE) { + /* "skb_mstamp_ns" is used as a start point for the retransmit timer */ + skb->skb_mstamp_ns = tp->tcp_wstamp_ns = tp->tcp_clock_cache; + list_move_tail(&skb->tcp_tsorted_anchor, &tp->tsorted_sent_queue); + goto repair; /* Skip network transmission */ + } + if (tcp_pacing_check(sk)) break; tso_segs = tcp_init_tso_segs(skb, mss_now); BUG_ON(!tso_segs); - if (unlikely(tp->repair) && tp->repair_queue == TCP_SEND_QUEUE) { - /* "skb_mstamp" is used as a start point for the retransmit timer */ - tcp_update_skb_after_send(sk, skb); - goto repair; /* Skip network transmission */ - } - cwnd_quota = tcp_cwnd_test(tp, skb); if (!cwnd_quota) { if (push_one == 2) @@ -2440,8 +2455,8 @@ bool tcp_schedule_loss_probe(struct sock *sk, bool advancing_rto) if (rto_delta_us > 0) timeout = min_t(u32, timeout, usecs_to_jiffies(rto_delta_us)); - inet_csk_reset_xmit_timer(sk, ICSK_TIME_LOSS_PROBE, timeout, - TCP_RTO_MAX); + tcp_reset_xmit_timer(sk, ICSK_TIME_LOSS_PROBE, timeout, + TCP_RTO_MAX, NULL); return true; } @@ -2890,7 +2905,7 @@ int __tcp_retransmit_skb(struct sock *sk, struct sk_buff *skb, int segs) } tcp_skb_tsorted_restore(skb); if (!err) { - tcp_update_skb_after_send(sk, skb); + tcp_update_skb_after_send(sk, skb, tp->tcp_wstamp_ns); tcp_rate_skb_sent(sk, skb); } } else { @@ -3005,9 +3020,10 @@ void tcp_xmit_retransmit_queue(struct sock *sk) if (skb == rtx_head && icsk->icsk_pending != ICSK_TIME_REO_TIMEOUT) - inet_csk_reset_xmit_timer(sk, ICSK_TIME_RETRANS, - inet_csk(sk)->icsk_rto, - TCP_RTO_MAX); + tcp_reset_xmit_timer(sk, ICSK_TIME_RETRANS, + inet_csk(sk)->icsk_rto, + TCP_RTO_MAX, + skb); } } @@ -3737,9 +3753,10 @@ void tcp_send_probe0(struct sock *sk) icsk->icsk_probes_out = 1; probe_max = TCP_RESOURCE_PROBE_INTERVAL; } - inet_csk_reset_xmit_timer(sk, ICSK_TIME_PROBE0, - tcp_probe0_when(sk, probe_max), - TCP_RTO_MAX); + tcp_reset_xmit_timer(sk, ICSK_TIME_PROBE0, + tcp_probe0_when(sk, probe_max), + TCP_RTO_MAX, + NULL); } int tcp_rtx_synack(const struct sock *sk, struct request_sock *req) diff --git a/net/ipv4/tcp_timer.c b/net/ipv4/tcp_timer.c index 61023d50cd60..676020663ce8 100644 --- a/net/ipv4/tcp_timer.c +++ b/net/ipv4/tcp_timer.c @@ -360,7 +360,7 @@ static void tcp_probe_timer(struct sock *sk) */ start_ts = tcp_skb_timestamp(skb); if (!start_ts) - skb->skb_mstamp_ns = tp->tcp_wstamp_ns; + skb->skb_mstamp_ns = tp->tcp_clock_cache; else if (icsk->icsk_user_timeout && (s32)(tcp_time_stamp(tp) - start_ts) > icsk->icsk_user_timeout) goto abort; diff --git a/net/ipv4/tcp_ulp.c b/net/ipv4/tcp_ulp.c index a5995bb2eaca..95df7f7f6328 100644 --- a/net/ipv4/tcp_ulp.c +++ b/net/ipv4/tcp_ulp.c @@ -6,7 +6,7 @@ * */ -#include<linux/module.h> +#include <linux/module.h> #include <linux/mm.h> #include <linux/types.h> #include <linux/list.h> @@ -29,18 +29,6 @@ static struct tcp_ulp_ops *tcp_ulp_find(const char *name) return NULL; } -static struct tcp_ulp_ops *tcp_ulp_find_id(const int ulp) -{ - struct tcp_ulp_ops *e; - - list_for_each_entry_rcu(e, &tcp_ulp_list, list) { - if (e->uid == ulp) - return e; - } - - return NULL; -} - static const struct tcp_ulp_ops *__tcp_ulp_find_autoload(const char *name) { const struct tcp_ulp_ops *ulp = NULL; @@ -63,18 +51,6 @@ static const struct tcp_ulp_ops *__tcp_ulp_find_autoload(const char *name) return ulp; } -static const struct tcp_ulp_ops *__tcp_ulp_lookup(const int uid) -{ - const struct tcp_ulp_ops *ulp; - - rcu_read_lock(); - ulp = tcp_ulp_find_id(uid); - if (!ulp || !try_module_get(ulp->owner)) - ulp = NULL; - rcu_read_unlock(); - return ulp; -} - /* Attach new upper layer protocol to the list * of available protocols. */ @@ -123,6 +99,10 @@ void tcp_cleanup_ulp(struct sock *sk) { struct inet_connection_sock *icsk = inet_csk(sk); + /* No sock_owned_by_me() check here as at the time the + * stack calls this function, the socket is dead and + * about to be destroyed. + */ if (!icsk->icsk_ulp_ops) return; @@ -133,54 +113,35 @@ void tcp_cleanup_ulp(struct sock *sk) icsk->icsk_ulp_ops = NULL; } -/* Change upper layer protocol for socket */ -int tcp_set_ulp(struct sock *sk, const char *name) +static int __tcp_set_ulp(struct sock *sk, const struct tcp_ulp_ops *ulp_ops) { struct inet_connection_sock *icsk = inet_csk(sk); - const struct tcp_ulp_ops *ulp_ops; - int err = 0; + int err; + err = -EEXIST; if (icsk->icsk_ulp_ops) - return -EEXIST; - - ulp_ops = __tcp_ulp_find_autoload(name); - if (!ulp_ops) - return -ENOENT; - - if (!ulp_ops->user_visible) { - module_put(ulp_ops->owner); - return -ENOENT; - } + goto out_err; err = ulp_ops->init(sk); - if (err) { - module_put(ulp_ops->owner); - return err; - } + if (err) + goto out_err; icsk->icsk_ulp_ops = ulp_ops; return 0; +out_err: + module_put(ulp_ops->owner); + return err; } -int tcp_set_ulp_id(struct sock *sk, int ulp) +int tcp_set_ulp(struct sock *sk, const char *name) { - struct inet_connection_sock *icsk = inet_csk(sk); const struct tcp_ulp_ops *ulp_ops; - int err; - if (icsk->icsk_ulp_ops) - return -EEXIST; + sock_owned_by_me(sk); - ulp_ops = __tcp_ulp_lookup(ulp); + ulp_ops = __tcp_ulp_find_autoload(name); if (!ulp_ops) return -ENOENT; - err = ulp_ops->init(sk); - if (err) { - module_put(ulp_ops->owner); - return err; - } - - icsk->icsk_ulp_ops = ulp_ops; - return 0; + return __tcp_set_ulp(sk, ulp_ops); } diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 5fc4beb1c336..1976fddb9e00 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -81,7 +81,7 @@ #include <linux/uaccess.h> #include <asm/ioctls.h> -#include <linux/bootmem.h> +#include <linux/memblock.h> #include <linux/highmem.h> #include <linux/swap.h> #include <linux/types.h> @@ -609,8 +609,8 @@ void __udp4_lib_err(struct sk_buff *skb, u32 info, struct udp_table *udptable) struct net *net = dev_net(skb->dev); sk = __udp4_lib_lookup(net, iph->daddr, uh->dest, - iph->saddr, uh->source, skb->dev->ifindex, 0, - udptable, NULL); + iph->saddr, uh->source, skb->dev->ifindex, + inet_sdif(skb), udptable, NULL); if (!sk) { __ICMP_INC_STATS(net, ICMP_MIB_INERRORS); return; /* No socket for error */ @@ -1627,7 +1627,7 @@ busy_check: *err = error; return NULL; } -EXPORT_SYMBOL_GPL(__skb_recv_udp); +EXPORT_SYMBOL(__skb_recv_udp); /* * This should be easy, if there is something there we @@ -1889,7 +1889,7 @@ static int __udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) return 0; } -static DEFINE_STATIC_KEY_FALSE(udp_encap_needed_key); +DEFINE_STATIC_KEY_FALSE(udp_encap_needed_key); void udp_encap_enable(void) { static_branch_enable(&udp_encap_needed_key); @@ -2120,8 +2120,24 @@ static inline int udp4_csum_init(struct sk_buff *skb, struct udphdr *uh, /* Note, we are only interested in != 0 or == 0, thus the * force to int. */ - return (__force int)skb_checksum_init_zero_check(skb, proto, uh->check, - inet_compute_pseudo); + err = (__force int)skb_checksum_init_zero_check(skb, proto, uh->check, + inet_compute_pseudo); + if (err) + return err; + + if (skb->ip_summed == CHECKSUM_COMPLETE && !skb->csum_valid) { + /* If SW calculated the value, we know it's bad */ + if (skb->csum_complete_sw) + return 1; + + /* HW says the value is bad. Let's validate that. + * skb->csum is no longer the full packet checksum, + * so don't treat it as such. + */ + skb_checksum_complete_unset(skb); + } + + return 0; } /* wrapper for udp_queue_rcv_skb tacking care of csum conversion and diff --git a/net/ipv4/udp_diag.c b/net/ipv4/udp_diag.c index d9ad986c7b2c..5cbb9be05295 100644 --- a/net/ipv4/udp_diag.c +++ b/net/ipv4/udp_diag.c @@ -42,6 +42,7 @@ static int udp_dump_one(struct udp_table *tbl, struct sk_buff *in_skb, rcu_read_lock(); if (req->sdiag_family == AF_INET) + /* src and dst are swapped for historical reasons */ sk = __udp4_lib_lookup(net, req->id.idiag_src[0], req->id.idiag_sport, req->id.idiag_dst[0], req->id.idiag_dport, diff --git a/net/ipv4/udp_offload.c b/net/ipv4/udp_offload.c index 0c0522b79b43..802f2bc00d69 100644 --- a/net/ipv4/udp_offload.c +++ b/net/ipv4/udp_offload.c @@ -405,7 +405,7 @@ static struct sk_buff *udp4_gro_receive(struct list_head *head, { struct udphdr *uh = udp_gro_udphdr(skb); - if (unlikely(!uh)) + if (unlikely(!uh) || !static_branch_unlikely(&udp_encap_needed_key)) goto flush; /* Don't bother verifying checksum if we're going to flush anyway. */ diff --git a/net/ipv4/xfrm4_input.c b/net/ipv4/xfrm4_input.c index bcfc00e88756..f8de2482a529 100644 --- a/net/ipv4/xfrm4_input.c +++ b/net/ipv4/xfrm4_input.c @@ -67,6 +67,7 @@ int xfrm4_transport_finish(struct sk_buff *skb, int async) if (xo && (xo->flags & XFRM_GRO)) { skb_mac_header_rebuild(skb); + skb_reset_transport_header(skb); return 0; } diff --git a/net/ipv4/xfrm4_mode_transport.c b/net/ipv4/xfrm4_mode_transport.c index 3d36644890bb..1ad2c2c4e250 100644 --- a/net/ipv4/xfrm4_mode_transport.c +++ b/net/ipv4/xfrm4_mode_transport.c @@ -46,7 +46,6 @@ static int xfrm4_transport_output(struct xfrm_state *x, struct sk_buff *skb) static int xfrm4_transport_input(struct xfrm_state *x, struct sk_buff *skb) { int ihl = skb->data - skb_transport_header(skb); - struct xfrm_offload *xo = xfrm_offload(skb); if (skb->transport_header != skb->network_header) { memmove(skb_transport_header(skb), @@ -54,8 +53,7 @@ static int xfrm4_transport_input(struct xfrm_state *x, struct sk_buff *skb) skb->network_header = skb->transport_header; } ip_hdr(skb)->tot_len = htons(skb->len + ihl); - if (!xo || !(xo->flags & XFRM_GRO)) - skb_reset_transport_header(skb); + skb_reset_transport_header(skb); return 0; } diff --git a/net/ipv6/addrconf.c b/net/ipv6/addrconf.c index a9a317322388..63a808d5af15 100644 --- a/net/ipv6/addrconf.c +++ b/net/ipv6/addrconf.c @@ -666,6 +666,7 @@ errout: static int inet6_netconf_dump_devconf(struct sk_buff *skb, struct netlink_callback *cb) { + const struct nlmsghdr *nlh = cb->nlh; struct net *net = sock_net(skb->sk); int h, s_h; int idx, s_idx; @@ -673,6 +674,21 @@ static int inet6_netconf_dump_devconf(struct sk_buff *skb, struct inet6_dev *idev; struct hlist_head *head; + if (cb->strict_check) { + struct netlink_ext_ack *extack = cb->extack; + struct netconfmsg *ncm; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ncm))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid header for netconf dump request"); + return -EINVAL; + } + + if (nlmsg_attrlen(nlh, sizeof(*ncm))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid data after header in netconf dump request"); + return -EINVAL; + } + } + s_h = cb->args[0]; s_idx = idx = cb->args[1]; @@ -692,7 +708,7 @@ static int inet6_netconf_dump_devconf(struct sk_buff *skb, if (inet6_netconf_fill_devconf(skb, dev->ifindex, &idev->cnf, NETLINK_CB(cb->skb).portid, - cb->nlh->nlmsg_seq, + nlh->nlmsg_seq, RTM_NEWNETCONF, NLM_F_MULTI, NETCONFA_ALL) < 0) { @@ -709,7 +725,7 @@ cont: if (inet6_netconf_fill_devconf(skb, NETCONFA_IFINDEX_ALL, net->ipv6.devconf_all, NETLINK_CB(cb->skb).portid, - cb->nlh->nlmsg_seq, + nlh->nlmsg_seq, RTM_NEWNETCONF, NLM_F_MULTI, NETCONFA_ALL) < 0) goto done; @@ -720,7 +736,7 @@ cont: if (inet6_netconf_fill_devconf(skb, NETCONFA_IFINDEX_DEFAULT, net->ipv6.devconf_dflt, NETLINK_CB(cb->skb).portid, - cb->nlh->nlmsg_seq, + nlh->nlmsg_seq, RTM_NEWNETCONF, NLM_F_MULTI, NETCONFA_ALL) < 0) goto done; @@ -4793,12 +4809,20 @@ static inline int inet6_ifaddr_msgsize(void) + nla_total_size(4) /* IFA_RT_PRIORITY */; } +enum addr_type_t { + UNICAST_ADDR, + MULTICAST_ADDR, + ANYCAST_ADDR, +}; + struct inet6_fill_args { u32 portid; u32 seq; int event; unsigned int flags; int netnsid; + int ifindex; + enum addr_type_t type; }; static int inet6_fill_ifaddr(struct sk_buff *skb, struct inet6_ifaddr *ifa, @@ -4930,66 +4954,56 @@ static int inet6_fill_ifacaddr(struct sk_buff *skb, struct ifacaddr6 *ifaca, return 0; } -enum addr_type_t { - UNICAST_ADDR, - MULTICAST_ADDR, - ANYCAST_ADDR, -}; - /* called with rcu_read_lock() */ static int in6_dump_addrs(struct inet6_dev *idev, struct sk_buff *skb, - struct netlink_callback *cb, enum addr_type_t type, - int s_ip_idx, int *p_ip_idx, int netnsid) + struct netlink_callback *cb, int s_ip_idx, + struct inet6_fill_args *fillargs) { - struct inet6_fill_args fillargs = { - .portid = NETLINK_CB(cb->skb).portid, - .seq = cb->nlh->nlmsg_seq, - .flags = NLM_F_MULTI, - .netnsid = netnsid, - }; struct ifmcaddr6 *ifmca; struct ifacaddr6 *ifaca; + int ip_idx = 0; int err = 1; - int ip_idx = *p_ip_idx; read_lock_bh(&idev->lock); - switch (type) { + switch (fillargs->type) { case UNICAST_ADDR: { struct inet6_ifaddr *ifa; - fillargs.event = RTM_NEWADDR; + fillargs->event = RTM_NEWADDR; /* unicast address incl. temp addr */ list_for_each_entry(ifa, &idev->addr_list, if_list) { - if (++ip_idx < s_ip_idx) - continue; - err = inet6_fill_ifaddr(skb, ifa, &fillargs); + if (ip_idx < s_ip_idx) + goto next; + err = inet6_fill_ifaddr(skb, ifa, fillargs); if (err < 0) break; nl_dump_check_consistent(cb, nlmsg_hdr(skb)); +next: + ip_idx++; } break; } case MULTICAST_ADDR: - fillargs.event = RTM_GETMULTICAST; + fillargs->event = RTM_GETMULTICAST; /* multicast address */ for (ifmca = idev->mc_list; ifmca; ifmca = ifmca->next, ip_idx++) { if (ip_idx < s_ip_idx) continue; - err = inet6_fill_ifmcaddr(skb, ifmca, &fillargs); + err = inet6_fill_ifmcaddr(skb, ifmca, fillargs); if (err < 0) break; } break; case ANYCAST_ADDR: - fillargs.event = RTM_GETANYCAST; + fillargs->event = RTM_GETANYCAST; /* anycast address */ for (ifaca = idev->ac_list; ifaca; ifaca = ifaca->aca_next, ip_idx++) { if (ip_idx < s_ip_idx) continue; - err = inet6_fill_ifacaddr(skb, ifaca, &fillargs); + err = inet6_fill_ifacaddr(skb, ifaca, fillargs); if (err < 0) break; } @@ -4998,36 +5012,109 @@ static int in6_dump_addrs(struct inet6_dev *idev, struct sk_buff *skb, break; } read_unlock_bh(&idev->lock); - *p_ip_idx = ip_idx; + cb->args[2] = ip_idx; return err; } +static int inet6_valid_dump_ifaddr_req(const struct nlmsghdr *nlh, + struct inet6_fill_args *fillargs, + struct net **tgt_net, struct sock *sk, + struct netlink_callback *cb) +{ + struct netlink_ext_ack *extack = cb->extack; + struct nlattr *tb[IFA_MAX+1]; + struct ifaddrmsg *ifm; + int err, i; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ifm))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid header for address dump request"); + return -EINVAL; + } + + ifm = nlmsg_data(nlh); + if (ifm->ifa_prefixlen || ifm->ifa_flags || ifm->ifa_scope) { + NL_SET_ERR_MSG_MOD(extack, "Invalid values in header for address dump request"); + return -EINVAL; + } + + fillargs->ifindex = ifm->ifa_index; + if (fillargs->ifindex) { + cb->answer_flags |= NLM_F_DUMP_FILTERED; + fillargs->flags |= NLM_F_DUMP_FILTERED; + } + + err = nlmsg_parse_strict(nlh, sizeof(*ifm), tb, IFA_MAX, + ifa_ipv6_policy, extack); + if (err < 0) + return err; + + for (i = 0; i <= IFA_MAX; ++i) { + if (!tb[i]) + continue; + + if (i == IFA_TARGET_NETNSID) { + struct net *net; + + fillargs->netnsid = nla_get_s32(tb[i]); + net = rtnl_get_net_ns_capable(sk, fillargs->netnsid); + if (IS_ERR(net)) { + fillargs->netnsid = -1; + NL_SET_ERR_MSG_MOD(extack, "Invalid target network namespace id"); + return PTR_ERR(net); + } + *tgt_net = net; + } else { + NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in dump request"); + return -EINVAL; + } + } + + return 0; +} + static int inet6_dump_addr(struct sk_buff *skb, struct netlink_callback *cb, enum addr_type_t type) { + const struct nlmsghdr *nlh = cb->nlh; + struct inet6_fill_args fillargs = { + .portid = NETLINK_CB(cb->skb).portid, + .seq = cb->nlh->nlmsg_seq, + .flags = NLM_F_MULTI, + .netnsid = -1, + .type = type, + }; struct net *net = sock_net(skb->sk); - struct nlattr *tb[IFA_MAX+1]; struct net *tgt_net = net; - int netnsid = -1; + int idx, s_idx, s_ip_idx; int h, s_h; - int idx, ip_idx; - int s_idx, s_ip_idx; struct net_device *dev; struct inet6_dev *idev; struct hlist_head *head; + int err = 0; s_h = cb->args[0]; s_idx = idx = cb->args[1]; - s_ip_idx = ip_idx = cb->args[2]; + s_ip_idx = cb->args[2]; - if (nlmsg_parse(cb->nlh, sizeof(struct ifaddrmsg), tb, IFA_MAX, - ifa_ipv6_policy, NULL) >= 0) { - if (tb[IFA_TARGET_NETNSID]) { - netnsid = nla_get_s32(tb[IFA_TARGET_NETNSID]); + if (cb->strict_check) { + err = inet6_valid_dump_ifaddr_req(nlh, &fillargs, &tgt_net, + skb->sk, cb); + if (err < 0) + goto put_tgt_net; - tgt_net = rtnl_get_net_ns_capable(skb->sk, netnsid); - if (IS_ERR(tgt_net)) - return PTR_ERR(tgt_net); + err = 0; + if (fillargs.ifindex) { + dev = __dev_get_by_index(tgt_net, fillargs.ifindex); + if (!dev) { + err = -ENODEV; + goto put_tgt_net; + } + idev = __in6_dev_get(dev); + if (idev) { + err = in6_dump_addrs(idev, skb, cb, s_ip_idx, + &fillargs); + } + goto put_tgt_net; } } @@ -5041,13 +5128,12 @@ static int inet6_dump_addr(struct sk_buff *skb, struct netlink_callback *cb, goto cont; if (h > s_h || idx > s_idx) s_ip_idx = 0; - ip_idx = 0; idev = __in6_dev_get(dev); if (!idev) goto cont; - if (in6_dump_addrs(idev, skb, cb, type, - s_ip_idx, &ip_idx, netnsid) < 0) + if (in6_dump_addrs(idev, skb, cb, s_ip_idx, + &fillargs) < 0) goto done; cont: idx++; @@ -5057,11 +5143,11 @@ done: rcu_read_unlock(); cb->args[0] = h; cb->args[1] = idx; - cb->args[2] = ip_idx; - if (netnsid >= 0) +put_tgt_net: + if (fillargs.netnsid >= 0) put_net(tgt_net); - return skb->len; + return err < 0 ? err : skb->len; } static int inet6_dump_ifaddr(struct sk_buff *skb, struct netlink_callback *cb) @@ -5592,6 +5678,31 @@ nla_put_failure: return -EMSGSIZE; } +static int inet6_valid_dump_ifinfo(const struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct ifinfomsg *ifm; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ifm))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid header for link dump request"); + return -EINVAL; + } + + if (nlmsg_attrlen(nlh, sizeof(*ifm))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid data after header"); + return -EINVAL; + } + + ifm = nlmsg_data(nlh); + if (ifm->__ifi_pad || ifm->ifi_type || ifm->ifi_flags || + ifm->ifi_change || ifm->ifi_index) { + NL_SET_ERR_MSG_MOD(extack, "Invalid values in header for dump request"); + return -EINVAL; + } + + return 0; +} + static int inet6_dump_ifinfo(struct sk_buff *skb, struct netlink_callback *cb) { struct net *net = sock_net(skb->sk); @@ -5601,6 +5712,16 @@ static int inet6_dump_ifinfo(struct sk_buff *skb, struct netlink_callback *cb) struct inet6_dev *idev; struct hlist_head *head; + /* only requests using strict checking can pass data to + * influence the dump + */ + if (cb->strict_check) { + int err = inet6_valid_dump_ifinfo(cb->nlh, cb->extack); + + if (err < 0) + return err; + } + s_h = cb->args[0]; s_idx = cb->args[1]; diff --git a/net/ipv6/addrlabel.c b/net/ipv6/addrlabel.c index 1d6ced37ad71..0d1ee82ee55b 100644 --- a/net/ipv6/addrlabel.c +++ b/net/ipv6/addrlabel.c @@ -458,20 +458,52 @@ static int ip6addrlbl_fill(struct sk_buff *skb, return 0; } +static int ip6addrlbl_valid_dump_req(const struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct ifaddrlblmsg *ifal; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ifal))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid header for address label dump request"); + return -EINVAL; + } + + ifal = nlmsg_data(nlh); + if (ifal->__ifal_reserved || ifal->ifal_prefixlen || + ifal->ifal_flags || ifal->ifal_index || ifal->ifal_seq) { + NL_SET_ERR_MSG_MOD(extack, "Invalid values in header for address label dump request"); + return -EINVAL; + } + + if (nlmsg_attrlen(nlh, sizeof(*ifal))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid data after header for address label dump requewst"); + return -EINVAL; + } + + return 0; +} + static int ip6addrlbl_dump(struct sk_buff *skb, struct netlink_callback *cb) { + const struct nlmsghdr *nlh = cb->nlh; struct net *net = sock_net(skb->sk); struct ip6addrlbl_entry *p; int idx = 0, s_idx = cb->args[0]; int err; + if (cb->strict_check) { + err = ip6addrlbl_valid_dump_req(nlh, cb->extack); + if (err < 0) + return err; + } + rcu_read_lock(); hlist_for_each_entry_rcu(p, &net->ipv6.ip6addrlbl_table.head, list) { if (idx >= s_idx) { err = ip6addrlbl_fill(skb, p, net->ipv6.ip6addrlbl_table.seq, NETLINK_CB(cb->skb).portid, - cb->nlh->nlmsg_seq, + nlh->nlmsg_seq, RTM_NEWADDRLABEL, NLM_F_MULTI); if (err < 0) diff --git a/net/ipv6/af_inet6.c b/net/ipv6/af_inet6.c index e9c8cfdf4b4c..3f4d61017a69 100644 --- a/net/ipv6/af_inet6.c +++ b/net/ipv6/af_inet6.c @@ -901,6 +901,7 @@ static const struct ipv6_stub ipv6_stub_impl = { static const struct ipv6_bpf_stub ipv6_bpf_stub_impl = { .inet6_bind = __inet6_bind, + .udp6_lib_lookup = __udp6_lib_lookup, }; static int __init inet6_init(void) diff --git a/net/ipv6/ip6_checksum.c b/net/ipv6/ip6_checksum.c index 547515e8450a..377717045f8f 100644 --- a/net/ipv6/ip6_checksum.c +++ b/net/ipv6/ip6_checksum.c @@ -88,8 +88,24 @@ int udp6_csum_init(struct sk_buff *skb, struct udphdr *uh, int proto) * Note, we are only interested in != 0 or == 0, thus the * force to int. */ - return (__force int)skb_checksum_init_zero_check(skb, proto, uh->check, - ip6_compute_pseudo); + err = (__force int)skb_checksum_init_zero_check(skb, proto, uh->check, + ip6_compute_pseudo); + if (err) + return err; + + if (skb->ip_summed == CHECKSUM_COMPLETE && !skb->csum_valid) { + /* If SW calculated the value, we know it's bad */ + if (skb->csum_complete_sw) + return 1; + + /* HW says the value is bad. Let's validate that. + * skb->csum is no longer the full packet checksum, + * so don't treat is as such. + */ + skb_checksum_complete_unset(skb); + } + + return 0; } EXPORT_SYMBOL(udp6_csum_init); diff --git a/net/ipv6/ip6_fib.c b/net/ipv6/ip6_fib.c index 5516f55e214b..1b8bc008b53b 100644 --- a/net/ipv6/ip6_fib.c +++ b/net/ipv6/ip6_fib.c @@ -29,6 +29,7 @@ #include <linux/list.h> #include <linux/slab.h> +#include <net/ip.h> #include <net/ipv6.h> #include <net/ndisc.h> #include <net/addrconf.h> @@ -46,6 +47,7 @@ struct fib6_cleaner { int (*func)(struct fib6_info *, void *arg); int sernum; void *arg; + bool skip_notify; }; #ifdef CONFIG_IPV6_SUBTREES @@ -160,8 +162,6 @@ struct fib6_info *fib6_info_alloc(gfp_t gfp_flags) } INIT_LIST_HEAD(&f6i->fib6_siblings); - f6i->fib6_metrics = (struct dst_metrics *)&dst_default_metrics; - atomic_inc(&f6i->fib6_ref); return f6i; @@ -171,7 +171,6 @@ void fib6_info_destroy_rcu(struct rcu_head *head) { struct fib6_info *f6i = container_of(head, struct fib6_info, rcu); struct rt6_exception_bucket *bucket; - struct dst_metrics *m; WARN_ON(f6i->fib6_node); @@ -196,6 +195,8 @@ void fib6_info_destroy_rcu(struct rcu_head *head) *ppcpu_rt = NULL; } } + + free_percpu(f6i->rt6i_pcpu); } lwtstate_put(f6i->fib6_nh.nh_lwtstate); @@ -203,9 +204,7 @@ void fib6_info_destroy_rcu(struct rcu_head *head) if (f6i->fib6_nh.nh_dev) dev_put(f6i->fib6_nh.nh_dev); - m = f6i->fib6_metrics; - if (m != &dst_default_metrics && refcount_dec_and_test(&m->refcnt)) - kfree(m); + ip_fib_metrics_put(f6i->fib6_metrics); kfree(f6i); } @@ -568,17 +567,31 @@ static int fib6_dump_table(struct fib6_table *table, struct sk_buff *skb, static int inet6_dump_fib(struct sk_buff *skb, struct netlink_callback *cb) { + const struct nlmsghdr *nlh = cb->nlh; struct net *net = sock_net(skb->sk); + struct rt6_rtnl_dump_arg arg = {}; unsigned int h, s_h; unsigned int e = 0, s_e; - struct rt6_rtnl_dump_arg arg; struct fib6_walker *w; struct fib6_table *tb; struct hlist_head *head; int res = 0; - s_h = cb->args[0]; - s_e = cb->args[1]; + if (cb->strict_check) { + int err; + + err = ip_valid_fib_dump_req(net, nlh, &arg.filter, cb); + if (err < 0) + return err; + } else if (nlmsg_len(nlh) >= sizeof(struct rtmsg)) { + struct rtmsg *rtm = nlmsg_data(nlh); + + arg.filter.flags = rtm->rtm_flags & (RTM_F_PREFIX|RTM_F_CLONED); + } + + /* fib entries are never clones */ + if (arg.filter.flags & RTM_F_CLONED) + return skb->len; w = (void *)cb->args[2]; if (!w) { @@ -604,6 +617,23 @@ static int inet6_dump_fib(struct sk_buff *skb, struct netlink_callback *cb) arg.net = net; w->args = &arg; + if (arg.filter.table_id) { + tb = fib6_get_table(net, arg.filter.table_id); + if (!tb) { + if (arg.filter.dump_all_families) + return skb->len; + + NL_SET_ERR_MSG_MOD(cb->extack, "FIB table does not exist"); + return -ENOENT; + } + + res = fib6_dump_table(tb, skb, cb); + goto out; + } + + s_h = cb->args[0]; + s_e = cb->args[1]; + rcu_read_lock(); for (h = s_h; h < FIB6_TABLE_HASHSZ; h++, s_e = 0) { e = 0; @@ -613,16 +643,16 @@ static int inet6_dump_fib(struct sk_buff *skb, struct netlink_callback *cb) goto next; res = fib6_dump_table(tb, skb, cb); if (res != 0) - goto out; + goto out_unlock; next: e++; } } -out: +out_unlock: rcu_read_unlock(); cb->args[1] = e; cb->args[0] = h; - +out: res = res < 0 ? res : skb->len; if (res <= 0) fib6_dump_end(cb); @@ -1952,6 +1982,7 @@ static int fib6_clean_node(struct fib6_walker *w) struct fib6_cleaner *c = container_of(w, struct fib6_cleaner, w); struct nl_info info = { .nl_net = c->net, + .skip_notify = c->skip_notify, }; if (c->sernum != FIB6_NO_SERNUM_CHANGE && @@ -2003,7 +2034,7 @@ static int fib6_clean_node(struct fib6_walker *w) static void fib6_clean_tree(struct net *net, struct fib6_node *root, int (*func)(struct fib6_info *, void *arg), - int sernum, void *arg) + int sernum, void *arg, bool skip_notify) { struct fib6_cleaner c; @@ -2015,13 +2046,14 @@ static void fib6_clean_tree(struct net *net, struct fib6_node *root, c.sernum = sernum; c.arg = arg; c.net = net; + c.skip_notify = skip_notify; fib6_walk(net, &c.w); } static void __fib6_clean_all(struct net *net, int (*func)(struct fib6_info *, void *), - int sernum, void *arg) + int sernum, void *arg, bool skip_notify) { struct fib6_table *table; struct hlist_head *head; @@ -2033,7 +2065,7 @@ static void __fib6_clean_all(struct net *net, hlist_for_each_entry_rcu(table, head, tb6_hlist) { spin_lock_bh(&table->tb6_lock); fib6_clean_tree(net, &table->tb6_root, - func, sernum, arg); + func, sernum, arg, skip_notify); spin_unlock_bh(&table->tb6_lock); } } @@ -2043,14 +2075,21 @@ static void __fib6_clean_all(struct net *net, void fib6_clean_all(struct net *net, int (*func)(struct fib6_info *, void *), void *arg) { - __fib6_clean_all(net, func, FIB6_NO_SERNUM_CHANGE, arg); + __fib6_clean_all(net, func, FIB6_NO_SERNUM_CHANGE, arg, false); +} + +void fib6_clean_all_skip_notify(struct net *net, + int (*func)(struct fib6_info *, void *), + void *arg) +{ + __fib6_clean_all(net, func, FIB6_NO_SERNUM_CHANGE, arg, true); } static void fib6_flush_trees(struct net *net) { int new_sernum = fib6_new_sernum(net); - __fib6_clean_all(net, NULL, new_sernum, NULL); + __fib6_clean_all(net, NULL, new_sernum, NULL, false); } /* diff --git a/net/ipv6/ip6_tunnel.c b/net/ipv6/ip6_tunnel.c index a0b6932c3afd..a9d06d4dd057 100644 --- a/net/ipv6/ip6_tunnel.c +++ b/net/ipv6/ip6_tunnel.c @@ -1184,11 +1184,6 @@ route_lookup: } skb_dst_set(skb, dst); - if (encap_limit >= 0) { - init_tel_txopt(&opt, encap_limit); - ipv6_push_frag_opts(skb, &opt.ops, &proto); - } - if (hop_limit == 0) { if (skb->protocol == htons(ETH_P_IP)) hop_limit = ip_hdr(skb)->ttl; @@ -1210,6 +1205,11 @@ route_lookup: if (err) return err; + if (encap_limit >= 0) { + init_tel_txopt(&opt, encap_limit); + ipv6_push_frag_opts(skb, &opt.ops, &proto); + } + skb_push(skb, sizeof(struct ipv6hdr)); skb_reset_network_header(skb); ipv6h = ipv6_hdr(skb); diff --git a/net/ipv6/ip6mr.c b/net/ipv6/ip6mr.c index 6f07b8380425..e2ea691e42c6 100644 --- a/net/ipv6/ip6mr.c +++ b/net/ipv6/ip6mr.c @@ -2457,6 +2457,33 @@ errout: static int ip6mr_rtm_dumproute(struct sk_buff *skb, struct netlink_callback *cb) { + const struct nlmsghdr *nlh = cb->nlh; + struct fib_dump_filter filter = {}; + int err; + + if (cb->strict_check) { + err = ip_valid_fib_dump_req(sock_net(skb->sk), nlh, + &filter, cb); + if (err < 0) + return err; + } + + if (filter.table_id) { + struct mr_table *mrt; + + mrt = ip6mr_get_table(sock_net(skb->sk), filter.table_id); + if (!mrt) { + if (filter.dump_all_families) + return skb->len; + + NL_SET_ERR_MSG_MOD(cb->extack, "MR table does not exist"); + return -ENOENT; + } + err = mr_table_dump(mrt, skb, cb, _ip6mr_fill_mroute, + &mfc_unres_lock, &filter); + return skb->len ? : err; + } + return mr_rtm_dumproute(skb, cb, ip6mr_mr_table_iter, - _ip6mr_fill_mroute, &mfc_unres_lock); + _ip6mr_fill_mroute, &mfc_unres_lock, &filter); } diff --git a/net/ipv6/mcast.c b/net/ipv6/mcast.c index 6895e1dc0b03..21f6deb2aec9 100644 --- a/net/ipv6/mcast.c +++ b/net/ipv6/mcast.c @@ -2436,17 +2436,17 @@ static int ip6_mc_leave_src(struct sock *sk, struct ipv6_mc_socklist *iml, { int err; - /* callers have the socket lock and rtnl lock - * so no other readers or writers of iml or its sflist - */ + write_lock_bh(&iml->sflock); if (!iml->sflist) { /* any-source empty exclude case */ - return ip6_mc_del_src(idev, &iml->addr, iml->sfmode, 0, NULL, 0); + err = ip6_mc_del_src(idev, &iml->addr, iml->sfmode, 0, NULL, 0); + } else { + err = ip6_mc_del_src(idev, &iml->addr, iml->sfmode, + iml->sflist->sl_count, iml->sflist->sl_addr, 0); + sock_kfree_s(sk, iml->sflist, IP6_SFLSIZE(iml->sflist->sl_max)); + iml->sflist = NULL; } - err = ip6_mc_del_src(idev, &iml->addr, iml->sfmode, - iml->sflist->sl_count, iml->sflist->sl_addr, 0); - sock_kfree_s(sk, iml->sflist, IP6_SFLSIZE(iml->sflist->sl_max)); - iml->sflist = NULL; + write_unlock_bh(&iml->sflock); return err; } diff --git a/net/ipv6/ndisc.c b/net/ipv6/ndisc.c index 51863ada15a4..659ecf4e4b3c 100644 --- a/net/ipv6/ndisc.c +++ b/net/ipv6/ndisc.c @@ -1732,10 +1732,9 @@ int ndisc_rcv(struct sk_buff *skb) return 0; } - memset(NEIGH_CB(skb), 0, sizeof(struct neighbour_cb)); - switch (msg->icmph.icmp6_type) { case NDISC_NEIGHBOUR_SOLICITATION: + memset(NEIGH_CB(skb), 0, sizeof(struct neighbour_cb)); ndisc_recv_ns(skb); break; @@ -1784,6 +1783,8 @@ static int ndisc_netdev_event(struct notifier_block *this, unsigned long event, change_info = ptr; if (change_info->flags_changed & IFF_NOARP) neigh_changeaddr(&nd_tbl, dev); + if (!netif_carrier_ok(dev)) + neigh_carrier_down(&nd_tbl, dev); break; case NETDEV_DOWN: neigh_ifdown(&nd_tbl, dev); diff --git a/net/ipv6/netfilter/ip6t_ipv6header.c b/net/ipv6/netfilter/ip6t_ipv6header.c index 8b147440fbdc..af737b47b9b5 100644 --- a/net/ipv6/netfilter/ip6t_ipv6header.c +++ b/net/ipv6/netfilter/ip6t_ipv6header.c @@ -65,7 +65,10 @@ ipv6header_mt6(const struct sk_buff *skb, struct xt_action_param *par) } hp = skb_header_pointer(skb, ptr, sizeof(_hdr), &_hdr); - BUG_ON(hp == NULL); + if (!hp) { + par->hotdrop = true; + return false; + } /* Calculate the header length */ if (nexthdr == NEXTHDR_FRAGMENT) diff --git a/net/ipv6/netfilter/ip6t_rt.c b/net/ipv6/netfilter/ip6t_rt.c index 2c99b94eeca3..21bf6bf04323 100644 --- a/net/ipv6/netfilter/ip6t_rt.c +++ b/net/ipv6/netfilter/ip6t_rt.c @@ -137,7 +137,10 @@ static bool rt_mt6(const struct sk_buff *skb, struct xt_action_param *par) sizeof(_addr), &_addr); - BUG_ON(ap == NULL); + if (ap == NULL) { + par->hotdrop = true; + return false; + } if (ipv6_addr_equal(ap, &rtinfo->addrs[i])) { pr_debug("i=%d temp=%d;\n", i, temp); @@ -166,7 +169,10 @@ static bool rt_mt6(const struct sk_buff *skb, struct xt_action_param *par) + temp * sizeof(_addr), sizeof(_addr), &_addr); - BUG_ON(ap == NULL); + if (ap == NULL) { + par->hotdrop = true; + return false; + } if (!ipv6_addr_equal(ap, &rtinfo->addrs[temp])) break; diff --git a/net/ipv6/netfilter/nf_nat_masquerade_ipv6.c b/net/ipv6/netfilter/nf_nat_masquerade_ipv6.c index e6eb7cf9b54f..3e4bf2286abe 100644 --- a/net/ipv6/netfilter/nf_nat_masquerade_ipv6.c +++ b/net/ipv6/netfilter/nf_nat_masquerade_ipv6.c @@ -87,18 +87,30 @@ static struct notifier_block masq_dev_notifier = { struct masq_dev_work { struct work_struct work; struct net *net; + struct in6_addr addr; int ifindex; }; +static int inet_cmp(struct nf_conn *ct, void *work) +{ + struct masq_dev_work *w = (struct masq_dev_work *)work; + struct nf_conntrack_tuple *tuple; + + if (!device_cmp(ct, (void *)(long)w->ifindex)) + return 0; + + tuple = &ct->tuplehash[IP_CT_DIR_REPLY].tuple; + + return ipv6_addr_equal(&w->addr, &tuple->dst.u3.in6); +} + static void iterate_cleanup_work(struct work_struct *work) { struct masq_dev_work *w; - long index; w = container_of(work, struct masq_dev_work, work); - index = w->ifindex; - nf_ct_iterate_cleanup_net(w->net, device_cmp, (void *)index, 0, 0); + nf_ct_iterate_cleanup_net(w->net, inet_cmp, (void *)w, 0, 0); put_net(w->net); kfree(w); @@ -147,6 +159,7 @@ static int masq_inet_event(struct notifier_block *this, INIT_WORK(&w->work, iterate_cleanup_work); w->ifindex = dev->ifindex; w->net = net; + w->addr = ifa->addr; schedule_work(&w->work); return NOTIFY_DONE; diff --git a/net/ipv6/raw.c b/net/ipv6/raw.c index 413d98bf24f4..5e0efd3954e9 100644 --- a/net/ipv6/raw.c +++ b/net/ipv6/raw.c @@ -651,8 +651,6 @@ static int rawv6_send_hdrinc(struct sock *sk, struct msghdr *msg, int length, skb->priority = sk->sk_priority; skb->mark = sk->sk_mark; skb->tstamp = sockc->transmit_time; - skb_dst_set(skb, &rt->dst); - *dstp = NULL; skb_put(skb, length); skb_reset_network_header(skb); @@ -665,8 +663,14 @@ static int rawv6_send_hdrinc(struct sock *sk, struct msghdr *msg, int length, skb->transport_header = skb->network_header; err = memcpy_from_msg(iph, msg, length); - if (err) - goto error_fault; + if (err) { + err = -EFAULT; + kfree_skb(skb); + goto error; + } + + skb_dst_set(skb, &rt->dst); + *dstp = NULL; /* if egress device is enslaved to an L3 master device pass the * skb to its handler for processing @@ -675,21 +679,28 @@ static int rawv6_send_hdrinc(struct sock *sk, struct msghdr *msg, int length, if (unlikely(!skb)) return 0; + /* Acquire rcu_read_lock() in case we need to use rt->rt6i_idev + * in the error path. Since skb has been freed, the dst could + * have been queued for deletion. + */ + rcu_read_lock(); IP6_UPD_PO_STATS(net, rt->rt6i_idev, IPSTATS_MIB_OUT, skb->len); err = NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_OUT, net, sk, skb, NULL, rt->dst.dev, dst_output); if (err > 0) err = net_xmit_errno(err); - if (err) - goto error; + if (err) { + IP6_INC_STATS(net, rt->rt6i_idev, IPSTATS_MIB_OUTDISCARDS); + rcu_read_unlock(); + goto error_check; + } + rcu_read_unlock(); out: return 0; -error_fault: - err = -EFAULT; - kfree_skb(skb); error: IP6_INC_STATS(net, rt->rt6i_idev, IPSTATS_MIB_OUTDISCARDS); +error_check: if (err == -ENOBUFS && !np->recverr) err = 0; return err; diff --git a/net/ipv6/route.c b/net/ipv6/route.c index 6311a7fc5f63..2a7423c39456 100644 --- a/net/ipv6/route.c +++ b/net/ipv6/route.c @@ -364,14 +364,11 @@ EXPORT_SYMBOL(ip6_dst_alloc); static void ip6_dst_destroy(struct dst_entry *dst) { - struct dst_metrics *p = (struct dst_metrics *)DST_METRICS_PTR(dst); struct rt6_info *rt = (struct rt6_info *)dst; struct fib6_info *from; struct inet6_dev *idev; - if (p != &dst_default_metrics && refcount_dec_and_test(&p->refcnt)) - kfree(p); - + ip_dst_metrics_put(dst); rt6_uncached_list_del(rt); idev = rt->rt6i_idev; @@ -520,10 +517,11 @@ static void rt6_probe_deferred(struct work_struct *w) static void rt6_probe(struct fib6_info *rt) { - struct __rt6_probe_work *work; + struct __rt6_probe_work *work = NULL; const struct in6_addr *nh_gw; struct neighbour *neigh; struct net_device *dev; + struct inet6_dev *idev; /* * Okay, this does not seem to be appropriate @@ -539,15 +537,12 @@ static void rt6_probe(struct fib6_info *rt) nh_gw = &rt->fib6_nh.nh_gw; dev = rt->fib6_nh.nh_dev; rcu_read_lock_bh(); + idev = __in6_dev_get(dev); neigh = __ipv6_neigh_lookup_noref(dev, nh_gw); if (neigh) { - struct inet6_dev *idev; - if (neigh->nud_state & NUD_VALID) goto out; - idev = __in6_dev_get(dev); - work = NULL; write_lock(&neigh->lock); if (!(neigh->nud_state & NUD_VALID) && time_after(jiffies, @@ -557,11 +552,13 @@ static void rt6_probe(struct fib6_info *rt) __neigh_set_probe_once(neigh); } write_unlock(&neigh->lock); - } else { + } else if (time_after(jiffies, rt->last_probe + + idev->cnf.rtr_probe_interval)) { work = kmalloc(sizeof(*work), GFP_ATOMIC); } if (work) { + rt->last_probe = jiffies; INIT_WORK(&work->work, rt6_probe_deferred); work->target = *nh_gw; dev_hold(dev); @@ -978,11 +975,7 @@ static void rt6_set_from(struct rt6_info *rt, struct fib6_info *from) { rt->rt6i_flags &= ~RTF_EXPIRES; rcu_assign_pointer(rt->from, from); - dst_init_metrics(&rt->dst, from->fib6_metrics->metrics, true); - if (from->fib6_metrics != &dst_default_metrics) { - rt->dst._metrics |= DST_METRICS_REFCOUNTED; - refcount_inc(&from->fib6_metrics->refcnt); - } + ip_dst_init_metrics(&rt->dst, from->fib6_metrics); } /* Caller must already hold reference to @ort */ @@ -2705,24 +2698,6 @@ out: return entries > rt_max_size; } -static int ip6_convert_metrics(struct net *net, struct fib6_info *rt, - struct fib6_config *cfg) -{ - struct dst_metrics *p; - - if (!cfg->fc_mx) - return 0; - - p = kzalloc(sizeof(*rt->fib6_metrics), GFP_KERNEL); - if (unlikely(!p)) - return -ENOMEM; - - refcount_set(&p->refcnt, 1); - rt->fib6_metrics = p; - - return ip_metrics_convert(net, cfg->fc_mx, cfg->fc_mx_len, p->metrics); -} - static struct rt6_info *ip6_nh_lookup_table(struct net *net, struct fib6_config *cfg, const struct in6_addr *gw_addr, @@ -2770,6 +2745,8 @@ static int ip6_route_check_nh_onlink(struct net *net, grt = ip6_nh_lookup_table(net, cfg, gw_addr, tbid, 0); if (grt) { if (!grt->dst.error && + /* ignore match if it is the default route */ + grt->from && !ipv6_addr_any(&grt->from->fib6_dst.addr) && (grt->rt6i_flags & flags || dev != grt->dst.dev)) { NL_SET_ERR_MSG(extack, "Nexthop has invalid gateway or device mismatch"); @@ -2998,13 +2975,17 @@ static struct fib6_info *ip6_route_info_create(struct fib6_config *cfg, if (!rt) goto out; + rt->fib6_metrics = ip_fib_metrics_init(net, cfg->fc_mx, cfg->fc_mx_len); + if (IS_ERR(rt->fib6_metrics)) { + err = PTR_ERR(rt->fib6_metrics); + /* Do not leave garbage there. */ + rt->fib6_metrics = (struct dst_metrics *)&dst_default_metrics; + goto out; + } + if (cfg->fc_flags & RTF_ADDRCONF) rt->dst_nocount = true; - err = ip6_convert_metrics(net, rt, cfg); - if (err < 0) - goto out; - if (cfg->fc_flags & RTF_EXPIRES) fib6_set_expires(rt, jiffies + clock_t_to_jiffies(cfg->fc_expires)); @@ -3727,6 +3708,7 @@ struct fib6_info *addrconf_f6i_alloc(struct net *net, if (!f6i) return ERR_PTR(-ENOMEM); + f6i->fib6_metrics = ip_fib_metrics_init(net, NULL, 0); f6i->dst_nocount = true; f6i->dst_host = true; f6i->fib6_protocol = RTPROT_KERNEL; @@ -4046,8 +4028,12 @@ void rt6_sync_down_dev(struct net_device *dev, unsigned long event) .event = event, }, }; + struct net *net = dev_net(dev); - fib6_clean_all(dev_net(dev), fib6_ifdown, &arg); + if (net->ipv6.sysctl.skip_notify_on_dev_down) + fib6_clean_all_skip_notify(net, fib6_ifdown, &arg); + else + fib6_clean_all(net, fib6_ifdown, &arg); } void rt6_disable_ip(struct net_device *dev, unsigned long event) @@ -4137,7 +4123,7 @@ static int rtm_to_fib6_config(struct sk_buff *skb, struct nlmsghdr *nlh, int err; err = nlmsg_parse(nlh, sizeof(*rtm), tb, RTA_MAX, rtm_ipv6_policy, - NULL); + extack); if (err < 0) goto errout; @@ -4289,11 +4275,6 @@ static int ip6_route_info_append(struct net *net, if (!nh) return -ENOMEM; nh->fib6_info = rt; - err = ip6_convert_metrics(net, rt, r_cfg); - if (err) { - kfree(nh); - return err; - } memcpy(&nh->r_cfg, r_cfg, sizeof(*r_cfg)); list_add_tail(&nh->next, rt6_nh_list); @@ -4788,28 +4769,52 @@ nla_put_failure: return -EMSGSIZE; } +static bool fib6_info_uses_dev(const struct fib6_info *f6i, + const struct net_device *dev) +{ + if (f6i->fib6_nh.nh_dev == dev) + return true; + + if (f6i->fib6_nsiblings) { + struct fib6_info *sibling, *next_sibling; + + list_for_each_entry_safe(sibling, next_sibling, + &f6i->fib6_siblings, fib6_siblings) { + if (sibling->fib6_nh.nh_dev == dev) + return true; + } + } + + return false; +} + int rt6_dump_route(struct fib6_info *rt, void *p_arg) { struct rt6_rtnl_dump_arg *arg = (struct rt6_rtnl_dump_arg *) p_arg; + struct fib_dump_filter *filter = &arg->filter; + unsigned int flags = NLM_F_MULTI; struct net *net = arg->net; if (rt == net->ipv6.fib6_null_entry) return 0; - if (nlmsg_len(arg->cb->nlh) >= sizeof(struct rtmsg)) { - struct rtmsg *rtm = nlmsg_data(arg->cb->nlh); - - /* user wants prefix routes only */ - if (rtm->rtm_flags & RTM_F_PREFIX && - !(rt->fib6_flags & RTF_PREFIX_RT)) { - /* success since this is not a prefix route */ + if ((filter->flags & RTM_F_PREFIX) && + !(rt->fib6_flags & RTF_PREFIX_RT)) { + /* success since this is not a prefix route */ + return 1; + } + if (filter->filter_set) { + if ((filter->rt_type && rt->fib6_type != filter->rt_type) || + (filter->dev && !fib6_info_uses_dev(rt, filter->dev)) || + (filter->protocol && rt->fib6_protocol != filter->protocol)) { return 1; } + flags |= NLM_F_DUMP_FILTERED; } return rt6_fill_node(net, arg->skb, rt, NULL, NULL, NULL, 0, RTM_NEWROUTE, NETLINK_CB(arg->cb->skb).portid, - arg->cb->nlh->nlmsg_seq, NLM_F_MULTI); + arg->cb->nlh->nlmsg_seq, flags); } static int inet6_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, @@ -5056,7 +5061,10 @@ int ipv6_sysctl_rtcache_flush(struct ctl_table *ctl, int write, return 0; } -struct ctl_table ipv6_route_table_template[] = { +static int zero; +static int one = 1; + +static struct ctl_table ipv6_route_table_template[] = { { .procname = "flush", .data = &init_net.ipv6.sysctl.flush_delay, @@ -5127,6 +5135,15 @@ struct ctl_table ipv6_route_table_template[] = { .mode = 0644, .proc_handler = proc_dointvec_ms_jiffies, }, + { + .procname = "skip_notify_on_dev_down", + .data = &init_net.ipv6.sysctl.skip_notify_on_dev_down, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + .extra1 = &zero, + .extra2 = &one, + }, { } }; @@ -5150,6 +5167,7 @@ struct ctl_table * __net_init ipv6_route_sysctl_init(struct net *net) table[7].data = &net->ipv6.sysctl.ip6_rt_mtu_expires; table[8].data = &net->ipv6.sysctl.ip6_rt_min_advmss; table[9].data = &net->ipv6.sysctl.ip6_rt_gc_min_interval; + table[10].data = &net->ipv6.sysctl.skip_notify_on_dev_down; /* Don't export sysctls to unprivileged users */ if (net->user_ns != &init_user_ns) @@ -5214,6 +5232,7 @@ static int __net_init ip6_route_net_init(struct net *net) net->ipv6.sysctl.ip6_rt_gc_elasticity = 9; net->ipv6.sysctl.ip6_rt_mtu_expires = 10*60*HZ; net->ipv6.sysctl.ip6_rt_min_advmss = IPV6_MIN_MTU - 20 - 40; + net->ipv6.sysctl.skip_notify_on_dev_down = 0; net->ipv6.ip6_rt_gc_expire = 30*HZ; diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c index 28c4aa5078fc..d2d97d07ef27 100644 --- a/net/ipv6/udp.c +++ b/net/ipv6/udp.c @@ -478,7 +478,7 @@ void __udp6_lib_err(struct sk_buff *skb, struct inet6_skb_parm *opt, struct net *net = dev_net(skb->dev); sk = __udp6_lib_lookup(net, daddr, uh->dest, saddr, uh->source, - inet6_iif(skb), 0, udptable, skb); + inet6_iif(skb), inet6_sdif(skb), udptable, skb); if (!sk) { __ICMP6_INC_STATS(net, __in6_dev_get(skb->dev), ICMP6_MIB_INERRORS); @@ -548,7 +548,7 @@ static __inline__ void udpv6_err(struct sk_buff *skb, __udp6_lib_err(skb, opt, type, code, offset, info, &udp_table); } -static DEFINE_STATIC_KEY_FALSE(udpv6_encap_needed_key); +DEFINE_STATIC_KEY_FALSE(udpv6_encap_needed_key); void udpv6_encap_enable(void) { static_branch_enable(&udpv6_encap_needed_key); @@ -766,11 +766,9 @@ static int udp6_unicast_rcv_skb(struct sock *sk, struct sk_buff *skb, ret = udpv6_queue_rcv_skb(sk, skb); - /* a return value > 0 means to resubmit the input, but - * it wants the return to be -protocol, or 0 - */ + /* a return value > 0 means to resubmit the input */ if (ret > 0) - return -ret; + return ret; return 0; } diff --git a/net/ipv6/udp_offload.c b/net/ipv6/udp_offload.c index 95dee9ca8d22..1b8e161ac527 100644 --- a/net/ipv6/udp_offload.c +++ b/net/ipv6/udp_offload.c @@ -119,7 +119,7 @@ static struct sk_buff *udp6_gro_receive(struct list_head *head, { struct udphdr *uh = udp_gro_udphdr(skb); - if (unlikely(!uh)) + if (unlikely(!uh) || !static_branch_unlikely(&udpv6_encap_needed_key)) goto flush; /* Don't bother verifying checksum if we're going to flush anyway. */ diff --git a/net/ipv6/xfrm6_input.c b/net/ipv6/xfrm6_input.c index 841f4a07438e..9ef490dddcea 100644 --- a/net/ipv6/xfrm6_input.c +++ b/net/ipv6/xfrm6_input.c @@ -59,6 +59,7 @@ int xfrm6_transport_finish(struct sk_buff *skb, int async) if (xo && (xo->flags & XFRM_GRO)) { skb_mac_header_rebuild(skb); + skb_reset_transport_header(skb); return -1; } diff --git a/net/ipv6/xfrm6_mode_transport.c b/net/ipv6/xfrm6_mode_transport.c index 9ad07a91708e..3c29da5defe6 100644 --- a/net/ipv6/xfrm6_mode_transport.c +++ b/net/ipv6/xfrm6_mode_transport.c @@ -51,7 +51,6 @@ static int xfrm6_transport_output(struct xfrm_state *x, struct sk_buff *skb) static int xfrm6_transport_input(struct xfrm_state *x, struct sk_buff *skb) { int ihl = skb->data - skb_transport_header(skb); - struct xfrm_offload *xo = xfrm_offload(skb); if (skb->transport_header != skb->network_header) { memmove(skb_transport_header(skb), @@ -60,8 +59,7 @@ static int xfrm6_transport_input(struct xfrm_state *x, struct sk_buff *skb) } ipv6_hdr(skb)->payload_len = htons(skb->len + ihl - sizeof(struct ipv6hdr)); - if (!xo || !(xo->flags & XFRM_GRO)) - skb_reset_transport_header(skb); + skb_reset_transport_header(skb); return 0; } diff --git a/net/ipv6/xfrm6_output.c b/net/ipv6/xfrm6_output.c index 5959ce9620eb..6a74080005cf 100644 --- a/net/ipv6/xfrm6_output.c +++ b/net/ipv6/xfrm6_output.c @@ -170,9 +170,11 @@ static int __xfrm6_output(struct net *net, struct sock *sk, struct sk_buff *skb) if (toobig && xfrm6_local_dontfrag(skb)) { xfrm6_local_rxpmtu(skb, mtu); + kfree_skb(skb); return -EMSGSIZE; } else if (!skb->ignore_df && toobig && skb->sk) { xfrm_local_error(skb, mtu); + kfree_skb(skb); return -EMSGSIZE; } diff --git a/net/ipv6/xfrm6_policy.c b/net/ipv6/xfrm6_policy.c index ef3defaf43b9..d35bcf92969c 100644 --- a/net/ipv6/xfrm6_policy.c +++ b/net/ipv6/xfrm6_policy.c @@ -146,8 +146,8 @@ _decode_session6(struct sk_buff *skb, struct flowi *fl, int reverse) fl6->daddr = reverse ? hdr->saddr : hdr->daddr; fl6->saddr = reverse ? hdr->daddr : hdr->saddr; - while (nh + offset + 1 < skb->data || - pskb_may_pull(skb, nh + offset + 1 - skb->data)) { + while (nh + offset + sizeof(*exthdr) < skb->data || + pskb_may_pull(skb, nh + offset + sizeof(*exthdr) - skb->data)) { nh = skb_network_header(skb); exthdr = (struct ipv6_opt_hdr *)(nh + offset); diff --git a/net/iucv/af_iucv.c b/net/iucv/af_iucv.c index 45115c125569..0bed4cc20603 100644 --- a/net/iucv/af_iucv.c +++ b/net/iucv/af_iucv.c @@ -1504,7 +1504,7 @@ __poll_t iucv_sock_poll(struct file *file, struct socket *sock, struct sock *sk = sock->sk; __poll_t mask = 0; - sock_poll_wait(file, wait); + sock_poll_wait(file, sock, wait); if (sk->sk_state == IUCV_LISTEN) return iucv_accept_poll(sk); diff --git a/net/llc/af_llc.c b/net/llc/af_llc.c index 1beeea9549fa..b99e73a7e7e0 100644 --- a/net/llc/af_llc.c +++ b/net/llc/af_llc.c @@ -730,7 +730,6 @@ static int llc_ui_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, struct sk_buff *skb = NULL; struct sock *sk = sock->sk; struct llc_sock *llc = llc_sk(sk); - unsigned long cpu_flags; size_t copied = 0; u32 peek_seq = 0; u32 *seq, skb_len; @@ -855,9 +854,8 @@ static int llc_ui_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, goto copy_uaddr; if (!(flags & MSG_PEEK)) { - spin_lock_irqsave(&sk->sk_receive_queue.lock, cpu_flags); - sk_eat_skb(sk, skb); - spin_unlock_irqrestore(&sk->sk_receive_queue.lock, cpu_flags); + skb_unlink(skb, &sk->sk_receive_queue); + kfree_skb(skb); *seq = 0; } @@ -878,9 +876,8 @@ copy_uaddr: llc_cmsg_rcv(msg, skb); if (!(flags & MSG_PEEK)) { - spin_lock_irqsave(&sk->sk_receive_queue.lock, cpu_flags); - sk_eat_skb(sk, skb); - spin_unlock_irqrestore(&sk->sk_receive_queue.lock, cpu_flags); + skb_unlink(skb, &sk->sk_receive_queue); + kfree_skb(skb); *seq = 0; } diff --git a/net/llc/llc_conn.c b/net/llc/llc_conn.c index c0ac522b48a1..4ff89cb7c86f 100644 --- a/net/llc/llc_conn.c +++ b/net/llc/llc_conn.c @@ -734,6 +734,7 @@ void llc_sap_add_socket(struct llc_sap *sap, struct sock *sk) llc_sk(sk)->sap = sap; spin_lock_bh(&sap->sk_lock); + sock_set_flag(sk, SOCK_RCU_FREE); sap->sk_count++; sk_nulls_add_node_rcu(sk, laddr_hb); hlist_add_head(&llc->dev_hash_node, dev_hb); diff --git a/net/mac80211/Kconfig b/net/mac80211/Kconfig index 76e30f4797fb..f869e35d0974 100644 --- a/net/mac80211/Kconfig +++ b/net/mac80211/Kconfig @@ -27,20 +27,6 @@ config MAC80211_RC_MINSTREL ---help--- This option enables the 'minstrel' TX rate control algorithm -config MAC80211_RC_MINSTREL_HT - bool "Minstrel 802.11n support" if EXPERT - depends on MAC80211_RC_MINSTREL - default y - ---help--- - This option enables the 'minstrel_ht' TX rate control algorithm - -config MAC80211_RC_MINSTREL_VHT - bool "Minstrel 802.11ac support" if EXPERT - depends on MAC80211_RC_MINSTREL_HT - default n - ---help--- - This option enables VHT in the 'minstrel_ht' TX rate control algorithm - choice prompt "Default rate control algorithm" depends on MAC80211_HAS_RC @@ -62,8 +48,7 @@ endchoice config MAC80211_RC_DEFAULT string - default "minstrel_ht" if MAC80211_RC_DEFAULT_MINSTREL && MAC80211_RC_MINSTREL_HT - default "minstrel" if MAC80211_RC_DEFAULT_MINSTREL + default "minstrel_ht" if MAC80211_RC_DEFAULT_MINSTREL default "" endif diff --git a/net/mac80211/Makefile b/net/mac80211/Makefile index bb707789ef2b..4f03ebe732fa 100644 --- a/net/mac80211/Makefile +++ b/net/mac80211/Makefile @@ -53,13 +53,14 @@ mac80211-$(CONFIG_PM) += pm.o CFLAGS_trace.o := -I$(src) -rc80211_minstrel-y := rc80211_minstrel.o -rc80211_minstrel-$(CONFIG_MAC80211_DEBUGFS) += rc80211_minstrel_debugfs.o +rc80211_minstrel-y := \ + rc80211_minstrel.o \ + rc80211_minstrel_ht.o -rc80211_minstrel_ht-y := rc80211_minstrel_ht.o -rc80211_minstrel_ht-$(CONFIG_MAC80211_DEBUGFS) += rc80211_minstrel_ht_debugfs.o +rc80211_minstrel-$(CONFIG_MAC80211_DEBUGFS) += \ + rc80211_minstrel_debugfs.o \ + rc80211_minstrel_ht_debugfs.o mac80211-$(CONFIG_MAC80211_RC_MINSTREL) += $(rc80211_minstrel-y) -mac80211-$(CONFIG_MAC80211_RC_MINSTREL_HT) += $(rc80211_minstrel_ht-y) ccflags-y += -DDEBUG diff --git a/net/mac80211/cfg.c b/net/mac80211/cfg.c index 504627e2117f..51622333d460 100644 --- a/net/mac80211/cfg.c +++ b/net/mac80211/cfg.c @@ -425,7 +425,7 @@ static int ieee80211_add_key(struct wiphy *wiphy, struct net_device *dev, case NL80211_IFTYPE_AP: case NL80211_IFTYPE_AP_VLAN: /* Keys without a station are used for TX only */ - if (key->sta && test_sta_flag(key->sta, WLAN_STA_MFP)) + if (sta && test_sta_flag(sta, WLAN_STA_MFP)) key->conf.flags |= IEEE80211_KEY_FLAG_RX_MGMT; break; case NL80211_IFTYPE_ADHOC: @@ -790,6 +790,48 @@ static int ieee80211_set_probe_resp(struct ieee80211_sub_if_data *sdata, return 0; } +static int ieee80211_set_ftm_responder_params( + struct ieee80211_sub_if_data *sdata, + const u8 *lci, size_t lci_len, + const u8 *civicloc, size_t civicloc_len) +{ + struct ieee80211_ftm_responder_params *new, *old; + struct ieee80211_bss_conf *bss_conf; + u8 *pos; + int len; + + if ((!lci || !lci_len) && (!civicloc || !civicloc_len)) + return 1; + + bss_conf = &sdata->vif.bss_conf; + old = bss_conf->ftmr_params; + len = lci_len + civicloc_len; + + new = kzalloc(sizeof(*new) + len, GFP_KERNEL); + if (!new) + return -ENOMEM; + + pos = (u8 *)(new + 1); + if (lci_len) { + new->lci_len = lci_len; + new->lci = pos; + memcpy(pos, lci, lci_len); + pos += lci_len; + } + + if (civicloc_len) { + new->civicloc_len = civicloc_len; + new->civicloc = pos; + memcpy(pos, civicloc, civicloc_len); + pos += civicloc_len; + } + + bss_conf->ftmr_params = new; + kfree(old); + + return 0; +} + static int ieee80211_assign_beacon(struct ieee80211_sub_if_data *sdata, struct cfg80211_beacon_data *params, const struct ieee80211_csa_settings *csa) @@ -863,6 +905,20 @@ static int ieee80211_assign_beacon(struct ieee80211_sub_if_data *sdata, if (err == 0) changed |= BSS_CHANGED_AP_PROBE_RESP; + if (params->ftm_responder != -1) { + sdata->vif.bss_conf.ftm_responder = params->ftm_responder; + err = ieee80211_set_ftm_responder_params(sdata, + params->lci, + params->lci_len, + params->civicloc, + params->civicloc_len); + + if (err < 0) + return err; + + changed |= BSS_CHANGED_FTM_RESPONDER; + } + rcu_assign_pointer(sdata->u.ap.beacon, new); if (old) @@ -1063,6 +1119,9 @@ static int ieee80211_stop_ap(struct wiphy *wiphy, struct net_device *dev) kfree_rcu(old_probe_resp, rcu_head); sdata->u.ap.driver_smps_mode = IEEE80211_SMPS_OFF; + kfree(sdata->vif.bss_conf.ftmr_params); + sdata->vif.bss_conf.ftmr_params = NULL; + __sta_info_flush(sdata, true); ieee80211_free_keys(sdata, true); @@ -2875,6 +2934,20 @@ cfg80211_beacon_dup(struct cfg80211_beacon_data *beacon) memcpy(pos, beacon->probe_resp, beacon->probe_resp_len); pos += beacon->probe_resp_len; } + if (beacon->ftm_responder) + new_beacon->ftm_responder = beacon->ftm_responder; + if (beacon->lci) { + new_beacon->lci_len = beacon->lci_len; + new_beacon->lci = pos; + memcpy(pos, beacon->lci, beacon->lci_len); + pos += beacon->lci_len; + } + if (beacon->civicloc) { + new_beacon->civicloc_len = beacon->civicloc_len; + new_beacon->civicloc = pos; + memcpy(pos, beacon->civicloc, beacon->civicloc_len); + pos += beacon->civicloc_len; + } return new_beacon; } @@ -3765,6 +3838,17 @@ out: return ret; } +static int +ieee80211_get_ftm_responder_stats(struct wiphy *wiphy, + struct net_device *dev, + struct cfg80211_ftm_responder_stats *ftm_stats) +{ + struct ieee80211_local *local = wiphy_priv(wiphy); + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + + return drv_get_ftm_responder_stats(local, sdata, ftm_stats); +} + const struct cfg80211_ops mac80211_config_ops = { .add_virtual_intf = ieee80211_add_iface, .del_virtual_intf = ieee80211_del_iface, @@ -3859,4 +3943,5 @@ const struct cfg80211_ops mac80211_config_ops = { .set_multicast_to_unicast = ieee80211_set_multicast_to_unicast, .tx_control_port = ieee80211_tx_control_port, .get_txq_stats = ieee80211_get_txq_stats, + .get_ftm_responder_stats = ieee80211_get_ftm_responder_stats, }; diff --git a/net/mac80211/driver-ops.h b/net/mac80211/driver-ops.h index e42c641b6190..0b1747a2313d 100644 --- a/net/mac80211/driver-ops.h +++ b/net/mac80211/driver-ops.h @@ -1183,6 +1183,22 @@ static inline int drv_can_aggregate_in_amsdu(struct ieee80211_local *local, return local->ops->can_aggregate_in_amsdu(&local->hw, head, skb); } +static inline int +drv_get_ftm_responder_stats(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct cfg80211_ftm_responder_stats *ftm_stats) +{ + u32 ret = -EOPNOTSUPP; + + if (local->ops->get_ftm_responder_stats) + ret = local->ops->get_ftm_responder_stats(&local->hw, + &sdata->vif, + ftm_stats); + trace_drv_get_ftm_responder_stats(local, sdata, ftm_stats); + + return ret; +} + static inline int drv_start_nan(struct ieee80211_local *local, struct ieee80211_sub_if_data *sdata, struct cfg80211_nan_conf *conf) diff --git a/net/mac80211/ieee80211_i.h b/net/mac80211/ieee80211_i.h index f40a2167935f..10a05062e4a0 100644 --- a/net/mac80211/ieee80211_i.h +++ b/net/mac80211/ieee80211_i.h @@ -377,6 +377,7 @@ struct ieee80211_mgd_auth_data { u8 key[WLAN_KEY_LEN_WEP104]; u8 key_len, key_idx; bool done; + bool peer_confirmed; bool timeout_started; u16 sae_trans, sae_status; diff --git a/net/mac80211/iface.c b/net/mac80211/iface.c index 5e6cf2cee965..5836ddeac9e3 100644 --- a/net/mac80211/iface.c +++ b/net/mac80211/iface.c @@ -1756,7 +1756,8 @@ int ieee80211_if_add(struct ieee80211_local *local, const char *name, if (local->ops->wake_tx_queue && type != NL80211_IFTYPE_AP_VLAN && - type != NL80211_IFTYPE_MONITOR) + (type != NL80211_IFTYPE_MONITOR || + (params->flags & MONITOR_FLAG_ACTIVE))) txq_size += sizeof(struct txq_info) + local->hw.txq_data_size; diff --git a/net/mac80211/main.c b/net/mac80211/main.c index 77381017bac7..83e71e6b2ebe 100644 --- a/net/mac80211/main.c +++ b/net/mac80211/main.c @@ -1203,8 +1203,10 @@ int ieee80211_register_hw(struct ieee80211_hw *hw) continue; sband = kmemdup(sband, sizeof(*sband), GFP_KERNEL); - if (!sband) + if (!sband) { + result = -ENOMEM; goto fail_rate; + } wiphy_dbg(hw->wiphy, "copying sband (band %d) due to VHT EXT NSS BW flag\n", band); @@ -1373,18 +1375,12 @@ static int __init ieee80211_init(void) if (ret) return ret; - ret = rc80211_minstrel_ht_init(); - if (ret) - goto err_minstrel; - ret = ieee80211_iface_init(); if (ret) goto err_netdev; return 0; err_netdev: - rc80211_minstrel_ht_exit(); - err_minstrel: rc80211_minstrel_exit(); return ret; @@ -1392,7 +1388,6 @@ static int __init ieee80211_init(void) static void __exit ieee80211_exit(void) { - rc80211_minstrel_ht_exit(); rc80211_minstrel_exit(); ieee80211s_stop(); diff --git a/net/mac80211/mesh.h b/net/mac80211/mesh.h index ee56f18cad3f..21526630bf65 100644 --- a/net/mac80211/mesh.h +++ b/net/mac80211/mesh.h @@ -217,7 +217,8 @@ void mesh_rmc_free(struct ieee80211_sub_if_data *sdata); int mesh_rmc_init(struct ieee80211_sub_if_data *sdata); void ieee80211s_init(void); void ieee80211s_update_metric(struct ieee80211_local *local, - struct sta_info *sta, struct sk_buff *skb); + struct sta_info *sta, + struct ieee80211_tx_status *st); void ieee80211_mesh_init_sdata(struct ieee80211_sub_if_data *sdata); void ieee80211_mesh_teardown_sdata(struct ieee80211_sub_if_data *sdata); int ieee80211_start_mesh(struct ieee80211_sub_if_data *sdata); diff --git a/net/mac80211/mesh_hwmp.c b/net/mac80211/mesh_hwmp.c index daf9db3c8f24..6950cd0bf594 100644 --- a/net/mac80211/mesh_hwmp.c +++ b/net/mac80211/mesh_hwmp.c @@ -295,15 +295,12 @@ int mesh_path_error_tx(struct ieee80211_sub_if_data *sdata, } void ieee80211s_update_metric(struct ieee80211_local *local, - struct sta_info *sta, struct sk_buff *skb) + struct sta_info *sta, + struct ieee80211_tx_status *st) { - struct ieee80211_tx_info *txinfo = IEEE80211_SKB_CB(skb); - struct ieee80211_hdr *hdr = (struct ieee80211_hdr *) skb->data; + struct ieee80211_tx_info *txinfo = st->info; int failed; - if (!ieee80211_is_data(hdr->frame_control)) - return; - failed = !(txinfo->flags & IEEE80211_TX_STAT_ACK); /* moving average, scaled to 100. diff --git a/net/mac80211/mlme.c b/net/mac80211/mlme.c index 89dac799a85f..d2bc8d57c87e 100644 --- a/net/mac80211/mlme.c +++ b/net/mac80211/mlme.c @@ -2761,13 +2761,40 @@ static void ieee80211_auth_challenge(struct ieee80211_sub_if_data *sdata, auth_data->key_idx, tx_flags); } +static bool ieee80211_mark_sta_auth(struct ieee80211_sub_if_data *sdata, + const u8 *bssid) +{ + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + struct sta_info *sta; + + sdata_info(sdata, "authenticated\n"); + ifmgd->auth_data->done = true; + ifmgd->auth_data->timeout = jiffies + IEEE80211_AUTH_WAIT_ASSOC; + ifmgd->auth_data->timeout_started = true; + run_again(sdata, ifmgd->auth_data->timeout); + + /* move station state to auth */ + mutex_lock(&sdata->local->sta_mtx); + sta = sta_info_get(sdata, bssid); + if (!sta) { + WARN_ONCE(1, "%s: STA %pM not found", sdata->name, bssid); + return false; + } + if (sta_info_move_state(sta, IEEE80211_STA_AUTH)) { + sdata_info(sdata, "failed moving %pM to auth\n", bssid); + return false; + } + mutex_unlock(&sdata->local->sta_mtx); + + return true; +} + static void ieee80211_rx_mgmt_auth(struct ieee80211_sub_if_data *sdata, struct ieee80211_mgmt *mgmt, size_t len) { struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; u8 bssid[ETH_ALEN]; u16 auth_alg, auth_transaction, status_code; - struct sta_info *sta; struct ieee80211_event event = { .type = MLME_EVENT, .u.mlme.data = AUTH_EVENT, @@ -2791,7 +2818,11 @@ static void ieee80211_rx_mgmt_auth(struct ieee80211_sub_if_data *sdata, status_code = le16_to_cpu(mgmt->u.auth.status_code); if (auth_alg != ifmgd->auth_data->algorithm || - auth_transaction != ifmgd->auth_data->expected_transaction) { + (auth_alg != WLAN_AUTH_SAE && + auth_transaction != ifmgd->auth_data->expected_transaction) || + (auth_alg == WLAN_AUTH_SAE && + (auth_transaction < ifmgd->auth_data->expected_transaction || + auth_transaction > 2))) { sdata_info(sdata, "%pM unexpected authentication state: alg %d (expected %d) transact %d (expected %d)\n", mgmt->sa, auth_alg, ifmgd->auth_data->algorithm, auth_transaction, @@ -2834,35 +2865,17 @@ static void ieee80211_rx_mgmt_auth(struct ieee80211_sub_if_data *sdata, event.u.mlme.status = MLME_SUCCESS; drv_event_callback(sdata->local, sdata, &event); - sdata_info(sdata, "authenticated\n"); - ifmgd->auth_data->done = true; - ifmgd->auth_data->timeout = jiffies + IEEE80211_AUTH_WAIT_ASSOC; - ifmgd->auth_data->timeout_started = true; - run_again(sdata, ifmgd->auth_data->timeout); - - if (ifmgd->auth_data->algorithm == WLAN_AUTH_SAE && - ifmgd->auth_data->expected_transaction != 2) { - /* - * Report auth frame to user space for processing since another - * round of Authentication frames is still needed. - */ - cfg80211_rx_mlme_mgmt(sdata->dev, (u8 *)mgmt, len); - return; + if (ifmgd->auth_data->algorithm != WLAN_AUTH_SAE || + (auth_transaction == 2 && + ifmgd->auth_data->expected_transaction == 2)) { + if (!ieee80211_mark_sta_auth(sdata, bssid)) + goto out_err; + } else if (ifmgd->auth_data->algorithm == WLAN_AUTH_SAE && + auth_transaction == 2) { + sdata_info(sdata, "SAE peer confirmed\n"); + ifmgd->auth_data->peer_confirmed = true; } - /* move station state to auth */ - mutex_lock(&sdata->local->sta_mtx); - sta = sta_info_get(sdata, bssid); - if (!sta) { - WARN_ONCE(1, "%s: STA %pM not found", sdata->name, bssid); - goto out_err; - } - if (sta_info_move_state(sta, IEEE80211_STA_AUTH)) { - sdata_info(sdata, "failed moving %pM to auth\n", bssid); - goto out_err; - } - mutex_unlock(&sdata->local->sta_mtx); - cfg80211_rx_mlme_mgmt(sdata->dev, (u8 *)mgmt, len); return; out_err: @@ -4878,6 +4891,7 @@ int ieee80211_mgd_auth(struct ieee80211_sub_if_data *sdata, struct ieee80211_mgd_auth_data *auth_data; u16 auth_alg; int err; + bool cont_auth; /* prepare auth data structure */ @@ -4912,6 +4926,9 @@ int ieee80211_mgd_auth(struct ieee80211_sub_if_data *sdata, return -EOPNOTSUPP; } + if (ifmgd->assoc_data) + return -EBUSY; + auth_data = kzalloc(sizeof(*auth_data) + req->auth_data_len + req->ie_len, GFP_KERNEL); if (!auth_data) @@ -4931,6 +4948,13 @@ int ieee80211_mgd_auth(struct ieee80211_sub_if_data *sdata, auth_data->data_len += req->auth_data_len - 4; } + /* Check if continuing authentication or trying to authenticate with the + * same BSS that we were in the process of authenticating with and avoid + * removal and re-addition of the STA entry in + * ieee80211_prep_connection(). + */ + cont_auth = ifmgd->auth_data && req->bss == ifmgd->auth_data->bss; + if (req->ie && req->ie_len) { memcpy(&auth_data->data[auth_data->data_len], req->ie, req->ie_len); @@ -4947,18 +4971,26 @@ int ieee80211_mgd_auth(struct ieee80211_sub_if_data *sdata, /* try to authenticate/probe */ - if ((ifmgd->auth_data && !ifmgd->auth_data->done) || - ifmgd->assoc_data) { - err = -EBUSY; - goto err_free; + if (ifmgd->auth_data) { + if (cont_auth && req->auth_type == NL80211_AUTHTYPE_SAE) { + auth_data->peer_confirmed = + ifmgd->auth_data->peer_confirmed; + } + ieee80211_destroy_auth_data(sdata, cont_auth); } - if (ifmgd->auth_data) - ieee80211_destroy_auth_data(sdata, false); - /* prep auth_data so we don't go into idle on disassoc */ ifmgd->auth_data = auth_data; + /* If this is continuation of an ongoing SAE authentication exchange + * (i.e., request to send SAE Confirm) and the peer has already + * confirmed, mark authentication completed since we are about to send + * out SAE Confirm. + */ + if (cont_auth && req->auth_type == NL80211_AUTHTYPE_SAE && + auth_data->peer_confirmed && auth_data->sae_trans == 2) + ieee80211_mark_sta_auth(sdata, req->bss->bssid); + if (ifmgd->associated) { u8 frame_buf[IEEE80211_DEAUTH_FRAME_LEN]; @@ -4976,7 +5008,7 @@ int ieee80211_mgd_auth(struct ieee80211_sub_if_data *sdata, sdata_info(sdata, "authenticate with %pM\n", req->bss->bssid); - err = ieee80211_prep_connection(sdata, req->bss, false, false); + err = ieee80211_prep_connection(sdata, req->bss, cont_auth, false); if (err) goto err_clear; @@ -4997,7 +5029,6 @@ int ieee80211_mgd_auth(struct ieee80211_sub_if_data *sdata, mutex_lock(&sdata->local->mtx); ieee80211_vif_release_channel(sdata); mutex_unlock(&sdata->local->mtx); - err_free: kfree(auth_data); return err; } diff --git a/net/mac80211/rate.h b/net/mac80211/rate.h index 8212bfeb71d6..d59198191a79 100644 --- a/net/mac80211/rate.h +++ b/net/mac80211/rate.h @@ -95,18 +95,5 @@ static inline void rc80211_minstrel_exit(void) } #endif -#ifdef CONFIG_MAC80211_RC_MINSTREL_HT -int rc80211_minstrel_ht_init(void); -void rc80211_minstrel_ht_exit(void); -#else -static inline int rc80211_minstrel_ht_init(void) -{ - return 0; -} -static inline void rc80211_minstrel_ht_exit(void) -{ -} -#endif - #endif /* IEEE80211_RATE_H */ diff --git a/net/mac80211/rc80211_minstrel.c b/net/mac80211/rc80211_minstrel.c index 07fb219327d6..a34e9c2ca626 100644 --- a/net/mac80211/rc80211_minstrel.c +++ b/net/mac80211/rc80211_minstrel.c @@ -167,12 +167,6 @@ minstrel_calc_rate_stats(struct minstrel_rate_stats *mrs) if (unlikely(!mrs->att_hist)) { mrs->prob_ewma = cur_prob; } else { - /* update exponential weighted moving variance */ - mrs->prob_ewmv = minstrel_ewmv(mrs->prob_ewmv, - cur_prob, - mrs->prob_ewma, - EWMA_LEVEL); - /*update exponential weighted moving avarage */ mrs->prob_ewma = minstrel_ewma(mrs->prob_ewma, cur_prob, @@ -572,141 +566,6 @@ minstrel_rate_init(void *priv, struct ieee80211_supported_band *sband, minstrel_update_rates(mp, mi); } -static void * -minstrel_alloc_sta(void *priv, struct ieee80211_sta *sta, gfp_t gfp) -{ - struct ieee80211_supported_band *sband; - struct minstrel_sta_info *mi; - struct minstrel_priv *mp = priv; - struct ieee80211_hw *hw = mp->hw; - int max_rates = 0; - int i; - - mi = kzalloc(sizeof(struct minstrel_sta_info), gfp); - if (!mi) - return NULL; - - for (i = 0; i < NUM_NL80211_BANDS; i++) { - sband = hw->wiphy->bands[i]; - if (sband && sband->n_bitrates > max_rates) - max_rates = sband->n_bitrates; - } - - mi->r = kcalloc(max_rates, sizeof(struct minstrel_rate), gfp); - if (!mi->r) - goto error; - - mi->sample_table = kmalloc_array(max_rates, SAMPLE_COLUMNS, gfp); - if (!mi->sample_table) - goto error1; - - mi->last_stats_update = jiffies; - return mi; - -error1: - kfree(mi->r); -error: - kfree(mi); - return NULL; -} - -static void -minstrel_free_sta(void *priv, struct ieee80211_sta *sta, void *priv_sta) -{ - struct minstrel_sta_info *mi = priv_sta; - - kfree(mi->sample_table); - kfree(mi->r); - kfree(mi); -} - -static void -minstrel_init_cck_rates(struct minstrel_priv *mp) -{ - static const int bitrates[4] = { 10, 20, 55, 110 }; - struct ieee80211_supported_band *sband; - u32 rate_flags = ieee80211_chandef_rate_flags(&mp->hw->conf.chandef); - int i, j; - - sband = mp->hw->wiphy->bands[NL80211_BAND_2GHZ]; - if (!sband) - return; - - for (i = 0, j = 0; i < sband->n_bitrates; i++) { - struct ieee80211_rate *rate = &sband->bitrates[i]; - - if (rate->flags & IEEE80211_RATE_ERP_G) - continue; - - if ((rate_flags & sband->bitrates[i].flags) != rate_flags) - continue; - - for (j = 0; j < ARRAY_SIZE(bitrates); j++) { - if (rate->bitrate != bitrates[j]) - continue; - - mp->cck_rates[j] = i; - break; - } - } -} - -static void * -minstrel_alloc(struct ieee80211_hw *hw, struct dentry *debugfsdir) -{ - struct minstrel_priv *mp; - - mp = kzalloc(sizeof(struct minstrel_priv), GFP_ATOMIC); - if (!mp) - return NULL; - - /* contention window settings - * Just an approximation. Using the per-queue values would complicate - * the calculations and is probably unnecessary */ - mp->cw_min = 15; - mp->cw_max = 1023; - - /* number of packets (in %) to use for sampling other rates - * sample less often for non-mrr packets, because the overhead - * is much higher than with mrr */ - mp->lookaround_rate = 5; - mp->lookaround_rate_mrr = 10; - - /* maximum time that the hw is allowed to stay in one MRR segment */ - mp->segment_size = 6000; - - if (hw->max_rate_tries > 0) - mp->max_retry = hw->max_rate_tries; - else - /* safe default, does not necessarily have to match hw properties */ - mp->max_retry = 7; - - if (hw->max_rates >= 4) - mp->has_mrr = true; - - mp->hw = hw; - mp->update_interval = 100; - -#ifdef CONFIG_MAC80211_DEBUGFS - mp->fixed_rate_idx = (u32) -1; - mp->dbg_fixed_rate = debugfs_create_u32("fixed_rate_idx", - 0666, debugfsdir, &mp->fixed_rate_idx); -#endif - - minstrel_init_cck_rates(mp); - - return mp; -} - -static void -minstrel_free(void *priv) -{ -#ifdef CONFIG_MAC80211_DEBUGFS - debugfs_remove(((struct minstrel_priv *)priv)->dbg_fixed_rate); -#endif - kfree(priv); -} - static u32 minstrel_get_expected_throughput(void *priv_sta) { struct minstrel_sta_info *mi = priv_sta; @@ -725,29 +584,8 @@ static u32 minstrel_get_expected_throughput(void *priv_sta) } const struct rate_control_ops mac80211_minstrel = { - .name = "minstrel", .tx_status_ext = minstrel_tx_status, .get_rate = minstrel_get_rate, .rate_init = minstrel_rate_init, - .alloc = minstrel_alloc, - .free = minstrel_free, - .alloc_sta = minstrel_alloc_sta, - .free_sta = minstrel_free_sta, -#ifdef CONFIG_MAC80211_DEBUGFS - .add_sta_debugfs = minstrel_add_sta_debugfs, - .remove_sta_debugfs = minstrel_remove_sta_debugfs, -#endif .get_expected_throughput = minstrel_get_expected_throughput, }; - -int __init -rc80211_minstrel_init(void) -{ - return ieee80211_rate_control_register(&mac80211_minstrel); -} - -void -rc80211_minstrel_exit(void) -{ - ieee80211_rate_control_unregister(&mac80211_minstrel); -} diff --git a/net/mac80211/rc80211_minstrel.h b/net/mac80211/rc80211_minstrel.h index be6c3f35f48b..23ec953e3a24 100644 --- a/net/mac80211/rc80211_minstrel.h +++ b/net/mac80211/rc80211_minstrel.h @@ -35,19 +35,6 @@ minstrel_ewma(int old, int new, int weight) return old + incr; } -/* - * Perform EWMV (Exponentially Weighted Moving Variance) calculation - */ -static inline int -minstrel_ewmv(int old_ewmv, int cur_prob, int prob_ewma, int weight) -{ - int diff, incr; - - diff = cur_prob - prob_ewma; - incr = (EWMA_DIV - weight) * diff / EWMA_DIV; - return weight * (old_ewmv + MINSTREL_TRUNC(diff * incr)) / EWMA_DIV; -} - struct minstrel_rate_stats { /* current / last sampling period attempts/success counters */ u16 attempts, last_attempts; @@ -56,11 +43,8 @@ struct minstrel_rate_stats { /* total attempts/success counters */ u32 att_hist, succ_hist; - /* statistis of packet delivery probability - * prob_ewma - exponential weighted moving average of prob - * prob_ewmsd - exp. weighted moving standard deviation of prob */ + /* prob_ewma - exponential weighted moving average of prob */ u16 prob_ewma; - u16 prob_ewmv; /* maximum retry counts */ u8 retry_count; @@ -109,11 +93,6 @@ struct minstrel_sta_info { /* sampling table */ u8 *sample_table; - -#ifdef CONFIG_MAC80211_DEBUGFS - struct dentry *dbg_stats; - struct dentry *dbg_stats_csv; -#endif }; struct minstrel_priv { @@ -137,7 +116,6 @@ struct minstrel_priv { * - setting will be applied on next update */ u32 fixed_rate_idx; - struct dentry *dbg_fixed_rate; #endif }; @@ -146,17 +124,8 @@ struct minstrel_debugfs_info { char buf[]; }; -/* Get EWMSD (Exponentially Weighted Moving Standard Deviation) * 10 */ -static inline int -minstrel_get_ewmsd10(struct minstrel_rate_stats *mrs) -{ - unsigned int ewmv = mrs->prob_ewmv; - return int_sqrt(MINSTREL_TRUNC(ewmv * 1000 * 1000)); -} - extern const struct rate_control_ops mac80211_minstrel; void minstrel_add_sta_debugfs(void *priv, void *priv_sta, struct dentry *dir); -void minstrel_remove_sta_debugfs(void *priv, void *priv_sta); /* Recalculate success probabilities and counters for a given rate using EWMA */ void minstrel_calc_rate_stats(struct minstrel_rate_stats *mrs); @@ -165,7 +134,5 @@ int minstrel_get_tp_avg(struct minstrel_rate *mr, int prob_ewma); /* debugfs */ int minstrel_stats_open(struct inode *inode, struct file *file); int minstrel_stats_csv_open(struct inode *inode, struct file *file); -ssize_t minstrel_stats_read(struct file *file, char __user *buf, size_t len, loff_t *ppos); -int minstrel_stats_release(struct inode *inode, struct file *file); #endif diff --git a/net/mac80211/rc80211_minstrel_debugfs.c b/net/mac80211/rc80211_minstrel_debugfs.c index 9ad7d63d3e5b..c8afd85b51a0 100644 --- a/net/mac80211/rc80211_minstrel_debugfs.c +++ b/net/mac80211/rc80211_minstrel_debugfs.c @@ -54,22 +54,6 @@ #include <net/mac80211.h> #include "rc80211_minstrel.h" -ssize_t -minstrel_stats_read(struct file *file, char __user *buf, size_t len, loff_t *ppos) -{ - struct minstrel_debugfs_info *ms; - - ms = file->private_data; - return simple_read_from_buffer(buf, len, ppos, ms->buf, ms->len); -} - -int -minstrel_stats_release(struct inode *inode, struct file *file) -{ - kfree(file->private_data); - return 0; -} - int minstrel_stats_open(struct inode *inode, struct file *file) { @@ -86,14 +70,13 @@ minstrel_stats_open(struct inode *inode, struct file *file) p = ms->buf; p += sprintf(p, "\n"); p += sprintf(p, - "best __________rate_________ ________statistics________ ____last_____ ______sum-of________\n"); + "best __________rate_________ ____statistics___ ____last_____ ______sum-of________\n"); p += sprintf(p, - "rate [name idx airtime max_tp] [avg(tp) avg(prob) sd(prob)] [retry|suc|att] [#success | #attempts]\n"); + "rate [name idx airtime max_tp] [avg(tp) avg(prob)] [retry|suc|att] [#success | #attempts]\n"); for (i = 0; i < mi->n_rates; i++) { struct minstrel_rate *mr = &mi->r[i]; struct minstrel_rate_stats *mrs = &mi->r[i].stats; - unsigned int prob_ewmsd; *(p++) = (i == mi->max_tp_rate[0]) ? 'A' : ' '; *(p++) = (i == mi->max_tp_rate[1]) ? 'B' : ' '; @@ -109,15 +92,13 @@ minstrel_stats_open(struct inode *inode, struct file *file) tp_max = minstrel_get_tp_avg(mr, MINSTREL_FRAC(100,100)); tp_avg = minstrel_get_tp_avg(mr, mrs->prob_ewma); eprob = MINSTREL_TRUNC(mrs->prob_ewma * 1000); - prob_ewmsd = minstrel_get_ewmsd10(mrs); - p += sprintf(p, "%4u.%1u %4u.%1u %3u.%1u %3u.%1u" + p += sprintf(p, "%4u.%1u %4u.%1u %3u.%1u" " %3u %3u %-3u " "%9llu %-9llu\n", tp_max / 10, tp_max % 10, tp_avg / 10, tp_avg % 10, eprob / 10, eprob % 10, - prob_ewmsd / 10, prob_ewmsd % 10, mrs->retry_count, mrs->last_success, mrs->last_attempts, @@ -135,14 +116,6 @@ minstrel_stats_open(struct inode *inode, struct file *file) return 0; } -static const struct file_operations minstrel_stat_fops = { - .owner = THIS_MODULE, - .open = minstrel_stats_open, - .read = minstrel_stats_read, - .release = minstrel_stats_release, - .llseek = default_llseek, -}; - int minstrel_stats_csv_open(struct inode *inode, struct file *file) { @@ -161,7 +134,6 @@ minstrel_stats_csv_open(struct inode *inode, struct file *file) for (i = 0; i < mi->n_rates; i++) { struct minstrel_rate *mr = &mi->r[i]; struct minstrel_rate_stats *mrs = &mi->r[i].stats; - unsigned int prob_ewmsd; p += sprintf(p, "%s" ,((i == mi->max_tp_rate[0]) ? "A" : "")); p += sprintf(p, "%s" ,((i == mi->max_tp_rate[1]) ? "B" : "")); @@ -177,14 +149,12 @@ minstrel_stats_csv_open(struct inode *inode, struct file *file) tp_max = minstrel_get_tp_avg(mr, MINSTREL_FRAC(100,100)); tp_avg = minstrel_get_tp_avg(mr, mrs->prob_ewma); eprob = MINSTREL_TRUNC(mrs->prob_ewma * 1000); - prob_ewmsd = minstrel_get_ewmsd10(mrs); - p += sprintf(p, "%u.%u,%u.%u,%u.%u,%u.%u,%u,%u,%u," + p += sprintf(p, "%u.%u,%u.%u,%u.%u,%u,%u,%u," "%llu,%llu,%d,%d\n", tp_max / 10, tp_max % 10, tp_avg / 10, tp_avg % 10, eprob / 10, eprob % 10, - prob_ewmsd / 10, prob_ewmsd % 10, mrs->retry_count, mrs->last_success, mrs->last_attempts, @@ -200,33 +170,3 @@ minstrel_stats_csv_open(struct inode *inode, struct file *file) return 0; } - -static const struct file_operations minstrel_stat_csv_fops = { - .owner = THIS_MODULE, - .open = minstrel_stats_csv_open, - .read = minstrel_stats_read, - .release = minstrel_stats_release, - .llseek = default_llseek, -}; - -void -minstrel_add_sta_debugfs(void *priv, void *priv_sta, struct dentry *dir) -{ - struct minstrel_sta_info *mi = priv_sta; - - mi->dbg_stats = debugfs_create_file("rc_stats", 0444, dir, mi, - &minstrel_stat_fops); - - mi->dbg_stats_csv = debugfs_create_file("rc_stats_csv", 0444, dir, mi, - &minstrel_stat_csv_fops); -} - -void -minstrel_remove_sta_debugfs(void *priv, void *priv_sta) -{ - struct minstrel_sta_info *mi = priv_sta; - - debugfs_remove(mi->dbg_stats); - - debugfs_remove(mi->dbg_stats_csv); -} diff --git a/net/mac80211/rc80211_minstrel_ht.c b/net/mac80211/rc80211_minstrel_ht.c index 67ebdeaffbbc..f466ec37d161 100644 --- a/net/mac80211/rc80211_minstrel_ht.c +++ b/net/mac80211/rc80211_minstrel_ht.c @@ -52,22 +52,23 @@ _streams - 1 /* MCS rate information for an MCS group */ -#define MCS_GROUP(_streams, _sgi, _ht40) \ +#define MCS_GROUP(_streams, _sgi, _ht40, _s) \ [GROUP_IDX(_streams, _sgi, _ht40)] = { \ .streams = _streams, \ + .shift = _s, \ .flags = \ IEEE80211_TX_RC_MCS | \ (_sgi ? IEEE80211_TX_RC_SHORT_GI : 0) | \ (_ht40 ? IEEE80211_TX_RC_40_MHZ_WIDTH : 0), \ .duration = { \ - MCS_DURATION(_streams, _sgi, _ht40 ? 54 : 26), \ - MCS_DURATION(_streams, _sgi, _ht40 ? 108 : 52), \ - MCS_DURATION(_streams, _sgi, _ht40 ? 162 : 78), \ - MCS_DURATION(_streams, _sgi, _ht40 ? 216 : 104), \ - MCS_DURATION(_streams, _sgi, _ht40 ? 324 : 156), \ - MCS_DURATION(_streams, _sgi, _ht40 ? 432 : 208), \ - MCS_DURATION(_streams, _sgi, _ht40 ? 486 : 234), \ - MCS_DURATION(_streams, _sgi, _ht40 ? 540 : 260) \ + MCS_DURATION(_streams, _sgi, _ht40 ? 54 : 26) >> _s, \ + MCS_DURATION(_streams, _sgi, _ht40 ? 108 : 52) >> _s, \ + MCS_DURATION(_streams, _sgi, _ht40 ? 162 : 78) >> _s, \ + MCS_DURATION(_streams, _sgi, _ht40 ? 216 : 104) >> _s, \ + MCS_DURATION(_streams, _sgi, _ht40 ? 324 : 156) >> _s, \ + MCS_DURATION(_streams, _sgi, _ht40 ? 432 : 208) >> _s, \ + MCS_DURATION(_streams, _sgi, _ht40 ? 486 : 234) >> _s, \ + MCS_DURATION(_streams, _sgi, _ht40 ? 540 : 260) >> _s \ } \ } @@ -80,9 +81,10 @@ #define BW2VBPS(_bw, r3, r2, r1) \ (_bw == BW_80 ? r3 : _bw == BW_40 ? r2 : r1) -#define VHT_GROUP(_streams, _sgi, _bw) \ +#define VHT_GROUP(_streams, _sgi, _bw, _s) \ [VHT_GROUP_IDX(_streams, _sgi, _bw)] = { \ .streams = _streams, \ + .shift = _s, \ .flags = \ IEEE80211_TX_RC_VHT_MCS | \ (_sgi ? IEEE80211_TX_RC_SHORT_GI : 0) | \ @@ -90,25 +92,25 @@ _bw == BW_40 ? IEEE80211_TX_RC_40_MHZ_WIDTH : 0), \ .duration = { \ MCS_DURATION(_streams, _sgi, \ - BW2VBPS(_bw, 117, 54, 26)), \ + BW2VBPS(_bw, 117, 54, 26)) >> _s, \ MCS_DURATION(_streams, _sgi, \ - BW2VBPS(_bw, 234, 108, 52)), \ + BW2VBPS(_bw, 234, 108, 52)) >> _s, \ MCS_DURATION(_streams, _sgi, \ - BW2VBPS(_bw, 351, 162, 78)), \ + BW2VBPS(_bw, 351, 162, 78)) >> _s, \ MCS_DURATION(_streams, _sgi, \ - BW2VBPS(_bw, 468, 216, 104)), \ + BW2VBPS(_bw, 468, 216, 104)) >> _s, \ MCS_DURATION(_streams, _sgi, \ - BW2VBPS(_bw, 702, 324, 156)), \ + BW2VBPS(_bw, 702, 324, 156)) >> _s, \ MCS_DURATION(_streams, _sgi, \ - BW2VBPS(_bw, 936, 432, 208)), \ + BW2VBPS(_bw, 936, 432, 208)) >> _s, \ MCS_DURATION(_streams, _sgi, \ - BW2VBPS(_bw, 1053, 486, 234)), \ + BW2VBPS(_bw, 1053, 486, 234)) >> _s, \ MCS_DURATION(_streams, _sgi, \ - BW2VBPS(_bw, 1170, 540, 260)), \ + BW2VBPS(_bw, 1170, 540, 260)) >> _s, \ MCS_DURATION(_streams, _sgi, \ - BW2VBPS(_bw, 1404, 648, 312)), \ + BW2VBPS(_bw, 1404, 648, 312)) >> _s, \ MCS_DURATION(_streams, _sgi, \ - BW2VBPS(_bw, 1560, 720, 346)) \ + BW2VBPS(_bw, 1560, 720, 346)) >> _s \ } \ } @@ -121,28 +123,27 @@ (CCK_DURATION((_bitrate > 10 ? 20 : 10), false, 60) + \ CCK_DURATION(_bitrate, _short, AVG_PKT_SIZE)) -#define CCK_DURATION_LIST(_short) \ - CCK_ACK_DURATION(10, _short), \ - CCK_ACK_DURATION(20, _short), \ - CCK_ACK_DURATION(55, _short), \ - CCK_ACK_DURATION(110, _short) +#define CCK_DURATION_LIST(_short, _s) \ + CCK_ACK_DURATION(10, _short) >> _s, \ + CCK_ACK_DURATION(20, _short) >> _s, \ + CCK_ACK_DURATION(55, _short) >> _s, \ + CCK_ACK_DURATION(110, _short) >> _s -#define CCK_GROUP \ +#define CCK_GROUP(_s) \ [MINSTREL_CCK_GROUP] = { \ - .streams = 0, \ + .streams = 1, \ .flags = 0, \ + .shift = _s, \ .duration = { \ - CCK_DURATION_LIST(false), \ - CCK_DURATION_LIST(true) \ + CCK_DURATION_LIST(false, _s), \ + CCK_DURATION_LIST(true, _s) \ } \ } -#ifdef CONFIG_MAC80211_RC_MINSTREL_VHT static bool minstrel_vht_only = true; module_param(minstrel_vht_only, bool, 0644); MODULE_PARM_DESC(minstrel_vht_only, "Use only VHT rates when VHT is supported by sta."); -#endif /* * To enable sufficiently targeted rate sampling, MCS rates are divided into @@ -153,49 +154,47 @@ MODULE_PARM_DESC(minstrel_vht_only, * BW -> SGI -> #streams */ const struct mcs_group minstrel_mcs_groups[] = { - MCS_GROUP(1, 0, BW_20), - MCS_GROUP(2, 0, BW_20), - MCS_GROUP(3, 0, BW_20), + MCS_GROUP(1, 0, BW_20, 5), + MCS_GROUP(2, 0, BW_20, 4), + MCS_GROUP(3, 0, BW_20, 4), - MCS_GROUP(1, 1, BW_20), - MCS_GROUP(2, 1, BW_20), - MCS_GROUP(3, 1, BW_20), + MCS_GROUP(1, 1, BW_20, 5), + MCS_GROUP(2, 1, BW_20, 4), + MCS_GROUP(3, 1, BW_20, 4), - MCS_GROUP(1, 0, BW_40), - MCS_GROUP(2, 0, BW_40), - MCS_GROUP(3, 0, BW_40), + MCS_GROUP(1, 0, BW_40, 4), + MCS_GROUP(2, 0, BW_40, 4), + MCS_GROUP(3, 0, BW_40, 4), - MCS_GROUP(1, 1, BW_40), - MCS_GROUP(2, 1, BW_40), - MCS_GROUP(3, 1, BW_40), + MCS_GROUP(1, 1, BW_40, 4), + MCS_GROUP(2, 1, BW_40, 4), + MCS_GROUP(3, 1, BW_40, 4), - CCK_GROUP, + CCK_GROUP(8), -#ifdef CONFIG_MAC80211_RC_MINSTREL_VHT - VHT_GROUP(1, 0, BW_20), - VHT_GROUP(2, 0, BW_20), - VHT_GROUP(3, 0, BW_20), + VHT_GROUP(1, 0, BW_20, 5), + VHT_GROUP(2, 0, BW_20, 4), + VHT_GROUP(3, 0, BW_20, 4), - VHT_GROUP(1, 1, BW_20), - VHT_GROUP(2, 1, BW_20), - VHT_GROUP(3, 1, BW_20), + VHT_GROUP(1, 1, BW_20, 5), + VHT_GROUP(2, 1, BW_20, 4), + VHT_GROUP(3, 1, BW_20, 4), - VHT_GROUP(1, 0, BW_40), - VHT_GROUP(2, 0, BW_40), - VHT_GROUP(3, 0, BW_40), + VHT_GROUP(1, 0, BW_40, 4), + VHT_GROUP(2, 0, BW_40, 4), + VHT_GROUP(3, 0, BW_40, 4), - VHT_GROUP(1, 1, BW_40), - VHT_GROUP(2, 1, BW_40), - VHT_GROUP(3, 1, BW_40), + VHT_GROUP(1, 1, BW_40, 4), + VHT_GROUP(2, 1, BW_40, 4), + VHT_GROUP(3, 1, BW_40, 4), - VHT_GROUP(1, 0, BW_80), - VHT_GROUP(2, 0, BW_80), - VHT_GROUP(3, 0, BW_80), + VHT_GROUP(1, 0, BW_80, 4), + VHT_GROUP(2, 0, BW_80, 4), + VHT_GROUP(3, 0, BW_80, 4), - VHT_GROUP(1, 1, BW_80), - VHT_GROUP(2, 1, BW_80), - VHT_GROUP(3, 1, BW_80), -#endif + VHT_GROUP(1, 1, BW_80, 4), + VHT_GROUP(2, 1, BW_80, 4), + VHT_GROUP(3, 1, BW_80, 4), }; static u8 sample_table[SAMPLE_COLUMNS][MCS_GROUP_RATES] __read_mostly; @@ -282,7 +281,8 @@ minstrel_ht_get_stats(struct minstrel_priv *mp, struct minstrel_ht_sta *mi, break; /* short preamble */ - if (!(mi->supported[group] & BIT(idx))) + if ((mi->supported[group] & BIT(idx + 4)) && + (rate->flags & IEEE80211_TX_RC_USE_SHORT_PREAMBLE)) idx += 4; } return &mi->groups[group].rates[idx]; @@ -311,7 +311,8 @@ minstrel_ht_get_tp_avg(struct minstrel_ht_sta *mi, int group, int rate, if (group != MINSTREL_CCK_GROUP) nsecs = 1000 * mi->overhead / MINSTREL_TRUNC(mi->avg_ampdu_len); - nsecs += minstrel_mcs_groups[group].duration[rate]; + nsecs += minstrel_mcs_groups[group].duration[rate] << + minstrel_mcs_groups[group].shift; /* * For the throughput calculation, limit the probability value to 90% to @@ -759,12 +760,19 @@ minstrel_ht_tx_status(void *priv, struct ieee80211_supported_band *sband, minstrel_ht_update_rates(mp, mi); } +static inline int +minstrel_get_duration(int index) +{ + const struct mcs_group *group = &minstrel_mcs_groups[index / MCS_GROUP_RATES]; + unsigned int duration = group->duration[index % MCS_GROUP_RATES]; + return duration << group->shift; +} + static void minstrel_calc_retransmit(struct minstrel_priv *mp, struct minstrel_ht_sta *mi, int index) { struct minstrel_rate_stats *mrs; - const struct mcs_group *group; unsigned int tx_time, tx_time_rtscts, tx_time_data; unsigned int cw = mp->cw_min; unsigned int ctime = 0; @@ -783,8 +791,7 @@ minstrel_calc_retransmit(struct minstrel_priv *mp, struct minstrel_ht_sta *mi, mrs->retry_count_rtscts = 2; mrs->retry_updated = true; - group = &minstrel_mcs_groups[index / MCS_GROUP_RATES]; - tx_time_data = group->duration[index % MCS_GROUP_RATES] * ampdu_len / 1000; + tx_time_data = minstrel_get_duration(index) * ampdu_len / 1000; /* Contention time for first 2 tries */ ctime = (t_slot * cw) >> 1; @@ -878,20 +885,24 @@ minstrel_ht_get_max_amsdu_len(struct minstrel_ht_sta *mi) int group = mi->max_prob_rate / MCS_GROUP_RATES; const struct mcs_group *g = &minstrel_mcs_groups[group]; int rate = mi->max_prob_rate % MCS_GROUP_RATES; + unsigned int duration; /* Disable A-MSDU if max_prob_rate is bad */ if (mi->groups[group].rates[rate].prob_ewma < MINSTREL_FRAC(50, 100)) return 1; + duration = g->duration[rate]; + duration <<= g->shift; + /* If the rate is slower than single-stream MCS1, make A-MSDU limit small */ - if (g->duration[rate] > MCS_DURATION(1, 0, 52)) + if (duration > MCS_DURATION(1, 0, 52)) return 500; /* * If the rate is slower than single-stream MCS4, limit A-MSDU to usual * data packet size */ - if (g->duration[rate] > MCS_DURATION(1, 0, 104)) + if (duration > MCS_DURATION(1, 0, 104)) return 1600; /* @@ -899,7 +910,7 @@ minstrel_ht_get_max_amsdu_len(struct minstrel_ht_sta *mi) * rate success probability is less than 75%, limit A-MSDU to twice the usual * data packet size */ - if (g->duration[rate] > MCS_DURATION(1, 0, 260) || + if (duration > MCS_DURATION(1, 0, 260) || (minstrel_ht_get_prob_ewma(mi, mi->max_tp_rate[0]) < MINSTREL_FRAC(75, 100))) return 3200; @@ -946,13 +957,6 @@ minstrel_ht_update_rates(struct minstrel_priv *mp, struct minstrel_ht_sta *mi) rate_control_set_rates(mp->hw, mi->sta, rates); } -static inline int -minstrel_get_duration(int index) -{ - const struct mcs_group *group = &minstrel_mcs_groups[index / MCS_GROUP_RATES]; - return group->duration[index % MCS_GROUP_RATES]; -} - static int minstrel_get_sample_rate(struct minstrel_priv *mp, struct minstrel_ht_sta *mi) { @@ -1000,10 +1004,13 @@ minstrel_get_sample_rate(struct minstrel_priv *mp, struct minstrel_ht_sta *mi) return -1; /* - * Do not sample if the probability is already higher than 95% - * to avoid wasting airtime. + * Do not sample if the probability is already higher than 95%, + * or if the rate is 3 times slower than the current max probability + * rate, to avoid wasting airtime. */ - if (mrs->prob_ewma > MINSTREL_FRAC(95, 100)) + sample_dur = minstrel_get_duration(sample_idx); + if (mrs->prob_ewma > MINSTREL_FRAC(95, 100) || + minstrel_get_duration(mi->max_prob_rate) * 3 < sample_dur) return -1; /* @@ -1013,7 +1020,6 @@ minstrel_get_sample_rate(struct minstrel_priv *mp, struct minstrel_ht_sta *mi) cur_max_tp_streams = minstrel_mcs_groups[tp_rate1 / MCS_GROUP_RATES].streams; - sample_dur = minstrel_get_duration(sample_idx); if (sample_dur >= minstrel_get_duration(tp_rate2) && (cur_max_tp_streams - 1 < minstrel_mcs_groups[sample_group].streams || @@ -1077,18 +1083,23 @@ minstrel_ht_get_rate(void *priv, struct ieee80211_sta *sta, void *priv_sta, return; sample_group = &minstrel_mcs_groups[sample_idx / MCS_GROUP_RATES]; + sample_idx %= MCS_GROUP_RATES; + + if (sample_group == &minstrel_mcs_groups[MINSTREL_CCK_GROUP] && + (sample_idx >= 4) != txrc->short_preamble) + return; + info->flags |= IEEE80211_TX_CTL_RATE_CTRL_PROBE; rate->count = 1; - if (sample_idx / MCS_GROUP_RATES == MINSTREL_CCK_GROUP) { + if (sample_group == &minstrel_mcs_groups[MINSTREL_CCK_GROUP]) { int idx = sample_idx % ARRAY_SIZE(mp->cck_rates); rate->idx = mp->cck_rates[idx]; } else if (sample_group->flags & IEEE80211_TX_RC_VHT_MCS) { ieee80211_rate_set_vht(rate, sample_idx % MCS_GROUP_RATES, sample_group->streams); } else { - rate->idx = sample_idx % MCS_GROUP_RATES + - (sample_group->streams - 1) * 8; + rate->idx = sample_idx + (sample_group->streams - 1) * 8; } rate->flags = sample_group->flags; @@ -1130,14 +1141,14 @@ minstrel_ht_update_caps(void *priv, struct ieee80211_supported_band *sband, struct minstrel_ht_sta_priv *msp = priv_sta; struct minstrel_ht_sta *mi = &msp->ht; struct ieee80211_mcs_info *mcs = &sta->ht_cap.mcs; - u16 sta_cap = sta->ht_cap.cap; + u16 ht_cap = sta->ht_cap.cap; struct ieee80211_sta_vht_cap *vht_cap = &sta->vht_cap; - struct sta_info *sinfo = container_of(sta, struct sta_info, sta); int use_vht; int n_supported = 0; int ack_dur; int stbc; int i; + bool ldpc; /* fall back to the old minstrel for legacy stations */ if (!sta->ht_cap.ht_supported) @@ -1145,12 +1156,10 @@ minstrel_ht_update_caps(void *priv, struct ieee80211_supported_band *sband, BUILD_BUG_ON(ARRAY_SIZE(minstrel_mcs_groups) != MINSTREL_GROUPS_NB); -#ifdef CONFIG_MAC80211_RC_MINSTREL_VHT if (vht_cap->vht_supported) use_vht = vht_cap->vht_mcs.tx_mcs_map != cpu_to_le16(~0); else -#endif - use_vht = 0; + use_vht = 0; msp->is_ht = true; memset(mi, 0, sizeof(*mi)); @@ -1175,16 +1184,22 @@ minstrel_ht_update_caps(void *priv, struct ieee80211_supported_band *sband, } mi->sample_tries = 4; - /* TODO tx_flags for vht - ATM the RC API is not fine-grained enough */ if (!use_vht) { - stbc = (sta_cap & IEEE80211_HT_CAP_RX_STBC) >> + stbc = (ht_cap & IEEE80211_HT_CAP_RX_STBC) >> IEEE80211_HT_CAP_RX_STBC_SHIFT; - mi->tx_flags |= stbc << IEEE80211_TX_CTL_STBC_SHIFT; - if (sta_cap & IEEE80211_HT_CAP_LDPC_CODING) - mi->tx_flags |= IEEE80211_TX_CTL_LDPC; + ldpc = ht_cap & IEEE80211_HT_CAP_LDPC_CODING; + } else { + stbc = (vht_cap->cap & IEEE80211_VHT_CAP_RXSTBC_MASK) >> + IEEE80211_VHT_CAP_RXSTBC_SHIFT; + + ldpc = vht_cap->cap & IEEE80211_VHT_CAP_RXLDPC; } + mi->tx_flags |= stbc << IEEE80211_TX_CTL_STBC_SHIFT; + if (ldpc) + mi->tx_flags |= IEEE80211_TX_CTL_LDPC; + for (i = 0; i < ARRAY_SIZE(mi->groups); i++) { u32 gflags = minstrel_mcs_groups[i].flags; int bw, nss; @@ -1197,10 +1212,10 @@ minstrel_ht_update_caps(void *priv, struct ieee80211_supported_band *sband, if (gflags & IEEE80211_TX_RC_SHORT_GI) { if (gflags & IEEE80211_TX_RC_40_MHZ_WIDTH) { - if (!(sta_cap & IEEE80211_HT_CAP_SGI_40)) + if (!(ht_cap & IEEE80211_HT_CAP_SGI_40)) continue; } else { - if (!(sta_cap & IEEE80211_HT_CAP_SGI_20)) + if (!(ht_cap & IEEE80211_HT_CAP_SGI_20)) continue; } } @@ -1217,10 +1232,9 @@ minstrel_ht_update_caps(void *priv, struct ieee80211_supported_band *sband, /* HT rate */ if (gflags & IEEE80211_TX_RC_MCS) { -#ifdef CONFIG_MAC80211_RC_MINSTREL_VHT if (use_vht && minstrel_vht_only) continue; -#endif + mi->supported[i] = mcs->rx_mask[nss - 1]; if (mi->supported[i]) n_supported++; @@ -1258,8 +1272,7 @@ minstrel_ht_update_caps(void *priv, struct ieee80211_supported_band *sband, if (!n_supported) goto use_legacy; - if (test_sta_flag(sinfo, WLAN_STA_SHORT_PREAMBLE)) - mi->cck_supported_short |= mi->cck_supported_short << 4; + mi->supported[MINSTREL_CCK_GROUP] |= mi->cck_supported_short << 4; /* create an initial rate table with the lowest supported rates */ minstrel_ht_update_stats(mp, mi); @@ -1340,16 +1353,88 @@ minstrel_ht_free_sta(void *priv, struct ieee80211_sta *sta, void *priv_sta) kfree(msp); } +static void +minstrel_ht_init_cck_rates(struct minstrel_priv *mp) +{ + static const int bitrates[4] = { 10, 20, 55, 110 }; + struct ieee80211_supported_band *sband; + u32 rate_flags = ieee80211_chandef_rate_flags(&mp->hw->conf.chandef); + int i, j; + + sband = mp->hw->wiphy->bands[NL80211_BAND_2GHZ]; + if (!sband) + return; + + for (i = 0; i < sband->n_bitrates; i++) { + struct ieee80211_rate *rate = &sband->bitrates[i]; + + if (rate->flags & IEEE80211_RATE_ERP_G) + continue; + + if ((rate_flags & sband->bitrates[i].flags) != rate_flags) + continue; + + for (j = 0; j < ARRAY_SIZE(bitrates); j++) { + if (rate->bitrate != bitrates[j]) + continue; + + mp->cck_rates[j] = i; + break; + } + } +} + static void * minstrel_ht_alloc(struct ieee80211_hw *hw, struct dentry *debugfsdir) { - return mac80211_minstrel.alloc(hw, debugfsdir); + struct minstrel_priv *mp; + + mp = kzalloc(sizeof(struct minstrel_priv), GFP_ATOMIC); + if (!mp) + return NULL; + + /* contention window settings + * Just an approximation. Using the per-queue values would complicate + * the calculations and is probably unnecessary */ + mp->cw_min = 15; + mp->cw_max = 1023; + + /* number of packets (in %) to use for sampling other rates + * sample less often for non-mrr packets, because the overhead + * is much higher than with mrr */ + mp->lookaround_rate = 5; + mp->lookaround_rate_mrr = 10; + + /* maximum time that the hw is allowed to stay in one MRR segment */ + mp->segment_size = 6000; + + if (hw->max_rate_tries > 0) + mp->max_retry = hw->max_rate_tries; + else + /* safe default, does not necessarily have to match hw properties */ + mp->max_retry = 7; + + if (hw->max_rates >= 4) + mp->has_mrr = true; + + mp->hw = hw; + mp->update_interval = 100; + +#ifdef CONFIG_MAC80211_DEBUGFS + mp->fixed_rate_idx = (u32) -1; + debugfs_create_u32("fixed_rate_idx", S_IRUGO | S_IWUGO, debugfsdir, + &mp->fixed_rate_idx); +#endif + + minstrel_ht_init_cck_rates(mp); + + return mp; } static void minstrel_ht_free(void *priv) { - mac80211_minstrel.free(priv); + kfree(priv); } static u32 minstrel_ht_get_expected_throughput(void *priv_sta) @@ -1384,7 +1469,6 @@ static const struct rate_control_ops mac80211_minstrel_ht = { .free = minstrel_ht_free, #ifdef CONFIG_MAC80211_DEBUGFS .add_sta_debugfs = minstrel_ht_add_sta_debugfs, - .remove_sta_debugfs = minstrel_ht_remove_sta_debugfs, #endif .get_expected_throughput = minstrel_ht_get_expected_throughput, }; @@ -1409,14 +1493,14 @@ static void __init init_sample_table(void) } int __init -rc80211_minstrel_ht_init(void) +rc80211_minstrel_init(void) { init_sample_table(); return ieee80211_rate_control_register(&mac80211_minstrel_ht); } void -rc80211_minstrel_ht_exit(void) +rc80211_minstrel_exit(void) { ieee80211_rate_control_unregister(&mac80211_minstrel_ht); } diff --git a/net/mac80211/rc80211_minstrel_ht.h b/net/mac80211/rc80211_minstrel_ht.h index de1646c42e82..26b7a3244b47 100644 --- a/net/mac80211/rc80211_minstrel_ht.h +++ b/net/mac80211/rc80211_minstrel_ht.h @@ -15,11 +15,7 @@ */ #define MINSTREL_MAX_STREAMS 3 #define MINSTREL_HT_STREAM_GROUPS 4 /* BW(=2) * SGI(=2) */ -#ifdef CONFIG_MAC80211_RC_MINSTREL_VHT #define MINSTREL_VHT_STREAM_GROUPS 6 /* BW(=3) * SGI(=2) */ -#else -#define MINSTREL_VHT_STREAM_GROUPS 0 -#endif #define MINSTREL_HT_GROUPS_NB (MINSTREL_MAX_STREAMS * \ MINSTREL_HT_STREAM_GROUPS) @@ -34,16 +30,13 @@ #define MINSTREL_CCK_GROUP (MINSTREL_HT_GROUP_0 + MINSTREL_HT_GROUPS_NB) #define MINSTREL_VHT_GROUP_0 (MINSTREL_CCK_GROUP + 1) -#ifdef CONFIG_MAC80211_RC_MINSTREL_VHT #define MCS_GROUP_RATES 10 -#else -#define MCS_GROUP_RATES 8 -#endif struct mcs_group { - u32 flags; - unsigned int streams; - unsigned int duration[MCS_GROUP_RATES]; + u16 flags; + u8 streams; + u8 shift; + u16 duration[MCS_GROUP_RATES]; }; extern const struct mcs_group minstrel_mcs_groups[]; @@ -110,17 +103,12 @@ struct minstrel_ht_sta_priv { struct minstrel_ht_sta ht; struct minstrel_sta_info legacy; }; -#ifdef CONFIG_MAC80211_DEBUGFS - struct dentry *dbg_stats; - struct dentry *dbg_stats_csv; -#endif void *ratelist; void *sample_table; bool is_ht; }; void minstrel_ht_add_sta_debugfs(void *priv, void *priv_sta, struct dentry *dir); -void minstrel_ht_remove_sta_debugfs(void *priv, void *priv_sta); int minstrel_ht_get_tp_avg(struct minstrel_ht_sta *mi, int group, int rate, int prob_ewma); diff --git a/net/mac80211/rc80211_minstrel_ht_debugfs.c b/net/mac80211/rc80211_minstrel_ht_debugfs.c index bfcc03152dc6..57820a5f2c16 100644 --- a/net/mac80211/rc80211_minstrel_ht_debugfs.c +++ b/net/mac80211/rc80211_minstrel_ht_debugfs.c @@ -15,6 +15,22 @@ #include "rc80211_minstrel.h" #include "rc80211_minstrel_ht.h" +static ssize_t +minstrel_stats_read(struct file *file, char __user *buf, size_t len, loff_t *ppos) +{ + struct minstrel_debugfs_info *ms; + + ms = file->private_data; + return simple_read_from_buffer(buf, len, ppos, ms->buf, ms->len); +} + +static int +minstrel_stats_release(struct inode *inode, struct file *file) +{ + kfree(file->private_data); + return 0; +} + static char * minstrel_ht_stats_dump(struct minstrel_ht_sta *mi, int i, char *p) { @@ -41,7 +57,7 @@ minstrel_ht_stats_dump(struct minstrel_ht_sta *mi, int i, char *p) struct minstrel_rate_stats *mrs = &mi->groups[i].rates[j]; static const int bitrates[4] = { 10, 20, 55, 110 }; int idx = i * MCS_GROUP_RATES + j; - unsigned int prob_ewmsd; + unsigned int duration; if (!(mi->supported[i] & BIT(j))) continue; @@ -79,21 +95,21 @@ minstrel_ht_stats_dump(struct minstrel_ht_sta *mi, int i, char *p) p += sprintf(p, " %3u ", idx); /* tx_time[rate(i)] in usec */ - tx_time = DIV_ROUND_CLOSEST(mg->duration[j], 1000); + duration = mg->duration[j]; + duration <<= mg->shift; + tx_time = DIV_ROUND_CLOSEST(duration, 1000); p += sprintf(p, "%6u ", tx_time); tp_max = minstrel_ht_get_tp_avg(mi, i, j, MINSTREL_FRAC(100, 100)); tp_avg = minstrel_ht_get_tp_avg(mi, i, j, mrs->prob_ewma); eprob = MINSTREL_TRUNC(mrs->prob_ewma * 1000); - prob_ewmsd = minstrel_get_ewmsd10(mrs); - p += sprintf(p, "%4u.%1u %4u.%1u %3u.%1u %3u.%1u" + p += sprintf(p, "%4u.%1u %4u.%1u %3u.%1u" " %3u %3u %-3u " "%9llu %-9llu\n", tp_max / 10, tp_max % 10, tp_avg / 10, tp_avg % 10, eprob / 10, eprob % 10, - prob_ewmsd / 10, prob_ewmsd % 10, mrs->retry_count, mrs->last_success, mrs->last_attempts, @@ -130,9 +146,9 @@ minstrel_ht_stats_open(struct inode *inode, struct file *file) p += sprintf(p, "\n"); p += sprintf(p, - " best ____________rate__________ ________statistics________ _____last____ ______sum-of________\n"); + " best ____________rate__________ ____statistics___ _____last____ ______sum-of________\n"); p += sprintf(p, - "mode guard # rate [name idx airtime max_tp] [avg(tp) avg(prob) sd(prob)] [retry|suc|att] [#success | #attempts]\n"); + "mode guard # rate [name idx airtime max_tp] [avg(tp) avg(prob)] [retry|suc|att] [#success | #attempts]\n"); p = minstrel_ht_stats_dump(mi, MINSTREL_CCK_GROUP, p); for (i = 0; i < MINSTREL_CCK_GROUP; i++) @@ -187,7 +203,7 @@ minstrel_ht_stats_csv_dump(struct minstrel_ht_sta *mi, int i, char *p) struct minstrel_rate_stats *mrs = &mi->groups[i].rates[j]; static const int bitrates[4] = { 10, 20, 55, 110 }; int idx = i * MCS_GROUP_RATES + j; - unsigned int prob_ewmsd; + unsigned int duration; if (!(mi->supported[i] & BIT(j))) continue; @@ -222,20 +238,21 @@ minstrel_ht_stats_csv_dump(struct minstrel_ht_sta *mi, int i, char *p) } p += sprintf(p, "%u,", idx); - tx_time = DIV_ROUND_CLOSEST(mg->duration[j], 1000); + + duration = mg->duration[j]; + duration <<= mg->shift; + tx_time = DIV_ROUND_CLOSEST(duration, 1000); p += sprintf(p, "%u,", tx_time); tp_max = minstrel_ht_get_tp_avg(mi, i, j, MINSTREL_FRAC(100, 100)); tp_avg = minstrel_ht_get_tp_avg(mi, i, j, mrs->prob_ewma); eprob = MINSTREL_TRUNC(mrs->prob_ewma * 1000); - prob_ewmsd = minstrel_get_ewmsd10(mrs); - p += sprintf(p, "%u.%u,%u.%u,%u.%u,%u.%u,%u,%u," + p += sprintf(p, "%u.%u,%u.%u,%u.%u,%u,%u," "%u,%llu,%llu,", tp_max / 10, tp_max % 10, tp_avg / 10, tp_avg % 10, eprob / 10, eprob % 10, - prob_ewmsd / 10, prob_ewmsd % 10, mrs->retry_count, mrs->last_success, mrs->last_attempts, @@ -303,17 +320,8 @@ minstrel_ht_add_sta_debugfs(void *priv, void *priv_sta, struct dentry *dir) { struct minstrel_ht_sta_priv *msp = priv_sta; - msp->dbg_stats = debugfs_create_file("rc_stats", 0444, dir, msp, - &minstrel_ht_stat_fops); - msp->dbg_stats_csv = debugfs_create_file("rc_stats_csv", 0444, dir, msp, - &minstrel_ht_stat_csv_fops); -} - -void -minstrel_ht_remove_sta_debugfs(void *priv, void *priv_sta) -{ - struct minstrel_ht_sta_priv *msp = priv_sta; - - debugfs_remove(msp->dbg_stats); - debugfs_remove(msp->dbg_stats_csv); + debugfs_create_file("rc_stats", 0444, dir, msp, + &minstrel_ht_stat_fops); + debugfs_create_file("rc_stats_csv", 0444, dir, msp, + &minstrel_ht_stat_csv_fops); } diff --git a/net/mac80211/rx.c b/net/mac80211/rx.c index a0ca27aeb732..3bd3b5769797 100644 --- a/net/mac80211/rx.c +++ b/net/mac80211/rx.c @@ -2458,8 +2458,9 @@ ieee80211_deliver_skb(struct ieee80211_rx_data *rx) if (!xmit_skb) net_info_ratelimited("%s: failed to clone multicast frame\n", dev->name); - } else if (!is_multicast_ether_addr(ehdr->h_dest)) { - dsta = sta_info_get(sdata, skb->data); + } else if (!is_multicast_ether_addr(ehdr->h_dest) && + !ether_addr_equal(ehdr->h_dest, ehdr->h_source)) { + dsta = sta_info_get(sdata, ehdr->h_dest); if (dsta) { /* * The destination station is associated to @@ -4240,11 +4241,10 @@ static bool ieee80211_invoke_fast_rx(struct ieee80211_rx_data *rx, if (fast_rx->internal_forward) { struct sk_buff *xmit_skb = NULL; - bool multicast = is_multicast_ether_addr(skb->data); - - if (multicast) { + if (is_multicast_ether_addr(addrs.da)) { xmit_skb = skb_copy(skb, GFP_ATOMIC); - } else if (sta_info_get(rx->sdata, skb->data)) { + } else if (!ether_addr_equal(addrs.da, addrs.sa) && + sta_info_get(rx->sdata, addrs.da)) { xmit_skb = skb; skb = NULL; } diff --git a/net/mac80211/status.c b/net/mac80211/status.c index 9a6d7208bf4f..aa4afbf0abaf 100644 --- a/net/mac80211/status.c +++ b/net/mac80211/status.c @@ -479,11 +479,6 @@ static void ieee80211_report_ack_skb(struct ieee80211_local *local, if (!skb) return; - if (dropped) { - dev_kfree_skb_any(skb); - return; - } - if (info->flags & IEEE80211_TX_INTFL_NL80211_FRAME_TX) { u64 cookie = IEEE80211_SKB_CB(skb)->ack.cookie; struct ieee80211_sub_if_data *sdata; @@ -507,6 +502,8 @@ static void ieee80211_report_ack_skb(struct ieee80211_local *local, rcu_read_unlock(); dev_kfree_skb_any(skb); + } else if (dropped) { + dev_kfree_skb_any(skb); } else { /* consumes skb */ skb_complete_wifi_ack(skb, acked); @@ -811,7 +808,7 @@ static void __ieee80211_tx_status(struct ieee80211_hw *hw, rate_control_tx_status(local, sband, status); if (ieee80211_vif_is_mesh(&sta->sdata->vif)) - ieee80211s_update_metric(local, sta, skb); + ieee80211s_update_metric(local, sta, status); if (!(info->flags & IEEE80211_TX_CTL_INJECTED) && acked) ieee80211_frame_acked(sta, skb); @@ -972,6 +969,8 @@ void ieee80211_tx_status_ext(struct ieee80211_hw *hw, } rate_control_tx_status(local, sband, status); + if (ieee80211_vif_is_mesh(&sta->sdata->vif)) + ieee80211s_update_metric(local, sta, status); } if (acked || noack_success) { @@ -988,6 +987,25 @@ void ieee80211_tx_status_ext(struct ieee80211_hw *hw, } EXPORT_SYMBOL(ieee80211_tx_status_ext); +void ieee80211_tx_rate_update(struct ieee80211_hw *hw, + struct ieee80211_sta *pubsta, + struct ieee80211_tx_info *info) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct ieee80211_supported_band *sband = hw->wiphy->bands[info->band]; + struct sta_info *sta = container_of(pubsta, struct sta_info, sta); + struct ieee80211_tx_status status = { + .info = info, + .sta = pubsta, + }; + + rate_control_tx_status(local, sband, &status); + + if (ieee80211_hw_check(&local->hw, HAS_RATE_CONTROL)) + sta->tx_stats.last_rate = info->status.rates[0]; +} +EXPORT_SYMBOL(ieee80211_tx_rate_update); + void ieee80211_report_low_ack(struct ieee80211_sta *pubsta, u32 num_packets) { struct sta_info *sta = container_of(pubsta, struct sta_info, sta); diff --git a/net/mac80211/tdls.c b/net/mac80211/tdls.c index 5cd5e6e5834e..6c647f425e05 100644 --- a/net/mac80211/tdls.c +++ b/net/mac80211/tdls.c @@ -16,6 +16,7 @@ #include "ieee80211_i.h" #include "driver-ops.h" #include "rate.h" +#include "wme.h" /* give usermode some time for retries in setting up the TDLS session */ #define TDLS_PEER_SETUP_TIMEOUT (15 * HZ) @@ -1010,14 +1011,13 @@ ieee80211_tdls_prep_mgmt_packet(struct wiphy *wiphy, struct net_device *dev, switch (action_code) { case WLAN_TDLS_SETUP_REQUEST: case WLAN_TDLS_SETUP_RESPONSE: - skb_set_queue_mapping(skb, IEEE80211_AC_BK); - skb->priority = 2; + skb->priority = 256 + 2; break; default: - skb_set_queue_mapping(skb, IEEE80211_AC_VI); - skb->priority = 5; + skb->priority = 256 + 5; break; } + skb_set_queue_mapping(skb, ieee80211_select_queue(sdata, skb)); /* * Set the WLAN_TDLS_TEARDOWN flag to indicate a teardown in progress. diff --git a/net/mac80211/trace.h b/net/mac80211/trace.h index 0ab69a1964f8..588c51a67c89 100644 --- a/net/mac80211/trace.h +++ b/net/mac80211/trace.h @@ -2600,6 +2600,29 @@ TRACE_EVENT(drv_wake_tx_queue, ) ); +TRACE_EVENT(drv_get_ftm_responder_stats, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct cfg80211_ftm_responder_stats *ftm_stats), + + TP_ARGS(local, sdata, ftm_stats), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT, + LOCAL_PR_ARG, VIF_PR_ARG + ) +); + #endif /* !__MAC80211_DRIVER_TRACE || TRACE_HEADER_MULTI_READ */ #undef TRACE_INCLUDE_PATH diff --git a/net/mac80211/tx.c b/net/mac80211/tx.c index c42bfa1dcd2c..e0ccee23fbcd 100644 --- a/net/mac80211/tx.c +++ b/net/mac80211/tx.c @@ -214,6 +214,7 @@ ieee80211_tx_h_dynamic_ps(struct ieee80211_tx_data *tx) { struct ieee80211_local *local = tx->local; struct ieee80211_if_managed *ifmgd; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(tx->skb); /* driver doesn't support power save */ if (!ieee80211_hw_check(&local->hw, SUPPORTS_PS)) @@ -242,6 +243,9 @@ ieee80211_tx_h_dynamic_ps(struct ieee80211_tx_data *tx) if (tx->sdata->vif.type != NL80211_IFTYPE_STATION) return TX_CONTINUE; + if (unlikely(info->flags & IEEE80211_TX_INTFL_OFFCHAN_TX_OK)) + return TX_CONTINUE; + ifmgd = &tx->sdata->u.mgd; /* @@ -1915,7 +1919,7 @@ static bool ieee80211_tx(struct ieee80211_sub_if_data *sdata, sdata->vif.hw_queue[skb_get_queue_mapping(skb)]; if (invoke_tx_handlers_early(&tx)) - return false; + return true; if (ieee80211_queue_skb(local, sdata, tx.sta, tx.skb)) return true; diff --git a/net/mac80211/util.c b/net/mac80211/util.c index 36a3c2ada515..bec424316ea4 100644 --- a/net/mac80211/util.c +++ b/net/mac80211/util.c @@ -264,6 +264,9 @@ static void __ieee80211_wake_txqs(struct ieee80211_sub_if_data *sdata, int ac) for (i = 0; i < ARRAY_SIZE(sta->sta.txq); i++) { struct ieee80211_txq *txq = sta->sta.txq[i]; + if (!txq) + continue; + txqi = to_txq_info(txq); if (ac != txq->ac) @@ -2175,6 +2178,11 @@ int ieee80211_reconfig(struct ieee80211_local *local) case NL80211_IFTYPE_AP: changed |= BSS_CHANGED_SSID | BSS_CHANGED_P2P_PS; + if (sdata->vif.bss_conf.ftm_responder == 1 && + wiphy_ext_feature_isset(sdata->local->hw.wiphy, + NL80211_EXT_FEATURE_ENABLE_FTM_RESPONDER)) + changed |= BSS_CHANGED_FTM_RESPONDER; + if (sdata->vif.type == NL80211_IFTYPE_AP) { changed |= BSS_CHANGED_AP_PROBE_RESP; diff --git a/net/mac802154/llsec.c b/net/mac802154/llsec.c index 2fb703d70803..7e29f88dbf6a 100644 --- a/net/mac802154/llsec.c +++ b/net/mac802154/llsec.c @@ -146,18 +146,18 @@ llsec_key_alloc(const struct ieee802154_llsec_key *template) goto err_tfm; } - key->tfm0 = crypto_alloc_skcipher("ctr(aes)", 0, CRYPTO_ALG_ASYNC); + key->tfm0 = crypto_alloc_sync_skcipher("ctr(aes)", 0, 0); if (IS_ERR(key->tfm0)) goto err_tfm; - if (crypto_skcipher_setkey(key->tfm0, template->key, + if (crypto_sync_skcipher_setkey(key->tfm0, template->key, IEEE802154_LLSEC_KEY_SIZE)) goto err_tfm0; return key; err_tfm0: - crypto_free_skcipher(key->tfm0); + crypto_free_sync_skcipher(key->tfm0); err_tfm: for (i = 0; i < ARRAY_SIZE(key->tfm); i++) if (key->tfm[i]) @@ -177,7 +177,7 @@ static void llsec_key_release(struct kref *ref) for (i = 0; i < ARRAY_SIZE(key->tfm); i++) crypto_free_aead(key->tfm[i]); - crypto_free_skcipher(key->tfm0); + crypto_free_sync_skcipher(key->tfm0); kzfree(key); } @@ -622,7 +622,7 @@ llsec_do_encrypt_unauth(struct sk_buff *skb, const struct mac802154_llsec *sec, { u8 iv[16]; struct scatterlist src; - SKCIPHER_REQUEST_ON_STACK(req, key->tfm0); + SYNC_SKCIPHER_REQUEST_ON_STACK(req, key->tfm0); int err, datalen; unsigned char *data; @@ -632,7 +632,7 @@ llsec_do_encrypt_unauth(struct sk_buff *skb, const struct mac802154_llsec *sec, datalen = skb_tail_pointer(skb) - data; sg_init_one(&src, data, datalen); - skcipher_request_set_tfm(req, key->tfm0); + skcipher_request_set_sync_tfm(req, key->tfm0); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, &src, &src, datalen, iv); err = crypto_skcipher_encrypt(req); @@ -840,7 +840,7 @@ llsec_do_decrypt_unauth(struct sk_buff *skb, const struct mac802154_llsec *sec, unsigned char *data; int datalen; struct scatterlist src; - SKCIPHER_REQUEST_ON_STACK(req, key->tfm0); + SYNC_SKCIPHER_REQUEST_ON_STACK(req, key->tfm0); int err; llsec_geniv(iv, dev_addr, &hdr->sec); @@ -849,7 +849,7 @@ llsec_do_decrypt_unauth(struct sk_buff *skb, const struct mac802154_llsec *sec, sg_init_one(&src, data, datalen); - skcipher_request_set_tfm(req, key->tfm0); + skcipher_request_set_sync_tfm(req, key->tfm0); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, &src, &src, datalen, iv); diff --git a/net/mac802154/llsec.h b/net/mac802154/llsec.h index 6f3b658e3279..8be46d74dc39 100644 --- a/net/mac802154/llsec.h +++ b/net/mac802154/llsec.h @@ -29,7 +29,7 @@ struct mac802154_llsec_key { /* one tfm for each authsize (4/8/16) */ struct crypto_aead *tfm[3]; - struct crypto_skcipher *tfm0; + struct crypto_sync_skcipher *tfm0; struct kref ref; }; diff --git a/net/mpls/af_mpls.c b/net/mpls/af_mpls.c index 8fbe6cdbe255..7d55d4c04088 100644 --- a/net/mpls/af_mpls.c +++ b/net/mpls/af_mpls.c @@ -1223,7 +1223,7 @@ static int mpls_netconf_get_devconf(struct sk_buff *in_skb, int err; err = nlmsg_parse(nlh, sizeof(*ncm), tb, NETCONFA_MAX, - devconf_mpls_policy, NULL); + devconf_mpls_policy, extack); if (err < 0) goto errout; @@ -1263,6 +1263,7 @@ errout: static int mpls_netconf_dump_devconf(struct sk_buff *skb, struct netlink_callback *cb) { + const struct nlmsghdr *nlh = cb->nlh; struct net *net = sock_net(skb->sk); struct hlist_head *head; struct net_device *dev; @@ -1270,6 +1271,21 @@ static int mpls_netconf_dump_devconf(struct sk_buff *skb, int idx, s_idx; int h, s_h; + if (cb->strict_check) { + struct netlink_ext_ack *extack = cb->extack; + struct netconfmsg *ncm; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ncm))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid header for netconf dump request"); + return -EINVAL; + } + + if (nlmsg_attrlen(nlh, sizeof(*ncm))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid data after header in netconf dump request"); + return -EINVAL; + } + } + s_h = cb->args[0]; s_idx = idx = cb->args[1]; @@ -1286,7 +1302,7 @@ static int mpls_netconf_dump_devconf(struct sk_buff *skb, goto cont; if (mpls_netconf_fill_devconf(skb, mdev, NETLINK_CB(cb->skb).portid, - cb->nlh->nlmsg_seq, + nlh->nlmsg_seq, RTM_NEWNETCONF, NLM_F_MULTI, NETCONFA_ALL) < 0) { @@ -2015,30 +2031,140 @@ nla_put_failure: return -EMSGSIZE; } +#if IS_ENABLED(CONFIG_INET) +static int mpls_valid_fib_dump_req(struct net *net, const struct nlmsghdr *nlh, + struct fib_dump_filter *filter, + struct netlink_callback *cb) +{ + return ip_valid_fib_dump_req(net, nlh, filter, cb); +} +#else +static int mpls_valid_fib_dump_req(struct net *net, const struct nlmsghdr *nlh, + struct fib_dump_filter *filter, + struct netlink_callback *cb) +{ + struct netlink_ext_ack *extack = cb->extack; + struct nlattr *tb[RTA_MAX + 1]; + struct rtmsg *rtm; + int err, i; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*rtm))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid header for FIB dump request"); + return -EINVAL; + } + + rtm = nlmsg_data(nlh); + if (rtm->rtm_dst_len || rtm->rtm_src_len || rtm->rtm_tos || + rtm->rtm_table || rtm->rtm_scope || rtm->rtm_type || + rtm->rtm_flags) { + NL_SET_ERR_MSG_MOD(extack, "Invalid values in header for FIB dump request"); + return -EINVAL; + } + + if (rtm->rtm_protocol) { + filter->protocol = rtm->rtm_protocol; + filter->filter_set = 1; + cb->answer_flags = NLM_F_DUMP_FILTERED; + } + + err = nlmsg_parse_strict(nlh, sizeof(*rtm), tb, RTA_MAX, + rtm_mpls_policy, extack); + if (err < 0) + return err; + + for (i = 0; i <= RTA_MAX; ++i) { + int ifindex; + + if (i == RTA_OIF) { + ifindex = nla_get_u32(tb[i]); + filter->dev = __dev_get_by_index(net, ifindex); + if (!filter->dev) + return -ENODEV; + filter->filter_set = 1; + } else if (tb[i]) { + NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in dump request"); + return -EINVAL; + } + } + + return 0; +} +#endif + +static bool mpls_rt_uses_dev(struct mpls_route *rt, + const struct net_device *dev) +{ + struct net_device *nh_dev; + + if (rt->rt_nhn == 1) { + struct mpls_nh *nh = rt->rt_nh; + + nh_dev = rtnl_dereference(nh->nh_dev); + if (dev == nh_dev) + return true; + } else { + for_nexthops(rt) { + nh_dev = rtnl_dereference(nh->nh_dev); + if (nh_dev == dev) + return true; + } endfor_nexthops(rt); + } + + return false; +} + static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb) { + const struct nlmsghdr *nlh = cb->nlh; struct net *net = sock_net(skb->sk); struct mpls_route __rcu **platform_label; + struct fib_dump_filter filter = {}; + unsigned int flags = NLM_F_MULTI; size_t platform_labels; unsigned int index; ASSERT_RTNL(); + if (cb->strict_check) { + int err; + + err = mpls_valid_fib_dump_req(net, nlh, &filter, cb); + if (err < 0) + return err; + + /* for MPLS, there is only 1 table with fixed type and flags. + * If either are set in the filter then return nothing. + */ + if ((filter.table_id && filter.table_id != RT_TABLE_MAIN) || + (filter.rt_type && filter.rt_type != RTN_UNICAST) || + filter.flags) + return skb->len; + } + index = cb->args[0]; if (index < MPLS_LABEL_FIRST_UNRESERVED) index = MPLS_LABEL_FIRST_UNRESERVED; platform_label = rtnl_dereference(net->mpls.platform_label); platform_labels = net->mpls.platform_labels; + + if (filter.filter_set) + flags |= NLM_F_DUMP_FILTERED; + for (; index < platform_labels; index++) { struct mpls_route *rt; + rt = rtnl_dereference(platform_label[index]); if (!rt) continue; + if ((filter.dev && !mpls_rt_uses_dev(rt, filter.dev)) || + (filter.protocol && rt->rt_protocol != filter.protocol)) + continue; + if (mpls_dump_route(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, RTM_NEWROUTE, - index, rt, NLM_F_MULTI) < 0) + index, rt, flags) < 0) break; } cb->args[0] = index; diff --git a/net/ncsi/Kconfig b/net/ncsi/Kconfig index 08a8a6031fd7..7f2b46108a24 100644 --- a/net/ncsi/Kconfig +++ b/net/ncsi/Kconfig @@ -10,3 +10,9 @@ config NET_NCSI support. Enable this only if your system connects to a network device via NCSI and the ethernet driver you're using supports the protocol explicitly. +config NCSI_OEM_CMD_GET_MAC + bool "Get NCSI OEM MAC Address" + depends on NET_NCSI + ---help--- + This allows to get MAC address from NCSI firmware and set them back to + controller. diff --git a/net/ncsi/internal.h b/net/ncsi/internal.h index 8055e3965cef..1dae77c54009 100644 --- a/net/ncsi/internal.h +++ b/net/ncsi/internal.h @@ -68,6 +68,17 @@ enum { NCSI_MODE_MAX }; +/* OEM Vendor Manufacture ID */ +#define NCSI_OEM_MFR_MLX_ID 0x8119 +#define NCSI_OEM_MFR_BCM_ID 0x113d +/* Broadcom specific OEM Command */ +#define NCSI_OEM_BCM_CMD_GMA 0x01 /* CMD ID for Get MAC */ +/* OEM Command payload lengths*/ +#define NCSI_OEM_BCM_CMD_GMA_LEN 12 +/* Mac address offset in OEM response */ +#define BCM_MAC_ADDR_OFFSET 28 + + struct ncsi_channel_version { u32 version; /* Supported BCD encoded NCSI version */ u32 alpha2; /* Supported BCD encoded NCSI version */ @@ -171,6 +182,8 @@ struct ncsi_package; #define NCSI_RESERVED_CHANNEL 0x1f #define NCSI_CHANNEL_INDEX(c) ((c) & ((1 << NCSI_PACKAGE_SHIFT) - 1)) #define NCSI_TO_CHANNEL(p, c) (((p) << NCSI_PACKAGE_SHIFT) | (c)) +#define NCSI_MAX_PACKAGE 8 +#define NCSI_MAX_CHANNEL 32 struct ncsi_channel { unsigned char id; @@ -216,11 +229,15 @@ struct ncsi_request { bool used; /* Request that has been assigned */ unsigned int flags; /* NCSI request property */ #define NCSI_REQ_FLAG_EVENT_DRIVEN 1 +#define NCSI_REQ_FLAG_NETLINK_DRIVEN 2 struct ncsi_dev_priv *ndp; /* Associated NCSI device */ struct sk_buff *cmd; /* Associated NCSI command packet */ struct sk_buff *rsp; /* Associated NCSI response packet */ struct timer_list timer; /* Timer on waiting for response */ bool enabled; /* Time has been enabled or not */ + u32 snd_seq; /* netlink sending sequence number */ + u32 snd_portid; /* netlink portid of sender */ + struct nlmsghdr nlhdr; /* netlink message header */ }; enum { @@ -236,6 +253,7 @@ enum { ncsi_dev_state_probe_dp, ncsi_dev_state_config_sp = 0x0301, ncsi_dev_state_config_cis, + ncsi_dev_state_config_oem_gma, ncsi_dev_state_config_clear_vids, ncsi_dev_state_config_svf, ncsi_dev_state_config_ev, @@ -269,6 +287,7 @@ struct ncsi_dev_priv { #define NCSI_DEV_PROBED 1 /* Finalized NCSI topology */ #define NCSI_DEV_HWA 2 /* Enabled HW arbitration */ #define NCSI_DEV_RESHUFFLE 4 + unsigned int gma_flag; /* OEM GMA flag */ spinlock_t lock; /* Protect the NCSI device */ #if IS_ENABLED(CONFIG_IPV6) unsigned int inet6_addr_num; /* Number of IPv6 addresses */ @@ -305,6 +324,8 @@ struct ncsi_cmd_arg { unsigned short words[8]; unsigned int dwords[4]; }; + unsigned char *data; /* NCSI OEM data */ + struct genl_info *info; /* Netlink information */ }; extern struct list_head ncsi_dev_list; diff --git a/net/ncsi/ncsi-cmd.c b/net/ncsi/ncsi-cmd.c index 7567ca63aae2..356af474e43c 100644 --- a/net/ncsi/ncsi-cmd.c +++ b/net/ncsi/ncsi-cmd.c @@ -17,6 +17,7 @@ #include <net/ncsi.h> #include <net/net_namespace.h> #include <net/sock.h> +#include <net/genetlink.h> #include "internal.h" #include "ncsi-pkt.h" @@ -211,6 +212,25 @@ static int ncsi_cmd_handler_snfc(struct sk_buff *skb, return 0; } +static int ncsi_cmd_handler_oem(struct sk_buff *skb, + struct ncsi_cmd_arg *nca) +{ + struct ncsi_cmd_oem_pkt *cmd; + unsigned int len; + + len = sizeof(struct ncsi_cmd_pkt_hdr) + 4; + if (nca->payload < 26) + len += 26; + else + len += nca->payload; + + cmd = skb_put_zero(skb, len); + memcpy(&cmd->mfr_id, nca->data, nca->payload); + ncsi_cmd_build_header(&cmd->cmd.common, nca); + + return 0; +} + static struct ncsi_cmd_handler { unsigned char type; int payload; @@ -244,7 +264,7 @@ static struct ncsi_cmd_handler { { NCSI_PKT_CMD_GNS, 0, ncsi_cmd_handler_default }, { NCSI_PKT_CMD_GNPTS, 0, ncsi_cmd_handler_default }, { NCSI_PKT_CMD_GPS, 0, ncsi_cmd_handler_default }, - { NCSI_PKT_CMD_OEM, 0, NULL }, + { NCSI_PKT_CMD_OEM, -1, ncsi_cmd_handler_oem }, { NCSI_PKT_CMD_PLDM, 0, NULL }, { NCSI_PKT_CMD_GPUUID, 0, ncsi_cmd_handler_default } }; @@ -316,12 +336,24 @@ int ncsi_xmit_cmd(struct ncsi_cmd_arg *nca) return -ENOENT; } - /* Get packet payload length and allocate the request */ - nca->payload = nch->payload; + /* Get packet payload length and allocate the request + * It is expected that if length set as negative in + * handler structure means caller is initializing it + * and setting length in nca before calling xmit function + */ + if (nch->payload >= 0) + nca->payload = nch->payload; nr = ncsi_alloc_command(nca); if (!nr) return -ENOMEM; + /* track netlink information */ + if (nca->req_flags == NCSI_REQ_FLAG_NETLINK_DRIVEN) { + nr->snd_seq = nca->info->snd_seq; + nr->snd_portid = nca->info->snd_portid; + nr->nlhdr = *nca->info->nlhdr; + } + /* Prepare the packet */ nca->id = nr->id; ret = nch->handler(nr->cmd, nca); diff --git a/net/ncsi/ncsi-manage.c b/net/ncsi/ncsi-manage.c index 091284760d21..bfc43b28c7a6 100644 --- a/net/ncsi/ncsi-manage.c +++ b/net/ncsi/ncsi-manage.c @@ -19,6 +19,7 @@ #include <net/addrconf.h> #include <net/ipv6.h> #include <net/if_inet6.h> +#include <net/genetlink.h> #include "internal.h" #include "ncsi-pkt.h" @@ -406,6 +407,9 @@ static void ncsi_request_timeout(struct timer_list *t) { struct ncsi_request *nr = from_timer(nr, t, timer); struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_cmd_pkt *cmd; + struct ncsi_package *np; + struct ncsi_channel *nc; unsigned long flags; /* If the request already had associated response, @@ -419,6 +423,18 @@ static void ncsi_request_timeout(struct timer_list *t) } spin_unlock_irqrestore(&ndp->lock, flags); + if (nr->flags == NCSI_REQ_FLAG_NETLINK_DRIVEN) { + if (nr->cmd) { + /* Find the package */ + cmd = (struct ncsi_cmd_pkt *) + skb_network_header(nr->cmd); + ncsi_find_package_and_channel(ndp, + cmd->cmd.common.channel, + &np, &nc); + ncsi_send_netlink_timeout(nr, np, nc); + } + } + /* Release the request */ ncsi_free_request(nr); } @@ -635,6 +651,72 @@ static int set_one_vid(struct ncsi_dev_priv *ndp, struct ncsi_channel *nc, return 0; } +#if IS_ENABLED(CONFIG_NCSI_OEM_CMD_GET_MAC) + +/* NCSI OEM Command APIs */ +static int ncsi_oem_gma_handler_bcm(struct ncsi_cmd_arg *nca) +{ + unsigned char data[NCSI_OEM_BCM_CMD_GMA_LEN]; + int ret = 0; + + nca->payload = NCSI_OEM_BCM_CMD_GMA_LEN; + + memset(data, 0, NCSI_OEM_BCM_CMD_GMA_LEN); + *(unsigned int *)data = ntohl(NCSI_OEM_MFR_BCM_ID); + data[5] = NCSI_OEM_BCM_CMD_GMA; + + nca->data = data; + + ret = ncsi_xmit_cmd(nca); + if (ret) + netdev_err(nca->ndp->ndev.dev, + "NCSI: Failed to transmit cmd 0x%x during configure\n", + nca->type); + return ret; +} + +/* OEM Command handlers initialization */ +static struct ncsi_oem_gma_handler { + unsigned int mfr_id; + int (*handler)(struct ncsi_cmd_arg *nca); +} ncsi_oem_gma_handlers[] = { + { NCSI_OEM_MFR_BCM_ID, ncsi_oem_gma_handler_bcm } +}; + +static int ncsi_gma_handler(struct ncsi_cmd_arg *nca, unsigned int mf_id) +{ + struct ncsi_oem_gma_handler *nch = NULL; + int i; + + /* This function should only be called once, return if flag set */ + if (nca->ndp->gma_flag == 1) + return -1; + + /* Find gma handler for given manufacturer id */ + for (i = 0; i < ARRAY_SIZE(ncsi_oem_gma_handlers); i++) { + if (ncsi_oem_gma_handlers[i].mfr_id == mf_id) { + if (ncsi_oem_gma_handlers[i].handler) + nch = &ncsi_oem_gma_handlers[i]; + break; + } + } + + if (!nch) { + netdev_err(nca->ndp->ndev.dev, + "NCSI: No GMA handler available for MFR-ID (0x%x)\n", + mf_id); + return -1; + } + + /* Set the flag for GMA command which should only be called once */ + nca->ndp->gma_flag = 1; + + /* Get Mac address from NCSI device */ + return nch->handler(nca); +} + +#endif /* CONFIG_NCSI_OEM_CMD_GET_MAC */ + static void ncsi_configure_channel(struct ncsi_dev_priv *ndp) { struct ncsi_dev *nd = &ndp->ndev; @@ -685,7 +767,23 @@ static void ncsi_configure_channel(struct ncsi_dev_priv *ndp) goto error; } + nd->state = ncsi_dev_state_config_oem_gma; + break; + case ncsi_dev_state_config_oem_gma: nd->state = ncsi_dev_state_config_clear_vids; + ret = -1; + +#if IS_ENABLED(CONFIG_NCSI_OEM_CMD_GET_MAC) + nca.type = NCSI_PKT_CMD_OEM; + nca.package = np->id; + nca.channel = nc->id; + ndp->pending_req_num = 1; + ret = ncsi_gma_handler(&nca, nc->version.mf_id); +#endif /* CONFIG_NCSI_OEM_CMD_GET_MAC */ + + if (ret < 0) + schedule_work(&ndp->work); + break; case ncsi_dev_state_config_clear_vids: case ncsi_dev_state_config_svf: diff --git a/net/ncsi/ncsi-netlink.c b/net/ncsi/ncsi-netlink.c index 32cb7751d216..33314381b4f5 100644 --- a/net/ncsi/ncsi-netlink.c +++ b/net/ncsi/ncsi-netlink.c @@ -19,6 +19,7 @@ #include <uapi/linux/ncsi.h> #include "internal.h" +#include "ncsi-pkt.h" #include "ncsi-netlink.h" static struct genl_family ncsi_genl_family; @@ -28,6 +29,7 @@ static const struct nla_policy ncsi_genl_policy[NCSI_ATTR_MAX + 1] = { [NCSI_ATTR_PACKAGE_LIST] = { .type = NLA_NESTED }, [NCSI_ATTR_PACKAGE_ID] = { .type = NLA_U32 }, [NCSI_ATTR_CHANNEL_ID] = { .type = NLA_U32 }, + [NCSI_ATTR_DATA] = { .type = NLA_BINARY, .len = 2048 }, }; static struct ncsi_dev_priv *ndp_from_ifindex(struct net *net, u32 ifindex) @@ -365,6 +367,202 @@ static int ncsi_clear_interface_nl(struct sk_buff *msg, struct genl_info *info) return 0; } +static int ncsi_send_cmd_nl(struct sk_buff *msg, struct genl_info *info) +{ + struct ncsi_dev_priv *ndp; + struct ncsi_pkt_hdr *hdr; + struct ncsi_cmd_arg nca; + unsigned char *data; + u32 package_id; + u32 channel_id; + int len, ret; + + if (!info || !info->attrs) { + ret = -EINVAL; + goto out; + } + + if (!info->attrs[NCSI_ATTR_IFINDEX]) { + ret = -EINVAL; + goto out; + } + + if (!info->attrs[NCSI_ATTR_PACKAGE_ID]) { + ret = -EINVAL; + goto out; + } + + if (!info->attrs[NCSI_ATTR_CHANNEL_ID]) { + ret = -EINVAL; + goto out; + } + + if (!info->attrs[NCSI_ATTR_DATA]) { + ret = -EINVAL; + goto out; + } + + ndp = ndp_from_ifindex(get_net(sock_net(msg->sk)), + nla_get_u32(info->attrs[NCSI_ATTR_IFINDEX])); + if (!ndp) { + ret = -ENODEV; + goto out; + } + + package_id = nla_get_u32(info->attrs[NCSI_ATTR_PACKAGE_ID]); + channel_id = nla_get_u32(info->attrs[NCSI_ATTR_CHANNEL_ID]); + + if (package_id >= NCSI_MAX_PACKAGE || channel_id >= NCSI_MAX_CHANNEL) { + ret = -ERANGE; + goto out_netlink; + } + + len = nla_len(info->attrs[NCSI_ATTR_DATA]); + if (len < sizeof(struct ncsi_pkt_hdr)) { + netdev_info(ndp->ndev.dev, "NCSI: no command to send %u\n", + package_id); + ret = -EINVAL; + goto out_netlink; + } else { + data = (unsigned char *)nla_data(info->attrs[NCSI_ATTR_DATA]); + } + + hdr = (struct ncsi_pkt_hdr *)data; + + nca.ndp = ndp; + nca.package = (unsigned char)package_id; + nca.channel = (unsigned char)channel_id; + nca.type = hdr->type; + nca.req_flags = NCSI_REQ_FLAG_NETLINK_DRIVEN; + nca.info = info; + nca.payload = ntohs(hdr->length); + nca.data = data + sizeof(*hdr); + + ret = ncsi_xmit_cmd(&nca); +out_netlink: + if (ret != 0) { + netdev_err(ndp->ndev.dev, + "NCSI: Error %d sending command\n", + ret); + ncsi_send_netlink_err(ndp->ndev.dev, + info->snd_seq, + info->snd_portid, + info->nlhdr, + ret); + } +out: + return ret; +} + +int ncsi_send_netlink_rsp(struct ncsi_request *nr, + struct ncsi_package *np, + struct ncsi_channel *nc) +{ + struct sk_buff *skb; + struct net *net; + void *hdr; + int rc; + + net = dev_net(nr->rsp->dev); + + skb = genlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); + if (!skb) + return -ENOMEM; + + hdr = genlmsg_put(skb, nr->snd_portid, nr->snd_seq, + &ncsi_genl_family, 0, NCSI_CMD_SEND_CMD); + if (!hdr) { + kfree_skb(skb); + return -EMSGSIZE; + } + + nla_put_u32(skb, NCSI_ATTR_IFINDEX, nr->rsp->dev->ifindex); + if (np) + nla_put_u32(skb, NCSI_ATTR_PACKAGE_ID, np->id); + if (nc) + nla_put_u32(skb, NCSI_ATTR_CHANNEL_ID, nc->id); + else + nla_put_u32(skb, NCSI_ATTR_CHANNEL_ID, NCSI_RESERVED_CHANNEL); + + rc = nla_put(skb, NCSI_ATTR_DATA, nr->rsp->len, (void *)nr->rsp->data); + if (rc) + goto err; + + genlmsg_end(skb, hdr); + return genlmsg_unicast(net, skb, nr->snd_portid); + +err: + kfree_skb(skb); + return rc; +} + +int ncsi_send_netlink_timeout(struct ncsi_request *nr, + struct ncsi_package *np, + struct ncsi_channel *nc) +{ + struct sk_buff *skb; + struct net *net; + void *hdr; + + skb = genlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); + if (!skb) + return -ENOMEM; + + hdr = genlmsg_put(skb, nr->snd_portid, nr->snd_seq, + &ncsi_genl_family, 0, NCSI_CMD_SEND_CMD); + if (!hdr) { + kfree_skb(skb); + return -EMSGSIZE; + } + + net = dev_net(nr->cmd->dev); + + nla_put_u32(skb, NCSI_ATTR_IFINDEX, nr->cmd->dev->ifindex); + + if (np) + nla_put_u32(skb, NCSI_ATTR_PACKAGE_ID, np->id); + else + nla_put_u32(skb, NCSI_ATTR_PACKAGE_ID, + NCSI_PACKAGE_INDEX((((struct ncsi_pkt_hdr *) + nr->cmd->data)->channel))); + + if (nc) + nla_put_u32(skb, NCSI_ATTR_CHANNEL_ID, nc->id); + else + nla_put_u32(skb, NCSI_ATTR_CHANNEL_ID, NCSI_RESERVED_CHANNEL); + + genlmsg_end(skb, hdr); + return genlmsg_unicast(net, skb, nr->snd_portid); +} + +int ncsi_send_netlink_err(struct net_device *dev, + u32 snd_seq, + u32 snd_portid, + struct nlmsghdr *nlhdr, + int err) +{ + struct nlmsghdr *nlh; + struct nlmsgerr *nle; + struct sk_buff *skb; + struct net *net; + + skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); + if (!skb) + return -ENOMEM; + + net = dev_net(dev); + + nlh = nlmsg_put(skb, snd_portid, snd_seq, + NLMSG_ERROR, sizeof(*nle), 0); + nle = (struct nlmsgerr *)nlmsg_data(nlh); + nle->error = err; + memcpy(&nle->msg, nlhdr, sizeof(*nlh)); + + nlmsg_end(skb, nlh); + + return nlmsg_unicast(net->genl_sock, skb, snd_portid); +} + static const struct genl_ops ncsi_ops[] = { { .cmd = NCSI_CMD_PKG_INFO, @@ -385,6 +583,12 @@ static const struct genl_ops ncsi_ops[] = { .doit = ncsi_clear_interface_nl, .flags = GENL_ADMIN_PERM, }, + { + .cmd = NCSI_CMD_SEND_CMD, + .policy = ncsi_genl_policy, + .doit = ncsi_send_cmd_nl, + .flags = GENL_ADMIN_PERM, + }, }; static struct genl_family ncsi_genl_family __ro_after_init = { diff --git a/net/ncsi/ncsi-netlink.h b/net/ncsi/ncsi-netlink.h index 91a5c256f8c4..c4a46887a932 100644 --- a/net/ncsi/ncsi-netlink.h +++ b/net/ncsi/ncsi-netlink.h @@ -14,6 +14,18 @@ #include "internal.h" +int ncsi_send_netlink_rsp(struct ncsi_request *nr, + struct ncsi_package *np, + struct ncsi_channel *nc); +int ncsi_send_netlink_timeout(struct ncsi_request *nr, + struct ncsi_package *np, + struct ncsi_channel *nc); +int ncsi_send_netlink_err(struct net_device *dev, + u32 snd_seq, + u32 snd_portid, + struct nlmsghdr *nlhdr, + int err); + int ncsi_init_netlink(struct net_device *dev); int ncsi_unregister_netlink(struct net_device *dev); diff --git a/net/ncsi/ncsi-pkt.h b/net/ncsi/ncsi-pkt.h index 91b4b66438df..4d3f06be38bd 100644 --- a/net/ncsi/ncsi-pkt.h +++ b/net/ncsi/ncsi-pkt.h @@ -151,6 +151,28 @@ struct ncsi_cmd_snfc_pkt { unsigned char pad[22]; }; +/* OEM Request Command as per NCSI Specification */ +struct ncsi_cmd_oem_pkt { + struct ncsi_cmd_pkt_hdr cmd; /* Command header */ + __be32 mfr_id; /* Manufacture ID */ + unsigned char data[]; /* OEM Payload Data */ +}; + +/* OEM Response Packet as per NCSI Specification */ +struct ncsi_rsp_oem_pkt { + struct ncsi_rsp_pkt_hdr rsp; /* Command header */ + __be32 mfr_id; /* Manufacture ID */ + unsigned char data[]; /* Payload data */ +}; + +/* Broadcom Response Data */ +struct ncsi_rsp_oem_bcm_pkt { + unsigned char ver; /* Payload Version */ + unsigned char type; /* OEM Command type */ + __be16 len; /* Payload Length */ + unsigned char data[]; /* Cmd specific Data */ +}; + /* Get Link Status */ struct ncsi_rsp_gls_pkt { struct ncsi_rsp_pkt_hdr rsp; /* Response header */ diff --git a/net/ncsi/ncsi-rsp.c b/net/ncsi/ncsi-rsp.c index 930c1d3796f0..77e07ba3f493 100644 --- a/net/ncsi/ncsi-rsp.c +++ b/net/ncsi/ncsi-rsp.c @@ -16,9 +16,11 @@ #include <net/ncsi.h> #include <net/net_namespace.h> #include <net/sock.h> +#include <net/genetlink.h> #include "internal.h" #include "ncsi-pkt.h" +#include "ncsi-netlink.h" static int ncsi_validate_rsp_pkt(struct ncsi_request *nr, unsigned short payload) @@ -32,15 +34,25 @@ static int ncsi_validate_rsp_pkt(struct ncsi_request *nr, * before calling this function. */ h = (struct ncsi_rsp_pkt_hdr *)skb_network_header(nr->rsp); - if (h->common.revision != NCSI_PKT_REVISION) + + if (h->common.revision != NCSI_PKT_REVISION) { + netdev_dbg(nr->ndp->ndev.dev, + "NCSI: unsupported header revision\n"); return -EINVAL; - if (ntohs(h->common.length) != payload) + } + if (ntohs(h->common.length) != payload) { + netdev_dbg(nr->ndp->ndev.dev, + "NCSI: payload length mismatched\n"); return -EINVAL; + } /* Check on code and reason */ if (ntohs(h->code) != NCSI_PKT_RSP_C_COMPLETED || - ntohs(h->reason) != NCSI_PKT_RSP_R_NO_ERROR) - return -EINVAL; + ntohs(h->reason) != NCSI_PKT_RSP_R_NO_ERROR) { + netdev_dbg(nr->ndp->ndev.dev, + "NCSI: non zero response/reason code\n"); + return -EPERM; + } /* Validate checksum, which might be zeroes if the * sender doesn't support checksum according to NCSI @@ -52,8 +64,11 @@ static int ncsi_validate_rsp_pkt(struct ncsi_request *nr, checksum = ncsi_calculate_checksum((unsigned char *)h, sizeof(*h) + payload - 4); - if (*pchecksum != htonl(checksum)) + + if (*pchecksum != htonl(checksum)) { + netdev_dbg(nr->ndp->ndev.dev, "NCSI: checksum mismatched\n"); return -EINVAL; + } return 0; } @@ -596,6 +611,87 @@ static int ncsi_rsp_handler_snfc(struct ncsi_request *nr) return 0; } +/* Response handler for Broadcom command Get Mac Address */ +static int ncsi_rsp_handler_oem_bcm_gma(struct ncsi_request *nr) +{ + struct ncsi_dev_priv *ndp = nr->ndp; + struct net_device *ndev = ndp->ndev.dev; + const struct net_device_ops *ops = ndev->netdev_ops; + struct ncsi_rsp_oem_pkt *rsp; + struct sockaddr saddr; + int ret = 0; + + /* Get the response header */ + rsp = (struct ncsi_rsp_oem_pkt *)skb_network_header(nr->rsp); + + saddr.sa_family = ndev->type; + ndev->priv_flags |= IFF_LIVE_ADDR_CHANGE; + memcpy(saddr.sa_data, &rsp->data[BCM_MAC_ADDR_OFFSET], ETH_ALEN); + /* Increase mac address by 1 for BMC's address */ + saddr.sa_data[ETH_ALEN - 1]++; + ret = ops->ndo_set_mac_address(ndev, &saddr); + if (ret < 0) + netdev_warn(ndev, "NCSI: 'Writing mac address to device failed\n"); + + return ret; +} + +/* Response handler for Broadcom card */ +static int ncsi_rsp_handler_oem_bcm(struct ncsi_request *nr) +{ + struct ncsi_rsp_oem_bcm_pkt *bcm; + struct ncsi_rsp_oem_pkt *rsp; + + /* Get the response header */ + rsp = (struct ncsi_rsp_oem_pkt *)skb_network_header(nr->rsp); + bcm = (struct ncsi_rsp_oem_bcm_pkt *)(rsp->data); + + if (bcm->type == NCSI_OEM_BCM_CMD_GMA) + return ncsi_rsp_handler_oem_bcm_gma(nr); + return 0; +} + +static struct ncsi_rsp_oem_handler { + unsigned int mfr_id; + int (*handler)(struct ncsi_request *nr); +} ncsi_rsp_oem_handlers[] = { + { NCSI_OEM_MFR_MLX_ID, NULL }, + { NCSI_OEM_MFR_BCM_ID, ncsi_rsp_handler_oem_bcm } +}; + +/* Response handler for OEM command */ +static int ncsi_rsp_handler_oem(struct ncsi_request *nr) +{ + struct ncsi_rsp_oem_handler *nrh = NULL; + struct ncsi_rsp_oem_pkt *rsp; + unsigned int mfr_id, i; + + /* Get the response header */ + rsp = (struct ncsi_rsp_oem_pkt *)skb_network_header(nr->rsp); + mfr_id = ntohl(rsp->mfr_id); + + /* Check for manufacturer id and Find the handler */ + for (i = 0; i < ARRAY_SIZE(ncsi_rsp_oem_handlers); i++) { + if (ncsi_rsp_oem_handlers[i].mfr_id == mfr_id) { + if (ncsi_rsp_oem_handlers[i].handler) + nrh = &ncsi_rsp_oem_handlers[i]; + else + nrh = NULL; + + break; + } + } + + if (!nrh) { + netdev_err(nr->ndp->ndev.dev, "Received unrecognized OEM packet with MFR-ID (0x%x)\n", + mfr_id); + return -ENOENT; + } + + /* Process the packet */ + return nrh->handler(nr); +} + static int ncsi_rsp_handler_gvi(struct ncsi_request *nr) { struct ncsi_rsp_gvi_pkt *rsp; @@ -900,6 +996,26 @@ static int ncsi_rsp_handler_gpuuid(struct ncsi_request *nr) return 0; } +static int ncsi_rsp_handler_netlink(struct ncsi_request *nr) +{ + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_rsp_pkt *rsp; + struct ncsi_package *np; + struct ncsi_channel *nc; + int ret; + + /* Find the package */ + rsp = (struct ncsi_rsp_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + &np, &nc); + if (!np) + return -ENODEV; + + ret = ncsi_send_netlink_rsp(nr, np, nc); + + return ret; +} + static struct ncsi_rsp_handler { unsigned char type; int payload; @@ -932,7 +1048,7 @@ static struct ncsi_rsp_handler { { NCSI_PKT_RSP_GNS, 172, ncsi_rsp_handler_gns }, { NCSI_PKT_RSP_GNPTS, 172, ncsi_rsp_handler_gnpts }, { NCSI_PKT_RSP_GPS, 8, ncsi_rsp_handler_gps }, - { NCSI_PKT_RSP_OEM, 0, NULL }, + { NCSI_PKT_RSP_OEM, -1, ncsi_rsp_handler_oem }, { NCSI_PKT_RSP_PLDM, 0, NULL }, { NCSI_PKT_RSP_GPUUID, 20, ncsi_rsp_handler_gpuuid } }; @@ -1002,6 +1118,17 @@ int ncsi_rcv_rsp(struct sk_buff *skb, struct net_device *dev, netdev_warn(ndp->ndev.dev, "NCSI: 'bad' packet ignored for type 0x%x\n", hdr->type); + + if (nr->flags == NCSI_REQ_FLAG_NETLINK_DRIVEN) { + if (ret == -EPERM) + goto out_netlink; + else + ncsi_send_netlink_err(ndp->ndev.dev, + nr->snd_seq, + nr->snd_portid, + &nr->nlhdr, + ret); + } goto out; } @@ -1011,6 +1138,17 @@ int ncsi_rcv_rsp(struct sk_buff *skb, struct net_device *dev, netdev_err(ndp->ndev.dev, "NCSI: Handler for packet type 0x%x returned %d\n", hdr->type, ret); + +out_netlink: + if (nr->flags == NCSI_REQ_FLAG_NETLINK_DRIVEN) { + ret = ncsi_rsp_handler_netlink(nr); + if (ret) { + netdev_err(ndp->ndev.dev, + "NCSI: Netlink handler for packet type 0x%x returned %d\n", + hdr->type, ret); + } + } + out: ncsi_free_request(nr); return ret; diff --git a/net/netfilter/Kconfig b/net/netfilter/Kconfig index f61c306de1d0..2ab870ef233a 100644 --- a/net/netfilter/Kconfig +++ b/net/netfilter/Kconfig @@ -625,6 +625,13 @@ config NFT_FIB_INET The lookup will be delegated to the IPv4 or IPv6 FIB depending on the protocol of the packet. +config NFT_XFRM + tristate "Netfilter nf_tables xfrm/IPSec security association matching" + depends on XFRM + help + This option adds an expression that you can use to extract properties + of a packets security association. + config NFT_SOCKET tristate "Netfilter nf_tables socket match support" depends on IPV6 || IPV6=n diff --git a/net/netfilter/Makefile b/net/netfilter/Makefile index 16895e045b66..4ddf3ef51ece 100644 --- a/net/netfilter/Makefile +++ b/net/netfilter/Makefile @@ -113,6 +113,7 @@ obj-$(CONFIG_NFT_FIB_NETDEV) += nft_fib_netdev.o obj-$(CONFIG_NFT_SOCKET) += nft_socket.o obj-$(CONFIG_NFT_OSF) += nft_osf.o obj-$(CONFIG_NFT_TPROXY) += nft_tproxy.o +obj-$(CONFIG_NFT_XFRM) += nft_xfrm.o # nf_tables netdev obj-$(CONFIG_NFT_DUP_NETDEV) += nft_dup_netdev.o diff --git a/net/netfilter/ipset/ip_set_hash_gen.h b/net/netfilter/ipset/ip_set_hash_gen.h index 8a33dac4e805..e287da68d5fa 100644 --- a/net/netfilter/ipset/ip_set_hash_gen.h +++ b/net/netfilter/ipset/ip_set_hash_gen.h @@ -15,7 +15,7 @@ #define __ipset_dereference_protected(p, c) rcu_dereference_protected(p, c) #define ipset_dereference_protected(p, set) \ - __ipset_dereference_protected(p, spin_is_locked(&(set)->lock)) + __ipset_dereference_protected(p, lockdep_is_held(&(set)->lock)) #define rcu_dereference_bh_nfnl(p) rcu_dereference_bh_check(p, 1) diff --git a/net/netfilter/ipvs/ip_vs_ctl.c b/net/netfilter/ipvs/ip_vs_ctl.c index 62eefea48973..83395bf6dc35 100644 --- a/net/netfilter/ipvs/ip_vs_ctl.c +++ b/net/netfilter/ipvs/ip_vs_ctl.c @@ -3234,7 +3234,7 @@ static int ip_vs_genl_dump_dests(struct sk_buff *skb, /* Try to find the service for which to dump destinations */ if (nlmsg_parse(cb->nlh, GENL_HDRLEN, attrs, IPVS_CMD_ATTR_MAX, - ip_vs_cmd_policy, NULL)) + ip_vs_cmd_policy, cb->extack)) goto out_err; diff --git a/net/netfilter/ipvs/ip_vs_sync.c b/net/netfilter/ipvs/ip_vs_sync.c index d4020c5e831d..2526be6b3d90 100644 --- a/net/netfilter/ipvs/ip_vs_sync.c +++ b/net/netfilter/ipvs/ip_vs_sync.c @@ -1616,7 +1616,7 @@ ip_vs_receive(struct socket *sock, char *buffer, const size_t buflen) EnterFunction(7); /* Receive a packet */ - iov_iter_kvec(&msg.msg_iter, READ | ITER_KVEC, &iov, 1, buflen); + iov_iter_kvec(&msg.msg_iter, READ, &iov, 1, buflen); len = sock_recvmsg(sock, &msg, MSG_DONTWAIT); if (len < 0) return len; diff --git a/net/netfilter/nf_conntrack_core.c b/net/netfilter/nf_conntrack_core.c index a676d5f76bdc..ca1168d67fac 100644 --- a/net/netfilter/nf_conntrack_core.c +++ b/net/netfilter/nf_conntrack_core.c @@ -379,7 +379,7 @@ bool nf_ct_get_tuplepr(const struct sk_buff *skb, unsigned int nhoff, return false; } - l4proto = __nf_ct_l4proto_find(l3num, protonum); + l4proto = __nf_ct_l4proto_find(protonum); ret = nf_ct_get_tuple(skb, nhoff, protoff, l3num, protonum, net, tuple, l4proto); @@ -539,7 +539,7 @@ destroy_conntrack(struct nf_conntrack *nfct) nf_ct_tmpl_free(ct); return; } - l4proto = __nf_ct_l4proto_find(nf_ct_l3num(ct), nf_ct_protonum(ct)); + l4proto = __nf_ct_l4proto_find(nf_ct_protonum(ct)); if (l4proto->destroy) l4proto->destroy(ct); @@ -840,7 +840,7 @@ static int nf_ct_resolve_clash(struct net *net, struct sk_buff *skb, enum ip_conntrack_info oldinfo; struct nf_conn *loser_ct = nf_ct_get(skb, &oldinfo); - l4proto = __nf_ct_l4proto_find(nf_ct_l3num(ct), nf_ct_protonum(ct)); + l4proto = __nf_ct_l4proto_find(nf_ct_protonum(ct)); if (l4proto->allow_clash && !nf_ct_is_dying(ct) && atomic_inc_not_zero(&ct->ct_general.use)) { @@ -1109,7 +1109,7 @@ static bool gc_worker_can_early_drop(const struct nf_conn *ct) if (!test_bit(IPS_ASSURED_BIT, &ct->status)) return true; - l4proto = __nf_ct_l4proto_find(nf_ct_l3num(ct), nf_ct_protonum(ct)); + l4proto = __nf_ct_l4proto_find(nf_ct_protonum(ct)); if (l4proto->can_early_drop && l4proto->can_early_drop(ct)) return true; @@ -1370,12 +1370,6 @@ init_conntrack(struct net *net, struct nf_conn *tmpl, timeout_ext = tmpl ? nf_ct_timeout_find(tmpl) : NULL; - if (!l4proto->new(ct, skb, dataoff)) { - nf_conntrack_free(ct); - pr_debug("can't track with proto module\n"); - return NULL; - } - if (timeout_ext) nf_ct_timeout_ext_add(ct, rcu_dereference(timeout_ext->timeout), GFP_ATOMIC); @@ -1436,12 +1430,12 @@ init_conntrack(struct net *net, struct nf_conn *tmpl, /* On success, returns 0, sets skb->_nfct | ctinfo */ static int -resolve_normal_ct(struct net *net, struct nf_conn *tmpl, +resolve_normal_ct(struct nf_conn *tmpl, struct sk_buff *skb, unsigned int dataoff, - u_int16_t l3num, u_int8_t protonum, - const struct nf_conntrack_l4proto *l4proto) + const struct nf_conntrack_l4proto *l4proto, + const struct nf_hook_state *state) { const struct nf_conntrack_zone *zone; struct nf_conntrack_tuple tuple; @@ -1452,17 +1446,18 @@ resolve_normal_ct(struct net *net, struct nf_conn *tmpl, u32 hash; if (!nf_ct_get_tuple(skb, skb_network_offset(skb), - dataoff, l3num, protonum, net, &tuple, l4proto)) { + dataoff, state->pf, protonum, state->net, + &tuple, l4proto)) { pr_debug("Can't get tuple\n"); return 0; } /* look for tuple match */ zone = nf_ct_zone_tmpl(tmpl, skb, &tmp); - hash = hash_conntrack_raw(&tuple, net); - h = __nf_conntrack_find_get(net, zone, &tuple, hash); + hash = hash_conntrack_raw(&tuple, state->net); + h = __nf_conntrack_find_get(state->net, zone, &tuple, hash); if (!h) { - h = init_conntrack(net, tmpl, &tuple, l4proto, + h = init_conntrack(state->net, tmpl, &tuple, l4proto, skb, dataoff, hash); if (!h) return 0; @@ -1491,13 +1486,45 @@ resolve_normal_ct(struct net *net, struct nf_conn *tmpl, return 0; } +/* + * icmp packets need special treatment to handle error messages that are + * related to a connection. + * + * Callers need to check if skb has a conntrack assigned when this + * helper returns; in such case skb belongs to an already known connection. + */ +static unsigned int __cold +nf_conntrack_handle_icmp(struct nf_conn *tmpl, + struct sk_buff *skb, + unsigned int dataoff, + u8 protonum, + const struct nf_hook_state *state) +{ + int ret; + + if (state->pf == NFPROTO_IPV4 && protonum == IPPROTO_ICMP) + ret = nf_conntrack_icmpv4_error(tmpl, skb, dataoff, state); +#if IS_ENABLED(CONFIG_IPV6) + else if (state->pf == NFPROTO_IPV6 && protonum == IPPROTO_ICMPV6) + ret = nf_conntrack_icmpv6_error(tmpl, skb, dataoff, state); +#endif + else + return NF_ACCEPT; + + if (ret <= 0) { + NF_CT_STAT_INC_ATOMIC(state->net, error); + NF_CT_STAT_INC_ATOMIC(state->net, invalid); + } + + return ret; +} + unsigned int -nf_conntrack_in(struct net *net, u_int8_t pf, unsigned int hooknum, - struct sk_buff *skb) +nf_conntrack_in(struct sk_buff *skb, const struct nf_hook_state *state) { const struct nf_conntrack_l4proto *l4proto; - struct nf_conn *ct, *tmpl; enum ip_conntrack_info ctinfo; + struct nf_conn *ct, *tmpl; u_int8_t protonum; int dataoff, ret; @@ -1506,32 +1533,28 @@ nf_conntrack_in(struct net *net, u_int8_t pf, unsigned int hooknum, /* Previously seen (loopback or untracked)? Ignore. */ if ((tmpl && !nf_ct_is_template(tmpl)) || ctinfo == IP_CT_UNTRACKED) { - NF_CT_STAT_INC_ATOMIC(net, ignore); + NF_CT_STAT_INC_ATOMIC(state->net, ignore); return NF_ACCEPT; } skb->_nfct = 0; } /* rcu_read_lock()ed by nf_hook_thresh */ - dataoff = get_l4proto(skb, skb_network_offset(skb), pf, &protonum); + dataoff = get_l4proto(skb, skb_network_offset(skb), state->pf, &protonum); if (dataoff <= 0) { pr_debug("not prepared to track yet or error occurred\n"); - NF_CT_STAT_INC_ATOMIC(net, error); - NF_CT_STAT_INC_ATOMIC(net, invalid); + NF_CT_STAT_INC_ATOMIC(state->net, error); + NF_CT_STAT_INC_ATOMIC(state->net, invalid); ret = NF_ACCEPT; goto out; } - l4proto = __nf_ct_l4proto_find(pf, protonum); + l4proto = __nf_ct_l4proto_find(protonum); - /* It may be an special packet, error, unclean... - * inverse of the return code tells to the netfilter - * core what to do with the packet. */ - if (l4proto->error != NULL) { - ret = l4proto->error(net, tmpl, skb, dataoff, pf, hooknum); + if (protonum == IPPROTO_ICMP || protonum == IPPROTO_ICMPV6) { + ret = nf_conntrack_handle_icmp(tmpl, skb, dataoff, + protonum, state); if (ret <= 0) { - NF_CT_STAT_INC_ATOMIC(net, error); - NF_CT_STAT_INC_ATOMIC(net, invalid); ret = -ret; goto out; } @@ -1540,10 +1563,11 @@ nf_conntrack_in(struct net *net, u_int8_t pf, unsigned int hooknum, goto out; } repeat: - ret = resolve_normal_ct(net, tmpl, skb, dataoff, pf, protonum, l4proto); + ret = resolve_normal_ct(tmpl, skb, dataoff, + protonum, l4proto, state); if (ret < 0) { /* Too stressed to deal. */ - NF_CT_STAT_INC_ATOMIC(net, drop); + NF_CT_STAT_INC_ATOMIC(state->net, drop); ret = NF_DROP; goto out; } @@ -1551,21 +1575,21 @@ repeat: ct = nf_ct_get(skb, &ctinfo); if (!ct) { /* Not valid part of a connection */ - NF_CT_STAT_INC_ATOMIC(net, invalid); + NF_CT_STAT_INC_ATOMIC(state->net, invalid); ret = NF_ACCEPT; goto out; } - ret = l4proto->packet(ct, skb, dataoff, ctinfo); + ret = l4proto->packet(ct, skb, dataoff, ctinfo, state); if (ret <= 0) { /* Invalid: inverse of the return code tells * the netfilter core what to do */ pr_debug("nf_conntrack_in: Can't track with proto module\n"); nf_conntrack_put(&ct->ct_general); skb->_nfct = 0; - NF_CT_STAT_INC_ATOMIC(net, invalid); + NF_CT_STAT_INC_ATOMIC(state->net, invalid); if (ret == -NF_DROP) - NF_CT_STAT_INC_ATOMIC(net, drop); + NF_CT_STAT_INC_ATOMIC(state->net, drop); /* Special case: TCP tracker reports an attempt to reopen a * closed/aborted connection. We have to go back and create a * fresh conntrack. @@ -1594,8 +1618,7 @@ bool nf_ct_invert_tuplepr(struct nf_conntrack_tuple *inverse, rcu_read_lock(); ret = nf_ct_invert_tuple(inverse, orig, - __nf_ct_l4proto_find(orig->src.l3num, - orig->dst.protonum)); + __nf_ct_l4proto_find(orig->dst.protonum)); rcu_read_unlock(); return ret; } @@ -1752,7 +1775,7 @@ static int nf_conntrack_update(struct net *net, struct sk_buff *skb) if (dataoff <= 0) return -1; - l4proto = nf_ct_l4proto_find_get(l3num, l4num); + l4proto = nf_ct_l4proto_find_get(l4num); if (!nf_ct_get_tuple(skb, skb_network_offset(skb), dataoff, l3num, l4num, net, &tuple, l4proto)) diff --git a/net/netfilter/nf_conntrack_expect.c b/net/netfilter/nf_conntrack_expect.c index 27b84231db10..3034038bfdf0 100644 --- a/net/netfilter/nf_conntrack_expect.c +++ b/net/netfilter/nf_conntrack_expect.c @@ -610,8 +610,7 @@ static int exp_seq_show(struct seq_file *s, void *v) expect->tuple.src.l3num, expect->tuple.dst.protonum); print_tuple(s, &expect->tuple, - __nf_ct_l4proto_find(expect->tuple.src.l3num, - expect->tuple.dst.protonum)); + __nf_ct_l4proto_find(expect->tuple.dst.protonum)); if (expect->flags & NF_CT_EXPECT_PERMANENT) { seq_puts(s, "PERMANENT"); diff --git a/net/netfilter/nf_conntrack_netlink.c b/net/netfilter/nf_conntrack_netlink.c index 036207ecaf16..4ae8e528943a 100644 --- a/net/netfilter/nf_conntrack_netlink.c +++ b/net/netfilter/nf_conntrack_netlink.c @@ -135,8 +135,7 @@ static int ctnetlink_dump_tuples(struct sk_buff *skb, ret = ctnetlink_dump_tuples_ip(skb, tuple); if (ret >= 0) { - l4proto = __nf_ct_l4proto_find(tuple->src.l3num, - tuple->dst.protonum); + l4proto = __nf_ct_l4proto_find(tuple->dst.protonum); ret = ctnetlink_dump_tuples_proto(skb, tuple, l4proto); } rcu_read_unlock(); @@ -184,7 +183,7 @@ static int ctnetlink_dump_protoinfo(struct sk_buff *skb, struct nf_conn *ct) struct nlattr *nest_proto; int ret; - l4proto = __nf_ct_l4proto_find(nf_ct_l3num(ct), nf_ct_protonum(ct)); + l4proto = __nf_ct_l4proto_find(nf_ct_protonum(ct)); if (!l4proto->to_nlattr) return 0; @@ -592,7 +591,7 @@ static size_t ctnetlink_proto_size(const struct nf_conn *ct) len = nla_policy_len(cta_ip_nla_policy, CTA_IP_MAX + 1); len *= 3u; /* ORIG, REPLY, MASTER */ - l4proto = __nf_ct_l4proto_find(nf_ct_l3num(ct), nf_ct_protonum(ct)); + l4proto = __nf_ct_l4proto_find(nf_ct_protonum(ct)); len += l4proto->nlattr_size; if (l4proto->nlattr_tuple_size) { len4 = l4proto->nlattr_tuple_size(); @@ -821,6 +820,7 @@ static int ctnetlink_done(struct netlink_callback *cb) } struct ctnetlink_filter { + u8 family; struct { u_int32_t val; u_int32_t mask; @@ -828,31 +828,39 @@ struct ctnetlink_filter { }; static struct ctnetlink_filter * -ctnetlink_alloc_filter(const struct nlattr * const cda[]) +ctnetlink_alloc_filter(const struct nlattr * const cda[], u8 family) { -#ifdef CONFIG_NF_CONNTRACK_MARK struct ctnetlink_filter *filter; +#ifndef CONFIG_NF_CONNTRACK_MARK + if (cda[CTA_MARK] && cda[CTA_MARK_MASK]) + return ERR_PTR(-EOPNOTSUPP); +#endif + filter = kzalloc(sizeof(*filter), GFP_KERNEL); if (filter == NULL) return ERR_PTR(-ENOMEM); - filter->mark.val = ntohl(nla_get_be32(cda[CTA_MARK])); - filter->mark.mask = ntohl(nla_get_be32(cda[CTA_MARK_MASK])); + filter->family = family; - return filter; -#else - return ERR_PTR(-EOPNOTSUPP); +#ifdef CONFIG_NF_CONNTRACK_MARK + if (cda[CTA_MARK] && cda[CTA_MARK_MASK]) { + filter->mark.val = ntohl(nla_get_be32(cda[CTA_MARK])); + filter->mark.mask = ntohl(nla_get_be32(cda[CTA_MARK_MASK])); + } #endif + return filter; } static int ctnetlink_start(struct netlink_callback *cb) { const struct nlattr * const *cda = cb->data; struct ctnetlink_filter *filter = NULL; + struct nfgenmsg *nfmsg = nlmsg_data(cb->nlh); + u8 family = nfmsg->nfgen_family; - if (cda[CTA_MARK] && cda[CTA_MARK_MASK]) { - filter = ctnetlink_alloc_filter(cda); + if (family || (cda[CTA_MARK] && cda[CTA_MARK_MASK])) { + filter = ctnetlink_alloc_filter(cda, family); if (IS_ERR(filter)) return PTR_ERR(filter); } @@ -866,13 +874,24 @@ static int ctnetlink_filter_match(struct nf_conn *ct, void *data) struct ctnetlink_filter *filter = data; if (filter == NULL) - return 1; + goto out; + + /* Match entries of a given L3 protocol number. + * If it is not specified, ie. l3proto == 0, + * then match everything. + */ + if (filter->family && nf_ct_l3num(ct) != filter->family) + goto ignore_entry; #ifdef CONFIG_NF_CONNTRACK_MARK - if ((ct->mark & filter->mark.mask) == filter->mark.val) - return 1; + if ((ct->mark & filter->mark.mask) != filter->mark.val) + goto ignore_entry; #endif +out: + return 1; + +ignore_entry: return 0; } @@ -883,8 +902,6 @@ ctnetlink_dump_table(struct sk_buff *skb, struct netlink_callback *cb) struct nf_conn *ct, *last; struct nf_conntrack_tuple_hash *h; struct hlist_nulls_node *n; - struct nfgenmsg *nfmsg = nlmsg_data(cb->nlh); - u_int8_t l3proto = nfmsg->nfgen_family; struct nf_conn *nf_ct_evict[8]; int res, i; spinlock_t *lockp; @@ -923,11 +940,6 @@ restart: if (!net_eq(net, nf_ct_net(ct))) continue; - /* Dump entries of a given L3 protocol number. - * If it is not specified, ie. l3proto == 0, - * then dump everything. */ - if (l3proto && nf_ct_l3num(ct) != l3proto) - continue; if (cb->args[1]) { if (ct != last) continue; @@ -1048,7 +1060,7 @@ static int ctnetlink_parse_tuple_proto(struct nlattr *attr, tuple->dst.protonum = nla_get_u8(tb[CTA_PROTO_NUM]); rcu_read_lock(); - l4proto = __nf_ct_l4proto_find(tuple->src.l3num, tuple->dst.protonum); + l4proto = __nf_ct_l4proto_find(tuple->dst.protonum); if (likely(l4proto->nlattr_to_tuple)) { ret = nla_validate_nested(attr, CTA_PROTO_MAX, @@ -1213,12 +1225,12 @@ static int ctnetlink_flush_iterate(struct nf_conn *ct, void *data) static int ctnetlink_flush_conntrack(struct net *net, const struct nlattr * const cda[], - u32 portid, int report) + u32 portid, int report, u8 family) { struct ctnetlink_filter *filter = NULL; - if (cda[CTA_MARK] && cda[CTA_MARK_MASK]) { - filter = ctnetlink_alloc_filter(cda); + if (family || (cda[CTA_MARK] && cda[CTA_MARK_MASK])) { + filter = ctnetlink_alloc_filter(cda, family); if (IS_ERR(filter)) return PTR_ERR(filter); } @@ -1257,7 +1269,7 @@ static int ctnetlink_del_conntrack(struct net *net, struct sock *ctnl, else { return ctnetlink_flush_conntrack(net, cda, NETLINK_CB(skb).portid, - nlmsg_report(nlh)); + nlmsg_report(nlh), u3); } if (err < 0) @@ -1696,7 +1708,7 @@ static int ctnetlink_change_protoinfo(struct nf_conn *ct, return err; rcu_read_lock(); - l4proto = __nf_ct_l4proto_find(nf_ct_l3num(ct), nf_ct_protonum(ct)); + l4proto = __nf_ct_l4proto_find(nf_ct_protonum(ct)); if (l4proto->from_nlattr) err = l4proto->from_nlattr(tb, ct); rcu_read_unlock(); @@ -2656,8 +2668,7 @@ static int ctnetlink_exp_dump_mask(struct sk_buff *skb, rcu_read_lock(); ret = ctnetlink_dump_tuples_ip(skb, &m); if (ret >= 0) { - l4proto = __nf_ct_l4proto_find(tuple->src.l3num, - tuple->dst.protonum); + l4proto = __nf_ct_l4proto_find(tuple->dst.protonum); ret = ctnetlink_dump_tuples_proto(skb, &m, l4proto); } rcu_read_unlock(); diff --git a/net/netfilter/nf_conntrack_proto.c b/net/netfilter/nf_conntrack_proto.c index 51c5d7eec0a3..40643af7137e 100644 --- a/net/netfilter/nf_conntrack_proto.c +++ b/net/netfilter/nf_conntrack_proto.c @@ -43,7 +43,7 @@ extern unsigned int nf_conntrack_net_id; -static struct nf_conntrack_l4proto __rcu **nf_ct_protos[NFPROTO_NUMPROTO] __read_mostly; +static struct nf_conntrack_l4proto __rcu *nf_ct_protos[MAX_NF_CT_PROTO + 1] __read_mostly; static DEFINE_MUTEX(nf_ct_proto_mutex); @@ -124,23 +124,21 @@ void nf_ct_l4proto_log_invalid(const struct sk_buff *skb, EXPORT_SYMBOL_GPL(nf_ct_l4proto_log_invalid); #endif -const struct nf_conntrack_l4proto * -__nf_ct_l4proto_find(u_int16_t l3proto, u_int8_t l4proto) +const struct nf_conntrack_l4proto *__nf_ct_l4proto_find(u8 l4proto) { - if (unlikely(l3proto >= NFPROTO_NUMPROTO || nf_ct_protos[l3proto] == NULL)) + if (unlikely(l4proto >= ARRAY_SIZE(nf_ct_protos))) return &nf_conntrack_l4proto_generic; - return rcu_dereference(nf_ct_protos[l3proto][l4proto]); + return rcu_dereference(nf_ct_protos[l4proto]); } EXPORT_SYMBOL_GPL(__nf_ct_l4proto_find); -const struct nf_conntrack_l4proto * -nf_ct_l4proto_find_get(u_int16_t l3num, u_int8_t l4num) +const struct nf_conntrack_l4proto *nf_ct_l4proto_find_get(u8 l4num) { const struct nf_conntrack_l4proto *p; rcu_read_lock(); - p = __nf_ct_l4proto_find(l3num, l4num); + p = __nf_ct_l4proto_find(l4num); if (!try_module_get(p->me)) p = &nf_conntrack_l4proto_generic; rcu_read_unlock(); @@ -159,8 +157,7 @@ static int kill_l4proto(struct nf_conn *i, void *data) { const struct nf_conntrack_l4proto *l4proto; l4proto = data; - return nf_ct_protonum(i) == l4proto->l4proto && - nf_ct_l3num(i) == l4proto->l3proto; + return nf_ct_protonum(i) == l4proto->l4proto; } static struct nf_proto_net *nf_ct_l4proto_net(struct net *net, @@ -219,48 +216,20 @@ int nf_ct_l4proto_register_one(const struct nf_conntrack_l4proto *l4proto) { int ret = 0; - if (l4proto->l3proto >= ARRAY_SIZE(nf_ct_protos)) - return -EBUSY; - if ((l4proto->to_nlattr && l4proto->nlattr_size == 0) || (l4proto->tuple_to_nlattr && !l4proto->nlattr_tuple_size)) return -EINVAL; mutex_lock(&nf_ct_proto_mutex); - if (!nf_ct_protos[l4proto->l3proto]) { - /* l3proto may be loaded latter. */ - struct nf_conntrack_l4proto __rcu **proto_array; - int i; - - proto_array = - kmalloc_array(MAX_NF_CT_PROTO, - sizeof(struct nf_conntrack_l4proto *), - GFP_KERNEL); - if (proto_array == NULL) { - ret = -ENOMEM; - goto out_unlock; - } - - for (i = 0; i < MAX_NF_CT_PROTO; i++) - RCU_INIT_POINTER(proto_array[i], - &nf_conntrack_l4proto_generic); - - /* Before making proto_array visible to lockless readers, - * we must make sure its content is committed to memory. - */ - smp_wmb(); - - nf_ct_protos[l4proto->l3proto] = proto_array; - } else if (rcu_dereference_protected( - nf_ct_protos[l4proto->l3proto][l4proto->l4proto], + if (rcu_dereference_protected( + nf_ct_protos[l4proto->l4proto], lockdep_is_held(&nf_ct_proto_mutex) ) != &nf_conntrack_l4proto_generic) { ret = -EBUSY; goto out_unlock; } - rcu_assign_pointer(nf_ct_protos[l4proto->l3proto][l4proto->l4proto], - l4proto); + rcu_assign_pointer(nf_ct_protos[l4proto->l4proto], l4proto); out_unlock: mutex_unlock(&nf_ct_proto_mutex); return ret; @@ -274,7 +243,7 @@ int nf_ct_l4proto_pernet_register_one(struct net *net, struct nf_proto_net *pn = NULL; if (l4proto->init_net) { - ret = l4proto->init_net(net, l4proto->l3proto); + ret = l4proto->init_net(net); if (ret < 0) goto out; } @@ -296,13 +265,13 @@ EXPORT_SYMBOL_GPL(nf_ct_l4proto_pernet_register_one); static void __nf_ct_l4proto_unregister_one(const struct nf_conntrack_l4proto *l4proto) { - BUG_ON(l4proto->l3proto >= ARRAY_SIZE(nf_ct_protos)); + BUG_ON(l4proto->l4proto >= ARRAY_SIZE(nf_ct_protos)); BUG_ON(rcu_dereference_protected( - nf_ct_protos[l4proto->l3proto][l4proto->l4proto], + nf_ct_protos[l4proto->l4proto], lockdep_is_held(&nf_ct_proto_mutex) ) != l4proto); - rcu_assign_pointer(nf_ct_protos[l4proto->l3proto][l4proto->l4proto], + rcu_assign_pointer(nf_ct_protos[l4proto->l4proto], &nf_conntrack_l4proto_generic); } @@ -352,7 +321,7 @@ static int nf_ct_l4proto_register(const struct nf_conntrack_l4proto * const l4proto[], unsigned int num_proto) { - int ret = -EINVAL, ver; + int ret = -EINVAL; unsigned int i; for (i = 0; i < num_proto; i++) { @@ -361,9 +330,8 @@ nf_ct_l4proto_register(const struct nf_conntrack_l4proto * const l4proto[], break; } if (i != num_proto) { - ver = l4proto[i]->l3proto == PF_INET6 ? 6 : 4; - pr_err("nf_conntrack_ipv%d: can't register l4 %d proto.\n", - ver, l4proto[i]->l4proto); + pr_err("nf_conntrack: can't register l4 %d proto.\n", + l4proto[i]->l4proto); nf_ct_l4proto_unregister(l4proto, i); } return ret; @@ -382,9 +350,8 @@ int nf_ct_l4proto_pernet_register(struct net *net, break; } if (i != num_proto) { - pr_err("nf_conntrack_proto_%d %d: pernet registration failed\n", - l4proto[i]->l4proto, - l4proto[i]->l3proto == PF_INET6 ? 6 : 4); + pr_err("nf_conntrack %d: pernet registration failed\n", + l4proto[i]->l4proto); nf_ct_l4proto_pernet_unregister(net, l4proto, i); } return ret; @@ -455,7 +422,7 @@ static unsigned int ipv4_conntrack_in(void *priv, struct sk_buff *skb, const struct nf_hook_state *state) { - return nf_conntrack_in(state->net, PF_INET, state->hook, skb); + return nf_conntrack_in(skb, state); } static unsigned int ipv4_conntrack_local(void *priv, @@ -477,7 +444,7 @@ static unsigned int ipv4_conntrack_local(void *priv, return NF_ACCEPT; } - return nf_conntrack_in(state->net, PF_INET, state->hook, skb); + return nf_conntrack_in(skb, state); } /* Connection tracking may drop packets, but never alters them, so @@ -690,14 +657,14 @@ static unsigned int ipv6_conntrack_in(void *priv, struct sk_buff *skb, const struct nf_hook_state *state) { - return nf_conntrack_in(state->net, PF_INET6, state->hook, skb); + return nf_conntrack_in(skb, state); } static unsigned int ipv6_conntrack_local(void *priv, struct sk_buff *skb, const struct nf_hook_state *state) { - return nf_conntrack_in(state->net, PF_INET6, state->hook, skb); + return nf_conntrack_in(skb, state); } static unsigned int ipv6_helper(void *priv, @@ -911,37 +878,26 @@ void nf_ct_netns_put(struct net *net, uint8_t nfproto) EXPORT_SYMBOL_GPL(nf_ct_netns_put); static const struct nf_conntrack_l4proto * const builtin_l4proto[] = { - &nf_conntrack_l4proto_tcp4, - &nf_conntrack_l4proto_udp4, + &nf_conntrack_l4proto_tcp, + &nf_conntrack_l4proto_udp, &nf_conntrack_l4proto_icmp, #ifdef CONFIG_NF_CT_PROTO_DCCP - &nf_conntrack_l4proto_dccp4, + &nf_conntrack_l4proto_dccp, #endif #ifdef CONFIG_NF_CT_PROTO_SCTP - &nf_conntrack_l4proto_sctp4, + &nf_conntrack_l4proto_sctp, #endif #ifdef CONFIG_NF_CT_PROTO_UDPLITE - &nf_conntrack_l4proto_udplite4, + &nf_conntrack_l4proto_udplite, #endif #if IS_ENABLED(CONFIG_IPV6) - &nf_conntrack_l4proto_tcp6, - &nf_conntrack_l4proto_udp6, &nf_conntrack_l4proto_icmpv6, -#ifdef CONFIG_NF_CT_PROTO_DCCP - &nf_conntrack_l4proto_dccp6, -#endif -#ifdef CONFIG_NF_CT_PROTO_SCTP - &nf_conntrack_l4proto_sctp6, -#endif -#ifdef CONFIG_NF_CT_PROTO_UDPLITE - &nf_conntrack_l4proto_udplite6, -#endif #endif /* CONFIG_IPV6 */ }; int nf_conntrack_proto_init(void) { - int ret = 0; + int ret = 0, i; ret = nf_register_sockopt(&so_getorigdst); if (ret < 0) @@ -952,6 +908,11 @@ int nf_conntrack_proto_init(void) if (ret < 0) goto cleanup_sockopt; #endif + + for (i = 0; i < ARRAY_SIZE(nf_ct_protos); i++) + RCU_INIT_POINTER(nf_ct_protos[i], + &nf_conntrack_l4proto_generic); + ret = nf_ct_l4proto_register(builtin_l4proto, ARRAY_SIZE(builtin_l4proto)); if (ret < 0) @@ -969,17 +930,10 @@ cleanup_sockopt: void nf_conntrack_proto_fini(void) { - unsigned int i; - nf_unregister_sockopt(&so_getorigdst); #if IS_ENABLED(CONFIG_IPV6) nf_unregister_sockopt(&so_getorigdst6); #endif - /* No need to call nf_ct_l4proto_unregister(), the register - * tables are free'd here anyway. - */ - for (i = 0; i < ARRAY_SIZE(nf_ct_protos); i++) - kfree(nf_ct_protos[i]); } int nf_conntrack_proto_pernet_init(struct net *net) @@ -988,8 +942,7 @@ int nf_conntrack_proto_pernet_init(struct net *net) struct nf_proto_net *pn = nf_ct_l4proto_net(net, &nf_conntrack_l4proto_generic); - err = nf_conntrack_l4proto_generic.init_net(net, - nf_conntrack_l4proto_generic.l3proto); + err = nf_conntrack_l4proto_generic.init_net(net); if (err < 0) return err; err = nf_ct_l4proto_register_sysctl(net, diff --git a/net/netfilter/nf_conntrack_proto_dccp.c b/net/netfilter/nf_conntrack_proto_dccp.c index f3f91ed2c21a..171e9e122e5f 100644 --- a/net/netfilter/nf_conntrack_proto_dccp.c +++ b/net/netfilter/nf_conntrack_proto_dccp.c @@ -389,18 +389,15 @@ static inline struct nf_dccp_net *dccp_pernet(struct net *net) return &net->ct.nf_ct_proto.dccp; } -static bool dccp_new(struct nf_conn *ct, const struct sk_buff *skb, - unsigned int dataoff) +static noinline bool +dccp_new(struct nf_conn *ct, const struct sk_buff *skb, + const struct dccp_hdr *dh) { struct net *net = nf_ct_net(ct); struct nf_dccp_net *dn; - struct dccp_hdr _dh, *dh; const char *msg; u_int8_t state; - dh = skb_header_pointer(skb, dataoff, sizeof(_dh), &_dh); - BUG_ON(dh == NULL); - state = dccp_state_table[CT_DCCP_ROLE_CLIENT][dh->dccph_type][CT_DCCP_NONE]; switch (state) { default: @@ -438,8 +435,51 @@ static u64 dccp_ack_seq(const struct dccp_hdr *dh) ntohl(dhack->dccph_ack_nr_low); } -static int dccp_packet(struct nf_conn *ct, const struct sk_buff *skb, - unsigned int dataoff, enum ip_conntrack_info ctinfo) +static bool dccp_error(const struct dccp_hdr *dh, + struct sk_buff *skb, unsigned int dataoff, + const struct nf_hook_state *state) +{ + unsigned int dccp_len = skb->len - dataoff; + unsigned int cscov; + const char *msg; + + if (dh->dccph_doff * 4 < sizeof(struct dccp_hdr) || + dh->dccph_doff * 4 > dccp_len) { + msg = "nf_ct_dccp: truncated/malformed packet "; + goto out_invalid; + } + + cscov = dccp_len; + if (dh->dccph_cscov) { + cscov = (dh->dccph_cscov - 1) * 4; + if (cscov > dccp_len) { + msg = "nf_ct_dccp: bad checksum coverage "; + goto out_invalid; + } + } + + if (state->hook == NF_INET_PRE_ROUTING && + state->net->ct.sysctl_checksum && + nf_checksum_partial(skb, state->hook, dataoff, cscov, + IPPROTO_DCCP, state->pf)) { + msg = "nf_ct_dccp: bad checksum "; + goto out_invalid; + } + + if (dh->dccph_type >= DCCP_PKT_INVALID) { + msg = "nf_ct_dccp: reserved packet type "; + goto out_invalid; + } + return false; +out_invalid: + nf_l4proto_log_invalid(skb, state->net, state->pf, + IPPROTO_DCCP, "%s", msg); + return true; +} + +static int dccp_packet(struct nf_conn *ct, struct sk_buff *skb, + unsigned int dataoff, enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) { enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); struct dccp_hdr _dh, *dh; @@ -448,8 +488,15 @@ static int dccp_packet(struct nf_conn *ct, const struct sk_buff *skb, unsigned int *timeouts; dh = skb_header_pointer(skb, dataoff, sizeof(_dh), &_dh); - BUG_ON(dh == NULL); + if (!dh) + return NF_DROP; + + if (dccp_error(dh, skb, dataoff, state)) + return -NF_ACCEPT; + type = dh->dccph_type; + if (!nf_ct_is_confirmed(ct) && !dccp_new(ct, skb, dh)) + return -NF_ACCEPT; if (type == DCCP_PKT_RESET && !test_bit(IPS_SEEN_REPLY_BIT, &ct->status)) { @@ -527,55 +574,6 @@ static int dccp_packet(struct nf_conn *ct, const struct sk_buff *skb, return NF_ACCEPT; } -static int dccp_error(struct net *net, struct nf_conn *tmpl, - struct sk_buff *skb, unsigned int dataoff, - u_int8_t pf, unsigned int hooknum) -{ - struct dccp_hdr _dh, *dh; - unsigned int dccp_len = skb->len - dataoff; - unsigned int cscov; - const char *msg; - - dh = skb_header_pointer(skb, dataoff, sizeof(_dh), &_dh); - if (dh == NULL) { - msg = "nf_ct_dccp: short packet "; - goto out_invalid; - } - - if (dh->dccph_doff * 4 < sizeof(struct dccp_hdr) || - dh->dccph_doff * 4 > dccp_len) { - msg = "nf_ct_dccp: truncated/malformed packet "; - goto out_invalid; - } - - cscov = dccp_len; - if (dh->dccph_cscov) { - cscov = (dh->dccph_cscov - 1) * 4; - if (cscov > dccp_len) { - msg = "nf_ct_dccp: bad checksum coverage "; - goto out_invalid; - } - } - - if (net->ct.sysctl_checksum && hooknum == NF_INET_PRE_ROUTING && - nf_checksum_partial(skb, hooknum, dataoff, cscov, IPPROTO_DCCP, - pf)) { - msg = "nf_ct_dccp: bad checksum "; - goto out_invalid; - } - - if (dh->dccph_type >= DCCP_PKT_INVALID) { - msg = "nf_ct_dccp: reserved packet type "; - goto out_invalid; - } - - return NF_ACCEPT; - -out_invalid: - nf_l4proto_log_invalid(skb, net, pf, IPPROTO_DCCP, "%s", msg); - return -NF_ACCEPT; -} - static bool dccp_can_early_drop(const struct nf_conn *ct) { switch (ct->proto.dccp.state) { @@ -814,7 +812,7 @@ static int dccp_kmemdup_sysctl_table(struct net *net, struct nf_proto_net *pn, return 0; } -static int dccp_init_net(struct net *net, u_int16_t proto) +static int dccp_init_net(struct net *net) { struct nf_dccp_net *dn = dccp_pernet(net); struct nf_proto_net *pn = &dn->pn; @@ -844,45 +842,9 @@ static struct nf_proto_net *dccp_get_net_proto(struct net *net) return &net->ct.nf_ct_proto.dccp.pn; } -const struct nf_conntrack_l4proto nf_conntrack_l4proto_dccp4 = { - .l3proto = AF_INET, - .l4proto = IPPROTO_DCCP, - .new = dccp_new, - .packet = dccp_packet, - .error = dccp_error, - .can_early_drop = dccp_can_early_drop, -#ifdef CONFIG_NF_CONNTRACK_PROCFS - .print_conntrack = dccp_print_conntrack, -#endif -#if IS_ENABLED(CONFIG_NF_CT_NETLINK) - .nlattr_size = DCCP_NLATTR_SIZE, - .to_nlattr = dccp_to_nlattr, - .from_nlattr = nlattr_to_dccp, - .tuple_to_nlattr = nf_ct_port_tuple_to_nlattr, - .nlattr_tuple_size = nf_ct_port_nlattr_tuple_size, - .nlattr_to_tuple = nf_ct_port_nlattr_to_tuple, - .nla_policy = nf_ct_port_nla_policy, -#endif -#ifdef CONFIG_NF_CONNTRACK_TIMEOUT - .ctnl_timeout = { - .nlattr_to_obj = dccp_timeout_nlattr_to_obj, - .obj_to_nlattr = dccp_timeout_obj_to_nlattr, - .nlattr_max = CTA_TIMEOUT_DCCP_MAX, - .obj_size = sizeof(unsigned int) * CT_DCCP_MAX, - .nla_policy = dccp_timeout_nla_policy, - }, -#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ - .init_net = dccp_init_net, - .get_net_proto = dccp_get_net_proto, -}; -EXPORT_SYMBOL_GPL(nf_conntrack_l4proto_dccp4); - -const struct nf_conntrack_l4proto nf_conntrack_l4proto_dccp6 = { - .l3proto = AF_INET6, +const struct nf_conntrack_l4proto nf_conntrack_l4proto_dccp = { .l4proto = IPPROTO_DCCP, - .new = dccp_new, .packet = dccp_packet, - .error = dccp_error, .can_early_drop = dccp_can_early_drop, #ifdef CONFIG_NF_CONNTRACK_PROCFS .print_conntrack = dccp_print_conntrack, @@ -908,4 +870,3 @@ const struct nf_conntrack_l4proto nf_conntrack_l4proto_dccp6 = { .init_net = dccp_init_net, .get_net_proto = dccp_get_net_proto, }; -EXPORT_SYMBOL_GPL(nf_conntrack_l4proto_dccp6); diff --git a/net/netfilter/nf_conntrack_proto_generic.c b/net/netfilter/nf_conntrack_proto_generic.c index 1df3244ecd07..e10e867e0b55 100644 --- a/net/netfilter/nf_conntrack_proto_generic.c +++ b/net/netfilter/nf_conntrack_proto_generic.c @@ -44,12 +44,19 @@ static bool generic_pkt_to_tuple(const struct sk_buff *skb, /* Returns verdict for packet, or -1 for invalid. */ static int generic_packet(struct nf_conn *ct, - const struct sk_buff *skb, + struct sk_buff *skb, unsigned int dataoff, - enum ip_conntrack_info ctinfo) + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) { const unsigned int *timeout = nf_ct_timeout_lookup(ct); + if (!nf_generic_should_process(nf_ct_protonum(ct))) { + pr_warn_once("conntrack: generic helper won't handle protocol %d. Please consider loading the specific helper module.\n", + nf_ct_protonum(ct)); + return -NF_ACCEPT; + } + if (!timeout) timeout = &generic_pernet(nf_ct_net(ct))->timeout; @@ -57,19 +64,6 @@ static int generic_packet(struct nf_conn *ct, return NF_ACCEPT; } -/* Called when a new connection for this protocol found. */ -static bool generic_new(struct nf_conn *ct, const struct sk_buff *skb, - unsigned int dataoff) -{ - bool ret; - - ret = nf_generic_should_process(nf_ct_protonum(ct)); - if (!ret) - pr_warn_once("conntrack: generic helper won't handle protocol %d. Please consider loading the specific helper module.\n", - nf_ct_protonum(ct)); - return ret; -} - #ifdef CONFIG_NF_CONNTRACK_TIMEOUT #include <linux/netfilter/nfnetlink.h> @@ -142,7 +136,7 @@ static int generic_kmemdup_sysctl_table(struct nf_proto_net *pn, return 0; } -static int generic_init_net(struct net *net, u_int16_t proto) +static int generic_init_net(struct net *net) { struct nf_generic_net *gn = generic_pernet(net); struct nf_proto_net *pn = &gn->pn; @@ -159,11 +153,9 @@ static struct nf_proto_net *generic_get_net_proto(struct net *net) const struct nf_conntrack_l4proto nf_conntrack_l4proto_generic = { - .l3proto = PF_UNSPEC, .l4proto = 255, .pkt_to_tuple = generic_pkt_to_tuple, .packet = generic_packet, - .new = generic_new, #ifdef CONFIG_NF_CONNTRACK_TIMEOUT .ctnl_timeout = { .nlattr_to_obj = generic_timeout_nlattr_to_obj, diff --git a/net/netfilter/nf_conntrack_proto_gre.c b/net/netfilter/nf_conntrack_proto_gre.c index 650eb4fba2c5..9b48dc8b4b88 100644 --- a/net/netfilter/nf_conntrack_proto_gre.c +++ b/net/netfilter/nf_conntrack_proto_gre.c @@ -233,10 +233,26 @@ static unsigned int *gre_get_timeouts(struct net *net) /* Returns verdict for packet, and may modify conntrack */ static int gre_packet(struct nf_conn *ct, - const struct sk_buff *skb, + struct sk_buff *skb, unsigned int dataoff, - enum ip_conntrack_info ctinfo) + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) { + if (state->pf != NFPROTO_IPV4) + return -NF_ACCEPT; + + if (!nf_ct_is_confirmed(ct)) { + unsigned int *timeouts = nf_ct_timeout_lookup(ct); + + if (!timeouts) + timeouts = gre_get_timeouts(nf_ct_net(ct)); + + /* initialize to sane value. Ideally a conntrack helper + * (e.g. in case of pptp) is increasing them */ + ct->proto.gre.stream_timeout = timeouts[GRE_CT_REPLIED]; + ct->proto.gre.timeout = timeouts[GRE_CT_UNREPLIED]; + } + /* If we've seen traffic both ways, this is a GRE connection. * Extend timeout. */ if (ct->status & IPS_SEEN_REPLY) { @@ -252,26 +268,6 @@ static int gre_packet(struct nf_conn *ct, return NF_ACCEPT; } -/* Called when a new connection for this protocol found. */ -static bool gre_new(struct nf_conn *ct, const struct sk_buff *skb, - unsigned int dataoff) -{ - unsigned int *timeouts = nf_ct_timeout_lookup(ct); - - if (!timeouts) - timeouts = gre_get_timeouts(nf_ct_net(ct)); - - pr_debug(": "); - nf_ct_dump_tuple(&ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple); - - /* initialize to sane value. Ideally a conntrack helper - * (e.g. in case of pptp) is increasing them */ - ct->proto.gre.stream_timeout = timeouts[GRE_CT_REPLIED]; - ct->proto.gre.timeout = timeouts[GRE_CT_UNREPLIED]; - - return true; -} - /* Called when a conntrack entry has already been removed from the hashes * and is about to be deleted from memory */ static void gre_destroy(struct nf_conn *ct) @@ -336,7 +332,7 @@ gre_timeout_nla_policy[CTA_TIMEOUT_GRE_MAX+1] = { }; #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ -static int gre_init_net(struct net *net, u_int16_t proto) +static int gre_init_net(struct net *net) { struct netns_proto_gre *net_gre = gre_pernet(net); int i; @@ -351,14 +347,12 @@ static int gre_init_net(struct net *net, u_int16_t proto) /* protocol helper struct */ static const struct nf_conntrack_l4proto nf_conntrack_l4proto_gre4 = { - .l3proto = AF_INET, .l4proto = IPPROTO_GRE, .pkt_to_tuple = gre_pkt_to_tuple, #ifdef CONFIG_NF_CONNTRACK_PROCFS .print_conntrack = gre_print_conntrack, #endif .packet = gre_packet, - .new = gre_new, .destroy = gre_destroy, .me = THIS_MODULE, #if IS_ENABLED(CONFIG_NF_CT_NETLINK) diff --git a/net/netfilter/nf_conntrack_proto_icmp.c b/net/netfilter/nf_conntrack_proto_icmp.c index 43c7e1a217b9..3598520bd19b 100644 --- a/net/netfilter/nf_conntrack_proto_icmp.c +++ b/net/netfilter/nf_conntrack_proto_icmp.c @@ -72,34 +72,17 @@ static bool icmp_invert_tuple(struct nf_conntrack_tuple *tuple, return true; } -static unsigned int *icmp_get_timeouts(struct net *net) -{ - return &icmp_pernet(net)->timeout; -} - /* Returns verdict for packet, or -1 for invalid. */ static int icmp_packet(struct nf_conn *ct, - const struct sk_buff *skb, + struct sk_buff *skb, unsigned int dataoff, - enum ip_conntrack_info ctinfo) + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) { /* Do not immediately delete the connection after the first successful reply to avoid excessive conntrackd traffic and also to handle correctly ICMP echo reply duplicates. */ unsigned int *timeout = nf_ct_timeout_lookup(ct); - - if (!timeout) - timeout = icmp_get_timeouts(nf_ct_net(ct)); - - nf_ct_refresh_acct(ct, ctinfo, skb, *timeout); - - return NF_ACCEPT; -} - -/* Called when a new connection for this protocol found. */ -static bool icmp_new(struct nf_conn *ct, const struct sk_buff *skb, - unsigned int dataoff) -{ static const u_int8_t valid_new[] = { [ICMP_ECHO] = 1, [ICMP_TIMESTAMP] = 1, @@ -107,21 +90,29 @@ static bool icmp_new(struct nf_conn *ct, const struct sk_buff *skb, [ICMP_ADDRESS] = 1 }; + if (state->pf != NFPROTO_IPV4) + return -NF_ACCEPT; + if (ct->tuplehash[0].tuple.dst.u.icmp.type >= sizeof(valid_new) || !valid_new[ct->tuplehash[0].tuple.dst.u.icmp.type]) { /* Can't create a new ICMP `conn' with this. */ pr_debug("icmp: can't create new conn with type %u\n", ct->tuplehash[0].tuple.dst.u.icmp.type); nf_ct_dump_tuple_ip(&ct->tuplehash[0].tuple); - return false; + return -NF_ACCEPT; } - return true; + + if (!timeout) + timeout = &icmp_pernet(nf_ct_net(ct))->timeout; + + nf_ct_refresh_acct(ct, ctinfo, skb, *timeout); + return NF_ACCEPT; } /* Returns conntrack if it dealt with ICMP, and filled in skb fields */ static int -icmp_error_message(struct net *net, struct nf_conn *tmpl, struct sk_buff *skb, - unsigned int hooknum) +icmp_error_message(struct nf_conn *tmpl, struct sk_buff *skb, + const struct nf_hook_state *state) { struct nf_conntrack_tuple innertuple, origtuple; const struct nf_conntrack_l4proto *innerproto; @@ -137,13 +128,13 @@ icmp_error_message(struct net *net, struct nf_conn *tmpl, struct sk_buff *skb, if (!nf_ct_get_tuplepr(skb, skb_network_offset(skb) + ip_hdrlen(skb) + sizeof(struct icmphdr), - PF_INET, net, &origtuple)) { + PF_INET, state->net, &origtuple)) { pr_debug("icmp_error_message: failed to get tuple\n"); return -NF_ACCEPT; } /* rcu_read_lock()ed by nf_hook_thresh */ - innerproto = __nf_ct_l4proto_find(PF_INET, origtuple.dst.protonum); + innerproto = __nf_ct_l4proto_find(origtuple.dst.protonum); /* Ordinarily, we'd expect the inverted tupleproto, but it's been preserved inside the ICMP. */ @@ -154,7 +145,7 @@ icmp_error_message(struct net *net, struct nf_conn *tmpl, struct sk_buff *skb, ctinfo = IP_CT_RELATED; - h = nf_conntrack_find_get(net, zone, &innertuple); + h = nf_conntrack_find_get(state->net, zone, &innertuple); if (!h) { pr_debug("icmp_error_message: no match\n"); return -NF_ACCEPT; @@ -168,17 +159,18 @@ icmp_error_message(struct net *net, struct nf_conn *tmpl, struct sk_buff *skb, return NF_ACCEPT; } -static void icmp_error_log(const struct sk_buff *skb, struct net *net, - u8 pf, const char *msg) +static void icmp_error_log(const struct sk_buff *skb, + const struct nf_hook_state *state, + const char *msg) { - nf_l4proto_log_invalid(skb, net, pf, IPPROTO_ICMP, "%s", msg); + nf_l4proto_log_invalid(skb, state->net, state->pf, + IPPROTO_ICMP, "%s", msg); } /* Small and modified version of icmp_rcv */ -static int -icmp_error(struct net *net, struct nf_conn *tmpl, - struct sk_buff *skb, unsigned int dataoff, - u8 pf, unsigned int hooknum) +int nf_conntrack_icmpv4_error(struct nf_conn *tmpl, + struct sk_buff *skb, unsigned int dataoff, + const struct nf_hook_state *state) { const struct icmphdr *icmph; struct icmphdr _ih; @@ -186,14 +178,15 @@ icmp_error(struct net *net, struct nf_conn *tmpl, /* Not enough header? */ icmph = skb_header_pointer(skb, ip_hdrlen(skb), sizeof(_ih), &_ih); if (icmph == NULL) { - icmp_error_log(skb, net, pf, "short packet"); + icmp_error_log(skb, state, "short packet"); return -NF_ACCEPT; } /* See ip_conntrack_proto_tcp.c */ - if (net->ct.sysctl_checksum && hooknum == NF_INET_PRE_ROUTING && - nf_ip_checksum(skb, hooknum, dataoff, 0)) { - icmp_error_log(skb, net, pf, "bad hw icmp checksum"); + if (state->net->ct.sysctl_checksum && + state->hook == NF_INET_PRE_ROUTING && + nf_ip_checksum(skb, state->hook, dataoff, 0)) { + icmp_error_log(skb, state, "bad hw icmp checksum"); return -NF_ACCEPT; } @@ -204,7 +197,7 @@ icmp_error(struct net *net, struct nf_conn *tmpl, * discarded. */ if (icmph->type > NR_ICMP_TYPES) { - icmp_error_log(skb, net, pf, "invalid icmp type"); + icmp_error_log(skb, state, "invalid icmp type"); return -NF_ACCEPT; } @@ -216,7 +209,7 @@ icmp_error(struct net *net, struct nf_conn *tmpl, icmph->type != ICMP_REDIRECT) return NF_ACCEPT; - return icmp_error_message(net, tmpl, skb, hooknum); + return icmp_error_message(tmpl, skb, state); } #if IS_ENABLED(CONFIG_NF_CT_NETLINK) @@ -342,7 +335,7 @@ static int icmp_kmemdup_sysctl_table(struct nf_proto_net *pn, return 0; } -static int icmp_init_net(struct net *net, u_int16_t proto) +static int icmp_init_net(struct net *net) { struct nf_icmp_net *in = icmp_pernet(net); struct nf_proto_net *pn = &in->pn; @@ -359,13 +352,10 @@ static struct nf_proto_net *icmp_get_net_proto(struct net *net) const struct nf_conntrack_l4proto nf_conntrack_l4proto_icmp = { - .l3proto = PF_INET, .l4proto = IPPROTO_ICMP, .pkt_to_tuple = icmp_pkt_to_tuple, .invert_tuple = icmp_invert_tuple, .packet = icmp_packet, - .new = icmp_new, - .error = icmp_error, .destroy = NULL, .me = NULL, #if IS_ENABLED(CONFIG_NF_CT_NETLINK) diff --git a/net/netfilter/nf_conntrack_proto_icmpv6.c b/net/netfilter/nf_conntrack_proto_icmpv6.c index 97e40f77d678..378618feed5d 100644 --- a/net/netfilter/nf_conntrack_proto_icmpv6.c +++ b/net/netfilter/nf_conntrack_proto_icmpv6.c @@ -92,11 +92,31 @@ static unsigned int *icmpv6_get_timeouts(struct net *net) /* Returns verdict for packet, or -1 for invalid. */ static int icmpv6_packet(struct nf_conn *ct, - const struct sk_buff *skb, - unsigned int dataoff, - enum ip_conntrack_info ctinfo) + struct sk_buff *skb, + unsigned int dataoff, + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) { unsigned int *timeout = nf_ct_timeout_lookup(ct); + static const u8 valid_new[] = { + [ICMPV6_ECHO_REQUEST - 128] = 1, + [ICMPV6_NI_QUERY - 128] = 1 + }; + + if (state->pf != NFPROTO_IPV6) + return -NF_ACCEPT; + + if (!nf_ct_is_confirmed(ct)) { + int type = ct->tuplehash[0].tuple.dst.u.icmp.type - 128; + + if (type < 0 || type >= sizeof(valid_new) || !valid_new[type]) { + /* Can't create a new ICMPv6 `conn' with this. */ + pr_debug("icmpv6: can't create new conn with type %u\n", + type + 128); + nf_ct_dump_tuple_ipv6(&ct->tuplehash[0].tuple); + return -NF_ACCEPT; + } + } if (!timeout) timeout = icmpv6_get_timeouts(nf_ct_net(ct)); @@ -109,26 +129,6 @@ static int icmpv6_packet(struct nf_conn *ct, return NF_ACCEPT; } -/* Called when a new connection for this protocol found. */ -static bool icmpv6_new(struct nf_conn *ct, const struct sk_buff *skb, - unsigned int dataoff) -{ - static const u_int8_t valid_new[] = { - [ICMPV6_ECHO_REQUEST - 128] = 1, - [ICMPV6_NI_QUERY - 128] = 1 - }; - int type = ct->tuplehash[0].tuple.dst.u.icmp.type - 128; - - if (type < 0 || type >= sizeof(valid_new) || !valid_new[type]) { - /* Can't create a new ICMPv6 `conn' with this. */ - pr_debug("icmpv6: can't create new conn with type %u\n", - type + 128); - nf_ct_dump_tuple_ipv6(&ct->tuplehash[0].tuple); - return false; - } - return true; -} - static int icmpv6_error_message(struct net *net, struct nf_conn *tmpl, struct sk_buff *skb, @@ -153,7 +153,7 @@ icmpv6_error_message(struct net *net, struct nf_conn *tmpl, } /* rcu_read_lock()ed by nf_hook_thresh */ - inproto = __nf_ct_l4proto_find(PF_INET6, origtuple.dst.protonum); + inproto = __nf_ct_l4proto_find(origtuple.dst.protonum); /* Ordinarily, we'd expect the inverted tupleproto, but it's been preserved inside the ICMP. */ @@ -179,16 +179,18 @@ icmpv6_error_message(struct net *net, struct nf_conn *tmpl, return NF_ACCEPT; } -static void icmpv6_error_log(const struct sk_buff *skb, struct net *net, - u8 pf, const char *msg) +static void icmpv6_error_log(const struct sk_buff *skb, + const struct nf_hook_state *state, + const char *msg) { - nf_l4proto_log_invalid(skb, net, pf, IPPROTO_ICMPV6, "%s", msg); + nf_l4proto_log_invalid(skb, state->net, state->pf, + IPPROTO_ICMPV6, "%s", msg); } -static int -icmpv6_error(struct net *net, struct nf_conn *tmpl, - struct sk_buff *skb, unsigned int dataoff, - u8 pf, unsigned int hooknum) +int nf_conntrack_icmpv6_error(struct nf_conn *tmpl, + struct sk_buff *skb, + unsigned int dataoff, + const struct nf_hook_state *state) { const struct icmp6hdr *icmp6h; struct icmp6hdr _ih; @@ -196,13 +198,14 @@ icmpv6_error(struct net *net, struct nf_conn *tmpl, icmp6h = skb_header_pointer(skb, dataoff, sizeof(_ih), &_ih); if (icmp6h == NULL) { - icmpv6_error_log(skb, net, pf, "short packet"); + icmpv6_error_log(skb, state, "short packet"); return -NF_ACCEPT; } - if (net->ct.sysctl_checksum && hooknum == NF_INET_PRE_ROUTING && - nf_ip6_checksum(skb, hooknum, dataoff, IPPROTO_ICMPV6)) { - icmpv6_error_log(skb, net, pf, "ICMPv6 checksum failed"); + if (state->hook == NF_INET_PRE_ROUTING && + state->net->ct.sysctl_checksum && + nf_ip6_checksum(skb, state->hook, dataoff, IPPROTO_ICMPV6)) { + icmpv6_error_log(skb, state, "ICMPv6 checksum failed"); return -NF_ACCEPT; } @@ -217,7 +220,7 @@ icmpv6_error(struct net *net, struct nf_conn *tmpl, if (icmp6h->icmp6_type >= 128) return NF_ACCEPT; - return icmpv6_error_message(net, tmpl, skb, dataoff); + return icmpv6_error_message(state->net, tmpl, skb, dataoff); } #if IS_ENABLED(CONFIG_NF_CT_NETLINK) @@ -343,7 +346,7 @@ static int icmpv6_kmemdup_sysctl_table(struct nf_proto_net *pn, return 0; } -static int icmpv6_init_net(struct net *net, u_int16_t proto) +static int icmpv6_init_net(struct net *net) { struct nf_icmp_net *in = icmpv6_pernet(net); struct nf_proto_net *pn = &in->pn; @@ -360,13 +363,10 @@ static struct nf_proto_net *icmpv6_get_net_proto(struct net *net) const struct nf_conntrack_l4proto nf_conntrack_l4proto_icmpv6 = { - .l3proto = PF_INET6, .l4proto = IPPROTO_ICMPV6, .pkt_to_tuple = icmpv6_pkt_to_tuple, .invert_tuple = icmpv6_invert_tuple, .packet = icmpv6_packet, - .new = icmpv6_new, - .error = icmpv6_error, #if IS_ENABLED(CONFIG_NF_CT_NETLINK) .tuple_to_nlattr = icmpv6_tuple_to_nlattr, .nlattr_tuple_size = icmpv6_nlattr_tuple_size, diff --git a/net/netfilter/nf_conntrack_proto_sctp.c b/net/netfilter/nf_conntrack_proto_sctp.c index e4d738d34cd0..3d719d3eb9a3 100644 --- a/net/netfilter/nf_conntrack_proto_sctp.c +++ b/net/netfilter/nf_conntrack_proto_sctp.c @@ -273,11 +273,100 @@ static int sctp_new_state(enum ip_conntrack_dir dir, return sctp_conntracks[dir][i][cur_state]; } +/* Don't need lock here: this conntrack not in circulation yet */ +static noinline bool +sctp_new(struct nf_conn *ct, const struct sk_buff *skb, + const struct sctphdr *sh, unsigned int dataoff) +{ + enum sctp_conntrack new_state; + const struct sctp_chunkhdr *sch; + struct sctp_chunkhdr _sch; + u32 offset, count; + + memset(&ct->proto.sctp, 0, sizeof(ct->proto.sctp)); + new_state = SCTP_CONNTRACK_MAX; + for_each_sctp_chunk(skb, sch, _sch, offset, dataoff, count) { + new_state = sctp_new_state(IP_CT_DIR_ORIGINAL, + SCTP_CONNTRACK_NONE, sch->type); + + /* Invalid: delete conntrack */ + if (new_state == SCTP_CONNTRACK_NONE || + new_state == SCTP_CONNTRACK_MAX) { + pr_debug("nf_conntrack_sctp: invalid new deleting.\n"); + return false; + } + + /* Copy the vtag into the state info */ + if (sch->type == SCTP_CID_INIT) { + struct sctp_inithdr _inithdr, *ih; + /* Sec 8.5.1 (A) */ + if (sh->vtag) + return false; + + ih = skb_header_pointer(skb, offset + sizeof(_sch), + sizeof(_inithdr), &_inithdr); + if (!ih) + return false; + + pr_debug("Setting vtag %x for new conn\n", + ih->init_tag); + + ct->proto.sctp.vtag[IP_CT_DIR_REPLY] = ih->init_tag; + } else if (sch->type == SCTP_CID_HEARTBEAT) { + pr_debug("Setting vtag %x for secondary conntrack\n", + sh->vtag); + ct->proto.sctp.vtag[IP_CT_DIR_ORIGINAL] = sh->vtag; + } else { + /* If it is a shutdown ack OOTB packet, we expect a return + shutdown complete, otherwise an ABORT Sec 8.4 (5) and (8) */ + pr_debug("Setting vtag %x for new conn OOTB\n", + sh->vtag); + ct->proto.sctp.vtag[IP_CT_DIR_REPLY] = sh->vtag; + } + + ct->proto.sctp.state = new_state; + } + + return true; +} + +static bool sctp_error(struct sk_buff *skb, + unsigned int dataoff, + const struct nf_hook_state *state) +{ + const struct sctphdr *sh; + const char *logmsg; + + if (skb->len < dataoff + sizeof(struct sctphdr)) { + logmsg = "nf_ct_sctp: short packet "; + goto out_invalid; + } + if (state->hook == NF_INET_PRE_ROUTING && + state->net->ct.sysctl_checksum && + skb->ip_summed == CHECKSUM_NONE) { + if (!skb_make_writable(skb, dataoff + sizeof(struct sctphdr))) { + logmsg = "nf_ct_sctp: failed to read header "; + goto out_invalid; + } + sh = (const struct sctphdr *)(skb->data + dataoff); + if (sh->checksum != sctp_compute_cksum(skb, dataoff)) { + logmsg = "nf_ct_sctp: bad CRC "; + goto out_invalid; + } + skb->ip_summed = CHECKSUM_UNNECESSARY; + } + return false; +out_invalid: + nf_l4proto_log_invalid(skb, state->net, state->pf, IPPROTO_SCTP, "%s", logmsg); + return true; +} + /* Returns verdict for packet, or -NF_ACCEPT for invalid. */ static int sctp_packet(struct nf_conn *ct, - const struct sk_buff *skb, + struct sk_buff *skb, unsigned int dataoff, - enum ip_conntrack_info ctinfo) + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) { enum sctp_conntrack new_state, old_state; enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); @@ -289,6 +378,9 @@ static int sctp_packet(struct nf_conn *ct, unsigned int *timeouts; unsigned long map[256 / sizeof(unsigned long)] = { 0 }; + if (sctp_error(skb, dataoff, state)) + return -NF_ACCEPT; + sh = skb_header_pointer(skb, dataoff, sizeof(_sctph), &_sctph); if (sh == NULL) goto out; @@ -296,6 +388,17 @@ static int sctp_packet(struct nf_conn *ct, if (do_basic_checks(ct, skb, dataoff, map) != 0) goto out; + if (!nf_ct_is_confirmed(ct)) { + /* If an OOTB packet has any of these chunks discard (Sec 8.4) */ + if (test_bit(SCTP_CID_ABORT, map) || + test_bit(SCTP_CID_SHUTDOWN_COMPLETE, map) || + test_bit(SCTP_CID_COOKIE_ACK, map)) + return -NF_ACCEPT; + + if (!sctp_new(ct, skb, sh, dataoff)) + return -NF_ACCEPT; + } + /* Check the verification tag (Sec 8.5) */ if (!test_bit(SCTP_CID_INIT, map) && !test_bit(SCTP_CID_SHUTDOWN_COMPLETE, map) && @@ -397,110 +500,6 @@ out: return -NF_ACCEPT; } -/* Called when a new connection for this protocol found. */ -static bool sctp_new(struct nf_conn *ct, const struct sk_buff *skb, - unsigned int dataoff) -{ - enum sctp_conntrack new_state; - const struct sctphdr *sh; - struct sctphdr _sctph; - const struct sctp_chunkhdr *sch; - struct sctp_chunkhdr _sch; - u_int32_t offset, count; - unsigned long map[256 / sizeof(unsigned long)] = { 0 }; - - sh = skb_header_pointer(skb, dataoff, sizeof(_sctph), &_sctph); - if (sh == NULL) - return false; - - if (do_basic_checks(ct, skb, dataoff, map) != 0) - return false; - - /* If an OOTB packet has any of these chunks discard (Sec 8.4) */ - if (test_bit(SCTP_CID_ABORT, map) || - test_bit(SCTP_CID_SHUTDOWN_COMPLETE, map) || - test_bit(SCTP_CID_COOKIE_ACK, map)) - return false; - - memset(&ct->proto.sctp, 0, sizeof(ct->proto.sctp)); - new_state = SCTP_CONNTRACK_MAX; - for_each_sctp_chunk (skb, sch, _sch, offset, dataoff, count) { - /* Don't need lock here: this conntrack not in circulation yet */ - new_state = sctp_new_state(IP_CT_DIR_ORIGINAL, - SCTP_CONNTRACK_NONE, sch->type); - - /* Invalid: delete conntrack */ - if (new_state == SCTP_CONNTRACK_NONE || - new_state == SCTP_CONNTRACK_MAX) { - pr_debug("nf_conntrack_sctp: invalid new deleting.\n"); - return false; - } - - /* Copy the vtag into the state info */ - if (sch->type == SCTP_CID_INIT) { - struct sctp_inithdr _inithdr, *ih; - /* Sec 8.5.1 (A) */ - if (sh->vtag) - return false; - - ih = skb_header_pointer(skb, offset + sizeof(_sch), - sizeof(_inithdr), &_inithdr); - if (!ih) - return false; - - pr_debug("Setting vtag %x for new conn\n", - ih->init_tag); - - ct->proto.sctp.vtag[IP_CT_DIR_REPLY] = ih->init_tag; - } else if (sch->type == SCTP_CID_HEARTBEAT) { - pr_debug("Setting vtag %x for secondary conntrack\n", - sh->vtag); - ct->proto.sctp.vtag[IP_CT_DIR_ORIGINAL] = sh->vtag; - } - /* If it is a shutdown ack OOTB packet, we expect a return - shutdown complete, otherwise an ABORT Sec 8.4 (5) and (8) */ - else { - pr_debug("Setting vtag %x for new conn OOTB\n", - sh->vtag); - ct->proto.sctp.vtag[IP_CT_DIR_REPLY] = sh->vtag; - } - - ct->proto.sctp.state = new_state; - } - - return true; -} - -static int sctp_error(struct net *net, struct nf_conn *tpl, struct sk_buff *skb, - unsigned int dataoff, - u8 pf, unsigned int hooknum) -{ - const struct sctphdr *sh; - const char *logmsg; - - if (skb->len < dataoff + sizeof(struct sctphdr)) { - logmsg = "nf_ct_sctp: short packet "; - goto out_invalid; - } - if (net->ct.sysctl_checksum && hooknum == NF_INET_PRE_ROUTING && - skb->ip_summed == CHECKSUM_NONE) { - if (!skb_make_writable(skb, dataoff + sizeof(struct sctphdr))) { - logmsg = "nf_ct_sctp: failed to read header "; - goto out_invalid; - } - sh = (const struct sctphdr *)(skb->data + dataoff); - if (sh->checksum != sctp_compute_cksum(skb, dataoff)) { - logmsg = "nf_ct_sctp: bad CRC "; - goto out_invalid; - } - skb->ip_summed = CHECKSUM_UNNECESSARY; - } - return NF_ACCEPT; -out_invalid: - nf_l4proto_log_invalid(skb, net, pf, IPPROTO_SCTP, "%s", logmsg); - return -NF_ACCEPT; -} - static bool sctp_can_early_drop(const struct nf_conn *ct) { switch (ct->proto.sctp.state) { @@ -735,7 +734,7 @@ static int sctp_kmemdup_sysctl_table(struct nf_proto_net *pn, return 0; } -static int sctp_init_net(struct net *net, u_int16_t proto) +static int sctp_init_net(struct net *net) { struct nf_sctp_net *sn = sctp_pernet(net); struct nf_proto_net *pn = &sn->pn; @@ -760,49 +759,12 @@ static struct nf_proto_net *sctp_get_net_proto(struct net *net) return &net->ct.nf_ct_proto.sctp.pn; } -const struct nf_conntrack_l4proto nf_conntrack_l4proto_sctp4 = { - .l3proto = PF_INET, - .l4proto = IPPROTO_SCTP, -#ifdef CONFIG_NF_CONNTRACK_PROCFS - .print_conntrack = sctp_print_conntrack, -#endif - .packet = sctp_packet, - .new = sctp_new, - .error = sctp_error, - .can_early_drop = sctp_can_early_drop, - .me = THIS_MODULE, -#if IS_ENABLED(CONFIG_NF_CT_NETLINK) - .nlattr_size = SCTP_NLATTR_SIZE, - .to_nlattr = sctp_to_nlattr, - .from_nlattr = nlattr_to_sctp, - .tuple_to_nlattr = nf_ct_port_tuple_to_nlattr, - .nlattr_tuple_size = nf_ct_port_nlattr_tuple_size, - .nlattr_to_tuple = nf_ct_port_nlattr_to_tuple, - .nla_policy = nf_ct_port_nla_policy, -#endif -#ifdef CONFIG_NF_CONNTRACK_TIMEOUT - .ctnl_timeout = { - .nlattr_to_obj = sctp_timeout_nlattr_to_obj, - .obj_to_nlattr = sctp_timeout_obj_to_nlattr, - .nlattr_max = CTA_TIMEOUT_SCTP_MAX, - .obj_size = sizeof(unsigned int) * SCTP_CONNTRACK_MAX, - .nla_policy = sctp_timeout_nla_policy, - }, -#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ - .init_net = sctp_init_net, - .get_net_proto = sctp_get_net_proto, -}; -EXPORT_SYMBOL_GPL(nf_conntrack_l4proto_sctp4); - -const struct nf_conntrack_l4proto nf_conntrack_l4proto_sctp6 = { - .l3proto = PF_INET6, +const struct nf_conntrack_l4proto nf_conntrack_l4proto_sctp = { .l4proto = IPPROTO_SCTP, #ifdef CONFIG_NF_CONNTRACK_PROCFS .print_conntrack = sctp_print_conntrack, #endif .packet = sctp_packet, - .new = sctp_new, - .error = sctp_error, .can_early_drop = sctp_can_early_drop, .me = THIS_MODULE, #if IS_ENABLED(CONFIG_NF_CT_NETLINK) @@ -826,4 +788,3 @@ const struct nf_conntrack_l4proto nf_conntrack_l4proto_sctp6 = { .init_net = sctp_init_net, .get_net_proto = sctp_get_net_proto, }; -EXPORT_SYMBOL_GPL(nf_conntrack_l4proto_sctp6); diff --git a/net/netfilter/nf_conntrack_proto_tcp.c b/net/netfilter/nf_conntrack_proto_tcp.c index b4bdf9eda7b7..1bcf9984d45e 100644 --- a/net/netfilter/nf_conntrack_proto_tcp.c +++ b/net/netfilter/nf_conntrack_proto_tcp.c @@ -717,35 +717,26 @@ static const u8 tcp_valid_flags[(TCPHDR_FIN|TCPHDR_SYN|TCPHDR_RST|TCPHDR_ACK| [TCPHDR_ACK|TCPHDR_URG] = 1, }; -static void tcp_error_log(const struct sk_buff *skb, struct net *net, - u8 pf, const char *msg) +static void tcp_error_log(const struct sk_buff *skb, + const struct nf_hook_state *state, + const char *msg) { - nf_l4proto_log_invalid(skb, net, pf, IPPROTO_TCP, "%s", msg); + nf_l4proto_log_invalid(skb, state->net, state->pf, IPPROTO_TCP, "%s", msg); } /* Protect conntrack agaist broken packets. Code taken from ipt_unclean.c. */ -static int tcp_error(struct net *net, struct nf_conn *tmpl, - struct sk_buff *skb, - unsigned int dataoff, - u_int8_t pf, - unsigned int hooknum) +static bool tcp_error(const struct tcphdr *th, + struct sk_buff *skb, + unsigned int dataoff, + const struct nf_hook_state *state) { - const struct tcphdr *th; - struct tcphdr _tcph; unsigned int tcplen = skb->len - dataoff; - u_int8_t tcpflags; - - /* Smaller that minimal TCP header? */ - th = skb_header_pointer(skb, dataoff, sizeof(_tcph), &_tcph); - if (th == NULL) { - tcp_error_log(skb, net, pf, "short packet"); - return -NF_ACCEPT; - } + u8 tcpflags; /* Not whole TCP header or malformed packet */ if (th->doff*4 < sizeof(struct tcphdr) || tcplen < th->doff*4) { - tcp_error_log(skb, net, pf, "truncated packet"); - return -NF_ACCEPT; + tcp_error_log(skb, state, "truncated packet"); + return true; } /* Checksum invalid? Ignore. @@ -753,27 +744,101 @@ static int tcp_error(struct net *net, struct nf_conn *tmpl, * because the checksum is assumed to be correct. */ /* FIXME: Source route IP option packets --RR */ - if (net->ct.sysctl_checksum && hooknum == NF_INET_PRE_ROUTING && - nf_checksum(skb, hooknum, dataoff, IPPROTO_TCP, pf)) { - tcp_error_log(skb, net, pf, "bad checksum"); - return -NF_ACCEPT; + if (state->net->ct.sysctl_checksum && + state->hook == NF_INET_PRE_ROUTING && + nf_checksum(skb, state->hook, dataoff, IPPROTO_TCP, state->pf)) { + tcp_error_log(skb, state, "bad checksum"); + return true; } /* Check TCP flags. */ tcpflags = (tcp_flag_byte(th) & ~(TCPHDR_ECE|TCPHDR_CWR|TCPHDR_PSH)); if (!tcp_valid_flags[tcpflags]) { - tcp_error_log(skb, net, pf, "invalid tcp flag combination"); - return -NF_ACCEPT; + tcp_error_log(skb, state, "invalid tcp flag combination"); + return true; } - return NF_ACCEPT; + return false; +} + +static noinline bool tcp_new(struct nf_conn *ct, const struct sk_buff *skb, + unsigned int dataoff, + const struct tcphdr *th) +{ + enum tcp_conntrack new_state; + struct net *net = nf_ct_net(ct); + const struct nf_tcp_net *tn = tcp_pernet(net); + const struct ip_ct_tcp_state *sender = &ct->proto.tcp.seen[0]; + const struct ip_ct_tcp_state *receiver = &ct->proto.tcp.seen[1]; + + /* Don't need lock here: this conntrack not in circulation yet */ + new_state = tcp_conntracks[0][get_conntrack_index(th)][TCP_CONNTRACK_NONE]; + + /* Invalid: delete conntrack */ + if (new_state >= TCP_CONNTRACK_MAX) { + pr_debug("nf_ct_tcp: invalid new deleting.\n"); + return false; + } + + if (new_state == TCP_CONNTRACK_SYN_SENT) { + memset(&ct->proto.tcp, 0, sizeof(ct->proto.tcp)); + /* SYN packet */ + ct->proto.tcp.seen[0].td_end = + segment_seq_plus_len(ntohl(th->seq), skb->len, + dataoff, th); + ct->proto.tcp.seen[0].td_maxwin = ntohs(th->window); + if (ct->proto.tcp.seen[0].td_maxwin == 0) + ct->proto.tcp.seen[0].td_maxwin = 1; + ct->proto.tcp.seen[0].td_maxend = + ct->proto.tcp.seen[0].td_end; + + tcp_options(skb, dataoff, th, &ct->proto.tcp.seen[0]); + } else if (tn->tcp_loose == 0) { + /* Don't try to pick up connections. */ + return false; + } else { + memset(&ct->proto.tcp, 0, sizeof(ct->proto.tcp)); + /* + * We are in the middle of a connection, + * its history is lost for us. + * Let's try to use the data from the packet. + */ + ct->proto.tcp.seen[0].td_end = + segment_seq_plus_len(ntohl(th->seq), skb->len, + dataoff, th); + ct->proto.tcp.seen[0].td_maxwin = ntohs(th->window); + if (ct->proto.tcp.seen[0].td_maxwin == 0) + ct->proto.tcp.seen[0].td_maxwin = 1; + ct->proto.tcp.seen[0].td_maxend = + ct->proto.tcp.seen[0].td_end + + ct->proto.tcp.seen[0].td_maxwin; + + /* We assume SACK and liberal window checking to handle + * window scaling */ + ct->proto.tcp.seen[0].flags = + ct->proto.tcp.seen[1].flags = IP_CT_TCP_FLAG_SACK_PERM | + IP_CT_TCP_FLAG_BE_LIBERAL; + } + + /* tcp_packet will set them */ + ct->proto.tcp.last_index = TCP_NONE_SET; + + pr_debug("%s: sender end=%u maxend=%u maxwin=%u scale=%i " + "receiver end=%u maxend=%u maxwin=%u scale=%i\n", + __func__, + sender->td_end, sender->td_maxend, sender->td_maxwin, + sender->td_scale, + receiver->td_end, receiver->td_maxend, receiver->td_maxwin, + receiver->td_scale); + return true; } /* Returns verdict for packet, or -1 for invalid. */ static int tcp_packet(struct nf_conn *ct, - const struct sk_buff *skb, + struct sk_buff *skb, unsigned int dataoff, - enum ip_conntrack_info ctinfo) + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) { struct net *net = nf_ct_net(ct); struct nf_tcp_net *tn = tcp_pernet(net); @@ -786,7 +851,14 @@ static int tcp_packet(struct nf_conn *ct, unsigned long timeout; th = skb_header_pointer(skb, dataoff, sizeof(_tcph), &_tcph); - BUG_ON(th == NULL); + if (th == NULL) + return -NF_ACCEPT; + + if (tcp_error(th, skb, dataoff, state)) + return -NF_ACCEPT; + + if (!nf_ct_is_confirmed(ct) && !tcp_new(ct, skb, dataoff, th)) + return -NF_ACCEPT; spin_lock_bh(&ct->lock); old_state = ct->proto.tcp.state; @@ -1067,82 +1139,6 @@ static int tcp_packet(struct nf_conn *ct, return NF_ACCEPT; } -/* Called when a new connection for this protocol found. */ -static bool tcp_new(struct nf_conn *ct, const struct sk_buff *skb, - unsigned int dataoff) -{ - enum tcp_conntrack new_state; - const struct tcphdr *th; - struct tcphdr _tcph; - struct net *net = nf_ct_net(ct); - struct nf_tcp_net *tn = tcp_pernet(net); - const struct ip_ct_tcp_state *sender = &ct->proto.tcp.seen[0]; - const struct ip_ct_tcp_state *receiver = &ct->proto.tcp.seen[1]; - - th = skb_header_pointer(skb, dataoff, sizeof(_tcph), &_tcph); - BUG_ON(th == NULL); - - /* Don't need lock here: this conntrack not in circulation yet */ - new_state = tcp_conntracks[0][get_conntrack_index(th)][TCP_CONNTRACK_NONE]; - - /* Invalid: delete conntrack */ - if (new_state >= TCP_CONNTRACK_MAX) { - pr_debug("nf_ct_tcp: invalid new deleting.\n"); - return false; - } - - if (new_state == TCP_CONNTRACK_SYN_SENT) { - memset(&ct->proto.tcp, 0, sizeof(ct->proto.tcp)); - /* SYN packet */ - ct->proto.tcp.seen[0].td_end = - segment_seq_plus_len(ntohl(th->seq), skb->len, - dataoff, th); - ct->proto.tcp.seen[0].td_maxwin = ntohs(th->window); - if (ct->proto.tcp.seen[0].td_maxwin == 0) - ct->proto.tcp.seen[0].td_maxwin = 1; - ct->proto.tcp.seen[0].td_maxend = - ct->proto.tcp.seen[0].td_end; - - tcp_options(skb, dataoff, th, &ct->proto.tcp.seen[0]); - } else if (tn->tcp_loose == 0) { - /* Don't try to pick up connections. */ - return false; - } else { - memset(&ct->proto.tcp, 0, sizeof(ct->proto.tcp)); - /* - * We are in the middle of a connection, - * its history is lost for us. - * Let's try to use the data from the packet. - */ - ct->proto.tcp.seen[0].td_end = - segment_seq_plus_len(ntohl(th->seq), skb->len, - dataoff, th); - ct->proto.tcp.seen[0].td_maxwin = ntohs(th->window); - if (ct->proto.tcp.seen[0].td_maxwin == 0) - ct->proto.tcp.seen[0].td_maxwin = 1; - ct->proto.tcp.seen[0].td_maxend = - ct->proto.tcp.seen[0].td_end + - ct->proto.tcp.seen[0].td_maxwin; - - /* We assume SACK and liberal window checking to handle - * window scaling */ - ct->proto.tcp.seen[0].flags = - ct->proto.tcp.seen[1].flags = IP_CT_TCP_FLAG_SACK_PERM | - IP_CT_TCP_FLAG_BE_LIBERAL; - } - - /* tcp_packet will set them */ - ct->proto.tcp.last_index = TCP_NONE_SET; - - pr_debug("tcp_new: sender end=%u maxend=%u maxwin=%u scale=%i " - "receiver end=%u maxend=%u maxwin=%u scale=%i\n", - sender->td_end, sender->td_maxend, sender->td_maxwin, - sender->td_scale, - receiver->td_end, receiver->td_maxend, receiver->td_maxwin, - receiver->td_scale); - return true; -} - static bool tcp_can_early_drop(const struct nf_conn *ct) { switch (ct->proto.tcp.state) { @@ -1213,8 +1209,8 @@ static const struct nla_policy tcp_nla_policy[CTA_PROTOINFO_TCP_MAX+1] = { #define TCP_NLATTR_SIZE ( \ NLA_ALIGN(NLA_HDRLEN + 1) + \ NLA_ALIGN(NLA_HDRLEN + 1) + \ - NLA_ALIGN(NLA_HDRLEN + sizeof(sizeof(struct nf_ct_tcp_flags))) + \ - NLA_ALIGN(NLA_HDRLEN + sizeof(sizeof(struct nf_ct_tcp_flags)))) + NLA_ALIGN(NLA_HDRLEN + sizeof(struct nf_ct_tcp_flags)) + \ + NLA_ALIGN(NLA_HDRLEN + sizeof(struct nf_ct_tcp_flags))) static int nlattr_to_tcp(struct nlattr *cda[], struct nf_conn *ct) { @@ -1510,7 +1506,7 @@ static int tcp_kmemdup_sysctl_table(struct nf_proto_net *pn, return 0; } -static int tcp_init_net(struct net *net, u_int16_t proto) +static int tcp_init_net(struct net *net) { struct nf_tcp_net *tn = tcp_pernet(net); struct nf_proto_net *pn = &tn->pn; @@ -1538,16 +1534,13 @@ static struct nf_proto_net *tcp_get_net_proto(struct net *net) return &net->ct.nf_ct_proto.tcp.pn; } -const struct nf_conntrack_l4proto nf_conntrack_l4proto_tcp4 = +const struct nf_conntrack_l4proto nf_conntrack_l4proto_tcp = { - .l3proto = PF_INET, .l4proto = IPPROTO_TCP, #ifdef CONFIG_NF_CONNTRACK_PROCFS .print_conntrack = tcp_print_conntrack, #endif .packet = tcp_packet, - .new = tcp_new, - .error = tcp_error, .can_early_drop = tcp_can_early_drop, #if IS_ENABLED(CONFIG_NF_CT_NETLINK) .to_nlattr = tcp_to_nlattr, @@ -1571,39 +1564,3 @@ const struct nf_conntrack_l4proto nf_conntrack_l4proto_tcp4 = .init_net = tcp_init_net, .get_net_proto = tcp_get_net_proto, }; -EXPORT_SYMBOL_GPL(nf_conntrack_l4proto_tcp4); - -const struct nf_conntrack_l4proto nf_conntrack_l4proto_tcp6 = -{ - .l3proto = PF_INET6, - .l4proto = IPPROTO_TCP, -#ifdef CONFIG_NF_CONNTRACK_PROCFS - .print_conntrack = tcp_print_conntrack, -#endif - .packet = tcp_packet, - .new = tcp_new, - .error = tcp_error, - .can_early_drop = tcp_can_early_drop, -#if IS_ENABLED(CONFIG_NF_CT_NETLINK) - .nlattr_size = TCP_NLATTR_SIZE, - .to_nlattr = tcp_to_nlattr, - .from_nlattr = nlattr_to_tcp, - .tuple_to_nlattr = nf_ct_port_tuple_to_nlattr, - .nlattr_to_tuple = nf_ct_port_nlattr_to_tuple, - .nlattr_tuple_size = tcp_nlattr_tuple_size, - .nla_policy = nf_ct_port_nla_policy, -#endif -#ifdef CONFIG_NF_CONNTRACK_TIMEOUT - .ctnl_timeout = { - .nlattr_to_obj = tcp_timeout_nlattr_to_obj, - .obj_to_nlattr = tcp_timeout_obj_to_nlattr, - .nlattr_max = CTA_TIMEOUT_TCP_MAX, - .obj_size = sizeof(unsigned int) * - TCP_CONNTRACK_TIMEOUT_MAX, - .nla_policy = tcp_timeout_nla_policy, - }, -#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ - .init_net = tcp_init_net, - .get_net_proto = tcp_get_net_proto, -}; -EXPORT_SYMBOL_GPL(nf_conntrack_l4proto_tcp6); diff --git a/net/netfilter/nf_conntrack_proto_udp.c b/net/netfilter/nf_conntrack_proto_udp.c index 3065fb8ef91b..a7aa70370913 100644 --- a/net/netfilter/nf_conntrack_proto_udp.c +++ b/net/netfilter/nf_conntrack_proto_udp.c @@ -42,14 +42,65 @@ static unsigned int *udp_get_timeouts(struct net *net) return udp_pernet(net)->timeouts; } +static void udp_error_log(const struct sk_buff *skb, + const struct nf_hook_state *state, + const char *msg) +{ + nf_l4proto_log_invalid(skb, state->net, state->pf, + IPPROTO_UDP, "%s", msg); +} + +static bool udp_error(struct sk_buff *skb, + unsigned int dataoff, + const struct nf_hook_state *state) +{ + unsigned int udplen = skb->len - dataoff; + const struct udphdr *hdr; + struct udphdr _hdr; + + /* Header is too small? */ + hdr = skb_header_pointer(skb, dataoff, sizeof(_hdr), &_hdr); + if (!hdr) { + udp_error_log(skb, state, "short packet"); + return true; + } + + /* Truncated/malformed packets */ + if (ntohs(hdr->len) > udplen || ntohs(hdr->len) < sizeof(*hdr)) { + udp_error_log(skb, state, "truncated/malformed packet"); + return true; + } + + /* Packet with no checksum */ + if (!hdr->check) + return false; + + /* Checksum invalid? Ignore. + * We skip checking packets on the outgoing path + * because the checksum is assumed to be correct. + * FIXME: Source route IP option packets --RR */ + if (state->hook == NF_INET_PRE_ROUTING && + state->net->ct.sysctl_checksum && + nf_checksum(skb, state->hook, dataoff, IPPROTO_UDP, state->pf)) { + udp_error_log(skb, state, "bad checksum"); + return true; + } + + return false; +} + /* Returns verdict for packet, and may modify conntracktype */ static int udp_packet(struct nf_conn *ct, - const struct sk_buff *skb, + struct sk_buff *skb, unsigned int dataoff, - enum ip_conntrack_info ctinfo) + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) { unsigned int *timeouts; + if (udp_error(skb, dataoff, state)) + return -NF_ACCEPT; + timeouts = nf_ct_timeout_lookup(ct); if (!timeouts) timeouts = udp_get_timeouts(nf_ct_net(ct)); @@ -69,24 +120,18 @@ static int udp_packet(struct nf_conn *ct, return NF_ACCEPT; } -/* Called when a new connection for this protocol found. */ -static bool udp_new(struct nf_conn *ct, const struct sk_buff *skb, - unsigned int dataoff) -{ - return true; -} - #ifdef CONFIG_NF_CT_PROTO_UDPLITE -static void udplite_error_log(const struct sk_buff *skb, struct net *net, - u8 pf, const char *msg) +static void udplite_error_log(const struct sk_buff *skb, + const struct nf_hook_state *state, + const char *msg) { - nf_l4proto_log_invalid(skb, net, pf, IPPROTO_UDPLITE, "%s", msg); + nf_l4proto_log_invalid(skb, state->net, state->pf, + IPPROTO_UDPLITE, "%s", msg); } -static int udplite_error(struct net *net, struct nf_conn *tmpl, - struct sk_buff *skb, - unsigned int dataoff, - u8 pf, unsigned int hooknum) +static bool udplite_error(struct sk_buff *skb, + unsigned int dataoff, + const struct nf_hook_state *state) { unsigned int udplen = skb->len - dataoff; const struct udphdr *hdr; @@ -96,80 +141,67 @@ static int udplite_error(struct net *net, struct nf_conn *tmpl, /* Header is too small? */ hdr = skb_header_pointer(skb, dataoff, sizeof(_hdr), &_hdr); if (!hdr) { - udplite_error_log(skb, net, pf, "short packet"); - return -NF_ACCEPT; + udplite_error_log(skb, state, "short packet"); + return true; } cscov = ntohs(hdr->len); if (cscov == 0) { cscov = udplen; } else if (cscov < sizeof(*hdr) || cscov > udplen) { - udplite_error_log(skb, net, pf, "invalid checksum coverage"); - return -NF_ACCEPT; + udplite_error_log(skb, state, "invalid checksum coverage"); + return true; } /* UDPLITE mandates checksums */ if (!hdr->check) { - udplite_error_log(skb, net, pf, "checksum missing"); - return -NF_ACCEPT; + udplite_error_log(skb, state, "checksum missing"); + return true; } /* Checksum invalid? Ignore. */ - if (net->ct.sysctl_checksum && hooknum == NF_INET_PRE_ROUTING && - nf_checksum_partial(skb, hooknum, dataoff, cscov, IPPROTO_UDP, - pf)) { - udplite_error_log(skb, net, pf, "bad checksum"); - return -NF_ACCEPT; + if (state->hook == NF_INET_PRE_ROUTING && + state->net->ct.sysctl_checksum && + nf_checksum_partial(skb, state->hook, dataoff, cscov, IPPROTO_UDP, + state->pf)) { + udplite_error_log(skb, state, "bad checksum"); + return true; } - return NF_ACCEPT; -} -#endif - -static void udp_error_log(const struct sk_buff *skb, struct net *net, - u8 pf, const char *msg) -{ - nf_l4proto_log_invalid(skb, net, pf, IPPROTO_UDP, "%s", msg); + return false; } -static int udp_error(struct net *net, struct nf_conn *tmpl, struct sk_buff *skb, - unsigned int dataoff, - u_int8_t pf, - unsigned int hooknum) +/* Returns verdict for packet, and may modify conntracktype */ +static int udplite_packet(struct nf_conn *ct, + struct sk_buff *skb, + unsigned int dataoff, + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) { - unsigned int udplen = skb->len - dataoff; - const struct udphdr *hdr; - struct udphdr _hdr; - - /* Header is too small? */ - hdr = skb_header_pointer(skb, dataoff, sizeof(_hdr), &_hdr); - if (hdr == NULL) { - udp_error_log(skb, net, pf, "short packet"); - return -NF_ACCEPT; - } + unsigned int *timeouts; - /* Truncated/malformed packets */ - if (ntohs(hdr->len) > udplen || ntohs(hdr->len) < sizeof(*hdr)) { - udp_error_log(skb, net, pf, "truncated/malformed packet"); + if (udplite_error(skb, dataoff, state)) return -NF_ACCEPT; - } - /* Packet with no checksum */ - if (!hdr->check) - return NF_ACCEPT; + timeouts = nf_ct_timeout_lookup(ct); + if (!timeouts) + timeouts = udp_get_timeouts(nf_ct_net(ct)); - /* Checksum invalid? Ignore. - * We skip checking packets on the outgoing path - * because the checksum is assumed to be correct. - * FIXME: Source route IP option packets --RR */ - if (net->ct.sysctl_checksum && hooknum == NF_INET_PRE_ROUTING && - nf_checksum(skb, hooknum, dataoff, IPPROTO_UDP, pf)) { - udp_error_log(skb, net, pf, "bad checksum"); - return -NF_ACCEPT; + /* If we've seen traffic both ways, this is some kind of UDP + stream. Extend timeout. */ + if (test_bit(IPS_SEEN_REPLY_BIT, &ct->status)) { + nf_ct_refresh_acct(ct, ctinfo, skb, + timeouts[UDP_CT_REPLIED]); + /* Also, more likely to be important, and not a probe */ + if (!test_and_set_bit(IPS_ASSURED_BIT, &ct->status)) + nf_conntrack_event_cache(IPCT_ASSURED, ct); + } else { + nf_ct_refresh_acct(ct, ctinfo, skb, + timeouts[UDP_CT_UNREPLIED]); } - return NF_ACCEPT; } +#endif #ifdef CONFIG_NF_CONNTRACK_TIMEOUT @@ -258,7 +290,7 @@ static int udp_kmemdup_sysctl_table(struct nf_proto_net *pn, return 0; } -static int udp_init_net(struct net *net, u_int16_t proto) +static int udp_init_net(struct net *net) { struct nf_udp_net *un = udp_pernet(net); struct nf_proto_net *pn = &un->pn; @@ -278,72 +310,11 @@ static struct nf_proto_net *udp_get_net_proto(struct net *net) return &net->ct.nf_ct_proto.udp.pn; } -const struct nf_conntrack_l4proto nf_conntrack_l4proto_udp4 = -{ - .l3proto = PF_INET, - .l4proto = IPPROTO_UDP, - .allow_clash = true, - .packet = udp_packet, - .new = udp_new, - .error = udp_error, -#if IS_ENABLED(CONFIG_NF_CT_NETLINK) - .tuple_to_nlattr = nf_ct_port_tuple_to_nlattr, - .nlattr_to_tuple = nf_ct_port_nlattr_to_tuple, - .nlattr_tuple_size = nf_ct_port_nlattr_tuple_size, - .nla_policy = nf_ct_port_nla_policy, -#endif -#ifdef CONFIG_NF_CONNTRACK_TIMEOUT - .ctnl_timeout = { - .nlattr_to_obj = udp_timeout_nlattr_to_obj, - .obj_to_nlattr = udp_timeout_obj_to_nlattr, - .nlattr_max = CTA_TIMEOUT_UDP_MAX, - .obj_size = sizeof(unsigned int) * CTA_TIMEOUT_UDP_MAX, - .nla_policy = udp_timeout_nla_policy, - }, -#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ - .init_net = udp_init_net, - .get_net_proto = udp_get_net_proto, -}; -EXPORT_SYMBOL_GPL(nf_conntrack_l4proto_udp4); - -#ifdef CONFIG_NF_CT_PROTO_UDPLITE -const struct nf_conntrack_l4proto nf_conntrack_l4proto_udplite4 = -{ - .l3proto = PF_INET, - .l4proto = IPPROTO_UDPLITE, - .allow_clash = true, - .packet = udp_packet, - .new = udp_new, - .error = udplite_error, -#if IS_ENABLED(CONFIG_NF_CT_NETLINK) - .tuple_to_nlattr = nf_ct_port_tuple_to_nlattr, - .nlattr_to_tuple = nf_ct_port_nlattr_to_tuple, - .nlattr_tuple_size = nf_ct_port_nlattr_tuple_size, - .nla_policy = nf_ct_port_nla_policy, -#endif -#ifdef CONFIG_NF_CONNTRACK_TIMEOUT - .ctnl_timeout = { - .nlattr_to_obj = udp_timeout_nlattr_to_obj, - .obj_to_nlattr = udp_timeout_obj_to_nlattr, - .nlattr_max = CTA_TIMEOUT_UDP_MAX, - .obj_size = sizeof(unsigned int) * CTA_TIMEOUT_UDP_MAX, - .nla_policy = udp_timeout_nla_policy, - }, -#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ - .init_net = udp_init_net, - .get_net_proto = udp_get_net_proto, -}; -EXPORT_SYMBOL_GPL(nf_conntrack_l4proto_udplite4); -#endif - -const struct nf_conntrack_l4proto nf_conntrack_l4proto_udp6 = +const struct nf_conntrack_l4proto nf_conntrack_l4proto_udp = { - .l3proto = PF_INET6, .l4proto = IPPROTO_UDP, .allow_clash = true, .packet = udp_packet, - .new = udp_new, - .error = udp_error, #if IS_ENABLED(CONFIG_NF_CT_NETLINK) .tuple_to_nlattr = nf_ct_port_tuple_to_nlattr, .nlattr_to_tuple = nf_ct_port_nlattr_to_tuple, @@ -362,17 +333,13 @@ const struct nf_conntrack_l4proto nf_conntrack_l4proto_udp6 = .init_net = udp_init_net, .get_net_proto = udp_get_net_proto, }; -EXPORT_SYMBOL_GPL(nf_conntrack_l4proto_udp6); #ifdef CONFIG_NF_CT_PROTO_UDPLITE -const struct nf_conntrack_l4proto nf_conntrack_l4proto_udplite6 = +const struct nf_conntrack_l4proto nf_conntrack_l4proto_udplite = { - .l3proto = PF_INET6, .l4proto = IPPROTO_UDPLITE, .allow_clash = true, - .packet = udp_packet, - .new = udp_new, - .error = udplite_error, + .packet = udplite_packet, #if IS_ENABLED(CONFIG_NF_CT_NETLINK) .tuple_to_nlattr = nf_ct_port_tuple_to_nlattr, .nlattr_to_tuple = nf_ct_port_nlattr_to_tuple, @@ -391,5 +358,4 @@ const struct nf_conntrack_l4proto nf_conntrack_l4proto_udplite6 = .init_net = udp_init_net, .get_net_proto = udp_get_net_proto, }; -EXPORT_SYMBOL_GPL(nf_conntrack_l4proto_udplite6); #endif diff --git a/net/netfilter/nf_conntrack_standalone.c b/net/netfilter/nf_conntrack_standalone.c index 13279f683da9..463d17d349c1 100644 --- a/net/netfilter/nf_conntrack_standalone.c +++ b/net/netfilter/nf_conntrack_standalone.c @@ -292,7 +292,7 @@ static int ct_seq_show(struct seq_file *s, void *v) if (!net_eq(nf_ct_net(ct), net)) goto release; - l4proto = __nf_ct_l4proto_find(nf_ct_l3num(ct), nf_ct_protonum(ct)); + l4proto = __nf_ct_l4proto_find(nf_ct_protonum(ct)); WARN_ON(!l4proto); ret = -ENOSPC; @@ -720,10 +720,3 @@ static void __exit nf_conntrack_standalone_fini(void) module_init(nf_conntrack_standalone_init); module_exit(nf_conntrack_standalone_fini); - -/* Some modules need us, but don't depend directly on any symbol. - They should call this. */ -void need_conntrack(void) -{ -} -EXPORT_SYMBOL_GPL(need_conntrack); diff --git a/net/netfilter/nf_flow_table_core.c b/net/netfilter/nf_flow_table_core.c index d8125616edc7..b7a4816add76 100644 --- a/net/netfilter/nf_flow_table_core.c +++ b/net/netfilter/nf_flow_table_core.c @@ -120,7 +120,7 @@ static void flow_offload_fixup_ct_state(struct nf_conn *ct) if (l4num == IPPROTO_TCP) flow_offload_fixup_tcp(&ct->proto.tcp); - l4proto = __nf_ct_l4proto_find(nf_ct_l3num(ct), l4num); + l4proto = __nf_ct_l4proto_find(l4num); if (!l4proto) return; @@ -233,8 +233,8 @@ flow_offload_lookup(struct nf_flowtable *flow_table, struct flow_offload *flow; int dir; - tuplehash = rhashtable_lookup_fast(&flow_table->rhashtable, tuple, - nf_flow_offload_rhash_params); + tuplehash = rhashtable_lookup(&flow_table->rhashtable, tuple, + nf_flow_offload_rhash_params); if (!tuplehash) return NULL; @@ -254,20 +254,17 @@ int nf_flow_table_iterate(struct nf_flowtable *flow_table, struct flow_offload_tuple_rhash *tuplehash; struct rhashtable_iter hti; struct flow_offload *flow; - int err; - - err = rhashtable_walk_init(&flow_table->rhashtable, &hti, GFP_KERNEL); - if (err) - return err; + int err = 0; + rhashtable_walk_enter(&flow_table->rhashtable, &hti); rhashtable_walk_start(&hti); while ((tuplehash = rhashtable_walk_next(&hti))) { if (IS_ERR(tuplehash)) { - err = PTR_ERR(tuplehash); - if (err != -EAGAIN) - goto out; - + if (PTR_ERR(tuplehash) != -EAGAIN) { + err = PTR_ERR(tuplehash); + break; + } continue; } if (tuplehash->tuple.dir) @@ -277,7 +274,6 @@ int nf_flow_table_iterate(struct nf_flowtable *flow_table, iter(flow, data); } -out: rhashtable_walk_stop(&hti); rhashtable_walk_exit(&hti); @@ -290,25 +286,19 @@ static inline bool nf_flow_has_expired(const struct flow_offload *flow) return (__s32)(flow->timeout - (u32)jiffies) <= 0; } -static int nf_flow_offload_gc_step(struct nf_flowtable *flow_table) +static void nf_flow_offload_gc_step(struct nf_flowtable *flow_table) { struct flow_offload_tuple_rhash *tuplehash; struct rhashtable_iter hti; struct flow_offload *flow; - int err; - - err = rhashtable_walk_init(&flow_table->rhashtable, &hti, GFP_KERNEL); - if (err) - return 0; + rhashtable_walk_enter(&flow_table->rhashtable, &hti); rhashtable_walk_start(&hti); while ((tuplehash = rhashtable_walk_next(&hti))) { if (IS_ERR(tuplehash)) { - err = PTR_ERR(tuplehash); - if (err != -EAGAIN) - goto out; - + if (PTR_ERR(tuplehash) != -EAGAIN) + break; continue; } if (tuplehash->tuple.dir) @@ -321,11 +311,8 @@ static int nf_flow_offload_gc_step(struct nf_flowtable *flow_table) FLOW_OFFLOAD_TEARDOWN))) flow_offload_del(flow_table, flow); } -out: rhashtable_walk_stop(&hti); rhashtable_walk_exit(&hti); - - return 1; } static void nf_flow_offload_work_gc(struct work_struct *work) @@ -478,14 +465,17 @@ EXPORT_SYMBOL_GPL(nf_flow_table_init); static void nf_flow_table_do_cleanup(struct flow_offload *flow, void *data) { struct net_device *dev = data; + struct flow_offload_entry *e; + + e = container_of(flow, struct flow_offload_entry, flow); if (!dev) { flow_offload_teardown(flow); return; } - - if (flow->tuplehash[0].tuple.iifidx == dev->ifindex || - flow->tuplehash[1].tuple.iifidx == dev->ifindex) + if (net_eq(nf_ct_net(e->ct), dev_net(dev)) && + (flow->tuplehash[0].tuple.iifidx == dev->ifindex || + flow->tuplehash[1].tuple.iifidx == dev->ifindex)) flow_offload_dead(flow); } @@ -496,7 +486,7 @@ static void nf_flow_table_iterate_cleanup(struct nf_flowtable *flowtable, flush_delayed_work(&flowtable->gc_work); } -void nf_flow_table_cleanup(struct net *net, struct net_device *dev) +void nf_flow_table_cleanup(struct net_device *dev) { struct nf_flowtable *flowtable; @@ -514,7 +504,7 @@ void nf_flow_table_free(struct nf_flowtable *flow_table) mutex_unlock(&flowtable_lock); cancel_delayed_work_sync(&flow_table->gc_work); nf_flow_table_iterate(flow_table, nf_flow_table_do_cleanup, NULL); - WARN_ON(!nf_flow_offload_gc_step(flow_table)); + nf_flow_offload_gc_step(flow_table); rhashtable_destroy(&flow_table->rhashtable); } EXPORT_SYMBOL_GPL(nf_flow_table_free); diff --git a/net/netfilter/nf_flow_table_ip.c b/net/netfilter/nf_flow_table_ip.c index 15ed91309992..1d291a51cd45 100644 --- a/net/netfilter/nf_flow_table_ip.c +++ b/net/netfilter/nf_flow_table_ip.c @@ -254,8 +254,7 @@ nf_flow_offload_ip_hook(void *priv, struct sk_buff *skb, if (nf_flow_state_check(flow, ip_hdr(skb)->protocol, skb, thoff)) return NF_ACCEPT; - if (flow->flags & (FLOW_OFFLOAD_SNAT | FLOW_OFFLOAD_DNAT) && - nf_flow_nat_ip(flow, skb, thoff, dir) < 0) + if (nf_flow_nat_ip(flow, skb, thoff, dir) < 0) return NF_DROP; flow->timeout = (u32)jiffies + NF_FLOW_TIMEOUT; @@ -471,8 +470,7 @@ nf_flow_offload_ipv6_hook(void *priv, struct sk_buff *skb, if (skb_try_make_writable(skb, sizeof(*ip6h))) return NF_DROP; - if (flow->flags & (FLOW_OFFLOAD_SNAT | FLOW_OFFLOAD_DNAT) && - nf_flow_nat_ipv6(flow, skb, dir) < 0) + if (nf_flow_nat_ipv6(flow, skb, dir) < 0) return NF_DROP; flow->timeout = (u32)jiffies + NF_FLOW_TIMEOUT; diff --git a/net/netfilter/nf_nat_helper.c b/net/netfilter/nf_nat_helper.c index 99606baedda4..38793b95d9bc 100644 --- a/net/netfilter/nf_nat_helper.c +++ b/net/netfilter/nf_nat_helper.c @@ -37,7 +37,7 @@ static void mangle_contents(struct sk_buff *skb, { unsigned char *data; - BUG_ON(skb_is_nonlinear(skb)); + SKB_LINEAR_ASSERT(skb); data = skb_network_header(skb) + dataoff; /* move post-replacement */ @@ -110,8 +110,6 @@ bool __nf_nat_mangle_tcp_packet(struct sk_buff *skb, !enlarge_skb(skb, rep_len - match_len)) return false; - SKB_LINEAR_ASSERT(skb); - tcph = (void *)skb->data + protoff; oldlen = skb->len - protoff; diff --git a/net/netfilter/nf_nat_redirect.c b/net/netfilter/nf_nat_redirect.c index adee04af8d43..78a9e6454ff3 100644 --- a/net/netfilter/nf_nat_redirect.c +++ b/net/netfilter/nf_nat_redirect.c @@ -52,13 +52,11 @@ nf_nat_redirect_ipv4(struct sk_buff *skb, newdst = 0; - rcu_read_lock(); indev = __in_dev_get_rcu(skb->dev); if (indev && indev->ifa_list) { ifa = indev->ifa_list; newdst = ifa->ifa_local; } - rcu_read_unlock(); if (!newdst) return NF_DROP; @@ -97,7 +95,6 @@ nf_nat_redirect_ipv6(struct sk_buff *skb, const struct nf_nat_range2 *range, struct inet6_ifaddr *ifa; bool addr = false; - rcu_read_lock(); idev = __in6_dev_get(skb->dev); if (idev != NULL) { read_lock_bh(&idev->lock); @@ -108,7 +105,6 @@ nf_nat_redirect_ipv6(struct sk_buff *skb, const struct nf_nat_range2 *range, } read_unlock_bh(&idev->lock); } - rcu_read_unlock(); if (!addr) return NF_DROP; diff --git a/net/netfilter/nf_tables_api.c b/net/netfilter/nf_tables_api.c index 2cfb173cd0b2..42487d01a3ed 100644 --- a/net/netfilter/nf_tables_api.c +++ b/net/netfilter/nf_tables_api.c @@ -27,6 +27,8 @@ static LIST_HEAD(nf_tables_expressions); static LIST_HEAD(nf_tables_objects); static LIST_HEAD(nf_tables_flowtables); +static LIST_HEAD(nf_tables_destroy_list); +static DEFINE_SPINLOCK(nf_tables_destroy_list_lock); static u64 table_handle; enum { @@ -64,6 +66,8 @@ static void nft_validate_state_update(struct net *net, u8 new_validate_state) net->nft.validate_state = new_validate_state; } +static void nf_tables_trans_destroy_work(struct work_struct *w); +static DECLARE_WORK(trans_destroy_work, nf_tables_trans_destroy_work); static void nft_ctx_init(struct nft_ctx *ctx, struct net *net, @@ -207,6 +211,18 @@ static int nft_delchain(struct nft_ctx *ctx) return err; } +/* either expr ops provide both activate/deactivate, or neither */ +static bool nft_expr_check_ops(const struct nft_expr_ops *ops) +{ + if (!ops) + return true; + + if (WARN_ON_ONCE((!ops->activate ^ !ops->deactivate))) + return false; + + return true; +} + static void nft_rule_expr_activate(const struct nft_ctx *ctx, struct nft_rule *rule) { @@ -298,7 +314,7 @@ static int nft_delrule_by_chain(struct nft_ctx *ctx) return 0; } -static int nft_trans_set_add(struct nft_ctx *ctx, int msg_type, +static int nft_trans_set_add(const struct nft_ctx *ctx, int msg_type, struct nft_set *set) { struct nft_trans *trans; @@ -318,7 +334,7 @@ static int nft_trans_set_add(struct nft_ctx *ctx, int msg_type, return 0; } -static int nft_delset(struct nft_ctx *ctx, struct nft_set *set) +static int nft_delset(const struct nft_ctx *ctx, struct nft_set *set) { int err; @@ -1005,7 +1021,8 @@ static int nf_tables_deltable(struct net *net, struct sock *nlsk, static void nf_tables_table_destroy(struct nft_ctx *ctx) { - BUG_ON(ctx->table->use > 0); + if (WARN_ON(ctx->table->use > 0)) + return; rhltable_destroy(&ctx->table->chains_ht); kfree(ctx->table->name); @@ -1412,7 +1429,8 @@ static void nf_tables_chain_destroy(struct nft_ctx *ctx) { struct nft_chain *chain = ctx->chain; - BUG_ON(chain->use > 0); + if (WARN_ON(chain->use > 0)) + return; /* no concurrent access possible anymore */ nf_tables_chain_free_chain_rules(chain); @@ -1907,6 +1925,9 @@ static int nf_tables_delchain(struct net *net, struct sock *nlsk, */ int nft_register_expr(struct nft_expr_type *type) { + if (!nft_expr_check_ops(type->ops)) + return -EINVAL; + nfnl_lock(NFNL_SUBSYS_NFTABLES); if (type->family == NFPROTO_UNSPEC) list_add_tail_rcu(&type->list, &nf_tables_expressions); @@ -2054,6 +2075,10 @@ static int nf_tables_expr_parse(const struct nft_ctx *ctx, err = PTR_ERR(ops); goto err1; } + if (!nft_expr_check_ops(ops)) { + err = -EINVAL; + goto err1; + } } else ops = type->ops; @@ -2434,7 +2459,6 @@ static void nf_tables_rule_destroy(const struct nft_ctx *ctx, { struct nft_expr *expr; - lockdep_assert_held(&ctx->net->nft.commit_mutex); /* * Careful: some expressions might not be initialized in case this * is called on error from nf_tables_newrule(). @@ -3567,13 +3591,6 @@ static void nft_set_destroy(struct nft_set *set) kvfree(set); } -static void nf_tables_set_destroy(const struct nft_ctx *ctx, struct nft_set *set) -{ - list_del_rcu(&set->list); - nf_tables_set_notify(ctx, set, NFT_MSG_DELSET, GFP_ATOMIC); - nft_set_destroy(set); -} - static int nf_tables_delset(struct net *net, struct sock *nlsk, struct sk_buff *skb, const struct nlmsghdr *nlh, const struct nlattr * const nla[], @@ -3668,17 +3685,38 @@ bind: } EXPORT_SYMBOL_GPL(nf_tables_bind_set); -void nf_tables_unbind_set(const struct nft_ctx *ctx, struct nft_set *set, +void nf_tables_rebind_set(const struct nft_ctx *ctx, struct nft_set *set, struct nft_set_binding *binding) { + if (list_empty(&set->bindings) && nft_set_is_anonymous(set) && + nft_is_active(ctx->net, set)) + list_add_tail_rcu(&set->list, &ctx->table->sets); + + list_add_tail_rcu(&binding->list, &set->bindings); +} +EXPORT_SYMBOL_GPL(nf_tables_rebind_set); + +void nf_tables_unbind_set(const struct nft_ctx *ctx, struct nft_set *set, + struct nft_set_binding *binding) +{ list_del_rcu(&binding->list); if (list_empty(&set->bindings) && nft_set_is_anonymous(set) && nft_is_active(ctx->net, set)) - nf_tables_set_destroy(ctx, set); + list_del_rcu(&set->list); } EXPORT_SYMBOL_GPL(nf_tables_unbind_set); +void nf_tables_destroy_set(const struct nft_ctx *ctx, struct nft_set *set) +{ + if (list_empty(&set->bindings) && nft_set_is_anonymous(set) && + nft_is_active(ctx->net, set)) { + nf_tables_set_notify(ctx, set, NFT_MSG_DELSET, GFP_ATOMIC); + nft_set_destroy(set); + } +} +EXPORT_SYMBOL_GPL(nf_tables_destroy_set); + const struct nft_set_ext_type nft_set_ext_types[] = { [NFT_SET_EXT_KEY] = { .align = __alignof__(u32), @@ -6191,19 +6229,28 @@ static void nft_commit_release(struct nft_trans *trans) nf_tables_flowtable_destroy(nft_trans_flowtable(trans)); break; } + + if (trans->put_net) + put_net(trans->ctx.net); + kfree(trans); } -static void nf_tables_commit_release(struct net *net) +static void nf_tables_trans_destroy_work(struct work_struct *w) { struct nft_trans *trans, *next; + LIST_HEAD(head); - if (list_empty(&net->nft.commit_list)) + spin_lock(&nf_tables_destroy_list_lock); + list_splice_init(&nf_tables_destroy_list, &head); + spin_unlock(&nf_tables_destroy_list_lock); + + if (list_empty(&head)) return; synchronize_rcu(); - list_for_each_entry_safe(trans, next, &net->nft.commit_list, list) { + list_for_each_entry_safe(trans, next, &head, list) { list_del(&trans->list); nft_commit_release(trans); } @@ -6334,6 +6381,37 @@ static void nft_chain_del(struct nft_chain *chain) list_del_rcu(&chain->list); } +static void nf_tables_commit_release(struct net *net) +{ + struct nft_trans *trans; + + /* all side effects have to be made visible. + * For example, if a chain named 'foo' has been deleted, a + * new transaction must not find it anymore. + * + * Memory reclaim happens asynchronously from work queue + * to prevent expensive synchronize_rcu() in commit phase. + */ + if (list_empty(&net->nft.commit_list)) { + mutex_unlock(&net->nft.commit_mutex); + return; + } + + trans = list_last_entry(&net->nft.commit_list, + struct nft_trans, list); + get_net(trans->ctx.net); + WARN_ON_ONCE(trans->put_net); + + trans->put_net = true; + spin_lock(&nf_tables_destroy_list_lock); + list_splice_tail_init(&net->nft.commit_list, &nf_tables_destroy_list); + spin_unlock(&nf_tables_destroy_list_lock); + + mutex_unlock(&net->nft.commit_mutex); + + schedule_work(&trans_destroy_work); +} + static int nf_tables_commit(struct net *net, struct sk_buff *skb) { struct nft_trans *trans, *next; @@ -6495,9 +6573,8 @@ static int nf_tables_commit(struct net *net, struct sk_buff *skb) } } - nf_tables_commit_release(net); nf_tables_gen_notify(net, skb, NFT_MSG_NEWGEN); - mutex_unlock(&net->nft.commit_mutex); + nf_tables_commit_release(net); return 0; } @@ -7168,7 +7245,8 @@ int __nft_release_basechain(struct nft_ctx *ctx) { struct nft_rule *rule, *nr; - BUG_ON(!nft_is_base_chain(ctx->chain)); + if (WARN_ON(!nft_is_base_chain(ctx->chain))) + return 0; nf_tables_unregister_hook(ctx->net, ctx->chain->table, ctx->chain); list_for_each_entry_safe(rule, nr, &ctx->chain->rules, list) { @@ -7202,9 +7280,6 @@ static void __nft_release_tables(struct net *net) list_for_each_entry(chain, &table->chains, list) nf_tables_unregister_hook(net, table, chain); - list_for_each_entry(flowtable, &table->flowtables, list) - nf_unregister_net_hooks(net, flowtable->ops, - flowtable->ops_len); /* No packets are walking on these chains anymore. */ ctx.table = table; list_for_each_entry(chain, &table->chains, list) { @@ -7271,6 +7346,7 @@ static int __init nf_tables_module_init(void) { int err; + spin_lock_init(&nf_tables_destroy_list_lock); err = register_pernet_subsys(&nf_tables_net_ops); if (err < 0) return err; @@ -7310,6 +7386,7 @@ static void __exit nf_tables_module_exit(void) unregister_netdevice_notifier(&nf_tables_flowtable_notifier); nft_chain_filter_fini(); unregister_pernet_subsys(&nf_tables_net_ops); + cancel_work_sync(&trans_destroy_work); rcu_barrier(); nf_tables_core_module_exit(); } diff --git a/net/netfilter/nf_tables_core.c b/net/netfilter/nf_tables_core.c index ffd5c0f9412b..3fbce3b9c5ec 100644 --- a/net/netfilter/nf_tables_core.c +++ b/net/netfilter/nf_tables_core.c @@ -249,12 +249,24 @@ static struct nft_expr_type *nft_basic_types[] = { &nft_exthdr_type, }; +static struct nft_object_type *nft_basic_objects[] = { +#ifdef CONFIG_NETWORK_SECMARK + &nft_secmark_obj_type, +#endif +}; + int __init nf_tables_core_module_init(void) { - int err, i; + int err, i, j = 0; + + for (i = 0; i < ARRAY_SIZE(nft_basic_objects); i++) { + err = nft_register_obj(nft_basic_objects[i]); + if (err) + goto err; + } - for (i = 0; i < ARRAY_SIZE(nft_basic_types); i++) { - err = nft_register_expr(nft_basic_types[i]); + for (j = 0; j < ARRAY_SIZE(nft_basic_types); j++) { + err = nft_register_expr(nft_basic_types[j]); if (err) goto err; } @@ -262,8 +274,12 @@ int __init nf_tables_core_module_init(void) return 0; err: + while (j-- > 0) + nft_unregister_expr(nft_basic_types[j]); + while (i-- > 0) - nft_unregister_expr(nft_basic_types[i]); + nft_unregister_obj(nft_basic_objects[i]); + return err; } @@ -274,4 +290,8 @@ void nf_tables_core_module_exit(void) i = ARRAY_SIZE(nft_basic_types); while (i-- > 0) nft_unregister_expr(nft_basic_types[i]); + + i = ARRAY_SIZE(nft_basic_objects); + while (i-- > 0) + nft_unregister_obj(nft_basic_objects[i]); } diff --git a/net/netfilter/nfnetlink_cttimeout.c b/net/netfilter/nfnetlink_cttimeout.c index a30f8ba4b89a..e7a50af1b3d6 100644 --- a/net/netfilter/nfnetlink_cttimeout.c +++ b/net/netfilter/nfnetlink_cttimeout.c @@ -53,9 +53,6 @@ ctnl_timeout_parse_policy(void *timeout, struct nlattr **tb; int ret = 0; - if (!l4proto->ctnl_timeout.nlattr_to_obj) - return 0; - tb = kcalloc(l4proto->ctnl_timeout.nlattr_max + 1, sizeof(*tb), GFP_KERNEL); @@ -125,7 +122,7 @@ static int cttimeout_new_timeout(struct net *net, struct sock *ctnl, return -EBUSY; } - l4proto = nf_ct_l4proto_find_get(l3num, l4num); + l4proto = nf_ct_l4proto_find_get(l4num); /* This protocol is not supportted, skip. */ if (l4proto->l4proto != l4num) { @@ -167,6 +164,8 @@ ctnl_timeout_fill_info(struct sk_buff *skb, u32 portid, u32 seq, u32 type, struct nfgenmsg *nfmsg; unsigned int flags = portid ? NLM_F_MULTI : 0; const struct nf_conntrack_l4proto *l4proto = timeout->timeout.l4proto; + struct nlattr *nest_parms; + int ret; event = nfnl_msg_type(NFNL_SUBSYS_CTNETLINK_TIMEOUT, event); nlh = nlmsg_put(skb, portid, seq, event, sizeof(*nfmsg), flags); @@ -186,22 +185,15 @@ ctnl_timeout_fill_info(struct sk_buff *skb, u32 portid, u32 seq, u32 type, htonl(refcount_read(&timeout->refcnt)))) goto nla_put_failure; - if (likely(l4proto->ctnl_timeout.obj_to_nlattr)) { - struct nlattr *nest_parms; - int ret; - - nest_parms = nla_nest_start(skb, - CTA_TIMEOUT_DATA | NLA_F_NESTED); - if (!nest_parms) - goto nla_put_failure; + nest_parms = nla_nest_start(skb, CTA_TIMEOUT_DATA | NLA_F_NESTED); + if (!nest_parms) + goto nla_put_failure; - ret = l4proto->ctnl_timeout.obj_to_nlattr(skb, - &timeout->timeout.data); - if (ret < 0) - goto nla_put_failure; + ret = l4proto->ctnl_timeout.obj_to_nlattr(skb, &timeout->timeout.data); + if (ret < 0) + goto nla_put_failure; - nla_nest_end(skb, nest_parms); - } + nla_nest_end(skb, nest_parms); nlmsg_end(skb, nlh); return skb->len; @@ -358,7 +350,6 @@ static int cttimeout_default_set(struct net *net, struct sock *ctnl, struct netlink_ext_ack *extack) { const struct nf_conntrack_l4proto *l4proto; - __u16 l3num; __u8 l4num; int ret; @@ -367,9 +358,8 @@ static int cttimeout_default_set(struct net *net, struct sock *ctnl, !cda[CTA_TIMEOUT_DATA]) return -EINVAL; - l3num = ntohs(nla_get_be16(cda[CTA_TIMEOUT_L3PROTO])); l4num = nla_get_u8(cda[CTA_TIMEOUT_L4PROTO]); - l4proto = nf_ct_l4proto_find_get(l3num, l4num); + l4proto = nf_ct_l4proto_find_get(l4num); /* This protocol is not supported, skip. */ if (l4proto->l4proto != l4num) { @@ -391,12 +381,14 @@ err: static int cttimeout_default_fill_info(struct net *net, struct sk_buff *skb, u32 portid, - u32 seq, u32 type, int event, + u32 seq, u32 type, int event, u16 l3num, const struct nf_conntrack_l4proto *l4proto) { struct nlmsghdr *nlh; struct nfgenmsg *nfmsg; unsigned int flags = portid ? NLM_F_MULTI : 0; + struct nlattr *nest_parms; + int ret; event = nfnl_msg_type(NFNL_SUBSYS_CTNETLINK_TIMEOUT, event); nlh = nlmsg_put(skb, portid, seq, event, sizeof(*nfmsg), flags); @@ -408,25 +400,19 @@ cttimeout_default_fill_info(struct net *net, struct sk_buff *skb, u32 portid, nfmsg->version = NFNETLINK_V0; nfmsg->res_id = 0; - if (nla_put_be16(skb, CTA_TIMEOUT_L3PROTO, htons(l4proto->l3proto)) || + if (nla_put_be16(skb, CTA_TIMEOUT_L3PROTO, htons(l3num)) || nla_put_u8(skb, CTA_TIMEOUT_L4PROTO, l4proto->l4proto)) goto nla_put_failure; - if (likely(l4proto->ctnl_timeout.obj_to_nlattr)) { - struct nlattr *nest_parms; - int ret; - - nest_parms = nla_nest_start(skb, - CTA_TIMEOUT_DATA | NLA_F_NESTED); - if (!nest_parms) - goto nla_put_failure; + nest_parms = nla_nest_start(skb, CTA_TIMEOUT_DATA | NLA_F_NESTED); + if (!nest_parms) + goto nla_put_failure; - ret = l4proto->ctnl_timeout.obj_to_nlattr(skb, NULL); - if (ret < 0) - goto nla_put_failure; + ret = l4proto->ctnl_timeout.obj_to_nlattr(skb, NULL); + if (ret < 0) + goto nla_put_failure; - nla_nest_end(skb, nest_parms); - } + nla_nest_end(skb, nest_parms); nlmsg_end(skb, nlh); return skb->len; @@ -454,7 +440,7 @@ static int cttimeout_default_get(struct net *net, struct sock *ctnl, l3num = ntohs(nla_get_be16(cda[CTA_TIMEOUT_L3PROTO])); l4num = nla_get_u8(cda[CTA_TIMEOUT_L4PROTO]); - l4proto = nf_ct_l4proto_find_get(l3num, l4num); + l4proto = nf_ct_l4proto_find_get(l4num); /* This protocol is not supported, skip. */ if (l4proto->l4proto != l4num) { @@ -472,6 +458,7 @@ static int cttimeout_default_get(struct net *net, struct sock *ctnl, nlh->nlmsg_seq, NFNL_MSG_TYPE(nlh->nlmsg_type), IPCTNL_MSG_TIMEOUT_DEFAULT_SET, + l3num, l4proto); if (ret <= 0) { kfree_skb(skb2); diff --git a/net/netfilter/nfnetlink_osf.c b/net/netfilter/nfnetlink_osf.c index 00db27dfd2ff..6f41dd74729d 100644 --- a/net/netfilter/nfnetlink_osf.c +++ b/net/netfilter/nfnetlink_osf.c @@ -30,32 +30,27 @@ EXPORT_SYMBOL_GPL(nf_osf_fingers); static inline int nf_osf_ttl(const struct sk_buff *skb, int ttl_check, unsigned char f_ttl) { + struct in_device *in_dev = __in_dev_get_rcu(skb->dev); const struct iphdr *ip = ip_hdr(skb); - - if (ttl_check != -1) { - if (ttl_check == NF_OSF_TTL_TRUE) - return ip->ttl == f_ttl; - if (ttl_check == NF_OSF_TTL_NOCHECK) - return 1; - else if (ip->ttl <= f_ttl) - return 1; - else { - struct in_device *in_dev = __in_dev_get_rcu(skb->dev); - int ret = 0; - - for_ifa(in_dev) { - if (inet_ifa_match(ip->saddr, ifa)) { - ret = (ip->ttl == f_ttl); - break; - } - } - endfor_ifa(in_dev); - - return ret; + int ret = 0; + + if (ttl_check == NF_OSF_TTL_TRUE) + return ip->ttl == f_ttl; + if (ttl_check == NF_OSF_TTL_NOCHECK) + return 1; + else if (ip->ttl <= f_ttl) + return 1; + + for_ifa(in_dev) { + if (inet_ifa_match(ip->saddr, ifa)) { + ret = (ip->ttl == f_ttl); + break; } } - return ip->ttl == f_ttl; + endfor_ifa(in_dev); + + return ret; } struct nf_osf_hdr_ctx { @@ -213,7 +208,7 @@ nf_osf_match(const struct sk_buff *skb, u_int8_t family, if (!tcp) return false; - ttl_check = (info->flags & NF_OSF_TTL) ? info->ttl : -1; + ttl_check = (info->flags & NF_OSF_TTL) ? info->ttl : 0; list_for_each_entry_rcu(kf, &nf_osf_fingers[ctx.df], finger_entry) { @@ -257,7 +252,8 @@ nf_osf_match(const struct sk_buff *skb, u_int8_t family, EXPORT_SYMBOL_GPL(nf_osf_match); const char *nf_osf_find(const struct sk_buff *skb, - const struct list_head *nf_osf_fingers) + const struct list_head *nf_osf_fingers, + const int ttl_check) { const struct iphdr *ip = ip_hdr(skb); const struct nf_osf_user_finger *f; @@ -275,7 +271,7 @@ const char *nf_osf_find(const struct sk_buff *skb, list_for_each_entry_rcu(kf, &nf_osf_fingers[ctx.df], finger_entry) { f = &kf->finger; - if (!nf_osf_match_one(skb, f, -1, &ctx)) + if (!nf_osf_match_one(skb, f, ttl_check, &ctx)) continue; genre = f->genre; diff --git a/net/netfilter/nft_cmp.c b/net/netfilter/nft_cmp.c index fa90a8402845..79d48c1d06f4 100644 --- a/net/netfilter/nft_cmp.c +++ b/net/netfilter/nft_cmp.c @@ -79,7 +79,8 @@ static int nft_cmp_init(const struct nft_ctx *ctx, const struct nft_expr *expr, err = nft_data_init(NULL, &priv->data, sizeof(priv->data), &desc, tb[NFTA_CMP_DATA]); - BUG_ON(err < 0); + if (err < 0) + return err; priv->sreg = nft_parse_register(tb[NFTA_CMP_SREG]); err = nft_validate_register_load(priv->sreg, desc.len); @@ -129,7 +130,8 @@ static int nft_cmp_fast_init(const struct nft_ctx *ctx, err = nft_data_init(NULL, &data, sizeof(data), &desc, tb[NFTA_CMP_DATA]); - BUG_ON(err < 0); + if (err < 0) + return err; priv->sreg = nft_parse_register(tb[NFTA_CMP_SREG]); err = nft_validate_register_load(priv->sreg, desc.len); diff --git a/net/netfilter/nft_compat.c b/net/netfilter/nft_compat.c index 32535eea51b2..768292eac2a4 100644 --- a/net/netfilter/nft_compat.c +++ b/net/netfilter/nft_compat.c @@ -290,6 +290,24 @@ nft_target_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr) module_put(target->me); } +static int nft_extension_dump_info(struct sk_buff *skb, int attr, + const void *info, + unsigned int size, unsigned int user_size) +{ + unsigned int info_size, aligned_size = XT_ALIGN(size); + struct nlattr *nla; + + nla = nla_reserve(skb, attr, aligned_size); + if (!nla) + return -1; + + info_size = user_size ? : size; + memcpy(nla_data(nla), info, info_size); + memset(nla_data(nla) + info_size, 0, aligned_size - info_size); + + return 0; +} + static int nft_target_dump(struct sk_buff *skb, const struct nft_expr *expr) { const struct xt_target *target = expr->ops->data; @@ -297,7 +315,8 @@ static int nft_target_dump(struct sk_buff *skb, const struct nft_expr *expr) if (nla_put_string(skb, NFTA_TARGET_NAME, target->name) || nla_put_be32(skb, NFTA_TARGET_REV, htonl(target->revision)) || - nla_put(skb, NFTA_TARGET_INFO, XT_ALIGN(target->targetsize), info)) + nft_extension_dump_info(skb, NFTA_TARGET_INFO, info, + target->targetsize, target->usersize)) goto nla_put_failure; return 0; @@ -532,7 +551,8 @@ static int __nft_match_dump(struct sk_buff *skb, const struct nft_expr *expr, if (nla_put_string(skb, NFTA_MATCH_NAME, match->name) || nla_put_be32(skb, NFTA_MATCH_REV, htonl(match->revision)) || - nla_put(skb, NFTA_MATCH_INFO, XT_ALIGN(match->matchsize), info)) + nft_extension_dump_info(skb, NFTA_MATCH_INFO, info, + match->matchsize, match->usersize)) goto nla_put_failure; return 0; diff --git a/net/netfilter/nft_ct.c b/net/netfilter/nft_ct.c index 5dd87748afa8..586627c361df 100644 --- a/net/netfilter/nft_ct.c +++ b/net/netfilter/nft_ct.c @@ -279,7 +279,7 @@ static void nft_ct_set_eval(const struct nft_expr *expr, { const struct nft_ct *priv = nft_expr_priv(expr); struct sk_buff *skb = pkt->skb; -#ifdef CONFIG_NF_CONNTRACK_MARK +#if defined(CONFIG_NF_CONNTRACK_MARK) || defined(CONFIG_NF_CONNTRACK_SECMARK) u32 value = regs->data[priv->sreg]; #endif enum ip_conntrack_info ctinfo; @@ -298,6 +298,14 @@ static void nft_ct_set_eval(const struct nft_expr *expr, } break; #endif +#ifdef CONFIG_NF_CONNTRACK_SECMARK + case NFT_CT_SECMARK: + if (ct->secmark != value) { + ct->secmark = value; + nf_conntrack_event_cache(IPCT_SECMARK, ct); + } + break; +#endif #ifdef CONFIG_NF_CONNTRACK_LABELS case NFT_CT_LABELS: nf_connlabels_replace(ct, @@ -565,6 +573,13 @@ static int nft_ct_set_init(const struct nft_ctx *ctx, len = sizeof(u32); break; #endif +#ifdef CONFIG_NF_CONNTRACK_SECMARK + case NFT_CT_SECMARK: + if (tb[NFTA_CT_DIRECTION]) + return -EINVAL; + len = sizeof(u32); + break; +#endif default: return -EOPNOTSUPP; } @@ -776,9 +791,6 @@ nft_ct_timeout_parse_policy(void *timeouts, struct nlattr **tb; int ret = 0; - if (!l4proto->ctnl_timeout.nlattr_to_obj) - return 0; - tb = kcalloc(l4proto->ctnl_timeout.nlattr_max + 1, sizeof(*tb), GFP_KERNEL); @@ -858,7 +870,7 @@ static int nft_ct_timeout_obj_init(const struct nft_ctx *ctx, l4num = nla_get_u8(tb[NFTA_CT_TIMEOUT_L4PROTO]); priv->l4proto = l4num; - l4proto = nf_ct_l4proto_find_get(l3num, l4num); + l4proto = nf_ct_l4proto_find_get(l4num); if (l4proto->l4proto != l4num) { ret = -EOPNOTSUPP; diff --git a/net/netfilter/nft_dup_netdev.c b/net/netfilter/nft_dup_netdev.c index 2cc1e0ef56e8..15cc62b293d6 100644 --- a/net/netfilter/nft_dup_netdev.c +++ b/net/netfilter/nft_dup_netdev.c @@ -46,8 +46,6 @@ static int nft_dup_netdev_init(const struct nft_ctx *ctx, return nft_validate_register_load(priv->sreg_dev, sizeof(int)); } -static const struct nft_expr_ops nft_dup_netdev_ingress_ops; - static int nft_dup_netdev_dump(struct sk_buff *skb, const struct nft_expr *expr) { struct nft_dup_netdev *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_dynset.c b/net/netfilter/nft_dynset.c index 6e91a37d57f2..07d4efd3d851 100644 --- a/net/netfilter/nft_dynset.c +++ b/net/netfilter/nft_dynset.c @@ -235,14 +235,31 @@ err1: return err; } +static void nft_dynset_activate(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + struct nft_dynset *priv = nft_expr_priv(expr); + + nf_tables_rebind_set(ctx, priv->set, &priv->binding); +} + +static void nft_dynset_deactivate(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + struct nft_dynset *priv = nft_expr_priv(expr); + + nf_tables_unbind_set(ctx, priv->set, &priv->binding); +} + static void nft_dynset_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr) { struct nft_dynset *priv = nft_expr_priv(expr); - nf_tables_unbind_set(ctx, priv->set, &priv->binding); if (priv->expr != NULL) nft_expr_destroy(ctx, priv->expr); + + nf_tables_destroy_set(ctx, priv->set); } static int nft_dynset_dump(struct sk_buff *skb, const struct nft_expr *expr) @@ -279,6 +296,8 @@ static const struct nft_expr_ops nft_dynset_ops = { .eval = nft_dynset_eval, .init = nft_dynset_init, .destroy = nft_dynset_destroy, + .activate = nft_dynset_activate, + .deactivate = nft_dynset_deactivate, .dump = nft_dynset_dump, }; diff --git a/net/netfilter/nft_flow_offload.c b/net/netfilter/nft_flow_offload.c index d6bab8c3cbb0..e82d9a966c45 100644 --- a/net/netfilter/nft_flow_offload.c +++ b/net/netfilter/nft_flow_offload.c @@ -201,7 +201,7 @@ static int flow_offload_netdev_event(struct notifier_block *this, if (event != NETDEV_DOWN) return NOTIFY_DONE; - nf_flow_table_cleanup(dev_net(dev), dev); + nf_flow_table_cleanup(dev); return NOTIFY_DONE; } diff --git a/net/netfilter/nft_fwd_netdev.c b/net/netfilter/nft_fwd_netdev.c index 8abb9891cdf2..d7694e7255a0 100644 --- a/net/netfilter/nft_fwd_netdev.c +++ b/net/netfilter/nft_fwd_netdev.c @@ -53,8 +53,6 @@ static int nft_fwd_netdev_init(const struct nft_ctx *ctx, return nft_validate_register_load(priv->sreg_dev, sizeof(int)); } -static const struct nft_expr_ops nft_fwd_netdev_ingress_ops; - static int nft_fwd_netdev_dump(struct sk_buff *skb, const struct nft_expr *expr) { struct nft_fwd_netdev *priv = nft_expr_priv(expr); @@ -169,8 +167,6 @@ static int nft_fwd_neigh_init(const struct nft_ctx *ctx, return nft_validate_register_load(priv->sreg_addr, addr_len); } -static const struct nft_expr_ops nft_fwd_netdev_ingress_ops; - static int nft_fwd_neigh_dump(struct sk_buff *skb, const struct nft_expr *expr) { struct nft_fwd_neigh *priv = nft_expr_priv(expr); diff --git a/net/netfilter/nft_lookup.c b/net/netfilter/nft_lookup.c index ad13e8643599..227b2b15a19c 100644 --- a/net/netfilter/nft_lookup.c +++ b/net/netfilter/nft_lookup.c @@ -121,12 +121,28 @@ static int nft_lookup_init(const struct nft_ctx *ctx, return 0; } +static void nft_lookup_activate(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + struct nft_lookup *priv = nft_expr_priv(expr); + + nf_tables_rebind_set(ctx, priv->set, &priv->binding); +} + +static void nft_lookup_deactivate(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + struct nft_lookup *priv = nft_expr_priv(expr); + + nf_tables_unbind_set(ctx, priv->set, &priv->binding); +} + static void nft_lookup_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr) { struct nft_lookup *priv = nft_expr_priv(expr); - nf_tables_unbind_set(ctx, priv->set, &priv->binding); + nf_tables_destroy_set(ctx, priv->set); } static int nft_lookup_dump(struct sk_buff *skb, const struct nft_expr *expr) @@ -209,6 +225,8 @@ static const struct nft_expr_ops nft_lookup_ops = { .size = NFT_EXPR_SIZE(sizeof(struct nft_lookup)), .eval = nft_lookup_eval, .init = nft_lookup_init, + .activate = nft_lookup_activate, + .deactivate = nft_lookup_deactivate, .destroy = nft_lookup_destroy, .dump = nft_lookup_dump, .validate = nft_lookup_validate, diff --git a/net/netfilter/nft_meta.c b/net/netfilter/nft_meta.c index 297fe7d97c18..6180626c3f80 100644 --- a/net/netfilter/nft_meta.c +++ b/net/netfilter/nft_meta.c @@ -284,6 +284,11 @@ static void nft_meta_set_eval(const struct nft_expr *expr, skb->nf_trace = !!value8; break; +#ifdef CONFIG_NETWORK_SECMARK + case NFT_META_SECMARK: + skb->secmark = value; + break; +#endif default: WARN_ON(1); } @@ -436,6 +441,9 @@ static int nft_meta_set_init(const struct nft_ctx *ctx, switch (priv->key) { case NFT_META_MARK: case NFT_META_PRIORITY: +#ifdef CONFIG_NETWORK_SECMARK + case NFT_META_SECMARK: +#endif len = sizeof(u32); break; case NFT_META_NFTRACE: @@ -543,3 +551,111 @@ struct nft_expr_type nft_meta_type __read_mostly = { .maxattr = NFTA_META_MAX, .owner = THIS_MODULE, }; + +#ifdef CONFIG_NETWORK_SECMARK +struct nft_secmark { + u32 secid; + char *ctx; +}; + +static const struct nla_policy nft_secmark_policy[NFTA_SECMARK_MAX + 1] = { + [NFTA_SECMARK_CTX] = { .type = NLA_STRING, .len = NFT_SECMARK_CTX_MAXLEN }, +}; + +static int nft_secmark_compute_secid(struct nft_secmark *priv) +{ + u32 tmp_secid = 0; + int err; + + err = security_secctx_to_secid(priv->ctx, strlen(priv->ctx), &tmp_secid); + if (err) + return err; + + if (!tmp_secid) + return -ENOENT; + + err = security_secmark_relabel_packet(tmp_secid); + if (err) + return err; + + priv->secid = tmp_secid; + return 0; +} + +static void nft_secmark_obj_eval(struct nft_object *obj, struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_secmark *priv = nft_obj_data(obj); + struct sk_buff *skb = pkt->skb; + + skb->secmark = priv->secid; +} + +static int nft_secmark_obj_init(const struct nft_ctx *ctx, + const struct nlattr * const tb[], + struct nft_object *obj) +{ + struct nft_secmark *priv = nft_obj_data(obj); + int err; + + if (tb[NFTA_SECMARK_CTX] == NULL) + return -EINVAL; + + priv->ctx = nla_strdup(tb[NFTA_SECMARK_CTX], GFP_KERNEL); + if (!priv->ctx) + return -ENOMEM; + + err = nft_secmark_compute_secid(priv); + if (err) { + kfree(priv->ctx); + return err; + } + + security_secmark_refcount_inc(); + + return 0; +} + +static int nft_secmark_obj_dump(struct sk_buff *skb, struct nft_object *obj, + bool reset) +{ + struct nft_secmark *priv = nft_obj_data(obj); + int err; + + if (nla_put_string(skb, NFTA_SECMARK_CTX, priv->ctx)) + return -1; + + if (reset) { + err = nft_secmark_compute_secid(priv); + if (err) + return err; + } + + return 0; +} + +static void nft_secmark_obj_destroy(const struct nft_ctx *ctx, struct nft_object *obj) +{ + struct nft_secmark *priv = nft_obj_data(obj); + + security_secmark_refcount_dec(); + + kfree(priv->ctx); +} + +static const struct nft_object_ops nft_secmark_obj_ops = { + .type = &nft_secmark_obj_type, + .size = sizeof(struct nft_secmark), + .init = nft_secmark_obj_init, + .eval = nft_secmark_obj_eval, + .dump = nft_secmark_obj_dump, + .destroy = nft_secmark_obj_destroy, +}; +struct nft_object_type nft_secmark_obj_type __read_mostly = { + .type = NFT_OBJECT_SECMARK, + .ops = &nft_secmark_obj_ops, + .maxattr = NFTA_SECMARK_MAX, + .policy = nft_secmark_policy, + .owner = THIS_MODULE, +}; +#endif /* CONFIG_NETWORK_SECMARK */ diff --git a/net/netfilter/nft_objref.c b/net/netfilter/nft_objref.c index cdf348f751ec..a3185ca2a3a9 100644 --- a/net/netfilter/nft_objref.c +++ b/net/netfilter/nft_objref.c @@ -155,12 +155,28 @@ nla_put_failure: return -1; } +static void nft_objref_map_activate(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + struct nft_objref_map *priv = nft_expr_priv(expr); + + nf_tables_rebind_set(ctx, priv->set, &priv->binding); +} + +static void nft_objref_map_deactivate(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + struct nft_objref_map *priv = nft_expr_priv(expr); + + nf_tables_unbind_set(ctx, priv->set, &priv->binding); +} + static void nft_objref_map_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr) { struct nft_objref_map *priv = nft_expr_priv(expr); - nf_tables_unbind_set(ctx, priv->set, &priv->binding); + nf_tables_destroy_set(ctx, priv->set); } static struct nft_expr_type nft_objref_type; @@ -169,6 +185,8 @@ static const struct nft_expr_ops nft_objref_map_ops = { .size = NFT_EXPR_SIZE(sizeof(struct nft_objref_map)), .eval = nft_objref_map_eval, .init = nft_objref_map_init, + .activate = nft_objref_map_activate, + .deactivate = nft_objref_map_deactivate, .destroy = nft_objref_map_destroy, .dump = nft_objref_map_dump, }; diff --git a/net/netfilter/nft_osf.c b/net/netfilter/nft_osf.c index 5af74b37f423..ca5e5d8c5ef8 100644 --- a/net/netfilter/nft_osf.c +++ b/net/netfilter/nft_osf.c @@ -6,10 +6,12 @@ struct nft_osf { enum nft_registers dreg:8; + u8 ttl; }; static const struct nla_policy nft_osf_policy[NFTA_OSF_MAX + 1] = { [NFTA_OSF_DREG] = { .type = NLA_U32 }, + [NFTA_OSF_TTL] = { .type = NLA_U8 }, }; static void nft_osf_eval(const struct nft_expr *expr, struct nft_regs *regs, @@ -33,7 +35,7 @@ static void nft_osf_eval(const struct nft_expr *expr, struct nft_regs *regs, return; } - os_name = nf_osf_find(skb, nf_osf_fingers); + os_name = nf_osf_find(skb, nf_osf_fingers, priv->ttl); if (!os_name) strncpy((char *)dest, "unknown", NFT_OSF_MAXGENRELEN); else @@ -46,10 +48,18 @@ static int nft_osf_init(const struct nft_ctx *ctx, { struct nft_osf *priv = nft_expr_priv(expr); int err; + u8 ttl; + + if (nla_get_u8(tb[NFTA_OSF_TTL])) { + ttl = nla_get_u8(tb[NFTA_OSF_TTL]); + if (ttl > 2) + return -EINVAL; + priv->ttl = ttl; + } priv->dreg = nft_parse_register(tb[NFTA_OSF_DREG]); err = nft_validate_register_store(ctx, priv->dreg, NULL, - NFTA_DATA_VALUE, NFT_OSF_MAXGENRELEN); + NFT_DATA_VALUE, NFT_OSF_MAXGENRELEN); if (err < 0) return err; @@ -60,6 +70,9 @@ static int nft_osf_dump(struct sk_buff *skb, const struct nft_expr *expr) { const struct nft_osf *priv = nft_expr_priv(expr); + if (nla_put_u8(skb, NFTA_OSF_TTL, priv->ttl)) + goto nla_put_failure; + if (nft_dump_register(skb, NFTA_OSF_DREG, priv->dreg)) goto nla_put_failure; @@ -69,6 +82,15 @@ nla_put_failure: return -1; } +static int nft_osf_validate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nft_data **data) +{ + return nft_chain_validate_hooks(ctx->chain, (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_FORWARD)); +} + static struct nft_expr_type nft_osf_type; static const struct nft_expr_ops nft_osf_op = { .eval = nft_osf_eval, @@ -76,6 +98,7 @@ static const struct nft_expr_ops nft_osf_op = { .init = nft_osf_init, .dump = nft_osf_dump, .type = &nft_osf_type, + .validate = nft_osf_validate, }; static struct nft_expr_type nft_osf_type __read_mostly = { diff --git a/net/netfilter/nft_reject.c b/net/netfilter/nft_reject.c index 29f5bd2377b0..b48e58cceeb7 100644 --- a/net/netfilter/nft_reject.c +++ b/net/netfilter/nft_reject.c @@ -94,7 +94,8 @@ static u8 icmp_code_v4[NFT_REJECT_ICMPX_MAX + 1] = { int nft_reject_icmp_code(u8 code) { - BUG_ON(code > NFT_REJECT_ICMPX_MAX); + if (WARN_ON_ONCE(code > NFT_REJECT_ICMPX_MAX)) + return ICMP_NET_UNREACH; return icmp_code_v4[code]; } @@ -111,7 +112,8 @@ static u8 icmp_code_v6[NFT_REJECT_ICMPX_MAX + 1] = { int nft_reject_icmpv6_code(u8 code) { - BUG_ON(code > NFT_REJECT_ICMPX_MAX); + if (WARN_ON_ONCE(code > NFT_REJECT_ICMPX_MAX)) + return ICMPV6_NOROUTE; return icmp_code_v6[code]; } diff --git a/net/netfilter/nft_rt.c b/net/netfilter/nft_rt.c index 76dba9f6b6f6..f35fa33913ae 100644 --- a/net/netfilter/nft_rt.c +++ b/net/netfilter/nft_rt.c @@ -90,6 +90,11 @@ static void nft_rt_get_eval(const struct nft_expr *expr, case NFT_RT_TCPMSS: nft_reg_store16(dest, get_tcpmss(pkt, dst)); break; +#ifdef CONFIG_XFRM + case NFT_RT_XFRM: + nft_reg_store8(dest, !!dst->xfrm); + break; +#endif default: WARN_ON(1); goto err; @@ -130,6 +135,11 @@ static int nft_rt_get_init(const struct nft_ctx *ctx, case NFT_RT_TCPMSS: len = sizeof(u16); break; +#ifdef CONFIG_XFRM + case NFT_RT_XFRM: + len = sizeof(u8); + break; +#endif default: return -EOPNOTSUPP; } @@ -164,6 +174,7 @@ static int nft_rt_validate(const struct nft_ctx *ctx, const struct nft_expr *exp case NFT_RT_NEXTHOP4: case NFT_RT_NEXTHOP6: case NFT_RT_CLASSID: + case NFT_RT_XFRM: return 0; case NFT_RT_TCPMSS: hooks = (1 << NF_INET_FORWARD) | diff --git a/net/netfilter/nft_set_hash.c b/net/netfilter/nft_set_hash.c index 015124e649cb..339a9dd1c832 100644 --- a/net/netfilter/nft_set_hash.c +++ b/net/netfilter/nft_set_hash.c @@ -88,7 +88,7 @@ static bool nft_rhash_lookup(const struct net *net, const struct nft_set *set, .key = key, }; - he = rhashtable_lookup_fast(&priv->ht, &arg, nft_rhash_params); + he = rhashtable_lookup(&priv->ht, &arg, nft_rhash_params); if (he != NULL) *ext = &he->ext; @@ -106,7 +106,7 @@ static void *nft_rhash_get(const struct net *net, const struct nft_set *set, .key = elem->key.val.data, }; - he = rhashtable_lookup_fast(&priv->ht, &arg, nft_rhash_params); + he = rhashtable_lookup(&priv->ht, &arg, nft_rhash_params); if (he != NULL) return he; @@ -129,7 +129,7 @@ static bool nft_rhash_update(struct nft_set *set, const u32 *key, .key = key, }; - he = rhashtable_lookup_fast(&priv->ht, &arg, nft_rhash_params); + he = rhashtable_lookup(&priv->ht, &arg, nft_rhash_params); if (he != NULL) goto out; @@ -217,7 +217,7 @@ static void *nft_rhash_deactivate(const struct net *net, }; rcu_read_lock(); - he = rhashtable_lookup_fast(&priv->ht, &arg, nft_rhash_params); + he = rhashtable_lookup(&priv->ht, &arg, nft_rhash_params); if (he != NULL && !nft_rhash_flush(net, set, he)) he = NULL; @@ -244,21 +244,15 @@ static void nft_rhash_walk(const struct nft_ctx *ctx, struct nft_set *set, struct nft_rhash_elem *he; struct rhashtable_iter hti; struct nft_set_elem elem; - int err; - - err = rhashtable_walk_init(&priv->ht, &hti, GFP_ATOMIC); - iter->err = err; - if (err) - return; + rhashtable_walk_enter(&priv->ht, &hti); rhashtable_walk_start(&hti); while ((he = rhashtable_walk_next(&hti))) { if (IS_ERR(he)) { - err = PTR_ERR(he); - if (err != -EAGAIN) { - iter->err = err; - goto out; + if (PTR_ERR(he) != -EAGAIN) { + iter->err = PTR_ERR(he); + break; } continue; @@ -275,13 +269,11 @@ static void nft_rhash_walk(const struct nft_ctx *ctx, struct nft_set *set, iter->err = iter->fn(ctx, set, iter, &elem); if (iter->err < 0) - goto out; + break; cont: iter->count++; } - -out: rhashtable_walk_stop(&hti); rhashtable_walk_exit(&hti); } @@ -293,21 +285,17 @@ static void nft_rhash_gc(struct work_struct *work) struct nft_rhash *priv; struct nft_set_gc_batch *gcb = NULL; struct rhashtable_iter hti; - int err; priv = container_of(work, struct nft_rhash, gc_work.work); set = nft_set_container_of(priv); - err = rhashtable_walk_init(&priv->ht, &hti, GFP_KERNEL); - if (err) - goto schedule; - + rhashtable_walk_enter(&priv->ht, &hti); rhashtable_walk_start(&hti); while ((he = rhashtable_walk_next(&hti))) { if (IS_ERR(he)) { if (PTR_ERR(he) != -EAGAIN) - goto out; + break; continue; } @@ -326,17 +314,15 @@ gc: gcb = nft_set_gc_batch_check(set, gcb, GFP_ATOMIC); if (gcb == NULL) - goto out; + break; rhashtable_remove_fast(&priv->ht, &he->node, nft_rhash_params); atomic_dec(&set->nelems); nft_set_gc_batch_add(gcb, he); } -out: rhashtable_walk_stop(&hti); rhashtable_walk_exit(&hti); nft_set_gc_batch_complete(gcb); -schedule: queue_delayed_work(system_power_efficient_wq, &priv->gc_work, nft_set_gc_interval(set)); } diff --git a/net/netfilter/nft_set_rbtree.c b/net/netfilter/nft_set_rbtree.c index 55e2d9215c0d..fa61208371f8 100644 --- a/net/netfilter/nft_set_rbtree.c +++ b/net/netfilter/nft_set_rbtree.c @@ -135,9 +135,12 @@ static bool __nft_rbtree_get(const struct net *net, const struct nft_set *set, d = memcmp(this, key, set->klen); if (d < 0) { parent = rcu_dereference_raw(parent->rb_left); - interval = rbe; + if (!(flags & NFT_SET_ELEM_INTERVAL_END)) + interval = rbe; } else if (d > 0) { parent = rcu_dereference_raw(parent->rb_right); + if (flags & NFT_SET_ELEM_INTERVAL_END) + interval = rbe; } else { if (!nft_set_elem_active(&rbe->ext, genmask)) parent = rcu_dereference_raw(parent->rb_left); @@ -154,7 +157,10 @@ static bool __nft_rbtree_get(const struct net *net, const struct nft_set *set, if (set->flags & NFT_SET_INTERVAL && interval != NULL && nft_set_elem_active(&interval->ext, genmask) && - !nft_rbtree_interval_end(interval)) { + ((!nft_rbtree_interval_end(interval) && + !(flags & NFT_SET_ELEM_INTERVAL_END)) || + (nft_rbtree_interval_end(interval) && + (flags & NFT_SET_ELEM_INTERVAL_END)))) { *elem = interval; return true; } @@ -355,12 +361,11 @@ cont: static void nft_rbtree_gc(struct work_struct *work) { + struct nft_rbtree_elem *rbe, *rbe_end = NULL, *rbe_prev = NULL; struct nft_set_gc_batch *gcb = NULL; - struct rb_node *node, *prev = NULL; - struct nft_rbtree_elem *rbe; struct nft_rbtree *priv; + struct rb_node *node; struct nft_set *set; - int i; priv = container_of(work, struct nft_rbtree, gc_work.work); set = nft_set_container_of(priv); @@ -371,7 +376,7 @@ static void nft_rbtree_gc(struct work_struct *work) rbe = rb_entry(node, struct nft_rbtree_elem, node); if (nft_rbtree_interval_end(rbe)) { - prev = node; + rbe_end = rbe; continue; } if (!nft_set_elem_expired(&rbe->ext)) @@ -379,29 +384,30 @@ static void nft_rbtree_gc(struct work_struct *work) if (nft_set_elem_mark_busy(&rbe->ext)) continue; + if (rbe_prev) { + rb_erase(&rbe_prev->node, &priv->root); + rbe_prev = NULL; + } gcb = nft_set_gc_batch_check(set, gcb, GFP_ATOMIC); if (!gcb) break; atomic_dec(&set->nelems); nft_set_gc_batch_add(gcb, rbe); + rbe_prev = rbe; - if (prev) { - rbe = rb_entry(prev, struct nft_rbtree_elem, node); + if (rbe_end) { atomic_dec(&set->nelems); - nft_set_gc_batch_add(gcb, rbe); - prev = NULL; + nft_set_gc_batch_add(gcb, rbe_end); + rb_erase(&rbe_end->node, &priv->root); + rbe_end = NULL; } node = rb_next(node); if (!node) break; } - if (gcb) { - for (i = 0; i < gcb->head.cnt; i++) { - rbe = gcb->elems[i]; - rb_erase(&rbe->node, &priv->root); - } - } + if (rbe_prev) + rb_erase(&rbe_prev->node, &priv->root); write_seqcount_end(&priv->count); write_unlock_bh(&priv->lock); diff --git a/net/netfilter/nft_xfrm.c b/net/netfilter/nft_xfrm.c new file mode 100644 index 000000000000..5322609f7662 --- /dev/null +++ b/net/netfilter/nft_xfrm.c @@ -0,0 +1,294 @@ +/* + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 2 as + * published by the Free Software Foundation. + * + * Generic part shared by ipv4 and ipv6 backends. + */ + +#include <linux/kernel.h> +#include <linux/init.h> +#include <linux/module.h> +#include <linux/netlink.h> +#include <linux/netfilter.h> +#include <linux/netfilter/nf_tables.h> +#include <net/netfilter/nf_tables_core.h> +#include <net/netfilter/nf_tables.h> +#include <linux/in.h> +#include <net/xfrm.h> + +static const struct nla_policy nft_xfrm_policy[NFTA_XFRM_MAX + 1] = { + [NFTA_XFRM_KEY] = { .type = NLA_U32 }, + [NFTA_XFRM_DIR] = { .type = NLA_U8 }, + [NFTA_XFRM_SPNUM] = { .type = NLA_U32 }, + [NFTA_XFRM_DREG] = { .type = NLA_U32 }, +}; + +struct nft_xfrm { + enum nft_xfrm_keys key:8; + enum nft_registers dreg:8; + u8 dir; + u8 spnum; +}; + +static int nft_xfrm_get_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_xfrm *priv = nft_expr_priv(expr); + unsigned int len = 0; + u32 spnum = 0; + u8 dir; + + if (!tb[NFTA_XFRM_KEY] || !tb[NFTA_XFRM_DIR] || !tb[NFTA_XFRM_DREG]) + return -EINVAL; + + switch (ctx->family) { + case NFPROTO_IPV4: + case NFPROTO_IPV6: + case NFPROTO_INET: + break; + default: + return -EOPNOTSUPP; + } + + priv->key = ntohl(nla_get_u32(tb[NFTA_XFRM_KEY])); + switch (priv->key) { + case NFT_XFRM_KEY_REQID: + case NFT_XFRM_KEY_SPI: + len = sizeof(u32); + break; + case NFT_XFRM_KEY_DADDR_IP4: + case NFT_XFRM_KEY_SADDR_IP4: + len = sizeof(struct in_addr); + break; + case NFT_XFRM_KEY_DADDR_IP6: + case NFT_XFRM_KEY_SADDR_IP6: + len = sizeof(struct in6_addr); + break; + default: + return -EINVAL; + } + + dir = nla_get_u8(tb[NFTA_XFRM_DIR]); + switch (dir) { + case XFRM_POLICY_IN: + case XFRM_POLICY_OUT: + priv->dir = dir; + break; + default: + return -EINVAL; + } + + if (tb[NFTA_XFRM_SPNUM]) + spnum = ntohl(nla_get_be32(tb[NFTA_XFRM_SPNUM])); + + if (spnum >= XFRM_MAX_DEPTH) + return -ERANGE; + + priv->spnum = spnum; + + priv->dreg = nft_parse_register(tb[NFTA_XFRM_DREG]); + return nft_validate_register_store(ctx, priv->dreg, NULL, + NFT_DATA_VALUE, len); +} + +/* Return true if key asks for daddr/saddr and current + * state does have a valid address (BEET, TUNNEL). + */ +static bool xfrm_state_addr_ok(enum nft_xfrm_keys k, u8 family, u8 mode) +{ + switch (k) { + case NFT_XFRM_KEY_DADDR_IP4: + case NFT_XFRM_KEY_SADDR_IP4: + if (family == NFPROTO_IPV4) + break; + return false; + case NFT_XFRM_KEY_DADDR_IP6: + case NFT_XFRM_KEY_SADDR_IP6: + if (family == NFPROTO_IPV6) + break; + return false; + default: + return true; + } + + return mode == XFRM_MODE_BEET || mode == XFRM_MODE_TUNNEL; +} + +static void nft_xfrm_state_get_key(const struct nft_xfrm *priv, + struct nft_regs *regs, + const struct xfrm_state *state) +{ + u32 *dest = ®s->data[priv->dreg]; + + if (!xfrm_state_addr_ok(priv->key, + state->props.family, + state->props.mode)) { + regs->verdict.code = NFT_BREAK; + return; + } + + switch (priv->key) { + case NFT_XFRM_KEY_UNSPEC: + case __NFT_XFRM_KEY_MAX: + WARN_ON_ONCE(1); + break; + case NFT_XFRM_KEY_DADDR_IP4: + *dest = state->id.daddr.a4; + return; + case NFT_XFRM_KEY_DADDR_IP6: + memcpy(dest, &state->id.daddr.in6, sizeof(struct in6_addr)); + return; + case NFT_XFRM_KEY_SADDR_IP4: + *dest = state->props.saddr.a4; + return; + case NFT_XFRM_KEY_SADDR_IP6: + memcpy(dest, &state->props.saddr.in6, sizeof(struct in6_addr)); + return; + case NFT_XFRM_KEY_REQID: + *dest = state->props.reqid; + return; + case NFT_XFRM_KEY_SPI: + *dest = state->id.spi; + return; + } + + regs->verdict.code = NFT_BREAK; +} + +static void nft_xfrm_get_eval_in(const struct nft_xfrm *priv, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct sec_path *sp = pkt->skb->sp; + const struct xfrm_state *state; + + if (sp == NULL || sp->len <= priv->spnum) { + regs->verdict.code = NFT_BREAK; + return; + } + + state = sp->xvec[priv->spnum]; + nft_xfrm_state_get_key(priv, regs, state); +} + +static void nft_xfrm_get_eval_out(const struct nft_xfrm *priv, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct dst_entry *dst = skb_dst(pkt->skb); + int i; + + for (i = 0; dst && dst->xfrm; + dst = ((const struct xfrm_dst *)dst)->child, i++) { + if (i < priv->spnum) + continue; + + nft_xfrm_state_get_key(priv, regs, dst->xfrm); + return; + } + + regs->verdict.code = NFT_BREAK; +} + +static void nft_xfrm_get_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_xfrm *priv = nft_expr_priv(expr); + + switch (priv->dir) { + case XFRM_POLICY_IN: + nft_xfrm_get_eval_in(priv, regs, pkt); + break; + case XFRM_POLICY_OUT: + nft_xfrm_get_eval_out(priv, regs, pkt); + break; + default: + WARN_ON_ONCE(1); + regs->verdict.code = NFT_BREAK; + break; + } +} + +static int nft_xfrm_get_dump(struct sk_buff *skb, + const struct nft_expr *expr) +{ + const struct nft_xfrm *priv = nft_expr_priv(expr); + + if (nft_dump_register(skb, NFTA_XFRM_DREG, priv->dreg)) + return -1; + + if (nla_put_be32(skb, NFTA_XFRM_KEY, htonl(priv->key))) + return -1; + if (nla_put_u8(skb, NFTA_XFRM_DIR, priv->dir)) + return -1; + if (nla_put_be32(skb, NFTA_XFRM_SPNUM, htonl(priv->spnum))) + return -1; + + return 0; +} + +static int nft_xfrm_validate(const struct nft_ctx *ctx, const struct nft_expr *expr, + const struct nft_data **data) +{ + const struct nft_xfrm *priv = nft_expr_priv(expr); + unsigned int hooks; + + switch (priv->dir) { + case XFRM_POLICY_IN: + hooks = (1 << NF_INET_FORWARD) | + (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_PRE_ROUTING); + break; + case XFRM_POLICY_OUT: + hooks = (1 << NF_INET_FORWARD) | + (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_POST_ROUTING); + break; + default: + WARN_ON_ONCE(1); + return -EINVAL; + } + + return nft_chain_validate_hooks(ctx->chain, hooks); +} + + +static struct nft_expr_type nft_xfrm_type; +static const struct nft_expr_ops nft_xfrm_get_ops = { + .type = &nft_xfrm_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_xfrm)), + .eval = nft_xfrm_get_eval, + .init = nft_xfrm_get_init, + .dump = nft_xfrm_get_dump, + .validate = nft_xfrm_validate, +}; + +static struct nft_expr_type nft_xfrm_type __read_mostly = { + .name = "xfrm", + .ops = &nft_xfrm_get_ops, + .policy = nft_xfrm_policy, + .maxattr = NFTA_XFRM_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_xfrm_module_init(void) +{ + return nft_register_expr(&nft_xfrm_type); +} + +static void __exit nft_xfrm_module_exit(void) +{ + nft_unregister_expr(&nft_xfrm_type); +} + +module_init(nft_xfrm_module_init); +module_exit(nft_xfrm_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("nf_tables: xfrm/IPSec matching"); +MODULE_AUTHOR("Florian Westphal <fw@strlen.de>"); +MODULE_AUTHOR("Máté Eckl <ecklm94@gmail.com>"); +MODULE_ALIAS_NFT_EXPR("xfrm"); diff --git a/net/netfilter/xt_CT.c b/net/netfilter/xt_CT.c index 89457efd2e00..2c7a4b80206f 100644 --- a/net/netfilter/xt_CT.c +++ b/net/netfilter/xt_CT.c @@ -159,7 +159,7 @@ xt_ct_set_timeout(struct nf_conn *ct, const struct xt_tgchk_param *par, /* Make sure the timeout policy matches any existing protocol tracker, * otherwise default to generic. */ - l4proto = __nf_ct_l4proto_find(par->family, proto); + l4proto = __nf_ct_l4proto_find(proto); if (timeout->l4proto->l4proto != l4proto->l4proto) { ret = -EINVAL; pr_info_ratelimited("Timeout policy `%s' can only be used by L%d protocol number %d\n", diff --git a/net/netfilter/xt_IDLETIMER.c b/net/netfilter/xt_IDLETIMER.c index 5ee859193783..c6acfc2d9c84 100644 --- a/net/netfilter/xt_IDLETIMER.c +++ b/net/netfilter/xt_IDLETIMER.c @@ -68,8 +68,6 @@ struct idletimer_tg *__idletimer_tg_find_by_label(const char *label) { struct idletimer_tg *entry; - BUG_ON(!label); - list_for_each_entry(entry, &idletimer_tg_list, entry) { if (!strcmp(label, entry->attr.attr.name)) return entry; @@ -172,8 +170,6 @@ static unsigned int idletimer_tg_target(struct sk_buff *skb, pr_debug("resetting timer %s, timeout period %u\n", info->label, info->timeout); - BUG_ON(!info->timer); - mod_timer(&info->timer->timer, msecs_to_jiffies(info->timeout * 1000) + jiffies); diff --git a/net/netfilter/xt_SECMARK.c b/net/netfilter/xt_SECMARK.c index 4ad5fe27e08b..f16202d26c20 100644 --- a/net/netfilter/xt_SECMARK.c +++ b/net/netfilter/xt_SECMARK.c @@ -35,8 +35,6 @@ secmark_tg(struct sk_buff *skb, const struct xt_action_param *par) u32 secmark = 0; const struct xt_secmark_target_info *info = par->targinfo; - BUG_ON(info->mode != mode); - switch (mode) { case SECMARK_MODE_SEL: secmark = info->secid; diff --git a/net/netfilter/xt_TEE.c b/net/netfilter/xt_TEE.c index 0d0d68c989df..1dae02a97ee3 100644 --- a/net/netfilter/xt_TEE.c +++ b/net/netfilter/xt_TEE.c @@ -14,6 +14,8 @@ #include <linux/skbuff.h> #include <linux/route.h> #include <linux/netfilter/x_tables.h> +#include <net/net_namespace.h> +#include <net/netns/generic.h> #include <net/route.h> #include <net/netfilter/ipv4/nf_dup_ipv4.h> #include <net/netfilter/ipv6/nf_dup_ipv6.h> @@ -25,8 +27,15 @@ struct xt_tee_priv { int oif; }; +static unsigned int tee_net_id __read_mostly; static const union nf_inet_addr tee_zero_address; +struct tee_net { + struct list_head priv_list; + /* lock protects the priv_list */ + struct mutex lock; +}; + static unsigned int tee_tg4(struct sk_buff *skb, const struct xt_action_param *par) { @@ -51,17 +60,16 @@ tee_tg6(struct sk_buff *skb, const struct xt_action_param *par) } #endif -static DEFINE_MUTEX(priv_list_mutex); -static LIST_HEAD(priv_list); - static int tee_netdev_event(struct notifier_block *this, unsigned long event, void *ptr) { struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct net *net = dev_net(dev); + struct tee_net *tn = net_generic(net, tee_net_id); struct xt_tee_priv *priv; - mutex_lock(&priv_list_mutex); - list_for_each_entry(priv, &priv_list, list) { + mutex_lock(&tn->lock); + list_for_each_entry(priv, &tn->priv_list, list) { switch (event) { case NETDEV_REGISTER: if (!strcmp(dev->name, priv->tginfo->oif)) @@ -79,13 +87,14 @@ static int tee_netdev_event(struct notifier_block *this, unsigned long event, break; } } - mutex_unlock(&priv_list_mutex); + mutex_unlock(&tn->lock); return NOTIFY_DONE; } static int tee_tg_check(const struct xt_tgchk_param *par) { + struct tee_net *tn = net_generic(par->net, tee_net_id); struct xt_tee_tginfo *info = par->targinfo; struct xt_tee_priv *priv; @@ -95,6 +104,8 @@ static int tee_tg_check(const struct xt_tgchk_param *par) return -EINVAL; if (info->oif[0]) { + struct net_device *dev; + if (info->oif[sizeof(info->oif)-1] != '\0') return -EINVAL; @@ -106,9 +117,14 @@ static int tee_tg_check(const struct xt_tgchk_param *par) priv->oif = -1; info->priv = priv; - mutex_lock(&priv_list_mutex); - list_add(&priv->list, &priv_list); - mutex_unlock(&priv_list_mutex); + dev = dev_get_by_name(par->net, info->oif); + if (dev) { + priv->oif = dev->ifindex; + dev_put(dev); + } + mutex_lock(&tn->lock); + list_add(&priv->list, &tn->priv_list); + mutex_unlock(&tn->lock); } else info->priv = NULL; @@ -118,12 +134,13 @@ static int tee_tg_check(const struct xt_tgchk_param *par) static void tee_tg_destroy(const struct xt_tgdtor_param *par) { + struct tee_net *tn = net_generic(par->net, tee_net_id); struct xt_tee_tginfo *info = par->targinfo; if (info->priv) { - mutex_lock(&priv_list_mutex); + mutex_lock(&tn->lock); list_del(&info->priv->list); - mutex_unlock(&priv_list_mutex); + mutex_unlock(&tn->lock); kfree(info->priv); } static_key_slow_dec(&xt_tee_enabled); @@ -156,6 +173,21 @@ static struct xt_target tee_tg_reg[] __read_mostly = { #endif }; +static int __net_init tee_net_init(struct net *net) +{ + struct tee_net *tn = net_generic(net, tee_net_id); + + INIT_LIST_HEAD(&tn->priv_list); + mutex_init(&tn->lock); + return 0; +} + +static struct pernet_operations tee_net_ops = { + .init = tee_net_init, + .id = &tee_net_id, + .size = sizeof(struct tee_net), +}; + static struct notifier_block tee_netdev_notifier = { .notifier_call = tee_netdev_event, }; @@ -164,22 +196,32 @@ static int __init tee_tg_init(void) { int ret; - ret = xt_register_targets(tee_tg_reg, ARRAY_SIZE(tee_tg_reg)); - if (ret) + ret = register_pernet_subsys(&tee_net_ops); + if (ret < 0) return ret; + + ret = xt_register_targets(tee_tg_reg, ARRAY_SIZE(tee_tg_reg)); + if (ret < 0) + goto cleanup_subsys; + ret = register_netdevice_notifier(&tee_netdev_notifier); - if (ret) { - xt_unregister_targets(tee_tg_reg, ARRAY_SIZE(tee_tg_reg)); - return ret; - } + if (ret < 0) + goto unregister_targets; return 0; + +unregister_targets: + xt_unregister_targets(tee_tg_reg, ARRAY_SIZE(tee_tg_reg)); +cleanup_subsys: + unregister_pernet_subsys(&tee_net_ops); + return ret; } static void __exit tee_tg_exit(void) { unregister_netdevice_notifier(&tee_netdev_notifier); xt_unregister_targets(tee_tg_reg, ARRAY_SIZE(tee_tg_reg)); + unregister_pernet_subsys(&tee_net_ops); } module_init(tee_tg_init); diff --git a/net/netfilter/xt_cgroup.c b/net/netfilter/xt_cgroup.c index 5d92e1781980..5cb1ecb29ea4 100644 --- a/net/netfilter/xt_cgroup.c +++ b/net/netfilter/xt_cgroup.c @@ -68,6 +68,38 @@ static int cgroup_mt_check_v1(const struct xt_mtchk_param *par) return 0; } +static int cgroup_mt_check_v2(const struct xt_mtchk_param *par) +{ + struct xt_cgroup_info_v2 *info = par->matchinfo; + struct cgroup *cgrp; + + if ((info->invert_path & ~1) || (info->invert_classid & ~1)) + return -EINVAL; + + if (!info->has_path && !info->has_classid) { + pr_info("xt_cgroup: no path or classid specified\n"); + return -EINVAL; + } + + if (info->has_path && info->has_classid) { + pr_info_ratelimited("path and classid specified\n"); + return -EINVAL; + } + + info->priv = NULL; + if (info->has_path) { + cgrp = cgroup_get_from_path(info->path); + if (IS_ERR(cgrp)) { + pr_info_ratelimited("invalid path, errno=%ld\n", + PTR_ERR(cgrp)); + return -EINVAL; + } + info->priv = cgrp; + } + + return 0; +} + static bool cgroup_mt_v0(const struct sk_buff *skb, struct xt_action_param *par) { @@ -99,6 +131,24 @@ static bool cgroup_mt_v1(const struct sk_buff *skb, struct xt_action_param *par) info->invert_classid; } +static bool cgroup_mt_v2(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_cgroup_info_v2 *info = par->matchinfo; + struct sock_cgroup_data *skcd = &skb->sk->sk_cgrp_data; + struct cgroup *ancestor = info->priv; + struct sock *sk = skb->sk; + + if (!sk || !sk_fullsock(sk) || !net_eq(xt_net(par), sock_net(sk))) + return false; + + if (ancestor) + return cgroup_is_descendant(sock_cgroup_ptr(skcd), ancestor) ^ + info->invert_path; + else + return (info->classid == sock_cgroup_classid(skcd)) ^ + info->invert_classid; +} + static void cgroup_mt_destroy_v1(const struct xt_mtdtor_param *par) { struct xt_cgroup_info_v1 *info = par->matchinfo; @@ -107,6 +157,14 @@ static void cgroup_mt_destroy_v1(const struct xt_mtdtor_param *par) cgroup_put(info->priv); } +static void cgroup_mt_destroy_v2(const struct xt_mtdtor_param *par) +{ + struct xt_cgroup_info_v2 *info = par->matchinfo; + + if (info->priv) + cgroup_put(info->priv); +} + static struct xt_match cgroup_mt_reg[] __read_mostly = { { .name = "cgroup", @@ -134,6 +192,20 @@ static struct xt_match cgroup_mt_reg[] __read_mostly = { (1 << NF_INET_POST_ROUTING) | (1 << NF_INET_LOCAL_IN), }, + { + .name = "cgroup", + .revision = 2, + .family = NFPROTO_UNSPEC, + .checkentry = cgroup_mt_check_v2, + .match = cgroup_mt_v2, + .matchsize = sizeof(struct xt_cgroup_info_v2), + .usersize = offsetof(struct xt_cgroup_info_v2, priv), + .destroy = cgroup_mt_destroy_v2, + .me = THIS_MODULE, + .hooks = (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_POST_ROUTING) | + (1 << NF_INET_LOCAL_IN), + }, }; static int __init cgroup_mt_init(void) diff --git a/net/netfilter/xt_nat.c b/net/netfilter/xt_nat.c index 8af9707f8789..ac91170fc8c8 100644 --- a/net/netfilter/xt_nat.c +++ b/net/netfilter/xt_nat.c @@ -216,6 +216,8 @@ static struct xt_target xt_nat_target_reg[] __read_mostly = { { .name = "DNAT", .revision = 2, + .checkentry = xt_nat_checkentry, + .destroy = xt_nat_destroy, .target = xt_dnat_target_v2, .targetsize = sizeof(struct nf_nat_range2), .table = "nat", diff --git a/net/netfilter/xt_osf.c b/net/netfilter/xt_osf.c index bf7bba80e24c..7a103553d10d 100644 --- a/net/netfilter/xt_osf.c +++ b/net/netfilter/xt_osf.c @@ -40,14 +40,8 @@ static bool xt_osf_match_packet(const struct sk_buff *skb, struct xt_action_param *p) { - const struct xt_osf_info *info = p->matchinfo; - struct net *net = xt_net(p); - - if (!info) - return false; - return nf_osf_match(skb, xt_family(p), xt_hooknum(p), xt_in(p), - xt_out(p), info, net, nf_osf_fingers); + xt_out(p), p->matchinfo, xt_net(p), nf_osf_fingers); } static struct xt_match xt_osf_match = { diff --git a/net/netfilter/xt_socket.c b/net/netfilter/xt_socket.c index 0472f3472842..ada144e5645b 100644 --- a/net/netfilter/xt_socket.c +++ b/net/netfilter/xt_socket.c @@ -56,7 +56,7 @@ socket_match(const struct sk_buff *skb, struct xt_action_param *par, struct sk_buff *pskb = (struct sk_buff *)skb; struct sock *sk = skb->sk; - if (!net_eq(xt_net(par), sock_net(sk))) + if (sk && !net_eq(xt_net(par), sock_net(sk))) sk = NULL; if (!sk) @@ -117,7 +117,7 @@ socket_mt6_v1_v2_v3(const struct sk_buff *skb, struct xt_action_param *par) struct sk_buff *pskb = (struct sk_buff *)skb; struct sock *sk = skb->sk; - if (!net_eq(xt_net(par), sock_net(sk))) + if (sk && !net_eq(xt_net(par), sock_net(sk))) sk = NULL; if (!sk) diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c index e3a0538ec0be..6bb9f3cde0b0 100644 --- a/net/netlink/af_netlink.c +++ b/net/netlink/af_netlink.c @@ -1706,6 +1706,13 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname, nlk->flags &= ~NETLINK_F_EXT_ACK; err = 0; break; + case NETLINK_DUMP_STRICT_CHK: + if (val) + nlk->flags |= NETLINK_F_STRICT_CHK; + else + nlk->flags &= ~NETLINK_F_STRICT_CHK; + err = 0; + break; default: err = -ENOPROTOOPT; } @@ -1799,6 +1806,15 @@ static int netlink_getsockopt(struct socket *sock, int level, int optname, return -EFAULT; err = 0; break; + case NETLINK_DUMP_STRICT_CHK: + if (len < sizeof(int)) + return -EINVAL; + len = sizeof(int); + val = nlk->flags & NETLINK_F_STRICT_CHK ? 1 : 0; + if (put_user(len, optlen) || put_user(val, optval)) + return -EFAULT; + err = 0; + break; default: err = -ENOPROTOOPT; } @@ -2171,6 +2187,7 @@ EXPORT_SYMBOL(__nlmsg_put); static int netlink_dump(struct sock *sk) { struct netlink_sock *nlk = nlk_sk(sk); + struct netlink_ext_ack extack = {}; struct netlink_callback *cb; struct sk_buff *skb = NULL; struct nlmsghdr *nlh; @@ -2222,8 +2239,11 @@ static int netlink_dump(struct sock *sk) skb_reserve(skb, skb_tailroom(skb) - alloc_size); netlink_skb_set_owner_r(skb, sk); - if (nlk->dump_done_errno > 0) + if (nlk->dump_done_errno > 0) { + cb->extack = &extack; nlk->dump_done_errno = cb->dump(skb, cb); + cb->extack = NULL; + } if (nlk->dump_done_errno > 0 || skb_tailroom(skb) < nlmsg_total_size(sizeof(nlk->dump_done_errno))) { @@ -2237,7 +2257,8 @@ static int netlink_dump(struct sock *sk) } nlh = nlmsg_put_answer(skb, cb, NLMSG_DONE, - sizeof(nlk->dump_done_errno), NLM_F_MULTI); + sizeof(nlk->dump_done_errno), + NLM_F_MULTI | cb->answer_flags); if (WARN_ON(!nlh)) goto errout_skb; @@ -2246,6 +2267,12 @@ static int netlink_dump(struct sock *sk) memcpy(nlmsg_data(nlh), &nlk->dump_done_errno, sizeof(nlk->dump_done_errno)); + if (extack._msg && nlk->flags & NETLINK_F_EXT_ACK) { + nlh->nlmsg_flags |= NLM_F_ACK_TLVS; + if (!nla_put_string(skb, NLMSGERR_ATTR_MSG, extack._msg)) + nlmsg_end(skb, nlh); + } + if (sk_filter(sk, skb)) kfree_skb(skb); else @@ -2272,9 +2299,9 @@ int __netlink_dump_start(struct sock *ssk, struct sk_buff *skb, const struct nlmsghdr *nlh, struct netlink_dump_control *control) { + struct netlink_sock *nlk, *nlk2; struct netlink_callback *cb; struct sock *sk; - struct netlink_sock *nlk; int ret; refcount_inc(&skb->users); @@ -2308,6 +2335,9 @@ int __netlink_dump_start(struct sock *ssk, struct sk_buff *skb, cb->min_dump_alloc = control->min_dump_alloc; cb->skb = skb; + nlk2 = nlk_sk(NETLINK_CB(skb).sk); + cb->strict_check = !!(nlk2->flags & NETLINK_F_STRICT_CHK); + if (control->start) { ret = control->start(cb); if (ret) diff --git a/net/netlink/af_netlink.h b/net/netlink/af_netlink.h index 962de7b3c023..5f454c8de6a4 100644 --- a/net/netlink/af_netlink.h +++ b/net/netlink/af_netlink.h @@ -15,6 +15,7 @@ #define NETLINK_F_LISTEN_ALL_NSID 0x10 #define NETLINK_F_CAP_ACK 0x20 #define NETLINK_F_EXT_ACK 0x40 +#define NETLINK_F_STRICT_CHK 0x80 #define NLGRPSZ(x) (ALIGN(x, sizeof(unsigned long) * 8) / 8) #define NLGRPLONGS(x) (NLGRPSZ(x)/sizeof(unsigned long)) diff --git a/net/nfc/llcp_sock.c b/net/nfc/llcp_sock.c index dd4adf8b1167..ae296273ce3d 100644 --- a/net/nfc/llcp_sock.c +++ b/net/nfc/llcp_sock.c @@ -556,7 +556,7 @@ static __poll_t llcp_sock_poll(struct file *file, struct socket *sock, pr_debug("%p\n", sk); - sock_poll_wait(file, wait); + sock_poll_wait(file, sock, wait); if (sk->sk_state == LLCP_LISTEN) return llcp_accept_poll(sk); diff --git a/net/nfc/nci/uart.c b/net/nfc/nci/uart.c index 4503937915ad..78fe622eba65 100644 --- a/net/nfc/nci/uart.c +++ b/net/nfc/nci/uart.c @@ -463,6 +463,7 @@ static struct tty_ldisc_ops nci_uart_ldisc = { .receive_buf = nci_uart_tty_receive, .write_wakeup = nci_uart_tty_wakeup, .ioctl = nci_uart_tty_ioctl, + .compat_ioctl = nci_uart_tty_ioctl, }; static int __init nci_uart_init(void) diff --git a/net/openvswitch/conntrack.c b/net/openvswitch/conntrack.c index 86a75105af1a..6bec37ab4472 100644 --- a/net/openvswitch/conntrack.c +++ b/net/openvswitch/conntrack.c @@ -933,6 +933,11 @@ static int __ovs_ct_lookup(struct net *net, struct sw_flow_key *key, struct nf_conn *ct; if (!cached) { + struct nf_hook_state state = { + .hook = NF_INET_PRE_ROUTING, + .pf = info->family, + .net = net, + }; struct nf_conn *tmpl = info->ct; int err; @@ -944,8 +949,7 @@ static int __ovs_ct_lookup(struct net *net, struct sw_flow_key *key, nf_ct_set(skb, tmpl, IP_CT_NEW); } - err = nf_conntrack_in(net, info->family, - NF_INET_PRE_ROUTING, skb); + err = nf_conntrack_in(skb, &state); if (err != NF_ACCEPT) return -ENOENT; @@ -1312,6 +1316,10 @@ static int ovs_ct_add_helper(struct ovs_conntrack_info *info, const char *name, rcu_assign_pointer(help->helper, helper); info->helper = helper; + + if (info->nat) + request_module("ip_nat_%s", name); + return 0; } @@ -1624,10 +1632,6 @@ int ovs_ct_copy_action(struct net *net, const struct nlattr *attr, OVS_NLERR(log, "Failed to allocate conntrack template"); return -ENOMEM; } - - __set_bit(IPS_CONFIRMED_BIT, &ct_info.ct->status); - nf_conntrack_get(&ct_info.ct->ct_general); - if (helper) { err = ovs_ct_add_helper(&ct_info, helper, key, log); if (err) @@ -1639,6 +1643,8 @@ int ovs_ct_copy_action(struct net *net, const struct nlattr *attr, if (err) goto err_free_ct; + __set_bit(IPS_CONFIRMED_BIT, &ct_info.ct->status); + nf_conntrack_get(&ct_info.ct->ct_general); return 0; err_free_ct: __ovs_ct_free_action(&ct_info); diff --git a/net/openvswitch/flow_netlink.c b/net/openvswitch/flow_netlink.c index a70097ecf33c..865ecef68196 100644 --- a/net/openvswitch/flow_netlink.c +++ b/net/openvswitch/flow_netlink.c @@ -3030,7 +3030,7 @@ static int __ovs_nla_copy_actions(struct net *net, const struct nlattr *attr, * is already present */ if (mac_proto != MAC_PROTO_NONE) return -EINVAL; - mac_proto = MAC_PROTO_NONE; + mac_proto = MAC_PROTO_ETHERNET; break; case OVS_ACTION_ATTR_POP_ETH: @@ -3038,7 +3038,7 @@ static int __ovs_nla_copy_actions(struct net *net, const struct nlattr *attr, return -EINVAL; if (vlan_tci & htons(VLAN_TAG_PRESENT)) return -EINVAL; - mac_proto = MAC_PROTO_ETHERNET; + mac_proto = MAC_PROTO_NONE; break; case OVS_ACTION_ATTR_PUSH_NSH: diff --git a/net/packet/af_packet.c b/net/packet/af_packet.c index f85f67b5c1f4..ec3095f13aae 100644 --- a/net/packet/af_packet.c +++ b/net/packet/af_packet.c @@ -2715,10 +2715,12 @@ tpacket_error: } } - if (po->has_vnet_hdr && virtio_net_hdr_to_skb(skb, vnet_hdr, - vio_le())) { - tp_len = -EINVAL; - goto tpacket_error; + if (po->has_vnet_hdr) { + if (virtio_net_hdr_to_skb(skb, vnet_hdr, vio_le())) { + tp_len = -EINVAL; + goto tpacket_error; + } + virtio_net_hdr_set_proto(skb, vnet_hdr); } skb->destructor = tpacket_destruct_skb; @@ -2915,6 +2917,7 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len) if (err) goto out_free; len += sizeof(vnet_hdr); + virtio_net_hdr_set_proto(skb, &vnet_hdr); } skb_probe_transport_header(skb, reserve); diff --git a/net/rds/send.c b/net/rds/send.c index 57b3d5a8b2db..fe785ee819dd 100644 --- a/net/rds/send.c +++ b/net/rds/send.c @@ -1007,7 +1007,8 @@ static int rds_cmsg_send(struct rds_sock *rs, struct rds_message *rm, return ret; } -static int rds_send_mprds_hash(struct rds_sock *rs, struct rds_connection *conn) +static int rds_send_mprds_hash(struct rds_sock *rs, + struct rds_connection *conn, int nonblock) { int hash; @@ -1023,10 +1024,16 @@ static int rds_send_mprds_hash(struct rds_sock *rs, struct rds_connection *conn) * used. But if we are interrupted, we have to use the zero * c_path in case the connection ends up being non-MP capable. */ - if (conn->c_npaths == 0) + if (conn->c_npaths == 0) { + /* Cannot wait for the connection be made, so just use + * the base c_path. + */ + if (nonblock) + return 0; if (wait_event_interruptible(conn->c_hs_waitq, conn->c_npaths != 0)) hash = 0; + } if (conn->c_npaths == 1) hash = 0; } @@ -1256,7 +1263,7 @@ int rds_sendmsg(struct socket *sock, struct msghdr *msg, size_t payload_len) } if (conn->c_trans->t_mp_capable) - cpath = &conn->c_path[rds_send_mprds_hash(rs, conn)]; + cpath = &conn->c_path[rds_send_mprds_hash(rs, conn, nonblock)]; else cpath = &conn->c_path[0]; diff --git a/net/rxrpc/af_rxrpc.c b/net/rxrpc/af_rxrpc.c index ac44d8afffb1..64362d078da8 100644 --- a/net/rxrpc/af_rxrpc.c +++ b/net/rxrpc/af_rxrpc.c @@ -97,7 +97,8 @@ static int rxrpc_validate_address(struct rxrpc_sock *rx, srx->transport_len > len) return -EINVAL; - if (srx->transport.family != rx->family) + if (srx->transport.family != rx->family && + srx->transport.family == AF_INET && rx->family != AF_INET6) return -EAFNOSUPPORT; switch (srx->transport.family) { @@ -385,6 +386,20 @@ u32 rxrpc_kernel_check_life(struct socket *sock, struct rxrpc_call *call) EXPORT_SYMBOL(rxrpc_kernel_check_life); /** + * rxrpc_kernel_get_epoch - Retrieve the epoch value from a call. + * @sock: The socket the call is on + * @call: The call to query + * + * Allow a kernel service to retrieve the epoch value from a service call to + * see if the client at the other end rebooted. + */ +u32 rxrpc_kernel_get_epoch(struct socket *sock, struct rxrpc_call *call) +{ + return call->conn->proto.epoch; +} +EXPORT_SYMBOL(rxrpc_kernel_get_epoch); + +/** * rxrpc_kernel_check_call - Check a call's state * @sock: The socket the call is on * @call: The call to check @@ -741,7 +756,7 @@ static __poll_t rxrpc_poll(struct file *file, struct socket *sock, struct rxrpc_sock *rx = rxrpc_sk(sk); __poll_t mask; - sock_poll_wait(file, wait); + sock_poll_wait(file, sock, wait); mask = 0; /* the socket is readable if there are any messages waiting on the Rx diff --git a/net/rxrpc/ar-internal.h b/net/rxrpc/ar-internal.h index c97558710421..382196e57a26 100644 --- a/net/rxrpc/ar-internal.h +++ b/net/rxrpc/ar-internal.h @@ -40,17 +40,12 @@ struct rxrpc_crypt { struct rxrpc_connection; /* - * Mark applied to socket buffers. + * Mark applied to socket buffers in skb->mark. skb->priority is used + * to pass supplementary information. */ enum rxrpc_skb_mark { - RXRPC_SKB_MARK_DATA, /* data message */ - RXRPC_SKB_MARK_FINAL_ACK, /* final ACK received message */ - RXRPC_SKB_MARK_BUSY, /* server busy message */ - RXRPC_SKB_MARK_REMOTE_ABORT, /* remote abort message */ - RXRPC_SKB_MARK_LOCAL_ABORT, /* local abort message */ - RXRPC_SKB_MARK_NET_ERROR, /* network error message */ - RXRPC_SKB_MARK_LOCAL_ERROR, /* local error message */ - RXRPC_SKB_MARK_NEW_CALL, /* local error message */ + RXRPC_SKB_MARK_REJECT_BUSY, /* Reject with BUSY */ + RXRPC_SKB_MARK_REJECT_ABORT, /* Reject with ABORT (code in skb->priority) */ }; /* @@ -293,7 +288,6 @@ struct rxrpc_peer { struct hlist_node hash_link; struct rxrpc_local *local; struct hlist_head error_targets; /* targets for net error distribution */ - struct work_struct error_distributor; struct rb_root service_conns; /* Service connections */ struct list_head keepalive_link; /* Link in net->peer_keepalive[] */ time64_t last_tx_at; /* Last time packet sent here */ @@ -304,12 +298,11 @@ struct rxrpc_peer { unsigned int maxdata; /* data size (MTU - hdrsize) */ unsigned short hdrsize; /* header size (IP + UDP + RxRPC) */ int debug_id; /* debug ID for printks */ - int error_report; /* Net (+0) or local (+1000000) to distribute */ -#define RXRPC_LOCAL_ERROR_OFFSET 1000000 struct sockaddr_rxrpc srx; /* remote address */ /* calculated RTT cache */ #define RXRPC_RTT_CACHE_SIZE 32 + spinlock_t rtt_input_lock; /* RTT lock for input routine */ ktime_t rtt_last_req; /* Time of last RTT request */ u64 rtt; /* Current RTT estimate (in nS) */ u64 rtt_sum; /* Sum of cache contents */ @@ -442,7 +435,7 @@ struct rxrpc_connection { struct sk_buff_head rx_queue; /* received conn-level packets */ const struct rxrpc_security *security; /* applied security module */ struct key *server_key; /* security for this service */ - struct crypto_skcipher *cipher; /* encryption handle */ + struct crypto_sync_skcipher *cipher; /* encryption handle */ struct rxrpc_crypt csum_iv; /* packet checksum base */ unsigned long flags; unsigned long events; @@ -450,19 +443,29 @@ struct rxrpc_connection { spinlock_t state_lock; /* state-change lock */ enum rxrpc_conn_cache_state cache_state; enum rxrpc_conn_proto_state state; /* current state of connection */ - u32 local_abort; /* local abort code */ - u32 remote_abort; /* remote abort code */ + u32 abort_code; /* Abort code of connection abort */ int debug_id; /* debug ID for printks */ atomic_t serial; /* packet serial number counter */ unsigned int hi_serial; /* highest serial number received */ u32 security_nonce; /* response re-use preventer */ - u16 service_id; /* Service ID, possibly upgraded */ + u32 service_id; /* Service ID, possibly upgraded */ u8 size_align; /* data size alignment (for security) */ u8 security_size; /* security header size */ u8 security_ix; /* security type */ u8 out_clientflag; /* RXRPC_CLIENT_INITIATED if we are client */ + short error; /* Local error code */ }; +static inline bool rxrpc_to_server(const struct rxrpc_skb_priv *sp) +{ + return sp->hdr.flags & RXRPC_CLIENT_INITIATED; +} + +static inline bool rxrpc_to_client(const struct rxrpc_skb_priv *sp) +{ + return !rxrpc_to_server(sp); +} + /* * Flags in call->flags. */ @@ -633,6 +636,8 @@ struct rxrpc_call { bool tx_phase; /* T if transmission phase, F if receive phase */ u8 nr_jumbo_bad; /* Number of jumbo dups/exceeds-windows */ + spinlock_t input_lock; /* Lock for packet input to this call */ + /* receive-phase ACK management */ u8 ackr_reason; /* reason to ACK */ u16 ackr_skew; /* skew on packet being ACK'd */ @@ -717,7 +722,7 @@ extern struct workqueue_struct *rxrpc_workqueue; int rxrpc_service_prealloc(struct rxrpc_sock *, gfp_t); void rxrpc_discard_prealloc(struct rxrpc_sock *); struct rxrpc_call *rxrpc_new_incoming_call(struct rxrpc_local *, - struct rxrpc_connection *, + struct rxrpc_sock *, struct sk_buff *); void rxrpc_accept_incoming_calls(struct rxrpc_local *); struct rxrpc_call *rxrpc_accept_call(struct rxrpc_sock *, unsigned long, @@ -887,8 +892,9 @@ extern unsigned long rxrpc_conn_idle_client_fast_expiry; extern struct idr rxrpc_client_conn_ids; void rxrpc_destroy_client_conn_ids(void); -int rxrpc_connect_call(struct rxrpc_call *, struct rxrpc_conn_parameters *, - struct sockaddr_rxrpc *, gfp_t); +int rxrpc_connect_call(struct rxrpc_sock *, struct rxrpc_call *, + struct rxrpc_conn_parameters *, struct sockaddr_rxrpc *, + gfp_t); void rxrpc_expose_client_call(struct rxrpc_call *); void rxrpc_disconnect_client_call(struct rxrpc_call *); void rxrpc_put_client_conn(struct rxrpc_connection *); @@ -908,7 +914,8 @@ extern unsigned int rxrpc_closed_conn_expiry; struct rxrpc_connection *rxrpc_alloc_connection(gfp_t); struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *, - struct sk_buff *); + struct sk_buff *, + struct rxrpc_peer **); void __rxrpc_disconnect_call(struct rxrpc_connection *, struct rxrpc_call *); void rxrpc_disconnect_call(struct rxrpc_call *); void rxrpc_kill_connection(struct rxrpc_connection *); @@ -960,7 +967,7 @@ void rxrpc_unpublish_service_conn(struct rxrpc_connection *); /* * input.c */ -void rxrpc_data_ready(struct sock *); +int rxrpc_input_packet(struct sock *, struct sk_buff *); /* * insecure.c @@ -1031,7 +1038,6 @@ void rxrpc_send_keepalive(struct rxrpc_peer *); * peer_event.c */ void rxrpc_error_report(struct sock *); -void rxrpc_peer_error_distributor(struct work_struct *); void rxrpc_peer_add_rtt(struct rxrpc_call *, enum rxrpc_rtt_rx_trace, rxrpc_serial_t, rxrpc_serial_t, ktime_t, ktime_t); void rxrpc_peer_keepalive_worker(struct work_struct *); @@ -1041,22 +1047,22 @@ void rxrpc_peer_keepalive_worker(struct work_struct *); */ struct rxrpc_peer *rxrpc_lookup_peer_rcu(struct rxrpc_local *, const struct sockaddr_rxrpc *); -struct rxrpc_peer *rxrpc_lookup_peer(struct rxrpc_local *, +struct rxrpc_peer *rxrpc_lookup_peer(struct rxrpc_sock *, struct rxrpc_local *, struct sockaddr_rxrpc *, gfp_t); struct rxrpc_peer *rxrpc_alloc_peer(struct rxrpc_local *, gfp_t); -struct rxrpc_peer *rxrpc_lookup_incoming_peer(struct rxrpc_local *, - struct rxrpc_peer *); +void rxrpc_new_incoming_peer(struct rxrpc_sock *, struct rxrpc_local *, + struct rxrpc_peer *); void rxrpc_destroy_all_peers(struct rxrpc_net *); struct rxrpc_peer *rxrpc_get_peer(struct rxrpc_peer *); struct rxrpc_peer *rxrpc_get_peer_maybe(struct rxrpc_peer *); void rxrpc_put_peer(struct rxrpc_peer *); -void __rxrpc_queue_peer_error(struct rxrpc_peer *); /* * proc.c */ extern const struct seq_operations rxrpc_call_seq_ops; extern const struct seq_operations rxrpc_connection_seq_ops; +extern const struct seq_operations rxrpc_peer_seq_ops; /* * recvmsg.c @@ -1093,7 +1099,6 @@ void rxrpc_new_skb(struct sk_buff *, enum rxrpc_skb_trace); void rxrpc_see_skb(struct sk_buff *, enum rxrpc_skb_trace); void rxrpc_get_skb(struct sk_buff *, enum rxrpc_skb_trace); void rxrpc_free_skb(struct sk_buff *, enum rxrpc_skb_trace); -void rxrpc_lose_skb(struct sk_buff *, enum rxrpc_skb_trace); void rxrpc_purge_queue(struct sk_buff_head *); /* @@ -1110,8 +1115,7 @@ static inline void rxrpc_sysctl_exit(void) {} /* * utils.c */ -int rxrpc_extract_addr_from_skb(struct rxrpc_local *, struct sockaddr_rxrpc *, - struct sk_buff *); +int rxrpc_extract_addr_from_skb(struct sockaddr_rxrpc *, struct sk_buff *); static inline bool before(u32 seq1, u32 seq2) { diff --git a/net/rxrpc/call_accept.c b/net/rxrpc/call_accept.c index 9d1e298b784c..44860505246d 100644 --- a/net/rxrpc/call_accept.c +++ b/net/rxrpc/call_accept.c @@ -249,11 +249,11 @@ void rxrpc_discard_prealloc(struct rxrpc_sock *rx) */ static struct rxrpc_call *rxrpc_alloc_incoming_call(struct rxrpc_sock *rx, struct rxrpc_local *local, + struct rxrpc_peer *peer, struct rxrpc_connection *conn, struct sk_buff *skb) { struct rxrpc_backlog *b = rx->backlog; - struct rxrpc_peer *peer, *xpeer; struct rxrpc_call *call; unsigned short call_head, conn_head, peer_head; unsigned short call_tail, conn_tail, peer_tail; @@ -276,21 +276,18 @@ static struct rxrpc_call *rxrpc_alloc_incoming_call(struct rxrpc_sock *rx, return NULL; if (!conn) { - /* No connection. We're going to need a peer to start off - * with. If one doesn't yet exist, use a spare from the - * preallocation set. We dump the address into the spare in - * anticipation - and to save on stack space. - */ - xpeer = b->peer_backlog[peer_tail]; - if (rxrpc_extract_addr_from_skb(local, &xpeer->srx, skb) < 0) - return NULL; - - peer = rxrpc_lookup_incoming_peer(local, xpeer); - if (peer == xpeer) { + if (peer && !rxrpc_get_peer_maybe(peer)) + peer = NULL; + if (!peer) { + peer = b->peer_backlog[peer_tail]; + if (rxrpc_extract_addr_from_skb(&peer->srx, skb) < 0) + return NULL; b->peer_backlog[peer_tail] = NULL; smp_store_release(&b->peer_backlog_tail, (peer_tail + 1) & (RXRPC_BACKLOG_MAX - 1)); + + rxrpc_new_incoming_peer(rx, local, peer); } /* Now allocate and set up the connection */ @@ -335,45 +332,38 @@ static struct rxrpc_call *rxrpc_alloc_incoming_call(struct rxrpc_sock *rx, * The call is returned with the user access mutex held. */ struct rxrpc_call *rxrpc_new_incoming_call(struct rxrpc_local *local, - struct rxrpc_connection *conn, + struct rxrpc_sock *rx, struct sk_buff *skb) { struct rxrpc_skb_priv *sp = rxrpc_skb(skb); - struct rxrpc_sock *rx; + struct rxrpc_connection *conn; + struct rxrpc_peer *peer = NULL; struct rxrpc_call *call; - u16 service_id = sp->hdr.serviceId; _enter(""); - /* Get the socket providing the service */ - rx = rcu_dereference(local->service); - if (rx && (service_id == rx->srx.srx_service || - service_id == rx->second_service)) - goto found_service; - - trace_rxrpc_abort(0, "INV", sp->hdr.cid, sp->hdr.callNumber, sp->hdr.seq, - RX_INVALID_OPERATION, EOPNOTSUPP); - skb->mark = RXRPC_SKB_MARK_LOCAL_ABORT; - skb->priority = RX_INVALID_OPERATION; - _leave(" = NULL [service]"); - return NULL; - -found_service: spin_lock(&rx->incoming_lock); if (rx->sk.sk_state == RXRPC_SERVER_LISTEN_DISABLED || rx->sk.sk_state == RXRPC_CLOSE) { trace_rxrpc_abort(0, "CLS", sp->hdr.cid, sp->hdr.callNumber, sp->hdr.seq, RX_INVALID_OPERATION, ESHUTDOWN); - skb->mark = RXRPC_SKB_MARK_LOCAL_ABORT; + skb->mark = RXRPC_SKB_MARK_REJECT_ABORT; skb->priority = RX_INVALID_OPERATION; _leave(" = NULL [close]"); call = NULL; goto out; } - call = rxrpc_alloc_incoming_call(rx, local, conn, skb); + /* The peer, connection and call may all have sprung into existence due + * to a duplicate packet being handled on another CPU in parallel, so + * we have to recheck the routing. However, we're now holding + * rx->incoming_lock, so the values should remain stable. + */ + conn = rxrpc_find_connection_rcu(local, skb, &peer); + + call = rxrpc_alloc_incoming_call(rx, local, peer, conn, skb); if (!call) { - skb->mark = RXRPC_SKB_MARK_BUSY; + skb->mark = RXRPC_SKB_MARK_REJECT_BUSY; _leave(" = NULL [busy]"); call = NULL; goto out; @@ -413,20 +403,22 @@ found_service: case RXRPC_CONN_SERVICE: write_lock(&call->state_lock); - if (rx->discard_new_call) - call->state = RXRPC_CALL_SERVER_RECV_REQUEST; - else - call->state = RXRPC_CALL_SERVER_ACCEPTING; + if (call->state < RXRPC_CALL_COMPLETE) { + if (rx->discard_new_call) + call->state = RXRPC_CALL_SERVER_RECV_REQUEST; + else + call->state = RXRPC_CALL_SERVER_ACCEPTING; + } write_unlock(&call->state_lock); break; case RXRPC_CONN_REMOTELY_ABORTED: rxrpc_set_call_completion(call, RXRPC_CALL_REMOTELY_ABORTED, - conn->remote_abort, -ECONNABORTED); + conn->abort_code, conn->error); break; case RXRPC_CONN_LOCALLY_ABORTED: rxrpc_abort_call("CON", call, sp->hdr.seq, - conn->local_abort, -ECONNABORTED); + conn->abort_code, conn->error); break; default: BUG(); diff --git a/net/rxrpc/call_object.c b/net/rxrpc/call_object.c index 9486293fef5c..8f1a8f85b1f9 100644 --- a/net/rxrpc/call_object.c +++ b/net/rxrpc/call_object.c @@ -138,6 +138,7 @@ struct rxrpc_call *rxrpc_alloc_call(struct rxrpc_sock *rx, gfp_t gfp, init_waitqueue_head(&call->waitq); spin_lock_init(&call->lock); spin_lock_init(&call->notify_lock); + spin_lock_init(&call->input_lock); rwlock_init(&call->state_lock); atomic_set(&call->usage, 1); call->debug_id = debug_id; @@ -287,7 +288,7 @@ struct rxrpc_call *rxrpc_new_client_call(struct rxrpc_sock *rx, /* Set up or get a connection record and set the protocol parameters, * including channel number and call ID. */ - ret = rxrpc_connect_call(call, cp, srx, gfp); + ret = rxrpc_connect_call(rx, call, cp, srx, gfp); if (ret < 0) goto error; @@ -339,7 +340,7 @@ int rxrpc_retry_client_call(struct rxrpc_sock *rx, /* Set up or get a connection record and set the protocol parameters, * including channel number and call ID. */ - ret = rxrpc_connect_call(call, cp, srx, gfp); + ret = rxrpc_connect_call(rx, call, cp, srx, gfp); if (ret < 0) goto error; @@ -400,7 +401,7 @@ void rxrpc_incoming_call(struct rxrpc_sock *rx, rcu_assign_pointer(conn->channels[chan].call, call); spin_lock(&conn->params.peer->lock); - hlist_add_head(&call->error_link, &conn->params.peer->error_targets); + hlist_add_head_rcu(&call->error_link, &conn->params.peer->error_targets); spin_unlock(&conn->params.peer->lock); _net("CALL incoming %d on CONN %d", call->debug_id, call->conn->debug_id); diff --git a/net/rxrpc/conn_client.c b/net/rxrpc/conn_client.c index f8f37188a932..521189f4b666 100644 --- a/net/rxrpc/conn_client.c +++ b/net/rxrpc/conn_client.c @@ -276,7 +276,8 @@ dont_reuse: * If we return with a connection, the call will be on its waiting list. It's * left to the caller to assign a channel and wake up the call. */ -static int rxrpc_get_client_conn(struct rxrpc_call *call, +static int rxrpc_get_client_conn(struct rxrpc_sock *rx, + struct rxrpc_call *call, struct rxrpc_conn_parameters *cp, struct sockaddr_rxrpc *srx, gfp_t gfp) @@ -289,7 +290,7 @@ static int rxrpc_get_client_conn(struct rxrpc_call *call, _enter("{%d,%lx},", call->debug_id, call->user_call_ID); - cp->peer = rxrpc_lookup_peer(cp->local, srx, gfp); + cp->peer = rxrpc_lookup_peer(rx, cp->local, srx, gfp); if (!cp->peer) goto error; @@ -683,7 +684,8 @@ out: * find a connection for a call * - called in process context with IRQs enabled */ -int rxrpc_connect_call(struct rxrpc_call *call, +int rxrpc_connect_call(struct rxrpc_sock *rx, + struct rxrpc_call *call, struct rxrpc_conn_parameters *cp, struct sockaddr_rxrpc *srx, gfp_t gfp) @@ -696,7 +698,7 @@ int rxrpc_connect_call(struct rxrpc_call *call, rxrpc_discard_expired_client_conns(&rxnet->client_conn_reaper); rxrpc_cull_active_client_conns(rxnet); - ret = rxrpc_get_client_conn(call, cp, srx, gfp); + ret = rxrpc_get_client_conn(rx, call, cp, srx, gfp); if (ret < 0) goto out; @@ -710,8 +712,8 @@ int rxrpc_connect_call(struct rxrpc_call *call, } spin_lock_bh(&call->conn->params.peer->lock); - hlist_add_head(&call->error_link, - &call->conn->params.peer->error_targets); + hlist_add_head_rcu(&call->error_link, + &call->conn->params.peer->error_targets); spin_unlock_bh(&call->conn->params.peer->lock); out: diff --git a/net/rxrpc/conn_event.c b/net/rxrpc/conn_event.c index 6df56ce68861..b6fca8ebb117 100644 --- a/net/rxrpc/conn_event.c +++ b/net/rxrpc/conn_event.c @@ -126,7 +126,7 @@ static void rxrpc_conn_retransmit_call(struct rxrpc_connection *conn, switch (chan->last_type) { case RXRPC_PACKET_TYPE_ABORT: - _proto("Tx ABORT %%%u { %d } [re]", serial, conn->local_abort); + _proto("Tx ABORT %%%u { %d } [re]", serial, conn->abort_code); break; case RXRPC_PACKET_TYPE_ACK: trace_rxrpc_tx_ack(chan->call_debug_id, serial, @@ -153,13 +153,12 @@ static void rxrpc_conn_retransmit_call(struct rxrpc_connection *conn, * pass a connection-level abort onto all calls on that connection */ static void rxrpc_abort_calls(struct rxrpc_connection *conn, - enum rxrpc_call_completion compl, - u32 abort_code, int error) + enum rxrpc_call_completion compl) { struct rxrpc_call *call; int i; - _enter("{%d},%x", conn->debug_id, abort_code); + _enter("{%d},%x", conn->debug_id, conn->abort_code); spin_lock(&conn->channel_lock); @@ -172,9 +171,11 @@ static void rxrpc_abort_calls(struct rxrpc_connection *conn, trace_rxrpc_abort(call->debug_id, "CON", call->cid, call->call_id, 0, - abort_code, error); + conn->abort_code, + conn->error); if (rxrpc_set_call_completion(call, compl, - abort_code, error)) + conn->abort_code, + conn->error)) rxrpc_notify_socket(call); } } @@ -207,10 +208,12 @@ static int rxrpc_abort_connection(struct rxrpc_connection *conn, return 0; } + conn->error = error; + conn->abort_code = abort_code; conn->state = RXRPC_CONN_LOCALLY_ABORTED; spin_unlock_bh(&conn->state_lock); - rxrpc_abort_calls(conn, RXRPC_CALL_LOCALLY_ABORTED, abort_code, error); + rxrpc_abort_calls(conn, RXRPC_CALL_LOCALLY_ABORTED); msg.msg_name = &conn->params.peer->srx.transport; msg.msg_namelen = conn->params.peer->srx.transport_len; @@ -229,7 +232,7 @@ static int rxrpc_abort_connection(struct rxrpc_connection *conn, whdr._rsvd = 0; whdr.serviceId = htons(conn->service_id); - word = htonl(conn->local_abort); + word = htonl(conn->abort_code); iov[0].iov_base = &whdr; iov[0].iov_len = sizeof(whdr); @@ -240,7 +243,7 @@ static int rxrpc_abort_connection(struct rxrpc_connection *conn, serial = atomic_inc_return(&conn->serial); whdr.serial = htonl(serial); - _proto("Tx CONN ABORT %%%u { %d }", serial, conn->local_abort); + _proto("Tx CONN ABORT %%%u { %d }", serial, conn->abort_code); ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, 2, len); if (ret < 0) { @@ -315,9 +318,10 @@ static int rxrpc_process_event(struct rxrpc_connection *conn, abort_code = ntohl(wtmp); _proto("Rx ABORT %%%u { ac=%d }", sp->hdr.serial, abort_code); + conn->error = -ECONNABORTED; + conn->abort_code = abort_code; conn->state = RXRPC_CONN_REMOTELY_ABORTED; - rxrpc_abort_calls(conn, RXRPC_CALL_REMOTELY_ABORTED, - abort_code, -ECONNABORTED); + rxrpc_abort_calls(conn, RXRPC_CALL_REMOTELY_ABORTED); return -ECONNABORTED; case RXRPC_PACKET_TYPE_CHALLENGE: diff --git a/net/rxrpc/conn_object.c b/net/rxrpc/conn_object.c index 77440a356b14..c332722820c2 100644 --- a/net/rxrpc/conn_object.c +++ b/net/rxrpc/conn_object.c @@ -69,10 +69,14 @@ struct rxrpc_connection *rxrpc_alloc_connection(gfp_t gfp) * If successful, a pointer to the connection is returned, but no ref is taken. * NULL is returned if there is no match. * + * When searching for a service call, if we find a peer but no connection, we + * return that through *_peer in case we need to create a new service call. + * * The caller must be holding the RCU read lock. */ struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *local, - struct sk_buff *skb) + struct sk_buff *skb, + struct rxrpc_peer **_peer) { struct rxrpc_connection *conn; struct rxrpc_conn_proto k; @@ -82,14 +86,12 @@ struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *local, _enter(",%x", sp->hdr.cid & RXRPC_CIDMASK); - if (rxrpc_extract_addr_from_skb(local, &srx, skb) < 0) + if (rxrpc_extract_addr_from_skb(&srx, skb) < 0) goto not_found; - k.epoch = sp->hdr.epoch; - k.cid = sp->hdr.cid & RXRPC_CIDMASK; - - /* We may have to handle mixing IPv4 and IPv6 */ - if (srx.transport.family != local->srx.transport.family) { + if (srx.transport.family != local->srx.transport.family && + (srx.transport.family == AF_INET && + local->srx.transport.family != AF_INET6)) { pr_warn_ratelimited("AF_RXRPC: Protocol mismatch %u not %u\n", srx.transport.family, local->srx.transport.family); @@ -99,7 +101,7 @@ struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *local, k.epoch = sp->hdr.epoch; k.cid = sp->hdr.cid & RXRPC_CIDMASK; - if (sp->hdr.flags & RXRPC_CLIENT_INITIATED) { + if (rxrpc_to_server(sp)) { /* We need to look up service connections by the full protocol * parameter set. We look up the peer first as an intermediate * step and then the connection from the peer's tree. @@ -107,6 +109,7 @@ struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *local, peer = rxrpc_lookup_peer_rcu(local, &srx); if (!peer) goto not_found; + *_peer = peer; conn = rxrpc_find_service_conn_rcu(peer, skb); if (!conn || atomic_read(&conn->usage) == 0) goto not_found; @@ -214,7 +217,7 @@ void rxrpc_disconnect_call(struct rxrpc_call *call) call->peer->cong_cwnd = call->cong_cwnd; spin_lock_bh(&conn->params.peer->lock); - hlist_del_init(&call->error_link); + hlist_del_rcu(&call->error_link); spin_unlock_bh(&conn->params.peer->lock); if (rxrpc_is_client_call(call)) diff --git a/net/rxrpc/input.c b/net/rxrpc/input.c index ee8e7e1d5c0f..9128aa0e40aa 100644 --- a/net/rxrpc/input.c +++ b/net/rxrpc/input.c @@ -216,10 +216,11 @@ static void rxrpc_send_ping(struct rxrpc_call *call, struct sk_buff *skb, /* * Apply a hard ACK by advancing the Tx window. */ -static void rxrpc_rotate_tx_window(struct rxrpc_call *call, rxrpc_seq_t to, +static bool rxrpc_rotate_tx_window(struct rxrpc_call *call, rxrpc_seq_t to, struct rxrpc_ack_summary *summary) { struct sk_buff *skb, *list = NULL; + bool rot_last = false; int ix; u8 annotation; @@ -243,15 +244,17 @@ static void rxrpc_rotate_tx_window(struct rxrpc_call *call, rxrpc_seq_t to, skb->next = list; list = skb; - if (annotation & RXRPC_TX_ANNO_LAST) + if (annotation & RXRPC_TX_ANNO_LAST) { set_bit(RXRPC_CALL_TX_LAST, &call->flags); + rot_last = true; + } if ((annotation & RXRPC_TX_ANNO_MASK) != RXRPC_TX_ANNO_ACK) summary->nr_rot_new_acks++; } spin_unlock(&call->lock); - trace_rxrpc_transmit(call, (test_bit(RXRPC_CALL_TX_LAST, &call->flags) ? + trace_rxrpc_transmit(call, (rot_last ? rxrpc_transmit_rotate_last : rxrpc_transmit_rotate)); wake_up(&call->waitq); @@ -262,6 +265,8 @@ static void rxrpc_rotate_tx_window(struct rxrpc_call *call, rxrpc_seq_t to, skb_mark_not_on_list(skb); rxrpc_free_skb(skb, rxrpc_skb_tx_freed); } + + return rot_last; } /* @@ -273,23 +278,26 @@ static void rxrpc_rotate_tx_window(struct rxrpc_call *call, rxrpc_seq_t to, static bool rxrpc_end_tx_phase(struct rxrpc_call *call, bool reply_begun, const char *abort_why) { + unsigned int state; ASSERT(test_bit(RXRPC_CALL_TX_LAST, &call->flags)); write_lock(&call->state_lock); - switch (call->state) { + state = call->state; + switch (state) { case RXRPC_CALL_CLIENT_SEND_REQUEST: case RXRPC_CALL_CLIENT_AWAIT_REPLY: if (reply_begun) - call->state = RXRPC_CALL_CLIENT_RECV_REPLY; + call->state = state = RXRPC_CALL_CLIENT_RECV_REPLY; else - call->state = RXRPC_CALL_CLIENT_AWAIT_REPLY; + call->state = state = RXRPC_CALL_CLIENT_AWAIT_REPLY; break; case RXRPC_CALL_SERVER_AWAIT_ACK: __rxrpc_call_completed(call); rxrpc_notify_socket(call); + state = call->state; break; default: @@ -297,11 +305,10 @@ static bool rxrpc_end_tx_phase(struct rxrpc_call *call, bool reply_begun, } write_unlock(&call->state_lock); - if (call->state == RXRPC_CALL_CLIENT_AWAIT_REPLY) { + if (state == RXRPC_CALL_CLIENT_AWAIT_REPLY) trace_rxrpc_transmit(call, rxrpc_transmit_await_reply); - } else { + else trace_rxrpc_transmit(call, rxrpc_transmit_end); - } _leave(" = ok"); return true; @@ -332,11 +339,11 @@ static bool rxrpc_receiving_reply(struct rxrpc_call *call) trace_rxrpc_timer(call, rxrpc_timer_init_for_reply, now); } - if (!test_bit(RXRPC_CALL_TX_LAST, &call->flags)) - rxrpc_rotate_tx_window(call, top, &summary); if (!test_bit(RXRPC_CALL_TX_LAST, &call->flags)) { - rxrpc_proto_abort("TXL", call, top); - return false; + if (!rxrpc_rotate_tx_window(call, top, &summary)) { + rxrpc_proto_abort("TXL", call, top); + return false; + } } if (!rxrpc_end_tx_phase(call, true, "ETD")) return false; @@ -452,13 +459,15 @@ static void rxrpc_input_data(struct rxrpc_call *call, struct sk_buff *skb, } } + spin_lock(&call->input_lock); + /* Received data implicitly ACKs all of the request packets we sent * when we're acting as a client. */ if ((state == RXRPC_CALL_CLIENT_SEND_REQUEST || state == RXRPC_CALL_CLIENT_AWAIT_REPLY) && !rxrpc_receiving_reply(call)) - return; + goto unlock; call->ackr_prev_seq = seq; @@ -488,12 +497,16 @@ next_subpacket: if (flags & RXRPC_LAST_PACKET) { if (test_bit(RXRPC_CALL_RX_LAST, &call->flags) && - seq != call->rx_top) - return rxrpc_proto_abort("LSN", call, seq); + seq != call->rx_top) { + rxrpc_proto_abort("LSN", call, seq); + goto unlock; + } } else { if (test_bit(RXRPC_CALL_RX_LAST, &call->flags) && - after_eq(seq, call->rx_top)) - return rxrpc_proto_abort("LSA", call, seq); + after_eq(seq, call->rx_top)) { + rxrpc_proto_abort("LSA", call, seq); + goto unlock; + } } trace_rxrpc_rx_data(call->debug_id, seq, serial, flags, annotation); @@ -560,8 +573,10 @@ next_subpacket: skip: offset += len; if (flags & RXRPC_JUMBO_PACKET) { - if (skb_copy_bits(skb, offset, &flags, 1) < 0) - return rxrpc_proto_abort("XJF", call, seq); + if (skb_copy_bits(skb, offset, &flags, 1) < 0) { + rxrpc_proto_abort("XJF", call, seq); + goto unlock; + } offset += sizeof(struct rxrpc_jumbo_header); seq++; serial++; @@ -601,6 +616,9 @@ ack: trace_rxrpc_notify_socket(call->debug_id, serial); rxrpc_notify_socket(call); } + +unlock: + spin_unlock(&call->input_lock); _leave(" [queued]"); } @@ -622,13 +640,14 @@ static void rxrpc_input_requested_ack(struct rxrpc_call *call, if (!skb) continue; + sent_at = skb->tstamp; + smp_rmb(); /* Read timestamp before serial. */ sp = rxrpc_skb(skb); if (sp->hdr.serial != orig_serial) continue; - smp_rmb(); - sent_at = skb->tstamp; goto found; } + return; found: @@ -686,15 +705,14 @@ static void rxrpc_input_ping_response(struct rxrpc_call *call, ping_time = call->ping_time; smp_rmb(); - ping_serial = call->ping_serial; + ping_serial = READ_ONCE(call->ping_serial); if (orig_serial == call->acks_lost_ping) rxrpc_input_check_for_lost_ack(call); - if (!test_bit(RXRPC_CALL_PINGING, &call->flags) || - before(orig_serial, ping_serial)) + if (before(orig_serial, ping_serial) || + !test_and_clear_bit(RXRPC_CALL_PINGING, &call->flags)) return; - clear_bit(RXRPC_CALL_PINGING, &call->flags); if (after(orig_serial, ping_serial)) return; @@ -860,15 +878,32 @@ static void rxrpc_input_ack(struct rxrpc_call *call, struct sk_buff *skb, rxrpc_propose_ack_respond_to_ack); } + /* Discard any out-of-order or duplicate ACKs. */ + if (before_eq(sp->hdr.serial, call->acks_latest)) + return; + + buf.info.rxMTU = 0; ioffset = offset + nr_acks + 3; - if (skb->len >= ioffset + sizeof(buf.info)) { - if (skb_copy_bits(skb, ioffset, &buf.info, sizeof(buf.info)) < 0) - return rxrpc_proto_abort("XAI", call, 0); + if (skb->len >= ioffset + sizeof(buf.info) && + skb_copy_bits(skb, ioffset, &buf.info, sizeof(buf.info)) < 0) + return rxrpc_proto_abort("XAI", call, 0); + + spin_lock(&call->input_lock); + + /* Discard any out-of-order or duplicate ACKs. */ + if (before_eq(sp->hdr.serial, call->acks_latest)) + goto out; + call->acks_latest_ts = skb->tstamp; + call->acks_latest = sp->hdr.serial; + + /* Parse rwind and mtu sizes if provided. */ + if (buf.info.rxMTU) rxrpc_input_ackinfo(call, skb, &buf.info); - } - if (first_soft_ack == 0) - return rxrpc_proto_abort("AK0", call, 0); + if (first_soft_ack == 0) { + rxrpc_proto_abort("AK0", call, 0); + goto out; + } /* Ignore ACKs unless we are or have just been transmitting. */ switch (READ_ONCE(call->state)) { @@ -878,39 +913,35 @@ static void rxrpc_input_ack(struct rxrpc_call *call, struct sk_buff *skb, case RXRPC_CALL_SERVER_AWAIT_ACK: break; default: - return; - } - - /* Discard any out-of-order or duplicate ACKs. */ - if (before_eq(sp->hdr.serial, call->acks_latest)) { - _debug("discard ACK %d <= %d", - sp->hdr.serial, call->acks_latest); - return; + goto out; } - call->acks_latest_ts = skb->tstamp; - call->acks_latest = sp->hdr.serial; if (before(hard_ack, call->tx_hard_ack) || - after(hard_ack, call->tx_top)) - return rxrpc_proto_abort("AKW", call, 0); - if (nr_acks > call->tx_top - hard_ack) - return rxrpc_proto_abort("AKN", call, 0); + after(hard_ack, call->tx_top)) { + rxrpc_proto_abort("AKW", call, 0); + goto out; + } + if (nr_acks > call->tx_top - hard_ack) { + rxrpc_proto_abort("AKN", call, 0); + goto out; + } - if (after(hard_ack, call->tx_hard_ack)) - rxrpc_rotate_tx_window(call, hard_ack, &summary); + if (after(hard_ack, call->tx_hard_ack)) { + if (rxrpc_rotate_tx_window(call, hard_ack, &summary)) { + rxrpc_end_tx_phase(call, false, "ETA"); + goto out; + } + } if (nr_acks > 0) { - if (skb_copy_bits(skb, offset, buf.acks, nr_acks) < 0) - return rxrpc_proto_abort("XSA", call, 0); + if (skb_copy_bits(skb, offset, buf.acks, nr_acks) < 0) { + rxrpc_proto_abort("XSA", call, 0); + goto out; + } rxrpc_input_soft_acks(call, buf.acks, first_soft_ack, nr_acks, &summary); } - if (test_bit(RXRPC_CALL_TX_LAST, &call->flags)) { - rxrpc_end_tx_phase(call, false, "ETA"); - return; - } - if (call->rxtx_annotations[call->tx_top & RXRPC_RXTX_BUFF_MASK] & RXRPC_TX_ANNO_LAST && summary.nr_acks == call->tx_top - hard_ack && @@ -919,7 +950,9 @@ static void rxrpc_input_ack(struct rxrpc_call *call, struct sk_buff *skb, false, true, rxrpc_propose_ack_ping_for_lost_reply); - return rxrpc_congestion_management(call, skb, &summary, acked_serial); + rxrpc_congestion_management(call, skb, &summary, acked_serial); +out: + spin_unlock(&call->input_lock); } /* @@ -932,9 +965,12 @@ static void rxrpc_input_ackall(struct rxrpc_call *call, struct sk_buff *skb) _proto("Rx ACKALL %%%u", sp->hdr.serial); - rxrpc_rotate_tx_window(call, call->tx_top, &summary); - if (test_bit(RXRPC_CALL_TX_LAST, &call->flags)) + spin_lock(&call->input_lock); + + if (rxrpc_rotate_tx_window(call, call->tx_top, &summary)) rxrpc_end_tx_phase(call, false, "ETL"); + + spin_unlock(&call->input_lock); } /* @@ -1017,18 +1053,19 @@ static void rxrpc_input_call_packet(struct rxrpc_call *call, } /* - * Handle a new call on a channel implicitly completing the preceding call on - * that channel. + * Handle a new service call on a channel implicitly completing the preceding + * call on that channel. This does not apply to client conns. * * TODO: If callNumber > call_id + 1, renegotiate security. */ -static void rxrpc_input_implicit_end_call(struct rxrpc_connection *conn, +static void rxrpc_input_implicit_end_call(struct rxrpc_sock *rx, + struct rxrpc_connection *conn, struct rxrpc_call *call) { switch (READ_ONCE(call->state)) { case RXRPC_CALL_SERVER_AWAIT_ACK: rxrpc_call_completed(call); - break; + /* Fall through */ case RXRPC_CALL_COMPLETE: break; default: @@ -1036,11 +1073,13 @@ static void rxrpc_input_implicit_end_call(struct rxrpc_connection *conn, set_bit(RXRPC_CALL_EV_ABORT, &call->events); rxrpc_queue_call(call); } + trace_rxrpc_improper_term(call); break; } - trace_rxrpc_improper_term(call); + spin_lock(&rx->incoming_lock); __rxrpc_disconnect_call(conn, call); + spin_unlock(&rx->incoming_lock); rxrpc_notify_socket(call); } @@ -1119,43 +1158,29 @@ int rxrpc_extract_header(struct rxrpc_skb_priv *sp, struct sk_buff *skb) * The socket is locked by the caller and this prevents the socket from being * shut down and the local endpoint from going away, thus sk_user_data will not * be cleared until this function returns. + * + * Called with the RCU read lock held from the IP layer via UDP. */ -void rxrpc_data_ready(struct sock *udp_sk) +int rxrpc_input_packet(struct sock *udp_sk, struct sk_buff *skb) { struct rxrpc_connection *conn; struct rxrpc_channel *chan; - struct rxrpc_call *call; + struct rxrpc_call *call = NULL; struct rxrpc_skb_priv *sp; struct rxrpc_local *local = udp_sk->sk_user_data; - struct sk_buff *skb; + struct rxrpc_peer *peer = NULL; + struct rxrpc_sock *rx = NULL; unsigned int channel; - int ret, skew; + int skew = 0; _enter("%p", udp_sk); - ASSERT(!irqs_disabled()); - - skb = skb_recv_udp(udp_sk, 0, 1, &ret); - if (!skb) { - if (ret == -EAGAIN) - return; - _debug("UDP socket error %d", ret); - return; - } + if (skb->tstamp == 0) + skb->tstamp = ktime_get_real(); rxrpc_new_skb(skb, rxrpc_skb_rx_received); - _net("recv skb %p", skb); - - /* we'll probably need to checksum it (didn't call sock_recvmsg) */ - if (skb_checksum_complete(skb)) { - rxrpc_free_skb(skb, rxrpc_skb_rx_freed); - __UDP_INC_STATS(&init_net, UDP_MIB_INERRORS, 0); - _leave(" [CSUM failed]"); - return; - } - - __UDP_INC_STATS(&init_net, UDP_MIB_INDATAGRAMS, 0); + skb_pull(skb, sizeof(struct udphdr)); /* The UDP protocol already released all skb resources; * we are free to add our own data there. @@ -1170,69 +1195,104 @@ void rxrpc_data_ready(struct sock *udp_sk) static int lose; if ((lose++ & 7) == 7) { trace_rxrpc_rx_lose(sp); - rxrpc_lose_skb(skb, rxrpc_skb_rx_lost); - return; + rxrpc_free_skb(skb, rxrpc_skb_rx_lost); + return 0; } } + if (skb->tstamp == 0) + skb->tstamp = ktime_get_real(); trace_rxrpc_rx_packet(sp); - _net("Rx RxRPC %s ep=%x call=%x:%x", - sp->hdr.flags & RXRPC_CLIENT_INITIATED ? "ToServer" : "ToClient", - sp->hdr.epoch, sp->hdr.cid, sp->hdr.callNumber); - - if (sp->hdr.type >= RXRPC_N_PACKET_TYPES || - !((RXRPC_SUPPORTED_PACKET_TYPES >> sp->hdr.type) & 1)) { - _proto("Rx Bad Packet Type %u", sp->hdr.type); - goto bad_message; - } - switch (sp->hdr.type) { case RXRPC_PACKET_TYPE_VERSION: - if (!(sp->hdr.flags & RXRPC_CLIENT_INITIATED)) + if (rxrpc_to_client(sp)) goto discard; rxrpc_post_packet_to_local(local, skb); goto out; case RXRPC_PACKET_TYPE_BUSY: - if (sp->hdr.flags & RXRPC_CLIENT_INITIATED) + if (rxrpc_to_server(sp)) goto discard; /* Fall through */ + case RXRPC_PACKET_TYPE_ACK: + case RXRPC_PACKET_TYPE_ACKALL: + if (sp->hdr.callNumber == 0) + goto bad_message; + /* Fall through */ + case RXRPC_PACKET_TYPE_ABORT: + break; case RXRPC_PACKET_TYPE_DATA: - if (sp->hdr.callNumber == 0) + if (sp->hdr.callNumber == 0 || + sp->hdr.seq == 0) goto bad_message; if (sp->hdr.flags & RXRPC_JUMBO_PACKET && !rxrpc_validate_jumbo(skb)) goto bad_message; break; + case RXRPC_PACKET_TYPE_CHALLENGE: + if (rxrpc_to_server(sp)) + goto discard; + break; + case RXRPC_PACKET_TYPE_RESPONSE: + if (rxrpc_to_client(sp)) + goto discard; + break; + /* Packet types 9-11 should just be ignored. */ case RXRPC_PACKET_TYPE_PARAMS: case RXRPC_PACKET_TYPE_10: case RXRPC_PACKET_TYPE_11: goto discard; + + default: + _proto("Rx Bad Packet Type %u", sp->hdr.type); + goto bad_message; } - rcu_read_lock(); + if (sp->hdr.serviceId == 0) + goto bad_message; + + if (rxrpc_to_server(sp)) { + /* Weed out packets to services we're not offering. Packets + * that would begin a call are explicitly rejected and the rest + * are just discarded. + */ + rx = rcu_dereference(local->service); + if (!rx || (sp->hdr.serviceId != rx->srx.srx_service && + sp->hdr.serviceId != rx->second_service)) { + if (sp->hdr.type == RXRPC_PACKET_TYPE_DATA && + sp->hdr.seq == 1) + goto unsupported_service; + goto discard; + } + } - conn = rxrpc_find_connection_rcu(local, skb); + conn = rxrpc_find_connection_rcu(local, skb, &peer); if (conn) { if (sp->hdr.securityIndex != conn->security_ix) goto wrong_security; if (sp->hdr.serviceId != conn->service_id) { - if (!test_bit(RXRPC_CONN_PROBING_FOR_UPGRADE, &conn->flags) || - conn->service_id != conn->params.service_id) + int old_id; + + if (!test_bit(RXRPC_CONN_PROBING_FOR_UPGRADE, &conn->flags)) + goto reupgrade; + old_id = cmpxchg(&conn->service_id, conn->params.service_id, + sp->hdr.serviceId); + + if (old_id != conn->params.service_id && + old_id != sp->hdr.serviceId) goto reupgrade; - conn->service_id = sp->hdr.serviceId; } if (sp->hdr.callNumber == 0) { /* Connection-level packet */ _debug("CONN %p {%d}", conn, conn->debug_id); rxrpc_post_packet_to_conn(conn, skb); - goto out_unlock; + goto out; } /* Note the serial number skew here */ @@ -1251,19 +1311,19 @@ void rxrpc_data_ready(struct sock *udp_sk) /* Ignore really old calls */ if (sp->hdr.callNumber < chan->last_call) - goto discard_unlock; + goto discard; if (sp->hdr.callNumber == chan->last_call) { if (chan->call || sp->hdr.type == RXRPC_PACKET_TYPE_ABORT) - goto discard_unlock; + goto discard; /* For the previous service call, if completed * successfully, we discard all further packets. */ if (rxrpc_conn_is_service(conn) && chan->last_type == RXRPC_PACKET_TYPE_ACK) - goto discard_unlock; + goto discard; /* But otherwise we need to retransmit the final packet * from data cached in the connection record. @@ -1274,18 +1334,16 @@ void rxrpc_data_ready(struct sock *udp_sk) sp->hdr.serial, sp->hdr.flags, 0); rxrpc_post_packet_to_conn(conn, skb); - goto out_unlock; + goto out; } call = rcu_dereference(chan->call); if (sp->hdr.callNumber > chan->call_id) { - if (!(sp->hdr.flags & RXRPC_CLIENT_INITIATED)) { - rcu_read_unlock(); + if (rxrpc_to_client(sp)) goto reject_packet; - } if (call) - rxrpc_input_implicit_end_call(conn, call); + rxrpc_input_implicit_end_call(rx, conn, call); call = NULL; } @@ -1297,66 +1355,57 @@ void rxrpc_data_ready(struct sock *udp_sk) if (!test_bit(RXRPC_CALL_RX_HEARD, &call->flags)) set_bit(RXRPC_CALL_RX_HEARD, &call->flags); } - } else { - skew = 0; - call = NULL; } if (!call || atomic_read(&call->usage) == 0) { - if (!(sp->hdr.type & RXRPC_CLIENT_INITIATED) || - sp->hdr.callNumber == 0 || + if (rxrpc_to_client(sp) || sp->hdr.type != RXRPC_PACKET_TYPE_DATA) - goto bad_message_unlock; + goto bad_message; if (sp->hdr.seq != 1) - goto discard_unlock; - call = rxrpc_new_incoming_call(local, conn, skb); - if (!call) { - rcu_read_unlock(); + goto discard; + call = rxrpc_new_incoming_call(local, rx, skb); + if (!call) goto reject_packet; - } rxrpc_send_ping(call, skb, skew); mutex_unlock(&call->user_mutex); } rxrpc_input_call_packet(call, skb, skew); - goto discard_unlock; + goto discard; -discard_unlock: - rcu_read_unlock(); discard: rxrpc_free_skb(skb, rxrpc_skb_rx_freed); out: trace_rxrpc_rx_done(0, 0); - return; - -out_unlock: - rcu_read_unlock(); - goto out; + return 0; wrong_security: - rcu_read_unlock(); trace_rxrpc_abort(0, "SEC", sp->hdr.cid, sp->hdr.callNumber, sp->hdr.seq, RXKADINCONSISTENCY, EBADMSG); skb->priority = RXKADINCONSISTENCY; goto post_abort; +unsupported_service: + trace_rxrpc_abort(0, "INV", sp->hdr.cid, sp->hdr.callNumber, sp->hdr.seq, + RX_INVALID_OPERATION, EOPNOTSUPP); + skb->priority = RX_INVALID_OPERATION; + goto post_abort; + reupgrade: - rcu_read_unlock(); trace_rxrpc_abort(0, "UPG", sp->hdr.cid, sp->hdr.callNumber, sp->hdr.seq, RX_PROTOCOL_ERROR, EBADMSG); goto protocol_error; -bad_message_unlock: - rcu_read_unlock(); bad_message: trace_rxrpc_abort(0, "BAD", sp->hdr.cid, sp->hdr.callNumber, sp->hdr.seq, RX_PROTOCOL_ERROR, EBADMSG); protocol_error: skb->priority = RX_PROTOCOL_ERROR; post_abort: - skb->mark = RXRPC_SKB_MARK_LOCAL_ABORT; + skb->mark = RXRPC_SKB_MARK_REJECT_ABORT; reject_packet: trace_rxrpc_rx_done(skb->mark, skb->priority); rxrpc_reject_packet(local, skb); _leave(" [badmsg]"); + return 0; } diff --git a/net/rxrpc/local_event.c b/net/rxrpc/local_event.c index 13bd8a4dfac7..927ead43df42 100644 --- a/net/rxrpc/local_event.c +++ b/net/rxrpc/local_event.c @@ -39,7 +39,7 @@ static void rxrpc_send_version_request(struct rxrpc_local *local, _enter(""); - if (rxrpc_extract_addr_from_skb(local, &srx, skb) < 0) + if (rxrpc_extract_addr_from_skb(&srx, skb) < 0) return; msg.msg_name = &srx.transport; diff --git a/net/rxrpc/local_object.c b/net/rxrpc/local_object.c index 777c3ed4cfc0..0906e51d3cfb 100644 --- a/net/rxrpc/local_object.c +++ b/net/rxrpc/local_object.c @@ -19,6 +19,7 @@ #include <linux/ip.h> #include <linux/hashtable.h> #include <net/sock.h> +#include <net/udp.h> #include <net/af_rxrpc.h> #include "ar-internal.h" @@ -108,7 +109,7 @@ static struct rxrpc_local *rxrpc_alloc_local(struct rxrpc_net *rxnet, */ static int rxrpc_open_socket(struct rxrpc_local *local, struct net *net) { - struct sock *sock; + struct sock *usk; int ret, opt; _enter("%p{%d,%d}", @@ -122,6 +123,28 @@ static int rxrpc_open_socket(struct rxrpc_local *local, struct net *net) return ret; } + /* set the socket up */ + usk = local->socket->sk; + inet_sk(usk)->mc_loop = 0; + + /* Enable CHECKSUM_UNNECESSARY to CHECKSUM_COMPLETE conversion */ + inet_inc_convert_csum(usk); + + rcu_assign_sk_user_data(usk, local); + + udp_sk(usk)->encap_type = UDP_ENCAP_RXRPC; + udp_sk(usk)->encap_rcv = rxrpc_input_packet; + udp_sk(usk)->encap_destroy = NULL; + udp_sk(usk)->gro_receive = NULL; + udp_sk(usk)->gro_complete = NULL; + + udp_encap_enable(); +#if IS_ENABLED(CONFIG_AF_RXRPC_IPV6) + if (local->srx.transport.family == AF_INET6) + udpv6_encap_enable(); +#endif + usk->sk_error_report = rxrpc_error_report; + /* if a local address was supplied then bind it */ if (local->srx.transport_len > sizeof(sa_family_t)) { _debug("bind"); @@ -135,10 +158,10 @@ static int rxrpc_open_socket(struct rxrpc_local *local, struct net *net) } switch (local->srx.transport.family) { - case AF_INET: - /* we want to receive ICMP errors */ + case AF_INET6: + /* we want to receive ICMPv6 errors */ opt = 1; - ret = kernel_setsockopt(local->socket, SOL_IP, IP_RECVERR, + ret = kernel_setsockopt(local->socket, SOL_IPV6, IPV6_RECVERR, (char *) &opt, sizeof(opt)); if (ret < 0) { _debug("setsockopt failed"); @@ -146,19 +169,22 @@ static int rxrpc_open_socket(struct rxrpc_local *local, struct net *net) } /* we want to set the don't fragment bit */ - opt = IP_PMTUDISC_DO; - ret = kernel_setsockopt(local->socket, SOL_IP, IP_MTU_DISCOVER, + opt = IPV6_PMTUDISC_DO; + ret = kernel_setsockopt(local->socket, SOL_IPV6, IPV6_MTU_DISCOVER, (char *) &opt, sizeof(opt)); if (ret < 0) { _debug("setsockopt failed"); goto error; } - break; - case AF_INET6: + /* Fall through and set IPv4 options too otherwise we don't get + * errors from IPv4 packets sent through the IPv6 socket. + */ + + case AF_INET: /* we want to receive ICMP errors */ opt = 1; - ret = kernel_setsockopt(local->socket, SOL_IPV6, IPV6_RECVERR, + ret = kernel_setsockopt(local->socket, SOL_IP, IP_RECVERR, (char *) &opt, sizeof(opt)); if (ret < 0) { _debug("setsockopt failed"); @@ -166,24 +192,28 @@ static int rxrpc_open_socket(struct rxrpc_local *local, struct net *net) } /* we want to set the don't fragment bit */ - opt = IPV6_PMTUDISC_DO; - ret = kernel_setsockopt(local->socket, SOL_IPV6, IPV6_MTU_DISCOVER, + opt = IP_PMTUDISC_DO; + ret = kernel_setsockopt(local->socket, SOL_IP, IP_MTU_DISCOVER, (char *) &opt, sizeof(opt)); if (ret < 0) { _debug("setsockopt failed"); goto error; } + + /* We want receive timestamps. */ + opt = 1; + ret = kernel_setsockopt(local->socket, SOL_SOCKET, SO_TIMESTAMPNS, + (char *)&opt, sizeof(opt)); + if (ret < 0) { + _debug("setsockopt failed"); + goto error; + } break; default: BUG(); } - /* set the socket up */ - sock = local->socket->sk; - sock->sk_user_data = local; - sock->sk_data_ready = rxrpc_data_ready; - sock->sk_error_report = rxrpc_error_report; _leave(" = 0"); return 0; diff --git a/net/rxrpc/net_ns.c b/net/rxrpc/net_ns.c index 417d80867c4f..fd7eba8467fa 100644 --- a/net/rxrpc/net_ns.c +++ b/net/rxrpc/net_ns.c @@ -102,6 +102,9 @@ static __net_init int rxrpc_init_net(struct net *net) proc_create_net("conns", 0444, rxnet->proc_net, &rxrpc_connection_seq_ops, sizeof(struct seq_net_private)); + proc_create_net("peers", 0444, rxnet->proc_net, + &rxrpc_peer_seq_ops, + sizeof(struct seq_net_private)); return 0; err_proc: diff --git a/net/rxrpc/output.c b/net/rxrpc/output.c index ccf5de160444..189418888839 100644 --- a/net/rxrpc/output.c +++ b/net/rxrpc/output.c @@ -124,7 +124,6 @@ int rxrpc_send_ack_packet(struct rxrpc_call *call, bool ping, struct kvec iov[2]; rxrpc_serial_t serial; rxrpc_seq_t hard_ack, top; - ktime_t now; size_t len, n; int ret; u8 reason; @@ -196,9 +195,7 @@ int rxrpc_send_ack_packet(struct rxrpc_call *call, bool ping, /* We need to stick a time in before we send the packet in case * the reply gets back before kernel_sendmsg() completes - but * asking UDP to send the packet can take a relatively long - * time, so we update the time after, on the assumption that - * the packet transmission is more likely to happen towards the - * end of the kernel_sendmsg() call. + * time. */ call->ping_time = ktime_get_real(); set_bit(RXRPC_CALL_PINGING, &call->flags); @@ -206,9 +203,6 @@ int rxrpc_send_ack_packet(struct rxrpc_call *call, bool ping, } ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, 2, len); - now = ktime_get_real(); - if (ping) - call->ping_time = now; conn->params.peer->last_tx_at = ktime_get_seconds(); if (ret < 0) trace_rxrpc_tx_fail(call->debug_id, serial, ret, @@ -363,8 +357,14 @@ int rxrpc_send_data_packet(struct rxrpc_call *call, struct sk_buff *skb, /* If our RTT cache needs working on, request an ACK. Also request * ACKs if a DATA packet appears to have been lost. + * + * However, we mustn't request an ACK on the last reply packet of a + * service call, lest OpenAFS incorrectly send us an ACK with some + * soft-ACKs in it and then never follow up with a proper hard ACK. */ - if (!(sp->hdr.flags & RXRPC_LAST_PACKET) && + if ((!(sp->hdr.flags & RXRPC_LAST_PACKET) || + rxrpc_to_server(sp) + ) && (test_and_clear_bit(RXRPC_CALL_EV_ACK_LOST, &call->events) || retrans || call->cong_mode == RXRPC_CALL_SLOW_START || @@ -378,11 +378,13 @@ int rxrpc_send_data_packet(struct rxrpc_call *call, struct sk_buff *skb, if ((lose++ & 7) == 7) { ret = 0; lost = true; - goto done; } } - _proto("Tx DATA %%%u { #%u }", serial, sp->hdr.seq); + trace_rxrpc_tx_data(call, sp->hdr.seq, serial, whdr.flags, + retrans, lost); + if (lost) + goto done; /* send the packet with the don't fragment bit set if we currently * think it's small enough */ @@ -390,6 +392,11 @@ int rxrpc_send_data_packet(struct rxrpc_call *call, struct sk_buff *skb, goto send_fragmentable; down_read(&conn->params.local->defrag_sem); + + sp->hdr.serial = serial; + smp_wmb(); /* Set serial before timestamp */ + skb->tstamp = ktime_get_real(); + /* send the packet by UDP * - returns -EMSGSIZE if UDP would have to fragment the packet * to go out of the interface @@ -410,15 +417,9 @@ int rxrpc_send_data_packet(struct rxrpc_call *call, struct sk_buff *skb, goto send_fragmentable; done: - trace_rxrpc_tx_data(call, sp->hdr.seq, serial, whdr.flags, - retrans, lost); if (ret >= 0) { - ktime_t now = ktime_get_real(); - skb->tstamp = now; - smp_wmb(); - sp->hdr.serial = serial; if (whdr.flags & RXRPC_REQUEST_ACK) { - call->peer->rtt_last_req = now; + call->peer->rtt_last_req = skb->tstamp; trace_rxrpc_rtt_tx(call, rxrpc_rtt_tx_data, serial); if (call->peer->rtt_usage > 1) { unsigned long nowj = jiffies, ack_lost_at; @@ -457,6 +458,10 @@ send_fragmentable: down_write(&conn->params.local->defrag_sem); + sp->hdr.serial = serial; + smp_wmb(); /* Set serial before timestamp */ + skb->tstamp = ktime_get_real(); + switch (conn->params.local->srx.transport.family) { case AF_INET: opt = IP_PMTUDISC_DONT; @@ -519,7 +524,7 @@ void rxrpc_reject_packets(struct rxrpc_local *local) struct kvec iov[2]; size_t size; __be32 code; - int ret; + int ret, ioc; _enter("%d", local->debug_id); @@ -527,7 +532,6 @@ void rxrpc_reject_packets(struct rxrpc_local *local) iov[0].iov_len = sizeof(whdr); iov[1].iov_base = &code; iov[1].iov_len = sizeof(code); - size = sizeof(whdr) + sizeof(code); msg.msg_name = &srx.transport; msg.msg_control = NULL; @@ -535,16 +539,30 @@ void rxrpc_reject_packets(struct rxrpc_local *local) msg.msg_flags = 0; memset(&whdr, 0, sizeof(whdr)); - whdr.type = RXRPC_PACKET_TYPE_ABORT; while ((skb = skb_dequeue(&local->reject_queue))) { rxrpc_see_skb(skb, rxrpc_skb_rx_seen); sp = rxrpc_skb(skb); - if (rxrpc_extract_addr_from_skb(local, &srx, skb) == 0) { - msg.msg_namelen = srx.transport_len; - + switch (skb->mark) { + case RXRPC_SKB_MARK_REJECT_BUSY: + whdr.type = RXRPC_PACKET_TYPE_BUSY; + size = sizeof(whdr); + ioc = 1; + break; + case RXRPC_SKB_MARK_REJECT_ABORT: + whdr.type = RXRPC_PACKET_TYPE_ABORT; code = htonl(skb->priority); + size = sizeof(whdr) + sizeof(code); + ioc = 2; + break; + default: + rxrpc_free_skb(skb, rxrpc_skb_rx_freed); + continue; + } + + if (rxrpc_extract_addr_from_skb(&srx, skb) == 0) { + msg.msg_namelen = srx.transport_len; whdr.epoch = htonl(sp->hdr.epoch); whdr.cid = htonl(sp->hdr.cid); @@ -554,7 +572,8 @@ void rxrpc_reject_packets(struct rxrpc_local *local) whdr.flags ^= RXRPC_CLIENT_INITIATED; whdr.flags &= RXRPC_CLIENT_INITIATED; - ret = kernel_sendmsg(local->socket, &msg, iov, 2, size); + ret = kernel_sendmsg(local->socket, &msg, + iov, ioc, size); if (ret < 0) trace_rxrpc_tx_fail(local->debug_id, 0, ret, rxrpc_tx_point_reject); diff --git a/net/rxrpc/peer_event.c b/net/rxrpc/peer_event.c index 4f9da2f51c69..bc05af89fc38 100644 --- a/net/rxrpc/peer_event.c +++ b/net/rxrpc/peer_event.c @@ -23,6 +23,8 @@ #include "ar-internal.h" static void rxrpc_store_error(struct rxrpc_peer *, struct sock_exterr_skb *); +static void rxrpc_distribute_error(struct rxrpc_peer *, int, + enum rxrpc_call_completion); /* * Find the peer associated with an ICMP packet. @@ -45,6 +47,8 @@ static struct rxrpc_peer *rxrpc_lookup_peer_icmp_rcu(struct rxrpc_local *local, */ switch (srx->transport.family) { case AF_INET: + srx->transport_len = sizeof(srx->transport.sin); + srx->transport.family = AF_INET; srx->transport.sin.sin_port = serr->port; switch (serr->ee.ee_origin) { case SO_EE_ORIGIN_ICMP: @@ -68,20 +72,20 @@ static struct rxrpc_peer *rxrpc_lookup_peer_icmp_rcu(struct rxrpc_local *local, #ifdef CONFIG_AF_RXRPC_IPV6 case AF_INET6: - srx->transport.sin6.sin6_port = serr->port; switch (serr->ee.ee_origin) { case SO_EE_ORIGIN_ICMP6: _net("Rx ICMP6"); + srx->transport.sin6.sin6_port = serr->port; memcpy(&srx->transport.sin6.sin6_addr, skb_network_header(skb) + serr->addr_offset, sizeof(struct in6_addr)); break; case SO_EE_ORIGIN_ICMP: _net("Rx ICMP on v6 sock"); - srx->transport.sin6.sin6_addr.s6_addr32[0] = 0; - srx->transport.sin6.sin6_addr.s6_addr32[1] = 0; - srx->transport.sin6.sin6_addr.s6_addr32[2] = htonl(0xffff); - memcpy(srx->transport.sin6.sin6_addr.s6_addr + 12, + srx->transport_len = sizeof(srx->transport.sin); + srx->transport.family = AF_INET; + srx->transport.sin.sin_port = serr->port; + memcpy(&srx->transport.sin.sin_addr, skb_network_header(skb) + serr->addr_offset, sizeof(struct in_addr)); break; @@ -193,9 +197,8 @@ void rxrpc_error_report(struct sock *sk) rxrpc_store_error(peer, serr); rcu_read_unlock(); rxrpc_free_skb(skb, rxrpc_skb_rx_freed); + rxrpc_put_peer(peer); - /* The ref we obtained is passed off to the work item */ - __rxrpc_queue_peer_error(peer); _leave(""); } @@ -205,6 +208,7 @@ void rxrpc_error_report(struct sock *sk) static void rxrpc_store_error(struct rxrpc_peer *peer, struct sock_exterr_skb *serr) { + enum rxrpc_call_completion compl = RXRPC_CALL_NETWORK_ERROR; struct sock_extended_err *ee; int err; @@ -255,7 +259,7 @@ static void rxrpc_store_error(struct rxrpc_peer *peer, case SO_EE_ORIGIN_NONE: case SO_EE_ORIGIN_LOCAL: _proto("Rx Received local error { error=%d }", err); - err += RXRPC_LOCAL_ERROR_OFFSET; + compl = RXRPC_CALL_LOCAL_ERROR; break; case SO_EE_ORIGIN_ICMP6: @@ -264,48 +268,23 @@ static void rxrpc_store_error(struct rxrpc_peer *peer, break; } - peer->error_report = err; + rxrpc_distribute_error(peer, err, compl); } /* - * Distribute an error that occurred on a peer + * Distribute an error that occurred on a peer. */ -void rxrpc_peer_error_distributor(struct work_struct *work) +static void rxrpc_distribute_error(struct rxrpc_peer *peer, int error, + enum rxrpc_call_completion compl) { - struct rxrpc_peer *peer = - container_of(work, struct rxrpc_peer, error_distributor); struct rxrpc_call *call; - enum rxrpc_call_completion compl; - int error; - - _enter(""); - - error = READ_ONCE(peer->error_report); - if (error < RXRPC_LOCAL_ERROR_OFFSET) { - compl = RXRPC_CALL_NETWORK_ERROR; - } else { - compl = RXRPC_CALL_LOCAL_ERROR; - error -= RXRPC_LOCAL_ERROR_OFFSET; - } - _debug("ISSUE ERROR %s %d", rxrpc_call_completions[compl], error); - - spin_lock_bh(&peer->lock); - - while (!hlist_empty(&peer->error_targets)) { - call = hlist_entry(peer->error_targets.first, - struct rxrpc_call, error_link); - hlist_del_init(&call->error_link); + hlist_for_each_entry_rcu(call, &peer->error_targets, error_link) { rxrpc_see_call(call); - - if (rxrpc_set_call_completion(call, compl, 0, -error)) + if (call->state < RXRPC_CALL_COMPLETE && + rxrpc_set_call_completion(call, compl, 0, -error)) rxrpc_notify_socket(call); } - - spin_unlock_bh(&peer->lock); - - rxrpc_put_peer(peer); - _leave(""); } /* @@ -325,6 +304,8 @@ void rxrpc_peer_add_rtt(struct rxrpc_call *call, enum rxrpc_rtt_rx_trace why, if (rtt < 0) return; + spin_lock(&peer->rtt_input_lock); + /* Replace the oldest datum in the RTT buffer */ sum -= peer->rtt_cache[cursor]; sum += rtt; @@ -336,6 +317,8 @@ void rxrpc_peer_add_rtt(struct rxrpc_call *call, enum rxrpc_rtt_rx_trace why, peer->rtt_usage = usage; } + spin_unlock(&peer->rtt_input_lock); + /* Now recalculate the average */ if (usage == RXRPC_RTT_CACHE_SIZE) { avg = sum / RXRPC_RTT_CACHE_SIZE; @@ -344,6 +327,7 @@ void rxrpc_peer_add_rtt(struct rxrpc_call *call, enum rxrpc_rtt_rx_trace why, do_div(avg, usage); } + /* Don't need to update this under lock */ peer->rtt = avg; trace_rxrpc_rtt_rx(call, why, send_serial, resp_serial, rtt, usage, avg); diff --git a/net/rxrpc/peer_object.c b/net/rxrpc/peer_object.c index 1dc7648e3eff..5691b7d266ca 100644 --- a/net/rxrpc/peer_object.c +++ b/net/rxrpc/peer_object.c @@ -124,11 +124,9 @@ static struct rxrpc_peer *__rxrpc_lookup_peer_rcu( struct rxrpc_net *rxnet = local->rxnet; hash_for_each_possible_rcu(rxnet->peer_hash, peer, hash_link, hash_key) { - if (rxrpc_peer_cmp_key(peer, local, srx, hash_key) == 0) { - if (atomic_read(&peer->usage) == 0) - return NULL; + if (rxrpc_peer_cmp_key(peer, local, srx, hash_key) == 0 && + atomic_read(&peer->usage) > 0) return peer; - } } return NULL; @@ -155,8 +153,10 @@ struct rxrpc_peer *rxrpc_lookup_peer_rcu(struct rxrpc_local *local, * assess the MTU size for the network interface through which this peer is * reached */ -static void rxrpc_assess_MTU_size(struct rxrpc_peer *peer) +static void rxrpc_assess_MTU_size(struct rxrpc_sock *rx, + struct rxrpc_peer *peer) { + struct net *net = sock_net(&rx->sk); struct dst_entry *dst; struct rtable *rt; struct flowi fl; @@ -171,7 +171,7 @@ static void rxrpc_assess_MTU_size(struct rxrpc_peer *peer) switch (peer->srx.transport.family) { case AF_INET: rt = ip_route_output_ports( - &init_net, fl4, NULL, + net, fl4, NULL, peer->srx.transport.sin.sin_addr.s_addr, 0, htons(7000), htons(7001), IPPROTO_UDP, 0, 0); if (IS_ERR(rt)) { @@ -190,7 +190,7 @@ static void rxrpc_assess_MTU_size(struct rxrpc_peer *peer) sizeof(struct in6_addr)); fl6->fl6_dport = htons(7001); fl6->fl6_sport = htons(7000); - dst = ip6_route_output(&init_net, NULL, fl6); + dst = ip6_route_output(net, NULL, fl6); if (dst->error) { _leave(" [route err %d]", dst->error); return; @@ -222,11 +222,10 @@ struct rxrpc_peer *rxrpc_alloc_peer(struct rxrpc_local *local, gfp_t gfp) atomic_set(&peer->usage, 1); peer->local = local; INIT_HLIST_HEAD(&peer->error_targets); - INIT_WORK(&peer->error_distributor, - &rxrpc_peer_error_distributor); peer->service_conns = RB_ROOT; seqlock_init(&peer->service_conn_lock); spin_lock_init(&peer->lock); + spin_lock_init(&peer->rtt_input_lock); peer->debug_id = atomic_inc_return(&rxrpc_debug_id); if (RXRPC_TX_SMSS > 2190) @@ -244,10 +243,11 @@ struct rxrpc_peer *rxrpc_alloc_peer(struct rxrpc_local *local, gfp_t gfp) /* * Initialise peer record. */ -static void rxrpc_init_peer(struct rxrpc_peer *peer, unsigned long hash_key) +static void rxrpc_init_peer(struct rxrpc_sock *rx, struct rxrpc_peer *peer, + unsigned long hash_key) { peer->hash_key = hash_key; - rxrpc_assess_MTU_size(peer); + rxrpc_assess_MTU_size(rx, peer); peer->mtu = peer->if_mtu; peer->rtt_last_req = ktime_get_real(); @@ -279,7 +279,8 @@ static void rxrpc_init_peer(struct rxrpc_peer *peer, unsigned long hash_key) /* * Set up a new peer. */ -static struct rxrpc_peer *rxrpc_create_peer(struct rxrpc_local *local, +static struct rxrpc_peer *rxrpc_create_peer(struct rxrpc_sock *rx, + struct rxrpc_local *local, struct sockaddr_rxrpc *srx, unsigned long hash_key, gfp_t gfp) @@ -291,7 +292,7 @@ static struct rxrpc_peer *rxrpc_create_peer(struct rxrpc_local *local, peer = rxrpc_alloc_peer(local, gfp); if (peer) { memcpy(&peer->srx, srx, sizeof(*srx)); - rxrpc_init_peer(peer, hash_key); + rxrpc_init_peer(rx, peer, hash_key); } _leave(" = %p", peer); @@ -299,40 +300,31 @@ static struct rxrpc_peer *rxrpc_create_peer(struct rxrpc_local *local, } /* - * Set up a new incoming peer. The address is prestored in the preallocated - * peer. + * Set up a new incoming peer. There shouldn't be any other matching peers + * since we've already done a search in the list from the non-reentrant context + * (the data_ready handler) that is the only place we can add new peers. */ -struct rxrpc_peer *rxrpc_lookup_incoming_peer(struct rxrpc_local *local, - struct rxrpc_peer *prealloc) +void rxrpc_new_incoming_peer(struct rxrpc_sock *rx, struct rxrpc_local *local, + struct rxrpc_peer *peer) { - struct rxrpc_peer *peer; struct rxrpc_net *rxnet = local->rxnet; unsigned long hash_key; - hash_key = rxrpc_peer_hash_key(local, &prealloc->srx); - prealloc->local = local; - rxrpc_init_peer(prealloc, hash_key); + hash_key = rxrpc_peer_hash_key(local, &peer->srx); + peer->local = local; + rxrpc_init_peer(rx, peer, hash_key); spin_lock(&rxnet->peer_hash_lock); - - /* Need to check that we aren't racing with someone else */ - peer = __rxrpc_lookup_peer_rcu(local, &prealloc->srx, hash_key); - if (peer && !rxrpc_get_peer_maybe(peer)) - peer = NULL; - if (!peer) { - peer = prealloc; - hash_add_rcu(rxnet->peer_hash, &peer->hash_link, hash_key); - list_add_tail(&peer->keepalive_link, &rxnet->peer_keepalive_new); - } - + hash_add_rcu(rxnet->peer_hash, &peer->hash_link, hash_key); + list_add_tail(&peer->keepalive_link, &rxnet->peer_keepalive_new); spin_unlock(&rxnet->peer_hash_lock); - return peer; } /* * obtain a remote transport endpoint for the specified address */ -struct rxrpc_peer *rxrpc_lookup_peer(struct rxrpc_local *local, +struct rxrpc_peer *rxrpc_lookup_peer(struct rxrpc_sock *rx, + struct rxrpc_local *local, struct sockaddr_rxrpc *srx, gfp_t gfp) { struct rxrpc_peer *peer, *candidate; @@ -352,7 +344,7 @@ struct rxrpc_peer *rxrpc_lookup_peer(struct rxrpc_local *local, /* The peer is not yet present in hash - create a candidate * for a new record and then redo the search. */ - candidate = rxrpc_create_peer(local, srx, hash_key, gfp); + candidate = rxrpc_create_peer(rx, local, srx, hash_key, gfp); if (!candidate) { _leave(" = NULL [nomem]"); return NULL; @@ -416,21 +408,6 @@ struct rxrpc_peer *rxrpc_get_peer_maybe(struct rxrpc_peer *peer) } /* - * Queue a peer record. This passes the caller's ref to the workqueue. - */ -void __rxrpc_queue_peer_error(struct rxrpc_peer *peer) -{ - const void *here = __builtin_return_address(0); - int n; - - n = atomic_read(&peer->usage); - if (rxrpc_queue_work(&peer->error_distributor)) - trace_rxrpc_peer(peer, rxrpc_peer_queued_error, n, here); - else - rxrpc_put_peer(peer); -} - -/* * Discard a peer record. */ static void __rxrpc_put_peer(struct rxrpc_peer *peer) diff --git a/net/rxrpc/proc.c b/net/rxrpc/proc.c index 9805e3b85c36..c7d976859d40 100644 --- a/net/rxrpc/proc.c +++ b/net/rxrpc/proc.c @@ -212,3 +212,129 @@ const struct seq_operations rxrpc_connection_seq_ops = { .stop = rxrpc_connection_seq_stop, .show = rxrpc_connection_seq_show, }; + +/* + * generate a list of extant virtual peers in /proc/net/rxrpc/peers + */ +static int rxrpc_peer_seq_show(struct seq_file *seq, void *v) +{ + struct rxrpc_peer *peer; + time64_t now; + char lbuff[50], rbuff[50]; + + if (v == SEQ_START_TOKEN) { + seq_puts(seq, + "Proto Local " + " Remote " + " Use CW MTU LastUse RTT Rc\n" + ); + return 0; + } + + peer = list_entry(v, struct rxrpc_peer, hash_link); + + sprintf(lbuff, "%pISpc", &peer->local->srx.transport); + + sprintf(rbuff, "%pISpc", &peer->srx.transport); + + now = ktime_get_seconds(); + seq_printf(seq, + "UDP %-47.47s %-47.47s %3u" + " %3u %5u %6llus %12llu %2u\n", + lbuff, + rbuff, + atomic_read(&peer->usage), + peer->cong_cwnd, + peer->mtu, + now - peer->last_tx_at, + peer->rtt, + peer->rtt_cursor); + + return 0; +} + +static void *rxrpc_peer_seq_start(struct seq_file *seq, loff_t *_pos) + __acquires(rcu) +{ + struct rxrpc_net *rxnet = rxrpc_net(seq_file_net(seq)); + unsigned int bucket, n; + unsigned int shift = 32 - HASH_BITS(rxnet->peer_hash); + void *p; + + rcu_read_lock(); + + if (*_pos >= UINT_MAX) + return NULL; + + n = *_pos & ((1U << shift) - 1); + bucket = *_pos >> shift; + for (;;) { + if (bucket >= HASH_SIZE(rxnet->peer_hash)) { + *_pos = UINT_MAX; + return NULL; + } + if (n == 0) { + if (bucket == 0) + return SEQ_START_TOKEN; + *_pos += 1; + n++; + } + + p = seq_hlist_start_rcu(&rxnet->peer_hash[bucket], n - 1); + if (p) + return p; + bucket++; + n = 1; + *_pos = (bucket << shift) | n; + } +} + +static void *rxrpc_peer_seq_next(struct seq_file *seq, void *v, loff_t *_pos) +{ + struct rxrpc_net *rxnet = rxrpc_net(seq_file_net(seq)); + unsigned int bucket, n; + unsigned int shift = 32 - HASH_BITS(rxnet->peer_hash); + void *p; + + if (*_pos >= UINT_MAX) + return NULL; + + bucket = *_pos >> shift; + + p = seq_hlist_next_rcu(v, &rxnet->peer_hash[bucket], _pos); + if (p) + return p; + + for (;;) { + bucket++; + n = 1; + *_pos = (bucket << shift) | n; + + if (bucket >= HASH_SIZE(rxnet->peer_hash)) { + *_pos = UINT_MAX; + return NULL; + } + if (n == 0) { + *_pos += 1; + n++; + } + + p = seq_hlist_start_rcu(&rxnet->peer_hash[bucket], n - 1); + if (p) + return p; + } +} + +static void rxrpc_peer_seq_stop(struct seq_file *seq, void *v) + __releases(rcu) +{ + rcu_read_unlock(); +} + + +const struct seq_operations rxrpc_peer_seq_ops = { + .start = rxrpc_peer_seq_start, + .next = rxrpc_peer_seq_next, + .stop = rxrpc_peer_seq_stop, + .show = rxrpc_peer_seq_show, +}; diff --git a/net/rxrpc/protocol.h b/net/rxrpc/protocol.h index 93da73bf7098..f9cb83c938f3 100644 --- a/net/rxrpc/protocol.h +++ b/net/rxrpc/protocol.h @@ -50,7 +50,6 @@ struct rxrpc_wire_header { #define RXRPC_PACKET_TYPE_10 10 /* Ignored */ #define RXRPC_PACKET_TYPE_11 11 /* Ignored */ #define RXRPC_PACKET_TYPE_VERSION 13 /* version string request */ -#define RXRPC_N_PACKET_TYPES 14 /* number of packet types (incl type 0) */ uint8_t flags; /* packet flags */ #define RXRPC_CLIENT_INITIATED 0x01 /* signifies a packet generated by a client */ @@ -72,20 +71,6 @@ struct rxrpc_wire_header { } __packed; -#define RXRPC_SUPPORTED_PACKET_TYPES ( \ - (1 << RXRPC_PACKET_TYPE_DATA) | \ - (1 << RXRPC_PACKET_TYPE_ACK) | \ - (1 << RXRPC_PACKET_TYPE_BUSY) | \ - (1 << RXRPC_PACKET_TYPE_ABORT) | \ - (1 << RXRPC_PACKET_TYPE_ACKALL) | \ - (1 << RXRPC_PACKET_TYPE_CHALLENGE) | \ - (1 << RXRPC_PACKET_TYPE_RESPONSE) | \ - /*(1 << RXRPC_PACKET_TYPE_DEBUG) | */ \ - (1 << RXRPC_PACKET_TYPE_PARAMS) | \ - (1 << RXRPC_PACKET_TYPE_10) | \ - (1 << RXRPC_PACKET_TYPE_11) | \ - (1 << RXRPC_PACKET_TYPE_VERSION)) - /*****************************************************************************/ /* * jumbo packet secondary header diff --git a/net/rxrpc/recvmsg.c b/net/rxrpc/recvmsg.c index 816b19a78809..eaf19ebaa964 100644 --- a/net/rxrpc/recvmsg.c +++ b/net/rxrpc/recvmsg.c @@ -715,3 +715,46 @@ call_complete: goto out; } EXPORT_SYMBOL(rxrpc_kernel_recv_data); + +/** + * rxrpc_kernel_get_reply_time - Get timestamp on first reply packet + * @sock: The socket that the call exists on + * @call: The call to query + * @_ts: Where to put the timestamp + * + * Retrieve the timestamp from the first DATA packet of the reply if it is + * in the ring. Returns true if successful, false if not. + */ +bool rxrpc_kernel_get_reply_time(struct socket *sock, struct rxrpc_call *call, + ktime_t *_ts) +{ + struct sk_buff *skb; + rxrpc_seq_t hard_ack, top, seq; + bool success = false; + + mutex_lock(&call->user_mutex); + + if (READ_ONCE(call->state) != RXRPC_CALL_CLIENT_RECV_REPLY) + goto out; + + hard_ack = call->rx_hard_ack; + if (hard_ack != 0) + goto out; + + seq = hard_ack + 1; + top = smp_load_acquire(&call->rx_top); + if (after(seq, top)) + goto out; + + skb = call->rxtx_buffer[seq & RXRPC_RXTX_BUFF_MASK]; + if (!skb) + goto out; + + *_ts = skb_get_ktime(skb); + success = true; + +out: + mutex_unlock(&call->user_mutex); + return success; +} +EXPORT_SYMBOL(rxrpc_kernel_get_reply_time); diff --git a/net/rxrpc/rxkad.c b/net/rxrpc/rxkad.c index cea16838d588..cbef9ea43dec 100644 --- a/net/rxrpc/rxkad.c +++ b/net/rxrpc/rxkad.c @@ -46,7 +46,7 @@ struct rxkad_level2_hdr { * alloc routine, but since we have it to hand, we use it to decrypt RESPONSE * packets */ -static struct crypto_skcipher *rxkad_ci; +static struct crypto_sync_skcipher *rxkad_ci; static DEFINE_MUTEX(rxkad_ci_mutex); /* @@ -54,7 +54,7 @@ static DEFINE_MUTEX(rxkad_ci_mutex); */ static int rxkad_init_connection_security(struct rxrpc_connection *conn) { - struct crypto_skcipher *ci; + struct crypto_sync_skcipher *ci; struct rxrpc_key_token *token; int ret; @@ -63,14 +63,14 @@ static int rxkad_init_connection_security(struct rxrpc_connection *conn) token = conn->params.key->payload.data[0]; conn->security_ix = token->security_index; - ci = crypto_alloc_skcipher("pcbc(fcrypt)", 0, CRYPTO_ALG_ASYNC); + ci = crypto_alloc_sync_skcipher("pcbc(fcrypt)", 0, 0); if (IS_ERR(ci)) { _debug("no cipher"); ret = PTR_ERR(ci); goto error; } - if (crypto_skcipher_setkey(ci, token->kad->session_key, + if (crypto_sync_skcipher_setkey(ci, token->kad->session_key, sizeof(token->kad->session_key)) < 0) BUG(); @@ -104,7 +104,7 @@ error: static int rxkad_prime_packet_security(struct rxrpc_connection *conn) { struct rxrpc_key_token *token; - SKCIPHER_REQUEST_ON_STACK(req, conn->cipher); + SYNC_SKCIPHER_REQUEST_ON_STACK(req, conn->cipher); struct scatterlist sg; struct rxrpc_crypt iv; __be32 *tmpbuf; @@ -128,7 +128,7 @@ static int rxkad_prime_packet_security(struct rxrpc_connection *conn) tmpbuf[3] = htonl(conn->security_ix); sg_init_one(&sg, tmpbuf, tmpsize); - skcipher_request_set_tfm(req, conn->cipher); + skcipher_request_set_sync_tfm(req, conn->cipher); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, &sg, &sg, tmpsize, iv.x); crypto_skcipher_encrypt(req); @@ -167,7 +167,7 @@ static int rxkad_secure_packet_auth(const struct rxrpc_call *call, memset(&iv, 0, sizeof(iv)); sg_init_one(&sg, sechdr, 8); - skcipher_request_set_tfm(req, call->conn->cipher); + skcipher_request_set_sync_tfm(req, call->conn->cipher); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x); crypto_skcipher_encrypt(req); @@ -212,7 +212,7 @@ static int rxkad_secure_packet_encrypt(const struct rxrpc_call *call, memcpy(&iv, token->kad->session_key, sizeof(iv)); sg_init_one(&sg[0], sechdr, sizeof(rxkhdr)); - skcipher_request_set_tfm(req, call->conn->cipher); + skcipher_request_set_sync_tfm(req, call->conn->cipher); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, &sg[0], &sg[0], sizeof(rxkhdr), iv.x); crypto_skcipher_encrypt(req); @@ -250,7 +250,7 @@ static int rxkad_secure_packet(struct rxrpc_call *call, void *sechdr) { struct rxrpc_skb_priv *sp; - SKCIPHER_REQUEST_ON_STACK(req, call->conn->cipher); + SYNC_SKCIPHER_REQUEST_ON_STACK(req, call->conn->cipher); struct rxrpc_crypt iv; struct scatterlist sg; u32 x, y; @@ -279,7 +279,7 @@ static int rxkad_secure_packet(struct rxrpc_call *call, call->crypto_buf[1] = htonl(x); sg_init_one(&sg, call->crypto_buf, 8); - skcipher_request_set_tfm(req, call->conn->cipher); + skcipher_request_set_sync_tfm(req, call->conn->cipher); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x); crypto_skcipher_encrypt(req); @@ -352,7 +352,7 @@ static int rxkad_verify_packet_1(struct rxrpc_call *call, struct sk_buff *skb, /* start the decryption afresh */ memset(&iv, 0, sizeof(iv)); - skcipher_request_set_tfm(req, call->conn->cipher); + skcipher_request_set_sync_tfm(req, call->conn->cipher); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, sg, sg, 8, iv.x); crypto_skcipher_decrypt(req); @@ -450,7 +450,7 @@ static int rxkad_verify_packet_2(struct rxrpc_call *call, struct sk_buff *skb, token = call->conn->params.key->payload.data[0]; memcpy(&iv, token->kad->session_key, sizeof(iv)); - skcipher_request_set_tfm(req, call->conn->cipher); + skcipher_request_set_sync_tfm(req, call->conn->cipher); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, sg, sg, len, iv.x); crypto_skcipher_decrypt(req); @@ -506,7 +506,7 @@ static int rxkad_verify_packet(struct rxrpc_call *call, struct sk_buff *skb, unsigned int offset, unsigned int len, rxrpc_seq_t seq, u16 expected_cksum) { - SKCIPHER_REQUEST_ON_STACK(req, call->conn->cipher); + SYNC_SKCIPHER_REQUEST_ON_STACK(req, call->conn->cipher); struct rxrpc_crypt iv; struct scatterlist sg; bool aborted; @@ -529,7 +529,7 @@ static int rxkad_verify_packet(struct rxrpc_call *call, struct sk_buff *skb, call->crypto_buf[1] = htonl(x); sg_init_one(&sg, call->crypto_buf, 8); - skcipher_request_set_tfm(req, call->conn->cipher); + skcipher_request_set_sync_tfm(req, call->conn->cipher); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x); crypto_skcipher_encrypt(req); @@ -755,7 +755,7 @@ static void rxkad_encrypt_response(struct rxrpc_connection *conn, struct rxkad_response *resp, const struct rxkad_key *s2) { - SKCIPHER_REQUEST_ON_STACK(req, conn->cipher); + SYNC_SKCIPHER_REQUEST_ON_STACK(req, conn->cipher); struct rxrpc_crypt iv; struct scatterlist sg[1]; @@ -764,7 +764,7 @@ static void rxkad_encrypt_response(struct rxrpc_connection *conn, sg_init_table(sg, 1); sg_set_buf(sg, &resp->encrypted, sizeof(resp->encrypted)); - skcipher_request_set_tfm(req, conn->cipher); + skcipher_request_set_sync_tfm(req, conn->cipher); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, sg, sg, sizeof(resp->encrypted), iv.x); crypto_skcipher_encrypt(req); @@ -1021,7 +1021,7 @@ static void rxkad_decrypt_response(struct rxrpc_connection *conn, struct rxkad_response *resp, const struct rxrpc_crypt *session_key) { - SKCIPHER_REQUEST_ON_STACK(req, rxkad_ci); + SYNC_SKCIPHER_REQUEST_ON_STACK(req, rxkad_ci); struct scatterlist sg[1]; struct rxrpc_crypt iv; @@ -1031,7 +1031,7 @@ static void rxkad_decrypt_response(struct rxrpc_connection *conn, ASSERT(rxkad_ci != NULL); mutex_lock(&rxkad_ci_mutex); - if (crypto_skcipher_setkey(rxkad_ci, session_key->x, + if (crypto_sync_skcipher_setkey(rxkad_ci, session_key->x, sizeof(*session_key)) < 0) BUG(); @@ -1039,7 +1039,7 @@ static void rxkad_decrypt_response(struct rxrpc_connection *conn, sg_init_table(sg, 1); sg_set_buf(sg, &resp->encrypted, sizeof(resp->encrypted)); - skcipher_request_set_tfm(req, rxkad_ci); + skcipher_request_set_sync_tfm(req, rxkad_ci); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, sg, sg, sizeof(resp->encrypted), iv.x); crypto_skcipher_decrypt(req); @@ -1218,7 +1218,7 @@ static void rxkad_clear(struct rxrpc_connection *conn) _enter(""); if (conn->cipher) - crypto_free_skcipher(conn->cipher); + crypto_free_sync_skcipher(conn->cipher); } /* @@ -1228,7 +1228,7 @@ static int rxkad_init(void) { /* pin the cipher we need so that the crypto layer doesn't invoke * keventd to go get it */ - rxkad_ci = crypto_alloc_skcipher("pcbc(fcrypt)", 0, CRYPTO_ALG_ASYNC); + rxkad_ci = crypto_alloc_sync_skcipher("pcbc(fcrypt)", 0, 0); return PTR_ERR_OR_ZERO(rxkad_ci); } @@ -1238,7 +1238,7 @@ static int rxkad_init(void) static void rxkad_exit(void) { if (rxkad_ci) - crypto_free_skcipher(rxkad_ci); + crypto_free_sync_skcipher(rxkad_ci); } /* diff --git a/net/rxrpc/skbuff.c b/net/rxrpc/skbuff.c index b8985d01876a..913dca65cc65 100644 --- a/net/rxrpc/skbuff.c +++ b/net/rxrpc/skbuff.c @@ -69,21 +69,6 @@ void rxrpc_free_skb(struct sk_buff *skb, enum rxrpc_skb_trace op) } /* - * Note the injected loss of a socket buffer. - */ -void rxrpc_lose_skb(struct sk_buff *skb, enum rxrpc_skb_trace op) -{ - const void *here = __builtin_return_address(0); - if (skb) { - int n; - CHECK_SLAB_OKAY(&skb->users); - n = atomic_dec_return(select_skb_count(op)); - trace_rxrpc_skb(skb, op, refcount_read(&skb->users), n, here); - kfree_skb(skb); - } -} - -/* * Clear a queue of socket buffers. */ void rxrpc_purge_queue(struct sk_buff_head *list) diff --git a/net/rxrpc/utils.c b/net/rxrpc/utils.c index e801171fa351..ff7af71c4b49 100644 --- a/net/rxrpc/utils.c +++ b/net/rxrpc/utils.c @@ -17,28 +17,17 @@ /* * Fill out a peer address from a socket buffer containing a packet. */ -int rxrpc_extract_addr_from_skb(struct rxrpc_local *local, - struct sockaddr_rxrpc *srx, - struct sk_buff *skb) +int rxrpc_extract_addr_from_skb(struct sockaddr_rxrpc *srx, struct sk_buff *skb) { memset(srx, 0, sizeof(*srx)); switch (ntohs(skb->protocol)) { case ETH_P_IP: - if (local->srx.transport.family == AF_INET6) { - srx->transport_type = SOCK_DGRAM; - srx->transport_len = sizeof(srx->transport.sin6); - srx->transport.sin6.sin6_family = AF_INET6; - srx->transport.sin6.sin6_port = udp_hdr(skb)->source; - srx->transport.sin6.sin6_addr.s6_addr32[2] = htonl(0xffff); - srx->transport.sin6.sin6_addr.s6_addr32[3] = ip_hdr(skb)->saddr; - } else { - srx->transport_type = SOCK_DGRAM; - srx->transport_len = sizeof(srx->transport.sin); - srx->transport.sin.sin_family = AF_INET; - srx->transport.sin.sin_port = udp_hdr(skb)->source; - srx->transport.sin.sin_addr.s_addr = ip_hdr(skb)->saddr; - } + srx->transport_type = SOCK_DGRAM; + srx->transport_len = sizeof(srx->transport.sin); + srx->transport.sin.sin_family = AF_INET; + srx->transport.sin.sin_port = udp_hdr(skb)->source; + srx->transport.sin.sin_addr.s_addr = ip_hdr(skb)->saddr; return 0; #ifdef CONFIG_AF_RXRPC_IPV6 diff --git a/net/sched/Kconfig b/net/sched/Kconfig index e95741388311..1b9afdee5ba9 100644 --- a/net/sched/Kconfig +++ b/net/sched/Kconfig @@ -194,6 +194,17 @@ config NET_SCH_ETF To compile this code as a module, choose M here: the module will be called sch_etf. +config NET_SCH_TAPRIO + tristate "Time Aware Priority (taprio) Scheduler" + help + Say Y here if you want to use the Time Aware Priority (taprio) packet + scheduling algorithm. + + See the top of <file:net/sched/sch_taprio.c> for more details. + + To compile this code as a module, choose M here: the + module will be called sch_taprio. + config NET_SCH_GRED tristate "Generic Random Early Detection (GRED)" ---help--- diff --git a/net/sched/Makefile b/net/sched/Makefile index f0403f49edcb..8a40431d7b5c 100644 --- a/net/sched/Makefile +++ b/net/sched/Makefile @@ -57,6 +57,7 @@ obj-$(CONFIG_NET_SCH_HHF) += sch_hhf.o obj-$(CONFIG_NET_SCH_PIE) += sch_pie.o obj-$(CONFIG_NET_SCH_CBS) += sch_cbs.o obj-$(CONFIG_NET_SCH_ETF) += sch_etf.o +obj-$(CONFIG_NET_SCH_TAPRIO) += sch_taprio.o obj-$(CONFIG_NET_CLS_U32) += cls_u32.o obj-$(CONFIG_NET_CLS_ROUTE4) += cls_route.o diff --git a/net/sched/act_api.c b/net/sched/act_api.c index 3c7c23421885..9c1b0729aebf 100644 --- a/net/sched/act_api.c +++ b/net/sched/act_api.c @@ -104,11 +104,11 @@ static int __tcf_action_put(struct tc_action *p, bool bind) { struct tcf_idrinfo *idrinfo = p->idrinfo; - if (refcount_dec_and_lock(&p->tcfa_refcnt, &idrinfo->lock)) { + if (refcount_dec_and_mutex_lock(&p->tcfa_refcnt, &idrinfo->lock)) { if (bind) atomic_dec(&p->tcfa_bindcnt); idr_remove(&idrinfo->action_idr, p->tcfa_index); - spin_unlock(&idrinfo->lock); + mutex_unlock(&idrinfo->lock); tcf_action_cleanup(p); return 1; @@ -200,7 +200,7 @@ static int tcf_dump_walker(struct tcf_idrinfo *idrinfo, struct sk_buff *skb, struct tc_action *p; unsigned long id = 1; - spin_lock(&idrinfo->lock); + mutex_lock(&idrinfo->lock); s_i = cb->args[0]; @@ -235,7 +235,7 @@ done: if (index >= 0) cb->args[0] = index + 1; - spin_unlock(&idrinfo->lock); + mutex_unlock(&idrinfo->lock); if (n_i) { if (act_flags & TCA_FLAG_LARGE_DUMP_ON) cb->args[1] = n_i; @@ -277,18 +277,18 @@ static int tcf_del_walker(struct tcf_idrinfo *idrinfo, struct sk_buff *skb, if (nla_put_string(skb, TCA_KIND, ops->kind)) goto nla_put_failure; - spin_lock(&idrinfo->lock); + mutex_lock(&idrinfo->lock); idr_for_each_entry_ul(idr, p, id) { ret = tcf_idr_release_unsafe(p); if (ret == ACT_P_DELETED) { module_put(ops->owner); n_i++; } else if (ret < 0) { - spin_unlock(&idrinfo->lock); + mutex_unlock(&idrinfo->lock); goto nla_put_failure; } } - spin_unlock(&idrinfo->lock); + mutex_unlock(&idrinfo->lock); if (nla_put_u32(skb, TCA_FCNT, n_i)) goto nla_put_failure; @@ -324,13 +324,13 @@ int tcf_idr_search(struct tc_action_net *tn, struct tc_action **a, u32 index) struct tcf_idrinfo *idrinfo = tn->idrinfo; struct tc_action *p; - spin_lock(&idrinfo->lock); + mutex_lock(&idrinfo->lock); p = idr_find(&idrinfo->action_idr, index); if (IS_ERR(p)) p = NULL; else if (p) refcount_inc(&p->tcfa_refcnt); - spin_unlock(&idrinfo->lock); + mutex_unlock(&idrinfo->lock); if (p) { *a = p; @@ -345,10 +345,10 @@ static int tcf_idr_delete_index(struct tcf_idrinfo *idrinfo, u32 index) struct tc_action *p; int ret = 0; - spin_lock(&idrinfo->lock); + mutex_lock(&idrinfo->lock); p = idr_find(&idrinfo->action_idr, index); if (!p) { - spin_unlock(&idrinfo->lock); + mutex_unlock(&idrinfo->lock); return -ENOENT; } @@ -358,7 +358,7 @@ static int tcf_idr_delete_index(struct tcf_idrinfo *idrinfo, u32 index) WARN_ON(p != idr_remove(&idrinfo->action_idr, p->tcfa_index)); - spin_unlock(&idrinfo->lock); + mutex_unlock(&idrinfo->lock); tcf_action_cleanup(p); module_put(owner); @@ -369,7 +369,7 @@ static int tcf_idr_delete_index(struct tcf_idrinfo *idrinfo, u32 index) ret = -EPERM; } - spin_unlock(&idrinfo->lock); + mutex_unlock(&idrinfo->lock); return ret; } @@ -431,10 +431,10 @@ void tcf_idr_insert(struct tc_action_net *tn, struct tc_action *a) { struct tcf_idrinfo *idrinfo = tn->idrinfo; - spin_lock(&idrinfo->lock); + mutex_lock(&idrinfo->lock); /* Replace ERR_PTR(-EBUSY) allocated by tcf_idr_check_alloc */ WARN_ON(!IS_ERR(idr_replace(&idrinfo->action_idr, a, a->tcfa_index))); - spin_unlock(&idrinfo->lock); + mutex_unlock(&idrinfo->lock); } EXPORT_SYMBOL(tcf_idr_insert); @@ -444,10 +444,10 @@ void tcf_idr_cleanup(struct tc_action_net *tn, u32 index) { struct tcf_idrinfo *idrinfo = tn->idrinfo; - spin_lock(&idrinfo->lock); + mutex_lock(&idrinfo->lock); /* Remove ERR_PTR(-EBUSY) allocated by tcf_idr_check_alloc */ WARN_ON(!IS_ERR(idr_remove(&idrinfo->action_idr, index))); - spin_unlock(&idrinfo->lock); + mutex_unlock(&idrinfo->lock); } EXPORT_SYMBOL(tcf_idr_cleanup); @@ -465,14 +465,14 @@ int tcf_idr_check_alloc(struct tc_action_net *tn, u32 *index, int ret; again: - spin_lock(&idrinfo->lock); + mutex_lock(&idrinfo->lock); if (*index) { p = idr_find(&idrinfo->action_idr, *index); if (IS_ERR(p)) { /* This means that another process allocated * index but did not assign the pointer yet. */ - spin_unlock(&idrinfo->lock); + mutex_unlock(&idrinfo->lock); goto again; } @@ -485,7 +485,7 @@ again: } else { *a = NULL; ret = idr_alloc_u32(&idrinfo->action_idr, NULL, index, - *index, GFP_ATOMIC); + *index, GFP_KERNEL); if (!ret) idr_replace(&idrinfo->action_idr, ERR_PTR(-EBUSY), *index); @@ -494,12 +494,12 @@ again: *index = 1; *a = NULL; ret = idr_alloc_u32(&idrinfo->action_idr, NULL, index, - UINT_MAX, GFP_ATOMIC); + UINT_MAX, GFP_KERNEL); if (!ret) idr_replace(&idrinfo->action_idr, ERR_PTR(-EBUSY), *index); } - spin_unlock(&idrinfo->lock); + mutex_unlock(&idrinfo->lock); return ret; } EXPORT_SYMBOL(tcf_idr_check_alloc); @@ -1452,7 +1452,7 @@ static int tc_dump_action(struct sk_buff *skb, struct netlink_callback *cb) u32 act_count = 0; ret = nlmsg_parse(cb->nlh, sizeof(struct tcamsg), tb, TCA_ROOT_MAX, - tcaa_policy, NULL); + tcaa_policy, cb->extack); if (ret < 0) return ret; diff --git a/net/sched/act_gact.c b/net/sched/act_gact.c index c89a7fa43d1b..b61c20ebb314 100644 --- a/net/sched/act_gact.c +++ b/net/sched/act_gact.c @@ -88,6 +88,11 @@ static int tcf_gact_init(struct net *net, struct nlattr *nla, p_parm = nla_data(tb[TCA_GACT_PROB]); if (p_parm->ptype >= MAX_RAND) return -EINVAL; + if (TC_ACT_EXT_CMP(p_parm->paction, TC_ACT_GOTO_CHAIN)) { + NL_SET_ERR_MSG(extack, + "goto chain not allowed on fallback"); + return -EINVAL; + } } #endif diff --git a/net/sched/act_ipt.c b/net/sched/act_ipt.c index 1efbfb10b1fc..8af6c11d2482 100644 --- a/net/sched/act_ipt.c +++ b/net/sched/act_ipt.c @@ -135,7 +135,7 @@ static int __tcf_ipt_init(struct net *net, unsigned int id, struct nlattr *nla, } td = (struct xt_entry_target *)nla_data(tb[TCA_IPT_TARG]); - if (nla_len(tb[TCA_IPT_TARG]) < td->u.target_size) { + if (nla_len(tb[TCA_IPT_TARG]) != td->u.target_size) { if (exists) tcf_idr_release(*a, bind); else diff --git a/net/sched/act_police.c b/net/sched/act_police.c index 92649d2667ed..052855d47354 100644 --- a/net/sched/act_police.c +++ b/net/sched/act_police.c @@ -185,8 +185,6 @@ static int tcf_police_init(struct net *net, struct nlattr *nla, new->peak_present = false; } - if (tb[TCA_POLICE_RESULT]) - new->tcfp_result = nla_get_u32(tb[TCA_POLICE_RESULT]); new->tcfp_burst = PSCHED_TICKS2NS(parm->burst); new->tcfp_toks = new->tcfp_burst; if (new->peak_present) { @@ -198,6 +196,16 @@ static int tcf_police_init(struct net *net, struct nlattr *nla, if (tb[TCA_POLICE_AVRATE]) new->tcfp_ewma_rate = nla_get_u32(tb[TCA_POLICE_AVRATE]); + if (tb[TCA_POLICE_RESULT]) { + new->tcfp_result = nla_get_u32(tb[TCA_POLICE_RESULT]); + if (TC_ACT_EXT_CMP(new->tcfp_result, TC_ACT_GOTO_CHAIN)) { + NL_SET_ERR_MSG(extack, + "goto chain not allowed on fallback"); + err = -EINVAL; + goto failure; + } + } + spin_lock_bh(&police->tcf_lock); new->tcfp_t_c = ktime_get_ns(); police->tcf_action = parm->action; diff --git a/net/sched/cls_api.c b/net/sched/cls_api.c index d670d3066ebd..f427a1e00e7e 100644 --- a/net/sched/cls_api.c +++ b/net/sched/cls_api.c @@ -31,6 +31,8 @@ #include <net/pkt_sched.h> #include <net/pkt_cls.h> +extern const struct nla_policy rtm_tca_policy[TCA_MAX + 1]; + /* The list of all installed classifier types */ static LIST_HEAD(tcf_proto_base); @@ -1304,7 +1306,7 @@ static int tc_new_tfilter(struct sk_buff *skb, struct nlmsghdr *n, replay: tp_created = 0; - err = nlmsg_parse(n, sizeof(*t), tca, TCA_MAX, NULL, extack); + err = nlmsg_parse(n, sizeof(*t), tca, TCA_MAX, rtm_tca_policy, extack); if (err < 0) return err; @@ -1454,7 +1456,7 @@ static int tc_del_tfilter(struct sk_buff *skb, struct nlmsghdr *n, if (!netlink_ns_capable(skb, net->user_ns, CAP_NET_ADMIN)) return -EPERM; - err = nlmsg_parse(n, sizeof(*t), tca, TCA_MAX, NULL, extack); + err = nlmsg_parse(n, sizeof(*t), tca, TCA_MAX, rtm_tca_policy, extack); if (err < 0) return err; @@ -1570,7 +1572,7 @@ static int tc_get_tfilter(struct sk_buff *skb, struct nlmsghdr *n, void *fh = NULL; int err; - err = nlmsg_parse(n, sizeof(*t), tca, TCA_MAX, NULL, extack); + err = nlmsg_parse(n, sizeof(*t), tca, TCA_MAX, rtm_tca_policy, extack); if (err < 0) return err; @@ -1727,7 +1729,8 @@ static int tc_dump_tfilter(struct sk_buff *skb, struct netlink_callback *cb) if (nlmsg_len(cb->nlh) < sizeof(*tcm)) return skb->len; - err = nlmsg_parse(cb->nlh, sizeof(*tcm), tca, TCA_MAX, NULL, NULL); + err = nlmsg_parse(cb->nlh, sizeof(*tcm), tca, TCA_MAX, NULL, + cb->extack); if (err) return err; @@ -1936,7 +1939,7 @@ static int tc_ctl_chain(struct sk_buff *skb, struct nlmsghdr *n, return -EPERM; replay: - err = nlmsg_parse(n, sizeof(*t), tca, TCA_MAX, NULL, extack); + err = nlmsg_parse(n, sizeof(*t), tca, TCA_MAX, rtm_tca_policy, extack); if (err < 0) return err; @@ -2054,7 +2057,8 @@ static int tc_dump_chain(struct sk_buff *skb, struct netlink_callback *cb) if (nlmsg_len(cb->nlh) < sizeof(*tcm)) return skb->len; - err = nlmsg_parse(cb->nlh, sizeof(*tcm), tca, TCA_MAX, NULL, NULL); + err = nlmsg_parse(cb->nlh, sizeof(*tcm), tca, TCA_MAX, rtm_tca_policy, + cb->extack); if (err) return err; diff --git a/net/sched/cls_flower.c b/net/sched/cls_flower.c index 92dd5071a708..9aada2d0ef06 100644 --- a/net/sched/cls_flower.c +++ b/net/sched/cls_flower.c @@ -79,7 +79,6 @@ struct fl_flow_tmplt { struct fl_flow_key mask; struct flow_dissector dissector; struct tcf_chain *chain; - struct rcu_work rwork; }; struct cls_fl_head { @@ -1438,20 +1437,12 @@ errout_tb: return ERR_PTR(err); } -static void fl_tmplt_destroy_work(struct work_struct *work) -{ - struct fl_flow_tmplt *tmplt = container_of(to_rcu_work(work), - struct fl_flow_tmplt, rwork); - - fl_hw_destroy_tmplt(tmplt->chain, tmplt); - kfree(tmplt); -} - static void fl_tmplt_destroy(void *tmplt_priv) { struct fl_flow_tmplt *tmplt = tmplt_priv; - tcf_queue_work(&tmplt->rwork, fl_tmplt_destroy_work); + fl_hw_destroy_tmplt(tmplt->chain, tmplt); + kfree(tmplt); } static int fl_dump_key_val(struct sk_buff *skb, diff --git a/net/sched/cls_u32.c b/net/sched/cls_u32.c index f218ccf1e2d9..4b28fd44576d 100644 --- a/net/sched/cls_u32.c +++ b/net/sched/cls_u32.c @@ -68,7 +68,6 @@ struct tc_u_knode { u32 mask; u32 __percpu *pcpu_success; #endif - struct tcf_proto *tp; struct rcu_work rwork; /* The 'sel' field MUST be the last field in structure to allow for * tc_u32_keys allocated at end of structure. @@ -80,10 +79,10 @@ struct tc_u_hnode { struct tc_u_hnode __rcu *next; u32 handle; u32 prio; - struct tc_u_common *tp_c; int refcnt; unsigned int divisor; struct idr handle_idr; + bool is_root; struct rcu_head rcu; u32 flags; /* The 'ht' field MUST be the last field in structure to allow for @@ -98,7 +97,7 @@ struct tc_u_common { int refcnt; struct idr handle_idr; struct hlist_node hnode; - struct rcu_head rcu; + long knodes; }; static inline unsigned int u32_hash_fold(__be32 key, @@ -344,19 +343,16 @@ static void *tc_u_common_ptr(const struct tcf_proto *tp) return block->q; } -static unsigned int tc_u_hash(const struct tcf_proto *tp) +static struct hlist_head *tc_u_hash(void *key) { - return hash_ptr(tc_u_common_ptr(tp), U32_HASH_SHIFT); + return tc_u_common_hash + hash_ptr(key, U32_HASH_SHIFT); } -static struct tc_u_common *tc_u_common_find(const struct tcf_proto *tp) +static struct tc_u_common *tc_u_common_find(void *key) { struct tc_u_common *tc; - unsigned int h; - - h = tc_u_hash(tp); - hlist_for_each_entry(tc, &tc_u_common_hash[h], hnode) { - if (tc->ptr == tc_u_common_ptr(tp)) + hlist_for_each_entry(tc, tc_u_hash(key), hnode) { + if (tc->ptr == key) return tc; } return NULL; @@ -365,10 +361,8 @@ static struct tc_u_common *tc_u_common_find(const struct tcf_proto *tp) static int u32_init(struct tcf_proto *tp) { struct tc_u_hnode *root_ht; - struct tc_u_common *tp_c; - unsigned int h; - - tp_c = tc_u_common_find(tp); + void *key = tc_u_common_ptr(tp); + struct tc_u_common *tp_c = tc_u_common_find(key); root_ht = kzalloc(sizeof(*root_ht), GFP_KERNEL); if (root_ht == NULL) @@ -377,6 +371,7 @@ static int u32_init(struct tcf_proto *tp) root_ht->refcnt++; root_ht->handle = tp_c ? gen_new_htid(tp_c, root_ht) : 0x80000000; root_ht->prio = tp->prio; + root_ht->is_root = true; idr_init(&root_ht->handle_idr); if (tp_c == NULL) { @@ -385,26 +380,24 @@ static int u32_init(struct tcf_proto *tp) kfree(root_ht); return -ENOBUFS; } - tp_c->ptr = tc_u_common_ptr(tp); + tp_c->ptr = key; INIT_HLIST_NODE(&tp_c->hnode); idr_init(&tp_c->handle_idr); - h = tc_u_hash(tp); - hlist_add_head(&tp_c->hnode, &tc_u_common_hash[h]); + hlist_add_head(&tp_c->hnode, tc_u_hash(key)); } tp_c->refcnt++; RCU_INIT_POINTER(root_ht->next, tp_c->hlist); rcu_assign_pointer(tp_c->hlist, root_ht); - root_ht->tp_c = tp_c; + root_ht->refcnt++; rcu_assign_pointer(tp->root, root_ht); tp->data = tp_c; return 0; } -static int u32_destroy_key(struct tcf_proto *tp, struct tc_u_knode *n, - bool free_pf) +static int u32_destroy_key(struct tc_u_knode *n, bool free_pf) { struct tc_u_hnode *ht = rtnl_dereference(n->ht_down); @@ -438,7 +431,7 @@ static void u32_delete_key_work(struct work_struct *work) struct tc_u_knode, rwork); rtnl_lock(); - u32_destroy_key(key->tp, key, false); + u32_destroy_key(key, false); rtnl_unlock(); } @@ -455,12 +448,13 @@ static void u32_delete_key_freepf_work(struct work_struct *work) struct tc_u_knode, rwork); rtnl_lock(); - u32_destroy_key(key->tp, key, true); + u32_destroy_key(key, true); rtnl_unlock(); } static int u32_delete_key(struct tcf_proto *tp, struct tc_u_knode *key) { + struct tc_u_common *tp_c = tp->data; struct tc_u_knode __rcu **kp; struct tc_u_knode *pkp; struct tc_u_hnode *ht = rtnl_dereference(key->ht_up); @@ -471,6 +465,7 @@ static int u32_delete_key(struct tcf_proto *tp, struct tc_u_knode *key) kp = &pkp->next, pkp = rtnl_dereference(*kp)) { if (pkp == key) { RCU_INIT_POINTER(*kp, key->next); + tp_c->knodes--; tcf_unbind_filter(tp, &key->res); idr_remove(&ht->handle_idr, key->handle); @@ -585,6 +580,7 @@ static int u32_replace_hw_knode(struct tcf_proto *tp, struct tc_u_knode *n, static void u32_clear_hnode(struct tcf_proto *tp, struct tc_u_hnode *ht, struct netlink_ext_ack *extack) { + struct tc_u_common *tp_c = tp->data; struct tc_u_knode *n; unsigned int h; @@ -592,13 +588,14 @@ static void u32_clear_hnode(struct tcf_proto *tp, struct tc_u_hnode *ht, while ((n = rtnl_dereference(ht->ht[h])) != NULL) { RCU_INIT_POINTER(ht->ht[h], rtnl_dereference(n->next)); + tp_c->knodes--; tcf_unbind_filter(tp, &n->res); u32_remove_hw_knode(tp, n, extack); idr_remove(&ht->handle_idr, n->handle); if (tcf_exts_get_net(&n->exts)) tcf_queue_work(&n->rwork, u32_delete_key_freepf_work); else - u32_destroy_key(n->tp, n, true); + u32_destroy_key(n, true); } } } @@ -610,7 +607,7 @@ static int u32_destroy_hnode(struct tcf_proto *tp, struct tc_u_hnode *ht, struct tc_u_hnode __rcu **hn; struct tc_u_hnode *phn; - WARN_ON(ht->refcnt); + WARN_ON(--ht->refcnt); u32_clear_hnode(tp, ht, extack); @@ -631,17 +628,6 @@ static int u32_destroy_hnode(struct tcf_proto *tp, struct tc_u_hnode *ht, return -ENOENT; } -static bool ht_empty(struct tc_u_hnode *ht) -{ - unsigned int h; - - for (h = 0; h <= ht->divisor; h++) - if (rcu_access_pointer(ht->ht[h])) - return false; - - return true; -} - static void u32_destroy(struct tcf_proto *tp, struct netlink_ext_ack *extack) { struct tc_u_common *tp_c = tp->data; @@ -649,7 +635,7 @@ static void u32_destroy(struct tcf_proto *tp, struct netlink_ext_ack *extack) WARN_ON(root_ht == NULL); - if (root_ht && --root_ht->refcnt == 0) + if (root_ht && --root_ht->refcnt == 1) u32_destroy_hnode(tp, root_ht, extack); if (--tp_c->refcnt == 0) { @@ -679,26 +665,21 @@ static int u32_delete(struct tcf_proto *tp, void *arg, bool *last, struct netlink_ext_ack *extack) { struct tc_u_hnode *ht = arg; - struct tc_u_hnode *root_ht = rtnl_dereference(tp->root); struct tc_u_common *tp_c = tp->data; int ret = 0; - if (ht == NULL) - goto out; - if (TC_U32_KEY(ht->handle)) { u32_remove_hw_knode(tp, (struct tc_u_knode *)ht, extack); ret = u32_delete_key(tp, (struct tc_u_knode *)ht); goto out; } - if (root_ht == ht) { + if (ht->is_root) { NL_SET_ERR_MSG_MOD(extack, "Not allowed to delete root node"); return -EINVAL; } if (ht->refcnt == 1) { - ht->refcnt--; u32_destroy_hnode(tp, ht, extack); } else { NL_SET_ERR_MSG_MOD(extack, "Can not delete in-use filter"); @@ -706,38 +687,7 @@ static int u32_delete(struct tcf_proto *tp, void *arg, bool *last, } out: - *last = true; - if (root_ht) { - if (root_ht->refcnt > 1) { - *last = false; - goto ret; - } - if (root_ht->refcnt == 1) { - if (!ht_empty(root_ht)) { - *last = false; - goto ret; - } - } - } - - if (tp_c->refcnt > 1) { - *last = false; - goto ret; - } - - if (tp_c->refcnt == 1) { - struct tc_u_hnode *ht; - - for (ht = rtnl_dereference(tp_c->hlist); - ht; - ht = rtnl_dereference(ht->next)) - if (!ht_empty(ht)) { - *last = false; - break; - } - } - -ret: + *last = tp_c->refcnt == 1 && tp_c->knodes == 0; return ret; } @@ -768,7 +718,7 @@ static const struct nla_policy u32_policy[TCA_U32_MAX + 1] = { }; static int u32_set_parms(struct net *net, struct tcf_proto *tp, - unsigned long base, struct tc_u_hnode *ht, + unsigned long base, struct tc_u_knode *n, struct nlattr **tb, struct nlattr *est, bool ovr, struct netlink_ext_ack *extack) @@ -789,12 +739,16 @@ static int u32_set_parms(struct net *net, struct tcf_proto *tp, } if (handle) { - ht_down = u32_lookup_ht(ht->tp_c, handle); + ht_down = u32_lookup_ht(tp->data, handle); if (!ht_down) { NL_SET_ERR_MSG_MOD(extack, "Link hash table not found"); return -EINVAL; } + if (ht_down->is_root) { + NL_SET_ERR_MSG_MOD(extack, "Not linking to root node"); + return -EINVAL; + } ht_down->refcnt++; } @@ -891,7 +845,6 @@ static struct tc_u_knode *u32_init_knode(struct tcf_proto *tp, /* Similarly success statistics must be moved as pointers */ new->pcpu_success = n->pcpu_success; #endif - new->tp = tp; memcpy(&new->sel, s, sizeof(*s) + s->nkeys*sizeof(struct tc_u32_key)); if (tcf_exts_init(&new->exts, TCA_U32_ACT, TCA_U32_POLICE)) { @@ -960,18 +913,17 @@ static int u32_change(struct net *net, struct sk_buff *in_skb, if (!new) return -ENOMEM; - err = u32_set_parms(net, tp, base, - rtnl_dereference(n->ht_up), new, tb, + err = u32_set_parms(net, tp, base, new, tb, tca[TCA_RATE], ovr, extack); if (err) { - u32_destroy_key(tp, new, false); + u32_destroy_key(new, false); return err; } err = u32_replace_hw_knode(tp, new, flags, extack); if (err) { - u32_destroy_key(tp, new, false); + u32_destroy_key(new, false); return err; } @@ -988,7 +940,11 @@ static int u32_change(struct net *net, struct sk_buff *in_skb, if (tb[TCA_U32_DIVISOR]) { unsigned int divisor = nla_get_u32(tb[TCA_U32_DIVISOR]); - if (--divisor > 0x100) { + if (!is_power_of_2(divisor)) { + NL_SET_ERR_MSG_MOD(extack, "Divisor is not a power of 2"); + return -EINVAL; + } + if (divisor-- > 0x100) { NL_SET_ERR_MSG_MOD(extack, "Exceeded maximum 256 hash buckets"); return -EINVAL; } @@ -1013,7 +969,6 @@ static int u32_change(struct net *net, struct sk_buff *in_skb, return err; } } - ht->tp_c = tp_c; ht->refcnt = 1; ht->divisor = divisor; ht->handle = handle; @@ -1103,7 +1058,6 @@ static int u32_change(struct net *net, struct sk_buff *in_skb, n->handle = handle; n->fshift = s->hmask ? ffs(ntohl(s->hmask)) - 1 : 0; n->flags = flags; - n->tp = tp; err = tcf_exts_init(&n->exts, TCA_U32_ACT, TCA_U32_POLICE); if (err < 0) @@ -1125,7 +1079,7 @@ static int u32_change(struct net *net, struct sk_buff *in_skb, } #endif - err = u32_set_parms(net, tp, base, ht, n, tb, tca[TCA_RATE], ovr, + err = u32_set_parms(net, tp, base, n, tb, tca[TCA_RATE], ovr, extack); if (err == 0) { struct tc_u_knode __rcu **ins; @@ -1146,6 +1100,7 @@ static int u32_change(struct net *net, struct sk_buff *in_skb, RCU_INIT_POINTER(n->next, pins); rcu_assign_pointer(*ins, n); + tp_c->knodes++; *arg = n; return 0; } diff --git a/net/sched/sch_api.c b/net/sched/sch_api.c index 22e9799e5b69..ca3b0f46de53 100644 --- a/net/sched/sch_api.c +++ b/net/sched/sch_api.c @@ -1318,6 +1318,17 @@ check_loop_fn(struct Qdisc *q, unsigned long cl, struct qdisc_walker *w) return 0; } +const struct nla_policy rtm_tca_policy[TCA_MAX + 1] = { + [TCA_KIND] = { .type = NLA_STRING }, + [TCA_RATE] = { .type = NLA_BINARY, + .len = sizeof(struct tc_estimator) }, + [TCA_STAB] = { .type = NLA_NESTED }, + [TCA_DUMP_INVISIBLE] = { .type = NLA_FLAG }, + [TCA_CHAIN] = { .type = NLA_U32 }, + [TCA_INGRESS_BLOCK] = { .type = NLA_U32 }, + [TCA_EGRESS_BLOCK] = { .type = NLA_U32 }, +}; + /* * Delete/get qdisc. */ @@ -1338,7 +1349,8 @@ static int tc_get_qdisc(struct sk_buff *skb, struct nlmsghdr *n, !netlink_ns_capable(skb, net->user_ns, CAP_NET_ADMIN)) return -EPERM; - err = nlmsg_parse(n, sizeof(*tcm), tca, TCA_MAX, NULL, extack); + err = nlmsg_parse(n, sizeof(*tcm), tca, TCA_MAX, rtm_tca_policy, + extack); if (err < 0) return err; @@ -1422,7 +1434,8 @@ static int tc_modify_qdisc(struct sk_buff *skb, struct nlmsghdr *n, replay: /* Reinit, just in case something touches this. */ - err = nlmsg_parse(n, sizeof(*tcm), tca, TCA_MAX, NULL, extack); + err = nlmsg_parse(n, sizeof(*tcm), tca, TCA_MAX, rtm_tca_policy, + extack); if (err < 0) return err; @@ -1656,7 +1669,8 @@ static int tc_dump_qdisc(struct sk_buff *skb, struct netlink_callback *cb) idx = 0; ASSERT_RTNL(); - err = nlmsg_parse(nlh, sizeof(struct tcmsg), tca, TCA_MAX, NULL, NULL); + err = nlmsg_parse(nlh, sizeof(struct tcmsg), tca, TCA_MAX, + rtm_tca_policy, cb->extack); if (err < 0) return err; @@ -1875,7 +1889,8 @@ static int tc_ctl_tclass(struct sk_buff *skb, struct nlmsghdr *n, !netlink_ns_capable(skb, net->user_ns, CAP_NET_ADMIN)) return -EPERM; - err = nlmsg_parse(n, sizeof(*tcm), tca, TCA_MAX, NULL, extack); + err = nlmsg_parse(n, sizeof(*tcm), tca, TCA_MAX, rtm_tca_policy, + extack); if (err < 0) return err; @@ -2054,7 +2069,8 @@ static int tc_dump_tclass_root(struct Qdisc *root, struct sk_buff *skb, if (tcm->tcm_parent) { q = qdisc_match_from_root(root, TC_H_MAJ(tcm->tcm_parent)); - if (q && tc_dump_tclass_qdisc(q, skb, tcm, cb, t_p, s_t) < 0) + if (q && q != root && + tc_dump_tclass_qdisc(q, skb, tcm, cb, t_p, s_t) < 0) return -1; return 0; } diff --git a/net/sched/sch_cake.c b/net/sched/sch_cake.c index dc539295ae65..b910cd5c56f7 100644 --- a/net/sched/sch_cake.c +++ b/net/sched/sch_cake.c @@ -2644,7 +2644,7 @@ static int cake_init(struct Qdisc *sch, struct nlattr *opt, for (i = 1; i <= CAKE_QUEUES; i++) quantum_div[i] = 65535 / i; - q->tins = kvzalloc(CAKE_MAX_TINS * sizeof(struct cake_tin_data), + q->tins = kvcalloc(CAKE_MAX_TINS, sizeof(struct cake_tin_data), GFP_KERNEL); if (!q->tins) goto nomem; diff --git a/net/sched/sch_fq.c b/net/sched/sch_fq.c index 338222a6c664..4b1af706896c 100644 --- a/net/sched/sch_fq.c +++ b/net/sched/sch_fq.c @@ -92,8 +92,8 @@ struct fq_sched_data { u32 quantum; u32 initial_quantum; u32 flow_refill_delay; - u32 flow_max_rate; /* optional max rate per flow */ u32 flow_plimit; /* max packets per flow */ + unsigned long flow_max_rate; /* optional max rate per flow */ u32 orphan_mask; /* mask for orphaned skb */ u32 low_rate_threshold; struct rb_root *fq_root; @@ -416,7 +416,8 @@ static struct sk_buff *fq_dequeue(struct Qdisc *sch) struct fq_flow_head *head; struct sk_buff *skb; struct fq_flow *f; - u32 rate, plen; + unsigned long rate; + u32 plen; skb = fq_dequeue_head(sch, &q->internal); if (skb) @@ -443,7 +444,7 @@ begin: } skb = f->head; - if (skb && !skb_is_tcp_pure_ack(skb)) { + if (skb) { u64 time_next_packet = max_t(u64, ktime_to_ns(skb->tstamp), f->time_next_packet); @@ -485,11 +486,11 @@ begin: if (f->credit > 0) goto out; } - if (rate != ~0U) { + if (rate != ~0UL) { u64 len = (u64)plen * NSEC_PER_SEC; if (likely(rate)) - do_div(len, rate); + len = div64_ul(len, rate); /* Since socket rate can change later, * clamp the delay to 1 second. * Really, providers of too big packets should be fixed ! @@ -701,9 +702,11 @@ static int fq_change(struct Qdisc *sch, struct nlattr *opt, pr_warn_ratelimited("sch_fq: defrate %u ignored.\n", nla_get_u32(tb[TCA_FQ_FLOW_DEFAULT_RATE])); - if (tb[TCA_FQ_FLOW_MAX_RATE]) - q->flow_max_rate = nla_get_u32(tb[TCA_FQ_FLOW_MAX_RATE]); + if (tb[TCA_FQ_FLOW_MAX_RATE]) { + u32 rate = nla_get_u32(tb[TCA_FQ_FLOW_MAX_RATE]); + q->flow_max_rate = (rate == ~0U) ? ~0UL : rate; + } if (tb[TCA_FQ_LOW_RATE_THRESHOLD]) q->low_rate_threshold = nla_get_u32(tb[TCA_FQ_LOW_RATE_THRESHOLD]); @@ -766,7 +769,7 @@ static int fq_init(struct Qdisc *sch, struct nlattr *opt, q->quantum = 2 * psched_mtu(qdisc_dev(sch)); q->initial_quantum = 10 * psched_mtu(qdisc_dev(sch)); q->flow_refill_delay = msecs_to_jiffies(40); - q->flow_max_rate = ~0U; + q->flow_max_rate = ~0UL; q->time_next_delayed_flow = ~0ULL; q->rate_enable = 1; q->new_flows.first = NULL; @@ -802,7 +805,8 @@ static int fq_dump(struct Qdisc *sch, struct sk_buff *skb) nla_put_u32(skb, TCA_FQ_QUANTUM, q->quantum) || nla_put_u32(skb, TCA_FQ_INITIAL_QUANTUM, q->initial_quantum) || nla_put_u32(skb, TCA_FQ_RATE_ENABLE, q->rate_enable) || - nla_put_u32(skb, TCA_FQ_FLOW_MAX_RATE, q->flow_max_rate) || + nla_put_u32(skb, TCA_FQ_FLOW_MAX_RATE, + min_t(unsigned long, q->flow_max_rate, ~0U)) || nla_put_u32(skb, TCA_FQ_FLOW_REFILL_DELAY, jiffies_to_usecs(q->flow_refill_delay)) || nla_put_u32(skb, TCA_FQ_ORPHAN_MASK, q->orphan_mask) || diff --git a/net/sched/sch_generic.c b/net/sched/sch_generic.c index 3023929852e8..de1663f7d3ad 100644 --- a/net/sched/sch_generic.c +++ b/net/sched/sch_generic.c @@ -572,6 +572,18 @@ struct Qdisc noop_qdisc = { .dev_queue = &noop_netdev_queue, .running = SEQCNT_ZERO(noop_qdisc.running), .busylock = __SPIN_LOCK_UNLOCKED(noop_qdisc.busylock), + .gso_skb = { + .next = (struct sk_buff *)&noop_qdisc.gso_skb, + .prev = (struct sk_buff *)&noop_qdisc.gso_skb, + .qlen = 0, + .lock = __SPIN_LOCK_UNLOCKED(noop_qdisc.gso_skb.lock), + }, + .skb_bad_txq = { + .next = (struct sk_buff *)&noop_qdisc.skb_bad_txq, + .prev = (struct sk_buff *)&noop_qdisc.skb_bad_txq, + .qlen = 0, + .lock = __SPIN_LOCK_UNLOCKED(noop_qdisc.skb_bad_txq.lock), + }, }; EXPORT_SYMBOL(noop_qdisc); @@ -1273,8 +1285,6 @@ static void dev_init_scheduler_queue(struct net_device *dev, rcu_assign_pointer(dev_queue->qdisc, qdisc); dev_queue->qdisc_sleeping = qdisc; - __skb_queue_head_init(&qdisc->gso_skb); - __skb_queue_head_init(&qdisc->skb_bad_txq); } void dev_init_scheduler(struct net_device *dev) diff --git a/net/sched/sch_gred.c b/net/sched/sch_gred.c index cbe4831f46f4..4a042abf844c 100644 --- a/net/sched/sch_gred.c +++ b/net/sched/sch_gred.c @@ -413,7 +413,7 @@ static int gred_change(struct Qdisc *sch, struct nlattr *opt, if (tb[TCA_GRED_PARMS] == NULL && tb[TCA_GRED_STAB] == NULL) { if (tb[TCA_GRED_LIMIT] != NULL) sch->limit = nla_get_u32(tb[TCA_GRED_LIMIT]); - return gred_change_table_def(sch, opt); + return gred_change_table_def(sch, tb[TCA_GRED_DPS]); } if (tb[TCA_GRED_PARMS] == NULL || diff --git a/net/sched/sch_pie.c b/net/sched/sch_pie.c index 18d30bb86881..d1429371592f 100644 --- a/net/sched/sch_pie.c +++ b/net/sched/sch_pie.c @@ -110,8 +110,8 @@ static bool drop_early(struct Qdisc *sch, u32 packet_size) /* If current delay is less than half of target, and * if drop prob is low already, disable early_drop */ - if ((q->vars.qdelay < q->params.target / 2) - && (q->vars.prob < MAX_PROB / 5)) + if ((q->vars.qdelay < q->params.target / 2) && + (q->vars.prob < MAX_PROB / 5)) return false; /* If we have fewer than 2 mtu-sized packets, disable drop_early, @@ -209,7 +209,8 @@ static int pie_change(struct Qdisc *sch, struct nlattr *opt, /* tupdate is in jiffies */ if (tb[TCA_PIE_TUPDATE]) - q->params.tupdate = usecs_to_jiffies(nla_get_u32(tb[TCA_PIE_TUPDATE])); + q->params.tupdate = + usecs_to_jiffies(nla_get_u32(tb[TCA_PIE_TUPDATE])); if (tb[TCA_PIE_LIMIT]) { u32 limit = nla_get_u32(tb[TCA_PIE_LIMIT]); @@ -247,7 +248,6 @@ static int pie_change(struct Qdisc *sch, struct nlattr *opt, static void pie_process_dequeue(struct Qdisc *sch, struct sk_buff *skb) { - struct pie_sched_data *q = qdisc_priv(sch); int qlen = sch->qstats.backlog; /* current queue size in bytes */ @@ -294,9 +294,9 @@ static void pie_process_dequeue(struct Qdisc *sch, struct sk_buff *skb) * dq_count to 0 to re-enter the if block when the next * packet is dequeued */ - if (qlen < QUEUE_THRESHOLD) + if (qlen < QUEUE_THRESHOLD) { q->vars.dq_count = DQCOUNT_INVALID; - else { + } else { q->vars.dq_count = 0; q->vars.dq_tstamp = psched_get_time(); } @@ -370,7 +370,7 @@ static void calculate_probability(struct Qdisc *sch) oldprob = q->vars.prob; /* to ensure we increase probability in steps of no more than 2% */ - if (delta > (s32) (MAX_PROB / (100 / 2)) && + if (delta > (s32)(MAX_PROB / (100 / 2)) && q->vars.prob >= MAX_PROB / 10) delta = (MAX_PROB / 100) * 2; @@ -405,7 +405,7 @@ static void calculate_probability(struct Qdisc *sch) * delay is 0 for 2 consecutive Tupdate periods. */ - if ((qdelay == 0) && (qdelay_old == 0) && update_prob) + if (qdelay == 0 && qdelay_old == 0 && update_prob) q->vars.prob = (q->vars.prob * 98) / 100; q->vars.qdelay = qdelay; @@ -419,8 +419,8 @@ static void calculate_probability(struct Qdisc *sch) */ if ((q->vars.qdelay < q->params.target / 2) && (q->vars.qdelay_old < q->params.target / 2) && - (q->vars.prob == 0) && - (q->vars.avg_dq_rate > 0)) + q->vars.prob == 0 && + q->vars.avg_dq_rate > 0) pie_vars_init(&q->vars); } @@ -437,7 +437,6 @@ static void pie_timer(struct timer_list *t) if (q->params.tupdate) mod_timer(&q->adapt_timer, jiffies + q->params.tupdate); spin_unlock(root_lock); - } static int pie_init(struct Qdisc *sch, struct nlattr *opt, @@ -469,15 +468,16 @@ static int pie_dump(struct Qdisc *sch, struct sk_buff *skb) struct nlattr *opts; opts = nla_nest_start(skb, TCA_OPTIONS); - if (opts == NULL) + if (!opts) goto nla_put_failure; /* convert target from pschedtime to us */ if (nla_put_u32(skb, TCA_PIE_TARGET, - ((u32) PSCHED_TICKS2NS(q->params.target)) / + ((u32)PSCHED_TICKS2NS(q->params.target)) / NSEC_PER_USEC) || nla_put_u32(skb, TCA_PIE_LIMIT, sch->limit) || - nla_put_u32(skb, TCA_PIE_TUPDATE, jiffies_to_usecs(q->params.tupdate)) || + nla_put_u32(skb, TCA_PIE_TUPDATE, + jiffies_to_usecs(q->params.tupdate)) || nla_put_u32(skb, TCA_PIE_ALPHA, q->params.alpha) || nla_put_u32(skb, TCA_PIE_BETA, q->params.beta) || nla_put_u32(skb, TCA_PIE_ECN, q->params.ecn) || @@ -489,7 +489,6 @@ static int pie_dump(struct Qdisc *sch, struct sk_buff *skb) nla_put_failure: nla_nest_cancel(skb, opts); return -1; - } static int pie_dump_stats(struct Qdisc *sch, struct gnet_dump *d) @@ -497,7 +496,7 @@ static int pie_dump_stats(struct Qdisc *sch, struct gnet_dump *d) struct pie_sched_data *q = qdisc_priv(sch); struct tc_pie_xstats st = { .prob = q->vars.prob, - .delay = ((u32) PSCHED_TICKS2NS(q->vars.qdelay)) / + .delay = ((u32)PSCHED_TICKS2NS(q->vars.qdelay)) / NSEC_PER_USEC, /* unscale and return dq_rate in bytes per sec */ .avg_dq_rate = q->vars.avg_dq_rate * @@ -514,8 +513,7 @@ static int pie_dump_stats(struct Qdisc *sch, struct gnet_dump *d) static struct sk_buff *pie_qdisc_dequeue(struct Qdisc *sch) { - struct sk_buff *skb; - skb = qdisc_dequeue_head(sch); + struct sk_buff *skb = qdisc_dequeue_head(sch); if (!skb) return NULL; @@ -527,6 +525,7 @@ static struct sk_buff *pie_qdisc_dequeue(struct Qdisc *sch) static void pie_reset(struct Qdisc *sch) { struct pie_sched_data *q = qdisc_priv(sch); + qdisc_reset_queue(sch); pie_vars_init(&q->vars); } @@ -534,6 +533,7 @@ static void pie_reset(struct Qdisc *sch) static void pie_destroy(struct Qdisc *sch) { struct pie_sched_data *q = qdisc_priv(sch); + q->params.tupdate = 0; del_timer_sync(&q->adapt_timer); } diff --git a/net/sched/sch_taprio.c b/net/sched/sch_taprio.c new file mode 100644 index 000000000000..206e4dbed12f --- /dev/null +++ b/net/sched/sch_taprio.c @@ -0,0 +1,962 @@ +// SPDX-License-Identifier: GPL-2.0 + +/* net/sched/sch_taprio.c Time Aware Priority Scheduler + * + * Authors: Vinicius Costa Gomes <vinicius.gomes@intel.com> + * + */ + +#include <linux/types.h> +#include <linux/slab.h> +#include <linux/kernel.h> +#include <linux/string.h> +#include <linux/list.h> +#include <linux/errno.h> +#include <linux/skbuff.h> +#include <linux/module.h> +#include <linux/spinlock.h> +#include <net/netlink.h> +#include <net/pkt_sched.h> +#include <net/pkt_cls.h> +#include <net/sch_generic.h> + +#define TAPRIO_ALL_GATES_OPEN -1 + +struct sched_entry { + struct list_head list; + + /* The instant that this entry "closes" and the next one + * should open, the qdisc will make some effort so that no + * packet leaves after this time. + */ + ktime_t close_time; + atomic_t budget; + int index; + u32 gate_mask; + u32 interval; + u8 command; +}; + +struct taprio_sched { + struct Qdisc **qdiscs; + struct Qdisc *root; + s64 base_time; + int clockid; + int picos_per_byte; /* Using picoseconds because for 10Gbps+ + * speeds it's sub-nanoseconds per byte + */ + size_t num_entries; + + /* Protects the update side of the RCU protected current_entry */ + spinlock_t current_entry_lock; + struct sched_entry __rcu *current_entry; + struct list_head entries; + ktime_t (*get_time)(void); + struct hrtimer advance_timer; +}; + +static int taprio_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + struct taprio_sched *q = qdisc_priv(sch); + struct Qdisc *child; + int queue; + + queue = skb_get_queue_mapping(skb); + + child = q->qdiscs[queue]; + if (unlikely(!child)) + return qdisc_drop(skb, sch, to_free); + + qdisc_qstats_backlog_inc(sch, skb); + sch->q.qlen++; + + return qdisc_enqueue(skb, child, to_free); +} + +static struct sk_buff *taprio_peek(struct Qdisc *sch) +{ + struct taprio_sched *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + struct sched_entry *entry; + struct sk_buff *skb; + u32 gate_mask; + int i; + + rcu_read_lock(); + entry = rcu_dereference(q->current_entry); + gate_mask = entry ? entry->gate_mask : -1; + rcu_read_unlock(); + + if (!gate_mask) + return NULL; + + for (i = 0; i < dev->num_tx_queues; i++) { + struct Qdisc *child = q->qdiscs[i]; + int prio; + u8 tc; + + if (unlikely(!child)) + continue; + + skb = child->ops->peek(child); + if (!skb) + continue; + + prio = skb->priority; + tc = netdev_get_prio_tc_map(dev, prio); + + if (!(gate_mask & BIT(tc))) + return NULL; + + return skb; + } + + return NULL; +} + +static inline int length_to_duration(struct taprio_sched *q, int len) +{ + return (len * q->picos_per_byte) / 1000; +} + +static struct sk_buff *taprio_dequeue(struct Qdisc *sch) +{ + struct taprio_sched *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + struct sched_entry *entry; + struct sk_buff *skb; + u32 gate_mask; + int i; + + rcu_read_lock(); + entry = rcu_dereference(q->current_entry); + /* if there's no entry, it means that the schedule didn't + * start yet, so force all gates to be open, this is in + * accordance to IEEE 802.1Qbv-2015 Section 8.6.9.4.5 + * "AdminGateSates" + */ + gate_mask = entry ? entry->gate_mask : TAPRIO_ALL_GATES_OPEN; + rcu_read_unlock(); + + if (!gate_mask) + return NULL; + + for (i = 0; i < dev->num_tx_queues; i++) { + struct Qdisc *child = q->qdiscs[i]; + ktime_t guard; + int prio; + int len; + u8 tc; + + if (unlikely(!child)) + continue; + + skb = child->ops->peek(child); + if (!skb) + continue; + + prio = skb->priority; + tc = netdev_get_prio_tc_map(dev, prio); + + if (!(gate_mask & BIT(tc))) + continue; + + len = qdisc_pkt_len(skb); + guard = ktime_add_ns(q->get_time(), + length_to_duration(q, len)); + + /* In the case that there's no gate entry, there's no + * guard band ... + */ + if (gate_mask != TAPRIO_ALL_GATES_OPEN && + ktime_after(guard, entry->close_time)) + return NULL; + + /* ... and no budget. */ + if (gate_mask != TAPRIO_ALL_GATES_OPEN && + atomic_sub_return(len, &entry->budget) < 0) + return NULL; + + skb = child->ops->dequeue(child); + if (unlikely(!skb)) + return NULL; + + qdisc_bstats_update(sch, skb); + qdisc_qstats_backlog_dec(sch, skb); + sch->q.qlen--; + + return skb; + } + + return NULL; +} + +static bool should_restart_cycle(const struct taprio_sched *q, + const struct sched_entry *entry) +{ + WARN_ON(!entry); + + return list_is_last(&entry->list, &q->entries); +} + +static enum hrtimer_restart advance_sched(struct hrtimer *timer) +{ + struct taprio_sched *q = container_of(timer, struct taprio_sched, + advance_timer); + struct sched_entry *entry, *next; + struct Qdisc *sch = q->root; + ktime_t close_time; + + spin_lock(&q->current_entry_lock); + entry = rcu_dereference_protected(q->current_entry, + lockdep_is_held(&q->current_entry_lock)); + + /* This is the case that it's the first time that the schedule + * runs, so it only happens once per schedule. The first entry + * is pre-calculated during the schedule initialization. + */ + if (unlikely(!entry)) { + next = list_first_entry(&q->entries, struct sched_entry, + list); + close_time = next->close_time; + goto first_run; + } + + if (should_restart_cycle(q, entry)) + next = list_first_entry(&q->entries, struct sched_entry, + list); + else + next = list_next_entry(entry, list); + + close_time = ktime_add_ns(entry->close_time, next->interval); + + next->close_time = close_time; + atomic_set(&next->budget, + (next->interval * 1000) / q->picos_per_byte); + +first_run: + rcu_assign_pointer(q->current_entry, next); + spin_unlock(&q->current_entry_lock); + + hrtimer_set_expires(&q->advance_timer, close_time); + + rcu_read_lock(); + __netif_schedule(sch); + rcu_read_unlock(); + + return HRTIMER_RESTART; +} + +static const struct nla_policy entry_policy[TCA_TAPRIO_SCHED_ENTRY_MAX + 1] = { + [TCA_TAPRIO_SCHED_ENTRY_INDEX] = { .type = NLA_U32 }, + [TCA_TAPRIO_SCHED_ENTRY_CMD] = { .type = NLA_U8 }, + [TCA_TAPRIO_SCHED_ENTRY_GATE_MASK] = { .type = NLA_U32 }, + [TCA_TAPRIO_SCHED_ENTRY_INTERVAL] = { .type = NLA_U32 }, +}; + +static const struct nla_policy entry_list_policy[TCA_TAPRIO_SCHED_MAX + 1] = { + [TCA_TAPRIO_SCHED_ENTRY] = { .type = NLA_NESTED }, +}; + +static const struct nla_policy taprio_policy[TCA_TAPRIO_ATTR_MAX + 1] = { + [TCA_TAPRIO_ATTR_PRIOMAP] = { + .len = sizeof(struct tc_mqprio_qopt) + }, + [TCA_TAPRIO_ATTR_SCHED_ENTRY_LIST] = { .type = NLA_NESTED }, + [TCA_TAPRIO_ATTR_SCHED_BASE_TIME] = { .type = NLA_S64 }, + [TCA_TAPRIO_ATTR_SCHED_SINGLE_ENTRY] = { .type = NLA_NESTED }, + [TCA_TAPRIO_ATTR_SCHED_CLOCKID] = { .type = NLA_S32 }, +}; + +static int fill_sched_entry(struct nlattr **tb, struct sched_entry *entry, + struct netlink_ext_ack *extack) +{ + u32 interval = 0; + + if (tb[TCA_TAPRIO_SCHED_ENTRY_CMD]) + entry->command = nla_get_u8( + tb[TCA_TAPRIO_SCHED_ENTRY_CMD]); + + if (tb[TCA_TAPRIO_SCHED_ENTRY_GATE_MASK]) + entry->gate_mask = nla_get_u32( + tb[TCA_TAPRIO_SCHED_ENTRY_GATE_MASK]); + + if (tb[TCA_TAPRIO_SCHED_ENTRY_INTERVAL]) + interval = nla_get_u32( + tb[TCA_TAPRIO_SCHED_ENTRY_INTERVAL]); + + if (interval == 0) { + NL_SET_ERR_MSG(extack, "Invalid interval for schedule entry"); + return -EINVAL; + } + + entry->interval = interval; + + return 0; +} + +static int parse_sched_entry(struct nlattr *n, struct sched_entry *entry, + int index, struct netlink_ext_ack *extack) +{ + struct nlattr *tb[TCA_TAPRIO_SCHED_ENTRY_MAX + 1] = { }; + int err; + + err = nla_parse_nested(tb, TCA_TAPRIO_SCHED_ENTRY_MAX, n, + entry_policy, NULL); + if (err < 0) { + NL_SET_ERR_MSG(extack, "Could not parse nested entry"); + return -EINVAL; + } + + entry->index = index; + + return fill_sched_entry(tb, entry, extack); +} + +/* Returns the number of entries in case of success */ +static int parse_sched_single_entry(struct nlattr *n, + struct taprio_sched *q, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb_entry[TCA_TAPRIO_SCHED_ENTRY_MAX + 1] = { }; + struct nlattr *tb_list[TCA_TAPRIO_SCHED_MAX + 1] = { }; + struct sched_entry *entry; + bool found = false; + u32 index; + int err; + + err = nla_parse_nested(tb_list, TCA_TAPRIO_SCHED_MAX, + n, entry_list_policy, NULL); + if (err < 0) { + NL_SET_ERR_MSG(extack, "Could not parse nested entry"); + return -EINVAL; + } + + if (!tb_list[TCA_TAPRIO_SCHED_ENTRY]) { + NL_SET_ERR_MSG(extack, "Single-entry must include an entry"); + return -EINVAL; + } + + err = nla_parse_nested(tb_entry, TCA_TAPRIO_SCHED_ENTRY_MAX, + tb_list[TCA_TAPRIO_SCHED_ENTRY], + entry_policy, NULL); + if (err < 0) { + NL_SET_ERR_MSG(extack, "Could not parse nested entry"); + return -EINVAL; + } + + if (!tb_entry[TCA_TAPRIO_SCHED_ENTRY_INDEX]) { + NL_SET_ERR_MSG(extack, "Entry must specify an index\n"); + return -EINVAL; + } + + index = nla_get_u32(tb_entry[TCA_TAPRIO_SCHED_ENTRY_INDEX]); + if (index >= q->num_entries) { + NL_SET_ERR_MSG(extack, "Index for single entry exceeds number of entries in schedule"); + return -EINVAL; + } + + list_for_each_entry(entry, &q->entries, list) { + if (entry->index == index) { + found = true; + break; + } + } + + if (!found) { + NL_SET_ERR_MSG(extack, "Could not find entry"); + return -ENOENT; + } + + err = fill_sched_entry(tb_entry, entry, extack); + if (err < 0) + return err; + + return q->num_entries; +} + +static int parse_sched_list(struct nlattr *list, + struct taprio_sched *q, + struct netlink_ext_ack *extack) +{ + struct nlattr *n; + int err, rem; + int i = 0; + + if (!list) + return -EINVAL; + + nla_for_each_nested(n, list, rem) { + struct sched_entry *entry; + + if (nla_type(n) != TCA_TAPRIO_SCHED_ENTRY) { + NL_SET_ERR_MSG(extack, "Attribute is not of type 'entry'"); + continue; + } + + entry = kzalloc(sizeof(*entry), GFP_KERNEL); + if (!entry) { + NL_SET_ERR_MSG(extack, "Not enough memory for entry"); + return -ENOMEM; + } + + err = parse_sched_entry(n, entry, i, extack); + if (err < 0) { + kfree(entry); + return err; + } + + list_add_tail(&entry->list, &q->entries); + i++; + } + + q->num_entries = i; + + return i; +} + +/* Returns the number of entries in case of success */ +static int parse_taprio_opt(struct nlattr **tb, struct taprio_sched *q, + struct netlink_ext_ack *extack) +{ + int err = 0; + int clockid; + + if (tb[TCA_TAPRIO_ATTR_SCHED_ENTRY_LIST] && + tb[TCA_TAPRIO_ATTR_SCHED_SINGLE_ENTRY]) + return -EINVAL; + + if (tb[TCA_TAPRIO_ATTR_SCHED_SINGLE_ENTRY] && q->num_entries == 0) + return -EINVAL; + + if (q->clockid == -1 && !tb[TCA_TAPRIO_ATTR_SCHED_CLOCKID]) + return -EINVAL; + + if (tb[TCA_TAPRIO_ATTR_SCHED_BASE_TIME]) + q->base_time = nla_get_s64( + tb[TCA_TAPRIO_ATTR_SCHED_BASE_TIME]); + + if (tb[TCA_TAPRIO_ATTR_SCHED_CLOCKID]) { + clockid = nla_get_s32(tb[TCA_TAPRIO_ATTR_SCHED_CLOCKID]); + + /* We only support static clockids and we don't allow + * for it to be modified after the first init. + */ + if (clockid < 0 || (q->clockid != -1 && q->clockid != clockid)) + return -EINVAL; + + q->clockid = clockid; + } + + if (tb[TCA_TAPRIO_ATTR_SCHED_ENTRY_LIST]) + err = parse_sched_list( + tb[TCA_TAPRIO_ATTR_SCHED_ENTRY_LIST], q, extack); + else if (tb[TCA_TAPRIO_ATTR_SCHED_SINGLE_ENTRY]) + err = parse_sched_single_entry( + tb[TCA_TAPRIO_ATTR_SCHED_SINGLE_ENTRY], q, extack); + + /* parse_sched_* return the number of entries in the schedule, + * a schedule with zero entries is an error. + */ + if (err == 0) { + NL_SET_ERR_MSG(extack, "The schedule should contain at least one entry"); + return -EINVAL; + } + + return err; +} + +static int taprio_parse_mqprio_opt(struct net_device *dev, + struct tc_mqprio_qopt *qopt, + struct netlink_ext_ack *extack) +{ + int i, j; + + if (!qopt) { + NL_SET_ERR_MSG(extack, "'mqprio' configuration is necessary"); + return -EINVAL; + } + + /* Verify num_tc is not out of max range */ + if (qopt->num_tc > TC_MAX_QUEUE) { + NL_SET_ERR_MSG(extack, "Number of traffic classes is outside valid range"); + return -EINVAL; + } + + /* taprio imposes that traffic classes map 1:n to tx queues */ + if (qopt->num_tc > dev->num_tx_queues) { + NL_SET_ERR_MSG(extack, "Number of traffic classes is greater than number of HW queues"); + return -EINVAL; + } + + /* Verify priority mapping uses valid tcs */ + for (i = 0; i < TC_BITMASK + 1; i++) { + if (qopt->prio_tc_map[i] >= qopt->num_tc) { + NL_SET_ERR_MSG(extack, "Invalid traffic class in priority to traffic class mapping"); + return -EINVAL; + } + } + + for (i = 0; i < qopt->num_tc; i++) { + unsigned int last = qopt->offset[i] + qopt->count[i]; + + /* Verify the queue count is in tx range being equal to the + * real_num_tx_queues indicates the last queue is in use. + */ + if (qopt->offset[i] >= dev->num_tx_queues || + !qopt->count[i] || + last > dev->real_num_tx_queues) { + NL_SET_ERR_MSG(extack, "Invalid queue in traffic class to queue mapping"); + return -EINVAL; + } + + /* Verify that the offset and counts do not overlap */ + for (j = i + 1; j < qopt->num_tc; j++) { + if (last > qopt->offset[j]) { + NL_SET_ERR_MSG(extack, "Detected overlap in the traffic class to queue mapping"); + return -EINVAL; + } + } + } + + return 0; +} + +static ktime_t taprio_get_start_time(struct Qdisc *sch) +{ + struct taprio_sched *q = qdisc_priv(sch); + struct sched_entry *entry; + ktime_t now, base, cycle; + s64 n; + + base = ns_to_ktime(q->base_time); + cycle = 0; + + /* Calculate the cycle_time, by summing all the intervals. + */ + list_for_each_entry(entry, &q->entries, list) + cycle = ktime_add_ns(cycle, entry->interval); + + if (!cycle) + return base; + + now = q->get_time(); + + if (ktime_after(base, now)) + return base; + + /* Schedule the start time for the beginning of the next + * cycle. + */ + n = div64_s64(ktime_sub_ns(now, base), cycle); + + return ktime_add_ns(base, (n + 1) * cycle); +} + +static void taprio_start_sched(struct Qdisc *sch, ktime_t start) +{ + struct taprio_sched *q = qdisc_priv(sch); + struct sched_entry *first; + unsigned long flags; + + spin_lock_irqsave(&q->current_entry_lock, flags); + + first = list_first_entry(&q->entries, struct sched_entry, + list); + + first->close_time = ktime_add_ns(start, first->interval); + atomic_set(&first->budget, + (first->interval * 1000) / q->picos_per_byte); + rcu_assign_pointer(q->current_entry, NULL); + + spin_unlock_irqrestore(&q->current_entry_lock, flags); + + hrtimer_start(&q->advance_timer, start, HRTIMER_MODE_ABS); +} + +static int taprio_change(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[TCA_TAPRIO_ATTR_MAX + 1] = { }; + struct taprio_sched *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + struct tc_mqprio_qopt *mqprio = NULL; + struct ethtool_link_ksettings ecmd; + int i, err, size; + s64 link_speed; + ktime_t start; + + err = nla_parse_nested(tb, TCA_TAPRIO_ATTR_MAX, opt, + taprio_policy, extack); + if (err < 0) + return err; + + err = -EINVAL; + if (tb[TCA_TAPRIO_ATTR_PRIOMAP]) + mqprio = nla_data(tb[TCA_TAPRIO_ATTR_PRIOMAP]); + + err = taprio_parse_mqprio_opt(dev, mqprio, extack); + if (err < 0) + return err; + + /* A schedule with less than one entry is an error */ + size = parse_taprio_opt(tb, q, extack); + if (size < 0) + return size; + + hrtimer_init(&q->advance_timer, q->clockid, HRTIMER_MODE_ABS); + q->advance_timer.function = advance_sched; + + switch (q->clockid) { + case CLOCK_REALTIME: + q->get_time = ktime_get_real; + break; + case CLOCK_MONOTONIC: + q->get_time = ktime_get; + break; + case CLOCK_BOOTTIME: + q->get_time = ktime_get_boottime; + break; + case CLOCK_TAI: + q->get_time = ktime_get_clocktai; + break; + default: + return -ENOTSUPP; + } + + for (i = 0; i < dev->num_tx_queues; i++) { + struct netdev_queue *dev_queue; + struct Qdisc *qdisc; + + dev_queue = netdev_get_tx_queue(dev, i); + qdisc = qdisc_create_dflt(dev_queue, + &pfifo_qdisc_ops, + TC_H_MAKE(TC_H_MAJ(sch->handle), + TC_H_MIN(i + 1)), + extack); + if (!qdisc) + return -ENOMEM; + + if (i < dev->real_num_tx_queues) + qdisc_hash_add(qdisc, false); + + q->qdiscs[i] = qdisc; + } + + if (mqprio) { + netdev_set_num_tc(dev, mqprio->num_tc); + for (i = 0; i < mqprio->num_tc; i++) + netdev_set_tc_queue(dev, i, + mqprio->count[i], + mqprio->offset[i]); + + /* Always use supplied priority mappings */ + for (i = 0; i < TC_BITMASK + 1; i++) + netdev_set_prio_tc_map(dev, i, + mqprio->prio_tc_map[i]); + } + + if (!__ethtool_get_link_ksettings(dev, &ecmd)) + link_speed = ecmd.base.speed; + else + link_speed = SPEED_1000; + + q->picos_per_byte = div64_s64(NSEC_PER_SEC * 1000LL * 8, + link_speed * 1000 * 1000); + + start = taprio_get_start_time(sch); + if (!start) + return 0; + + taprio_start_sched(sch, start); + + return 0; +} + +static void taprio_destroy(struct Qdisc *sch) +{ + struct taprio_sched *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + struct sched_entry *entry, *n; + unsigned int i; + + hrtimer_cancel(&q->advance_timer); + + if (q->qdiscs) { + for (i = 0; i < dev->num_tx_queues && q->qdiscs[i]; i++) + qdisc_put(q->qdiscs[i]); + + kfree(q->qdiscs); + } + q->qdiscs = NULL; + + netdev_set_num_tc(dev, 0); + + list_for_each_entry_safe(entry, n, &q->entries, list) { + list_del(&entry->list); + kfree(entry); + } +} + +static int taprio_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct taprio_sched *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + + INIT_LIST_HEAD(&q->entries); + spin_lock_init(&q->current_entry_lock); + + /* We may overwrite the configuration later */ + hrtimer_init(&q->advance_timer, CLOCK_TAI, HRTIMER_MODE_ABS); + + q->root = sch; + + /* We only support static clockids. Use an invalid value as default + * and get the valid one on taprio_change(). + */ + q->clockid = -1; + + if (sch->parent != TC_H_ROOT) + return -EOPNOTSUPP; + + if (!netif_is_multiqueue(dev)) + return -EOPNOTSUPP; + + /* pre-allocate qdisc, attachment can't fail */ + q->qdiscs = kcalloc(dev->num_tx_queues, + sizeof(q->qdiscs[0]), + GFP_KERNEL); + + if (!q->qdiscs) + return -ENOMEM; + + if (!opt) + return -EINVAL; + + return taprio_change(sch, opt, extack); +} + +static struct netdev_queue *taprio_queue_get(struct Qdisc *sch, + unsigned long cl) +{ + struct net_device *dev = qdisc_dev(sch); + unsigned long ntx = cl - 1; + + if (ntx >= dev->num_tx_queues) + return NULL; + + return netdev_get_tx_queue(dev, ntx); +} + +static int taprio_graft(struct Qdisc *sch, unsigned long cl, + struct Qdisc *new, struct Qdisc **old, + struct netlink_ext_ack *extack) +{ + struct taprio_sched *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + struct netdev_queue *dev_queue = taprio_queue_get(sch, cl); + + if (!dev_queue) + return -EINVAL; + + if (dev->flags & IFF_UP) + dev_deactivate(dev); + + *old = q->qdiscs[cl - 1]; + q->qdiscs[cl - 1] = new; + + if (new) + new->flags |= TCQ_F_ONETXQUEUE | TCQ_F_NOPARENT; + + if (dev->flags & IFF_UP) + dev_activate(dev); + + return 0; +} + +static int dump_entry(struct sk_buff *msg, + const struct sched_entry *entry) +{ + struct nlattr *item; + + item = nla_nest_start(msg, TCA_TAPRIO_SCHED_ENTRY); + if (!item) + return -ENOSPC; + + if (nla_put_u32(msg, TCA_TAPRIO_SCHED_ENTRY_INDEX, entry->index)) + goto nla_put_failure; + + if (nla_put_u8(msg, TCA_TAPRIO_SCHED_ENTRY_CMD, entry->command)) + goto nla_put_failure; + + if (nla_put_u32(msg, TCA_TAPRIO_SCHED_ENTRY_GATE_MASK, + entry->gate_mask)) + goto nla_put_failure; + + if (nla_put_u32(msg, TCA_TAPRIO_SCHED_ENTRY_INTERVAL, + entry->interval)) + goto nla_put_failure; + + return nla_nest_end(msg, item); + +nla_put_failure: + nla_nest_cancel(msg, item); + return -1; +} + +static int taprio_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct taprio_sched *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + struct tc_mqprio_qopt opt = { 0 }; + struct nlattr *nest, *entry_list; + struct sched_entry *entry; + unsigned int i; + + opt.num_tc = netdev_get_num_tc(dev); + memcpy(opt.prio_tc_map, dev->prio_tc_map, sizeof(opt.prio_tc_map)); + + for (i = 0; i < netdev_get_num_tc(dev); i++) { + opt.count[i] = dev->tc_to_txq[i].count; + opt.offset[i] = dev->tc_to_txq[i].offset; + } + + nest = nla_nest_start(skb, TCA_OPTIONS); + if (!nest) + return -ENOSPC; + + if (nla_put(skb, TCA_TAPRIO_ATTR_PRIOMAP, sizeof(opt), &opt)) + goto options_error; + + if (nla_put_s64(skb, TCA_TAPRIO_ATTR_SCHED_BASE_TIME, + q->base_time, TCA_TAPRIO_PAD)) + goto options_error; + + if (nla_put_s32(skb, TCA_TAPRIO_ATTR_SCHED_CLOCKID, q->clockid)) + goto options_error; + + entry_list = nla_nest_start(skb, TCA_TAPRIO_ATTR_SCHED_ENTRY_LIST); + if (!entry_list) + goto options_error; + + list_for_each_entry(entry, &q->entries, list) { + if (dump_entry(skb, entry) < 0) + goto options_error; + } + + nla_nest_end(skb, entry_list); + + return nla_nest_end(skb, nest); + +options_error: + nla_nest_cancel(skb, nest); + return -1; +} + +static struct Qdisc *taprio_leaf(struct Qdisc *sch, unsigned long cl) +{ + struct netdev_queue *dev_queue = taprio_queue_get(sch, cl); + + if (!dev_queue) + return NULL; + + return dev_queue->qdisc_sleeping; +} + +static unsigned long taprio_find(struct Qdisc *sch, u32 classid) +{ + unsigned int ntx = TC_H_MIN(classid); + + if (!taprio_queue_get(sch, ntx)) + return 0; + return ntx; +} + +static int taprio_dump_class(struct Qdisc *sch, unsigned long cl, + struct sk_buff *skb, struct tcmsg *tcm) +{ + struct netdev_queue *dev_queue = taprio_queue_get(sch, cl); + + tcm->tcm_parent = TC_H_ROOT; + tcm->tcm_handle |= TC_H_MIN(cl); + tcm->tcm_info = dev_queue->qdisc_sleeping->handle; + + return 0; +} + +static int taprio_dump_class_stats(struct Qdisc *sch, unsigned long cl, + struct gnet_dump *d) + __releases(d->lock) + __acquires(d->lock) +{ + struct netdev_queue *dev_queue = taprio_queue_get(sch, cl); + + sch = dev_queue->qdisc_sleeping; + if (gnet_stats_copy_basic(&sch->running, d, NULL, &sch->bstats) < 0 || + gnet_stats_copy_queue(d, NULL, &sch->qstats, sch->q.qlen) < 0) + return -1; + return 0; +} + +static void taprio_walk(struct Qdisc *sch, struct qdisc_walker *arg) +{ + struct net_device *dev = qdisc_dev(sch); + unsigned long ntx; + + if (arg->stop) + return; + + arg->count = arg->skip; + for (ntx = arg->skip; ntx < dev->num_tx_queues; ntx++) { + if (arg->fn(sch, ntx + 1, arg) < 0) { + arg->stop = 1; + break; + } + arg->count++; + } +} + +static struct netdev_queue *taprio_select_queue(struct Qdisc *sch, + struct tcmsg *tcm) +{ + return taprio_queue_get(sch, TC_H_MIN(tcm->tcm_parent)); +} + +static const struct Qdisc_class_ops taprio_class_ops = { + .graft = taprio_graft, + .leaf = taprio_leaf, + .find = taprio_find, + .walk = taprio_walk, + .dump = taprio_dump_class, + .dump_stats = taprio_dump_class_stats, + .select_queue = taprio_select_queue, +}; + +static struct Qdisc_ops taprio_qdisc_ops __read_mostly = { + .cl_ops = &taprio_class_ops, + .id = "taprio", + .priv_size = sizeof(struct taprio_sched), + .init = taprio_init, + .destroy = taprio_destroy, + .peek = taprio_peek, + .dequeue = taprio_dequeue, + .enqueue = taprio_enqueue, + .dump = taprio_dump, + .owner = THIS_MODULE, +}; + +static int __init taprio_module_init(void) +{ + return register_qdisc(&taprio_qdisc_ops); +} + +static void __exit taprio_module_exit(void) +{ + unregister_qdisc(&taprio_qdisc_ops); +} + +module_init(taprio_module_init); +module_exit(taprio_module_exit); +MODULE_LICENSE("GPL"); diff --git a/net/sctp/associola.c b/net/sctp/associola.c index 297d9cf960b9..6a28b96e779e 100644 --- a/net/sctp/associola.c +++ b/net/sctp/associola.c @@ -499,8 +499,9 @@ void sctp_assoc_set_primary(struct sctp_association *asoc, void sctp_assoc_rm_peer(struct sctp_association *asoc, struct sctp_transport *peer) { - struct list_head *pos; - struct sctp_transport *transport; + struct sctp_transport *transport; + struct list_head *pos; + struct sctp_chunk *ch; pr_debug("%s: association:%p addr:%pISpc\n", __func__, asoc, &peer->ipaddr.sa); @@ -564,7 +565,6 @@ void sctp_assoc_rm_peer(struct sctp_association *asoc, */ if (!list_empty(&peer->transmitted)) { struct sctp_transport *active = asoc->peer.active_path; - struct sctp_chunk *ch; /* Reset the transport of each chunk on this list */ list_for_each_entry(ch, &peer->transmitted, @@ -586,6 +586,10 @@ void sctp_assoc_rm_peer(struct sctp_association *asoc, sctp_transport_hold(active); } + list_for_each_entry(ch, &asoc->outqueue.out_chunk_list, list) + if (ch->transport == peer) + ch->transport = NULL; + asoc->peer.transport_count--; sctp_transport_free(peer); @@ -1450,7 +1454,8 @@ void sctp_assoc_sync_pmtu(struct sctp_association *asoc) /* Get the lowest pmtu of all the transports. */ list_for_each_entry(t, &asoc->peer.transport_addr_list, transports) { if (t->pmtu_pending && t->dst) { - sctp_transport_update_pmtu(t, sctp_dst_mtu(t->dst)); + sctp_transport_update_pmtu(t, + atomic_read(&t->mtu_info)); t->pmtu_pending = 0; } if (!pmtu || (t->pathmtu < pmtu)) diff --git a/net/sctp/input.c b/net/sctp/input.c index 9bbc5f92c941..5c36a99882ed 100644 --- a/net/sctp/input.c +++ b/net/sctp/input.c @@ -395,6 +395,7 @@ void sctp_icmp_frag_needed(struct sock *sk, struct sctp_association *asoc, return; if (sock_owned_by_user(sk)) { + atomic_set(&t->mtu_info, pmtu); asoc->pmtu_pending = 1; t->pmtu_pending = 1; return; diff --git a/net/sctp/output.c b/net/sctp/output.c index 7f849b01ec8e..67939ad99c01 100644 --- a/net/sctp/output.c +++ b/net/sctp/output.c @@ -120,6 +120,12 @@ void sctp_packet_config(struct sctp_packet *packet, __u32 vtag, sctp_assoc_sync_pmtu(asoc); } + if (asoc->pmtu_pending) { + if (asoc->param_flags & SPP_PMTUD_ENABLE) + sctp_assoc_sync_pmtu(asoc); + asoc->pmtu_pending = 0; + } + /* If there a is a prepend chunk stick it on the list before * any other chunks get appended. */ diff --git a/net/sctp/outqueue.c b/net/sctp/outqueue.c index d74d00b29942..9cb854b05342 100644 --- a/net/sctp/outqueue.c +++ b/net/sctp/outqueue.c @@ -385,9 +385,7 @@ static int sctp_prsctp_prune_sent(struct sctp_association *asoc, asoc->outqueue.outstanding_bytes -= sctp_data_size(chk); } - msg_len -= SCTP_DATA_SNDSIZE(chk) + - sizeof(struct sk_buff) + - sizeof(struct sctp_chunk); + msg_len -= chk->skb->truesize + sizeof(struct sctp_chunk); if (msg_len <= 0) break; } @@ -421,9 +419,7 @@ static int sctp_prsctp_prune_unsent(struct sctp_association *asoc, streamout->ext->abandoned_unsent[SCTP_PR_INDEX(PRIO)]++; } - msg_len -= SCTP_DATA_SNDSIZE(chk) + - sizeof(struct sk_buff) + - sizeof(struct sctp_chunk); + msg_len -= chk->skb->truesize + sizeof(struct sctp_chunk); sctp_chunk_free(chk); if (msg_len <= 0) break; @@ -1048,7 +1044,7 @@ static void sctp_outq_flush_data(struct sctp_flush_ctx *ctx, if (!ctx->packet || !ctx->packet->has_cookie_echo) return; - /* fallthru */ + /* fall through */ case SCTP_STATE_ESTABLISHED: case SCTP_STATE_SHUTDOWN_PENDING: case SCTP_STATE_SHUTDOWN_RECEIVED: diff --git a/net/sctp/protocol.c b/net/sctp/protocol.c index e948db29ab53..9b277bd36d1a 100644 --- a/net/sctp/protocol.c +++ b/net/sctp/protocol.c @@ -46,7 +46,7 @@ #include <linux/netdevice.h> #include <linux/inetdevice.h> #include <linux/seq_file.h> -#include <linux/bootmem.h> +#include <linux/memblock.h> #include <linux/highmem.h> #include <linux/swap.h> #include <linux/slab.h> diff --git a/net/sctp/socket.c b/net/sctp/socket.c index f73e9d38d5ba..739f3e50120d 100644 --- a/net/sctp/socket.c +++ b/net/sctp/socket.c @@ -83,7 +83,7 @@ #include <net/sctp/stream_sched.h> /* Forward declarations for internal helper functions. */ -static int sctp_writeable(struct sock *sk); +static bool sctp_writeable(struct sock *sk); static void sctp_wfree(struct sk_buff *skb); static int sctp_wait_for_sndbuf(struct sctp_association *asoc, long *timeo_p, size_t msg_len); @@ -119,25 +119,10 @@ static void sctp_enter_memory_pressure(struct sock *sk) /* Get the sndbuf space available at the time on the association. */ static inline int sctp_wspace(struct sctp_association *asoc) { - int amt; + struct sock *sk = asoc->base.sk; - if (asoc->ep->sndbuf_policy) - amt = asoc->sndbuf_used; - else - amt = sk_wmem_alloc_get(asoc->base.sk); - - if (amt >= asoc->base.sk->sk_sndbuf) { - if (asoc->base.sk->sk_userlocks & SOCK_SNDBUF_LOCK) - amt = 0; - else { - amt = sk_stream_wspace(asoc->base.sk); - if (amt < 0) - amt = 0; - } - } else { - amt = asoc->base.sk->sk_sndbuf - amt; - } - return amt; + return asoc->ep->sndbuf_policy ? sk->sk_sndbuf - asoc->sndbuf_used + : sk_stream_wspace(sk); } /* Increment the used sndbuf space count of the corresponding association by @@ -166,12 +151,9 @@ static inline void sctp_set_owner_w(struct sctp_chunk *chunk) /* Save the chunk pointer in skb for sctp_wfree to use later. */ skb_shinfo(chunk->skb)->destructor_arg = chunk; - asoc->sndbuf_used += SCTP_DATA_SNDSIZE(chunk) + - sizeof(struct sk_buff) + - sizeof(struct sctp_chunk); - refcount_add(sizeof(struct sctp_chunk), &sk->sk_wmem_alloc); - sk->sk_wmem_queued += chunk->skb->truesize; + asoc->sndbuf_used += chunk->skb->truesize + sizeof(struct sctp_chunk); + sk->sk_wmem_queued += chunk->skb->truesize + sizeof(struct sctp_chunk); sk_mem_charge(sk, chunk->skb->truesize); } @@ -271,11 +253,10 @@ struct sctp_association *sctp_id2assoc(struct sock *sk, sctp_assoc_t id) spin_lock_bh(&sctp_assocs_id_lock); asoc = (struct sctp_association *)idr_find(&sctp_assocs_id, (int)id); + if (asoc && (asoc->base.sk != sk || asoc->base.dead)) + asoc = NULL; spin_unlock_bh(&sctp_assocs_id_lock); - if (!asoc || (asoc->base.sk != sk) || asoc->base.dead) - return NULL; - return asoc; } @@ -1928,10 +1909,10 @@ static int sctp_sendmsg_to_asoc(struct sctp_association *asoc, asoc->pmtu_pending = 0; } - if (sctp_wspace(asoc) < msg_len) + if (sctp_wspace(asoc) < (int)msg_len) sctp_prsctp_prune(asoc, sinfo, msg_len - sctp_wspace(asoc)); - if (!sctp_wspace(asoc)) { + if (sctp_wspace(asoc) <= 0) { timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); err = sctp_wait_for_sndbuf(asoc, &timeo, msg_len); if (err) @@ -1946,8 +1927,10 @@ static int sctp_sendmsg_to_asoc(struct sctp_association *asoc, if (sp->strm_interleave) { timeo = sock_sndtimeo(sk, 0); err = sctp_wait_for_connect(asoc, &timeo); - if (err) + if (err) { + err = -ESRCH; goto err; + } } else { wait_connect = true; } @@ -7100,14 +7083,15 @@ static int sctp_getsockopt_pr_assocstatus(struct sock *sk, int len, } policy = params.sprstat_policy; - if (policy & ~SCTP_PR_SCTP_MASK) + if (!policy || (policy & ~(SCTP_PR_SCTP_MASK | SCTP_PR_SCTP_ALL)) || + ((policy & SCTP_PR_SCTP_ALL) && (policy & SCTP_PR_SCTP_MASK))) goto out; asoc = sctp_id2assoc(sk, params.sprstat_assoc_id); if (!asoc) goto out; - if (policy == SCTP_PR_SCTP_NONE) { + if (policy == SCTP_PR_SCTP_ALL) { params.sprstat_abandoned_unsent = 0; params.sprstat_abandoned_sent = 0; for (policy = 0; policy <= SCTP_PR_INDEX(MAX); policy++) { @@ -7159,7 +7143,8 @@ static int sctp_getsockopt_pr_streamstatus(struct sock *sk, int len, } policy = params.sprstat_policy; - if (policy & ~SCTP_PR_SCTP_MASK) + if (!policy || (policy & ~(SCTP_PR_SCTP_MASK | SCTP_PR_SCTP_ALL)) || + ((policy & SCTP_PR_SCTP_ALL) && (policy & SCTP_PR_SCTP_MASK))) goto out; asoc = sctp_id2assoc(sk, params.sprstat_assoc_id); @@ -7175,7 +7160,7 @@ static int sctp_getsockopt_pr_streamstatus(struct sock *sk, int len, goto out; } - if (policy == SCTP_PR_SCTP_NONE) { + if (policy == SCTP_PR_SCTP_ALL) { params.sprstat_abandoned_unsent = 0; params.sprstat_abandoned_sent = 0; for (policy = 0; policy <= SCTP_PR_INDEX(MAX); policy++) { @@ -8460,17 +8445,11 @@ static void sctp_wfree(struct sk_buff *skb) struct sctp_association *asoc = chunk->asoc; struct sock *sk = asoc->base.sk; - asoc->sndbuf_used -= SCTP_DATA_SNDSIZE(chunk) + - sizeof(struct sk_buff) + - sizeof(struct sctp_chunk); - - WARN_ON(refcount_sub_and_test(sizeof(struct sctp_chunk), &sk->sk_wmem_alloc)); - - /* - * This undoes what is done via sctp_set_owner_w and sk_mem_charge - */ - sk->sk_wmem_queued -= skb->truesize; sk_mem_uncharge(sk, skb->truesize); + sk->sk_wmem_queued -= skb->truesize + sizeof(struct sctp_chunk); + asoc->sndbuf_used -= skb->truesize + sizeof(struct sctp_chunk); + WARN_ON(refcount_sub_and_test(sizeof(struct sctp_chunk), + &sk->sk_wmem_alloc)); if (chunk->shkey) { struct sctp_shared_key *shkey = chunk->shkey; @@ -8544,7 +8523,7 @@ static int sctp_wait_for_sndbuf(struct sctp_association *asoc, long *timeo_p, goto do_error; if (signal_pending(current)) goto do_interrupted; - if (msg_len <= sctp_wspace(asoc)) + if ((int)msg_len <= sctp_wspace(asoc)) break; /* Let another process have a go. Since we are going @@ -8619,14 +8598,9 @@ void sctp_write_space(struct sock *sk) * UDP-style sockets or TCP-style sockets, this code should work. * - Daisy */ -static int sctp_writeable(struct sock *sk) +static bool sctp_writeable(struct sock *sk) { - int amt = 0; - - amt = sk->sk_sndbuf - sk_wmem_alloc_get(sk); - if (amt < 0) - amt = 0; - return amt; + return sk->sk_sndbuf > sk->sk_wmem_queued; } /* Wait for an association to go into ESTABLISHED state. If timeout is 0, diff --git a/net/smc/af_smc.c b/net/smc/af_smc.c index 015231789ed2..80e2119f1c70 100644 --- a/net/smc/af_smc.c +++ b/net/smc/af_smc.c @@ -1543,7 +1543,7 @@ static __poll_t smc_poll(struct file *file, struct socket *sock, mask |= EPOLLERR; } else { if (sk->sk_state != SMC_CLOSED) - sock_poll_wait(file, wait); + sock_poll_wait(file, sock, wait); if (sk->sk_err) mask |= EPOLLERR; if ((sk->sk_shutdown == SHUTDOWN_MASK) || diff --git a/net/smc/smc_clc.c b/net/smc/smc_clc.c index 52241d679cc9..89c3a8c7859a 100644 --- a/net/smc/smc_clc.c +++ b/net/smc/smc_clc.c @@ -286,7 +286,7 @@ int smc_clc_wait_msg(struct smc_sock *smc, void *buf, int buflen, */ krflags = MSG_PEEK | MSG_WAITALL; smc->clcsock->sk->sk_rcvtimeo = CLC_WAIT_TIME; - iov_iter_kvec(&msg.msg_iter, READ | ITER_KVEC, &vec, 1, + iov_iter_kvec(&msg.msg_iter, READ, &vec, 1, sizeof(struct smc_clc_msg_hdr)); len = sock_recvmsg(smc->clcsock, &msg, krflags); if (signal_pending(current)) { @@ -325,7 +325,7 @@ int smc_clc_wait_msg(struct smc_sock *smc, void *buf, int buflen, /* receive the complete CLC message */ memset(&msg, 0, sizeof(struct msghdr)); - iov_iter_kvec(&msg.msg_iter, READ | ITER_KVEC, &vec, 1, datlen); + iov_iter_kvec(&msg.msg_iter, READ, &vec, 1, datlen); krflags = MSG_WAITALL; len = sock_recvmsg(smc->clcsock, &msg, krflags); if (len < datlen || !smc_clc_msg_hdr_valid(clcm)) { diff --git a/net/smc/smc_core.c b/net/smc/smc_core.c index e871368500e3..18daebcef181 100644 --- a/net/smc/smc_core.c +++ b/net/smc/smc_core.c @@ -122,22 +122,17 @@ static void __smc_lgr_unregister_conn(struct smc_connection *conn) sock_put(&smc->sk); /* sock_hold in smc_lgr_register_conn() */ } -/* Unregister connection and trigger lgr freeing if applicable +/* Unregister connection from lgr */ static void smc_lgr_unregister_conn(struct smc_connection *conn) { struct smc_link_group *lgr = conn->lgr; - int reduced = 0; write_lock_bh(&lgr->conns_lock); if (conn->alert_token_local) { - reduced = 1; __smc_lgr_unregister_conn(conn); } write_unlock_bh(&lgr->conns_lock); - if (!reduced || lgr->conns_num) - return; - smc_lgr_schedule_free_work(lgr); } /* Send delete link, either as client to request the initiation @@ -291,7 +286,8 @@ out: return rc; } -static void smc_buf_unuse(struct smc_connection *conn) +static void smc_buf_unuse(struct smc_connection *conn, + struct smc_link_group *lgr) { if (conn->sndbuf_desc) conn->sndbuf_desc->used = 0; @@ -301,8 +297,6 @@ static void smc_buf_unuse(struct smc_connection *conn) conn->rmb_desc->used = 0; } else { /* buf registration failed, reuse not possible */ - struct smc_link_group *lgr = conn->lgr; - write_lock_bh(&lgr->rmbs_lock); list_del(&conn->rmb_desc->list); write_unlock_bh(&lgr->rmbs_lock); @@ -315,16 +309,21 @@ static void smc_buf_unuse(struct smc_connection *conn) /* remove a finished connection from its link group */ void smc_conn_free(struct smc_connection *conn) { - if (!conn->lgr) + struct smc_link_group *lgr = conn->lgr; + + if (!lgr) return; - if (conn->lgr->is_smcd) { + if (lgr->is_smcd) { smc_ism_unset_conn(conn); tasklet_kill(&conn->rx_tsklet); } else { smc_cdc_tx_dismiss_slots(conn); } - smc_lgr_unregister_conn(conn); - smc_buf_unuse(conn); + smc_lgr_unregister_conn(conn); /* unsets conn->lgr */ + smc_buf_unuse(conn, lgr); /* allow buffer reuse */ + + if (!lgr->conns_num) + smc_lgr_schedule_free_work(lgr); } static void smc_link_clear(struct smc_link *lnk) diff --git a/net/socket.c b/net/socket.c index 01f3f8f32d6f..593826e11a53 100644 --- a/net/socket.c +++ b/net/socket.c @@ -635,7 +635,7 @@ EXPORT_SYMBOL(sock_sendmsg); int kernel_sendmsg(struct socket *sock, struct msghdr *msg, struct kvec *vec, size_t num, size_t size) { - iov_iter_kvec(&msg->msg_iter, WRITE | ITER_KVEC, vec, num, size); + iov_iter_kvec(&msg->msg_iter, WRITE, vec, num, size); return sock_sendmsg(sock, msg); } EXPORT_SYMBOL(kernel_sendmsg); @@ -648,7 +648,7 @@ int kernel_sendmsg_locked(struct sock *sk, struct msghdr *msg, if (!sock->ops->sendmsg_locked) return sock_no_sendmsg_locked(sk, msg, size); - iov_iter_kvec(&msg->msg_iter, WRITE | ITER_KVEC, vec, num, size); + iov_iter_kvec(&msg->msg_iter, WRITE, vec, num, size); return sock->ops->sendmsg_locked(sk, msg, msg_data_left(msg)); } @@ -823,7 +823,7 @@ int kernel_recvmsg(struct socket *sock, struct msghdr *msg, mm_segment_t oldfs = get_fs(); int result; - iov_iter_kvec(&msg->msg_iter, READ | ITER_KVEC, vec, num, size); + iov_iter_kvec(&msg->msg_iter, READ, vec, num, size); set_fs(KERNEL_DS); result = sock_recvmsg(sock, msg, flags); set_fs(oldfs); @@ -1475,7 +1475,7 @@ int __sys_bind(int fd, struct sockaddr __user *umyaddr, int addrlen) sock = sockfd_lookup_light(fd, &err, &fput_needed); if (sock) { err = move_addr_to_kernel(umyaddr, addrlen, &address); - if (err >= 0) { + if (!err) { err = security_socket_bind(sock, (struct sockaddr *)&address, addrlen); @@ -2342,7 +2342,7 @@ SYSCALL_DEFINE3(recvmsg, int, fd, struct user_msghdr __user *, msg, */ int __sys_recvmmsg(int fd, struct mmsghdr __user *mmsg, unsigned int vlen, - unsigned int flags, struct timespec *timeout) + unsigned int flags, struct timespec64 *timeout) { int fput_needed, err, datagrams; struct socket *sock; @@ -2407,8 +2407,7 @@ int __sys_recvmmsg(int fd, struct mmsghdr __user *mmsg, unsigned int vlen, if (timeout) { ktime_get_ts64(&timeout64); - *timeout = timespec64_to_timespec( - timespec64_sub(end_time, timeout64)); + *timeout = timespec64_sub(end_time, timeout64); if (timeout->tv_sec < 0) { timeout->tv_sec = timeout->tv_nsec = 0; break; @@ -2454,10 +2453,10 @@ out_put: static int do_sys_recvmmsg(int fd, struct mmsghdr __user *mmsg, unsigned int vlen, unsigned int flags, - struct timespec __user *timeout) + struct __kernel_timespec __user *timeout) { int datagrams; - struct timespec timeout_sys; + struct timespec64 timeout_sys; if (flags & MSG_CMSG_COMPAT) return -EINVAL; @@ -2465,13 +2464,12 @@ static int do_sys_recvmmsg(int fd, struct mmsghdr __user *mmsg, if (!timeout) return __sys_recvmmsg(fd, mmsg, vlen, flags, NULL); - if (copy_from_user(&timeout_sys, timeout, sizeof(timeout_sys))) + if (get_timespec64(&timeout_sys, timeout)) return -EFAULT; datagrams = __sys_recvmmsg(fd, mmsg, vlen, flags, &timeout_sys); - if (datagrams > 0 && - copy_to_user(timeout, &timeout_sys, sizeof(timeout_sys))) + if (datagrams > 0 && put_timespec64(&timeout_sys, timeout)) datagrams = -EFAULT; return datagrams; @@ -2479,7 +2477,7 @@ static int do_sys_recvmmsg(int fd, struct mmsghdr __user *mmsg, SYSCALL_DEFINE5(recvmmsg, int, fd, struct mmsghdr __user *, mmsg, unsigned int, vlen, unsigned int, flags, - struct timespec __user *, timeout) + struct __kernel_timespec __user *, timeout) { return do_sys_recvmmsg(fd, mmsg, vlen, flags, timeout); } @@ -2603,7 +2601,7 @@ SYSCALL_DEFINE2(socketcall, int, call, unsigned long __user *, args) break; case SYS_RECVMMSG: err = do_sys_recvmmsg(a0, (struct mmsghdr __user *)a1, a[2], - a[3], (struct timespec __user *)a[4]); + a[3], (struct __kernel_timespec __user *)a[4]); break; case SYS_ACCEPT4: err = __sys_accept4(a0, (struct sockaddr __user *)a1, @@ -2875,9 +2873,14 @@ static int ethtool_ioctl(struct net *net, struct compat_ifreq __user *ifr32) copy_in_user(&rxnfc->fs.ring_cookie, &compat_rxnfc->fs.ring_cookie, (void __user *)(&rxnfc->fs.location + 1) - - (void __user *)&rxnfc->fs.ring_cookie) || - copy_in_user(&rxnfc->rule_cnt, &compat_rxnfc->rule_cnt, - sizeof(rxnfc->rule_cnt))) + (void __user *)&rxnfc->fs.ring_cookie)) + return -EFAULT; + if (ethcmd == ETHTOOL_GRXCLSRLALL) { + if (put_user(rule_cnt, &rxnfc->rule_cnt)) + return -EFAULT; + } else if (copy_in_user(&rxnfc->rule_cnt, + &compat_rxnfc->rule_cnt, + sizeof(rxnfc->rule_cnt))) return -EFAULT; } diff --git a/net/strparser/Kconfig b/net/strparser/Kconfig index 6cff3f6d0c3a..94da19a2a220 100644 --- a/net/strparser/Kconfig +++ b/net/strparser/Kconfig @@ -1,4 +1,2 @@ - config STREAM_PARSER - tristate - default n + def_bool n diff --git a/net/sunrpc/auth.c b/net/sunrpc/auth.c index 305ecea92170..ad8ead738981 100644 --- a/net/sunrpc/auth.c +++ b/net/sunrpc/auth.c @@ -30,10 +30,9 @@ struct rpc_cred_cache { static unsigned int auth_hashbits = RPC_CREDCACHE_DEFAULT_HASHBITS; -static DEFINE_SPINLOCK(rpc_authflavor_lock); -static const struct rpc_authops *auth_flavors[RPC_AUTH_MAXFLAVOR] = { - &authnull_ops, /* AUTH_NULL */ - &authunix_ops, /* AUTH_UNIX */ +static const struct rpc_authops __rcu *auth_flavors[RPC_AUTH_MAXFLAVOR] = { + [RPC_AUTH_NULL] = (const struct rpc_authops __force __rcu *)&authnull_ops, + [RPC_AUTH_UNIX] = (const struct rpc_authops __force __rcu *)&authunix_ops, NULL, /* others can be loadable modules */ }; @@ -93,39 +92,65 @@ pseudoflavor_to_flavor(u32 flavor) { int rpcauth_register(const struct rpc_authops *ops) { + const struct rpc_authops *old; rpc_authflavor_t flavor; - int ret = -EPERM; if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR) return -EINVAL; - spin_lock(&rpc_authflavor_lock); - if (auth_flavors[flavor] == NULL) { - auth_flavors[flavor] = ops; - ret = 0; - } - spin_unlock(&rpc_authflavor_lock); - return ret; + old = cmpxchg((const struct rpc_authops ** __force)&auth_flavors[flavor], NULL, ops); + if (old == NULL || old == ops) + return 0; + return -EPERM; } EXPORT_SYMBOL_GPL(rpcauth_register); int rpcauth_unregister(const struct rpc_authops *ops) { + const struct rpc_authops *old; rpc_authflavor_t flavor; - int ret = -EPERM; if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR) return -EINVAL; - spin_lock(&rpc_authflavor_lock); - if (auth_flavors[flavor] == ops) { - auth_flavors[flavor] = NULL; - ret = 0; - } - spin_unlock(&rpc_authflavor_lock); - return ret; + + old = cmpxchg((const struct rpc_authops ** __force)&auth_flavors[flavor], ops, NULL); + if (old == ops || old == NULL) + return 0; + return -EPERM; } EXPORT_SYMBOL_GPL(rpcauth_unregister); +static const struct rpc_authops * +rpcauth_get_authops(rpc_authflavor_t flavor) +{ + const struct rpc_authops *ops; + + if (flavor >= RPC_AUTH_MAXFLAVOR) + return NULL; + + rcu_read_lock(); + ops = rcu_dereference(auth_flavors[flavor]); + if (ops == NULL) { + rcu_read_unlock(); + request_module("rpc-auth-%u", flavor); + rcu_read_lock(); + ops = rcu_dereference(auth_flavors[flavor]); + if (ops == NULL) + goto out; + } + if (!try_module_get(ops->owner)) + ops = NULL; +out: + rcu_read_unlock(); + return ops; +} + +static void +rpcauth_put_authops(const struct rpc_authops *ops) +{ + module_put(ops->owner); +} + /** * rpcauth_get_pseudoflavor - check if security flavor is supported * @flavor: a security flavor @@ -138,25 +163,16 @@ EXPORT_SYMBOL_GPL(rpcauth_unregister); rpc_authflavor_t rpcauth_get_pseudoflavor(rpc_authflavor_t flavor, struct rpcsec_gss_info *info) { - const struct rpc_authops *ops; + const struct rpc_authops *ops = rpcauth_get_authops(flavor); rpc_authflavor_t pseudoflavor; - ops = auth_flavors[flavor]; - if (ops == NULL) - request_module("rpc-auth-%u", flavor); - spin_lock(&rpc_authflavor_lock); - ops = auth_flavors[flavor]; - if (ops == NULL || !try_module_get(ops->owner)) { - spin_unlock(&rpc_authflavor_lock); + if (!ops) return RPC_AUTH_MAXFLAVOR; - } - spin_unlock(&rpc_authflavor_lock); - pseudoflavor = flavor; if (ops->info2flavor != NULL) pseudoflavor = ops->info2flavor(info); - module_put(ops->owner); + rpcauth_put_authops(ops); return pseudoflavor; } EXPORT_SYMBOL_GPL(rpcauth_get_pseudoflavor); @@ -176,25 +192,15 @@ rpcauth_get_gssinfo(rpc_authflavor_t pseudoflavor, struct rpcsec_gss_info *info) const struct rpc_authops *ops; int result; - if (flavor >= RPC_AUTH_MAXFLAVOR) - return -EINVAL; - - ops = auth_flavors[flavor]; + ops = rpcauth_get_authops(flavor); if (ops == NULL) - request_module("rpc-auth-%u", flavor); - spin_lock(&rpc_authflavor_lock); - ops = auth_flavors[flavor]; - if (ops == NULL || !try_module_get(ops->owner)) { - spin_unlock(&rpc_authflavor_lock); return -ENOENT; - } - spin_unlock(&rpc_authflavor_lock); result = -ENOENT; if (ops->flavor2info != NULL) result = ops->flavor2info(pseudoflavor, info); - module_put(ops->owner); + rpcauth_put_authops(ops); return result; } EXPORT_SYMBOL_GPL(rpcauth_get_gssinfo); @@ -212,15 +218,13 @@ EXPORT_SYMBOL_GPL(rpcauth_get_gssinfo); int rpcauth_list_flavors(rpc_authflavor_t *array, int size) { - rpc_authflavor_t flavor; - int result = 0; + const struct rpc_authops *ops; + rpc_authflavor_t flavor, pseudos[4]; + int i, len, result = 0; - spin_lock(&rpc_authflavor_lock); + rcu_read_lock(); for (flavor = 0; flavor < RPC_AUTH_MAXFLAVOR; flavor++) { - const struct rpc_authops *ops = auth_flavors[flavor]; - rpc_authflavor_t pseudos[4]; - int i, len; - + ops = rcu_dereference(auth_flavors[flavor]); if (result >= size) { result = -ENOMEM; break; @@ -245,7 +249,7 @@ rpcauth_list_flavors(rpc_authflavor_t *array, int size) array[result++] = pseudos[i]; } } - spin_unlock(&rpc_authflavor_lock); + rcu_read_unlock(); dprintk("RPC: %s returns %d\n", __func__, result); return result; @@ -255,25 +259,17 @@ EXPORT_SYMBOL_GPL(rpcauth_list_flavors); struct rpc_auth * rpcauth_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) { - struct rpc_auth *auth; + struct rpc_auth *auth = ERR_PTR(-EINVAL); const struct rpc_authops *ops; - u32 flavor = pseudoflavor_to_flavor(args->pseudoflavor); + u32 flavor = pseudoflavor_to_flavor(args->pseudoflavor); - auth = ERR_PTR(-EINVAL); - if (flavor >= RPC_AUTH_MAXFLAVOR) + ops = rpcauth_get_authops(flavor); + if (ops == NULL) goto out; - if ((ops = auth_flavors[flavor]) == NULL) - request_module("rpc-auth-%u", flavor); - spin_lock(&rpc_authflavor_lock); - ops = auth_flavors[flavor]; - if (ops == NULL || !try_module_get(ops->owner)) { - spin_unlock(&rpc_authflavor_lock); - goto out; - } - spin_unlock(&rpc_authflavor_lock); auth = ops->create(args, clnt); - module_put(ops->owner); + + rpcauth_put_authops(ops); if (IS_ERR(auth)) return auth; if (clnt->cl_auth) @@ -288,32 +284,37 @@ EXPORT_SYMBOL_GPL(rpcauth_create); void rpcauth_release(struct rpc_auth *auth) { - if (!atomic_dec_and_test(&auth->au_count)) + if (!refcount_dec_and_test(&auth->au_count)) return; auth->au_ops->destroy(auth); } static DEFINE_SPINLOCK(rpc_credcache_lock); -static void +/* + * On success, the caller is responsible for freeing the reference + * held by the hashtable + */ +static bool rpcauth_unhash_cred_locked(struct rpc_cred *cred) { + if (!test_and_clear_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags)) + return false; hlist_del_rcu(&cred->cr_hash); - smp_mb__before_atomic(); - clear_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags); + return true; } -static int +static bool rpcauth_unhash_cred(struct rpc_cred *cred) { spinlock_t *cache_lock; - int ret; + bool ret; + if (!test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags)) + return false; cache_lock = &cred->cr_auth->au_credcache->lock; spin_lock(cache_lock); - ret = atomic_read(&cred->cr_count) == 0; - if (ret) - rpcauth_unhash_cred_locked(cred); + ret = rpcauth_unhash_cred_locked(cred); spin_unlock(cache_lock); return ret; } @@ -392,6 +393,44 @@ void rpcauth_destroy_credlist(struct list_head *head) } } +static void +rpcauth_lru_add_locked(struct rpc_cred *cred) +{ + if (!list_empty(&cred->cr_lru)) + return; + number_cred_unused++; + list_add_tail(&cred->cr_lru, &cred_unused); +} + +static void +rpcauth_lru_add(struct rpc_cred *cred) +{ + if (!list_empty(&cred->cr_lru)) + return; + spin_lock(&rpc_credcache_lock); + rpcauth_lru_add_locked(cred); + spin_unlock(&rpc_credcache_lock); +} + +static void +rpcauth_lru_remove_locked(struct rpc_cred *cred) +{ + if (list_empty(&cred->cr_lru)) + return; + number_cred_unused--; + list_del_init(&cred->cr_lru); +} + +static void +rpcauth_lru_remove(struct rpc_cred *cred) +{ + if (list_empty(&cred->cr_lru)) + return; + spin_lock(&rpc_credcache_lock); + rpcauth_lru_remove_locked(cred); + spin_unlock(&rpc_credcache_lock); +} + /* * Clear the RPC credential cache, and delete those credentials * that are not referenced. @@ -411,13 +450,10 @@ rpcauth_clear_credcache(struct rpc_cred_cache *cache) head = &cache->hashtable[i]; while (!hlist_empty(head)) { cred = hlist_entry(head->first, struct rpc_cred, cr_hash); - get_rpccred(cred); - if (!list_empty(&cred->cr_lru)) { - list_del(&cred->cr_lru); - number_cred_unused--; - } - list_add_tail(&cred->cr_lru, &free); rpcauth_unhash_cred_locked(cred); + /* Note: We now hold a reference to cred */ + rpcauth_lru_remove_locked(cred); + list_add_tail(&cred->cr_lru, &free); } } spin_unlock(&cache->lock); @@ -451,7 +487,6 @@ EXPORT_SYMBOL_GPL(rpcauth_destroy_credcache); static long rpcauth_prune_expired(struct list_head *free, int nr_to_scan) { - spinlock_t *cache_lock; struct rpc_cred *cred, *next; unsigned long expired = jiffies - RPC_AUTH_EXPIRY_MORATORIUM; long freed = 0; @@ -460,32 +495,24 @@ rpcauth_prune_expired(struct list_head *free, int nr_to_scan) if (nr_to_scan-- == 0) break; + if (refcount_read(&cred->cr_count) > 1) { + rpcauth_lru_remove_locked(cred); + continue; + } /* * Enforce a 60 second garbage collection moratorium * Note that the cred_unused list must be time-ordered. */ - if (time_in_range(cred->cr_expire, expired, jiffies) && - test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0) { - freed = SHRINK_STOP; - break; - } - - list_del_init(&cred->cr_lru); - number_cred_unused--; - freed++; - if (atomic_read(&cred->cr_count) != 0) + if (!time_in_range(cred->cr_expire, expired, jiffies)) + continue; + if (!rpcauth_unhash_cred(cred)) continue; - cache_lock = &cred->cr_auth->au_credcache->lock; - spin_lock(cache_lock); - if (atomic_read(&cred->cr_count) == 0) { - get_rpccred(cred); - list_add_tail(&cred->cr_lru, free); - rpcauth_unhash_cred_locked(cred); - } - spin_unlock(cache_lock); + rpcauth_lru_remove_locked(cred); + freed++; + list_add_tail(&cred->cr_lru, free); } - return freed; + return freed ? freed : SHRINK_STOP; } static unsigned long @@ -561,19 +588,15 @@ rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred, if (!entry->cr_ops->crmatch(acred, entry, flags)) continue; if (flags & RPCAUTH_LOOKUP_RCU) { - if (test_bit(RPCAUTH_CRED_HASHED, &entry->cr_flags) && - !test_bit(RPCAUTH_CRED_NEW, &entry->cr_flags)) - cred = entry; + if (test_bit(RPCAUTH_CRED_NEW, &entry->cr_flags) || + refcount_read(&entry->cr_count) == 0) + continue; + cred = entry; break; } - spin_lock(&cache->lock); - if (test_bit(RPCAUTH_CRED_HASHED, &entry->cr_flags) == 0) { - spin_unlock(&cache->lock); - continue; - } cred = get_rpccred(entry); - spin_unlock(&cache->lock); - break; + if (cred) + break; } rcu_read_unlock(); @@ -594,11 +617,13 @@ rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred, if (!entry->cr_ops->crmatch(acred, entry, flags)) continue; cred = get_rpccred(entry); - break; + if (cred) + break; } if (cred == NULL) { cred = new; set_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags); + refcount_inc(&cred->cr_count); hlist_add_head_rcu(&cred->cr_hash, &cache->hashtable[nr]); } else list_add_tail(&new->cr_lru, &free); @@ -645,7 +670,7 @@ rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred, { INIT_HLIST_NODE(&cred->cr_hash); INIT_LIST_HEAD(&cred->cr_lru); - atomic_set(&cred->cr_count, 1); + refcount_set(&cred->cr_count, 1); cred->cr_auth = auth; cred->cr_ops = ops; cred->cr_expire = jiffies; @@ -713,36 +738,29 @@ put_rpccred(struct rpc_cred *cred) { if (cred == NULL) return; - /* Fast path for unhashed credentials */ - if (test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) == 0) { - if (atomic_dec_and_test(&cred->cr_count)) - cred->cr_ops->crdestroy(cred); - return; - } - - if (!atomic_dec_and_lock(&cred->cr_count, &rpc_credcache_lock)) - return; - if (!list_empty(&cred->cr_lru)) { - number_cred_unused--; - list_del_init(&cred->cr_lru); - } - if (test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0) { - if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0) { - cred->cr_expire = jiffies; - list_add_tail(&cred->cr_lru, &cred_unused); - number_cred_unused++; - goto out_nodestroy; - } - if (!rpcauth_unhash_cred(cred)) { - /* We were hashed and someone looked us up... */ - goto out_nodestroy; - } + rcu_read_lock(); + if (refcount_dec_and_test(&cred->cr_count)) + goto destroy; + if (refcount_read(&cred->cr_count) != 1 || + !test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags)) + goto out; + if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0) { + cred->cr_expire = jiffies; + rpcauth_lru_add(cred); + /* Race breaker */ + if (unlikely(!test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))) + rpcauth_lru_remove(cred); + } else if (rpcauth_unhash_cred(cred)) { + rpcauth_lru_remove(cred); + if (refcount_dec_and_test(&cred->cr_count)) + goto destroy; } - spin_unlock(&rpc_credcache_lock); - cred->cr_ops->crdestroy(cred); +out: + rcu_read_unlock(); return; -out_nodestroy: - spin_unlock(&rpc_credcache_lock); +destroy: + rcu_read_unlock(); + cred->cr_ops->crdestroy(cred); } EXPORT_SYMBOL_GPL(put_rpccred); @@ -817,6 +835,16 @@ rpcauth_unwrap_resp(struct rpc_task *task, kxdrdproc_t decode, void *rqstp, return rpcauth_unwrap_req_decode(decode, rqstp, data, obj); } +bool +rpcauth_xmit_need_reencode(struct rpc_task *task) +{ + struct rpc_cred *cred = task->tk_rqstp->rq_cred; + + if (!cred || !cred->cr_ops->crneed_reencode) + return false; + return cred->cr_ops->crneed_reencode(task); +} + int rpcauth_refreshcred(struct rpc_task *task) { diff --git a/net/sunrpc/auth_generic.c b/net/sunrpc/auth_generic.c index f1df9837f1ac..d8831b988b1e 100644 --- a/net/sunrpc/auth_generic.c +++ b/net/sunrpc/auth_generic.c @@ -274,7 +274,7 @@ static const struct rpc_authops generic_auth_ops = { static struct rpc_auth generic_auth = { .au_ops = &generic_auth_ops, - .au_count = ATOMIC_INIT(0), + .au_count = REFCOUNT_INIT(1), }; static bool generic_key_to_expire(struct rpc_cred *cred) diff --git a/net/sunrpc/auth_gss/auth_gss.c b/net/sunrpc/auth_gss/auth_gss.c index 21c0aa0a0d1d..30f970cdc7f6 100644 --- a/net/sunrpc/auth_gss/auth_gss.c +++ b/net/sunrpc/auth_gss/auth_gss.c @@ -1058,7 +1058,7 @@ gss_create_new(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) auth->au_flavor = flavor; if (gss_pseudoflavor_to_datatouch(gss_auth->mech, flavor)) auth->au_flags |= RPCAUTH_AUTH_DATATOUCH; - atomic_set(&auth->au_count, 1); + refcount_set(&auth->au_count, 1); kref_init(&gss_auth->kref); err = rpcauth_init_credcache(auth); @@ -1187,7 +1187,7 @@ gss_auth_find_or_add_hashed(const struct rpc_auth_create_args *args, if (strcmp(gss_auth->target_name, args->target_name)) continue; } - if (!atomic_inc_not_zero(&gss_auth->rpc_auth.au_count)) + if (!refcount_inc_not_zero(&gss_auth->rpc_auth.au_count)) continue; goto out; } @@ -1984,6 +1984,46 @@ gss_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp, return decode(rqstp, &xdr, obj); } +static bool +gss_seq_is_newer(u32 new, u32 old) +{ + return (s32)(new - old) > 0; +} + +static bool +gss_xmit_need_reencode(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_cred *cred = req->rq_cred; + struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); + u32 win, seq_xmit; + bool ret = true; + + if (!ctx) + return true; + + if (gss_seq_is_newer(req->rq_seqno, READ_ONCE(ctx->gc_seq))) + goto out; + + seq_xmit = READ_ONCE(ctx->gc_seq_xmit); + while (gss_seq_is_newer(req->rq_seqno, seq_xmit)) { + u32 tmp = seq_xmit; + + seq_xmit = cmpxchg(&ctx->gc_seq_xmit, tmp, req->rq_seqno); + if (seq_xmit == tmp) { + ret = false; + goto out; + } + } + + win = ctx->gc_win; + if (win > 0) + ret = !gss_seq_is_newer(req->rq_seqno, seq_xmit - win); +out: + gss_put_ctx(ctx); + return ret; +} + static int gss_unwrap_resp(struct rpc_task *task, kxdrdproc_t decode, void *rqstp, __be32 *p, void *obj) @@ -2052,6 +2092,7 @@ static const struct rpc_credops gss_credops = { .crunwrap_resp = gss_unwrap_resp, .crkey_timeout = gss_key_timeout, .crstringify_acceptor = gss_stringify_acceptor, + .crneed_reencode = gss_xmit_need_reencode, }; static const struct rpc_credops gss_nullops = { diff --git a/net/sunrpc/auth_gss/gss_krb5_crypto.c b/net/sunrpc/auth_gss/gss_krb5_crypto.c index 0220e1ca5280..4f43383971ba 100644 --- a/net/sunrpc/auth_gss/gss_krb5_crypto.c +++ b/net/sunrpc/auth_gss/gss_krb5_crypto.c @@ -53,7 +53,7 @@ u32 krb5_encrypt( - struct crypto_skcipher *tfm, + struct crypto_sync_skcipher *tfm, void * iv, void * in, void * out, @@ -62,24 +62,24 @@ krb5_encrypt( u32 ret = -EINVAL; struct scatterlist sg[1]; u8 local_iv[GSS_KRB5_MAX_BLOCKSIZE] = {0}; - SKCIPHER_REQUEST_ON_STACK(req, tfm); + SYNC_SKCIPHER_REQUEST_ON_STACK(req, tfm); - if (length % crypto_skcipher_blocksize(tfm) != 0) + if (length % crypto_sync_skcipher_blocksize(tfm) != 0) goto out; - if (crypto_skcipher_ivsize(tfm) > GSS_KRB5_MAX_BLOCKSIZE) { + if (crypto_sync_skcipher_ivsize(tfm) > GSS_KRB5_MAX_BLOCKSIZE) { dprintk("RPC: gss_k5encrypt: tfm iv size too large %d\n", - crypto_skcipher_ivsize(tfm)); + crypto_sync_skcipher_ivsize(tfm)); goto out; } if (iv) - memcpy(local_iv, iv, crypto_skcipher_ivsize(tfm)); + memcpy(local_iv, iv, crypto_sync_skcipher_ivsize(tfm)); memcpy(out, in, length); sg_init_one(sg, out, length); - skcipher_request_set_tfm(req, tfm); + skcipher_request_set_sync_tfm(req, tfm); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, sg, sg, length, local_iv); @@ -92,7 +92,7 @@ out: u32 krb5_decrypt( - struct crypto_skcipher *tfm, + struct crypto_sync_skcipher *tfm, void * iv, void * in, void * out, @@ -101,23 +101,23 @@ krb5_decrypt( u32 ret = -EINVAL; struct scatterlist sg[1]; u8 local_iv[GSS_KRB5_MAX_BLOCKSIZE] = {0}; - SKCIPHER_REQUEST_ON_STACK(req, tfm); + SYNC_SKCIPHER_REQUEST_ON_STACK(req, tfm); - if (length % crypto_skcipher_blocksize(tfm) != 0) + if (length % crypto_sync_skcipher_blocksize(tfm) != 0) goto out; - if (crypto_skcipher_ivsize(tfm) > GSS_KRB5_MAX_BLOCKSIZE) { + if (crypto_sync_skcipher_ivsize(tfm) > GSS_KRB5_MAX_BLOCKSIZE) { dprintk("RPC: gss_k5decrypt: tfm iv size too large %d\n", - crypto_skcipher_ivsize(tfm)); + crypto_sync_skcipher_ivsize(tfm)); goto out; } if (iv) - memcpy(local_iv,iv, crypto_skcipher_ivsize(tfm)); + memcpy(local_iv, iv, crypto_sync_skcipher_ivsize(tfm)); memcpy(out, in, length); sg_init_one(sg, out, length); - skcipher_request_set_tfm(req, tfm); + skcipher_request_set_sync_tfm(req, tfm); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, sg, sg, length, local_iv); @@ -466,7 +466,8 @@ encryptor(struct scatterlist *sg, void *data) { struct encryptor_desc *desc = data; struct xdr_buf *outbuf = desc->outbuf; - struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(desc->req); + struct crypto_sync_skcipher *tfm = + crypto_sync_skcipher_reqtfm(desc->req); struct page *in_page; int thislen = desc->fraglen + sg->length; int fraglen, ret; @@ -492,7 +493,7 @@ encryptor(struct scatterlist *sg, void *data) desc->fraglen += sg->length; desc->pos += sg->length; - fraglen = thislen & (crypto_skcipher_blocksize(tfm) - 1); + fraglen = thislen & (crypto_sync_skcipher_blocksize(tfm) - 1); thislen -= fraglen; if (thislen == 0) @@ -526,16 +527,16 @@ encryptor(struct scatterlist *sg, void *data) } int -gss_encrypt_xdr_buf(struct crypto_skcipher *tfm, struct xdr_buf *buf, +gss_encrypt_xdr_buf(struct crypto_sync_skcipher *tfm, struct xdr_buf *buf, int offset, struct page **pages) { int ret; struct encryptor_desc desc; - SKCIPHER_REQUEST_ON_STACK(req, tfm); + SYNC_SKCIPHER_REQUEST_ON_STACK(req, tfm); - BUG_ON((buf->len - offset) % crypto_skcipher_blocksize(tfm) != 0); + BUG_ON((buf->len - offset) % crypto_sync_skcipher_blocksize(tfm) != 0); - skcipher_request_set_tfm(req, tfm); + skcipher_request_set_sync_tfm(req, tfm); skcipher_request_set_callback(req, 0, NULL, NULL); memset(desc.iv, 0, sizeof(desc.iv)); @@ -567,7 +568,8 @@ decryptor(struct scatterlist *sg, void *data) { struct decryptor_desc *desc = data; int thislen = desc->fraglen + sg->length; - struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(desc->req); + struct crypto_sync_skcipher *tfm = + crypto_sync_skcipher_reqtfm(desc->req); int fraglen, ret; /* Worst case is 4 fragments: head, end of page 1, start @@ -578,7 +580,7 @@ decryptor(struct scatterlist *sg, void *data) desc->fragno++; desc->fraglen += sg->length; - fraglen = thislen & (crypto_skcipher_blocksize(tfm) - 1); + fraglen = thislen & (crypto_sync_skcipher_blocksize(tfm) - 1); thislen -= fraglen; if (thislen == 0) @@ -608,17 +610,17 @@ decryptor(struct scatterlist *sg, void *data) } int -gss_decrypt_xdr_buf(struct crypto_skcipher *tfm, struct xdr_buf *buf, +gss_decrypt_xdr_buf(struct crypto_sync_skcipher *tfm, struct xdr_buf *buf, int offset) { int ret; struct decryptor_desc desc; - SKCIPHER_REQUEST_ON_STACK(req, tfm); + SYNC_SKCIPHER_REQUEST_ON_STACK(req, tfm); /* XXXJBF: */ - BUG_ON((buf->len - offset) % crypto_skcipher_blocksize(tfm) != 0); + BUG_ON((buf->len - offset) % crypto_sync_skcipher_blocksize(tfm) != 0); - skcipher_request_set_tfm(req, tfm); + skcipher_request_set_sync_tfm(req, tfm); skcipher_request_set_callback(req, 0, NULL, NULL); memset(desc.iv, 0, sizeof(desc.iv)); @@ -672,12 +674,12 @@ xdr_extend_head(struct xdr_buf *buf, unsigned int base, unsigned int shiftlen) } static u32 -gss_krb5_cts_crypt(struct crypto_skcipher *cipher, struct xdr_buf *buf, +gss_krb5_cts_crypt(struct crypto_sync_skcipher *cipher, struct xdr_buf *buf, u32 offset, u8 *iv, struct page **pages, int encrypt) { u32 ret; struct scatterlist sg[1]; - SKCIPHER_REQUEST_ON_STACK(req, cipher); + SYNC_SKCIPHER_REQUEST_ON_STACK(req, cipher); u8 *data; struct page **save_pages; u32 len = buf->len - offset; @@ -706,7 +708,7 @@ gss_krb5_cts_crypt(struct crypto_skcipher *cipher, struct xdr_buf *buf, sg_init_one(sg, data, len); - skcipher_request_set_tfm(req, cipher); + skcipher_request_set_sync_tfm(req, cipher); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, sg, sg, len, iv); @@ -735,7 +737,7 @@ gss_krb5_aes_encrypt(struct krb5_ctx *kctx, u32 offset, struct xdr_netobj hmac; u8 *cksumkey; u8 *ecptr; - struct crypto_skcipher *cipher, *aux_cipher; + struct crypto_sync_skcipher *cipher, *aux_cipher; int blocksize; struct page **save_pages; int nblocks, nbytes; @@ -754,7 +756,7 @@ gss_krb5_aes_encrypt(struct krb5_ctx *kctx, u32 offset, cksumkey = kctx->acceptor_integ; usage = KG_USAGE_ACCEPTOR_SEAL; } - blocksize = crypto_skcipher_blocksize(cipher); + blocksize = crypto_sync_skcipher_blocksize(cipher); /* hide the gss token header and insert the confounder */ offset += GSS_KRB5_TOK_HDR_LEN; @@ -807,7 +809,7 @@ gss_krb5_aes_encrypt(struct krb5_ctx *kctx, u32 offset, memset(desc.iv, 0, sizeof(desc.iv)); if (cbcbytes) { - SKCIPHER_REQUEST_ON_STACK(req, aux_cipher); + SYNC_SKCIPHER_REQUEST_ON_STACK(req, aux_cipher); desc.pos = offset + GSS_KRB5_TOK_HDR_LEN; desc.fragno = 0; @@ -816,7 +818,7 @@ gss_krb5_aes_encrypt(struct krb5_ctx *kctx, u32 offset, desc.outbuf = buf; desc.req = req; - skcipher_request_set_tfm(req, aux_cipher); + skcipher_request_set_sync_tfm(req, aux_cipher); skcipher_request_set_callback(req, 0, NULL, NULL); sg_init_table(desc.infrags, 4); @@ -855,7 +857,7 @@ gss_krb5_aes_decrypt(struct krb5_ctx *kctx, u32 offset, struct xdr_buf *buf, struct xdr_buf subbuf; u32 ret = 0; u8 *cksum_key; - struct crypto_skcipher *cipher, *aux_cipher; + struct crypto_sync_skcipher *cipher, *aux_cipher; struct xdr_netobj our_hmac_obj; u8 our_hmac[GSS_KRB5_MAX_CKSUM_LEN]; u8 pkt_hmac[GSS_KRB5_MAX_CKSUM_LEN]; @@ -874,7 +876,7 @@ gss_krb5_aes_decrypt(struct krb5_ctx *kctx, u32 offset, struct xdr_buf *buf, cksum_key = kctx->initiator_integ; usage = KG_USAGE_INITIATOR_SEAL; } - blocksize = crypto_skcipher_blocksize(cipher); + blocksize = crypto_sync_skcipher_blocksize(cipher); /* create a segment skipping the header and leaving out the checksum */ @@ -891,13 +893,13 @@ gss_krb5_aes_decrypt(struct krb5_ctx *kctx, u32 offset, struct xdr_buf *buf, memset(desc.iv, 0, sizeof(desc.iv)); if (cbcbytes) { - SKCIPHER_REQUEST_ON_STACK(req, aux_cipher); + SYNC_SKCIPHER_REQUEST_ON_STACK(req, aux_cipher); desc.fragno = 0; desc.fraglen = 0; desc.req = req; - skcipher_request_set_tfm(req, aux_cipher); + skcipher_request_set_sync_tfm(req, aux_cipher); skcipher_request_set_callback(req, 0, NULL, NULL); sg_init_table(desc.frags, 4); @@ -946,7 +948,8 @@ out_err: * Set the key of the given cipher. */ int -krb5_rc4_setup_seq_key(struct krb5_ctx *kctx, struct crypto_skcipher *cipher, +krb5_rc4_setup_seq_key(struct krb5_ctx *kctx, + struct crypto_sync_skcipher *cipher, unsigned char *cksum) { struct crypto_shash *hmac; @@ -994,7 +997,7 @@ krb5_rc4_setup_seq_key(struct krb5_ctx *kctx, struct crypto_skcipher *cipher, if (err) goto out_err; - err = crypto_skcipher_setkey(cipher, Kseq, kctx->gk5e->keylength); + err = crypto_sync_skcipher_setkey(cipher, Kseq, kctx->gk5e->keylength); if (err) goto out_err; @@ -1012,7 +1015,8 @@ out_err: * Set the key of cipher kctx->enc. */ int -krb5_rc4_setup_enc_key(struct krb5_ctx *kctx, struct crypto_skcipher *cipher, +krb5_rc4_setup_enc_key(struct krb5_ctx *kctx, + struct crypto_sync_skcipher *cipher, s32 seqnum) { struct crypto_shash *hmac; @@ -1069,7 +1073,8 @@ krb5_rc4_setup_enc_key(struct krb5_ctx *kctx, struct crypto_skcipher *cipher, if (err) goto out_err; - err = crypto_skcipher_setkey(cipher, Kcrypt, kctx->gk5e->keylength); + err = crypto_sync_skcipher_setkey(cipher, Kcrypt, + kctx->gk5e->keylength); if (err) goto out_err; diff --git a/net/sunrpc/auth_gss/gss_krb5_keys.c b/net/sunrpc/auth_gss/gss_krb5_keys.c index f7fe2d2b851f..550fdf18d3b3 100644 --- a/net/sunrpc/auth_gss/gss_krb5_keys.c +++ b/net/sunrpc/auth_gss/gss_krb5_keys.c @@ -147,7 +147,7 @@ u32 krb5_derive_key(const struct gss_krb5_enctype *gk5e, size_t blocksize, keybytes, keylength, n; unsigned char *inblockdata, *outblockdata, *rawkey; struct xdr_netobj inblock, outblock; - struct crypto_skcipher *cipher; + struct crypto_sync_skcipher *cipher; u32 ret = EINVAL; blocksize = gk5e->blocksize; @@ -157,11 +157,10 @@ u32 krb5_derive_key(const struct gss_krb5_enctype *gk5e, if ((inkey->len != keylength) || (outkey->len != keylength)) goto err_return; - cipher = crypto_alloc_skcipher(gk5e->encrypt_name, 0, - CRYPTO_ALG_ASYNC); + cipher = crypto_alloc_sync_skcipher(gk5e->encrypt_name, 0, 0); if (IS_ERR(cipher)) goto err_return; - if (crypto_skcipher_setkey(cipher, inkey->data, inkey->len)) + if (crypto_sync_skcipher_setkey(cipher, inkey->data, inkey->len)) goto err_return; /* allocate and set up buffers */ @@ -238,7 +237,7 @@ err_free_in: memset(inblockdata, 0, blocksize); kfree(inblockdata); err_free_cipher: - crypto_free_skcipher(cipher); + crypto_free_sync_skcipher(cipher); err_return: return ret; } diff --git a/net/sunrpc/auth_gss/gss_krb5_mech.c b/net/sunrpc/auth_gss/gss_krb5_mech.c index 7bb2514aadd9..eab71fc7af3e 100644 --- a/net/sunrpc/auth_gss/gss_krb5_mech.c +++ b/net/sunrpc/auth_gss/gss_krb5_mech.c @@ -218,7 +218,7 @@ simple_get_netobj(const void *p, const void *end, struct xdr_netobj *res) static inline const void * get_key(const void *p, const void *end, - struct krb5_ctx *ctx, struct crypto_skcipher **res) + struct krb5_ctx *ctx, struct crypto_sync_skcipher **res) { struct xdr_netobj key; int alg; @@ -246,15 +246,14 @@ get_key(const void *p, const void *end, if (IS_ERR(p)) goto out_err; - *res = crypto_alloc_skcipher(ctx->gk5e->encrypt_name, 0, - CRYPTO_ALG_ASYNC); + *res = crypto_alloc_sync_skcipher(ctx->gk5e->encrypt_name, 0, 0); if (IS_ERR(*res)) { printk(KERN_WARNING "gss_kerberos_mech: unable to initialize " "crypto algorithm %s\n", ctx->gk5e->encrypt_name); *res = NULL; goto out_err_free_key; } - if (crypto_skcipher_setkey(*res, key.data, key.len)) { + if (crypto_sync_skcipher_setkey(*res, key.data, key.len)) { printk(KERN_WARNING "gss_kerberos_mech: error setting key for " "crypto algorithm %s\n", ctx->gk5e->encrypt_name); goto out_err_free_tfm; @@ -264,7 +263,7 @@ get_key(const void *p, const void *end, return p; out_err_free_tfm: - crypto_free_skcipher(*res); + crypto_free_sync_skcipher(*res); out_err_free_key: kfree(key.data); p = ERR_PTR(-EINVAL); @@ -275,6 +274,7 @@ out_err: static int gss_import_v1_context(const void *p, const void *end, struct krb5_ctx *ctx) { + u32 seq_send; int tmp; p = simple_get_bytes(p, end, &ctx->initiate, sizeof(ctx->initiate)); @@ -316,9 +316,10 @@ gss_import_v1_context(const void *p, const void *end, struct krb5_ctx *ctx) p = simple_get_bytes(p, end, &ctx->endtime, sizeof(ctx->endtime)); if (IS_ERR(p)) goto out_err; - p = simple_get_bytes(p, end, &ctx->seq_send, sizeof(ctx->seq_send)); + p = simple_get_bytes(p, end, &seq_send, sizeof(seq_send)); if (IS_ERR(p)) goto out_err; + atomic_set(&ctx->seq_send, seq_send); p = simple_get_netobj(p, end, &ctx->mech_used); if (IS_ERR(p)) goto out_err; @@ -336,30 +337,30 @@ gss_import_v1_context(const void *p, const void *end, struct krb5_ctx *ctx) return 0; out_err_free_key2: - crypto_free_skcipher(ctx->seq); + crypto_free_sync_skcipher(ctx->seq); out_err_free_key1: - crypto_free_skcipher(ctx->enc); + crypto_free_sync_skcipher(ctx->enc); out_err_free_mech: kfree(ctx->mech_used.data); out_err: return PTR_ERR(p); } -static struct crypto_skcipher * +static struct crypto_sync_skcipher * context_v2_alloc_cipher(struct krb5_ctx *ctx, const char *cname, u8 *key) { - struct crypto_skcipher *cp; + struct crypto_sync_skcipher *cp; - cp = crypto_alloc_skcipher(cname, 0, CRYPTO_ALG_ASYNC); + cp = crypto_alloc_sync_skcipher(cname, 0, 0); if (IS_ERR(cp)) { dprintk("gss_kerberos_mech: unable to initialize " "crypto algorithm %s\n", cname); return NULL; } - if (crypto_skcipher_setkey(cp, key, ctx->gk5e->keylength)) { + if (crypto_sync_skcipher_setkey(cp, key, ctx->gk5e->keylength)) { dprintk("gss_kerberos_mech: error setting key for " "crypto algorithm %s\n", cname); - crypto_free_skcipher(cp); + crypto_free_sync_skcipher(cp); return NULL; } return cp; @@ -413,9 +414,9 @@ context_derive_keys_des3(struct krb5_ctx *ctx, gfp_t gfp_mask) return 0; out_free_enc: - crypto_free_skcipher(ctx->enc); + crypto_free_sync_skcipher(ctx->enc); out_free_seq: - crypto_free_skcipher(ctx->seq); + crypto_free_sync_skcipher(ctx->seq); out_err: return -EINVAL; } @@ -469,17 +470,15 @@ context_derive_keys_rc4(struct krb5_ctx *ctx) /* * allocate hash, and skciphers for data and seqnum encryption */ - ctx->enc = crypto_alloc_skcipher(ctx->gk5e->encrypt_name, 0, - CRYPTO_ALG_ASYNC); + ctx->enc = crypto_alloc_sync_skcipher(ctx->gk5e->encrypt_name, 0, 0); if (IS_ERR(ctx->enc)) { err = PTR_ERR(ctx->enc); goto out_err_free_hmac; } - ctx->seq = crypto_alloc_skcipher(ctx->gk5e->encrypt_name, 0, - CRYPTO_ALG_ASYNC); + ctx->seq = crypto_alloc_sync_skcipher(ctx->gk5e->encrypt_name, 0, 0); if (IS_ERR(ctx->seq)) { - crypto_free_skcipher(ctx->enc); + crypto_free_sync_skcipher(ctx->enc); err = PTR_ERR(ctx->seq); goto out_err_free_hmac; } @@ -591,7 +590,7 @@ context_derive_keys_new(struct krb5_ctx *ctx, gfp_t gfp_mask) context_v2_alloc_cipher(ctx, "cbc(aes)", ctx->acceptor_seal); if (ctx->acceptor_enc_aux == NULL) { - crypto_free_skcipher(ctx->initiator_enc_aux); + crypto_free_sync_skcipher(ctx->initiator_enc_aux); goto out_free_acceptor_enc; } } @@ -599,9 +598,9 @@ context_derive_keys_new(struct krb5_ctx *ctx, gfp_t gfp_mask) return 0; out_free_acceptor_enc: - crypto_free_skcipher(ctx->acceptor_enc); + crypto_free_sync_skcipher(ctx->acceptor_enc); out_free_initiator_enc: - crypto_free_skcipher(ctx->initiator_enc); + crypto_free_sync_skcipher(ctx->initiator_enc); out_err: return -EINVAL; } @@ -610,6 +609,7 @@ static int gss_import_v2_context(const void *p, const void *end, struct krb5_ctx *ctx, gfp_t gfp_mask) { + u64 seq_send64; int keylen; p = simple_get_bytes(p, end, &ctx->flags, sizeof(ctx->flags)); @@ -620,14 +620,15 @@ gss_import_v2_context(const void *p, const void *end, struct krb5_ctx *ctx, p = simple_get_bytes(p, end, &ctx->endtime, sizeof(ctx->endtime)); if (IS_ERR(p)) goto out_err; - p = simple_get_bytes(p, end, &ctx->seq_send64, sizeof(ctx->seq_send64)); + p = simple_get_bytes(p, end, &seq_send64, sizeof(seq_send64)); if (IS_ERR(p)) goto out_err; + atomic64_set(&ctx->seq_send64, seq_send64); /* set seq_send for use by "older" enctypes */ - ctx->seq_send = ctx->seq_send64; - if (ctx->seq_send64 != ctx->seq_send) { - dprintk("%s: seq_send64 %lx, seq_send %x overflow?\n", __func__, - (unsigned long)ctx->seq_send64, ctx->seq_send); + atomic_set(&ctx->seq_send, seq_send64); + if (seq_send64 != atomic_read(&ctx->seq_send)) { + dprintk("%s: seq_send64 %llx, seq_send %x overflow?\n", __func__, + seq_send64, atomic_read(&ctx->seq_send)); p = ERR_PTR(-EINVAL); goto out_err; } @@ -713,12 +714,12 @@ static void gss_delete_sec_context_kerberos(void *internal_ctx) { struct krb5_ctx *kctx = internal_ctx; - crypto_free_skcipher(kctx->seq); - crypto_free_skcipher(kctx->enc); - crypto_free_skcipher(kctx->acceptor_enc); - crypto_free_skcipher(kctx->initiator_enc); - crypto_free_skcipher(kctx->acceptor_enc_aux); - crypto_free_skcipher(kctx->initiator_enc_aux); + crypto_free_sync_skcipher(kctx->seq); + crypto_free_sync_skcipher(kctx->enc); + crypto_free_sync_skcipher(kctx->acceptor_enc); + crypto_free_sync_skcipher(kctx->initiator_enc); + crypto_free_sync_skcipher(kctx->acceptor_enc_aux); + crypto_free_sync_skcipher(kctx->initiator_enc_aux); kfree(kctx->mech_used.data); kfree(kctx); } diff --git a/net/sunrpc/auth_gss/gss_krb5_seal.c b/net/sunrpc/auth_gss/gss_krb5_seal.c index eaad9bc7a0bd..48fe4a591b54 100644 --- a/net/sunrpc/auth_gss/gss_krb5_seal.c +++ b/net/sunrpc/auth_gss/gss_krb5_seal.c @@ -63,13 +63,12 @@ #include <linux/sunrpc/gss_krb5.h> #include <linux/random.h> #include <linux/crypto.h> +#include <linux/atomic.h> #if IS_ENABLED(CONFIG_SUNRPC_DEBUG) # define RPCDBG_FACILITY RPCDBG_AUTH #endif -DEFINE_SPINLOCK(krb5_seq_lock); - static void * setup_token(struct krb5_ctx *ctx, struct xdr_netobj *token) { @@ -154,9 +153,7 @@ gss_get_mic_v1(struct krb5_ctx *ctx, struct xdr_buf *text, memcpy(ptr + GSS_KRB5_TOK_HDR_LEN, md5cksum.data, md5cksum.len); - spin_lock(&krb5_seq_lock); - seq_send = ctx->seq_send++; - spin_unlock(&krb5_seq_lock); + seq_send = atomic_fetch_inc(&ctx->seq_send); if (krb5_make_seq_num(ctx, ctx->seq, ctx->initiate ? 0 : 0xff, seq_send, ptr + GSS_KRB5_TOK_HDR_LEN, ptr + 8)) @@ -174,7 +171,6 @@ gss_get_mic_v2(struct krb5_ctx *ctx, struct xdr_buf *text, .data = cksumdata}; void *krb5_hdr; s32 now; - u64 seq_send; u8 *cksumkey; unsigned int cksum_usage; __be64 seq_send_be64; @@ -185,11 +181,7 @@ gss_get_mic_v2(struct krb5_ctx *ctx, struct xdr_buf *text, /* Set up the sequence number. Now 64-bits in clear * text and w/o direction indicator */ - spin_lock(&krb5_seq_lock); - seq_send = ctx->seq_send64++; - spin_unlock(&krb5_seq_lock); - - seq_send_be64 = cpu_to_be64(seq_send); + seq_send_be64 = cpu_to_be64(atomic64_fetch_inc(&ctx->seq_send64)); memcpy(krb5_hdr + 8, (char *) &seq_send_be64, 8); if (ctx->initiate) { diff --git a/net/sunrpc/auth_gss/gss_krb5_seqnum.c b/net/sunrpc/auth_gss/gss_krb5_seqnum.c index c8b9082f4a9d..fb6656295204 100644 --- a/net/sunrpc/auth_gss/gss_krb5_seqnum.c +++ b/net/sunrpc/auth_gss/gss_krb5_seqnum.c @@ -43,13 +43,12 @@ static s32 krb5_make_rc4_seq_num(struct krb5_ctx *kctx, int direction, s32 seqnum, unsigned char *cksum, unsigned char *buf) { - struct crypto_skcipher *cipher; + struct crypto_sync_skcipher *cipher; unsigned char plain[8]; s32 code; dprintk("RPC: %s:\n", __func__); - cipher = crypto_alloc_skcipher(kctx->gk5e->encrypt_name, 0, - CRYPTO_ALG_ASYNC); + cipher = crypto_alloc_sync_skcipher(kctx->gk5e->encrypt_name, 0, 0); if (IS_ERR(cipher)) return PTR_ERR(cipher); @@ -68,12 +67,12 @@ krb5_make_rc4_seq_num(struct krb5_ctx *kctx, int direction, s32 seqnum, code = krb5_encrypt(cipher, cksum, plain, buf, 8); out: - crypto_free_skcipher(cipher); + crypto_free_sync_skcipher(cipher); return code; } s32 krb5_make_seq_num(struct krb5_ctx *kctx, - struct crypto_skcipher *key, + struct crypto_sync_skcipher *key, int direction, u32 seqnum, unsigned char *cksum, unsigned char *buf) @@ -101,13 +100,12 @@ static s32 krb5_get_rc4_seq_num(struct krb5_ctx *kctx, unsigned char *cksum, unsigned char *buf, int *direction, s32 *seqnum) { - struct crypto_skcipher *cipher; + struct crypto_sync_skcipher *cipher; unsigned char plain[8]; s32 code; dprintk("RPC: %s:\n", __func__); - cipher = crypto_alloc_skcipher(kctx->gk5e->encrypt_name, 0, - CRYPTO_ALG_ASYNC); + cipher = crypto_alloc_sync_skcipher(kctx->gk5e->encrypt_name, 0, 0); if (IS_ERR(cipher)) return PTR_ERR(cipher); @@ -130,7 +128,7 @@ krb5_get_rc4_seq_num(struct krb5_ctx *kctx, unsigned char *cksum, *seqnum = ((plain[0] << 24) | (plain[1] << 16) | (plain[2] << 8) | (plain[3])); out: - crypto_free_skcipher(cipher); + crypto_free_sync_skcipher(cipher); return code; } @@ -142,7 +140,7 @@ krb5_get_seq_num(struct krb5_ctx *kctx, { s32 code; unsigned char plain[8]; - struct crypto_skcipher *key = kctx->seq; + struct crypto_sync_skcipher *key = kctx->seq; dprintk("RPC: krb5_get_seq_num:\n"); diff --git a/net/sunrpc/auth_gss/gss_krb5_wrap.c b/net/sunrpc/auth_gss/gss_krb5_wrap.c index 39a2e672900b..5cdde6cb703a 100644 --- a/net/sunrpc/auth_gss/gss_krb5_wrap.c +++ b/net/sunrpc/auth_gss/gss_krb5_wrap.c @@ -174,7 +174,7 @@ gss_wrap_kerberos_v1(struct krb5_ctx *kctx, int offset, now = get_seconds(); - blocksize = crypto_skcipher_blocksize(kctx->enc); + blocksize = crypto_sync_skcipher_blocksize(kctx->enc); gss_krb5_add_padding(buf, offset, blocksize); BUG_ON((buf->len - offset) % blocksize); plainlen = conflen + buf->len - offset; @@ -228,9 +228,7 @@ gss_wrap_kerberos_v1(struct krb5_ctx *kctx, int offset, memcpy(ptr + GSS_KRB5_TOK_HDR_LEN, md5cksum.data, md5cksum.len); - spin_lock(&krb5_seq_lock); - seq_send = kctx->seq_send++; - spin_unlock(&krb5_seq_lock); + seq_send = atomic_fetch_inc(&kctx->seq_send); /* XXX would probably be more efficient to compute checksum * and encrypt at the same time: */ @@ -239,10 +237,10 @@ gss_wrap_kerberos_v1(struct krb5_ctx *kctx, int offset, return GSS_S_FAILURE; if (kctx->enctype == ENCTYPE_ARCFOUR_HMAC) { - struct crypto_skcipher *cipher; + struct crypto_sync_skcipher *cipher; int err; - cipher = crypto_alloc_skcipher(kctx->gk5e->encrypt_name, 0, - CRYPTO_ALG_ASYNC); + cipher = crypto_alloc_sync_skcipher(kctx->gk5e->encrypt_name, + 0, 0); if (IS_ERR(cipher)) return GSS_S_FAILURE; @@ -250,7 +248,7 @@ gss_wrap_kerberos_v1(struct krb5_ctx *kctx, int offset, err = gss_encrypt_xdr_buf(cipher, buf, offset + headlen - conflen, pages); - crypto_free_skcipher(cipher); + crypto_free_sync_skcipher(cipher); if (err) return GSS_S_FAILURE; } else { @@ -327,18 +325,18 @@ gss_unwrap_kerberos_v1(struct krb5_ctx *kctx, int offset, struct xdr_buf *buf) return GSS_S_BAD_SIG; if (kctx->enctype == ENCTYPE_ARCFOUR_HMAC) { - struct crypto_skcipher *cipher; + struct crypto_sync_skcipher *cipher; int err; - cipher = crypto_alloc_skcipher(kctx->gk5e->encrypt_name, 0, - CRYPTO_ALG_ASYNC); + cipher = crypto_alloc_sync_skcipher(kctx->gk5e->encrypt_name, + 0, 0); if (IS_ERR(cipher)) return GSS_S_FAILURE; krb5_rc4_setup_enc_key(kctx, cipher, seqnum); err = gss_decrypt_xdr_buf(cipher, buf, crypt_offset); - crypto_free_skcipher(cipher); + crypto_free_sync_skcipher(cipher); if (err) return GSS_S_DEFECTIVE_TOKEN; } else { @@ -371,7 +369,7 @@ gss_unwrap_kerberos_v1(struct krb5_ctx *kctx, int offset, struct xdr_buf *buf) /* Copy the data back to the right position. XXX: Would probably be * better to copy and encrypt at the same time. */ - blocksize = crypto_skcipher_blocksize(kctx->enc); + blocksize = crypto_sync_skcipher_blocksize(kctx->enc); data_start = ptr + (GSS_KRB5_TOK_HDR_LEN + kctx->gk5e->cksumlength) + conflen; orig_start = buf->head[0].iov_base + offset; @@ -477,9 +475,7 @@ gss_wrap_kerberos_v2(struct krb5_ctx *kctx, u32 offset, *be16ptr++ = 0; be64ptr = (__be64 *)be16ptr; - spin_lock(&krb5_seq_lock); - *be64ptr = cpu_to_be64(kctx->seq_send64++); - spin_unlock(&krb5_seq_lock); + *be64ptr = cpu_to_be64(atomic64_fetch_inc(&kctx->seq_send64)); err = (*kctx->gk5e->encrypt_v2)(kctx, offset, buf, pages); if (err) diff --git a/net/sunrpc/auth_gss/gss_mech_switch.c b/net/sunrpc/auth_gss/gss_mech_switch.c index 5fec3abbe19b..16ac0f4cb7d8 100644 --- a/net/sunrpc/auth_gss/gss_mech_switch.c +++ b/net/sunrpc/auth_gss/gss_mech_switch.c @@ -117,7 +117,7 @@ int gss_mech_register(struct gss_api_mech *gm) if (status) return status; spin_lock(®istered_mechs_lock); - list_add(&gm->gm_list, ®istered_mechs); + list_add_rcu(&gm->gm_list, ®istered_mechs); spin_unlock(®istered_mechs_lock); dprintk("RPC: registered gss mechanism %s\n", gm->gm_name); return 0; @@ -132,7 +132,7 @@ EXPORT_SYMBOL_GPL(gss_mech_register); void gss_mech_unregister(struct gss_api_mech *gm) { spin_lock(®istered_mechs_lock); - list_del(&gm->gm_list); + list_del_rcu(&gm->gm_list); spin_unlock(®istered_mechs_lock); dprintk("RPC: unregistered gss mechanism %s\n", gm->gm_name); gss_mech_free(gm); @@ -151,15 +151,15 @@ _gss_mech_get_by_name(const char *name) { struct gss_api_mech *pos, *gm = NULL; - spin_lock(®istered_mechs_lock); - list_for_each_entry(pos, ®istered_mechs, gm_list) { + rcu_read_lock(); + list_for_each_entry_rcu(pos, ®istered_mechs, gm_list) { if (0 == strcmp(name, pos->gm_name)) { if (try_module_get(pos->gm_owner)) gm = pos; break; } } - spin_unlock(®istered_mechs_lock); + rcu_read_unlock(); return gm; } @@ -186,8 +186,8 @@ struct gss_api_mech *gss_mech_get_by_OID(struct rpcsec_gss_oid *obj) dprintk("RPC: %s(%s)\n", __func__, buf); request_module("rpc-auth-gss-%s", buf); - spin_lock(®istered_mechs_lock); - list_for_each_entry(pos, ®istered_mechs, gm_list) { + rcu_read_lock(); + list_for_each_entry_rcu(pos, ®istered_mechs, gm_list) { if (obj->len == pos->gm_oid.len) { if (0 == memcmp(obj->data, pos->gm_oid.data, obj->len)) { if (try_module_get(pos->gm_owner)) @@ -196,7 +196,7 @@ struct gss_api_mech *gss_mech_get_by_OID(struct rpcsec_gss_oid *obj) } } } - spin_unlock(®istered_mechs_lock); + rcu_read_unlock(); return gm; } @@ -216,15 +216,15 @@ static struct gss_api_mech *_gss_mech_get_by_pseudoflavor(u32 pseudoflavor) { struct gss_api_mech *gm = NULL, *pos; - spin_lock(®istered_mechs_lock); - list_for_each_entry(pos, ®istered_mechs, gm_list) { + rcu_read_lock(); + list_for_each_entry_rcu(pos, ®istered_mechs, gm_list) { if (!mech_supports_pseudoflavor(pos, pseudoflavor)) continue; if (try_module_get(pos->gm_owner)) gm = pos; break; } - spin_unlock(®istered_mechs_lock); + rcu_read_unlock(); return gm; } @@ -257,8 +257,8 @@ int gss_mech_list_pseudoflavors(rpc_authflavor_t *array_ptr, int size) struct gss_api_mech *pos = NULL; int j, i = 0; - spin_lock(®istered_mechs_lock); - list_for_each_entry(pos, ®istered_mechs, gm_list) { + rcu_read_lock(); + list_for_each_entry_rcu(pos, ®istered_mechs, gm_list) { for (j = 0; j < pos->gm_pf_num; j++) { if (i >= size) { spin_unlock(®istered_mechs_lock); @@ -267,7 +267,7 @@ int gss_mech_list_pseudoflavors(rpc_authflavor_t *array_ptr, int size) array_ptr[i++] = pos->gm_pfs[j].pseudoflavor; } } - spin_unlock(®istered_mechs_lock); + rcu_read_unlock(); return i; } diff --git a/net/sunrpc/auth_gss/gss_rpc_xdr.c b/net/sunrpc/auth_gss/gss_rpc_xdr.c index 444380f968f1..006062ad5f58 100644 --- a/net/sunrpc/auth_gss/gss_rpc_xdr.c +++ b/net/sunrpc/auth_gss/gss_rpc_xdr.c @@ -784,6 +784,7 @@ void gssx_enc_accept_sec_context(struct rpc_rqst *req, xdr_inline_pages(&req->rq_rcv_buf, PAGE_SIZE/2 /* pretty arbitrary */, arg->pages, 0 /* page base */, arg->npages * PAGE_SIZE); + req->rq_rcv_buf.flags |= XDRBUF_SPARSE_PAGES; done: if (err) dprintk("RPC: gssx_enc_accept_sec_context: %d\n", err); diff --git a/net/sunrpc/auth_gss/svcauth_gss.c b/net/sunrpc/auth_gss/svcauth_gss.c index 860f2a1bbb67..1ece4bc3eb8d 100644 --- a/net/sunrpc/auth_gss/svcauth_gss.c +++ b/net/sunrpc/auth_gss/svcauth_gss.c @@ -76,6 +76,7 @@ struct rsi { struct xdr_netobj in_handle, in_token; struct xdr_netobj out_handle, out_token; int major_status, minor_status; + struct rcu_head rcu_head; }; static struct rsi *rsi_update(struct cache_detail *cd, struct rsi *new, struct rsi *old); @@ -89,13 +90,21 @@ static void rsi_free(struct rsi *rsii) kfree(rsii->out_token.data); } -static void rsi_put(struct kref *ref) +static void rsi_free_rcu(struct rcu_head *head) { - struct rsi *rsii = container_of(ref, struct rsi, h.ref); + struct rsi *rsii = container_of(head, struct rsi, rcu_head); + rsi_free(rsii); kfree(rsii); } +static void rsi_put(struct kref *ref) +{ + struct rsi *rsii = container_of(ref, struct rsi, h.ref); + + call_rcu(&rsii->rcu_head, rsi_free_rcu); +} + static inline int rsi_hash(struct rsi *item) { return hash_mem(item->in_handle.data, item->in_handle.len, RSI_HASHBITS) @@ -282,7 +291,7 @@ static struct rsi *rsi_lookup(struct cache_detail *cd, struct rsi *item) struct cache_head *ch; int hash = rsi_hash(item); - ch = sunrpc_cache_lookup(cd, &item->h, hash); + ch = sunrpc_cache_lookup_rcu(cd, &item->h, hash); if (ch) return container_of(ch, struct rsi, h); else @@ -330,6 +339,7 @@ struct rsc { struct svc_cred cred; struct gss_svc_seq_data seqdata; struct gss_ctx *mechctx; + struct rcu_head rcu_head; }; static struct rsc *rsc_update(struct cache_detail *cd, struct rsc *new, struct rsc *old); @@ -343,12 +353,22 @@ static void rsc_free(struct rsc *rsci) free_svc_cred(&rsci->cred); } +static void rsc_free_rcu(struct rcu_head *head) +{ + struct rsc *rsci = container_of(head, struct rsc, rcu_head); + + kfree(rsci->handle.data); + kfree(rsci); +} + static void rsc_put(struct kref *ref) { struct rsc *rsci = container_of(ref, struct rsc, h.ref); - rsc_free(rsci); - kfree(rsci); + if (rsci->mechctx) + gss_delete_sec_context(&rsci->mechctx); + free_svc_cred(&rsci->cred); + call_rcu(&rsci->rcu_head, rsc_free_rcu); } static inline int @@ -542,7 +562,7 @@ static struct rsc *rsc_lookup(struct cache_detail *cd, struct rsc *item) struct cache_head *ch; int hash = rsc_hash(item); - ch = sunrpc_cache_lookup(cd, &item->h, hash); + ch = sunrpc_cache_lookup_rcu(cd, &item->h, hash); if (ch) return container_of(ch, struct rsc, h); else @@ -1764,14 +1784,21 @@ out_err: } static void -svcauth_gss_domain_release(struct auth_domain *dom) +svcauth_gss_domain_release_rcu(struct rcu_head *head) { + struct auth_domain *dom = container_of(head, struct auth_domain, rcu_head); struct gss_domain *gd = container_of(dom, struct gss_domain, h); kfree(dom->name); kfree(gd); } +static void +svcauth_gss_domain_release(struct auth_domain *dom) +{ + call_rcu(&dom->rcu_head, svcauth_gss_domain_release_rcu); +} + static struct auth_ops svcauthops_gss = { .name = "rpcsec_gss", .owner = THIS_MODULE, diff --git a/net/sunrpc/auth_null.c b/net/sunrpc/auth_null.c index 4b48228ee8c7..2694a1bc026b 100644 --- a/net/sunrpc/auth_null.c +++ b/net/sunrpc/auth_null.c @@ -21,7 +21,7 @@ static struct rpc_cred null_cred; static struct rpc_auth * nul_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) { - atomic_inc(&null_auth.au_count); + refcount_inc(&null_auth.au_count); return &null_auth; } @@ -119,7 +119,7 @@ struct rpc_auth null_auth = { .au_flags = RPCAUTH_AUTH_NO_CRKEY_TIMEOUT, .au_ops = &authnull_ops, .au_flavor = RPC_AUTH_NULL, - .au_count = ATOMIC_INIT(0), + .au_count = REFCOUNT_INIT(1), }; static @@ -138,6 +138,6 @@ struct rpc_cred null_cred = { .cr_lru = LIST_HEAD_INIT(null_cred.cr_lru), .cr_auth = &null_auth, .cr_ops = &null_credops, - .cr_count = ATOMIC_INIT(1), + .cr_count = REFCOUNT_INIT(2), .cr_flags = 1UL << RPCAUTH_CRED_UPTODATE, }; diff --git a/net/sunrpc/auth_unix.c b/net/sunrpc/auth_unix.c index 185e56d4f9ae..4c1c7e56288f 100644 --- a/net/sunrpc/auth_unix.c +++ b/net/sunrpc/auth_unix.c @@ -34,7 +34,7 @@ unx_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) { dprintk("RPC: creating UNIX authenticator for client %p\n", clnt); - atomic_inc(&unix_auth.au_count); + refcount_inc(&unix_auth.au_count); return &unix_auth; } @@ -239,7 +239,7 @@ struct rpc_auth unix_auth = { .au_flags = RPCAUTH_AUTH_NO_CRKEY_TIMEOUT, .au_ops = &authunix_ops, .au_flavor = RPC_AUTH_UNIX, - .au_count = ATOMIC_INIT(0), + .au_count = REFCOUNT_INIT(1), }; static diff --git a/net/sunrpc/backchannel_rqst.c b/net/sunrpc/backchannel_rqst.c index 3c15a99b9700..fa5ba6ed3197 100644 --- a/net/sunrpc/backchannel_rqst.c +++ b/net/sunrpc/backchannel_rqst.c @@ -91,7 +91,6 @@ struct rpc_rqst *xprt_alloc_bc_req(struct rpc_xprt *xprt, gfp_t gfp_flags) return NULL; req->rq_xprt = xprt; - INIT_LIST_HEAD(&req->rq_list); INIT_LIST_HEAD(&req->rq_bc_list); /* Preallocate one XDR receive buffer */ diff --git a/net/sunrpc/cache.c b/net/sunrpc/cache.c index 109fbe591e7b..f96345b1180e 100644 --- a/net/sunrpc/cache.c +++ b/net/sunrpc/cache.c @@ -54,28 +54,33 @@ static void cache_init(struct cache_head *h, struct cache_detail *detail) h->last_refresh = now; } -struct cache_head *sunrpc_cache_lookup(struct cache_detail *detail, - struct cache_head *key, int hash) +static struct cache_head *sunrpc_cache_find_rcu(struct cache_detail *detail, + struct cache_head *key, + int hash) { - struct cache_head *new = NULL, *freeme = NULL, *tmp = NULL; - struct hlist_head *head; - - head = &detail->hash_table[hash]; - - read_lock(&detail->hash_lock); + struct hlist_head *head = &detail->hash_table[hash]; + struct cache_head *tmp; - hlist_for_each_entry(tmp, head, cache_list) { + rcu_read_lock(); + hlist_for_each_entry_rcu(tmp, head, cache_list) { if (detail->match(tmp, key)) { if (cache_is_expired(detail, tmp)) - /* This entry is expired, we will discard it. */ - break; - cache_get(tmp); - read_unlock(&detail->hash_lock); + continue; + tmp = cache_get_rcu(tmp); + rcu_read_unlock(); return tmp; } } - read_unlock(&detail->hash_lock); - /* Didn't find anything, insert an empty entry */ + rcu_read_unlock(); + return NULL; +} + +static struct cache_head *sunrpc_cache_add_entry(struct cache_detail *detail, + struct cache_head *key, + int hash) +{ + struct cache_head *new, *tmp, *freeme = NULL; + struct hlist_head *head = &detail->hash_table[hash]; new = detail->alloc(); if (!new) @@ -87,35 +92,46 @@ struct cache_head *sunrpc_cache_lookup(struct cache_detail *detail, cache_init(new, detail); detail->init(new, key); - write_lock(&detail->hash_lock); + spin_lock(&detail->hash_lock); /* check if entry appeared while we slept */ - hlist_for_each_entry(tmp, head, cache_list) { + hlist_for_each_entry_rcu(tmp, head, cache_list) { if (detail->match(tmp, key)) { if (cache_is_expired(detail, tmp)) { - hlist_del_init(&tmp->cache_list); + hlist_del_init_rcu(&tmp->cache_list); detail->entries --; freeme = tmp; break; } cache_get(tmp); - write_unlock(&detail->hash_lock); + spin_unlock(&detail->hash_lock); cache_put(new, detail); return tmp; } } - hlist_add_head(&new->cache_list, head); + hlist_add_head_rcu(&new->cache_list, head); detail->entries++; cache_get(new); - write_unlock(&detail->hash_lock); + spin_unlock(&detail->hash_lock); if (freeme) cache_put(freeme, detail); return new; } -EXPORT_SYMBOL_GPL(sunrpc_cache_lookup); +struct cache_head *sunrpc_cache_lookup_rcu(struct cache_detail *detail, + struct cache_head *key, int hash) +{ + struct cache_head *ret; + + ret = sunrpc_cache_find_rcu(detail, key, hash); + if (ret) + return ret; + /* Didn't find anything, insert an empty entry */ + return sunrpc_cache_add_entry(detail, key, hash); +} +EXPORT_SYMBOL_GPL(sunrpc_cache_lookup_rcu); static void cache_dequeue(struct cache_detail *detail, struct cache_head *ch); @@ -151,18 +167,18 @@ struct cache_head *sunrpc_cache_update(struct cache_detail *detail, struct cache_head *tmp; if (!test_bit(CACHE_VALID, &old->flags)) { - write_lock(&detail->hash_lock); + spin_lock(&detail->hash_lock); if (!test_bit(CACHE_VALID, &old->flags)) { if (test_bit(CACHE_NEGATIVE, &new->flags)) set_bit(CACHE_NEGATIVE, &old->flags); else detail->update(old, new); cache_fresh_locked(old, new->expiry_time, detail); - write_unlock(&detail->hash_lock); + spin_unlock(&detail->hash_lock); cache_fresh_unlocked(old, detail); return old; } - write_unlock(&detail->hash_lock); + spin_unlock(&detail->hash_lock); } /* We need to insert a new entry */ tmp = detail->alloc(); @@ -173,7 +189,7 @@ struct cache_head *sunrpc_cache_update(struct cache_detail *detail, cache_init(tmp, detail); detail->init(tmp, old); - write_lock(&detail->hash_lock); + spin_lock(&detail->hash_lock); if (test_bit(CACHE_NEGATIVE, &new->flags)) set_bit(CACHE_NEGATIVE, &tmp->flags); else @@ -183,7 +199,7 @@ struct cache_head *sunrpc_cache_update(struct cache_detail *detail, cache_get(tmp); cache_fresh_locked(tmp, new->expiry_time, detail); cache_fresh_locked(old, 0, detail); - write_unlock(&detail->hash_lock); + spin_unlock(&detail->hash_lock); cache_fresh_unlocked(tmp, detail); cache_fresh_unlocked(old, detail); cache_put(old, detail); @@ -223,7 +239,7 @@ static int try_to_negate_entry(struct cache_detail *detail, struct cache_head *h { int rv; - write_lock(&detail->hash_lock); + spin_lock(&detail->hash_lock); rv = cache_is_valid(h); if (rv == -EAGAIN) { set_bit(CACHE_NEGATIVE, &h->flags); @@ -231,7 +247,7 @@ static int try_to_negate_entry(struct cache_detail *detail, struct cache_head *h detail); rv = -ENOENT; } - write_unlock(&detail->hash_lock); + spin_unlock(&detail->hash_lock); cache_fresh_unlocked(h, detail); return rv; } @@ -341,7 +357,7 @@ static struct delayed_work cache_cleaner; void sunrpc_init_cache_detail(struct cache_detail *cd) { - rwlock_init(&cd->hash_lock); + spin_lock_init(&cd->hash_lock); INIT_LIST_HEAD(&cd->queue); spin_lock(&cache_list_lock); cd->nextcheck = 0; @@ -361,11 +377,11 @@ void sunrpc_destroy_cache_detail(struct cache_detail *cd) { cache_purge(cd); spin_lock(&cache_list_lock); - write_lock(&cd->hash_lock); + spin_lock(&cd->hash_lock); if (current_detail == cd) current_detail = NULL; list_del_init(&cd->others); - write_unlock(&cd->hash_lock); + spin_unlock(&cd->hash_lock); spin_unlock(&cache_list_lock); if (list_empty(&cache_list)) { /* module must be being unloaded so its safe to kill the worker */ @@ -422,7 +438,7 @@ static int cache_clean(void) struct hlist_head *head; struct hlist_node *tmp; - write_lock(¤t_detail->hash_lock); + spin_lock(¤t_detail->hash_lock); /* Ok, now to clean this strand */ @@ -433,13 +449,13 @@ static int cache_clean(void) if (!cache_is_expired(current_detail, ch)) continue; - hlist_del_init(&ch->cache_list); + hlist_del_init_rcu(&ch->cache_list); current_detail->entries--; rv = 1; break; } - write_unlock(¤t_detail->hash_lock); + spin_unlock(¤t_detail->hash_lock); d = current_detail; if (!ch) current_index ++; @@ -494,9 +510,9 @@ void cache_purge(struct cache_detail *detail) struct hlist_node *tmp = NULL; int i = 0; - write_lock(&detail->hash_lock); + spin_lock(&detail->hash_lock); if (!detail->entries) { - write_unlock(&detail->hash_lock); + spin_unlock(&detail->hash_lock); return; } @@ -504,17 +520,17 @@ void cache_purge(struct cache_detail *detail) for (i = 0; i < detail->hash_size; i++) { head = &detail->hash_table[i]; hlist_for_each_entry_safe(ch, tmp, head, cache_list) { - hlist_del_init(&ch->cache_list); + hlist_del_init_rcu(&ch->cache_list); detail->entries--; set_bit(CACHE_CLEANED, &ch->flags); - write_unlock(&detail->hash_lock); + spin_unlock(&detail->hash_lock); cache_fresh_unlocked(ch, detail); cache_put(ch, detail); - write_lock(&detail->hash_lock); + spin_lock(&detail->hash_lock); } } - write_unlock(&detail->hash_lock); + spin_unlock(&detail->hash_lock); } EXPORT_SYMBOL_GPL(cache_purge); @@ -1289,21 +1305,19 @@ EXPORT_SYMBOL_GPL(qword_get); * get a header, then pass each real item in the cache */ -void *cache_seq_start(struct seq_file *m, loff_t *pos) - __acquires(cd->hash_lock) +static void *__cache_seq_start(struct seq_file *m, loff_t *pos) { loff_t n = *pos; unsigned int hash, entry; struct cache_head *ch; struct cache_detail *cd = m->private; - read_lock(&cd->hash_lock); if (!n--) return SEQ_START_TOKEN; hash = n >> 32; entry = n & ((1LL<<32) - 1); - hlist_for_each_entry(ch, &cd->hash_table[hash], cache_list) + hlist_for_each_entry_rcu(ch, &cd->hash_table[hash], cache_list) if (!entry--) return ch; n &= ~((1LL<<32) - 1); @@ -1315,12 +1329,12 @@ void *cache_seq_start(struct seq_file *m, loff_t *pos) if (hash >= cd->hash_size) return NULL; *pos = n+1; - return hlist_entry_safe(cd->hash_table[hash].first, + return hlist_entry_safe(rcu_dereference_raw( + hlist_first_rcu(&cd->hash_table[hash])), struct cache_head, cache_list); } -EXPORT_SYMBOL_GPL(cache_seq_start); -void *cache_seq_next(struct seq_file *m, void *p, loff_t *pos) +static void *cache_seq_next(struct seq_file *m, void *p, loff_t *pos) { struct cache_head *ch = p; int hash = (*pos >> 32); @@ -1333,7 +1347,8 @@ void *cache_seq_next(struct seq_file *m, void *p, loff_t *pos) *pos += 1LL<<32; } else { ++*pos; - return hlist_entry_safe(ch->cache_list.next, + return hlist_entry_safe(rcu_dereference_raw( + hlist_next_rcu(&ch->cache_list)), struct cache_head, cache_list); } *pos &= ~((1LL<<32) - 1); @@ -1345,18 +1360,32 @@ void *cache_seq_next(struct seq_file *m, void *p, loff_t *pos) if (hash >= cd->hash_size) return NULL; ++*pos; - return hlist_entry_safe(cd->hash_table[hash].first, + return hlist_entry_safe(rcu_dereference_raw( + hlist_first_rcu(&cd->hash_table[hash])), struct cache_head, cache_list); } EXPORT_SYMBOL_GPL(cache_seq_next); -void cache_seq_stop(struct seq_file *m, void *p) - __releases(cd->hash_lock) +void *cache_seq_start_rcu(struct seq_file *m, loff_t *pos) + __acquires(RCU) { - struct cache_detail *cd = m->private; - read_unlock(&cd->hash_lock); + rcu_read_lock(); + return __cache_seq_start(m, pos); +} +EXPORT_SYMBOL_GPL(cache_seq_start_rcu); + +void *cache_seq_next_rcu(struct seq_file *file, void *p, loff_t *pos) +{ + return cache_seq_next(file, p, pos); +} +EXPORT_SYMBOL_GPL(cache_seq_next_rcu); + +void cache_seq_stop_rcu(struct seq_file *m, void *p) + __releases(RCU) +{ + rcu_read_unlock(); } -EXPORT_SYMBOL_GPL(cache_seq_stop); +EXPORT_SYMBOL_GPL(cache_seq_stop_rcu); static int c_show(struct seq_file *m, void *p) { @@ -1384,9 +1413,9 @@ static int c_show(struct seq_file *m, void *p) } static const struct seq_operations cache_content_op = { - .start = cache_seq_start, - .next = cache_seq_next, - .stop = cache_seq_stop, + .start = cache_seq_start_rcu, + .next = cache_seq_next_rcu, + .stop = cache_seq_stop_rcu, .show = c_show, }; @@ -1844,13 +1873,13 @@ EXPORT_SYMBOL_GPL(sunrpc_cache_unregister_pipefs); void sunrpc_cache_unhash(struct cache_detail *cd, struct cache_head *h) { - write_lock(&cd->hash_lock); + spin_lock(&cd->hash_lock); if (!hlist_unhashed(&h->cache_list)){ - hlist_del_init(&h->cache_list); + hlist_del_init_rcu(&h->cache_list); cd->entries--; - write_unlock(&cd->hash_lock); + spin_unlock(&cd->hash_lock); cache_put(h, cd); } else - write_unlock(&cd->hash_lock); + spin_unlock(&cd->hash_lock); } EXPORT_SYMBOL_GPL(sunrpc_cache_unhash); diff --git a/net/sunrpc/clnt.c b/net/sunrpc/clnt.c index 8ea2f5fadd96..ae3b8145da35 100644 --- a/net/sunrpc/clnt.c +++ b/net/sunrpc/clnt.c @@ -61,6 +61,7 @@ static void call_start(struct rpc_task *task); static void call_reserve(struct rpc_task *task); static void call_reserveresult(struct rpc_task *task); static void call_allocate(struct rpc_task *task); +static void call_encode(struct rpc_task *task); static void call_decode(struct rpc_task *task); static void call_bind(struct rpc_task *task); static void call_bind_status(struct rpc_task *task); @@ -1137,10 +1138,10 @@ EXPORT_SYMBOL_GPL(rpc_call_async); struct rpc_task *rpc_run_bc_task(struct rpc_rqst *req) { struct rpc_task *task; - struct xdr_buf *xbufp = &req->rq_snd_buf; struct rpc_task_setup task_setup_data = { .callback_ops = &rpc_default_ops, - .flags = RPC_TASK_SOFTCONN, + .flags = RPC_TASK_SOFTCONN | + RPC_TASK_NO_RETRANS_TIMEOUT, }; dprintk("RPC: rpc_run_bc_task req= %p\n", req); @@ -1148,14 +1149,7 @@ struct rpc_task *rpc_run_bc_task(struct rpc_rqst *req) * Create an rpc_task to send the data */ task = rpc_new_task(&task_setup_data); - task->tk_rqstp = req; - - /* - * Set up the xdr_buf length. - * This also indicates that the buffer is XDR encoded already. - */ - xbufp->len = xbufp->head[0].iov_len + xbufp->page_len + - xbufp->tail[0].iov_len; + xprt_init_bc_request(req, task); task->tk_action = call_bc_transmit; atomic_inc(&task->tk_count); @@ -1558,7 +1552,6 @@ call_reserveresult(struct rpc_task *task) task->tk_status = 0; if (status >= 0) { if (task->tk_rqstp) { - xprt_request_init(task); task->tk_action = call_refresh; return; } @@ -1680,7 +1673,7 @@ call_allocate(struct rpc_task *task) dprint_status(task); task->tk_status = 0; - task->tk_action = call_bind; + task->tk_action = call_encode; if (req->rq_buffer) return; @@ -1721,22 +1714,15 @@ call_allocate(struct rpc_task *task) rpc_exit(task, -ERESTARTSYS); } -static inline int +static int rpc_task_need_encode(struct rpc_task *task) { - return task->tk_rqstp->rq_snd_buf.len == 0; + return test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate) == 0 && + (!(task->tk_flags & RPC_TASK_SENT) || + !(task->tk_flags & RPC_TASK_NO_RETRANS_TIMEOUT) || + xprt_request_need_retransmit(task)); } -static inline void -rpc_task_force_reencode(struct rpc_task *task) -{ - task->tk_rqstp->rq_snd_buf.len = 0; - task->tk_rqstp->rq_bytes_sent = 0; -} - -/* - * 3. Encode arguments of an RPC call - */ static void rpc_xdr_encode(struct rpc_task *task) { @@ -1752,6 +1738,7 @@ rpc_xdr_encode(struct rpc_task *task) xdr_buf_init(&req->rq_rcv_buf, req->rq_rbuffer, req->rq_rcvsize); + req->rq_bytes_sent = 0; p = rpc_encode_header(task); if (p == NULL) { @@ -1766,6 +1753,36 @@ rpc_xdr_encode(struct rpc_task *task) task->tk_status = rpcauth_wrap_req(task, encode, req, p, task->tk_msg.rpc_argp); + if (task->tk_status == 0) + xprt_request_prepare(req); +} + +/* + * 3. Encode arguments of an RPC call + */ +static void +call_encode(struct rpc_task *task) +{ + if (!rpc_task_need_encode(task)) + goto out; + /* Encode here so that rpcsec_gss can use correct sequence number. */ + rpc_xdr_encode(task); + /* Did the encode result in an error condition? */ + if (task->tk_status != 0) { + /* Was the error nonfatal? */ + if (task->tk_status == -EAGAIN || task->tk_status == -ENOMEM) + rpc_delay(task, HZ >> 4); + else + rpc_exit(task, task->tk_status); + return; + } + + /* Add task to reply queue before transmission to avoid races */ + if (rpc_reply_expected(task)) + xprt_request_enqueue_receive(task); + xprt_request_enqueue_transmit(task); +out: + task->tk_action = call_bind; } /* @@ -1947,43 +1964,16 @@ call_connect_status(struct rpc_task *task) static void call_transmit(struct rpc_task *task) { - int is_retrans = RPC_WAS_SENT(task); - dprint_status(task); - task->tk_action = call_status; - if (task->tk_status < 0) - return; - if (!xprt_prepare_transmit(task)) - return; - task->tk_action = call_transmit_status; - /* Encode here so that rpcsec_gss can use correct sequence number. */ - if (rpc_task_need_encode(task)) { - rpc_xdr_encode(task); - /* Did the encode result in an error condition? */ - if (task->tk_status != 0) { - /* Was the error nonfatal? */ - if (task->tk_status == -EAGAIN) - rpc_delay(task, HZ >> 4); - else - rpc_exit(task, task->tk_status); + task->tk_status = 0; + if (test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate)) { + if (!xprt_prepare_transmit(task)) return; - } + xprt_transmit(task); } - xprt_transmit(task); - if (task->tk_status < 0) - return; - if (is_retrans) - task->tk_client->cl_stats->rpcretrans++; - /* - * On success, ensure that we call xprt_end_transmit() before sleeping - * in order to allow access to the socket to other RPC requests. - */ - call_transmit_status(task); - if (rpc_reply_expected(task)) - return; - task->tk_action = rpc_exit_task; - rpc_wake_up_queued_task(&task->tk_rqstp->rq_xprt->pending, task); + task->tk_action = call_transmit_status; + xprt_end_transmit(task); } /* @@ -1999,19 +1989,17 @@ call_transmit_status(struct rpc_task *task) * test first. */ if (task->tk_status == 0) { - xprt_end_transmit(task); - rpc_task_force_reencode(task); + xprt_request_wait_receive(task); return; } switch (task->tk_status) { - case -EAGAIN: - case -ENOBUFS: - break; default: dprint_status(task); - xprt_end_transmit(task); - rpc_task_force_reencode(task); + break; + case -EBADMSG: + task->tk_status = 0; + task->tk_action = call_encode; break; /* * Special cases: if we've been waiting on the @@ -2019,6 +2007,14 @@ call_transmit_status(struct rpc_task *task) * socket just returned a connection error, * then hold onto the transport lock. */ + case -ENOBUFS: + rpc_delay(task, HZ>>2); + /* fall through */ + case -EBADSLT: + case -EAGAIN: + task->tk_action = call_transmit; + task->tk_status = 0; + break; case -ECONNREFUSED: case -EHOSTDOWN: case -ENETDOWN: @@ -2026,7 +2022,6 @@ call_transmit_status(struct rpc_task *task) case -ENETUNREACH: case -EPERM: if (RPC_IS_SOFTCONN(task)) { - xprt_end_transmit(task); if (!task->tk_msg.rpc_proc->p_proc) trace_xprt_ping(task->tk_xprt, task->tk_status); @@ -2039,7 +2034,7 @@ call_transmit_status(struct rpc_task *task) case -EADDRINUSE: case -ENOTCONN: case -EPIPE: - rpc_task_force_reencode(task); + break; } } @@ -2053,6 +2048,11 @@ call_bc_transmit(struct rpc_task *task) { struct rpc_rqst *req = task->tk_rqstp; + if (rpc_task_need_encode(task)) + xprt_request_enqueue_transmit(task); + if (!test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate)) + goto out_wakeup; + if (!xprt_prepare_transmit(task)) goto out_retry; @@ -2061,14 +2061,9 @@ call_bc_transmit(struct rpc_task *task) "error: %d\n", task->tk_status); goto out_done; } - if (req->rq_connect_cookie != req->rq_xprt->connect_cookie) - req->rq_bytes_sent = 0; xprt_transmit(task); - if (task->tk_status == -EAGAIN) - goto out_nospace; - xprt_end_transmit(task); dprint_status(task); switch (task->tk_status) { @@ -2084,6 +2079,8 @@ call_bc_transmit(struct rpc_task *task) case -ENOTCONN: case -EPIPE: break; + case -EAGAIN: + goto out_retry; case -ETIMEDOUT: /* * Problem reaching the server. Disconnect and let the @@ -2107,12 +2104,11 @@ call_bc_transmit(struct rpc_task *task) "error: %d\n", task->tk_status); break; } +out_wakeup: rpc_wake_up_queued_task(&req->rq_xprt->pending, task); out_done: task->tk_action = rpc_exit_task; return; -out_nospace: - req->rq_connect_cookie = req->rq_xprt->connect_cookie; out_retry: task->tk_status = 0; } @@ -2125,15 +2121,11 @@ static void call_status(struct rpc_task *task) { struct rpc_clnt *clnt = task->tk_client; - struct rpc_rqst *req = task->tk_rqstp; int status; if (!task->tk_msg.rpc_proc->p_proc) trace_xprt_ping(task->tk_xprt, task->tk_status); - if (req->rq_reply_bytes_recvd > 0 && !req->rq_bytes_sent) - task->tk_status = req->rq_reply_bytes_recvd; - dprint_status(task); status = task->tk_status; @@ -2173,13 +2165,8 @@ call_status(struct rpc_task *task) /* fall through */ case -EPIPE: case -ENOTCONN: - task->tk_action = call_bind; - break; - case -ENOBUFS: - rpc_delay(task, HZ>>2); - /* fall through */ case -EAGAIN: - task->tk_action = call_transmit; + task->tk_action = call_encode; break; case -EIO: /* shutdown or soft timeout */ @@ -2244,7 +2231,7 @@ call_timeout(struct rpc_task *task) rpcauth_invalcred(task); retry: - task->tk_action = call_bind; + task->tk_action = call_encode; task->tk_status = 0; } @@ -2261,6 +2248,11 @@ call_decode(struct rpc_task *task) dprint_status(task); + if (!decode) { + task->tk_action = rpc_exit_task; + return; + } + if (task->tk_flags & RPC_CALL_MAJORSEEN) { if (clnt->cl_chatty) { printk(KERN_NOTICE "%s: server %s OK\n", @@ -2283,7 +2275,7 @@ call_decode(struct rpc_task *task) if (req->rq_rcv_buf.len < 12) { if (!RPC_IS_SOFT(task)) { - task->tk_action = call_bind; + task->tk_action = call_encode; goto out_retry; } dprintk("RPC: %s: too small RPC reply size (%d bytes)\n", @@ -2298,13 +2290,11 @@ call_decode(struct rpc_task *task) goto out_retry; return; } - task->tk_action = rpc_exit_task; - if (decode) { - task->tk_status = rpcauth_unwrap_resp(task, decode, req, p, - task->tk_msg.rpc_resp); - } + task->tk_status = rpcauth_unwrap_resp(task, decode, req, p, + task->tk_msg.rpc_resp); + dprintk("RPC: %5u call_decode result %d\n", task->tk_pid, task->tk_status); return; @@ -2416,7 +2406,7 @@ rpc_verify_header(struct rpc_task *task) task->tk_garb_retry--; dprintk("RPC: %5u %s: retry garbled creds\n", task->tk_pid, __func__); - task->tk_action = call_bind; + task->tk_action = call_encode; goto out_retry; case RPC_AUTH_TOOWEAK: printk(KERN_NOTICE "RPC: server %s requires stronger " @@ -2485,7 +2475,7 @@ out_garbage: task->tk_garb_retry--; dprintk("RPC: %5u %s: retrying\n", task->tk_pid, __func__); - task->tk_action = call_bind; + task->tk_action = call_encode; out_retry: return ERR_PTR(-EAGAIN); } diff --git a/net/sunrpc/sched.c b/net/sunrpc/sched.c index 3fe5d60ab0e2..57ca5bead1cb 100644 --- a/net/sunrpc/sched.c +++ b/net/sunrpc/sched.c @@ -99,65 +99,79 @@ __rpc_add_timer(struct rpc_wait_queue *queue, struct rpc_task *task) list_add(&task->u.tk_wait.timer_list, &queue->timer_list.list); } -static void rpc_rotate_queue_owner(struct rpc_wait_queue *queue) -{ - struct list_head *q = &queue->tasks[queue->priority]; - struct rpc_task *task; - - if (!list_empty(q)) { - task = list_first_entry(q, struct rpc_task, u.tk_wait.list); - if (task->tk_owner == queue->owner) - list_move_tail(&task->u.tk_wait.list, q); - } -} - static void rpc_set_waitqueue_priority(struct rpc_wait_queue *queue, int priority) { if (queue->priority != priority) { - /* Fairness: rotate the list when changing priority */ - rpc_rotate_queue_owner(queue); queue->priority = priority; + queue->nr = 1U << priority; } } -static void rpc_set_waitqueue_owner(struct rpc_wait_queue *queue, pid_t pid) -{ - queue->owner = pid; - queue->nr = RPC_BATCH_COUNT; -} - static void rpc_reset_waitqueue_priority(struct rpc_wait_queue *queue) { rpc_set_waitqueue_priority(queue, queue->maxpriority); - rpc_set_waitqueue_owner(queue, 0); } /* - * Add new request to a priority queue. + * Add a request to a queue list */ -static void __rpc_add_wait_queue_priority(struct rpc_wait_queue *queue, - struct rpc_task *task, - unsigned char queue_priority) +static void +__rpc_list_enqueue_task(struct list_head *q, struct rpc_task *task) { - struct list_head *q; struct rpc_task *t; - INIT_LIST_HEAD(&task->u.tk_wait.links); - if (unlikely(queue_priority > queue->maxpriority)) - queue_priority = queue->maxpriority; - if (queue_priority > queue->priority) - rpc_set_waitqueue_priority(queue, queue_priority); - q = &queue->tasks[queue_priority]; list_for_each_entry(t, q, u.tk_wait.list) { if (t->tk_owner == task->tk_owner) { - list_add_tail(&task->u.tk_wait.list, &t->u.tk_wait.links); + list_add_tail(&task->u.tk_wait.links, + &t->u.tk_wait.links); + /* Cache the queue head in task->u.tk_wait.list */ + task->u.tk_wait.list.next = q; + task->u.tk_wait.list.prev = NULL; return; } } + INIT_LIST_HEAD(&task->u.tk_wait.links); list_add_tail(&task->u.tk_wait.list, q); } /* + * Remove request from a queue list + */ +static void +__rpc_list_dequeue_task(struct rpc_task *task) +{ + struct list_head *q; + struct rpc_task *t; + + if (task->u.tk_wait.list.prev == NULL) { + list_del(&task->u.tk_wait.links); + return; + } + if (!list_empty(&task->u.tk_wait.links)) { + t = list_first_entry(&task->u.tk_wait.links, + struct rpc_task, + u.tk_wait.links); + /* Assume __rpc_list_enqueue_task() cached the queue head */ + q = t->u.tk_wait.list.next; + list_add_tail(&t->u.tk_wait.list, q); + list_del(&task->u.tk_wait.links); + } + list_del(&task->u.tk_wait.list); +} + +/* + * Add new request to a priority queue. + */ +static void __rpc_add_wait_queue_priority(struct rpc_wait_queue *queue, + struct rpc_task *task, + unsigned char queue_priority) +{ + if (unlikely(queue_priority > queue->maxpriority)) + queue_priority = queue->maxpriority; + __rpc_list_enqueue_task(&queue->tasks[queue_priority], task); +} + +/* * Add new request to wait queue. * * Swapper tasks always get inserted at the head of the queue. @@ -194,13 +208,7 @@ static void __rpc_add_wait_queue(struct rpc_wait_queue *queue, */ static void __rpc_remove_wait_queue_priority(struct rpc_task *task) { - struct rpc_task *t; - - if (!list_empty(&task->u.tk_wait.links)) { - t = list_entry(task->u.tk_wait.links.next, struct rpc_task, u.tk_wait.list); - list_move(&t->u.tk_wait.list, &task->u.tk_wait.list); - list_splice_init(&task->u.tk_wait.links, &t->u.tk_wait.links); - } + __rpc_list_dequeue_task(task); } /* @@ -212,7 +220,8 @@ static void __rpc_remove_wait_queue(struct rpc_wait_queue *queue, struct rpc_tas __rpc_disable_timer(queue, task); if (RPC_IS_PRIORITY(queue)) __rpc_remove_wait_queue_priority(task); - list_del(&task->u.tk_wait.list); + else + list_del(&task->u.tk_wait.list); queue->qlen--; dprintk("RPC: %5u removed from queue %p \"%s\"\n", task->tk_pid, queue, rpc_qname(queue)); @@ -440,14 +449,28 @@ static void __rpc_do_wake_up_task_on_wq(struct workqueue_struct *wq, /* * Wake up a queued task while the queue lock is being held */ -static void rpc_wake_up_task_on_wq_queue_locked(struct workqueue_struct *wq, - struct rpc_wait_queue *queue, struct rpc_task *task) +static struct rpc_task * +rpc_wake_up_task_on_wq_queue_action_locked(struct workqueue_struct *wq, + struct rpc_wait_queue *queue, struct rpc_task *task, + bool (*action)(struct rpc_task *, void *), void *data) { if (RPC_IS_QUEUED(task)) { smp_rmb(); - if (task->tk_waitqueue == queue) - __rpc_do_wake_up_task_on_wq(wq, queue, task); + if (task->tk_waitqueue == queue) { + if (action == NULL || action(task, data)) { + __rpc_do_wake_up_task_on_wq(wq, queue, task); + return task; + } + } } + return NULL; +} + +static void +rpc_wake_up_task_on_wq_queue_locked(struct workqueue_struct *wq, + struct rpc_wait_queue *queue, struct rpc_task *task) +{ + rpc_wake_up_task_on_wq_queue_action_locked(wq, queue, task, NULL, NULL); } /* @@ -465,6 +488,8 @@ void rpc_wake_up_queued_task_on_wq(struct workqueue_struct *wq, struct rpc_wait_queue *queue, struct rpc_task *task) { + if (!RPC_IS_QUEUED(task)) + return; spin_lock_bh(&queue->lock); rpc_wake_up_task_on_wq_queue_locked(wq, queue, task); spin_unlock_bh(&queue->lock); @@ -475,12 +500,48 @@ void rpc_wake_up_queued_task_on_wq(struct workqueue_struct *wq, */ void rpc_wake_up_queued_task(struct rpc_wait_queue *queue, struct rpc_task *task) { + if (!RPC_IS_QUEUED(task)) + return; spin_lock_bh(&queue->lock); rpc_wake_up_task_queue_locked(queue, task); spin_unlock_bh(&queue->lock); } EXPORT_SYMBOL_GPL(rpc_wake_up_queued_task); +static bool rpc_task_action_set_status(struct rpc_task *task, void *status) +{ + task->tk_status = *(int *)status; + return true; +} + +static void +rpc_wake_up_task_queue_set_status_locked(struct rpc_wait_queue *queue, + struct rpc_task *task, int status) +{ + rpc_wake_up_task_on_wq_queue_action_locked(rpciod_workqueue, queue, + task, rpc_task_action_set_status, &status); +} + +/** + * rpc_wake_up_queued_task_set_status - wake up a task and set task->tk_status + * @queue: pointer to rpc_wait_queue + * @task: pointer to rpc_task + * @status: integer error value + * + * If @task is queued on @queue, then it is woken up, and @task->tk_status is + * set to the value of @status. + */ +void +rpc_wake_up_queued_task_set_status(struct rpc_wait_queue *queue, + struct rpc_task *task, int status) +{ + if (!RPC_IS_QUEUED(task)) + return; + spin_lock_bh(&queue->lock); + rpc_wake_up_task_queue_set_status_locked(queue, task, status); + spin_unlock_bh(&queue->lock); +} + /* * Wake up the next task on a priority queue. */ @@ -493,17 +554,9 @@ static struct rpc_task *__rpc_find_next_queued_priority(struct rpc_wait_queue *q * Service a batch of tasks from a single owner. */ q = &queue->tasks[queue->priority]; - if (!list_empty(q)) { - task = list_entry(q->next, struct rpc_task, u.tk_wait.list); - if (queue->owner == task->tk_owner) { - if (--queue->nr) - goto out; - list_move_tail(&task->u.tk_wait.list, q); - } - /* - * Check if we need to switch queues. - */ - goto new_owner; + if (!list_empty(q) && --queue->nr) { + task = list_first_entry(q, struct rpc_task, u.tk_wait.list); + goto out; } /* @@ -515,7 +568,7 @@ static struct rpc_task *__rpc_find_next_queued_priority(struct rpc_wait_queue *q else q = q - 1; if (!list_empty(q)) { - task = list_entry(q->next, struct rpc_task, u.tk_wait.list); + task = list_first_entry(q, struct rpc_task, u.tk_wait.list); goto new_queue; } } while (q != &queue->tasks[queue->priority]); @@ -525,8 +578,6 @@ static struct rpc_task *__rpc_find_next_queued_priority(struct rpc_wait_queue *q new_queue: rpc_set_waitqueue_priority(queue, (unsigned int)(q - &queue->tasks[0])); -new_owner: - rpc_set_waitqueue_owner(queue, task->tk_owner); out: return task; } @@ -553,12 +604,9 @@ struct rpc_task *rpc_wake_up_first_on_wq(struct workqueue_struct *wq, queue, rpc_qname(queue)); spin_lock_bh(&queue->lock); task = __rpc_find_next_queued(queue); - if (task != NULL) { - if (func(task, data)) - rpc_wake_up_task_on_wq_queue_locked(wq, queue, task); - else - task = NULL; - } + if (task != NULL) + task = rpc_wake_up_task_on_wq_queue_action_locked(wq, queue, + task, func, data); spin_unlock_bh(&queue->lock); return task; diff --git a/net/sunrpc/socklib.c b/net/sunrpc/socklib.c index f217c348b341..9062967575c4 100644 --- a/net/sunrpc/socklib.c +++ b/net/sunrpc/socklib.c @@ -26,7 +26,8 @@ * Possibly called several times to iterate over an sk_buff and copy * data out of it. */ -size_t xdr_skb_read_bits(struct xdr_skb_reader *desc, void *to, size_t len) +static size_t +xdr_skb_read_bits(struct xdr_skb_reader *desc, void *to, size_t len) { if (len > desc->count) len = desc->count; @@ -36,7 +37,6 @@ size_t xdr_skb_read_bits(struct xdr_skb_reader *desc, void *to, size_t len) desc->offset += len; return len; } -EXPORT_SYMBOL_GPL(xdr_skb_read_bits); /** * xdr_skb_read_and_csum_bits - copy and checksum from skb to buffer @@ -69,7 +69,8 @@ static size_t xdr_skb_read_and_csum_bits(struct xdr_skb_reader *desc, void *to, * @copy_actor: virtual method for copying data * */ -ssize_t xdr_partial_copy_from_skb(struct xdr_buf *xdr, unsigned int base, struct xdr_skb_reader *desc, xdr_skb_read_actor copy_actor) +static ssize_t +xdr_partial_copy_from_skb(struct xdr_buf *xdr, unsigned int base, struct xdr_skb_reader *desc, xdr_skb_read_actor copy_actor) { struct page **ppage = xdr->pages; unsigned int len, pglen = xdr->page_len; @@ -104,7 +105,7 @@ ssize_t xdr_partial_copy_from_skb(struct xdr_buf *xdr, unsigned int base, struct /* ACL likes to be lazy in allocating pages - ACLs * are small by default but can get huge. */ - if (unlikely(*ppage == NULL)) { + if ((xdr->flags & XDRBUF_SPARSE_PAGES) && *ppage == NULL) { *ppage = alloc_page(GFP_ATOMIC); if (unlikely(*ppage == NULL)) { if (copied == 0) @@ -140,7 +141,6 @@ copy_tail: out: return copied; } -EXPORT_SYMBOL_GPL(xdr_partial_copy_from_skb); /** * csum_partial_copy_to_xdr - checksum and copy data diff --git a/net/sunrpc/svc_xprt.c b/net/sunrpc/svc_xprt.c index 5185efb9027b..51d36230b6e3 100644 --- a/net/sunrpc/svc_xprt.c +++ b/net/sunrpc/svc_xprt.c @@ -171,7 +171,6 @@ void svc_xprt_init(struct net *net, struct svc_xprt_class *xcl, mutex_init(&xprt->xpt_mutex); spin_lock_init(&xprt->xpt_lock); set_bit(XPT_BUSY, &xprt->xpt_flags); - rpc_init_wait_queue(&xprt->xpt_bc_pending, "xpt_bc_pending"); xprt->xpt_net = get_net(net); strcpy(xprt->xpt_remotebuf, "uninitialized"); } @@ -895,7 +894,6 @@ int svc_send(struct svc_rqst *rqstp) else len = xprt->xpt_ops->xpo_sendto(rqstp); mutex_unlock(&xprt->xpt_mutex); - rpc_wake_up(&xprt->xpt_bc_pending); trace_svc_send(rqstp, len); svc_xprt_release(rqstp); @@ -989,7 +987,7 @@ static void call_xpt_users(struct svc_xprt *xprt) spin_lock(&xprt->xpt_lock); while (!list_empty(&xprt->xpt_users)) { u = list_first_entry(&xprt->xpt_users, struct svc_xpt_user, list); - list_del(&u->list); + list_del_init(&u->list); u->callback(u); } spin_unlock(&xprt->xpt_lock); diff --git a/net/sunrpc/svcauth.c b/net/sunrpc/svcauth.c index bb8db3cb8032..775b8c94265b 100644 --- a/net/sunrpc/svcauth.c +++ b/net/sunrpc/svcauth.c @@ -27,12 +27,32 @@ extern struct auth_ops svcauth_null; extern struct auth_ops svcauth_unix; -static DEFINE_SPINLOCK(authtab_lock); -static struct auth_ops *authtab[RPC_AUTH_MAXFLAVOR] = { - [0] = &svcauth_null, - [1] = &svcauth_unix, +static struct auth_ops __rcu *authtab[RPC_AUTH_MAXFLAVOR] = { + [RPC_AUTH_NULL] = (struct auth_ops __force __rcu *)&svcauth_null, + [RPC_AUTH_UNIX] = (struct auth_ops __force __rcu *)&svcauth_unix, }; +static struct auth_ops * +svc_get_auth_ops(rpc_authflavor_t flavor) +{ + struct auth_ops *aops; + + if (flavor >= RPC_AUTH_MAXFLAVOR) + return NULL; + rcu_read_lock(); + aops = rcu_dereference(authtab[flavor]); + if (aops != NULL && !try_module_get(aops->owner)) + aops = NULL; + rcu_read_unlock(); + return aops; +} + +static void +svc_put_auth_ops(struct auth_ops *aops) +{ + module_put(aops->owner); +} + int svc_authenticate(struct svc_rqst *rqstp, __be32 *authp) { @@ -45,14 +65,11 @@ svc_authenticate(struct svc_rqst *rqstp, __be32 *authp) dprintk("svc: svc_authenticate (%d)\n", flavor); - spin_lock(&authtab_lock); - if (flavor >= RPC_AUTH_MAXFLAVOR || !(aops = authtab[flavor]) || - !try_module_get(aops->owner)) { - spin_unlock(&authtab_lock); + aops = svc_get_auth_ops(flavor); + if (aops == NULL) { *authp = rpc_autherr_badcred; return SVC_DENIED; } - spin_unlock(&authtab_lock); rqstp->rq_auth_slack = 0; init_svc_cred(&rqstp->rq_cred); @@ -82,7 +99,7 @@ int svc_authorise(struct svc_rqst *rqstp) if (aops) { rv = aops->release(rqstp); - module_put(aops->owner); + svc_put_auth_ops(aops); } return rv; } @@ -90,13 +107,14 @@ int svc_authorise(struct svc_rqst *rqstp) int svc_auth_register(rpc_authflavor_t flavor, struct auth_ops *aops) { + struct auth_ops *old; int rv = -EINVAL; - spin_lock(&authtab_lock); - if (flavor < RPC_AUTH_MAXFLAVOR && authtab[flavor] == NULL) { - authtab[flavor] = aops; - rv = 0; + + if (flavor < RPC_AUTH_MAXFLAVOR) { + old = cmpxchg((struct auth_ops ** __force)&authtab[flavor], NULL, aops); + if (old == NULL || old == aops) + rv = 0; } - spin_unlock(&authtab_lock); return rv; } EXPORT_SYMBOL_GPL(svc_auth_register); @@ -104,10 +122,8 @@ EXPORT_SYMBOL_GPL(svc_auth_register); void svc_auth_unregister(rpc_authflavor_t flavor) { - spin_lock(&authtab_lock); if (flavor < RPC_AUTH_MAXFLAVOR) - authtab[flavor] = NULL; - spin_unlock(&authtab_lock); + rcu_assign_pointer(authtab[flavor], NULL); } EXPORT_SYMBOL_GPL(svc_auth_unregister); @@ -127,10 +143,11 @@ static struct hlist_head auth_domain_table[DN_HASHMAX]; static DEFINE_SPINLOCK(auth_domain_lock); static void auth_domain_release(struct kref *kref) + __releases(&auth_domain_lock) { struct auth_domain *dom = container_of(kref, struct auth_domain, ref); - hlist_del(&dom->hash); + hlist_del_rcu(&dom->hash); dom->flavour->domain_release(dom); spin_unlock(&auth_domain_lock); } @@ -159,7 +176,7 @@ auth_domain_lookup(char *name, struct auth_domain *new) } } if (new) - hlist_add_head(&new->hash, head); + hlist_add_head_rcu(&new->hash, head); spin_unlock(&auth_domain_lock); return new; } @@ -167,6 +184,21 @@ EXPORT_SYMBOL_GPL(auth_domain_lookup); struct auth_domain *auth_domain_find(char *name) { - return auth_domain_lookup(name, NULL); + struct auth_domain *hp; + struct hlist_head *head; + + head = &auth_domain_table[hash_str(name, DN_HASHBITS)]; + + rcu_read_lock(); + hlist_for_each_entry_rcu(hp, head, hash) { + if (strcmp(hp->name, name)==0) { + if (!kref_get_unless_zero(&hp->ref)) + hp = NULL; + rcu_read_unlock(); + return hp; + } + } + rcu_read_unlock(); + return NULL; } EXPORT_SYMBOL_GPL(auth_domain_find); diff --git a/net/sunrpc/svcauth_unix.c b/net/sunrpc/svcauth_unix.c index af7f28fb8102..fb9041b92f72 100644 --- a/net/sunrpc/svcauth_unix.c +++ b/net/sunrpc/svcauth_unix.c @@ -37,20 +37,26 @@ struct unix_domain { extern struct auth_ops svcauth_null; extern struct auth_ops svcauth_unix; -static void svcauth_unix_domain_release(struct auth_domain *dom) +static void svcauth_unix_domain_release_rcu(struct rcu_head *head) { + struct auth_domain *dom = container_of(head, struct auth_domain, rcu_head); struct unix_domain *ud = container_of(dom, struct unix_domain, h); kfree(dom->name); kfree(ud); } +static void svcauth_unix_domain_release(struct auth_domain *dom) +{ + call_rcu(&dom->rcu_head, svcauth_unix_domain_release_rcu); +} + struct auth_domain *unix_domain_find(char *name) { struct auth_domain *rv; struct unix_domain *new = NULL; - rv = auth_domain_lookup(name, NULL); + rv = auth_domain_find(name); while(1) { if (rv) { if (new && rv != &new->h) @@ -91,6 +97,7 @@ struct ip_map { char m_class[8]; /* e.g. "nfsd" */ struct in6_addr m_addr; struct unix_domain *m_client; + struct rcu_head m_rcu; }; static void ip_map_put(struct kref *kref) @@ -101,7 +108,7 @@ static void ip_map_put(struct kref *kref) if (test_bit(CACHE_VALID, &item->flags) && !test_bit(CACHE_NEGATIVE, &item->flags)) auth_domain_put(&im->m_client->h); - kfree(im); + kfree_rcu(im, m_rcu); } static inline int hash_ip6(const struct in6_addr *ip) @@ -280,9 +287,9 @@ static struct ip_map *__ip_map_lookup(struct cache_detail *cd, char *class, strcpy(ip.m_class, class); ip.m_addr = *addr; - ch = sunrpc_cache_lookup(cd, &ip.h, - hash_str(class, IP_HASHBITS) ^ - hash_ip6(addr)); + ch = sunrpc_cache_lookup_rcu(cd, &ip.h, + hash_str(class, IP_HASHBITS) ^ + hash_ip6(addr)); if (ch) return container_of(ch, struct ip_map, h); @@ -412,6 +419,7 @@ struct unix_gid { struct cache_head h; kuid_t uid; struct group_info *gi; + struct rcu_head rcu; }; static int unix_gid_hash(kuid_t uid) @@ -426,7 +434,7 @@ static void unix_gid_put(struct kref *kref) if (test_bit(CACHE_VALID, &item->flags) && !test_bit(CACHE_NEGATIVE, &item->flags)) put_group_info(ug->gi); - kfree(ug); + kfree_rcu(ug, rcu); } static int unix_gid_match(struct cache_head *corig, struct cache_head *cnew) @@ -619,7 +627,7 @@ static struct unix_gid *unix_gid_lookup(struct cache_detail *cd, kuid_t uid) struct cache_head *ch; ug.uid = uid; - ch = sunrpc_cache_lookup(cd, &ug.h, unix_gid_hash(uid)); + ch = sunrpc_cache_lookup_rcu(cd, &ug.h, unix_gid_hash(uid)); if (ch) return container_of(ch, struct unix_gid, h); else diff --git a/net/sunrpc/svcsock.c b/net/sunrpc/svcsock.c index 5445145e639c..986f3ed7d1a2 100644 --- a/net/sunrpc/svcsock.c +++ b/net/sunrpc/svcsock.c @@ -325,59 +325,34 @@ static int svc_one_sock_name(struct svc_sock *svsk, char *buf, int remaining) /* * Generic recvfrom routine. */ -static int svc_recvfrom(struct svc_rqst *rqstp, struct kvec *iov, int nr, - int buflen) +static ssize_t svc_recvfrom(struct svc_rqst *rqstp, struct kvec *iov, + unsigned int nr, size_t buflen, unsigned int base) { struct svc_sock *svsk = container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt); - struct msghdr msg = { - .msg_flags = MSG_DONTWAIT, - }; - int len; + struct msghdr msg = { NULL }; + ssize_t len; rqstp->rq_xprt_hlen = 0; clear_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); - iov_iter_kvec(&msg.msg_iter, READ | ITER_KVEC, iov, nr, buflen); - len = sock_recvmsg(svsk->sk_sock, &msg, msg.msg_flags); + iov_iter_kvec(&msg.msg_iter, READ, iov, nr, buflen); + if (base != 0) { + iov_iter_advance(&msg.msg_iter, base); + buflen -= base; + } + len = sock_recvmsg(svsk->sk_sock, &msg, MSG_DONTWAIT); /* If we read a full record, then assume there may be more * data to read (stream based sockets only!) */ if (len == buflen) set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); - dprintk("svc: socket %p recvfrom(%p, %zu) = %d\n", + dprintk("svc: socket %p recvfrom(%p, %zu) = %zd\n", svsk, iov[0].iov_base, iov[0].iov_len, len); return len; } -static int svc_partial_recvfrom(struct svc_rqst *rqstp, - struct kvec *iov, int nr, - int buflen, unsigned int base) -{ - size_t save_iovlen; - void *save_iovbase; - unsigned int i; - int ret; - - if (base == 0) - return svc_recvfrom(rqstp, iov, nr, buflen); - - for (i = 0; i < nr; i++) { - if (iov[i].iov_len > base) - break; - base -= iov[i].iov_len; - } - save_iovlen = iov[i].iov_len; - save_iovbase = iov[i].iov_base; - iov[i].iov_len -= base; - iov[i].iov_base += base; - ret = svc_recvfrom(rqstp, &iov[i], nr - i, buflen); - iov[i].iov_len = save_iovlen; - iov[i].iov_base = save_iovbase; - return ret; -} - /* * Set socket snd and rcv buffer lengths */ @@ -962,7 +937,8 @@ static int svc_tcp_recv_record(struct svc_sock *svsk, struct svc_rqst *rqstp) want = sizeof(rpc_fraghdr) - svsk->sk_tcplen; iov.iov_base = ((char *) &svsk->sk_reclen) + svsk->sk_tcplen; iov.iov_len = want; - if ((len = svc_recvfrom(rqstp, &iov, 1, want)) < 0) + len = svc_recvfrom(rqstp, &iov, 1, want, 0); + if (len < 0) goto error; svsk->sk_tcplen += len; @@ -1004,7 +980,7 @@ static int receive_cb_reply(struct svc_sock *svsk, struct svc_rqst *rqstp) if (!bc_xprt) return -EAGAIN; - spin_lock(&bc_xprt->recv_lock); + spin_lock(&bc_xprt->queue_lock); req = xprt_lookup_rqst(bc_xprt, xid); if (!req) goto unlock_notfound; @@ -1022,7 +998,7 @@ static int receive_cb_reply(struct svc_sock *svsk, struct svc_rqst *rqstp) memcpy(dst->iov_base, src->iov_base, src->iov_len); xprt_complete_rqst(req->rq_task, rqstp->rq_arg.len); rqstp->rq_arg.len = 0; - spin_unlock(&bc_xprt->recv_lock); + spin_unlock(&bc_xprt->queue_lock); return 0; unlock_notfound: printk(KERN_NOTICE @@ -1031,7 +1007,7 @@ unlock_notfound: __func__, ntohl(calldir), bc_xprt, ntohl(xid)); unlock_eagain: - spin_unlock(&bc_xprt->recv_lock); + spin_unlock(&bc_xprt->queue_lock); return -EAGAIN; } @@ -1088,14 +1064,13 @@ static int svc_tcp_recvfrom(struct svc_rqst *rqstp) vec = rqstp->rq_vec; - pnum = copy_pages_to_kvecs(&vec[0], &rqstp->rq_pages[0], - svsk->sk_datalen + want); + pnum = copy_pages_to_kvecs(&vec[0], &rqstp->rq_pages[0], base + want); rqstp->rq_respages = &rqstp->rq_pages[pnum]; rqstp->rq_next_page = rqstp->rq_respages + 1; /* Now receive data */ - len = svc_partial_recvfrom(rqstp, vec, pnum, want, base); + len = svc_recvfrom(rqstp, vec, pnum, base + want, base); if (len >= 0) { svsk->sk_tcplen += len; svsk->sk_datalen += len; diff --git a/net/sunrpc/xdr.c b/net/sunrpc/xdr.c index 30afbd236656..2bbb8d38d2bf 100644 --- a/net/sunrpc/xdr.c +++ b/net/sunrpc/xdr.c @@ -15,6 +15,7 @@ #include <linux/errno.h> #include <linux/sunrpc/xdr.h> #include <linux/sunrpc/msg_prot.h> +#include <linux/bvec.h> /* * XDR functions for basic NFS types @@ -128,6 +129,39 @@ xdr_terminate_string(struct xdr_buf *buf, const u32 len) } EXPORT_SYMBOL_GPL(xdr_terminate_string); +size_t +xdr_buf_pagecount(struct xdr_buf *buf) +{ + if (!buf->page_len) + return 0; + return (buf->page_base + buf->page_len + PAGE_SIZE - 1) >> PAGE_SHIFT; +} + +int +xdr_alloc_bvec(struct xdr_buf *buf, gfp_t gfp) +{ + size_t i, n = xdr_buf_pagecount(buf); + + if (n != 0 && buf->bvec == NULL) { + buf->bvec = kmalloc_array(n, sizeof(buf->bvec[0]), gfp); + if (!buf->bvec) + return -ENOMEM; + for (i = 0; i < n; i++) { + buf->bvec[i].bv_page = buf->pages[i]; + buf->bvec[i].bv_len = PAGE_SIZE; + buf->bvec[i].bv_offset = 0; + } + } + return 0; +} + +void +xdr_free_bvec(struct xdr_buf *buf) +{ + kfree(buf->bvec); + buf->bvec = NULL; +} + void xdr_inline_pages(struct xdr_buf *xdr, unsigned int offset, struct page **pages, unsigned int base, unsigned int len) diff --git a/net/sunrpc/xprt.c b/net/sunrpc/xprt.c index a8db2e3f8904..86bea4520c4d 100644 --- a/net/sunrpc/xprt.c +++ b/net/sunrpc/xprt.c @@ -68,8 +68,6 @@ static void xprt_init(struct rpc_xprt *xprt, struct net *net); static __be32 xprt_alloc_xid(struct rpc_xprt *xprt); static void xprt_connect_status(struct rpc_task *task); -static int __xprt_get_cong(struct rpc_xprt *, struct rpc_task *); -static void __xprt_put_cong(struct rpc_xprt *, struct rpc_rqst *); static void xprt_destroy(struct rpc_xprt *xprt); static DEFINE_SPINLOCK(xprt_list_lock); @@ -171,6 +169,17 @@ out: } EXPORT_SYMBOL_GPL(xprt_load_transport); +static void xprt_clear_locked(struct rpc_xprt *xprt) +{ + xprt->snd_task = NULL; + if (!test_bit(XPRT_CLOSE_WAIT, &xprt->state)) { + smp_mb__before_atomic(); + clear_bit(XPRT_LOCKED, &xprt->state); + smp_mb__after_atomic(); + } else + queue_work(xprtiod_workqueue, &xprt->task_cleanup); +} + /** * xprt_reserve_xprt - serialize write access to transports * @task: task that is requesting access to the transport @@ -183,44 +192,53 @@ EXPORT_SYMBOL_GPL(xprt_load_transport); int xprt_reserve_xprt(struct rpc_xprt *xprt, struct rpc_task *task) { struct rpc_rqst *req = task->tk_rqstp; - int priority; if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) { if (task == xprt->snd_task) return 1; goto out_sleep; } + if (test_bit(XPRT_WRITE_SPACE, &xprt->state)) + goto out_unlock; xprt->snd_task = task; - if (req != NULL) - req->rq_ntrans++; return 1; +out_unlock: + xprt_clear_locked(xprt); out_sleep: dprintk("RPC: %5u failed to lock transport %p\n", task->tk_pid, xprt); - task->tk_timeout = 0; + task->tk_timeout = RPC_IS_SOFT(task) ? req->rq_timeout : 0; task->tk_status = -EAGAIN; - if (req == NULL) - priority = RPC_PRIORITY_LOW; - else if (!req->rq_ntrans) - priority = RPC_PRIORITY_NORMAL; - else - priority = RPC_PRIORITY_HIGH; - rpc_sleep_on_priority(&xprt->sending, task, NULL, priority); + rpc_sleep_on(&xprt->sending, task, NULL); return 0; } EXPORT_SYMBOL_GPL(xprt_reserve_xprt); -static void xprt_clear_locked(struct rpc_xprt *xprt) +static bool +xprt_need_congestion_window_wait(struct rpc_xprt *xprt) { - xprt->snd_task = NULL; - if (!test_bit(XPRT_CLOSE_WAIT, &xprt->state)) { - smp_mb__before_atomic(); - clear_bit(XPRT_LOCKED, &xprt->state); - smp_mb__after_atomic(); - } else - queue_work(xprtiod_workqueue, &xprt->task_cleanup); + return test_bit(XPRT_CWND_WAIT, &xprt->state); +} + +static void +xprt_set_congestion_window_wait(struct rpc_xprt *xprt) +{ + if (!list_empty(&xprt->xmit_queue)) { + /* Peek at head of queue to see if it can make progress */ + if (list_first_entry(&xprt->xmit_queue, struct rpc_rqst, + rq_xmit)->rq_cong) + return; + } + set_bit(XPRT_CWND_WAIT, &xprt->state); +} + +static void +xprt_test_and_clear_congestion_window_wait(struct rpc_xprt *xprt) +{ + if (!RPCXPRT_CONGESTED(xprt)) + clear_bit(XPRT_CWND_WAIT, &xprt->state); } /* @@ -230,11 +248,11 @@ static void xprt_clear_locked(struct rpc_xprt *xprt) * Same as xprt_reserve_xprt, but Van Jacobson congestion control is * integrated into the decision of whether a request is allowed to be * woken up and given access to the transport. + * Note that the lock is only granted if we know there are free slots. */ int xprt_reserve_xprt_cong(struct rpc_xprt *xprt, struct rpc_task *task) { struct rpc_rqst *req = task->tk_rqstp; - int priority; if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) { if (task == xprt->snd_task) @@ -245,25 +263,19 @@ int xprt_reserve_xprt_cong(struct rpc_xprt *xprt, struct rpc_task *task) xprt->snd_task = task; return 1; } - if (__xprt_get_cong(xprt, task)) { + if (test_bit(XPRT_WRITE_SPACE, &xprt->state)) + goto out_unlock; + if (!xprt_need_congestion_window_wait(xprt)) { xprt->snd_task = task; - req->rq_ntrans++; return 1; } +out_unlock: xprt_clear_locked(xprt); out_sleep: - if (req) - __xprt_put_cong(xprt, req); dprintk("RPC: %5u failed to lock transport %p\n", task->tk_pid, xprt); - task->tk_timeout = 0; + task->tk_timeout = RPC_IS_SOFT(task) ? req->rq_timeout : 0; task->tk_status = -EAGAIN; - if (req == NULL) - priority = RPC_PRIORITY_LOW; - else if (!req->rq_ntrans) - priority = RPC_PRIORITY_NORMAL; - else - priority = RPC_PRIORITY_HIGH; - rpc_sleep_on_priority(&xprt->sending, task, NULL, priority); + rpc_sleep_on(&xprt->sending, task, NULL); return 0; } EXPORT_SYMBOL_GPL(xprt_reserve_xprt_cong); @@ -272,6 +284,8 @@ static inline int xprt_lock_write(struct rpc_xprt *xprt, struct rpc_task *task) { int retval; + if (test_bit(XPRT_LOCKED, &xprt->state) && xprt->snd_task == task) + return 1; spin_lock_bh(&xprt->transport_lock); retval = xprt->ops->reserve_xprt(xprt, task); spin_unlock_bh(&xprt->transport_lock); @@ -281,12 +295,8 @@ static inline int xprt_lock_write(struct rpc_xprt *xprt, struct rpc_task *task) static bool __xprt_lock_write_func(struct rpc_task *task, void *data) { struct rpc_xprt *xprt = data; - struct rpc_rqst *req; - req = task->tk_rqstp; xprt->snd_task = task; - if (req) - req->rq_ntrans++; return true; } @@ -294,53 +304,30 @@ static void __xprt_lock_write_next(struct rpc_xprt *xprt) { if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) return; - + if (test_bit(XPRT_WRITE_SPACE, &xprt->state)) + goto out_unlock; if (rpc_wake_up_first_on_wq(xprtiod_workqueue, &xprt->sending, __xprt_lock_write_func, xprt)) return; +out_unlock: xprt_clear_locked(xprt); } -static bool __xprt_lock_write_cong_func(struct rpc_task *task, void *data) -{ - struct rpc_xprt *xprt = data; - struct rpc_rqst *req; - - req = task->tk_rqstp; - if (req == NULL) { - xprt->snd_task = task; - return true; - } - if (__xprt_get_cong(xprt, task)) { - xprt->snd_task = task; - req->rq_ntrans++; - return true; - } - return false; -} - static void __xprt_lock_write_next_cong(struct rpc_xprt *xprt) { if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) return; - if (RPCXPRT_CONGESTED(xprt)) + if (test_bit(XPRT_WRITE_SPACE, &xprt->state)) + goto out_unlock; + if (xprt_need_congestion_window_wait(xprt)) goto out_unlock; if (rpc_wake_up_first_on_wq(xprtiod_workqueue, &xprt->sending, - __xprt_lock_write_cong_func, xprt)) + __xprt_lock_write_func, xprt)) return; out_unlock: xprt_clear_locked(xprt); } -static void xprt_task_clear_bytes_sent(struct rpc_task *task) -{ - if (task != NULL) { - struct rpc_rqst *req = task->tk_rqstp; - if (req != NULL) - req->rq_bytes_sent = 0; - } -} - /** * xprt_release_xprt - allow other requests to use a transport * @xprt: transport with other tasks potentially waiting @@ -351,7 +338,6 @@ static void xprt_task_clear_bytes_sent(struct rpc_task *task) void xprt_release_xprt(struct rpc_xprt *xprt, struct rpc_task *task) { if (xprt->snd_task == task) { - xprt_task_clear_bytes_sent(task); xprt_clear_locked(xprt); __xprt_lock_write_next(xprt); } @@ -369,7 +355,6 @@ EXPORT_SYMBOL_GPL(xprt_release_xprt); void xprt_release_xprt_cong(struct rpc_xprt *xprt, struct rpc_task *task) { if (xprt->snd_task == task) { - xprt_task_clear_bytes_sent(task); xprt_clear_locked(xprt); __xprt_lock_write_next_cong(xprt); } @@ -378,6 +363,8 @@ EXPORT_SYMBOL_GPL(xprt_release_xprt_cong); static inline void xprt_release_write(struct rpc_xprt *xprt, struct rpc_task *task) { + if (xprt->snd_task != task) + return; spin_lock_bh(&xprt->transport_lock); xprt->ops->release_xprt(xprt, task); spin_unlock_bh(&xprt->transport_lock); @@ -388,16 +375,16 @@ static inline void xprt_release_write(struct rpc_xprt *xprt, struct rpc_task *ta * overflowed. Put the task to sleep if this is the case. */ static int -__xprt_get_cong(struct rpc_xprt *xprt, struct rpc_task *task) +__xprt_get_cong(struct rpc_xprt *xprt, struct rpc_rqst *req) { - struct rpc_rqst *req = task->tk_rqstp; - if (req->rq_cong) return 1; dprintk("RPC: %5u xprt_cwnd_limited cong = %lu cwnd = %lu\n", - task->tk_pid, xprt->cong, xprt->cwnd); - if (RPCXPRT_CONGESTED(xprt)) + req->rq_task->tk_pid, xprt->cong, xprt->cwnd); + if (RPCXPRT_CONGESTED(xprt)) { + xprt_set_congestion_window_wait(xprt); return 0; + } req->rq_cong = 1; xprt->cong += RPC_CWNDSCALE; return 1; @@ -414,10 +401,32 @@ __xprt_put_cong(struct rpc_xprt *xprt, struct rpc_rqst *req) return; req->rq_cong = 0; xprt->cong -= RPC_CWNDSCALE; + xprt_test_and_clear_congestion_window_wait(xprt); __xprt_lock_write_next_cong(xprt); } /** + * xprt_request_get_cong - Request congestion control credits + * @xprt: pointer to transport + * @req: pointer to RPC request + * + * Useful for transports that require congestion control. + */ +bool +xprt_request_get_cong(struct rpc_xprt *xprt, struct rpc_rqst *req) +{ + bool ret = false; + + if (req->rq_cong) + return true; + spin_lock_bh(&xprt->transport_lock); + ret = __xprt_get_cong(xprt, req) != 0; + spin_unlock_bh(&xprt->transport_lock); + return ret; +} +EXPORT_SYMBOL_GPL(xprt_request_get_cong); + +/** * xprt_release_rqst_cong - housekeeping when request is complete * @task: RPC request that recently completed * @@ -431,6 +440,20 @@ void xprt_release_rqst_cong(struct rpc_task *task) } EXPORT_SYMBOL_GPL(xprt_release_rqst_cong); +/* + * Clear the congestion window wait flag and wake up the next + * entry on xprt->sending + */ +static void +xprt_clear_congestion_window_wait(struct rpc_xprt *xprt) +{ + if (test_and_clear_bit(XPRT_CWND_WAIT, &xprt->state)) { + spin_lock_bh(&xprt->transport_lock); + __xprt_lock_write_next_cong(xprt); + spin_unlock_bh(&xprt->transport_lock); + } +} + /** * xprt_adjust_cwnd - adjust transport congestion window * @xprt: pointer to xprt @@ -488,39 +511,46 @@ EXPORT_SYMBOL_GPL(xprt_wake_pending_tasks); /** * xprt_wait_for_buffer_space - wait for transport output buffer to clear - * @task: task to be put to sleep - * @action: function pointer to be executed after wait + * @xprt: transport * * Note that we only set the timer for the case of RPC_IS_SOFT(), since * we don't in general want to force a socket disconnection due to * an incomplete RPC call transmission. */ -void xprt_wait_for_buffer_space(struct rpc_task *task, rpc_action action) +void xprt_wait_for_buffer_space(struct rpc_xprt *xprt) { - struct rpc_rqst *req = task->tk_rqstp; - struct rpc_xprt *xprt = req->rq_xprt; - - task->tk_timeout = RPC_IS_SOFT(task) ? req->rq_timeout : 0; - rpc_sleep_on(&xprt->pending, task, action); + set_bit(XPRT_WRITE_SPACE, &xprt->state); } EXPORT_SYMBOL_GPL(xprt_wait_for_buffer_space); +static bool +xprt_clear_write_space_locked(struct rpc_xprt *xprt) +{ + if (test_and_clear_bit(XPRT_WRITE_SPACE, &xprt->state)) { + __xprt_lock_write_next(xprt); + dprintk("RPC: write space: waking waiting task on " + "xprt %p\n", xprt); + return true; + } + return false; +} + /** * xprt_write_space - wake the task waiting for transport output buffer space * @xprt: transport with waiting tasks * * Can be called in a soft IRQ context, so xprt_write_space never sleeps. */ -void xprt_write_space(struct rpc_xprt *xprt) +bool xprt_write_space(struct rpc_xprt *xprt) { + bool ret; + + if (!test_bit(XPRT_WRITE_SPACE, &xprt->state)) + return false; spin_lock_bh(&xprt->transport_lock); - if (xprt->snd_task) { - dprintk("RPC: write space: waking waiting task on " - "xprt %p\n", xprt); - rpc_wake_up_queued_task_on_wq(xprtiod_workqueue, - &xprt->pending, xprt->snd_task); - } + ret = xprt_clear_write_space_locked(xprt); spin_unlock_bh(&xprt->transport_lock); + return ret; } EXPORT_SYMBOL_GPL(xprt_write_space); @@ -631,6 +661,7 @@ void xprt_disconnect_done(struct rpc_xprt *xprt) dprintk("RPC: disconnected transport %p\n", xprt); spin_lock_bh(&xprt->transport_lock); xprt_clear_connected(xprt); + xprt_clear_write_space_locked(xprt); xprt_wake_pending_tasks(xprt, -EAGAIN); spin_unlock_bh(&xprt->transport_lock); } @@ -654,6 +685,22 @@ void xprt_force_disconnect(struct rpc_xprt *xprt) } EXPORT_SYMBOL_GPL(xprt_force_disconnect); +static unsigned int +xprt_connect_cookie(struct rpc_xprt *xprt) +{ + return READ_ONCE(xprt->connect_cookie); +} + +static bool +xprt_request_retransmit_after_disconnect(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + + return req->rq_connect_cookie != xprt_connect_cookie(xprt) || + !xprt_connected(xprt); +} + /** * xprt_conditional_disconnect - force a transport to disconnect * @xprt: transport to disconnect @@ -692,7 +739,7 @@ static void xprt_schedule_autodisconnect(struct rpc_xprt *xprt) __must_hold(&xprt->transport_lock) { - if (list_empty(&xprt->recv) && xprt_has_timer(xprt)) + if (RB_EMPTY_ROOT(&xprt->recv_queue) && xprt_has_timer(xprt)) mod_timer(&xprt->timer, xprt->last_used + xprt->idle_timeout); } @@ -702,7 +749,7 @@ xprt_init_autodisconnect(struct timer_list *t) struct rpc_xprt *xprt = from_timer(xprt, t, timer); spin_lock(&xprt->transport_lock); - if (!list_empty(&xprt->recv)) + if (!RB_EMPTY_ROOT(&xprt->recv_queue)) goto out_abort; /* Reset xprt->last_used to avoid connect/autodisconnect cycling */ xprt->last_used = jiffies; @@ -726,7 +773,6 @@ bool xprt_lock_connect(struct rpc_xprt *xprt, goto out; if (xprt->snd_task != task) goto out; - xprt_task_clear_bytes_sent(task); xprt->snd_task = cookie; ret = true; out: @@ -772,7 +818,6 @@ void xprt_connect(struct rpc_task *task) xprt->ops->close(xprt); if (!xprt_connected(xprt)) { - task->tk_rqstp->rq_bytes_sent = 0; task->tk_timeout = task->tk_rqstp->rq_timeout; task->tk_rqstp->rq_connect_cookie = xprt->connect_cookie; rpc_sleep_on(&xprt->pending, task, xprt_connect_status); @@ -789,17 +834,11 @@ void xprt_connect(struct rpc_task *task) static void xprt_connect_status(struct rpc_task *task) { - struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; - - if (task->tk_status == 0) { - xprt->stat.connect_count++; - xprt->stat.connect_time += (long)jiffies - xprt->stat.connect_start; + switch (task->tk_status) { + case 0: dprintk("RPC: %5u xprt_connect_status: connection established\n", task->tk_pid); - return; - } - - switch (task->tk_status) { + break; case -ECONNREFUSED: case -ECONNRESET: case -ECONNABORTED: @@ -816,28 +855,97 @@ static void xprt_connect_status(struct rpc_task *task) default: dprintk("RPC: %5u xprt_connect_status: error %d connecting to " "server %s\n", task->tk_pid, -task->tk_status, - xprt->servername); + task->tk_rqstp->rq_xprt->servername); task->tk_status = -EIO; } } +enum xprt_xid_rb_cmp { + XID_RB_EQUAL, + XID_RB_LEFT, + XID_RB_RIGHT, +}; +static enum xprt_xid_rb_cmp +xprt_xid_cmp(__be32 xid1, __be32 xid2) +{ + if (xid1 == xid2) + return XID_RB_EQUAL; + if ((__force u32)xid1 < (__force u32)xid2) + return XID_RB_LEFT; + return XID_RB_RIGHT; +} + +static struct rpc_rqst * +xprt_request_rb_find(struct rpc_xprt *xprt, __be32 xid) +{ + struct rb_node *n = xprt->recv_queue.rb_node; + struct rpc_rqst *req; + + while (n != NULL) { + req = rb_entry(n, struct rpc_rqst, rq_recv); + switch (xprt_xid_cmp(xid, req->rq_xid)) { + case XID_RB_LEFT: + n = n->rb_left; + break; + case XID_RB_RIGHT: + n = n->rb_right; + break; + case XID_RB_EQUAL: + return req; + } + } + return NULL; +} + +static void +xprt_request_rb_insert(struct rpc_xprt *xprt, struct rpc_rqst *new) +{ + struct rb_node **p = &xprt->recv_queue.rb_node; + struct rb_node *n = NULL; + struct rpc_rqst *req; + + while (*p != NULL) { + n = *p; + req = rb_entry(n, struct rpc_rqst, rq_recv); + switch(xprt_xid_cmp(new->rq_xid, req->rq_xid)) { + case XID_RB_LEFT: + p = &n->rb_left; + break; + case XID_RB_RIGHT: + p = &n->rb_right; + break; + case XID_RB_EQUAL: + WARN_ON_ONCE(new != req); + return; + } + } + rb_link_node(&new->rq_recv, n, p); + rb_insert_color(&new->rq_recv, &xprt->recv_queue); +} + +static void +xprt_request_rb_remove(struct rpc_xprt *xprt, struct rpc_rqst *req) +{ + rb_erase(&req->rq_recv, &xprt->recv_queue); +} + /** * xprt_lookup_rqst - find an RPC request corresponding to an XID * @xprt: transport on which the original request was transmitted * @xid: RPC XID of incoming reply * - * Caller holds xprt->recv_lock. + * Caller holds xprt->queue_lock. */ struct rpc_rqst *xprt_lookup_rqst(struct rpc_xprt *xprt, __be32 xid) { struct rpc_rqst *entry; - list_for_each_entry(entry, &xprt->recv, rq_list) - if (entry->rq_xid == xid) { - trace_xprt_lookup_rqst(xprt, xid, 0); - entry->rq_rtt = ktime_sub(ktime_get(), entry->rq_xtime); - return entry; - } + entry = xprt_request_rb_find(xprt, xid); + if (entry != NULL) { + trace_xprt_lookup_rqst(xprt, xid, 0); + entry->rq_rtt = ktime_sub(ktime_get(), entry->rq_xtime); + return entry; + } dprintk("RPC: xprt_lookup_rqst did not find xid %08x\n", ntohl(xid)); @@ -847,16 +955,22 @@ struct rpc_rqst *xprt_lookup_rqst(struct rpc_xprt *xprt, __be32 xid) } EXPORT_SYMBOL_GPL(xprt_lookup_rqst); +static bool +xprt_is_pinned_rqst(struct rpc_rqst *req) +{ + return atomic_read(&req->rq_pin) != 0; +} + /** * xprt_pin_rqst - Pin a request on the transport receive list * @req: Request to pin * * Caller must ensure this is atomic with the call to xprt_lookup_rqst() - * so should be holding the xprt transport lock. + * so should be holding the xprt receive lock. */ void xprt_pin_rqst(struct rpc_rqst *req) { - set_bit(RPC_TASK_MSG_RECV, &req->rq_task->tk_runstate); + atomic_inc(&req->rq_pin); } EXPORT_SYMBOL_GPL(xprt_pin_rqst); @@ -864,38 +978,87 @@ EXPORT_SYMBOL_GPL(xprt_pin_rqst); * xprt_unpin_rqst - Unpin a request on the transport receive list * @req: Request to pin * - * Caller should be holding the xprt transport lock. + * Caller should be holding the xprt receive lock. */ void xprt_unpin_rqst(struct rpc_rqst *req) { - struct rpc_task *task = req->rq_task; - - clear_bit(RPC_TASK_MSG_RECV, &task->tk_runstate); - if (test_bit(RPC_TASK_MSG_RECV_WAIT, &task->tk_runstate)) - wake_up_bit(&task->tk_runstate, RPC_TASK_MSG_RECV); + if (!test_bit(RPC_TASK_MSG_PIN_WAIT, &req->rq_task->tk_runstate)) { + atomic_dec(&req->rq_pin); + return; + } + if (atomic_dec_and_test(&req->rq_pin)) + wake_up_var(&req->rq_pin); } EXPORT_SYMBOL_GPL(xprt_unpin_rqst); static void xprt_wait_on_pinned_rqst(struct rpc_rqst *req) -__must_hold(&req->rq_xprt->recv_lock) { - struct rpc_task *task = req->rq_task; + wait_var_event(&req->rq_pin, !xprt_is_pinned_rqst(req)); +} - if (task && test_bit(RPC_TASK_MSG_RECV, &task->tk_runstate)) { - spin_unlock(&req->rq_xprt->recv_lock); - set_bit(RPC_TASK_MSG_RECV_WAIT, &task->tk_runstate); - wait_on_bit(&task->tk_runstate, RPC_TASK_MSG_RECV, - TASK_UNINTERRUPTIBLE); - clear_bit(RPC_TASK_MSG_RECV_WAIT, &task->tk_runstate); - spin_lock(&req->rq_xprt->recv_lock); - } +static bool +xprt_request_data_received(struct rpc_task *task) +{ + return !test_bit(RPC_TASK_NEED_RECV, &task->tk_runstate) && + READ_ONCE(task->tk_rqstp->rq_reply_bytes_recvd) != 0; +} + +static bool +xprt_request_need_enqueue_receive(struct rpc_task *task, struct rpc_rqst *req) +{ + return !test_bit(RPC_TASK_NEED_RECV, &task->tk_runstate) && + READ_ONCE(task->tk_rqstp->rq_reply_bytes_recvd) == 0; +} + +/** + * xprt_request_enqueue_receive - Add an request to the receive queue + * @task: RPC task + * + */ +void +xprt_request_enqueue_receive(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + + if (!xprt_request_need_enqueue_receive(task, req)) + return; + spin_lock(&xprt->queue_lock); + + /* Update the softirq receive buffer */ + memcpy(&req->rq_private_buf, &req->rq_rcv_buf, + sizeof(req->rq_private_buf)); + + /* Add request to the receive list */ + xprt_request_rb_insert(xprt, req); + set_bit(RPC_TASK_NEED_RECV, &task->tk_runstate); + spin_unlock(&xprt->queue_lock); + + xprt_reset_majortimeo(req); + /* Turn off autodisconnect */ + del_singleshot_timer_sync(&xprt->timer); +} + +/** + * xprt_request_dequeue_receive_locked - Remove a request from the receive queue + * @task: RPC task + * + * Caller must hold xprt->queue_lock. + */ +static void +xprt_request_dequeue_receive_locked(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + + if (test_and_clear_bit(RPC_TASK_NEED_RECV, &task->tk_runstate)) + xprt_request_rb_remove(req->rq_xprt, req); } /** * xprt_update_rtt - Update RPC RTT statistics * @task: RPC request that recently completed * - * Caller holds xprt->recv_lock. + * Caller holds xprt->queue_lock. */ void xprt_update_rtt(struct rpc_task *task) { @@ -917,7 +1080,7 @@ EXPORT_SYMBOL_GPL(xprt_update_rtt); * @task: RPC request that recently completed * @copied: actual number of bytes received from the transport * - * Caller holds xprt->recv_lock. + * Caller holds xprt->queue_lock. */ void xprt_complete_rqst(struct rpc_task *task, int copied) { @@ -930,12 +1093,12 @@ void xprt_complete_rqst(struct rpc_task *task, int copied) xprt->stat.recvs++; - list_del_init(&req->rq_list); req->rq_private_buf.len = copied; /* Ensure all writes are done before we update */ /* req->rq_reply_bytes_recvd */ smp_wmb(); req->rq_reply_bytes_recvd = copied; + xprt_request_dequeue_receive_locked(task); rpc_wake_up_queued_task(&xprt->pending, task); } EXPORT_SYMBOL_GPL(xprt_complete_rqst); @@ -957,6 +1120,172 @@ static void xprt_timer(struct rpc_task *task) } /** + * xprt_request_wait_receive - wait for the reply to an RPC request + * @task: RPC task about to send a request + * + */ +void xprt_request_wait_receive(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + + if (!test_bit(RPC_TASK_NEED_RECV, &task->tk_runstate)) + return; + /* + * Sleep on the pending queue if we're expecting a reply. + * The spinlock ensures atomicity between the test of + * req->rq_reply_bytes_recvd, and the call to rpc_sleep_on(). + */ + spin_lock(&xprt->queue_lock); + if (test_bit(RPC_TASK_NEED_RECV, &task->tk_runstate)) { + xprt->ops->set_retrans_timeout(task); + rpc_sleep_on(&xprt->pending, task, xprt_timer); + /* + * Send an extra queue wakeup call if the + * connection was dropped in case the call to + * rpc_sleep_on() raced. + */ + if (xprt_request_retransmit_after_disconnect(task)) + rpc_wake_up_queued_task_set_status(&xprt->pending, + task, -ENOTCONN); + } + spin_unlock(&xprt->queue_lock); +} + +static bool +xprt_request_need_enqueue_transmit(struct rpc_task *task, struct rpc_rqst *req) +{ + return !test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate); +} + +/** + * xprt_request_enqueue_transmit - queue a task for transmission + * @task: pointer to rpc_task + * + * Add a task to the transmission queue. + */ +void +xprt_request_enqueue_transmit(struct rpc_task *task) +{ + struct rpc_rqst *pos, *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + + if (xprt_request_need_enqueue_transmit(task, req)) { + spin_lock(&xprt->queue_lock); + /* + * Requests that carry congestion control credits are added + * to the head of the list to avoid starvation issues. + */ + if (req->rq_cong) { + xprt_clear_congestion_window_wait(xprt); + list_for_each_entry(pos, &xprt->xmit_queue, rq_xmit) { + if (pos->rq_cong) + continue; + /* Note: req is added _before_ pos */ + list_add_tail(&req->rq_xmit, &pos->rq_xmit); + INIT_LIST_HEAD(&req->rq_xmit2); + goto out; + } + } else if (RPC_IS_SWAPPER(task)) { + list_for_each_entry(pos, &xprt->xmit_queue, rq_xmit) { + if (pos->rq_cong || pos->rq_bytes_sent) + continue; + if (RPC_IS_SWAPPER(pos->rq_task)) + continue; + /* Note: req is added _before_ pos */ + list_add_tail(&req->rq_xmit, &pos->rq_xmit); + INIT_LIST_HEAD(&req->rq_xmit2); + goto out; + } + } else { + list_for_each_entry(pos, &xprt->xmit_queue, rq_xmit) { + if (pos->rq_task->tk_owner != task->tk_owner) + continue; + list_add_tail(&req->rq_xmit2, &pos->rq_xmit2); + INIT_LIST_HEAD(&req->rq_xmit); + goto out; + } + } + list_add_tail(&req->rq_xmit, &xprt->xmit_queue); + INIT_LIST_HEAD(&req->rq_xmit2); +out: + set_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate); + spin_unlock(&xprt->queue_lock); + } +} + +/** + * xprt_request_dequeue_transmit_locked - remove a task from the transmission queue + * @task: pointer to rpc_task + * + * Remove a task from the transmission queue + * Caller must hold xprt->queue_lock + */ +static void +xprt_request_dequeue_transmit_locked(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + + if (!test_and_clear_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate)) + return; + if (!list_empty(&req->rq_xmit)) { + list_del(&req->rq_xmit); + if (!list_empty(&req->rq_xmit2)) { + struct rpc_rqst *next = list_first_entry(&req->rq_xmit2, + struct rpc_rqst, rq_xmit2); + list_del(&req->rq_xmit2); + list_add_tail(&next->rq_xmit, &next->rq_xprt->xmit_queue); + } + } else + list_del(&req->rq_xmit2); +} + +/** + * xprt_request_dequeue_transmit - remove a task from the transmission queue + * @task: pointer to rpc_task + * + * Remove a task from the transmission queue + */ +static void +xprt_request_dequeue_transmit(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + + spin_lock(&xprt->queue_lock); + xprt_request_dequeue_transmit_locked(task); + spin_unlock(&xprt->queue_lock); +} + +/** + * xprt_request_prepare - prepare an encoded request for transport + * @req: pointer to rpc_rqst + * + * Calls into the transport layer to do whatever is needed to prepare + * the request for transmission or receive. + */ +void +xprt_request_prepare(struct rpc_rqst *req) +{ + struct rpc_xprt *xprt = req->rq_xprt; + + if (xprt->ops->prepare_request) + xprt->ops->prepare_request(req); +} + +/** + * xprt_request_need_retransmit - Test if a task needs retransmission + * @task: pointer to rpc_task + * + * Test for whether a connection breakage requires the task to retransmit + */ +bool +xprt_request_need_retransmit(struct rpc_task *task) +{ + return xprt_request_retransmit_after_disconnect(task); +} + +/** * xprt_prepare_transmit - reserve the transport before sending a request * @task: RPC task about to send a request * @@ -965,32 +1294,18 @@ bool xprt_prepare_transmit(struct rpc_task *task) { struct rpc_rqst *req = task->tk_rqstp; struct rpc_xprt *xprt = req->rq_xprt; - bool ret = false; dprintk("RPC: %5u xprt_prepare_transmit\n", task->tk_pid); - spin_lock_bh(&xprt->transport_lock); - if (!req->rq_bytes_sent) { - if (req->rq_reply_bytes_recvd) { - task->tk_status = req->rq_reply_bytes_recvd; - goto out_unlock; - } - if ((task->tk_flags & RPC_TASK_NO_RETRANS_TIMEOUT) - && xprt_connected(xprt) - && req->rq_connect_cookie == xprt->connect_cookie) { - xprt->ops->set_retrans_timeout(task); - rpc_sleep_on(&xprt->pending, task, xprt_timer); - goto out_unlock; - } - } - if (!xprt->ops->reserve_xprt(xprt, task)) { - task->tk_status = -EAGAIN; - goto out_unlock; + if (!xprt_lock_write(xprt, task)) { + /* Race breaker: someone may have transmitted us */ + if (!test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate)) + rpc_wake_up_queued_task_set_status(&xprt->sending, + task, 0); + return false; + } - ret = true; -out_unlock: - spin_unlock_bh(&xprt->transport_lock); - return ret; + return true; } void xprt_end_transmit(struct rpc_task *task) @@ -999,54 +1314,62 @@ void xprt_end_transmit(struct rpc_task *task) } /** - * xprt_transmit - send an RPC request on a transport - * @task: controlling RPC task + * xprt_request_transmit - send an RPC request on a transport + * @req: pointer to request to transmit + * @snd_task: RPC task that owns the transport lock * - * We have to copy the iovec because sendmsg fiddles with its contents. + * This performs the transmission of a single request. + * Note that if the request is not the same as snd_task, then it + * does need to be pinned. + * Returns '0' on success. */ -void xprt_transmit(struct rpc_task *task) +static int +xprt_request_transmit(struct rpc_rqst *req, struct rpc_task *snd_task) { - struct rpc_rqst *req = task->tk_rqstp; - struct rpc_xprt *xprt = req->rq_xprt; + struct rpc_xprt *xprt = req->rq_xprt; + struct rpc_task *task = req->rq_task; unsigned int connect_cookie; + int is_retrans = RPC_WAS_SENT(task); int status; dprintk("RPC: %5u xprt_transmit(%u)\n", task->tk_pid, req->rq_slen); - if (!req->rq_reply_bytes_recvd) { - if (list_empty(&req->rq_list) && rpc_reply_expected(task)) { - /* - * Add to the list only if we're expecting a reply - */ - /* Update the softirq receive buffer */ - memcpy(&req->rq_private_buf, &req->rq_rcv_buf, - sizeof(req->rq_private_buf)); - /* Add request to the receive list */ - spin_lock(&xprt->recv_lock); - list_add_tail(&req->rq_list, &xprt->recv); - spin_unlock(&xprt->recv_lock); - xprt_reset_majortimeo(req); - /* Turn off autodisconnect */ - del_singleshot_timer_sync(&xprt->timer); + if (!req->rq_bytes_sent) { + if (xprt_request_data_received(task)) { + status = 0; + goto out_dequeue; } - } else if (!req->rq_bytes_sent) - return; + /* Verify that our message lies in the RPCSEC_GSS window */ + if (rpcauth_xmit_need_reencode(task)) { + status = -EBADMSG; + goto out_dequeue; + } + } + + /* + * Update req->rq_ntrans before transmitting to avoid races with + * xprt_update_rtt(), which needs to know that it is recording a + * reply to the first transmission. + */ + req->rq_ntrans++; connect_cookie = xprt->connect_cookie; - status = xprt->ops->send_request(task); + status = xprt->ops->send_request(req); trace_xprt_transmit(xprt, req->rq_xid, status); if (status != 0) { - task->tk_status = status; - return; + req->rq_ntrans--; + return status; } + + if (is_retrans) + task->tk_client->cl_stats->rpcretrans++; + xprt_inject_disconnect(xprt); dprintk("RPC: %5u xmit complete\n", task->tk_pid); task->tk_flags |= RPC_TASK_SENT; spin_lock_bh(&xprt->transport_lock); - xprt->ops->set_retrans_timeout(task); - xprt->stat.sends++; xprt->stat.req_u += xprt->stat.sends - xprt->stat.recvs; xprt->stat.bklog_u += xprt->backlog.qlen; @@ -1055,25 +1378,49 @@ void xprt_transmit(struct rpc_task *task) spin_unlock_bh(&xprt->transport_lock); req->rq_connect_cookie = connect_cookie; - if (rpc_reply_expected(task) && !READ_ONCE(req->rq_reply_bytes_recvd)) { - /* - * Sleep on the pending queue if we're expecting a reply. - * The spinlock ensures atomicity between the test of - * req->rq_reply_bytes_recvd, and the call to rpc_sleep_on(). - */ - spin_lock(&xprt->recv_lock); - if (!req->rq_reply_bytes_recvd) { - rpc_sleep_on(&xprt->pending, task, xprt_timer); - /* - * Send an extra queue wakeup call if the - * connection was dropped in case the call to - * rpc_sleep_on() raced. - */ - if (!xprt_connected(xprt)) - xprt_wake_pending_tasks(xprt, -ENOTCONN); - } - spin_unlock(&xprt->recv_lock); +out_dequeue: + xprt_request_dequeue_transmit(task); + rpc_wake_up_queued_task_set_status(&xprt->sending, task, status); + return status; +} + +/** + * xprt_transmit - send an RPC request on a transport + * @task: controlling RPC task + * + * Attempts to drain the transmit queue. On exit, either the transport + * signalled an error that needs to be handled before transmission can + * resume, or @task finished transmitting, and detected that it already + * received a reply. + */ +void +xprt_transmit(struct rpc_task *task) +{ + struct rpc_rqst *next, *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + int status; + + spin_lock(&xprt->queue_lock); + while (!list_empty(&xprt->xmit_queue)) { + next = list_first_entry(&xprt->xmit_queue, + struct rpc_rqst, rq_xmit); + xprt_pin_rqst(next); + spin_unlock(&xprt->queue_lock); + status = xprt_request_transmit(next, task); + if (status == -EBADMSG && next != req) + status = 0; + cond_resched(); + spin_lock(&xprt->queue_lock); + xprt_unpin_rqst(next); + if (status == 0) { + if (!xprt_request_data_received(task) || + test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate)) + continue; + } else if (test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate)) + task->tk_status = status; + break; } + spin_unlock(&xprt->queue_lock); } static void xprt_add_backlog(struct rpc_xprt *xprt, struct rpc_task *task) @@ -1170,20 +1517,6 @@ out_init_req: } EXPORT_SYMBOL_GPL(xprt_alloc_slot); -void xprt_lock_and_alloc_slot(struct rpc_xprt *xprt, struct rpc_task *task) -{ - /* Note: grabbing the xprt_lock_write() ensures that we throttle - * new slot allocation if the transport is congested (i.e. when - * reconnecting a stream transport or when out of socket write - * buffer space). - */ - if (xprt_lock_write(xprt, task)) { - xprt_alloc_slot(xprt, task); - xprt_release_write(xprt, task); - } -} -EXPORT_SYMBOL_GPL(xprt_lock_and_alloc_slot); - void xprt_free_slot(struct rpc_xprt *xprt, struct rpc_rqst *req) { spin_lock(&xprt->reserve_lock); @@ -1250,6 +1583,60 @@ void xprt_free(struct rpc_xprt *xprt) } EXPORT_SYMBOL_GPL(xprt_free); +static void +xprt_init_connect_cookie(struct rpc_rqst *req, struct rpc_xprt *xprt) +{ + req->rq_connect_cookie = xprt_connect_cookie(xprt) - 1; +} + +static __be32 +xprt_alloc_xid(struct rpc_xprt *xprt) +{ + __be32 xid; + + spin_lock(&xprt->reserve_lock); + xid = (__force __be32)xprt->xid++; + spin_unlock(&xprt->reserve_lock); + return xid; +} + +static void +xprt_init_xid(struct rpc_xprt *xprt) +{ + xprt->xid = prandom_u32(); +} + +static void +xprt_request_init(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_xprt; + struct rpc_rqst *req = task->tk_rqstp; + + req->rq_timeout = task->tk_client->cl_timeout->to_initval; + req->rq_task = task; + req->rq_xprt = xprt; + req->rq_buffer = NULL; + req->rq_xid = xprt_alloc_xid(xprt); + xprt_init_connect_cookie(req, xprt); + req->rq_bytes_sent = 0; + req->rq_snd_buf.len = 0; + req->rq_snd_buf.buflen = 0; + req->rq_rcv_buf.len = 0; + req->rq_rcv_buf.buflen = 0; + req->rq_release_snd_buf = NULL; + xprt_reset_majortimeo(req); + dprintk("RPC: %5u reserved req %p xid %08x\n", task->tk_pid, + req, ntohl(req->rq_xid)); +} + +static void +xprt_do_reserve(struct rpc_xprt *xprt, struct rpc_task *task) +{ + xprt->ops->alloc_slot(xprt, task); + if (task->tk_rqstp != NULL) + xprt_request_init(task); +} + /** * xprt_reserve - allocate an RPC request slot * @task: RPC task requesting a slot allocation @@ -1269,7 +1656,7 @@ void xprt_reserve(struct rpc_task *task) task->tk_timeout = 0; task->tk_status = -EAGAIN; if (!xprt_throttle_congested(xprt, task)) - xprt->ops->alloc_slot(xprt, task); + xprt_do_reserve(xprt, task); } /** @@ -1291,45 +1678,29 @@ void xprt_retry_reserve(struct rpc_task *task) task->tk_timeout = 0; task->tk_status = -EAGAIN; - xprt->ops->alloc_slot(xprt, task); -} - -static inline __be32 xprt_alloc_xid(struct rpc_xprt *xprt) -{ - __be32 xid; - - spin_lock(&xprt->reserve_lock); - xid = (__force __be32)xprt->xid++; - spin_unlock(&xprt->reserve_lock); - return xid; + xprt_do_reserve(xprt, task); } -static inline void xprt_init_xid(struct rpc_xprt *xprt) -{ - xprt->xid = prandom_u32(); -} - -void xprt_request_init(struct rpc_task *task) +static void +xprt_request_dequeue_all(struct rpc_task *task, struct rpc_rqst *req) { - struct rpc_xprt *xprt = task->tk_xprt; - struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; - INIT_LIST_HEAD(&req->rq_list); - req->rq_timeout = task->tk_client->cl_timeout->to_initval; - req->rq_task = task; - req->rq_xprt = xprt; - req->rq_buffer = NULL; - req->rq_xid = xprt_alloc_xid(xprt); - req->rq_connect_cookie = xprt->connect_cookie - 1; - req->rq_bytes_sent = 0; - req->rq_snd_buf.len = 0; - req->rq_snd_buf.buflen = 0; - req->rq_rcv_buf.len = 0; - req->rq_rcv_buf.buflen = 0; - req->rq_release_snd_buf = NULL; - xprt_reset_majortimeo(req); - dprintk("RPC: %5u reserved req %p xid %08x\n", task->tk_pid, - req, ntohl(req->rq_xid)); + if (test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate) || + test_bit(RPC_TASK_NEED_RECV, &task->tk_runstate) || + xprt_is_pinned_rqst(req)) { + spin_lock(&xprt->queue_lock); + xprt_request_dequeue_transmit_locked(task); + xprt_request_dequeue_receive_locked(task); + while (xprt_is_pinned_rqst(req)) { + set_bit(RPC_TASK_MSG_PIN_WAIT, &task->tk_runstate); + spin_unlock(&xprt->queue_lock); + xprt_wait_on_pinned_rqst(req); + spin_lock(&xprt->queue_lock); + clear_bit(RPC_TASK_MSG_PIN_WAIT, &task->tk_runstate); + } + spin_unlock(&xprt->queue_lock); + } } /** @@ -1345,8 +1716,7 @@ void xprt_release(struct rpc_task *task) if (req == NULL) { if (task->tk_client) { xprt = task->tk_xprt; - if (xprt->snd_task == task) - xprt_release_write(xprt, task); + xprt_release_write(xprt, task); } return; } @@ -1356,12 +1726,7 @@ void xprt_release(struct rpc_task *task) task->tk_ops->rpc_count_stats(task, task->tk_calldata); else if (task->tk_client) rpc_count_iostats(task, task->tk_client->cl_metrics); - spin_lock(&xprt->recv_lock); - if (!list_empty(&req->rq_list)) { - list_del_init(&req->rq_list); - xprt_wait_on_pinned_rqst(req); - } - spin_unlock(&xprt->recv_lock); + xprt_request_dequeue_all(task, req); spin_lock_bh(&xprt->transport_lock); xprt->ops->release_xprt(xprt, task); if (xprt->ops->release_request) @@ -1372,6 +1737,7 @@ void xprt_release(struct rpc_task *task) if (req->rq_buffer) xprt->ops->buf_free(task); xprt_inject_disconnect(xprt); + xdr_free_bvec(&req->rq_rcv_buf); if (req->rq_cred != NULL) put_rpccred(req->rq_cred); task->tk_rqstp = NULL; @@ -1385,16 +1751,36 @@ void xprt_release(struct rpc_task *task) xprt_free_bc_request(req); } +#ifdef CONFIG_SUNRPC_BACKCHANNEL +void +xprt_init_bc_request(struct rpc_rqst *req, struct rpc_task *task) +{ + struct xdr_buf *xbufp = &req->rq_snd_buf; + + task->tk_rqstp = req; + req->rq_task = task; + xprt_init_connect_cookie(req, req->rq_xprt); + /* + * Set up the xdr_buf length. + * This also indicates that the buffer is XDR encoded already. + */ + xbufp->len = xbufp->head[0].iov_len + xbufp->page_len + + xbufp->tail[0].iov_len; + req->rq_bytes_sent = 0; +} +#endif + static void xprt_init(struct rpc_xprt *xprt, struct net *net) { kref_init(&xprt->kref); spin_lock_init(&xprt->transport_lock); spin_lock_init(&xprt->reserve_lock); - spin_lock_init(&xprt->recv_lock); + spin_lock_init(&xprt->queue_lock); INIT_LIST_HEAD(&xprt->free); - INIT_LIST_HEAD(&xprt->recv); + xprt->recv_queue = RB_ROOT; + INIT_LIST_HEAD(&xprt->xmit_queue); #if defined(CONFIG_SUNRPC_BACKCHANNEL) spin_lock_init(&xprt->bc_pa_lock); INIT_LIST_HEAD(&xprt->bc_pa_list); @@ -1407,7 +1793,7 @@ static void xprt_init(struct rpc_xprt *xprt, struct net *net) rpc_init_wait_queue(&xprt->binding, "xprt_binding"); rpc_init_wait_queue(&xprt->pending, "xprt_pending"); - rpc_init_priority_wait_queue(&xprt->sending, "xprt_sending"); + rpc_init_wait_queue(&xprt->sending, "xprt_sending"); rpc_init_priority_wait_queue(&xprt->backlog, "xprt_backlog"); xprt_init_xid(xprt); diff --git a/net/sunrpc/xprtrdma/backchannel.c b/net/sunrpc/xprtrdma/backchannel.c index 90adeff4c06b..e5b367a3e517 100644 --- a/net/sunrpc/xprtrdma/backchannel.c +++ b/net/sunrpc/xprtrdma/backchannel.c @@ -51,12 +51,11 @@ static int rpcrdma_bc_setup_reqs(struct rpcrdma_xprt *r_xprt, rqst = &req->rl_slot; rqst->rq_xprt = xprt; - INIT_LIST_HEAD(&rqst->rq_list); INIT_LIST_HEAD(&rqst->rq_bc_list); __set_bit(RPC_BC_PA_IN_USE, &rqst->rq_bc_pa_state); - spin_lock_bh(&xprt->bc_pa_lock); + spin_lock(&xprt->bc_pa_lock); list_add(&rqst->rq_bc_pa_list, &xprt->bc_pa_list); - spin_unlock_bh(&xprt->bc_pa_lock); + spin_unlock(&xprt->bc_pa_lock); size = r_xprt->rx_data.inline_rsize; rb = rpcrdma_alloc_regbuf(size, DMA_TO_DEVICE, GFP_KERNEL); @@ -201,6 +200,9 @@ int xprt_rdma_bc_send_reply(struct rpc_rqst *rqst) if (!xprt_connected(rqst->rq_xprt)) goto drop_connection; + if (!xprt_request_get_cong(rqst->rq_xprt, rqst)) + return -EBADSLT; + rc = rpcrdma_bc_marshal_reply(rqst); if (rc < 0) goto failed_marshal; @@ -228,16 +230,16 @@ void xprt_rdma_bc_destroy(struct rpc_xprt *xprt, unsigned int reqs) struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); struct rpc_rqst *rqst, *tmp; - spin_lock_bh(&xprt->bc_pa_lock); + spin_lock(&xprt->bc_pa_lock); list_for_each_entry_safe(rqst, tmp, &xprt->bc_pa_list, rq_bc_pa_list) { list_del(&rqst->rq_bc_pa_list); - spin_unlock_bh(&xprt->bc_pa_lock); + spin_unlock(&xprt->bc_pa_lock); rpcrdma_bc_free_rqst(r_xprt, rqst); - spin_lock_bh(&xprt->bc_pa_lock); + spin_lock(&xprt->bc_pa_lock); } - spin_unlock_bh(&xprt->bc_pa_lock); + spin_unlock(&xprt->bc_pa_lock); } /** @@ -255,9 +257,9 @@ void xprt_rdma_bc_free_rqst(struct rpc_rqst *rqst) rpcrdma_recv_buffer_put(req->rl_reply); req->rl_reply = NULL; - spin_lock_bh(&xprt->bc_pa_lock); + spin_lock(&xprt->bc_pa_lock); list_add_tail(&rqst->rq_bc_pa_list, &xprt->bc_pa_list); - spin_unlock_bh(&xprt->bc_pa_lock); + spin_unlock(&xprt->bc_pa_lock); } /** diff --git a/net/sunrpc/xprtrdma/fmr_ops.c b/net/sunrpc/xprtrdma/fmr_ops.c index 0f7c465d9a5a..7f5632cd5a48 100644 --- a/net/sunrpc/xprtrdma/fmr_ops.c +++ b/net/sunrpc/xprtrdma/fmr_ops.c @@ -49,46 +49,7 @@ fmr_is_supported(struct rpcrdma_ia *ia) return true; } -static int -fmr_op_init_mr(struct rpcrdma_ia *ia, struct rpcrdma_mr *mr) -{ - static struct ib_fmr_attr fmr_attr = { - .max_pages = RPCRDMA_MAX_FMR_SGES, - .max_maps = 1, - .page_shift = PAGE_SHIFT - }; - - mr->fmr.fm_physaddrs = kcalloc(RPCRDMA_MAX_FMR_SGES, - sizeof(u64), GFP_KERNEL); - if (!mr->fmr.fm_physaddrs) - goto out_free; - - mr->mr_sg = kcalloc(RPCRDMA_MAX_FMR_SGES, - sizeof(*mr->mr_sg), GFP_KERNEL); - if (!mr->mr_sg) - goto out_free; - - sg_init_table(mr->mr_sg, RPCRDMA_MAX_FMR_SGES); - - mr->fmr.fm_mr = ib_alloc_fmr(ia->ri_pd, RPCRDMA_FMR_ACCESS_FLAGS, - &fmr_attr); - if (IS_ERR(mr->fmr.fm_mr)) - goto out_fmr_err; - - INIT_LIST_HEAD(&mr->mr_list); - return 0; - -out_fmr_err: - dprintk("RPC: %s: ib_alloc_fmr returned %ld\n", __func__, - PTR_ERR(mr->fmr.fm_mr)); - -out_free: - kfree(mr->mr_sg); - kfree(mr->fmr.fm_physaddrs); - return -ENOMEM; -} - -static int +static void __fmr_unmap(struct rpcrdma_mr *mr) { LIST_HEAD(l); @@ -97,13 +58,16 @@ __fmr_unmap(struct rpcrdma_mr *mr) list_add(&mr->fmr.fm_mr->list, &l); rc = ib_unmap_fmr(&l); list_del(&mr->fmr.fm_mr->list); - return rc; + if (rc) + pr_err("rpcrdma: final ib_unmap_fmr for %p failed %i\n", + mr, rc); } +/* Release an MR. + */ static void fmr_op_release_mr(struct rpcrdma_mr *mr) { - LIST_HEAD(unmap_list); int rc; kfree(mr->fmr.fm_physaddrs); @@ -112,10 +76,7 @@ fmr_op_release_mr(struct rpcrdma_mr *mr) /* In case this one was left mapped, try to unmap it * to prevent dealloc_fmr from failing with EBUSY */ - rc = __fmr_unmap(mr); - if (rc) - pr_err("rpcrdma: final ib_unmap_fmr for %p failed %i\n", - mr, rc); + __fmr_unmap(mr); rc = ib_dealloc_fmr(mr->fmr.fm_mr); if (rc) @@ -125,40 +86,68 @@ fmr_op_release_mr(struct rpcrdma_mr *mr) kfree(mr); } -/* Reset of a single FMR. +/* MRs are dynamically allocated, so simply clean up and release the MR. + * A replacement MR will subsequently be allocated on demand. */ static void -fmr_op_recover_mr(struct rpcrdma_mr *mr) +fmr_mr_recycle_worker(struct work_struct *work) { + struct rpcrdma_mr *mr = container_of(work, struct rpcrdma_mr, mr_recycle); struct rpcrdma_xprt *r_xprt = mr->mr_xprt; - int rc; - /* ORDER: invalidate first */ - rc = __fmr_unmap(mr); - if (rc) - goto out_release; - - /* ORDER: then DMA unmap */ - rpcrdma_mr_unmap_and_put(mr); + trace_xprtrdma_mr_recycle(mr); - r_xprt->rx_stats.mrs_recovered++; - return; - -out_release: - pr_err("rpcrdma: FMR reset failed (%d), %p released\n", rc, mr); - r_xprt->rx_stats.mrs_orphaned++; - - trace_xprtrdma_dma_unmap(mr); + trace_xprtrdma_mr_unmap(mr); ib_dma_unmap_sg(r_xprt->rx_ia.ri_device, mr->mr_sg, mr->mr_nents, mr->mr_dir); spin_lock(&r_xprt->rx_buf.rb_mrlock); list_del(&mr->mr_all); + r_xprt->rx_stats.mrs_recycled++; spin_unlock(&r_xprt->rx_buf.rb_mrlock); - fmr_op_release_mr(mr); } +static int +fmr_op_init_mr(struct rpcrdma_ia *ia, struct rpcrdma_mr *mr) +{ + static struct ib_fmr_attr fmr_attr = { + .max_pages = RPCRDMA_MAX_FMR_SGES, + .max_maps = 1, + .page_shift = PAGE_SHIFT + }; + + mr->fmr.fm_physaddrs = kcalloc(RPCRDMA_MAX_FMR_SGES, + sizeof(u64), GFP_KERNEL); + if (!mr->fmr.fm_physaddrs) + goto out_free; + + mr->mr_sg = kcalloc(RPCRDMA_MAX_FMR_SGES, + sizeof(*mr->mr_sg), GFP_KERNEL); + if (!mr->mr_sg) + goto out_free; + + sg_init_table(mr->mr_sg, RPCRDMA_MAX_FMR_SGES); + + mr->fmr.fm_mr = ib_alloc_fmr(ia->ri_pd, RPCRDMA_FMR_ACCESS_FLAGS, + &fmr_attr); + if (IS_ERR(mr->fmr.fm_mr)) + goto out_fmr_err; + + INIT_LIST_HEAD(&mr->mr_list); + INIT_WORK(&mr->mr_recycle, fmr_mr_recycle_worker); + return 0; + +out_fmr_err: + dprintk("RPC: %s: ib_alloc_fmr returned %ld\n", __func__, + PTR_ERR(mr->fmr.fm_mr)); + +out_free: + kfree(mr->mr_sg); + kfree(mr->fmr.fm_physaddrs); + return -ENOMEM; +} + /* On success, sets: * ep->rep_attr.cap.max_send_wr * ep->rep_attr.cap.max_recv_wr @@ -187,6 +176,7 @@ fmr_op_open(struct rpcrdma_ia *ia, struct rpcrdma_ep *ep, ia->ri_max_segs = max_t(unsigned int, 1, RPCRDMA_MAX_DATA_SEGS / RPCRDMA_MAX_FMR_SGES); + ia->ri_max_segs += 2; /* segments for head and tail buffers */ return 0; } @@ -244,7 +234,7 @@ fmr_op_map(struct rpcrdma_xprt *r_xprt, struct rpcrdma_mr_seg *seg, mr->mr_sg, i, mr->mr_dir); if (!mr->mr_nents) goto out_dmamap_err; - trace_xprtrdma_dma_map(mr); + trace_xprtrdma_mr_map(mr); for (i = 0, dma_pages = mr->fmr.fm_physaddrs; i < mr->mr_nents; i++) dma_pages[i] = sg_dma_address(&mr->mr_sg[i]); @@ -305,13 +295,13 @@ fmr_op_unmap_sync(struct rpcrdma_xprt *r_xprt, struct list_head *mrs) list_for_each_entry(mr, mrs, mr_list) { dprintk("RPC: %s: unmapping fmr %p\n", __func__, &mr->fmr); - trace_xprtrdma_localinv(mr); + trace_xprtrdma_mr_localinv(mr); list_add_tail(&mr->fmr.fm_mr->list, &unmap_list); } r_xprt->rx_stats.local_inv_needed++; rc = ib_unmap_fmr(&unmap_list); if (rc) - goto out_reset; + goto out_release; /* ORDER: Now DMA unmap all of the req's MRs, and return * them to the free MW list. @@ -324,13 +314,13 @@ fmr_op_unmap_sync(struct rpcrdma_xprt *r_xprt, struct list_head *mrs) return; -out_reset: +out_release: pr_err("rpcrdma: ib_unmap_fmr failed (%i)\n", rc); while (!list_empty(mrs)) { mr = rpcrdma_mr_pop(mrs); list_del(&mr->fmr.fm_mr->list); - fmr_op_recover_mr(mr); + rpcrdma_mr_recycle(mr); } } @@ -338,7 +328,6 @@ const struct rpcrdma_memreg_ops rpcrdma_fmr_memreg_ops = { .ro_map = fmr_op_map, .ro_send = fmr_op_send, .ro_unmap_sync = fmr_op_unmap_sync, - .ro_recover_mr = fmr_op_recover_mr, .ro_open = fmr_op_open, .ro_maxpages = fmr_op_maxpages, .ro_init_mr = fmr_op_init_mr, diff --git a/net/sunrpc/xprtrdma/frwr_ops.c b/net/sunrpc/xprtrdma/frwr_ops.c index 1bb00dd6ccdb..fc6378cc0c1c 100644 --- a/net/sunrpc/xprtrdma/frwr_ops.c +++ b/net/sunrpc/xprtrdma/frwr_ops.c @@ -97,6 +97,44 @@ out_not_supported: return false; } +static void +frwr_op_release_mr(struct rpcrdma_mr *mr) +{ + int rc; + + rc = ib_dereg_mr(mr->frwr.fr_mr); + if (rc) + pr_err("rpcrdma: final ib_dereg_mr for %p returned %i\n", + mr, rc); + kfree(mr->mr_sg); + kfree(mr); +} + +/* MRs are dynamically allocated, so simply clean up and release the MR. + * A replacement MR will subsequently be allocated on demand. + */ +static void +frwr_mr_recycle_worker(struct work_struct *work) +{ + struct rpcrdma_mr *mr = container_of(work, struct rpcrdma_mr, mr_recycle); + enum rpcrdma_frwr_state state = mr->frwr.fr_state; + struct rpcrdma_xprt *r_xprt = mr->mr_xprt; + + trace_xprtrdma_mr_recycle(mr); + + if (state != FRWR_FLUSHED_LI) { + trace_xprtrdma_mr_unmap(mr); + ib_dma_unmap_sg(r_xprt->rx_ia.ri_device, + mr->mr_sg, mr->mr_nents, mr->mr_dir); + } + + spin_lock(&r_xprt->rx_buf.rb_mrlock); + list_del(&mr->mr_all); + r_xprt->rx_stats.mrs_recycled++; + spin_unlock(&r_xprt->rx_buf.rb_mrlock); + frwr_op_release_mr(mr); +} + static int frwr_op_init_mr(struct rpcrdma_ia *ia, struct rpcrdma_mr *mr) { @@ -113,6 +151,7 @@ frwr_op_init_mr(struct rpcrdma_ia *ia, struct rpcrdma_mr *mr) goto out_list_err; INIT_LIST_HEAD(&mr->mr_list); + INIT_WORK(&mr->mr_recycle, frwr_mr_recycle_worker); sg_init_table(mr->mr_sg, depth); init_completion(&frwr->fr_linv_done); return 0; @@ -131,79 +170,6 @@ out_list_err: return rc; } -static void -frwr_op_release_mr(struct rpcrdma_mr *mr) -{ - int rc; - - rc = ib_dereg_mr(mr->frwr.fr_mr); - if (rc) - pr_err("rpcrdma: final ib_dereg_mr for %p returned %i\n", - mr, rc); - kfree(mr->mr_sg); - kfree(mr); -} - -static int -__frwr_mr_reset(struct rpcrdma_ia *ia, struct rpcrdma_mr *mr) -{ - struct rpcrdma_frwr *frwr = &mr->frwr; - int rc; - - rc = ib_dereg_mr(frwr->fr_mr); - if (rc) { - pr_warn("rpcrdma: ib_dereg_mr status %d, frwr %p orphaned\n", - rc, mr); - return rc; - } - - frwr->fr_mr = ib_alloc_mr(ia->ri_pd, ia->ri_mrtype, - ia->ri_max_frwr_depth); - if (IS_ERR(frwr->fr_mr)) { - pr_warn("rpcrdma: ib_alloc_mr status %ld, frwr %p orphaned\n", - PTR_ERR(frwr->fr_mr), mr); - return PTR_ERR(frwr->fr_mr); - } - - dprintk("RPC: %s: recovered FRWR %p\n", __func__, frwr); - frwr->fr_state = FRWR_IS_INVALID; - return 0; -} - -/* Reset of a single FRWR. Generate a fresh rkey by replacing the MR. - */ -static void -frwr_op_recover_mr(struct rpcrdma_mr *mr) -{ - enum rpcrdma_frwr_state state = mr->frwr.fr_state; - struct rpcrdma_xprt *r_xprt = mr->mr_xprt; - struct rpcrdma_ia *ia = &r_xprt->rx_ia; - int rc; - - rc = __frwr_mr_reset(ia, mr); - if (state != FRWR_FLUSHED_LI) { - trace_xprtrdma_dma_unmap(mr); - ib_dma_unmap_sg(ia->ri_device, - mr->mr_sg, mr->mr_nents, mr->mr_dir); - } - if (rc) - goto out_release; - - rpcrdma_mr_put(mr); - r_xprt->rx_stats.mrs_recovered++; - return; - -out_release: - pr_err("rpcrdma: FRWR reset failed %d, %p released\n", rc, mr); - r_xprt->rx_stats.mrs_orphaned++; - - spin_lock(&r_xprt->rx_buf.rb_mrlock); - list_del(&mr->mr_all); - spin_unlock(&r_xprt->rx_buf.rb_mrlock); - - frwr_op_release_mr(mr); -} - /* On success, sets: * ep->rep_attr.cap.max_send_wr * ep->rep_attr.cap.max_recv_wr @@ -276,6 +242,7 @@ frwr_op_open(struct rpcrdma_ia *ia, struct rpcrdma_ep *ep, ia->ri_max_segs = max_t(unsigned int, 1, RPCRDMA_MAX_DATA_SEGS / ia->ri_max_frwr_depth); + ia->ri_max_segs += 2; /* segments for head and tail buffers */ return 0; } @@ -384,7 +351,7 @@ frwr_op_map(struct rpcrdma_xprt *r_xprt, struct rpcrdma_mr_seg *seg, mr = NULL; do { if (mr) - rpcrdma_mr_defer_recovery(mr); + rpcrdma_mr_recycle(mr); mr = rpcrdma_mr_get(r_xprt); if (!mr) return ERR_PTR(-EAGAIN); @@ -417,7 +384,7 @@ frwr_op_map(struct rpcrdma_xprt *r_xprt, struct rpcrdma_mr_seg *seg, mr->mr_nents = ib_dma_map_sg(ia->ri_device, mr->mr_sg, i, mr->mr_dir); if (!mr->mr_nents) goto out_dmamap_err; - trace_xprtrdma_dma_map(mr); + trace_xprtrdma_mr_map(mr); ibmr = frwr->fr_mr; n = ib_map_mr_sg(ibmr, mr->mr_sg, mr->mr_nents, NULL, PAGE_SIZE); @@ -451,7 +418,7 @@ out_dmamap_err: out_mapmr_err: pr_err("rpcrdma: failed to map mr %p (%d/%d)\n", frwr->fr_mr, n, mr->mr_nents); - rpcrdma_mr_defer_recovery(mr); + rpcrdma_mr_recycle(mr); return ERR_PTR(-EIO); } @@ -499,7 +466,7 @@ frwr_op_reminv(struct rpcrdma_rep *rep, struct list_head *mrs) list_for_each_entry(mr, mrs, mr_list) if (mr->mr_handle == rep->rr_inv_rkey) { list_del_init(&mr->mr_list); - trace_xprtrdma_remoteinv(mr); + trace_xprtrdma_mr_remoteinv(mr); mr->frwr.fr_state = FRWR_IS_INVALID; rpcrdma_mr_unmap_and_put(mr); break; /* only one invalidated MR per RPC */ @@ -536,7 +503,7 @@ frwr_op_unmap_sync(struct rpcrdma_xprt *r_xprt, struct list_head *mrs) mr->frwr.fr_state = FRWR_IS_INVALID; frwr = &mr->frwr; - trace_xprtrdma_localinv(mr); + trace_xprtrdma_mr_localinv(mr); frwr->fr_cqe.done = frwr_wc_localinv; last = &frwr->fr_invwr; @@ -570,7 +537,7 @@ frwr_op_unmap_sync(struct rpcrdma_xprt *r_xprt, struct list_head *mrs) if (bad_wr != first) wait_for_completion(&frwr->fr_linv_done); if (rc) - goto reset_mrs; + goto out_release; /* ORDER: Now DMA unmap all of the MRs, and return * them to the free MR list. @@ -582,22 +549,21 @@ unmap: } return; -reset_mrs: +out_release: pr_err("rpcrdma: FRWR invalidate ib_post_send returned %i\n", rc); - /* Find and reset the MRs in the LOCAL_INV WRs that did not + /* Unmap and release the MRs in the LOCAL_INV WRs that did not * get posted. */ while (bad_wr) { frwr = container_of(bad_wr, struct rpcrdma_frwr, fr_invwr); mr = container_of(frwr, struct rpcrdma_mr, frwr); - - __frwr_mr_reset(ia, mr); - bad_wr = bad_wr->next; + + list_del(&mr->mr_list); + frwr_op_release_mr(mr); } - goto unmap; } const struct rpcrdma_memreg_ops rpcrdma_frwr_memreg_ops = { @@ -605,7 +571,6 @@ const struct rpcrdma_memreg_ops rpcrdma_frwr_memreg_ops = { .ro_send = frwr_op_send, .ro_reminv = frwr_op_reminv, .ro_unmap_sync = frwr_op_unmap_sync, - .ro_recover_mr = frwr_op_recover_mr, .ro_open = frwr_op_open, .ro_maxpages = frwr_op_maxpages, .ro_init_mr = frwr_op_init_mr, diff --git a/net/sunrpc/xprtrdma/rpc_rdma.c b/net/sunrpc/xprtrdma/rpc_rdma.c index c8ae983c6cc0..9f53e0240035 100644 --- a/net/sunrpc/xprtrdma/rpc_rdma.c +++ b/net/sunrpc/xprtrdma/rpc_rdma.c @@ -71,7 +71,6 @@ static unsigned int rpcrdma_max_call_header_size(unsigned int maxsegs) size = RPCRDMA_HDRLEN_MIN; /* Maximum Read list size */ - maxsegs += 2; /* segment for head and tail buffers */ size = maxsegs * rpcrdma_readchunk_maxsz * sizeof(__be32); /* Minimal Read chunk size */ @@ -97,7 +96,6 @@ static unsigned int rpcrdma_max_reply_header_size(unsigned int maxsegs) size = RPCRDMA_HDRLEN_MIN; /* Maximum Write list size */ - maxsegs += 2; /* segment for head and tail buffers */ size = sizeof(__be32); /* segment count */ size += maxsegs * rpcrdma_segment_maxsz * sizeof(__be32); size += sizeof(__be32); /* list discriminator */ @@ -805,7 +803,7 @@ rpcrdma_marshal_req(struct rpcrdma_xprt *r_xprt, struct rpc_rqst *rqst) struct rpcrdma_mr *mr; mr = rpcrdma_mr_pop(&req->rl_registered); - rpcrdma_mr_defer_recovery(mr); + rpcrdma_mr_recycle(mr); } /* This implementation supports the following combinations @@ -866,7 +864,7 @@ rpcrdma_marshal_req(struct rpcrdma_xprt *r_xprt, struct rpc_rqst *rqst) out_err: switch (ret) { case -EAGAIN: - xprt_wait_for_buffer_space(rqst->rq_task, NULL); + xprt_wait_for_buffer_space(rqst->rq_xprt); break; case -ENOBUFS: break; @@ -1216,7 +1214,6 @@ void rpcrdma_complete_rqst(struct rpcrdma_rep *rep) struct rpcrdma_xprt *r_xprt = rep->rr_rxprt; struct rpc_xprt *xprt = &r_xprt->rx_xprt; struct rpc_rqst *rqst = rep->rr_rqst; - unsigned long cwnd; int status; xprt->reestablish_timeout = 0; @@ -1238,15 +1235,10 @@ void rpcrdma_complete_rqst(struct rpcrdma_rep *rep) goto out_badheader; out: - spin_lock(&xprt->recv_lock); - cwnd = xprt->cwnd; - xprt->cwnd = r_xprt->rx_buf.rb_credits << RPC_CWNDSHIFT; - if (xprt->cwnd > cwnd) - xprt_release_rqst_cong(rqst->rq_task); - + spin_lock(&xprt->queue_lock); xprt_complete_rqst(rqst->rq_task, status); xprt_unpin_rqst(rqst); - spin_unlock(&xprt->recv_lock); + spin_unlock(&xprt->queue_lock); return; /* If the incoming reply terminated a pending RPC, the next @@ -1345,19 +1337,23 @@ void rpcrdma_reply_handler(struct rpcrdma_rep *rep) /* Match incoming rpcrdma_rep to an rpcrdma_req to * get context for handling any incoming chunks. */ - spin_lock(&xprt->recv_lock); + spin_lock(&xprt->queue_lock); rqst = xprt_lookup_rqst(xprt, rep->rr_xid); if (!rqst) goto out_norqst; xprt_pin_rqst(rqst); + spin_unlock(&xprt->queue_lock); if (credits == 0) credits = 1; /* don't deadlock */ else if (credits > buf->rb_max_requests) credits = buf->rb_max_requests; - buf->rb_credits = credits; - - spin_unlock(&xprt->recv_lock); + if (buf->rb_credits != credits) { + spin_lock_bh(&xprt->transport_lock); + buf->rb_credits = credits; + xprt->cwnd = credits << RPC_CWNDSHIFT; + spin_unlock_bh(&xprt->transport_lock); + } req = rpcr_to_rdmar(rqst); req->rl_reply = rep; @@ -1378,7 +1374,7 @@ out_badversion: * is corrupt. */ out_norqst: - spin_unlock(&xprt->recv_lock); + spin_unlock(&xprt->queue_lock); trace_xprtrdma_reply_rqst(rep); goto repost; diff --git a/net/sunrpc/xprtrdma/svc_rdma_backchannel.c b/net/sunrpc/xprtrdma/svc_rdma_backchannel.c index a68180090554..f3c147d70286 100644 --- a/net/sunrpc/xprtrdma/svc_rdma_backchannel.c +++ b/net/sunrpc/xprtrdma/svc_rdma_backchannel.c @@ -5,8 +5,6 @@ * Support for backward direction RPCs on RPC/RDMA (server-side). */ -#include <linux/module.h> - #include <linux/sunrpc/svc_rdma.h> #include "xprt_rdma.h" @@ -32,7 +30,6 @@ int svc_rdma_handle_bc_reply(struct rpc_xprt *xprt, __be32 *rdma_resp, struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); struct kvec *dst, *src = &rcvbuf->head[0]; struct rpc_rqst *req; - unsigned long cwnd; u32 credits; size_t len; __be32 xid; @@ -56,7 +53,7 @@ int svc_rdma_handle_bc_reply(struct rpc_xprt *xprt, __be32 *rdma_resp, if (src->iov_len < 24) goto out_shortreply; - spin_lock(&xprt->recv_lock); + spin_lock(&xprt->queue_lock); req = xprt_lookup_rqst(xprt, xid); if (!req) goto out_notfound; @@ -66,6 +63,8 @@ int svc_rdma_handle_bc_reply(struct rpc_xprt *xprt, __be32 *rdma_resp, if (dst->iov_len < len) goto out_unlock; memcpy(dst->iov_base, p, len); + xprt_pin_rqst(req); + spin_unlock(&xprt->queue_lock); credits = be32_to_cpup(rdma_resp + 2); if (credits == 0) @@ -74,19 +73,17 @@ int svc_rdma_handle_bc_reply(struct rpc_xprt *xprt, __be32 *rdma_resp, credits = r_xprt->rx_buf.rb_bc_max_requests; spin_lock_bh(&xprt->transport_lock); - cwnd = xprt->cwnd; xprt->cwnd = credits << RPC_CWNDSHIFT; - if (xprt->cwnd > cwnd) - xprt_release_rqst_cong(req->rq_task); spin_unlock_bh(&xprt->transport_lock); - + spin_lock(&xprt->queue_lock); ret = 0; xprt_complete_rqst(req->rq_task, rcvbuf->len); + xprt_unpin_rqst(req); rcvbuf->len = 0; out_unlock: - spin_unlock(&xprt->recv_lock); + spin_unlock(&xprt->queue_lock); out: return ret; @@ -215,9 +212,8 @@ drop_connection: * connection. */ static int -xprt_rdma_bc_send_request(struct rpc_task *task) +xprt_rdma_bc_send_request(struct rpc_rqst *rqst) { - struct rpc_rqst *rqst = task->tk_rqstp; struct svc_xprt *sxprt = rqst->rq_xprt->bc_xprt; struct svcxprt_rdma *rdma; int ret; @@ -225,12 +221,7 @@ xprt_rdma_bc_send_request(struct rpc_task *task) dprintk("svcrdma: sending bc call with xid: %08x\n", be32_to_cpu(rqst->rq_xid)); - if (!mutex_trylock(&sxprt->xpt_mutex)) { - rpc_sleep_on(&sxprt->xpt_bc_pending, task, NULL); - if (!mutex_trylock(&sxprt->xpt_mutex)) - return -EAGAIN; - rpc_wake_up_queued_task(&sxprt->xpt_bc_pending, task); - } + mutex_lock(&sxprt->xpt_mutex); ret = -ENOTCONN; rdma = container_of(sxprt, struct svcxprt_rdma, sc_xprt); @@ -248,6 +239,7 @@ static void xprt_rdma_bc_close(struct rpc_xprt *xprt) { dprintk("svcrdma: %s: xprt %p\n", __func__, xprt); + xprt->cwnd = RPC_CWNDSHIFT; } static void @@ -256,7 +248,6 @@ xprt_rdma_bc_put(struct rpc_xprt *xprt) dprintk("svcrdma: %s: xprt %p\n", __func__, xprt); xprt_free(xprt); - module_put(THIS_MODULE); } static const struct rpc_xprt_ops xprt_rdma_bc_procs = { @@ -328,20 +319,9 @@ xprt_setup_rdma_bc(struct xprt_create *args) args->bc_xprt->xpt_bc_xprt = xprt; xprt->bc_xprt = args->bc_xprt; - if (!try_module_get(THIS_MODULE)) - goto out_fail; - /* Final put for backchannel xprt is in __svc_rdma_free */ xprt_get(xprt); return xprt; - -out_fail: - xprt_rdma_free_addresses(xprt); - args->bc_xprt->xpt_bc_xprt = NULL; - args->bc_xprt->xpt_bc_xps = NULL; - xprt_put(xprt); - xprt_free(xprt); - return ERR_PTR(-EINVAL); } struct xprt_class xprt_rdma_bc = { diff --git a/net/sunrpc/xprtrdma/svc_rdma_transport.c b/net/sunrpc/xprtrdma/svc_rdma_transport.c index 2848cafd4a17..2f7ec8912f49 100644 --- a/net/sunrpc/xprtrdma/svc_rdma_transport.c +++ b/net/sunrpc/xprtrdma/svc_rdma_transport.c @@ -475,10 +475,12 @@ static struct svc_xprt *svc_rdma_accept(struct svc_xprt *xprt) /* Qualify the transport resource defaults with the * capabilities of this particular device */ - newxprt->sc_max_send_sges = dev->attrs.max_send_sge; - /* transport hdr, head iovec, one page list entry, tail iovec */ - if (newxprt->sc_max_send_sges < 4) { - pr_err("svcrdma: too few Send SGEs available (%d)\n", + /* Transport header, head iovec, tail iovec */ + newxprt->sc_max_send_sges = 3; + /* Add one SGE per page list entry */ + newxprt->sc_max_send_sges += svcrdma_max_req_size / PAGE_SIZE; + if (newxprt->sc_max_send_sges > dev->attrs.max_send_sge) { + pr_err("svcrdma: too few Send SGEs available (%d needed)\n", newxprt->sc_max_send_sges); goto errout; } diff --git a/net/sunrpc/xprtrdma/transport.c b/net/sunrpc/xprtrdma/transport.c index 143ce2579ba9..ae2a83828953 100644 --- a/net/sunrpc/xprtrdma/transport.c +++ b/net/sunrpc/xprtrdma/transport.c @@ -225,69 +225,59 @@ xprt_rdma_free_addresses(struct rpc_xprt *xprt) } } -void -rpcrdma_conn_func(struct rpcrdma_ep *ep) -{ - schedule_delayed_work(&ep->rep_connect_worker, 0); -} - -void -rpcrdma_connect_worker(struct work_struct *work) -{ - struct rpcrdma_ep *ep = - container_of(work, struct rpcrdma_ep, rep_connect_worker.work); - struct rpcrdma_xprt *r_xprt = - container_of(ep, struct rpcrdma_xprt, rx_ep); - struct rpc_xprt *xprt = &r_xprt->rx_xprt; - - spin_lock_bh(&xprt->transport_lock); - if (ep->rep_connected > 0) { - if (!xprt_test_and_set_connected(xprt)) - xprt_wake_pending_tasks(xprt, 0); - } else { - if (xprt_test_and_clear_connected(xprt)) - xprt_wake_pending_tasks(xprt, -ENOTCONN); - } - spin_unlock_bh(&xprt->transport_lock); -} - +/** + * xprt_rdma_connect_worker - establish connection in the background + * @work: worker thread context + * + * Requester holds the xprt's send lock to prevent activity on this + * transport while a fresh connection is being established. RPC tasks + * sleep on the xprt's pending queue waiting for connect to complete. + */ static void xprt_rdma_connect_worker(struct work_struct *work) { struct rpcrdma_xprt *r_xprt = container_of(work, struct rpcrdma_xprt, rx_connect_worker.work); struct rpc_xprt *xprt = &r_xprt->rx_xprt; - int rc = 0; - - xprt_clear_connected(xprt); + int rc; rc = rpcrdma_ep_connect(&r_xprt->rx_ep, &r_xprt->rx_ia); - if (rc) - xprt_wake_pending_tasks(xprt, rc); - xprt_clear_connecting(xprt); + if (r_xprt->rx_ep.rep_connected > 0) { + if (!xprt_test_and_set_connected(xprt)) { + xprt->stat.connect_count++; + xprt->stat.connect_time += (long)jiffies - + xprt->stat.connect_start; + xprt_wake_pending_tasks(xprt, -EAGAIN); + } + } else { + if (xprt_test_and_clear_connected(xprt)) + xprt_wake_pending_tasks(xprt, rc); + } } +/** + * xprt_rdma_inject_disconnect - inject a connection fault + * @xprt: transport context + * + * If @xprt is connected, disconnect it to simulate spurious connection + * loss. + */ static void xprt_rdma_inject_disconnect(struct rpc_xprt *xprt) { - struct rpcrdma_xprt *r_xprt = container_of(xprt, struct rpcrdma_xprt, - rx_xprt); + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); trace_xprtrdma_inject_dsc(r_xprt); rdma_disconnect(r_xprt->rx_ia.ri_id); } -/* - * xprt_rdma_destroy +/** + * xprt_rdma_destroy - Full tear down of transport + * @xprt: doomed transport context * - * Destroy the xprt. - * Free all memory associated with the object, including its own. - * NOTE: none of the *destroy methods free memory for their top-level - * objects, even though they may have allocated it (they do free - * private memory). It's up to the caller to handle it. In this - * case (RDMA transport), all structure memory is inlined with the - * struct rpcrdma_xprt. + * Caller guarantees there will be no more calls to us with + * this @xprt. */ static void xprt_rdma_destroy(struct rpc_xprt *xprt) @@ -298,8 +288,6 @@ xprt_rdma_destroy(struct rpc_xprt *xprt) cancel_delayed_work_sync(&r_xprt->rx_connect_worker); - xprt_clear_connected(xprt); - rpcrdma_ep_destroy(&r_xprt->rx_ep, &r_xprt->rx_ia); rpcrdma_buffer_destroy(&r_xprt->rx_buf); rpcrdma_ia_close(&r_xprt->rx_ia); @@ -442,11 +430,12 @@ out1: } /** - * xprt_rdma_close - Close down RDMA connection - * @xprt: generic transport to be closed + * xprt_rdma_close - close a transport connection + * @xprt: transport context * - * Called during transport shutdown reconnect, or device - * removal. Caller holds the transport's write lock. + * Called during transport shutdown, reconnect, or device removal. + * Caller holds @xprt's send lock to prevent activity on this + * transport while the connection is torn down. */ static void xprt_rdma_close(struct rpc_xprt *xprt) @@ -468,6 +457,12 @@ xprt_rdma_close(struct rpc_xprt *xprt) xprt->reestablish_timeout = 0; xprt_disconnect_done(xprt); rpcrdma_ep_disconnect(ep, ia); + + /* Prepare @xprt for the next connection by reinitializing + * its credit grant to one (see RFC 8166, Section 3.3.3). + */ + r_xprt->rx_buf.rb_credits = 1; + xprt->cwnd = RPC_CWNDSHIFT; } /** @@ -519,6 +514,12 @@ xprt_rdma_timer(struct rpc_xprt *xprt, struct rpc_task *task) xprt_force_disconnect(xprt); } +/** + * xprt_rdma_connect - try to establish a transport connection + * @xprt: transport state + * @task: RPC scheduler context + * + */ static void xprt_rdma_connect(struct rpc_xprt *xprt, struct rpc_task *task) { @@ -638,13 +639,6 @@ rpcrdma_get_recvbuf(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req, * 0: Success; rq_buffer points to RPC buffer to use * ENOMEM: Out of memory, call again later * EIO: A permanent error occurred, do not retry - * - * The RDMA allocate/free functions need the task structure as a place - * to hide the struct rpcrdma_req, which is necessary for the actual - * send/recv sequence. - * - * xprt_rdma_allocate provides buffers that are already mapped for - * DMA, and a local DMA lkey is provided for each. */ static int xprt_rdma_allocate(struct rpc_task *task) @@ -693,7 +687,7 @@ xprt_rdma_free(struct rpc_task *task) /** * xprt_rdma_send_request - marshal and send an RPC request - * @task: RPC task with an RPC message in rq_snd_buf + * @rqst: RPC message in rq_snd_buf * * Caller holds the transport's write lock. * @@ -706,9 +700,8 @@ xprt_rdma_free(struct rpc_task *task) * sent. Do not try to send this message again. */ static int -xprt_rdma_send_request(struct rpc_task *task) +xprt_rdma_send_request(struct rpc_rqst *rqst) { - struct rpc_rqst *rqst = task->tk_rqstp; struct rpc_xprt *xprt = rqst->rq_xprt; struct rpcrdma_req *req = rpcr_to_rdmar(rqst); struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); @@ -722,6 +715,9 @@ xprt_rdma_send_request(struct rpc_task *task) if (!xprt_connected(xprt)) goto drop_connection; + if (!xprt_request_get_cong(xprt, rqst)) + return -EBADSLT; + rc = rpcrdma_marshal_req(r_xprt, rqst); if (rc < 0) goto failed_marshal; @@ -741,7 +737,7 @@ xprt_rdma_send_request(struct rpc_task *task) /* An RPC with no reply will throw off credit accounting, * so drop the connection to reset the credit grant. */ - if (!rpc_reply_expected(task)) + if (!rpc_reply_expected(rqst->rq_task)) goto drop_connection; return 0; @@ -766,7 +762,7 @@ void xprt_rdma_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) 0, /* need a local port? */ xprt->stat.bind_count, xprt->stat.connect_count, - xprt->stat.connect_time, + xprt->stat.connect_time / HZ, idle_time, xprt->stat.sends, xprt->stat.recvs, @@ -786,7 +782,7 @@ void xprt_rdma_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) r_xprt->rx_stats.bad_reply_count, r_xprt->rx_stats.nomsg_call_count); seq_printf(seq, "%lu %lu %lu %lu %lu %lu\n", - r_xprt->rx_stats.mrs_recovered, + r_xprt->rx_stats.mrs_recycled, r_xprt->rx_stats.mrs_orphaned, r_xprt->rx_stats.mrs_allocated, r_xprt->rx_stats.local_inv_needed, diff --git a/net/sunrpc/xprtrdma/verbs.c b/net/sunrpc/xprtrdma/verbs.c index 956a5ea47b58..3ddba94c939f 100644 --- a/net/sunrpc/xprtrdma/verbs.c +++ b/net/sunrpc/xprtrdma/verbs.c @@ -108,20 +108,48 @@ rpcrdma_destroy_wq(void) } } +/** + * rpcrdma_disconnect_worker - Force a disconnect + * @work: endpoint to be disconnected + * + * Provider callbacks can possibly run in an IRQ context. This function + * is invoked in a worker thread to guarantee that disconnect wake-up + * calls are always done in process context. + */ +static void +rpcrdma_disconnect_worker(struct work_struct *work) +{ + struct rpcrdma_ep *ep = container_of(work, struct rpcrdma_ep, + rep_disconnect_worker.work); + struct rpcrdma_xprt *r_xprt = + container_of(ep, struct rpcrdma_xprt, rx_ep); + + xprt_force_disconnect(&r_xprt->rx_xprt); +} + +/** + * rpcrdma_qp_event_handler - Handle one QP event (error notification) + * @event: details of the event + * @context: ep that owns QP where event occurred + * + * Called from the RDMA provider (device driver) possibly in an interrupt + * context. + */ static void -rpcrdma_qp_async_error_upcall(struct ib_event *event, void *context) +rpcrdma_qp_event_handler(struct ib_event *event, void *context) { struct rpcrdma_ep *ep = context; struct rpcrdma_xprt *r_xprt = container_of(ep, struct rpcrdma_xprt, rx_ep); - trace_xprtrdma_qp_error(r_xprt, event); - pr_err("rpcrdma: %s on device %s ep %p\n", - ib_event_msg(event->event), event->device->name, context); + trace_xprtrdma_qp_event(r_xprt, event); + pr_err("rpcrdma: %s on device %s connected to %s:%s\n", + ib_event_msg(event->event), event->device->name, + rpcrdma_addrstr(r_xprt), rpcrdma_portstr(r_xprt)); if (ep->rep_connected == 1) { ep->rep_connected = -EIO; - rpcrdma_conn_func(ep); + schedule_delayed_work(&ep->rep_disconnect_worker, 0); wake_up_all(&ep->rep_connect_wait); } } @@ -219,38 +247,48 @@ rpcrdma_update_connect_private(struct rpcrdma_xprt *r_xprt, rpcrdma_set_max_header_sizes(r_xprt); } +/** + * rpcrdma_cm_event_handler - Handle RDMA CM events + * @id: rdma_cm_id on which an event has occurred + * @event: details of the event + * + * Called with @id's mutex held. Returns 1 if caller should + * destroy @id, otherwise 0. + */ static int -rpcrdma_conn_upcall(struct rdma_cm_id *id, struct rdma_cm_event *event) +rpcrdma_cm_event_handler(struct rdma_cm_id *id, struct rdma_cm_event *event) { - struct rpcrdma_xprt *xprt = id->context; - struct rpcrdma_ia *ia = &xprt->rx_ia; - struct rpcrdma_ep *ep = &xprt->rx_ep; - int connstate = 0; + struct rpcrdma_xprt *r_xprt = id->context; + struct rpcrdma_ia *ia = &r_xprt->rx_ia; + struct rpcrdma_ep *ep = &r_xprt->rx_ep; + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + + might_sleep(); - trace_xprtrdma_conn_upcall(xprt, event); + trace_xprtrdma_cm_event(r_xprt, event); switch (event->event) { case RDMA_CM_EVENT_ADDR_RESOLVED: case RDMA_CM_EVENT_ROUTE_RESOLVED: ia->ri_async_rc = 0; complete(&ia->ri_done); - break; + return 0; case RDMA_CM_EVENT_ADDR_ERROR: ia->ri_async_rc = -EPROTO; complete(&ia->ri_done); - break; + return 0; case RDMA_CM_EVENT_ROUTE_ERROR: ia->ri_async_rc = -ENETUNREACH; complete(&ia->ri_done); - break; + return 0; case RDMA_CM_EVENT_DEVICE_REMOVAL: #if IS_ENABLED(CONFIG_SUNRPC_DEBUG) pr_info("rpcrdma: removing device %s for %s:%s\n", ia->ri_device->name, - rpcrdma_addrstr(xprt), rpcrdma_portstr(xprt)); + rpcrdma_addrstr(r_xprt), rpcrdma_portstr(r_xprt)); #endif set_bit(RPCRDMA_IAF_REMOVING, &ia->ri_flags); ep->rep_connected = -ENODEV; - xprt_force_disconnect(&xprt->rx_xprt); + xprt_force_disconnect(xprt); wait_for_completion(&ia->ri_remove_done); ia->ri_id = NULL; @@ -258,41 +296,40 @@ rpcrdma_conn_upcall(struct rdma_cm_id *id, struct rdma_cm_event *event) /* Return 1 to ensure the core destroys the id. */ return 1; case RDMA_CM_EVENT_ESTABLISHED: - ++xprt->rx_xprt.connect_cookie; - connstate = 1; - rpcrdma_update_connect_private(xprt, &event->param.conn); - goto connected; + ++xprt->connect_cookie; + ep->rep_connected = 1; + rpcrdma_update_connect_private(r_xprt, &event->param.conn); + wake_up_all(&ep->rep_connect_wait); + break; case RDMA_CM_EVENT_CONNECT_ERROR: - connstate = -ENOTCONN; - goto connected; + ep->rep_connected = -ENOTCONN; + goto disconnected; case RDMA_CM_EVENT_UNREACHABLE: - connstate = -ENETUNREACH; - goto connected; + ep->rep_connected = -ENETUNREACH; + goto disconnected; case RDMA_CM_EVENT_REJECTED: dprintk("rpcrdma: connection to %s:%s rejected: %s\n", - rpcrdma_addrstr(xprt), rpcrdma_portstr(xprt), + rpcrdma_addrstr(r_xprt), rpcrdma_portstr(r_xprt), rdma_reject_msg(id, event->status)); - connstate = -ECONNREFUSED; + ep->rep_connected = -ECONNREFUSED; if (event->status == IB_CM_REJ_STALE_CONN) - connstate = -EAGAIN; - goto connected; + ep->rep_connected = -EAGAIN; + goto disconnected; case RDMA_CM_EVENT_DISCONNECTED: - ++xprt->rx_xprt.connect_cookie; - connstate = -ECONNABORTED; -connected: - ep->rep_connected = connstate; - rpcrdma_conn_func(ep); + ++xprt->connect_cookie; + ep->rep_connected = -ECONNABORTED; +disconnected: + xprt_force_disconnect(xprt); wake_up_all(&ep->rep_connect_wait); - /*FALLTHROUGH*/ + break; default: - dprintk("RPC: %s: %s:%s on %s/%s (ep 0x%p): %s\n", - __func__, - rpcrdma_addrstr(xprt), rpcrdma_portstr(xprt), - ia->ri_device->name, ia->ri_ops->ro_displayname, - ep, rdma_event_msg(event->event)); break; } + dprintk("RPC: %s: %s:%s on %s/%s: %s\n", __func__, + rpcrdma_addrstr(r_xprt), rpcrdma_portstr(r_xprt), + ia->ri_device->name, ia->ri_ops->ro_displayname, + rdma_event_msg(event->event)); return 0; } @@ -308,7 +345,7 @@ rpcrdma_create_id(struct rpcrdma_xprt *xprt, struct rpcrdma_ia *ia) init_completion(&ia->ri_done); init_completion(&ia->ri_remove_done); - id = rdma_create_id(xprt->rx_xprt.xprt_net, rpcrdma_conn_upcall, + id = rdma_create_id(xprt->rx_xprt.xprt_net, rpcrdma_cm_event_handler, xprt, RDMA_PS_TCP, IB_QPT_RC); if (IS_ERR(id)) { rc = PTR_ERR(id); @@ -519,7 +556,7 @@ rpcrdma_ep_create(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia, if (rc) return rc; - ep->rep_attr.event_handler = rpcrdma_qp_async_error_upcall; + ep->rep_attr.event_handler = rpcrdma_qp_event_handler; ep->rep_attr.qp_context = ep; ep->rep_attr.srq = NULL; ep->rep_attr.cap.max_send_sge = max_sge; @@ -542,7 +579,8 @@ rpcrdma_ep_create(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia, cdata->max_requests >> 2); ep->rep_send_count = ep->rep_send_batch; init_waitqueue_head(&ep->rep_connect_wait); - INIT_DELAYED_WORK(&ep->rep_connect_worker, rpcrdma_connect_worker); + INIT_DELAYED_WORK(&ep->rep_disconnect_worker, + rpcrdma_disconnect_worker); sendcq = ib_alloc_cq(ia->ri_device, NULL, ep->rep_attr.cap.max_send_wr + 1, @@ -615,7 +653,7 @@ out1: void rpcrdma_ep_destroy(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia) { - cancel_delayed_work_sync(&ep->rep_connect_worker); + cancel_delayed_work_sync(&ep->rep_disconnect_worker); if (ia->ri_id && ia->ri_id->qp) { rpcrdma_ep_disconnect(ep, ia); @@ -728,6 +766,7 @@ rpcrdma_ep_connect(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia) { struct rpcrdma_xprt *r_xprt = container_of(ia, struct rpcrdma_xprt, rx_ia); + struct rpc_xprt *xprt = &r_xprt->rx_xprt; int rc; retry: @@ -754,6 +793,8 @@ retry: } ep->rep_connected = 0; + xprt_clear_connected(xprt); + rpcrdma_post_recvs(r_xprt, true); rc = rdma_connect(ia->ri_id, &ep->rep_remote_cma); @@ -877,7 +918,6 @@ static int rpcrdma_sendctxs_create(struct rpcrdma_xprt *r_xprt) sc->sc_xprt = r_xprt; buf->rb_sc_ctxs[i] = sc; } - buf->rb_flags = 0; return 0; @@ -978,39 +1018,6 @@ rpcrdma_sendctx_put_locked(struct rpcrdma_sendctx *sc) } static void -rpcrdma_mr_recovery_worker(struct work_struct *work) -{ - struct rpcrdma_buffer *buf = container_of(work, struct rpcrdma_buffer, - rb_recovery_worker.work); - struct rpcrdma_mr *mr; - - spin_lock(&buf->rb_recovery_lock); - while (!list_empty(&buf->rb_stale_mrs)) { - mr = rpcrdma_mr_pop(&buf->rb_stale_mrs); - spin_unlock(&buf->rb_recovery_lock); - - trace_xprtrdma_recover_mr(mr); - mr->mr_xprt->rx_ia.ri_ops->ro_recover_mr(mr); - - spin_lock(&buf->rb_recovery_lock); - } - spin_unlock(&buf->rb_recovery_lock); -} - -void -rpcrdma_mr_defer_recovery(struct rpcrdma_mr *mr) -{ - struct rpcrdma_xprt *r_xprt = mr->mr_xprt; - struct rpcrdma_buffer *buf = &r_xprt->rx_buf; - - spin_lock(&buf->rb_recovery_lock); - rpcrdma_mr_push(mr, &buf->rb_stale_mrs); - spin_unlock(&buf->rb_recovery_lock); - - schedule_delayed_work(&buf->rb_recovery_worker, 0); -} - -static void rpcrdma_mrs_create(struct rpcrdma_xprt *r_xprt) { struct rpcrdma_buffer *buf = &r_xprt->rx_buf; @@ -1019,7 +1026,7 @@ rpcrdma_mrs_create(struct rpcrdma_xprt *r_xprt) LIST_HEAD(free); LIST_HEAD(all); - for (count = 0; count < 3; count++) { + for (count = 0; count < ia->ri_max_segs; count++) { struct rpcrdma_mr *mr; int rc; @@ -1138,18 +1145,15 @@ rpcrdma_buffer_create(struct rpcrdma_xprt *r_xprt) struct rpcrdma_buffer *buf = &r_xprt->rx_buf; int i, rc; + buf->rb_flags = 0; buf->rb_max_requests = r_xprt->rx_data.max_requests; buf->rb_bc_srv_max_requests = 0; spin_lock_init(&buf->rb_mrlock); spin_lock_init(&buf->rb_lock); - spin_lock_init(&buf->rb_recovery_lock); INIT_LIST_HEAD(&buf->rb_mrs); INIT_LIST_HEAD(&buf->rb_all); - INIT_LIST_HEAD(&buf->rb_stale_mrs); INIT_DELAYED_WORK(&buf->rb_refresh_worker, rpcrdma_mr_refresh_worker); - INIT_DELAYED_WORK(&buf->rb_recovery_worker, - rpcrdma_mr_recovery_worker); rpcrdma_mrs_create(r_xprt); @@ -1233,7 +1237,6 @@ rpcrdma_mrs_destroy(struct rpcrdma_buffer *buf) void rpcrdma_buffer_destroy(struct rpcrdma_buffer *buf) { - cancel_delayed_work_sync(&buf->rb_recovery_worker); cancel_delayed_work_sync(&buf->rb_refresh_worker); rpcrdma_sendctxs_destroy(buf); @@ -1326,7 +1329,7 @@ rpcrdma_mr_unmap_and_put(struct rpcrdma_mr *mr) { struct rpcrdma_xprt *r_xprt = mr->mr_xprt; - trace_xprtrdma_dma_unmap(mr); + trace_xprtrdma_mr_unmap(mr); ib_dma_unmap_sg(r_xprt->rx_ia.ri_device, mr->mr_sg, mr->mr_nents, mr->mr_dir); __rpcrdma_mr_put(&r_xprt->rx_buf, mr); @@ -1518,9 +1521,11 @@ rpcrdma_post_recvs(struct rpcrdma_xprt *r_xprt, bool temp) struct ib_recv_wr *wr, *bad_wr; int needed, count, rc; + rc = 0; + count = 0; needed = buf->rb_credits + (buf->rb_bc_srv_max_requests << 1); if (buf->rb_posted_receives > needed) - return; + goto out; needed -= buf->rb_posted_receives; count = 0; @@ -1556,7 +1561,7 @@ rpcrdma_post_recvs(struct rpcrdma_xprt *r_xprt, bool temp) --needed; } if (!count) - return; + goto out; rc = ib_post_recv(r_xprt->rx_ia.ri_id->qp, wr, (const struct ib_recv_wr **)&bad_wr); @@ -1570,5 +1575,6 @@ rpcrdma_post_recvs(struct rpcrdma_xprt *r_xprt, bool temp) } } buf->rb_posted_receives += count; +out: trace_xprtrdma_post_recvs(r_xprt, count, rc); } diff --git a/net/sunrpc/xprtrdma/xprt_rdma.h b/net/sunrpc/xprtrdma/xprt_rdma.h index 2ca14f7c2d51..a13ccb643ce0 100644 --- a/net/sunrpc/xprtrdma/xprt_rdma.h +++ b/net/sunrpc/xprtrdma/xprt_rdma.h @@ -101,7 +101,7 @@ struct rpcrdma_ep { wait_queue_head_t rep_connect_wait; struct rpcrdma_connect_private rep_cm_private; struct rdma_conn_param rep_remote_cma; - struct delayed_work rep_connect_worker; + struct delayed_work rep_disconnect_worker; }; /* Pre-allocate extra Work Requests for handling backward receives @@ -280,6 +280,7 @@ struct rpcrdma_mr { u32 mr_handle; u32 mr_length; u64 mr_offset; + struct work_struct mr_recycle; struct list_head mr_all; }; @@ -411,9 +412,6 @@ struct rpcrdma_buffer { u32 rb_bc_max_requests; - spinlock_t rb_recovery_lock; /* protect rb_stale_mrs */ - struct list_head rb_stale_mrs; - struct delayed_work rb_recovery_worker; struct delayed_work rb_refresh_worker; }; #define rdmab_to_ia(b) (&container_of((b), struct rpcrdma_xprt, rx_buf)->rx_ia) @@ -452,7 +450,7 @@ struct rpcrdma_stats { unsigned long hardway_register_count; unsigned long failed_marshal_count; unsigned long bad_reply_count; - unsigned long mrs_recovered; + unsigned long mrs_recycled; unsigned long mrs_orphaned; unsigned long mrs_allocated; unsigned long empty_sendctx_q; @@ -481,7 +479,6 @@ struct rpcrdma_memreg_ops { struct list_head *mrs); void (*ro_unmap_sync)(struct rpcrdma_xprt *, struct list_head *); - void (*ro_recover_mr)(struct rpcrdma_mr *mr); int (*ro_open)(struct rpcrdma_ia *, struct rpcrdma_ep *, struct rpcrdma_create_data_internal *); @@ -559,7 +556,6 @@ int rpcrdma_ep_create(struct rpcrdma_ep *, struct rpcrdma_ia *, struct rpcrdma_create_data_internal *); void rpcrdma_ep_destroy(struct rpcrdma_ep *, struct rpcrdma_ia *); int rpcrdma_ep_connect(struct rpcrdma_ep *, struct rpcrdma_ia *); -void rpcrdma_conn_func(struct rpcrdma_ep *ep); void rpcrdma_ep_disconnect(struct rpcrdma_ep *, struct rpcrdma_ia *); int rpcrdma_ep_post(struct rpcrdma_ia *, struct rpcrdma_ep *, @@ -578,7 +574,12 @@ struct rpcrdma_sendctx *rpcrdma_sendctx_get_locked(struct rpcrdma_buffer *buf); struct rpcrdma_mr *rpcrdma_mr_get(struct rpcrdma_xprt *r_xprt); void rpcrdma_mr_put(struct rpcrdma_mr *mr); void rpcrdma_mr_unmap_and_put(struct rpcrdma_mr *mr); -void rpcrdma_mr_defer_recovery(struct rpcrdma_mr *mr); + +static inline void +rpcrdma_mr_recycle(struct rpcrdma_mr *mr) +{ + schedule_work(&mr->mr_recycle); +} struct rpcrdma_req *rpcrdma_buffer_get(struct rpcrdma_buffer *); void rpcrdma_buffer_put(struct rpcrdma_req *); @@ -652,7 +653,6 @@ static inline void rpcrdma_set_xdrlen(struct xdr_buf *xdr, size_t len) extern unsigned int xprt_rdma_max_inline_read; void xprt_rdma_format_addresses(struct rpc_xprt *xprt, struct sockaddr *sap); void xprt_rdma_free_addresses(struct rpc_xprt *xprt); -void rpcrdma_connect_worker(struct work_struct *work); void xprt_rdma_print_stats(struct rpc_xprt *xprt, struct seq_file *seq); int xprt_rdma_init(void); void xprt_rdma_cleanup(void); diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c index 6b7539c0466e..ae77c71c1f64 100644 --- a/net/sunrpc/xprtsock.c +++ b/net/sunrpc/xprtsock.c @@ -47,13 +47,13 @@ #include <net/checksum.h> #include <net/udp.h> #include <net/tcp.h> +#include <linux/bvec.h> +#include <linux/uio.h> #include <trace/events/sunrpc.h> #include "sunrpc.h" -#define RPC_TCP_READ_CHUNK_SZ (3*512*1024) - static void xs_close(struct rpc_xprt *xprt); static void xs_tcp_set_socket_timeouts(struct rpc_xprt *xprt, struct socket *sock); @@ -129,7 +129,7 @@ static struct ctl_table xs_tunables_table[] = { .mode = 0644, .proc_handler = proc_dointvec_minmax, .extra1 = &xprt_min_resvport_limit, - .extra2 = &xprt_max_resvport + .extra2 = &xprt_max_resvport_limit }, { .procname = "max_resvport", @@ -137,7 +137,7 @@ static struct ctl_table xs_tunables_table[] = { .maxlen = sizeof(unsigned int), .mode = 0644, .proc_handler = proc_dointvec_minmax, - .extra1 = &xprt_min_resvport, + .extra1 = &xprt_min_resvport_limit, .extra2 = &xprt_max_resvport_limit }, { @@ -325,6 +325,362 @@ static void xs_free_peer_addresses(struct rpc_xprt *xprt) } } +static size_t +xs_alloc_sparse_pages(struct xdr_buf *buf, size_t want, gfp_t gfp) +{ + size_t i,n; + + if (!(buf->flags & XDRBUF_SPARSE_PAGES)) + return want; + if (want > buf->page_len) + want = buf->page_len; + n = (buf->page_base + want + PAGE_SIZE - 1) >> PAGE_SHIFT; + for (i = 0; i < n; i++) { + if (buf->pages[i]) + continue; + buf->bvec[i].bv_page = buf->pages[i] = alloc_page(gfp); + if (!buf->pages[i]) { + buf->page_len = (i * PAGE_SIZE) - buf->page_base; + return buf->page_len; + } + } + return want; +} + +static ssize_t +xs_sock_recvmsg(struct socket *sock, struct msghdr *msg, int flags, size_t seek) +{ + ssize_t ret; + if (seek != 0) + iov_iter_advance(&msg->msg_iter, seek); + ret = sock_recvmsg(sock, msg, flags); + return ret > 0 ? ret + seek : ret; +} + +static ssize_t +xs_read_kvec(struct socket *sock, struct msghdr *msg, int flags, + struct kvec *kvec, size_t count, size_t seek) +{ + iov_iter_kvec(&msg->msg_iter, READ, kvec, 1, count); + return xs_sock_recvmsg(sock, msg, flags, seek); +} + +static ssize_t +xs_read_bvec(struct socket *sock, struct msghdr *msg, int flags, + struct bio_vec *bvec, unsigned long nr, size_t count, + size_t seek) +{ + iov_iter_bvec(&msg->msg_iter, READ, bvec, nr, count); + return xs_sock_recvmsg(sock, msg, flags, seek); +} + +static ssize_t +xs_read_discard(struct socket *sock, struct msghdr *msg, int flags, + size_t count) +{ + struct kvec kvec = { 0 }; + return xs_read_kvec(sock, msg, flags | MSG_TRUNC, &kvec, count, 0); +} + +static ssize_t +xs_read_xdr_buf(struct socket *sock, struct msghdr *msg, int flags, + struct xdr_buf *buf, size_t count, size_t seek, size_t *read) +{ + size_t want, seek_init = seek, offset = 0; + ssize_t ret; + + if (seek < buf->head[0].iov_len) { + want = min_t(size_t, count, buf->head[0].iov_len); + ret = xs_read_kvec(sock, msg, flags, &buf->head[0], want, seek); + if (ret <= 0) + goto sock_err; + offset += ret; + if (offset == count || msg->msg_flags & (MSG_EOR|MSG_TRUNC)) + goto out; + if (ret != want) + goto eagain; + seek = 0; + } else { + seek -= buf->head[0].iov_len; + offset += buf->head[0].iov_len; + } + if (seek < buf->page_len) { + want = xs_alloc_sparse_pages(buf, + min_t(size_t, count - offset, buf->page_len), + GFP_NOWAIT); + ret = xs_read_bvec(sock, msg, flags, buf->bvec, + xdr_buf_pagecount(buf), + want + buf->page_base, + seek + buf->page_base); + if (ret <= 0) + goto sock_err; + offset += ret - buf->page_base; + if (offset == count || msg->msg_flags & (MSG_EOR|MSG_TRUNC)) + goto out; + if (ret != want) + goto eagain; + seek = 0; + } else { + seek -= buf->page_len; + offset += buf->page_len; + } + if (seek < buf->tail[0].iov_len) { + want = min_t(size_t, count - offset, buf->tail[0].iov_len); + ret = xs_read_kvec(sock, msg, flags, &buf->tail[0], want, seek); + if (ret <= 0) + goto sock_err; + offset += ret; + if (offset == count || msg->msg_flags & (MSG_EOR|MSG_TRUNC)) + goto out; + if (ret != want) + goto eagain; + } else + offset += buf->tail[0].iov_len; + ret = -EMSGSIZE; + msg->msg_flags |= MSG_TRUNC; +out: + *read = offset - seek_init; + return ret; +eagain: + ret = -EAGAIN; + goto out; +sock_err: + offset += seek; + goto out; +} + +static void +xs_read_header(struct sock_xprt *transport, struct xdr_buf *buf) +{ + if (!transport->recv.copied) { + if (buf->head[0].iov_len >= transport->recv.offset) + memcpy(buf->head[0].iov_base, + &transport->recv.xid, + transport->recv.offset); + transport->recv.copied = transport->recv.offset; + } +} + +static bool +xs_read_stream_request_done(struct sock_xprt *transport) +{ + return transport->recv.fraghdr & cpu_to_be32(RPC_LAST_STREAM_FRAGMENT); +} + +static ssize_t +xs_read_stream_request(struct sock_xprt *transport, struct msghdr *msg, + int flags, struct rpc_rqst *req) +{ + struct xdr_buf *buf = &req->rq_private_buf; + size_t want, read; + ssize_t ret; + + xs_read_header(transport, buf); + + want = transport->recv.len - transport->recv.offset; + ret = xs_read_xdr_buf(transport->sock, msg, flags, buf, + transport->recv.copied + want, transport->recv.copied, + &read); + transport->recv.offset += read; + transport->recv.copied += read; + if (transport->recv.offset == transport->recv.len) { + if (xs_read_stream_request_done(transport)) + msg->msg_flags |= MSG_EOR; + return transport->recv.copied; + } + + switch (ret) { + case -EMSGSIZE: + return transport->recv.copied; + case 0: + return -ESHUTDOWN; + default: + if (ret < 0) + return ret; + } + return -EAGAIN; +} + +static size_t +xs_read_stream_headersize(bool isfrag) +{ + if (isfrag) + return sizeof(__be32); + return 3 * sizeof(__be32); +} + +static ssize_t +xs_read_stream_header(struct sock_xprt *transport, struct msghdr *msg, + int flags, size_t want, size_t seek) +{ + struct kvec kvec = { + .iov_base = &transport->recv.fraghdr, + .iov_len = want, + }; + return xs_read_kvec(transport->sock, msg, flags, &kvec, want, seek); +} + +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +static ssize_t +xs_read_stream_call(struct sock_xprt *transport, struct msghdr *msg, int flags) +{ + struct rpc_xprt *xprt = &transport->xprt; + struct rpc_rqst *req; + ssize_t ret; + + /* Look up and lock the request corresponding to the given XID */ + req = xprt_lookup_bc_request(xprt, transport->recv.xid); + if (!req) { + printk(KERN_WARNING "Callback slot table overflowed\n"); + return -ESHUTDOWN; + } + + ret = xs_read_stream_request(transport, msg, flags, req); + if (msg->msg_flags & (MSG_EOR|MSG_TRUNC)) + xprt_complete_bc_request(req, ret); + + return ret; +} +#else /* CONFIG_SUNRPC_BACKCHANNEL */ +static ssize_t +xs_read_stream_call(struct sock_xprt *transport, struct msghdr *msg, int flags) +{ + return -ESHUTDOWN; +} +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + +static ssize_t +xs_read_stream_reply(struct sock_xprt *transport, struct msghdr *msg, int flags) +{ + struct rpc_xprt *xprt = &transport->xprt; + struct rpc_rqst *req; + ssize_t ret = 0; + + /* Look up and lock the request corresponding to the given XID */ + spin_lock(&xprt->queue_lock); + req = xprt_lookup_rqst(xprt, transport->recv.xid); + if (!req) { + msg->msg_flags |= MSG_TRUNC; + goto out; + } + xprt_pin_rqst(req); + spin_unlock(&xprt->queue_lock); + + ret = xs_read_stream_request(transport, msg, flags, req); + + spin_lock(&xprt->queue_lock); + if (msg->msg_flags & (MSG_EOR|MSG_TRUNC)) + xprt_complete_rqst(req->rq_task, ret); + xprt_unpin_rqst(req); +out: + spin_unlock(&xprt->queue_lock); + return ret; +} + +static ssize_t +xs_read_stream(struct sock_xprt *transport, int flags) +{ + struct msghdr msg = { 0 }; + size_t want, read = 0; + ssize_t ret = 0; + + if (transport->recv.len == 0) { + want = xs_read_stream_headersize(transport->recv.copied != 0); + ret = xs_read_stream_header(transport, &msg, flags, want, + transport->recv.offset); + if (ret <= 0) + goto out_err; + transport->recv.offset = ret; + if (ret != want) { + ret = -EAGAIN; + goto out_err; + } + transport->recv.len = be32_to_cpu(transport->recv.fraghdr) & + RPC_FRAGMENT_SIZE_MASK; + transport->recv.offset -= sizeof(transport->recv.fraghdr); + read = ret; + } + + switch (be32_to_cpu(transport->recv.calldir)) { + case RPC_CALL: + ret = xs_read_stream_call(transport, &msg, flags); + break; + case RPC_REPLY: + ret = xs_read_stream_reply(transport, &msg, flags); + } + if (msg.msg_flags & MSG_TRUNC) { + transport->recv.calldir = cpu_to_be32(-1); + transport->recv.copied = -1; + } + if (ret < 0) + goto out_err; + read += ret; + if (transport->recv.offset < transport->recv.len) { + ret = xs_read_discard(transport->sock, &msg, flags, + transport->recv.len - transport->recv.offset); + if (ret <= 0) + goto out_err; + transport->recv.offset += ret; + read += ret; + if (transport->recv.offset != transport->recv.len) + return -EAGAIN; + } + if (xs_read_stream_request_done(transport)) { + trace_xs_stream_read_request(transport); + transport->recv.copied = 0; + } + transport->recv.offset = 0; + transport->recv.len = 0; + return read; +out_err: + switch (ret) { + case 0: + case -ESHUTDOWN: + xprt_force_disconnect(&transport->xprt); + return -ESHUTDOWN; + } + return ret; +} + +static void xs_stream_data_receive(struct sock_xprt *transport) +{ + size_t read = 0; + ssize_t ret = 0; + + mutex_lock(&transport->recv_mutex); + if (transport->sock == NULL) + goto out; + clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state); + for (;;) { + ret = xs_read_stream(transport, MSG_DONTWAIT); + if (ret <= 0) + break; + read += ret; + cond_resched(); + } +out: + mutex_unlock(&transport->recv_mutex); + trace_xs_stream_read_data(&transport->xprt, ret, read); +} + +static void xs_stream_data_receive_workfn(struct work_struct *work) +{ + struct sock_xprt *transport = + container_of(work, struct sock_xprt, recv_worker); + xs_stream_data_receive(transport); +} + +static void +xs_stream_reset_connect(struct sock_xprt *transport) +{ + transport->recv.offset = 0; + transport->recv.len = 0; + transport->recv.copied = 0; + transport->xmit.offset = 0; + transport->xprt.stat.connect_count++; + transport->xprt.stat.connect_start = jiffies; +} + #define XS_SENDMSG_FLAGS (MSG_DONTWAIT | MSG_NOSIGNAL) static int xs_send_kvec(struct socket *sock, struct sockaddr *addr, int addrlen, struct kvec *vec, unsigned int base, int more) @@ -440,28 +796,21 @@ out: return err; } -static void xs_nospace_callback(struct rpc_task *task) -{ - struct sock_xprt *transport = container_of(task->tk_rqstp->rq_xprt, struct sock_xprt, xprt); - - transport->inet->sk_write_pending--; -} - /** - * xs_nospace - place task on wait queue if transmit was incomplete - * @task: task to put to sleep + * xs_nospace - handle transmit was incomplete + * @req: pointer to RPC request * */ -static int xs_nospace(struct rpc_task *task) +static int xs_nospace(struct rpc_rqst *req) { - struct rpc_rqst *req = task->tk_rqstp; struct rpc_xprt *xprt = req->rq_xprt; struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); struct sock *sk = transport->inet; int ret = -EAGAIN; dprintk("RPC: %5u xmit incomplete (%u left of %u)\n", - task->tk_pid, req->rq_slen - req->rq_bytes_sent, + req->rq_task->tk_pid, + req->rq_slen - transport->xmit.offset, req->rq_slen); /* Protect against races with write_space */ @@ -471,7 +820,7 @@ static int xs_nospace(struct rpc_task *task) if (xprt_connected(xprt)) { /* wait for more buffer space */ sk->sk_write_pending++; - xprt_wait_for_buffer_space(task, xs_nospace_callback); + xprt_wait_for_buffer_space(xprt); } else ret = -ENOTCONN; @@ -491,6 +840,22 @@ static int xs_nospace(struct rpc_task *task) return ret; } +static void +xs_stream_prepare_request(struct rpc_rqst *req) +{ + req->rq_task->tk_status = xdr_alloc_bvec(&req->rq_rcv_buf, GFP_NOIO); +} + +/* + * Determine if the previous message in the stream was aborted before it + * could complete transmission. + */ +static bool +xs_send_request_was_aborted(struct sock_xprt *transport, struct rpc_rqst *req) +{ + return transport->xmit.offset != 0 && req->rq_bytes_sent == 0; +} + /* * Construct a stream transport record marker in @buf. */ @@ -503,7 +868,7 @@ static inline void xs_encode_stream_record_marker(struct xdr_buf *buf) /** * xs_local_send_request - write an RPC request to an AF_LOCAL socket - * @task: RPC task that manages the state of an RPC request + * @req: pointer to RPC request * * Return values: * 0: The request has been sent @@ -512,9 +877,8 @@ static inline void xs_encode_stream_record_marker(struct xdr_buf *buf) * ENOTCONN: Caller needs to invoke connect logic then call again * other: Some other error occured, the request was not sent */ -static int xs_local_send_request(struct rpc_task *task) +static int xs_local_send_request(struct rpc_rqst *req) { - struct rpc_rqst *req = task->tk_rqstp; struct rpc_xprt *xprt = req->rq_xprt; struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); @@ -522,25 +886,34 @@ static int xs_local_send_request(struct rpc_task *task) int status; int sent = 0; + /* Close the stream if the previous transmission was incomplete */ + if (xs_send_request_was_aborted(transport, req)) { + xs_close(xprt); + return -ENOTCONN; + } + xs_encode_stream_record_marker(&req->rq_snd_buf); xs_pktdump("packet data:", req->rq_svec->iov_base, req->rq_svec->iov_len); req->rq_xtime = ktime_get(); - status = xs_sendpages(transport->sock, NULL, 0, xdr, req->rq_bytes_sent, + status = xs_sendpages(transport->sock, NULL, 0, xdr, + transport->xmit.offset, true, &sent); dprintk("RPC: %s(%u) = %d\n", - __func__, xdr->len - req->rq_bytes_sent, status); + __func__, xdr->len - transport->xmit.offset, status); if (status == -EAGAIN && sock_writeable(transport->inet)) status = -ENOBUFS; if (likely(sent > 0) || status == 0) { - req->rq_bytes_sent += sent; - req->rq_xmit_bytes_sent += sent; + transport->xmit.offset += sent; + req->rq_bytes_sent = transport->xmit.offset; if (likely(req->rq_bytes_sent >= req->rq_slen)) { + req->rq_xmit_bytes_sent += transport->xmit.offset; req->rq_bytes_sent = 0; + transport->xmit.offset = 0; return 0; } status = -EAGAIN; @@ -550,7 +923,7 @@ static int xs_local_send_request(struct rpc_task *task) case -ENOBUFS: break; case -EAGAIN: - status = xs_nospace(task); + status = xs_nospace(req); break; default: dprintk("RPC: sendmsg returned unrecognized error %d\n", @@ -566,7 +939,7 @@ static int xs_local_send_request(struct rpc_task *task) /** * xs_udp_send_request - write an RPC request to a UDP socket - * @task: address of RPC task that manages the state of an RPC request + * @req: pointer to RPC request * * Return values: * 0: The request has been sent @@ -575,9 +948,8 @@ static int xs_local_send_request(struct rpc_task *task) * ENOTCONN: Caller needs to invoke connect logic then call again * other: Some other error occurred, the request was not sent */ -static int xs_udp_send_request(struct rpc_task *task) +static int xs_udp_send_request(struct rpc_rqst *req) { - struct rpc_rqst *req = task->tk_rqstp; struct rpc_xprt *xprt = req->rq_xprt; struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); struct xdr_buf *xdr = &req->rq_snd_buf; @@ -590,12 +962,16 @@ static int xs_udp_send_request(struct rpc_task *task) if (!xprt_bound(xprt)) return -ENOTCONN; + + if (!xprt_request_get_cong(xprt, req)) + return -EBADSLT; + req->rq_xtime = ktime_get(); status = xs_sendpages(transport->sock, xs_addr(xprt), xprt->addrlen, - xdr, req->rq_bytes_sent, true, &sent); + xdr, 0, true, &sent); dprintk("RPC: xs_udp_send_request(%u) = %d\n", - xdr->len - req->rq_bytes_sent, status); + xdr->len, status); /* firewall is blocking us, don't return -EAGAIN or we end up looping */ if (status == -EPERM) @@ -619,7 +995,7 @@ process_status: /* Should we call xs_close() here? */ break; case -EAGAIN: - status = xs_nospace(task); + status = xs_nospace(req); break; case -ENETUNREACH: case -ENOBUFS: @@ -639,7 +1015,7 @@ process_status: /** * xs_tcp_send_request - write an RPC request to a TCP socket - * @task: address of RPC task that manages the state of an RPC request + * @req: pointer to RPC request * * Return values: * 0: The request has been sent @@ -651,9 +1027,8 @@ process_status: * XXX: In the case of soft timeouts, should we eventually give up * if sendmsg is not able to make progress? */ -static int xs_tcp_send_request(struct rpc_task *task) +static int xs_tcp_send_request(struct rpc_rqst *req) { - struct rpc_rqst *req = task->tk_rqstp; struct rpc_xprt *xprt = req->rq_xprt; struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); struct xdr_buf *xdr = &req->rq_snd_buf; @@ -662,6 +1037,13 @@ static int xs_tcp_send_request(struct rpc_task *task) int status; int sent; + /* Close the stream if the previous transmission was incomplete */ + if (xs_send_request_was_aborted(transport, req)) { + if (transport->sock != NULL) + kernel_sock_shutdown(transport->sock, SHUT_RDWR); + return -ENOTCONN; + } + xs_encode_stream_record_marker(&req->rq_snd_buf); xs_pktdump("packet data:", @@ -671,7 +1053,7 @@ static int xs_tcp_send_request(struct rpc_task *task) * completes while the socket holds a reference to the pages, * then we may end up resending corrupted data. */ - if (task->tk_flags & RPC_TASK_SENT) + if (req->rq_task->tk_flags & RPC_TASK_SENT) zerocopy = false; if (test_bit(XPRT_SOCK_UPD_TIMEOUT, &transport->sock_state)) @@ -684,17 +1066,20 @@ static int xs_tcp_send_request(struct rpc_task *task) while (1) { sent = 0; status = xs_sendpages(transport->sock, NULL, 0, xdr, - req->rq_bytes_sent, zerocopy, &sent); + transport->xmit.offset, + zerocopy, &sent); dprintk("RPC: xs_tcp_send_request(%u) = %d\n", - xdr->len - req->rq_bytes_sent, status); + xdr->len - transport->xmit.offset, status); /* If we've sent the entire packet, immediately * reset the count of bytes sent. */ - req->rq_bytes_sent += sent; - req->rq_xmit_bytes_sent += sent; + transport->xmit.offset += sent; + req->rq_bytes_sent = transport->xmit.offset; if (likely(req->rq_bytes_sent >= req->rq_slen)) { + req->rq_xmit_bytes_sent += transport->xmit.offset; req->rq_bytes_sent = 0; + transport->xmit.offset = 0; return 0; } @@ -732,7 +1117,7 @@ static int xs_tcp_send_request(struct rpc_task *task) /* Should we call xs_close() here? */ break; case -EAGAIN: - status = xs_nospace(task); + status = xs_nospace(req); break; case -ECONNRESET: case -ECONNREFUSED: @@ -749,35 +1134,6 @@ static int xs_tcp_send_request(struct rpc_task *task) return status; } -/** - * xs_tcp_release_xprt - clean up after a tcp transmission - * @xprt: transport - * @task: rpc task - * - * This cleans up if an error causes us to abort the transmission of a request. - * In this case, the socket may need to be reset in order to avoid confusing - * the server. - */ -static void xs_tcp_release_xprt(struct rpc_xprt *xprt, struct rpc_task *task) -{ - struct rpc_rqst *req; - - if (task != xprt->snd_task) - return; - if (task == NULL) - goto out_release; - req = task->tk_rqstp; - if (req == NULL) - goto out_release; - if (req->rq_bytes_sent == 0) - goto out_release; - if (req->rq_bytes_sent == req->rq_snd_buf.len) - goto out_release; - set_bit(XPRT_CLOSE_WAIT, &xprt->state); -out_release: - xprt_release_xprt(xprt, task); -} - static void xs_save_old_callbacks(struct sock_xprt *transport, struct sock *sk) { transport->old_data_ready = sk->sk_data_ready; @@ -921,114 +1277,6 @@ static void xs_destroy(struct rpc_xprt *xprt) module_put(THIS_MODULE); } -static int xs_local_copy_to_xdr(struct xdr_buf *xdr, struct sk_buff *skb) -{ - struct xdr_skb_reader desc = { - .skb = skb, - .offset = sizeof(rpc_fraghdr), - .count = skb->len - sizeof(rpc_fraghdr), - }; - - if (xdr_partial_copy_from_skb(xdr, 0, &desc, xdr_skb_read_bits) < 0) - return -1; - if (desc.count) - return -1; - return 0; -} - -/** - * xs_local_data_read_skb - * @xprt: transport - * @sk: socket - * @skb: skbuff - * - * Currently this assumes we can read the whole reply in a single gulp. - */ -static void xs_local_data_read_skb(struct rpc_xprt *xprt, - struct sock *sk, - struct sk_buff *skb) -{ - struct rpc_task *task; - struct rpc_rqst *rovr; - int repsize, copied; - u32 _xid; - __be32 *xp; - - repsize = skb->len - sizeof(rpc_fraghdr); - if (repsize < 4) { - dprintk("RPC: impossible RPC reply size %d\n", repsize); - return; - } - - /* Copy the XID from the skb... */ - xp = skb_header_pointer(skb, sizeof(rpc_fraghdr), sizeof(_xid), &_xid); - if (xp == NULL) - return; - - /* Look up and lock the request corresponding to the given XID */ - spin_lock(&xprt->recv_lock); - rovr = xprt_lookup_rqst(xprt, *xp); - if (!rovr) - goto out_unlock; - xprt_pin_rqst(rovr); - spin_unlock(&xprt->recv_lock); - task = rovr->rq_task; - - copied = rovr->rq_private_buf.buflen; - if (copied > repsize) - copied = repsize; - - if (xs_local_copy_to_xdr(&rovr->rq_private_buf, skb)) { - dprintk("RPC: sk_buff copy failed\n"); - spin_lock(&xprt->recv_lock); - goto out_unpin; - } - - spin_lock(&xprt->recv_lock); - xprt_complete_rqst(task, copied); -out_unpin: - xprt_unpin_rqst(rovr); - out_unlock: - spin_unlock(&xprt->recv_lock); -} - -static void xs_local_data_receive(struct sock_xprt *transport) -{ - struct sk_buff *skb; - struct sock *sk; - int err; - -restart: - mutex_lock(&transport->recv_mutex); - sk = transport->inet; - if (sk == NULL) - goto out; - for (;;) { - skb = skb_recv_datagram(sk, 0, 1, &err); - if (skb != NULL) { - xs_local_data_read_skb(&transport->xprt, sk, skb); - skb_free_datagram(sk, skb); - continue; - } - if (!test_and_clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state)) - break; - if (need_resched()) { - mutex_unlock(&transport->recv_mutex); - cond_resched(); - goto restart; - } - } -out: - mutex_unlock(&transport->recv_mutex); -} - -static void xs_local_data_receive_workfn(struct work_struct *work) -{ - struct sock_xprt *transport = - container_of(work, struct sock_xprt, recv_worker); - xs_local_data_receive(transport); -} - /** * xs_udp_data_read_skb - receive callback for UDP sockets * @xprt: transport @@ -1058,13 +1306,13 @@ static void xs_udp_data_read_skb(struct rpc_xprt *xprt, return; /* Look up and lock the request corresponding to the given XID */ - spin_lock(&xprt->recv_lock); + spin_lock(&xprt->queue_lock); rovr = xprt_lookup_rqst(xprt, *xp); if (!rovr) goto out_unlock; xprt_pin_rqst(rovr); xprt_update_rtt(rovr->rq_task); - spin_unlock(&xprt->recv_lock); + spin_unlock(&xprt->queue_lock); task = rovr->rq_task; if ((copied = rovr->rq_private_buf.buflen) > repsize) @@ -1072,7 +1320,7 @@ static void xs_udp_data_read_skb(struct rpc_xprt *xprt, /* Suck it into the iovec, verify checksum if not done by hw. */ if (csum_partial_copy_to_xdr(&rovr->rq_private_buf, skb)) { - spin_lock(&xprt->recv_lock); + spin_lock(&xprt->queue_lock); __UDPX_INC_STATS(sk, UDP_MIB_INERRORS); goto out_unpin; } @@ -1081,13 +1329,13 @@ static void xs_udp_data_read_skb(struct rpc_xprt *xprt, spin_lock_bh(&xprt->transport_lock); xprt_adjust_cwnd(xprt, task, copied); spin_unlock_bh(&xprt->transport_lock); - spin_lock(&xprt->recv_lock); + spin_lock(&xprt->queue_lock); xprt_complete_rqst(task, copied); __UDPX_INC_STATS(sk, UDP_MIB_INDATAGRAMS); out_unpin: xprt_unpin_rqst(rovr); out_unlock: - spin_unlock(&xprt->recv_lock); + spin_unlock(&xprt->queue_lock); } static void xs_udp_data_receive(struct sock_xprt *transport) @@ -1096,25 +1344,18 @@ static void xs_udp_data_receive(struct sock_xprt *transport) struct sock *sk; int err; -restart: mutex_lock(&transport->recv_mutex); sk = transport->inet; if (sk == NULL) goto out; + clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state); for (;;) { skb = skb_recv_udp(sk, 0, 1, &err); - if (skb != NULL) { - xs_udp_data_read_skb(&transport->xprt, sk, skb); - consume_skb(skb); - continue; - } - if (!test_and_clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state)) + if (skb == NULL) break; - if (need_resched()) { - mutex_unlock(&transport->recv_mutex); - cond_resched(); - goto restart; - } + xs_udp_data_read_skb(&transport->xprt, sk, skb); + consume_skb(skb); + cond_resched(); } out: mutex_unlock(&transport->recv_mutex); @@ -1163,263 +1404,7 @@ static void xs_tcp_force_close(struct rpc_xprt *xprt) xprt_force_disconnect(xprt); } -static inline void xs_tcp_read_fraghdr(struct rpc_xprt *xprt, struct xdr_skb_reader *desc) -{ - struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); - size_t len, used; - char *p; - - p = ((char *) &transport->tcp_fraghdr) + transport->tcp_offset; - len = sizeof(transport->tcp_fraghdr) - transport->tcp_offset; - used = xdr_skb_read_bits(desc, p, len); - transport->tcp_offset += used; - if (used != len) - return; - - transport->tcp_reclen = ntohl(transport->tcp_fraghdr); - if (transport->tcp_reclen & RPC_LAST_STREAM_FRAGMENT) - transport->tcp_flags |= TCP_RCV_LAST_FRAG; - else - transport->tcp_flags &= ~TCP_RCV_LAST_FRAG; - transport->tcp_reclen &= RPC_FRAGMENT_SIZE_MASK; - - transport->tcp_flags &= ~TCP_RCV_COPY_FRAGHDR; - transport->tcp_offset = 0; - - /* Sanity check of the record length */ - if (unlikely(transport->tcp_reclen < 8)) { - dprintk("RPC: invalid TCP record fragment length\n"); - xs_tcp_force_close(xprt); - return; - } - dprintk("RPC: reading TCP record fragment of length %d\n", - transport->tcp_reclen); -} - -static void xs_tcp_check_fraghdr(struct sock_xprt *transport) -{ - if (transport->tcp_offset == transport->tcp_reclen) { - transport->tcp_flags |= TCP_RCV_COPY_FRAGHDR; - transport->tcp_offset = 0; - if (transport->tcp_flags & TCP_RCV_LAST_FRAG) { - transport->tcp_flags &= ~TCP_RCV_COPY_DATA; - transport->tcp_flags |= TCP_RCV_COPY_XID; - transport->tcp_copied = 0; - } - } -} - -static inline void xs_tcp_read_xid(struct sock_xprt *transport, struct xdr_skb_reader *desc) -{ - size_t len, used; - char *p; - - len = sizeof(transport->tcp_xid) - transport->tcp_offset; - dprintk("RPC: reading XID (%zu bytes)\n", len); - p = ((char *) &transport->tcp_xid) + transport->tcp_offset; - used = xdr_skb_read_bits(desc, p, len); - transport->tcp_offset += used; - if (used != len) - return; - transport->tcp_flags &= ~TCP_RCV_COPY_XID; - transport->tcp_flags |= TCP_RCV_READ_CALLDIR; - transport->tcp_copied = 4; - dprintk("RPC: reading %s XID %08x\n", - (transport->tcp_flags & TCP_RPC_REPLY) ? "reply for" - : "request with", - ntohl(transport->tcp_xid)); - xs_tcp_check_fraghdr(transport); -} - -static inline void xs_tcp_read_calldir(struct sock_xprt *transport, - struct xdr_skb_reader *desc) -{ - size_t len, used; - u32 offset; - char *p; - - /* - * We want transport->tcp_offset to be 8 at the end of this routine - * (4 bytes for the xid and 4 bytes for the call/reply flag). - * When this function is called for the first time, - * transport->tcp_offset is 4 (after having already read the xid). - */ - offset = transport->tcp_offset - sizeof(transport->tcp_xid); - len = sizeof(transport->tcp_calldir) - offset; - dprintk("RPC: reading CALL/REPLY flag (%zu bytes)\n", len); - p = ((char *) &transport->tcp_calldir) + offset; - used = xdr_skb_read_bits(desc, p, len); - transport->tcp_offset += used; - if (used != len) - return; - transport->tcp_flags &= ~TCP_RCV_READ_CALLDIR; - /* - * We don't yet have the XDR buffer, so we will write the calldir - * out after we get the buffer from the 'struct rpc_rqst' - */ - switch (ntohl(transport->tcp_calldir)) { - case RPC_REPLY: - transport->tcp_flags |= TCP_RCV_COPY_CALLDIR; - transport->tcp_flags |= TCP_RCV_COPY_DATA; - transport->tcp_flags |= TCP_RPC_REPLY; - break; - case RPC_CALL: - transport->tcp_flags |= TCP_RCV_COPY_CALLDIR; - transport->tcp_flags |= TCP_RCV_COPY_DATA; - transport->tcp_flags &= ~TCP_RPC_REPLY; - break; - default: - dprintk("RPC: invalid request message type\n"); - xs_tcp_force_close(&transport->xprt); - } - xs_tcp_check_fraghdr(transport); -} - -static inline void xs_tcp_read_common(struct rpc_xprt *xprt, - struct xdr_skb_reader *desc, - struct rpc_rqst *req) -{ - struct sock_xprt *transport = - container_of(xprt, struct sock_xprt, xprt); - struct xdr_buf *rcvbuf; - size_t len; - ssize_t r; - - rcvbuf = &req->rq_private_buf; - - if (transport->tcp_flags & TCP_RCV_COPY_CALLDIR) { - /* - * Save the RPC direction in the XDR buffer - */ - memcpy(rcvbuf->head[0].iov_base + transport->tcp_copied, - &transport->tcp_calldir, - sizeof(transport->tcp_calldir)); - transport->tcp_copied += sizeof(transport->tcp_calldir); - transport->tcp_flags &= ~TCP_RCV_COPY_CALLDIR; - } - - len = desc->count; - if (len > transport->tcp_reclen - transport->tcp_offset) - desc->count = transport->tcp_reclen - transport->tcp_offset; - r = xdr_partial_copy_from_skb(rcvbuf, transport->tcp_copied, - desc, xdr_skb_read_bits); - - if (desc->count) { - /* Error when copying to the receive buffer, - * usually because we weren't able to allocate - * additional buffer pages. All we can do now - * is turn off TCP_RCV_COPY_DATA, so the request - * will not receive any additional updates, - * and time out. - * Any remaining data from this record will - * be discarded. - */ - transport->tcp_flags &= ~TCP_RCV_COPY_DATA; - dprintk("RPC: XID %08x truncated request\n", - ntohl(transport->tcp_xid)); - dprintk("RPC: xprt = %p, tcp_copied = %lu, " - "tcp_offset = %u, tcp_reclen = %u\n", - xprt, transport->tcp_copied, - transport->tcp_offset, transport->tcp_reclen); - return; - } - - transport->tcp_copied += r; - transport->tcp_offset += r; - desc->count = len - r; - - dprintk("RPC: XID %08x read %zd bytes\n", - ntohl(transport->tcp_xid), r); - dprintk("RPC: xprt = %p, tcp_copied = %lu, tcp_offset = %u, " - "tcp_reclen = %u\n", xprt, transport->tcp_copied, - transport->tcp_offset, transport->tcp_reclen); - - if (transport->tcp_copied == req->rq_private_buf.buflen) - transport->tcp_flags &= ~TCP_RCV_COPY_DATA; - else if (transport->tcp_offset == transport->tcp_reclen) { - if (transport->tcp_flags & TCP_RCV_LAST_FRAG) - transport->tcp_flags &= ~TCP_RCV_COPY_DATA; - } -} - -/* - * Finds the request corresponding to the RPC xid and invokes the common - * tcp read code to read the data. - */ -static inline int xs_tcp_read_reply(struct rpc_xprt *xprt, - struct xdr_skb_reader *desc) -{ - struct sock_xprt *transport = - container_of(xprt, struct sock_xprt, xprt); - struct rpc_rqst *req; - - dprintk("RPC: read reply XID %08x\n", ntohl(transport->tcp_xid)); - - /* Find and lock the request corresponding to this xid */ - spin_lock(&xprt->recv_lock); - req = xprt_lookup_rqst(xprt, transport->tcp_xid); - if (!req) { - dprintk("RPC: XID %08x request not found!\n", - ntohl(transport->tcp_xid)); - spin_unlock(&xprt->recv_lock); - return -1; - } - xprt_pin_rqst(req); - spin_unlock(&xprt->recv_lock); - - xs_tcp_read_common(xprt, desc, req); - - spin_lock(&xprt->recv_lock); - if (!(transport->tcp_flags & TCP_RCV_COPY_DATA)) - xprt_complete_rqst(req->rq_task, transport->tcp_copied); - xprt_unpin_rqst(req); - spin_unlock(&xprt->recv_lock); - return 0; -} - #if defined(CONFIG_SUNRPC_BACKCHANNEL) -/* - * Obtains an rpc_rqst previously allocated and invokes the common - * tcp read code to read the data. The result is placed in the callback - * queue. - * If we're unable to obtain the rpc_rqst we schedule the closing of the - * connection and return -1. - */ -static int xs_tcp_read_callback(struct rpc_xprt *xprt, - struct xdr_skb_reader *desc) -{ - struct sock_xprt *transport = - container_of(xprt, struct sock_xprt, xprt); - struct rpc_rqst *req; - - /* Look up the request corresponding to the given XID */ - req = xprt_lookup_bc_request(xprt, transport->tcp_xid); - if (req == NULL) { - printk(KERN_WARNING "Callback slot table overflowed\n"); - xprt_force_disconnect(xprt); - return -1; - } - - dprintk("RPC: read callback XID %08x\n", ntohl(req->rq_xid)); - xs_tcp_read_common(xprt, desc, req); - - if (!(transport->tcp_flags & TCP_RCV_COPY_DATA)) - xprt_complete_bc_request(req, transport->tcp_copied); - - return 0; -} - -static inline int _xs_tcp_read_data(struct rpc_xprt *xprt, - struct xdr_skb_reader *desc) -{ - struct sock_xprt *transport = - container_of(xprt, struct sock_xprt, xprt); - - return (transport->tcp_flags & TCP_RPC_REPLY) ? - xs_tcp_read_reply(xprt, desc) : - xs_tcp_read_callback(xprt, desc); -} - static int xs_tcp_bc_up(struct svc_serv *serv, struct net *net) { int ret; @@ -1435,145 +1420,8 @@ static size_t xs_tcp_bc_maxpayload(struct rpc_xprt *xprt) { return PAGE_SIZE; } -#else -static inline int _xs_tcp_read_data(struct rpc_xprt *xprt, - struct xdr_skb_reader *desc) -{ - return xs_tcp_read_reply(xprt, desc); -} #endif /* CONFIG_SUNRPC_BACKCHANNEL */ -/* - * Read data off the transport. This can be either an RPC_CALL or an - * RPC_REPLY. Relay the processing to helper functions. - */ -static void xs_tcp_read_data(struct rpc_xprt *xprt, - struct xdr_skb_reader *desc) -{ - struct sock_xprt *transport = - container_of(xprt, struct sock_xprt, xprt); - - if (_xs_tcp_read_data(xprt, desc) == 0) - xs_tcp_check_fraghdr(transport); - else { - /* - * The transport_lock protects the request handling. - * There's no need to hold it to update the tcp_flags. - */ - transport->tcp_flags &= ~TCP_RCV_COPY_DATA; - } -} - -static inline void xs_tcp_read_discard(struct sock_xprt *transport, struct xdr_skb_reader *desc) -{ - size_t len; - - len = transport->tcp_reclen - transport->tcp_offset; - if (len > desc->count) - len = desc->count; - desc->count -= len; - desc->offset += len; - transport->tcp_offset += len; - dprintk("RPC: discarded %zu bytes\n", len); - xs_tcp_check_fraghdr(transport); -} - -static int xs_tcp_data_recv(read_descriptor_t *rd_desc, struct sk_buff *skb, unsigned int offset, size_t len) -{ - struct rpc_xprt *xprt = rd_desc->arg.data; - struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); - struct xdr_skb_reader desc = { - .skb = skb, - .offset = offset, - .count = len, - }; - size_t ret; - - dprintk("RPC: xs_tcp_data_recv started\n"); - do { - trace_xs_tcp_data_recv(transport); - /* Read in a new fragment marker if necessary */ - /* Can we ever really expect to get completely empty fragments? */ - if (transport->tcp_flags & TCP_RCV_COPY_FRAGHDR) { - xs_tcp_read_fraghdr(xprt, &desc); - continue; - } - /* Read in the xid if necessary */ - if (transport->tcp_flags & TCP_RCV_COPY_XID) { - xs_tcp_read_xid(transport, &desc); - continue; - } - /* Read in the call/reply flag */ - if (transport->tcp_flags & TCP_RCV_READ_CALLDIR) { - xs_tcp_read_calldir(transport, &desc); - continue; - } - /* Read in the request data */ - if (transport->tcp_flags & TCP_RCV_COPY_DATA) { - xs_tcp_read_data(xprt, &desc); - continue; - } - /* Skip over any trailing bytes on short reads */ - xs_tcp_read_discard(transport, &desc); - } while (desc.count); - ret = len - desc.count; - if (ret < rd_desc->count) - rd_desc->count -= ret; - else - rd_desc->count = 0; - trace_xs_tcp_data_recv(transport); - dprintk("RPC: xs_tcp_data_recv done\n"); - return ret; -} - -static void xs_tcp_data_receive(struct sock_xprt *transport) -{ - struct rpc_xprt *xprt = &transport->xprt; - struct sock *sk; - read_descriptor_t rd_desc = { - .arg.data = xprt, - }; - unsigned long total = 0; - int read = 0; - -restart: - mutex_lock(&transport->recv_mutex); - sk = transport->inet; - if (sk == NULL) - goto out; - - /* We use rd_desc to pass struct xprt to xs_tcp_data_recv */ - for (;;) { - rd_desc.count = RPC_TCP_READ_CHUNK_SZ; - lock_sock(sk); - read = tcp_read_sock(sk, &rd_desc, xs_tcp_data_recv); - if (rd_desc.count != 0 || read < 0) { - clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state); - release_sock(sk); - break; - } - release_sock(sk); - total += read; - if (need_resched()) { - mutex_unlock(&transport->recv_mutex); - cond_resched(); - goto restart; - } - } - if (test_bit(XPRT_SOCK_DATA_READY, &transport->sock_state)) - queue_work(xprtiod_workqueue, &transport->recv_worker); -out: - mutex_unlock(&transport->recv_mutex); - trace_xs_tcp_data_ready(xprt, read, total); -} - -static void xs_tcp_data_receive_workfn(struct work_struct *work) -{ - struct sock_xprt *transport = - container_of(work, struct sock_xprt, recv_worker); - xs_tcp_data_receive(transport); -} - /** * xs_tcp_state_change - callback to handle TCP socket state changes * @sk: socket whose state has changed @@ -1600,17 +1448,13 @@ static void xs_tcp_state_change(struct sock *sk) case TCP_ESTABLISHED: spin_lock(&xprt->transport_lock); if (!xprt_test_and_set_connected(xprt)) { - - /* Reset TCP record info */ - transport->tcp_offset = 0; - transport->tcp_reclen = 0; - transport->tcp_copied = 0; - transport->tcp_flags = - TCP_RCV_COPY_FRAGHDR | TCP_RCV_COPY_XID; xprt->connect_cookie++; clear_bit(XPRT_SOCK_CONNECTING, &transport->sock_state); xprt_clear_connecting(xprt); + xprt->stat.connect_count++; + xprt->stat.connect_time += (long)jiffies - + xprt->stat.connect_start; xprt_wake_pending_tasks(xprt, -EAGAIN); } spin_unlock(&xprt->transport_lock); @@ -1675,7 +1519,8 @@ static void xs_write_space(struct sock *sk) if (!wq || test_and_clear_bit(SOCKWQ_ASYNC_NOSPACE, &wq->flags) == 0) goto out; - xprt_write_space(xprt); + if (xprt_write_space(xprt)) + sk->sk_write_pending--; out: rcu_read_unlock(); } @@ -1773,11 +1618,17 @@ static void xs_udp_timer(struct rpc_xprt *xprt, struct rpc_task *task) spin_unlock_bh(&xprt->transport_lock); } -static unsigned short xs_get_random_port(void) +static int xs_get_random_port(void) { - unsigned short range = xprt_max_resvport - xprt_min_resvport + 1; - unsigned short rand = (unsigned short) prandom_u32() % range; - return rand + xprt_min_resvport; + unsigned short min = xprt_min_resvport, max = xprt_max_resvport; + unsigned short range; + unsigned short rand; + + if (max < min) + return -EADDRINUSE; + range = max - min + 1; + rand = (unsigned short) prandom_u32() % range; + return rand + min; } /** @@ -1833,9 +1684,9 @@ static void xs_set_srcport(struct sock_xprt *transport, struct socket *sock) transport->srcport = xs_sock_getport(sock); } -static unsigned short xs_get_srcport(struct sock_xprt *transport) +static int xs_get_srcport(struct sock_xprt *transport) { - unsigned short port = transport->srcport; + int port = transport->srcport; if (port == 0 && transport->xprt.resvport) port = xs_get_random_port(); @@ -1856,7 +1707,7 @@ static int xs_bind(struct sock_xprt *transport, struct socket *sock) { struct sockaddr_storage myaddr; int err, nloop = 0; - unsigned short port = xs_get_srcport(transport); + int port = xs_get_srcport(transport); unsigned short last; /* @@ -1874,8 +1725,8 @@ static int xs_bind(struct sock_xprt *transport, struct socket *sock) * transport->xprt.resvport == 1) xs_get_srcport above will * ensure that port is non-zero and we will bind as needed. */ - if (port == 0) - return 0; + if (port <= 0) + return port; memcpy(&myaddr, &transport->srcaddr, transport->xprt.addrlen); do { @@ -2028,9 +1879,8 @@ static int xs_local_finish_connecting(struct rpc_xprt *xprt, write_unlock_bh(&sk->sk_callback_lock); } - /* Tell the socket layer to start connecting... */ - xprt->stat.connect_count++; - xprt->stat.connect_start = jiffies; + xs_stream_reset_connect(transport); + return kernel_connect(sock, xs_addr(xprt), xprt->addrlen, 0); } @@ -2062,6 +1912,9 @@ static int xs_local_setup_socket(struct sock_xprt *transport) case 0: dprintk("RPC: xprt %p connected to %s\n", xprt, xprt->address_strings[RPC_DISPLAY_ADDR]); + xprt->stat.connect_count++; + xprt->stat.connect_time += (long)jiffies - + xprt->stat.connect_start; xprt_set_connected(xprt); case -ENOBUFS: break; @@ -2386,9 +2239,10 @@ static int xs_tcp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) xs_set_memalloc(xprt); + /* Reset TCP record info */ + xs_stream_reset_connect(transport); + /* Tell the socket layer to start connecting... */ - xprt->stat.connect_count++; - xprt->stat.connect_start = jiffies; set_bit(XPRT_SOCK_CONNECTING, &transport->sock_state); ret = kernel_connect(sock, xs_addr(xprt), xprt->addrlen, O_NONBLOCK); switch (ret) { @@ -2561,7 +2415,7 @@ static void xs_local_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) "%llu %llu %lu %llu %llu\n", xprt->stat.bind_count, xprt->stat.connect_count, - xprt->stat.connect_time, + xprt->stat.connect_time / HZ, idle_time, xprt->stat.sends, xprt->stat.recvs, @@ -2616,7 +2470,7 @@ static void xs_tcp_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) transport->srcport, xprt->stat.bind_count, xprt->stat.connect_count, - xprt->stat.connect_time, + xprt->stat.connect_time / HZ, idle_time, xprt->stat.sends, xprt->stat.recvs, @@ -2704,9 +2558,8 @@ static int bc_sendto(struct rpc_rqst *req) /* * The send routine. Borrows from svc_send */ -static int bc_send_request(struct rpc_task *task) +static int bc_send_request(struct rpc_rqst *req) { - struct rpc_rqst *req = task->tk_rqstp; struct svc_xprt *xprt; int len; @@ -2720,12 +2573,7 @@ static int bc_send_request(struct rpc_task *task) * Grab the mutex to serialize data as the connection is shared * with the fore channel */ - if (!mutex_trylock(&xprt->xpt_mutex)) { - rpc_sleep_on(&xprt->xpt_bc_pending, task, NULL); - if (!mutex_trylock(&xprt->xpt_mutex)) - return -EAGAIN; - rpc_wake_up_queued_task(&xprt->xpt_bc_pending, task); - } + mutex_lock(&xprt->xpt_mutex); if (test_bit(XPT_DEAD, &xprt->xpt_flags)) len = -ENOTCONN; else @@ -2761,7 +2609,7 @@ static void bc_destroy(struct rpc_xprt *xprt) static const struct rpc_xprt_ops xs_local_ops = { .reserve_xprt = xprt_reserve_xprt, - .release_xprt = xs_tcp_release_xprt, + .release_xprt = xprt_release_xprt, .alloc_slot = xprt_alloc_slot, .free_slot = xprt_free_slot, .rpcbind = xs_local_rpcbind, @@ -2769,6 +2617,7 @@ static const struct rpc_xprt_ops xs_local_ops = { .connect = xs_local_connect, .buf_alloc = rpc_malloc, .buf_free = rpc_free, + .prepare_request = xs_stream_prepare_request, .send_request = xs_local_send_request, .set_retrans_timeout = xprt_set_retrans_timeout_def, .close = xs_close, @@ -2803,14 +2652,15 @@ static const struct rpc_xprt_ops xs_udp_ops = { static const struct rpc_xprt_ops xs_tcp_ops = { .reserve_xprt = xprt_reserve_xprt, - .release_xprt = xs_tcp_release_xprt, - .alloc_slot = xprt_lock_and_alloc_slot, + .release_xprt = xprt_release_xprt, + .alloc_slot = xprt_alloc_slot, .free_slot = xprt_free_slot, .rpcbind = rpcb_getport_async, .set_port = xs_set_port, .connect = xs_connect, .buf_alloc = rpc_malloc, .buf_free = rpc_free, + .prepare_request = xs_stream_prepare_request, .send_request = xs_tcp_send_request, .set_retrans_timeout = xprt_set_retrans_timeout_def, .close = xs_tcp_shutdown, @@ -2952,9 +2802,8 @@ static struct rpc_xprt *xs_setup_local(struct xprt_create *args) xprt->ops = &xs_local_ops; xprt->timeout = &xs_local_default_timeout; - INIT_WORK(&transport->recv_worker, xs_local_data_receive_workfn); - INIT_DELAYED_WORK(&transport->connect_worker, - xs_dummy_setup_socket); + INIT_WORK(&transport->recv_worker, xs_stream_data_receive_workfn); + INIT_DELAYED_WORK(&transport->connect_worker, xs_dummy_setup_socket); switch (sun->sun_family) { case AF_LOCAL: @@ -3106,7 +2955,7 @@ static struct rpc_xprt *xs_setup_tcp(struct xprt_create *args) xprt->connect_timeout = xprt->timeout->to_initval * (xprt->timeout->to_retries + 1); - INIT_WORK(&transport->recv_worker, xs_tcp_data_receive_workfn); + INIT_WORK(&transport->recv_worker, xs_stream_data_receive_workfn); INIT_DELAYED_WORK(&transport->connect_worker, xs_tcp_setup_socket); switch (addr->sa_family) { @@ -3317,12 +3166,8 @@ static int param_set_uint_minmax(const char *val, static int param_set_portnr(const char *val, const struct kernel_param *kp) { - if (kp->arg == &xprt_min_resvport) - return param_set_uint_minmax(val, kp, - RPC_MIN_RESVPORT, - xprt_max_resvport); return param_set_uint_minmax(val, kp, - xprt_min_resvport, + RPC_MIN_RESVPORT, RPC_MAX_RESVPORT); } diff --git a/net/tipc/bearer.c b/net/tipc/bearer.c index 91891041e5e1..e65c3a8551e4 100644 --- a/net/tipc/bearer.c +++ b/net/tipc/bearer.c @@ -609,16 +609,18 @@ static int tipc_l2_device_event(struct notifier_block *nb, unsigned long evt, switch (evt) { case NETDEV_CHANGE: - if (netif_carrier_ok(dev)) + if (netif_carrier_ok(dev) && netif_oper_up(dev)) { + test_and_set_bit_lock(0, &b->up); break; - /* else: fall through */ - case NETDEV_UP: - test_and_set_bit_lock(0, &b->up); - break; + } + /* fall through */ case NETDEV_GOING_DOWN: clear_bit_unlock(0, &b->up); tipc_reset_bearer(net, b); break; + case NETDEV_UP: + test_and_set_bit_lock(0, &b->up); + break; case NETDEV_CHANGEMTU: if (tipc_mtu_bad(dev, 0)) { bearer_disable(net, b); diff --git a/net/tipc/group.c b/net/tipc/group.c index e82f13cb2dc5..06fee142f09f 100644 --- a/net/tipc/group.c +++ b/net/tipc/group.c @@ -666,6 +666,7 @@ static void tipc_group_create_event(struct tipc_group *grp, struct sk_buff *skb; struct tipc_msg *hdr; + memset(&evt, 0, sizeof(evt)); evt.event = event; evt.found_lower = m->instance; evt.found_upper = m->instance; diff --git a/net/tipc/link.c b/net/tipc/link.c index b1f0bee54eac..201c3b5bc96b 100644 --- a/net/tipc/link.c +++ b/net/tipc/link.c @@ -410,6 +410,11 @@ char *tipc_link_name(struct tipc_link *l) return l->name; } +u32 tipc_link_state(struct tipc_link *l) +{ + return l->state; +} + /** * tipc_link_create - create a new link * @n: pointer to associated node @@ -472,6 +477,8 @@ bool tipc_link_create(struct net *net, char *if_name, int bearer_id, l->in_session = false; l->bearer_id = bearer_id; l->tolerance = tolerance; + if (bc_rcvlink) + bc_rcvlink->tolerance = tolerance; l->net_plane = net_plane; l->advertised_mtu = mtu; l->mtu = mtu; @@ -838,12 +845,24 @@ static void link_prepare_wakeup(struct tipc_link *l) void tipc_link_reset(struct tipc_link *l) { + struct sk_buff_head list; + + __skb_queue_head_init(&list); + l->in_session = false; l->session++; l->mtu = l->advertised_mtu; + + spin_lock_bh(&l->wakeupq.lock); + skb_queue_splice_init(&l->wakeupq, &list); + spin_unlock_bh(&l->wakeupq.lock); + + spin_lock_bh(&l->inputq->lock); + skb_queue_splice_init(&list, l->inputq); + spin_unlock_bh(&l->inputq->lock); + __skb_queue_purge(&l->transmq); __skb_queue_purge(&l->deferdq); - skb_queue_splice_init(&l->wakeupq, l->inputq); __skb_queue_purge(&l->backlogq); l->backlog[TIPC_LOW_IMPORTANCE].len = 0; l->backlog[TIPC_MEDIUM_IMPORTANCE].len = 0; @@ -1021,7 +1040,8 @@ static int tipc_link_retrans(struct tipc_link *l, struct tipc_link *r, /* Detect repeated retransmit failures on same packet */ if (r->last_retransm != buf_seqno(skb)) { r->last_retransm = buf_seqno(skb); - r->stale_limit = jiffies + msecs_to_jiffies(l->tolerance); + r->stale_limit = jiffies + msecs_to_jiffies(r->tolerance); + r->stale_cnt = 0; } else if (++r->stale_cnt > 99 && time_after(jiffies, r->stale_limit)) { link_retransmit_failure(l, skb); if (link_is_bc_sndlink(l)) @@ -1380,6 +1400,36 @@ static void tipc_link_build_proto_msg(struct tipc_link *l, int mtyp, bool probe, __skb_queue_tail(xmitq, skb); } +void tipc_link_create_dummy_tnl_msg(struct tipc_link *l, + struct sk_buff_head *xmitq) +{ + u32 onode = tipc_own_addr(l->net); + struct tipc_msg *hdr, *ihdr; + struct sk_buff_head tnlq; + struct sk_buff *skb; + u32 dnode = l->addr; + + skb_queue_head_init(&tnlq); + skb = tipc_msg_create(TUNNEL_PROTOCOL, FAILOVER_MSG, + INT_H_SIZE, BASIC_H_SIZE, + dnode, onode, 0, 0, 0); + if (!skb) { + pr_warn("%sunable to create tunnel packet\n", link_co_err); + return; + } + + hdr = buf_msg(skb); + msg_set_msgcnt(hdr, 1); + msg_set_bearer_id(hdr, l->peer_bearer_id); + + ihdr = (struct tipc_msg *)msg_data(hdr); + tipc_msg_init(onode, ihdr, TIPC_LOW_IMPORTANCE, TIPC_DIRECT_MSG, + BASIC_H_SIZE, dnode); + msg_set_errcode(ihdr, TIPC_ERR_NO_PORT); + __skb_queue_tail(&tnlq, skb); + tipc_link_xmit(l, &tnlq, xmitq); +} + /* tipc_link_tnl_prepare(): prepare and return a list of tunnel packets * with contents of the link's transmit and backlog queues. */ @@ -1476,6 +1526,9 @@ bool tipc_link_validate_msg(struct tipc_link *l, struct tipc_msg *hdr) return false; if (session != curr_session) return false; + /* Extra sanity check */ + if (!link_is_up(l) && msg_ack(hdr)) + return false; if (!(l->peer_caps & TIPC_LINK_PROTO_SEQNO)) return true; /* Accept only STATE with new sequence number */ @@ -1533,9 +1586,10 @@ static int tipc_link_proto_rcv(struct tipc_link *l, struct sk_buff *skb, strncpy(if_name, data, TIPC_MAX_IF_NAME); /* Update own tolerance if peer indicates a non-zero value */ - if (in_range(peers_tol, TIPC_MIN_LINK_TOL, TIPC_MAX_LINK_TOL)) + if (in_range(peers_tol, TIPC_MIN_LINK_TOL, TIPC_MAX_LINK_TOL)) { l->tolerance = peers_tol; - + l->bc_rcvlink->tolerance = peers_tol; + } /* Update own priority if peer's priority is higher */ if (in_range(peers_prio, l->priority + 1, TIPC_MAX_LINK_PRI)) l->priority = peers_prio; @@ -1561,9 +1615,10 @@ static int tipc_link_proto_rcv(struct tipc_link *l, struct sk_buff *skb, l->rcv_nxt_state = msg_seqno(hdr) + 1; /* Update own tolerance if peer indicates a non-zero value */ - if (in_range(peers_tol, TIPC_MIN_LINK_TOL, TIPC_MAX_LINK_TOL)) + if (in_range(peers_tol, TIPC_MIN_LINK_TOL, TIPC_MAX_LINK_TOL)) { l->tolerance = peers_tol; - + l->bc_rcvlink->tolerance = peers_tol; + } /* Update own prio if peer indicates a different value */ if ((peers_prio != l->priority) && in_range(peers_prio, 1, TIPC_MAX_LINK_PRI)) { @@ -2180,6 +2235,8 @@ void tipc_link_set_tolerance(struct tipc_link *l, u32 tol, struct sk_buff_head *xmitq) { l->tolerance = tol; + if (l->bc_rcvlink) + l->bc_rcvlink->tolerance = tol; if (link_is_up(l)) tipc_link_build_proto_msg(l, STATE_MSG, 0, 0, 0, tol, 0, xmitq); } diff --git a/net/tipc/link.h b/net/tipc/link.h index 7bc494a33fdf..90488c538a4e 100644 --- a/net/tipc/link.h +++ b/net/tipc/link.h @@ -88,6 +88,8 @@ bool tipc_link_bc_create(struct net *net, u32 ownnode, u32 peer, struct tipc_link **link); void tipc_link_tnl_prepare(struct tipc_link *l, struct tipc_link *tnl, int mtyp, struct sk_buff_head *xmitq); +void tipc_link_create_dummy_tnl_msg(struct tipc_link *tnl, + struct sk_buff_head *xmitq); void tipc_link_build_reset_msg(struct tipc_link *l, struct sk_buff_head *xmitq); int tipc_link_fsm_evt(struct tipc_link *l, int evt); bool tipc_link_is_up(struct tipc_link *l); @@ -107,6 +109,7 @@ u16 tipc_link_rcv_nxt(struct tipc_link *l); u16 tipc_link_acked(struct tipc_link *l); u32 tipc_link_id(struct tipc_link *l); char *tipc_link_name(struct tipc_link *l); +u32 tipc_link_state(struct tipc_link *l); char tipc_link_plane(struct tipc_link *l); int tipc_link_prio(struct tipc_link *l); int tipc_link_window(struct tipc_link *l); diff --git a/net/tipc/name_distr.c b/net/tipc/name_distr.c index 51b4b96f89db..61219f0b9677 100644 --- a/net/tipc/name_distr.c +++ b/net/tipc/name_distr.c @@ -94,8 +94,9 @@ struct sk_buff *tipc_named_publish(struct net *net, struct publication *publ) list_add_tail_rcu(&publ->binding_node, &nt->node_scope); return NULL; } - list_add_tail_rcu(&publ->binding_node, &nt->cluster_scope); - + write_lock_bh(&nt->cluster_scope_lock); + list_add_tail(&publ->binding_node, &nt->cluster_scope); + write_unlock_bh(&nt->cluster_scope_lock); skb = named_prepare_buf(net, PUBLICATION, ITEM_SIZE, 0); if (!skb) { pr_warn("Publication distribution failure\n"); @@ -112,11 +113,13 @@ struct sk_buff *tipc_named_publish(struct net *net, struct publication *publ) */ struct sk_buff *tipc_named_withdraw(struct net *net, struct publication *publ) { + struct name_table *nt = tipc_name_table(net); struct sk_buff *buf; struct distr_item *item; + write_lock_bh(&nt->cluster_scope_lock); list_del(&publ->binding_node); - + write_unlock_bh(&nt->cluster_scope_lock); if (publ->scope == TIPC_NODE_SCOPE) return NULL; @@ -189,11 +192,10 @@ void tipc_named_node_up(struct net *net, u32 dnode) __skb_queue_head_init(&head); - rcu_read_lock(); + read_lock_bh(&nt->cluster_scope_lock); named_distribute(net, &head, dnode, &nt->cluster_scope); - rcu_read_unlock(); - tipc_node_xmit(net, &head, dnode, 0); + read_unlock_bh(&nt->cluster_scope_lock); } /** diff --git a/net/tipc/name_table.c b/net/tipc/name_table.c index 66d5b2c5987a..bff241f03525 100644 --- a/net/tipc/name_table.c +++ b/net/tipc/name_table.c @@ -744,6 +744,7 @@ int tipc_nametbl_init(struct net *net) INIT_LIST_HEAD(&nt->node_scope); INIT_LIST_HEAD(&nt->cluster_scope); + rwlock_init(&nt->cluster_scope_lock); tn->nametbl = nt; spin_lock_init(&tn->nametbl_lock); return 0; diff --git a/net/tipc/name_table.h b/net/tipc/name_table.h index 892bd750b85f..f79066334cc8 100644 --- a/net/tipc/name_table.h +++ b/net/tipc/name_table.h @@ -100,6 +100,7 @@ struct name_table { struct hlist_head services[TIPC_NAMETBL_SIZE]; struct list_head node_scope; struct list_head cluster_scope; + rwlock_t cluster_scope_lock; u32 local_publ_count; }; diff --git a/net/tipc/node.c b/net/tipc/node.c index 68014f1b6976..2afc4f8c37a7 100644 --- a/net/tipc/node.c +++ b/net/tipc/node.c @@ -111,6 +111,7 @@ struct tipc_node { int action_flags; struct list_head list; int state; + bool failover_sent; u16 sync_point; int link_cnt; u16 working_links; @@ -680,6 +681,7 @@ static void __tipc_node_link_up(struct tipc_node *n, int bearer_id, *slot0 = bearer_id; *slot1 = bearer_id; tipc_node_fsm_evt(n, SELF_ESTABL_CONTACT_EVT); + n->failover_sent = false; n->action_flags |= TIPC_NOTIFY_NODE_UP; tipc_link_set_active(nl, true); tipc_bcast_add_peer(n->net, nl, xmitq); @@ -911,6 +913,7 @@ void tipc_node_check_dest(struct net *net, u32 addr, bool reset = true; char *if_name; unsigned long intv; + u16 session; *dupl_addr = false; *respond = false; @@ -997,9 +1000,10 @@ void tipc_node_check_dest(struct net *net, u32 addr, goto exit; if_name = strchr(b->name, ':') + 1; + get_random_bytes(&session, sizeof(u16)); if (!tipc_link_create(net, if_name, b->identity, b->tolerance, b->net_plane, b->mtu, b->priority, - b->window, mod(tipc_net(net)->random), + b->window, session, tipc_own_addr(net), addr, peer_id, n->capabilities, tipc_bc_sndlink(n->net), n->bc_entry.link, @@ -1615,6 +1619,14 @@ static bool tipc_node_check_state(struct tipc_node *n, struct sk_buff *skb, tipc_skb_queue_splice_tail_init(tipc_link_inputq(pl), tipc_link_inputq(l)); } + /* If parallel link was already down, and this happened before + * the tunnel link came up, FAILOVER was never sent. Ensure that + * FAILOVER is sent to get peer out of NODE_FAILINGOVER state. + */ + if (n->state != NODE_FAILINGOVER && !n->failover_sent) { + tipc_link_create_dummy_tnl_msg(l, xmitq); + n->failover_sent = true; + } /* If pkts arrive out of order, use lowest calculated syncpt */ if (less(syncpt, n->sync_point)) n->sync_point = syncpt; diff --git a/net/tipc/socket.c b/net/tipc/socket.c index 595c5001b28d..636e6131769d 100644 --- a/net/tipc/socket.c +++ b/net/tipc/socket.c @@ -717,7 +717,7 @@ static __poll_t tipc_poll(struct file *file, struct socket *sock, struct tipc_sock *tsk = tipc_sk(sk); __poll_t revents = 0; - sock_poll_wait(file, wait); + sock_poll_wait(file, sock, wait); if (sk->sk_shutdown & RCV_SHUTDOWN) revents |= EPOLLRDHUP | EPOLLIN | EPOLLRDNORM; @@ -1198,6 +1198,7 @@ void tipc_sk_mcast_rcv(struct net *net, struct sk_buff_head *arrvq, * @skb: pointer to message buffer. */ static void tipc_sk_conn_proto_rcv(struct tipc_sock *tsk, struct sk_buff *skb, + struct sk_buff_head *inputq, struct sk_buff_head *xmitq) { struct tipc_msg *hdr = buf_msg(skb); @@ -1215,7 +1216,16 @@ static void tipc_sk_conn_proto_rcv(struct tipc_sock *tsk, struct sk_buff *skb, tipc_node_remove_conn(sock_net(sk), tsk_peer_node(tsk), tsk_peer_port(tsk)); sk->sk_state_change(sk); - goto exit; + + /* State change is ignored if socket already awake, + * - convert msg to abort msg and add to inqueue + */ + msg_set_user(hdr, TIPC_CRITICAL_IMPORTANCE); + msg_set_type(hdr, TIPC_CONN_MSG); + msg_set_size(hdr, BASIC_H_SIZE); + msg_set_hdr_sz(hdr, BASIC_H_SIZE); + __skb_queue_tail(inputq, skb); + return; } tsk->probe_unacked = false; @@ -1424,8 +1434,10 @@ static int __tipc_sendstream(struct socket *sock, struct msghdr *m, size_t dlen) /* Handle implicit connection setup */ if (unlikely(dest)) { rc = __tipc_sendmsg(sock, m, dlen); - if (dlen && (dlen == rc)) + if (dlen && dlen == rc) { + tsk->peer_caps = tipc_node_get_capabilities(net, dnode); tsk->snt_unacked = tsk_inc(tsk, dlen + msg_hdr_sz(hdr)); + } return rc; } @@ -1941,7 +1953,7 @@ static void tipc_sk_proto_rcv(struct sock *sk, switch (msg_user(hdr)) { case CONN_MANAGER: - tipc_sk_conn_proto_rcv(tsk, skb, xmitq); + tipc_sk_conn_proto_rcv(tsk, skb, inputq, xmitq); return; case SOCK_WAKEUP: tipc_dest_del(&tsk->cong_links, msg_orignode(hdr), 0); diff --git a/net/tipc/topsrv.c b/net/tipc/topsrv.c index d8956f7daac4..efb16f69bd2c 100644 --- a/net/tipc/topsrv.c +++ b/net/tipc/topsrv.c @@ -394,7 +394,7 @@ static int tipc_conn_rcv_from_sock(struct tipc_conn *con) iov.iov_base = &s; iov.iov_len = sizeof(s); msg.msg_name = NULL; - iov_iter_kvec(&msg.msg_iter, READ | ITER_KVEC, &iov, 1, iov.iov_len); + iov_iter_kvec(&msg.msg_iter, READ, &iov, 1, iov.iov_len); ret = sock_recvmsg(con->sock, &msg, MSG_DONTWAIT); if (ret == -EWOULDBLOCK) return -EWOULDBLOCK; @@ -651,7 +651,7 @@ int tipc_topsrv_start(struct net *net) srv->max_rcvbuf_size = sizeof(struct tipc_subscr); INIT_WORK(&srv->awork, tipc_topsrv_accept); - strncpy(srv->name, name, strlen(name) + 1); + strscpy(srv->name, name, sizeof(srv->name)); tn->topsrv = srv; atomic_set(&tn->subscription_count, 0); diff --git a/net/tipc/udp_media.c b/net/tipc/udp_media.c index 9783101bc4a9..10dc59ce9c82 100644 --- a/net/tipc/udp_media.c +++ b/net/tipc/udp_media.c @@ -650,6 +650,7 @@ static int tipc_udp_enable(struct net *net, struct tipc_bearer *b, struct udp_tunnel_sock_cfg tuncfg = {NULL}; struct nlattr *opts[TIPC_NLA_UDP_MAX + 1]; u8 node_id[NODE_ID_LEN] = {0,}; + int rmcast = 0; ub = kzalloc(sizeof(*ub), GFP_ATOMIC); if (!ub) @@ -680,6 +681,9 @@ static int tipc_udp_enable(struct net *net, struct tipc_bearer *b, if (err) goto err; + /* Checking remote ip address */ + rmcast = tipc_udp_is_mcast_addr(&remote); + /* Autoconfigure own node identity if needed */ if (!tipc_own_id(net)) { memcpy(node_id, local.ipv6.in6_u.u6_addr8, 16); @@ -705,7 +709,12 @@ static int tipc_udp_enable(struct net *net, struct tipc_bearer *b, goto err; } udp_conf.family = AF_INET; - udp_conf.local_ip.s_addr = htonl(INADDR_ANY); + + /* Switch to use ANY to receive packets from group */ + if (rmcast) + udp_conf.local_ip.s_addr = htonl(INADDR_ANY); + else + udp_conf.local_ip.s_addr = local.ipv4.s_addr; udp_conf.use_udp_checksums = false; ub->ifindex = dev->ifindex; if (tipc_mtu_bad(dev, sizeof(struct iphdr) + @@ -719,7 +728,10 @@ static int tipc_udp_enable(struct net *net, struct tipc_bearer *b, udp_conf.family = AF_INET6; udp_conf.use_udp6_tx_checksums = true; udp_conf.use_udp6_rx_checksums = true; - udp_conf.local_ip6 = in6addr_any; + if (rmcast) + udp_conf.local_ip6 = in6addr_any; + else + udp_conf.local_ip6 = local.ipv6; b->mtu = 1280; #endif } else { @@ -741,7 +753,7 @@ static int tipc_udp_enable(struct net *net, struct tipc_bearer *b, * is used if it's a multicast address. */ memcpy(&b->bcast_addr.value, &remote, sizeof(remote)); - if (tipc_udp_is_mcast_addr(&remote)) + if (rmcast) err = enable_mcast(ub, &remote); else err = tipc_udp_rcast_add(b, &remote); diff --git a/net/tls/Kconfig b/net/tls/Kconfig index 73f05ece53d0..99c1a19c17b1 100644 --- a/net/tls/Kconfig +++ b/net/tls/Kconfig @@ -8,6 +8,7 @@ config TLS select CRYPTO_AES select CRYPTO_GCM select STREAM_PARSER + select NET_SOCK_MSG default n ---help--- Enable kernel support for TLS protocol. This allows symmetric diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index 961b07d4d41c..d753e362d2d9 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -421,7 +421,7 @@ last_record: tls_push_record_flags = flags; if (more) { tls_ctx->pending_open_record_frags = - record->num_frags; + !!record->num_frags; break; } @@ -489,7 +489,7 @@ int tls_device_sendpage(struct sock *sk, struct page *page, iov.iov_base = kaddr + offset; iov.iov_len = size; - iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, &iov, 1, size); + iov_iter_kvec(&msg_iter, WRITE, &iov, 1, size); rc = tls_push_data(sk, &msg_iter, size, flags, TLS_RECORD_TYPE_DATA); kunmap(page); @@ -538,7 +538,7 @@ static int tls_device_push_pending_record(struct sock *sk, int flags) { struct iov_iter msg_iter; - iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, NULL, 0, 0); + iov_iter_kvec(&msg_iter, WRITE, NULL, 0, 0); return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA); } diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index b428069a1b05..311cec8e533d 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -620,12 +620,14 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], prot[TLS_SW][TLS_BASE].sendpage = tls_sw_sendpage; prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE]; - prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg; - prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close; + prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg; + prot[TLS_BASE][TLS_SW].stream_memory_read = tls_sw_stream_read; + prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close; prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE]; - prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg; - prot[TLS_SW][TLS_SW].close = tls_sk_proto_close; + prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg; + prot[TLS_SW][TLS_SW].stream_memory_read = tls_sw_stream_read; + prot[TLS_SW][TLS_SW].close = tls_sk_proto_close; #ifdef CONFIG_TLS_DEVICE prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; @@ -713,8 +715,6 @@ EXPORT_SYMBOL(tls_unregister_device); static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { .name = "tls", - .uid = TCP_ULP_TLS, - .user_visible = true, .owner = THIS_MODULE, .init = tls_init, }; @@ -724,7 +724,6 @@ static int __init tls_register(void) build_protos(tls_prots[TLSV4], &tcp_prot); tls_sw_proto_ops = inet_stream_ops; - tls_sw_proto_ops.poll = tls_sw_poll; tls_sw_proto_ops.splice_read = tls_sw_splice_read; #ifdef CONFIG_TLS_DEVICE diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index aa9fdce272b6..7b1af8b59cd2 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -4,6 +4,7 @@ * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved. * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved. * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved. + * Copyright (c) 2018, Covalent IO, Inc. http://covalent.io * * This software is available to you under a choice of one of two * licenses. You may choose to be licensed under the terms of the GNU @@ -213,153 +214,89 @@ static int tls_do_decryption(struct sock *sk, return ret; } -static void trim_sg(struct sock *sk, struct scatterlist *sg, - int *sg_num_elem, unsigned int *sg_size, int target_size) -{ - int i = *sg_num_elem - 1; - int trim = *sg_size - target_size; - - if (trim <= 0) { - WARN_ON(trim < 0); - return; - } - - *sg_size = target_size; - while (trim >= sg[i].length) { - trim -= sg[i].length; - sk_mem_uncharge(sk, sg[i].length); - put_page(sg_page(&sg[i])); - i--; - - if (i < 0) - goto out; - } - - sg[i].length -= trim; - sk_mem_uncharge(sk, trim); - -out: - *sg_num_elem = i + 1; -} - -static void trim_both_sgl(struct sock *sk, int target_size) +static void tls_trim_both_msgs(struct sock *sk, int target_size) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_rec *rec = ctx->open_rec; - trim_sg(sk, &rec->sg_plaintext_data[1], - &rec->sg_plaintext_num_elem, - &rec->sg_plaintext_size, - target_size); - + sk_msg_trim(sk, &rec->msg_plaintext, target_size); if (target_size > 0) target_size += tls_ctx->tx.overhead_size; - - trim_sg(sk, &rec->sg_encrypted_data[1], - &rec->sg_encrypted_num_elem, - &rec->sg_encrypted_size, - target_size); + sk_msg_trim(sk, &rec->msg_encrypted, target_size); } -static int alloc_encrypted_sg(struct sock *sk, int len) +static int tls_alloc_encrypted_msg(struct sock *sk, int len) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_rec *rec = ctx->open_rec; - int rc = 0; - - rc = sk_alloc_sg(sk, len, - &rec->sg_encrypted_data[1], 0, - &rec->sg_encrypted_num_elem, - &rec->sg_encrypted_size, 0); - - if (rc == -ENOSPC) - rec->sg_encrypted_num_elem = - ARRAY_SIZE(rec->sg_encrypted_data) - 1; + struct sk_msg *msg_en = &rec->msg_encrypted; - return rc; + return sk_msg_alloc(sk, msg_en, len, 0); } -static int move_to_plaintext_sg(struct sock *sk, int required_size) +static int tls_clone_plaintext_msg(struct sock *sk, int required) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_rec *rec = ctx->open_rec; - struct scatterlist *plain_sg = &rec->sg_plaintext_data[1]; - struct scatterlist *enc_sg = &rec->sg_encrypted_data[1]; - int enc_sg_idx = 0; + struct sk_msg *msg_pl = &rec->msg_plaintext; + struct sk_msg *msg_en = &rec->msg_encrypted; int skip, len; - if (rec->sg_plaintext_num_elem == MAX_SKB_FRAGS) - return -ENOSPC; - - /* We add page references worth len bytes from enc_sg at the - * end of plain_sg. It is guaranteed that sg_encrypted_data + /* We add page references worth len bytes from encrypted sg + * at the end of plaintext sg. It is guaranteed that msg_en * has enough required room (ensured by caller). */ - len = required_size - rec->sg_plaintext_size; + len = required - msg_pl->sg.size; - /* Skip initial bytes in sg_encrypted_data to be able - * to use same offset of both plain and encrypted data. + /* Skip initial bytes in msg_en's data to be able to use + * same offset of both plain and encrypted data. */ - skip = tls_ctx->tx.prepend_size + rec->sg_plaintext_size; + skip = tls_ctx->tx.prepend_size + msg_pl->sg.size; - while (enc_sg_idx < rec->sg_encrypted_num_elem) { - if (enc_sg[enc_sg_idx].length > skip) - break; - - skip -= enc_sg[enc_sg_idx].length; - enc_sg_idx++; - } - - /* unmark the end of plain_sg*/ - sg_unmark_end(plain_sg + rec->sg_plaintext_num_elem - 1); - - while (len) { - struct page *page = sg_page(&enc_sg[enc_sg_idx]); - int bytes = enc_sg[enc_sg_idx].length - skip; - int offset = enc_sg[enc_sg_idx].offset + skip; - - if (bytes > len) - bytes = len; - else - enc_sg_idx++; + return sk_msg_clone(sk, msg_pl, msg_en, skip, len); +} - /* Skipping is required only one time */ - skip = 0; +static struct tls_rec *tls_get_rec(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct sk_msg *msg_pl, *msg_en; + struct tls_rec *rec; + int mem_size; - /* Increment page reference */ - get_page(page); + mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send); - sg_set_page(&plain_sg[rec->sg_plaintext_num_elem], page, - bytes, offset); + rec = kzalloc(mem_size, sk->sk_allocation); + if (!rec) + return NULL; - sk_mem_charge(sk, bytes); + msg_pl = &rec->msg_plaintext; + msg_en = &rec->msg_encrypted; - len -= bytes; - rec->sg_plaintext_size += bytes; + sk_msg_init(msg_pl); + sk_msg_init(msg_en); - rec->sg_plaintext_num_elem++; + sg_init_table(rec->sg_aead_in, 2); + sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, + sizeof(rec->aad_space)); + sg_unmark_end(&rec->sg_aead_in[1]); - if (rec->sg_plaintext_num_elem == MAX_SKB_FRAGS) - return -ENOSPC; - } + sg_init_table(rec->sg_aead_out, 2); + sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, + sizeof(rec->aad_space)); + sg_unmark_end(&rec->sg_aead_out[1]); - return 0; + return rec; } -static void free_sg(struct sock *sk, struct scatterlist *sg, - int *sg_num_elem, unsigned int *sg_size) +static void tls_free_rec(struct sock *sk, struct tls_rec *rec) { - int i, n = *sg_num_elem; - - for (i = 0; i < n; ++i) { - sk_mem_uncharge(sk, sg[i].length); - put_page(sg_page(&sg[i])); - } - *sg_num_elem = 0; - *sg_size = 0; + sk_msg_free(sk, &rec->msg_encrypted); + sk_msg_free(sk, &rec->msg_plaintext); + kfree(rec); } static void tls_free_open_rec(struct sock *sk) @@ -368,19 +305,10 @@ static void tls_free_open_rec(struct sock *sk) struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_rec *rec = ctx->open_rec; - /* Return if there is no open record */ - if (!rec) - return; - - free_sg(sk, &rec->sg_encrypted_data[1], - &rec->sg_encrypted_num_elem, - &rec->sg_encrypted_size); - - free_sg(sk, &rec->sg_plaintext_data[1], - &rec->sg_plaintext_num_elem, - &rec->sg_plaintext_size); - - kfree(rec); + if (rec) { + tls_free_rec(sk, rec); + ctx->open_rec = NULL; + } } int tls_tx_records(struct sock *sk, int flags) @@ -388,6 +316,7 @@ int tls_tx_records(struct sock *sk, int flags) struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_rec *rec, *tmp; + struct sk_msg *msg_en; int tx_flags, rc = 0; if (tls_is_partially_sent_record(tls_ctx)) { @@ -407,9 +336,7 @@ int tls_tx_records(struct sock *sk, int flags) * Remove the head of tx_list */ list_del(&rec->list); - free_sg(sk, &rec->sg_plaintext_data[1], - &rec->sg_plaintext_num_elem, &rec->sg_plaintext_size); - + sk_msg_free(sk, &rec->msg_plaintext); kfree(rec); } @@ -421,17 +348,15 @@ int tls_tx_records(struct sock *sk, int flags) else tx_flags = flags; + msg_en = &rec->msg_encrypted; rc = tls_push_sg(sk, tls_ctx, - &rec->sg_encrypted_data[1], + &msg_en->sg.data[msg_en->sg.curr], 0, tx_flags); if (rc) goto tx_err; list_del(&rec->list); - free_sg(sk, &rec->sg_plaintext_data[1], - &rec->sg_plaintext_num_elem, - &rec->sg_plaintext_size); - + sk_msg_free(sk, &rec->msg_plaintext); kfree(rec); } else { break; @@ -451,15 +376,18 @@ static void tls_encrypt_done(struct crypto_async_request *req, int err) struct sock *sk = req->data; struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct scatterlist *sge; + struct sk_msg *msg_en; struct tls_rec *rec; bool ready = false; int pending; rec = container_of(aead_req, struct tls_rec, aead_req); + msg_en = &rec->msg_encrypted; - rec->sg_encrypted_data[1].offset -= tls_ctx->tx.prepend_size; - rec->sg_encrypted_data[1].length += tls_ctx->tx.prepend_size; - + sge = sk_msg_elem(msg_en, msg_en->sg.curr); + sge->offset -= tls_ctx->tx.prepend_size; + sge->length += tls_ctx->tx.prepend_size; /* Check if error is previously set on socket */ if (err || sk->sk_err) { @@ -497,31 +425,29 @@ static void tls_encrypt_done(struct crypto_async_request *req, int err) /* Schedule the transmission */ if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) - schedule_delayed_work(&ctx->tx_work.work, 2); + schedule_delayed_work(&ctx->tx_work.work, 1); } static int tls_do_encryption(struct sock *sk, struct tls_context *tls_ctx, struct tls_sw_context_tx *ctx, struct aead_request *aead_req, - size_t data_len) + size_t data_len, u32 start) { struct tls_rec *rec = ctx->open_rec; - struct scatterlist *plain_sg = rec->sg_plaintext_data; - struct scatterlist *enc_sg = rec->sg_encrypted_data; + struct sk_msg *msg_en = &rec->msg_encrypted; + struct scatterlist *sge = sk_msg_elem(msg_en, start); int rc; - /* Skip the first index as it contains AAD data */ - rec->sg_encrypted_data[1].offset += tls_ctx->tx.prepend_size; - rec->sg_encrypted_data[1].length -= tls_ctx->tx.prepend_size; + sge->offset += tls_ctx->tx.prepend_size; + sge->length -= tls_ctx->tx.prepend_size; - /* If it is inplace crypto, then pass same SG list as both src, dst */ - if (rec->inplace_crypto) - plain_sg = enc_sg; + msg_en->sg.curr = start; aead_request_set_tfm(aead_req, ctx->aead_send); aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); - aead_request_set_crypt(aead_req, plain_sg, enc_sg, + aead_request_set_crypt(aead_req, rec->sg_aead_in, + rec->sg_aead_out, data_len, tls_ctx->tx.iv); aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG, @@ -534,8 +460,8 @@ static int tls_do_encryption(struct sock *sk, rc = crypto_aead_encrypt(aead_req); if (!rc || rc != -EINPROGRESS) { atomic_dec(&ctx->encrypt_pending); - rec->sg_encrypted_data[1].offset -= tls_ctx->tx.prepend_size; - rec->sg_encrypted_data[1].length += tls_ctx->tx.prepend_size; + sge->offset -= tls_ctx->tx.prepend_size; + sge->length += tls_ctx->tx.prepend_size; } if (!rc) { @@ -551,177 +477,318 @@ static int tls_do_encryption(struct sock *sk, return rc; } +static int tls_split_open_record(struct sock *sk, struct tls_rec *from, + struct tls_rec **to, struct sk_msg *msg_opl, + struct sk_msg *msg_oen, u32 split_point, + u32 tx_overhead_size, u32 *orig_end) +{ + u32 i, j, bytes = 0, apply = msg_opl->apply_bytes; + struct scatterlist *sge, *osge, *nsge; + u32 orig_size = msg_opl->sg.size; + struct scatterlist tmp = { }; + struct sk_msg *msg_npl; + struct tls_rec *new; + int ret; + + new = tls_get_rec(sk); + if (!new) + return -ENOMEM; + ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size + + tx_overhead_size, 0); + if (ret < 0) { + tls_free_rec(sk, new); + return ret; + } + + *orig_end = msg_opl->sg.end; + i = msg_opl->sg.start; + sge = sk_msg_elem(msg_opl, i); + while (apply && sge->length) { + if (sge->length > apply) { + u32 len = sge->length - apply; + + get_page(sg_page(sge)); + sg_set_page(&tmp, sg_page(sge), len, + sge->offset + apply); + sge->length = apply; + bytes += apply; + apply = 0; + } else { + apply -= sge->length; + bytes += sge->length; + } + + sk_msg_iter_var_next(i); + if (i == msg_opl->sg.end) + break; + sge = sk_msg_elem(msg_opl, i); + } + + msg_opl->sg.end = i; + msg_opl->sg.curr = i; + msg_opl->sg.copybreak = 0; + msg_opl->apply_bytes = 0; + msg_opl->sg.size = bytes; + + msg_npl = &new->msg_plaintext; + msg_npl->apply_bytes = apply; + msg_npl->sg.size = orig_size - bytes; + + j = msg_npl->sg.start; + nsge = sk_msg_elem(msg_npl, j); + if (tmp.length) { + memcpy(nsge, &tmp, sizeof(*nsge)); + sk_msg_iter_var_next(j); + nsge = sk_msg_elem(msg_npl, j); + } + + osge = sk_msg_elem(msg_opl, i); + while (osge->length) { + memcpy(nsge, osge, sizeof(*nsge)); + sg_unmark_end(nsge); + sk_msg_iter_var_next(i); + sk_msg_iter_var_next(j); + if (i == *orig_end) + break; + osge = sk_msg_elem(msg_opl, i); + nsge = sk_msg_elem(msg_npl, j); + } + + msg_npl->sg.end = j; + msg_npl->sg.curr = j; + msg_npl->sg.copybreak = 0; + + *to = new; + return 0; +} + +static void tls_merge_open_record(struct sock *sk, struct tls_rec *to, + struct tls_rec *from, u32 orig_end) +{ + struct sk_msg *msg_npl = &from->msg_plaintext; + struct sk_msg *msg_opl = &to->msg_plaintext; + struct scatterlist *osge, *nsge; + u32 i, j; + + i = msg_opl->sg.end; + sk_msg_iter_var_prev(i); + j = msg_npl->sg.start; + + osge = sk_msg_elem(msg_opl, i); + nsge = sk_msg_elem(msg_npl, j); + + if (sg_page(osge) == sg_page(nsge) && + osge->offset + osge->length == nsge->offset) { + osge->length += nsge->length; + put_page(sg_page(nsge)); + } + + msg_opl->sg.end = orig_end; + msg_opl->sg.curr = orig_end; + msg_opl->sg.copybreak = 0; + msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size; + msg_opl->sg.size += msg_npl->sg.size; + + sk_msg_free(sk, &to->msg_encrypted); + sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted); + + kfree(from); +} + static int tls_push_record(struct sock *sk, int flags, unsigned char record_type) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); - struct tls_rec *rec = ctx->open_rec; + struct tls_rec *rec = ctx->open_rec, *tmp = NULL; + u32 i, split_point, uninitialized_var(orig_end); + struct sk_msg *msg_pl, *msg_en; struct aead_request *req; + bool split; int rc; if (!rec) return 0; + msg_pl = &rec->msg_plaintext; + msg_en = &rec->msg_encrypted; + + split_point = msg_pl->apply_bytes; + split = split_point && split_point < msg_pl->sg.size; + if (split) { + rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en, + split_point, tls_ctx->tx.overhead_size, + &orig_end); + if (rc < 0) + return rc; + sk_msg_trim(sk, msg_en, msg_pl->sg.size + + tls_ctx->tx.overhead_size); + } + rec->tx_flags = flags; req = &rec->aead_req; - sg_mark_end(rec->sg_plaintext_data + rec->sg_plaintext_num_elem); - sg_mark_end(rec->sg_encrypted_data + rec->sg_encrypted_num_elem); + i = msg_pl->sg.end; + sk_msg_iter_var_prev(i); + sg_mark_end(sk_msg_elem(msg_pl, i)); + + i = msg_pl->sg.start; + sg_chain(rec->sg_aead_in, 2, rec->inplace_crypto ? + &msg_en->sg.data[i] : &msg_pl->sg.data[i]); + + i = msg_en->sg.end; + sk_msg_iter_var_prev(i); + sg_mark_end(sk_msg_elem(msg_en, i)); - tls_make_aad(rec->aad_space, rec->sg_plaintext_size, + i = msg_en->sg.start; + sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]); + + tls_make_aad(rec->aad_space, msg_pl->sg.size, tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size, record_type); tls_fill_prepend(tls_ctx, - page_address(sg_page(&rec->sg_encrypted_data[1])) + - rec->sg_encrypted_data[1].offset, - rec->sg_plaintext_size, record_type); - - tls_ctx->pending_open_record_frags = 0; + page_address(sg_page(&msg_en->sg.data[i])) + + msg_en->sg.data[i].offset, msg_pl->sg.size, + record_type); - rc = tls_do_encryption(sk, tls_ctx, ctx, req, rec->sg_plaintext_size); - if (rc == -EINPROGRESS) - return -EINPROGRESS; + tls_ctx->pending_open_record_frags = false; + rc = tls_do_encryption(sk, tls_ctx, ctx, req, msg_pl->sg.size, i); if (rc < 0) { - tls_err_abort(sk, EBADMSG); + if (rc != -EINPROGRESS) { + tls_err_abort(sk, EBADMSG); + if (split) { + tls_ctx->pending_open_record_frags = true; + tls_merge_open_record(sk, rec, tmp, orig_end); + } + } return rc; + } else if (split) { + msg_pl = &tmp->msg_plaintext; + msg_en = &tmp->msg_encrypted; + sk_msg_trim(sk, msg_en, msg_pl->sg.size + + tls_ctx->tx.overhead_size); + tls_ctx->pending_open_record_frags = true; + ctx->open_rec = tmp; } return tls_tx_records(sk, flags); } -static int tls_sw_push_pending_record(struct sock *sk, int flags) -{ - return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA); -} - -static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from, - int length, int *pages_used, - unsigned int *size_used, - struct scatterlist *to, int to_max_pages, - bool charge) +static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk, + bool full_record, u8 record_type, + size_t *copied, int flags) { - struct page *pages[MAX_SKB_FRAGS]; - - size_t offset; - ssize_t copied, use; - int i = 0; - unsigned int size = *size_used; - int num_elem = *pages_used; - int rc = 0; - int maxpages; - - while (length > 0) { - i = 0; - maxpages = to_max_pages - num_elem; - if (maxpages == 0) { - rc = -EFAULT; - goto out; - } - copied = iov_iter_get_pages(from, pages, - length, - maxpages, &offset); - if (copied <= 0) { - rc = -EFAULT; - goto out; + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct sk_msg msg_redir = { }; + struct sk_psock *psock; + struct sock *sk_redir; + struct tls_rec *rec; + int err = 0, send; + bool enospc; + + psock = sk_psock_get(sk); + if (!psock) + return tls_push_record(sk, flags, record_type); +more_data: + enospc = sk_msg_full(msg); + if (psock->eval == __SK_NONE) + psock->eval = sk_psock_msg_verdict(sk, psock, msg); + if (msg->cork_bytes && msg->cork_bytes > msg->sg.size && + !enospc && !full_record) { + err = -ENOSPC; + goto out_err; + } + msg->cork_bytes = 0; + send = msg->sg.size; + if (msg->apply_bytes && msg->apply_bytes < send) + send = msg->apply_bytes; + + switch (psock->eval) { + case __SK_PASS: + err = tls_push_record(sk, flags, record_type); + if (err < 0) { + *copied -= sk_msg_free(sk, msg); + tls_free_open_rec(sk); + goto out_err; } - - iov_iter_advance(from, copied); - - length -= copied; - size += copied; - while (copied) { - use = min_t(int, copied, PAGE_SIZE - offset); - - sg_set_page(&to[num_elem], - pages[i], use, offset); - sg_unmark_end(&to[num_elem]); - if (charge) - sk_mem_charge(sk, use); - - offset = 0; - copied -= use; - - ++i; - ++num_elem; + break; + case __SK_REDIRECT: + sk_redir = psock->sk_redir; + memcpy(&msg_redir, msg, sizeof(*msg)); + if (msg->apply_bytes < send) + msg->apply_bytes = 0; + else + msg->apply_bytes -= send; + sk_msg_return_zero(sk, msg, send); + msg->sg.size -= send; + release_sock(sk); + err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags); + lock_sock(sk); + if (err < 0) { + *copied -= sk_msg_free_nocharge(sk, &msg_redir); + msg->sg.size = 0; } + if (msg->sg.size == 0) + tls_free_open_rec(sk); + break; + case __SK_DROP: + default: + sk_msg_free_partial(sk, msg, send); + if (msg->apply_bytes < send) + msg->apply_bytes = 0; + else + msg->apply_bytes -= send; + if (msg->sg.size == 0) + tls_free_open_rec(sk); + *copied -= send; + err = -EACCES; } - /* Mark the end in the last sg entry if newly added */ - if (num_elem > *pages_used) - sg_mark_end(&to[num_elem - 1]); -out: - if (rc) - iov_iter_revert(from, size - *size_used); - *size_used = size; - *pages_used = num_elem; + if (likely(!err)) { + bool reset_eval = !ctx->open_rec; - return rc; -} - -static int memcopy_from_iter(struct sock *sk, struct iov_iter *from, - int bytes) -{ - struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); - struct tls_rec *rec = ctx->open_rec; - struct scatterlist *sg = &rec->sg_plaintext_data[1]; - int copy, i, rc = 0; - - for (i = tls_ctx->pending_open_record_frags; - i < rec->sg_plaintext_num_elem; ++i) { - copy = sg[i].length; - if (copy_from_iter( - page_address(sg_page(&sg[i])) + sg[i].offset, - copy, from) != copy) { - rc = -EFAULT; - goto out; + rec = ctx->open_rec; + if (rec) { + msg = &rec->msg_plaintext; + if (!msg->apply_bytes) + reset_eval = true; } - bytes -= copy; - - ++tls_ctx->pending_open_record_frags; - - if (!bytes) - break; + if (reset_eval) { + psock->eval = __SK_NONE; + if (psock->sk_redir) { + sock_put(psock->sk_redir); + psock->sk_redir = NULL; + } + } + if (rec) + goto more_data; } - -out: - return rc; + out_err: + sk_psock_put(sk, psock); + return err; } -static struct tls_rec *get_rec(struct sock *sk) +static int tls_sw_push_pending_record(struct sock *sk, int flags) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); - struct tls_rec *rec; - int mem_size; - - /* Return if we already have an open record */ - if (ctx->open_rec) - return ctx->open_rec; - - mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send); + struct tls_rec *rec = ctx->open_rec; + struct sk_msg *msg_pl; + size_t copied; - rec = kzalloc(mem_size, sk->sk_allocation); if (!rec) - return NULL; - - sg_init_table(&rec->sg_plaintext_data[0], - ARRAY_SIZE(rec->sg_plaintext_data)); - sg_init_table(&rec->sg_encrypted_data[0], - ARRAY_SIZE(rec->sg_encrypted_data)); - - sg_set_buf(&rec->sg_plaintext_data[0], rec->aad_space, - sizeof(rec->aad_space)); - sg_set_buf(&rec->sg_encrypted_data[0], rec->aad_space, - sizeof(rec->aad_space)); + return 0; - ctx->open_rec = rec; - rec->inplace_crypto = 1; + msg_pl = &rec->msg_plaintext; + copied = msg_pl->sg.size; + if (!copied) + return 0; - return rec; + return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA, + &copied, flags); } int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) @@ -732,9 +799,10 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) struct crypto_tfm *tfm = crypto_aead_tfm(ctx->aead_send); bool async_capable = tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC; unsigned char record_type = TLS_RECORD_TYPE_DATA; - bool is_kvec = msg->msg_iter.type & ITER_KVEC; + bool is_kvec = iov_iter_is_kvec(&msg->msg_iter); bool eor = !(msg->msg_flags & MSG_MORE); size_t try_to_copy, copied = 0; + struct sk_msg *msg_pl, *msg_en; struct tls_rec *rec; int required_size; int num_async = 0; @@ -772,29 +840,35 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) goto send_end; } - rec = get_rec(sk); + if (ctx->open_rec) + rec = ctx->open_rec; + else + rec = ctx->open_rec = tls_get_rec(sk); if (!rec) { ret = -ENOMEM; goto send_end; } - orig_size = rec->sg_plaintext_size; + msg_pl = &rec->msg_plaintext; + msg_en = &rec->msg_encrypted; + + orig_size = msg_pl->sg.size; full_record = false; try_to_copy = msg_data_left(msg); - record_room = TLS_MAX_PAYLOAD_SIZE - rec->sg_plaintext_size; + record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size; if (try_to_copy >= record_room) { try_to_copy = record_room; full_record = true; } - required_size = rec->sg_plaintext_size + try_to_copy + + required_size = msg_pl->sg.size + try_to_copy + tls_ctx->tx.overhead_size; if (!sk_stream_memory_free(sk)) goto wait_for_sndbuf; alloc_encrypted: - ret = alloc_encrypted_sg(sk, required_size); + ret = tls_alloc_encrypted_msg(sk, required_size); if (ret) { if (ret != -ENOSPC) goto wait_for_memory; @@ -803,17 +877,15 @@ alloc_encrypted: * actually allocated. The difference is due * to max sg elements limit */ - try_to_copy -= required_size - rec->sg_encrypted_size; + try_to_copy -= required_size - msg_en->sg.size; full_record = true; } if (!is_kvec && (full_record || eor) && !async_capable) { - ret = zerocopy_from_iter(sk, &msg->msg_iter, - try_to_copy, &rec->sg_plaintext_num_elem, - &rec->sg_plaintext_size, - &rec->sg_plaintext_data[1], - ARRAY_SIZE(rec->sg_plaintext_data) - 1, - true); + u32 first = msg_pl->sg.end; + + ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter, + msg_pl, try_to_copy); if (ret) goto fallback_to_reg_send; @@ -821,25 +893,34 @@ alloc_encrypted: num_zc++; copied += try_to_copy; - ret = tls_push_record(sk, msg->msg_flags, record_type); + + sk_msg_sg_copy_set(msg_pl, first); + ret = bpf_exec_tx_verdict(msg_pl, sk, full_record, + record_type, &copied, + msg->msg_flags); if (ret) { if (ret == -EINPROGRESS) num_async++; + else if (ret == -ENOMEM) + goto wait_for_memory; + else if (ret == -ENOSPC) + goto rollback_iter; else if (ret != -EAGAIN) goto send_end; } continue; - +rollback_iter: + copied -= try_to_copy; + sk_msg_sg_copy_clear(msg_pl, first); + iov_iter_revert(&msg->msg_iter, + msg_pl->sg.size - orig_size); fallback_to_reg_send: - trim_sg(sk, &rec->sg_plaintext_data[1], - &rec->sg_plaintext_num_elem, - &rec->sg_plaintext_size, - orig_size); + sk_msg_trim(sk, msg_pl, orig_size); } - required_size = rec->sg_plaintext_size + try_to_copy; + required_size = msg_pl->sg.size + try_to_copy; - ret = move_to_plaintext_sg(sk, required_size); + ret = tls_clone_plaintext_msg(sk, required_size); if (ret) { if (ret != -ENOSPC) goto send_end; @@ -848,28 +929,36 @@ fallback_to_reg_send: * actually allocated. The difference is due * to max sg elements limit */ - try_to_copy -= required_size - rec->sg_plaintext_size; + try_to_copy -= required_size - msg_pl->sg.size; full_record = true; - - trim_sg(sk, &rec->sg_encrypted_data[1], - &rec->sg_encrypted_num_elem, - &rec->sg_encrypted_size, - rec->sg_plaintext_size + - tls_ctx->tx.overhead_size); + sk_msg_trim(sk, msg_en, msg_pl->sg.size + + tls_ctx->tx.overhead_size); } - ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy); - if (ret) + ret = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_pl, + try_to_copy); + if (ret < 0) goto trim_sgl; + /* Open records defined only if successfully copied, otherwise + * we would trim the sg but not reset the open record frags. + */ + tls_ctx->pending_open_record_frags = true; copied += try_to_copy; if (full_record || eor) { - ret = tls_push_record(sk, msg->msg_flags, record_type); + ret = bpf_exec_tx_verdict(msg_pl, sk, full_record, + record_type, &copied, + msg->msg_flags); if (ret) { if (ret == -EINPROGRESS) num_async++; - else if (ret != -EAGAIN) + else if (ret == -ENOMEM) + goto wait_for_memory; + else if (ret != -EAGAIN) { + if (ret == -ENOSPC) + ret = 0; goto send_end; + } } } @@ -881,11 +970,11 @@ wait_for_memory: ret = sk_stream_wait_memory(sk, &timeo); if (ret) { trim_sgl: - trim_both_sgl(sk, orig_size); + tls_trim_both_msgs(sk, orig_size); goto send_end; } - if (rec->sg_encrypted_size < required_size) + if (msg_en->sg.size < required_size) goto alloc_encrypted; } @@ -928,10 +1017,10 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); unsigned char record_type = TLS_RECORD_TYPE_DATA; - size_t orig_size = size; - struct scatterlist *sg; + struct sk_msg *msg_pl; struct tls_rec *rec; int num_async = 0; + size_t copied = 0; bool full_record; int record_room; int ret = 0; @@ -964,26 +1053,33 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, goto sendpage_end; } - rec = get_rec(sk); + if (ctx->open_rec) + rec = ctx->open_rec; + else + rec = ctx->open_rec = tls_get_rec(sk); if (!rec) { ret = -ENOMEM; goto sendpage_end; } + msg_pl = &rec->msg_plaintext; + full_record = false; - record_room = TLS_MAX_PAYLOAD_SIZE - rec->sg_plaintext_size; + record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size; + copied = 0; copy = size; if (copy >= record_room) { copy = record_room; full_record = true; } - required_size = rec->sg_plaintext_size + copy + - tls_ctx->tx.overhead_size; + + required_size = msg_pl->sg.size + copy + + tls_ctx->tx.overhead_size; if (!sk_stream_memory_free(sk)) goto wait_for_sndbuf; alloc_payload: - ret = alloc_encrypted_sg(sk, required_size); + ret = tls_alloc_encrypted_msg(sk, required_size); if (ret) { if (ret != -ENOSPC) goto wait_for_memory; @@ -992,33 +1088,32 @@ alloc_payload: * actually allocated. The difference is due * to max sg elements limit */ - copy -= required_size - rec->sg_plaintext_size; + copy -= required_size - msg_pl->sg.size; full_record = true; } - get_page(page); - sg = &rec->sg_plaintext_data[1] + rec->sg_plaintext_num_elem; - sg_set_page(sg, page, copy, offset); - sg_unmark_end(sg); - - rec->sg_plaintext_num_elem++; - + sk_msg_page_add(msg_pl, page, copy, offset); sk_mem_charge(sk, copy); + offset += copy; size -= copy; - rec->sg_plaintext_size += copy; - tls_ctx->pending_open_record_frags = rec->sg_plaintext_num_elem; + copied += copy; - if (full_record || eor || - rec->sg_plaintext_num_elem == - ARRAY_SIZE(rec->sg_plaintext_data) - 1) { + tls_ctx->pending_open_record_frags = true; + if (full_record || eor || sk_msg_full(msg_pl)) { rec->inplace_crypto = 0; - ret = tls_push_record(sk, flags, record_type); + ret = bpf_exec_tx_verdict(msg_pl, sk, full_record, + record_type, &copied, flags); if (ret) { if (ret == -EINPROGRESS) num_async++; - else if (ret != -EAGAIN) + else if (ret == -ENOMEM) + goto wait_for_memory; + else if (ret != -EAGAIN) { + if (ret == -ENOSPC) + ret = 0; goto sendpage_end; + } } } continue; @@ -1027,7 +1122,7 @@ wait_for_sndbuf: wait_for_memory: ret = sk_stream_wait_memory(sk, &timeo); if (ret) { - trim_both_sgl(sk, rec->sg_plaintext_size); + tls_trim_both_msgs(sk, msg_pl->sg.size); goto sendpage_end; } @@ -1042,24 +1137,20 @@ wait_for_memory: } } sendpage_end: - if (orig_size > size) - ret = orig_size - size; - else - ret = sk_stream_error(sk, flags, ret); - + ret = sk_stream_error(sk, flags, ret); release_sock(sk); - return ret; + return copied ? copied : ret; } -static struct sk_buff *tls_wait_data(struct sock *sk, int flags, - long timeo, int *err) +static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock, + int flags, long timeo, int *err) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct sk_buff *skb; DEFINE_WAIT_FUNC(wait, woken_wake_function); - while (!(skb = ctx->recv_pkt)) { + while (!(skb = ctx->recv_pkt) && sk_psock_queue_empty(psock)) { if (sk->sk_err) { *err = sock_error(sk); return NULL; @@ -1078,7 +1169,10 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int flags, add_wait_queue(sk_sleep(sk), &wait); sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); - sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait); + sk_wait_event(sk, &timeo, + ctx->recv_pkt != skb || + !sk_psock_queue_empty(psock), + &wait); sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); remove_wait_queue(sk_sleep(sk), &wait); @@ -1092,6 +1186,64 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int flags, return skb; } +static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from, + int length, int *pages_used, + unsigned int *size_used, + struct scatterlist *to, + int to_max_pages) +{ + int rc = 0, i = 0, num_elem = *pages_used, maxpages; + struct page *pages[MAX_SKB_FRAGS]; + unsigned int size = *size_used; + ssize_t copied, use; + size_t offset; + + while (length > 0) { + i = 0; + maxpages = to_max_pages - num_elem; + if (maxpages == 0) { + rc = -EFAULT; + goto out; + } + copied = iov_iter_get_pages(from, pages, + length, + maxpages, &offset); + if (copied <= 0) { + rc = -EFAULT; + goto out; + } + + iov_iter_advance(from, copied); + + length -= copied; + size += copied; + while (copied) { + use = min_t(int, copied, PAGE_SIZE - offset); + + sg_set_page(&to[num_elem], + pages[i], use, offset); + sg_unmark_end(&to[num_elem]); + /* We do not uncharge memory from this API */ + + offset = 0; + copied -= use; + + i++; + num_elem++; + } + } + /* Mark the end in the last sg entry if newly added */ + if (num_elem > *pages_used) + sg_mark_end(&to[num_elem - 1]); +out: + if (rc) + iov_iter_revert(from, size - *size_used); + *size_used = size; + *pages_used = num_elem; + + return rc; +} + /* This function decrypts the input skb into either out_iov or in out_sg * or in skb buffers itself. The input parameter 'zc' indicates if * zero-copy mode needs to be tried or not. With zero-copy mode, either @@ -1189,9 +1341,9 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE); *chunk = 0; - err = zerocopy_from_iter(sk, out_iov, data_len, &pages, - chunk, &sgout[1], - (n_sgout - 1), false); + err = tls_setup_from_iter(sk, out_iov, data_len, + &pages, chunk, &sgout[1], + (n_sgout - 1)); if (err < 0) goto fallback_to_reg_recv; } else if (out_sg) { @@ -1297,6 +1449,7 @@ int tls_sw_recvmsg(struct sock *sk, { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct sk_psock *psock; unsigned char control; struct strp_msg *rxm; struct sk_buff *skb; @@ -1304,7 +1457,7 @@ int tls_sw_recvmsg(struct sock *sk, bool cmsg = false; int target, err = 0; long timeo; - bool is_kvec = msg->msg_iter.type & ITER_KVEC; + bool is_kvec = iov_iter_is_kvec(&msg->msg_iter); int num_async = 0; flags |= nonblock; @@ -1312,6 +1465,7 @@ int tls_sw_recvmsg(struct sock *sk, if (unlikely(flags & MSG_ERRQUEUE)) return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR); + psock = sk_psock_get(sk); lock_sock(sk); target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); @@ -1321,9 +1475,20 @@ int tls_sw_recvmsg(struct sock *sk, bool async = false; int chunk = 0; - skb = tls_wait_data(sk, flags, timeo, &err); - if (!skb) + skb = tls_wait_data(sk, psock, flags, timeo, &err); + if (!skb) { + if (psock) { + int ret = __tcp_bpf_recvmsg(sk, psock, + msg, len, flags); + + if (ret > 0) { + copied += ret; + len -= ret; + continue; + } + } goto recv_end; + } rxm = strp_msg(skb); @@ -1429,6 +1594,8 @@ recv_end: } release_sock(sk); + if (psock) + sk_psock_put(sk, psock); return copied ? : err; } @@ -1451,7 +1618,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); - skb = tls_wait_data(sk, flags, timeo, &err); + skb = tls_wait_data(sk, NULL, flags, timeo, &err); if (!skb) goto splice_read_end; @@ -1485,23 +1652,20 @@ splice_read_end: return copied ? : err; } -unsigned int tls_sw_poll(struct file *file, struct socket *sock, - struct poll_table_struct *wait) +bool tls_sw_stream_read(const struct sock *sk) { - unsigned int ret; - struct sock *sk = sock->sk; struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + bool ingress_empty = true; + struct sk_psock *psock; - /* Grab POLLOUT and POLLHUP from the underlying socket */ - ret = ctx->sk_poll(file, sock, wait); - - /* Clear POLLIN bits, and set based on recv_pkt */ - ret &= ~(POLLIN | POLLRDNORM); - if (ctx->recv_pkt) - ret |= POLLIN | POLLRDNORM; + rcu_read_lock(); + psock = sk_psock(sk); + if (psock) + ingress_empty = list_empty(&psock->ingress_msg); + rcu_read_unlock(); - return ret; + return !ingress_empty || ctx->recv_pkt; } static int tls_read_size(struct strparser *strp, struct sk_buff *skb) @@ -1580,8 +1744,15 @@ static void tls_data_ready(struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct sk_psock *psock; strp_data_ready(&ctx->strp); + + psock = sk_psock_get(sk); + if (psock && !list_empty(&psock->ingress_msg)) { + ctx->saved_data_ready(sk); + sk_psock_put(sk, psock); + } } void tls_sw_free_resources_tx(struct sock *sk) @@ -1619,25 +1790,15 @@ void tls_sw_free_resources_tx(struct sock *sk) rec = list_first_entry(&ctx->tx_list, struct tls_rec, list); - - free_sg(sk, &rec->sg_plaintext_data[1], - &rec->sg_plaintext_num_elem, - &rec->sg_plaintext_size); - list_del(&rec->list); + sk_msg_free(sk, &rec->msg_plaintext); kfree(rec); } list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) { - free_sg(sk, &rec->sg_encrypted_data[1], - &rec->sg_encrypted_num_elem, - &rec->sg_encrypted_size); - - free_sg(sk, &rec->sg_plaintext_data[1], - &rec->sg_plaintext_num_elem, - &rec->sg_plaintext_size); - list_del(&rec->list); + sk_msg_free(sk, &rec->msg_encrypted); + sk_msg_free(sk, &rec->msg_plaintext); kfree(rec); } @@ -1829,8 +1990,6 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) sk->sk_data_ready = tls_data_ready; write_unlock_bh(&sk->sk_callback_lock); - sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll; - strp_check_rcv(&sw_ctx_rx->strp); } diff --git a/net/unix/af_unix.c b/net/unix/af_unix.c index d1edfa3cad61..74d1eed7cbd4 100644 --- a/net/unix/af_unix.c +++ b/net/unix/af_unix.c @@ -225,6 +225,8 @@ static inline void unix_release_addr(struct unix_address *addr) static int unix_mkname(struct sockaddr_un *sunaddr, int len, unsigned int *hashp) { + *hashp = 0; + if (len <= sizeof(short) || len > sizeof(*sunaddr)) return -EINVAL; if (!sunaddr || sunaddr->sun_family != AF_UNIX) @@ -2640,7 +2642,7 @@ static __poll_t unix_poll(struct file *file, struct socket *sock, poll_table *wa struct sock *sk = sock->sk; __poll_t mask; - sock_poll_wait(file, wait); + sock_poll_wait(file, sock, wait); mask = 0; /* exceptional events? */ @@ -2677,7 +2679,7 @@ static __poll_t unix_dgram_poll(struct file *file, struct socket *sock, unsigned int writable; __poll_t mask; - sock_poll_wait(file, wait); + sock_poll_wait(file, sock, wait); mask = 0; /* exceptional events? */ diff --git a/net/wireless/core.c b/net/wireless/core.c index a88551f3bc43..5bd01058b9e6 100644 --- a/net/wireless/core.c +++ b/net/wireless/core.c @@ -1019,36 +1019,49 @@ void cfg80211_cqm_config_free(struct wireless_dev *wdev) wdev->cqm_config = NULL; } -void cfg80211_unregister_wdev(struct wireless_dev *wdev) +static void __cfg80211_unregister_wdev(struct wireless_dev *wdev, bool sync) { struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); ASSERT_RTNL(); - if (WARN_ON(wdev->netdev)) - return; - nl80211_notify_iface(rdev, wdev, NL80211_CMD_DEL_INTERFACE); list_del_rcu(&wdev->list); - synchronize_rcu(); + if (sync) + synchronize_rcu(); rdev->devlist_generation++; + cfg80211_mlme_purge_registrations(wdev); + switch (wdev->iftype) { case NL80211_IFTYPE_P2P_DEVICE: - cfg80211_mlme_purge_registrations(wdev); cfg80211_stop_p2p_device(rdev, wdev); break; case NL80211_IFTYPE_NAN: cfg80211_stop_nan(rdev, wdev); break; default: - WARN_ON_ONCE(1); break; } +#ifdef CONFIG_CFG80211_WEXT + kzfree(wdev->wext.keys); +#endif + /* only initialized if we have a netdev */ + if (wdev->netdev) + flush_work(&wdev->disconnect_wk); + cfg80211_cqm_config_free(wdev); } + +void cfg80211_unregister_wdev(struct wireless_dev *wdev) +{ + if (WARN_ON(wdev->netdev)) + return; + + __cfg80211_unregister_wdev(wdev, true); +} EXPORT_SYMBOL(cfg80211_unregister_wdev); static const struct device_type wiphy_type = { @@ -1153,6 +1166,30 @@ void cfg80211_stop_iface(struct wiphy *wiphy, struct wireless_dev *wdev, } EXPORT_SYMBOL(cfg80211_stop_iface); +void cfg80211_init_wdev(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev) +{ + mutex_init(&wdev->mtx); + INIT_LIST_HEAD(&wdev->event_list); + spin_lock_init(&wdev->event_lock); + INIT_LIST_HEAD(&wdev->mgmt_registrations); + spin_lock_init(&wdev->mgmt_registrations_lock); + + /* + * We get here also when the interface changes network namespaces, + * as it's registered into the new one, but we don't want it to + * change ID in that case. Checking if the ID is already assigned + * works, because 0 isn't considered a valid ID and the memory is + * 0-initialized. + */ + if (!wdev->identifier) + wdev->identifier = ++rdev->wdev_id; + list_add_rcu(&wdev->list, &rdev->wiphy.wdev_list); + rdev->devlist_generation++; + + nl80211_notify_iface(rdev, wdev, NL80211_CMD_NEW_INTERFACE); +} + static int cfg80211_netdev_notifier_call(struct notifier_block *nb, unsigned long state, void *ptr) { @@ -1178,23 +1215,6 @@ static int cfg80211_netdev_notifier_call(struct notifier_block *nb, * called within code protected by it when interfaces * are added with nl80211. */ - mutex_init(&wdev->mtx); - INIT_LIST_HEAD(&wdev->event_list); - spin_lock_init(&wdev->event_lock); - INIT_LIST_HEAD(&wdev->mgmt_registrations); - spin_lock_init(&wdev->mgmt_registrations_lock); - - /* - * We get here also when the interface changes network namespaces, - * as it's registered into the new one, but we don't want it to - * change ID in that case. Checking if the ID is already assigned - * works, because 0 isn't considered a valid ID and the memory is - * 0-initialized. - */ - if (!wdev->identifier) - wdev->identifier = ++rdev->wdev_id; - list_add_rcu(&wdev->list, &rdev->wiphy.wdev_list); - rdev->devlist_generation++; /* can only change netns with wiphy */ dev->features |= NETIF_F_NETNS_LOCAL; @@ -1223,7 +1243,7 @@ static int cfg80211_netdev_notifier_call(struct notifier_block *nb, INIT_WORK(&wdev->disconnect_wk, cfg80211_autodisconnect_wk); - nl80211_notify_iface(rdev, wdev, NL80211_CMD_NEW_INTERFACE); + cfg80211_init_wdev(rdev, wdev); break; case NETDEV_GOING_DOWN: cfg80211_leave(rdev, wdev); @@ -1238,7 +1258,7 @@ static int cfg80211_netdev_notifier_call(struct notifier_block *nb, list_for_each_entry_safe(pos, tmp, &rdev->sched_scan_req_list, list) { - if (WARN_ON(pos && pos->dev == wdev->netdev)) + if (WARN_ON(pos->dev == wdev->netdev)) cfg80211_stop_sched_scan_req(rdev, pos, false); } @@ -1302,17 +1322,8 @@ static int cfg80211_netdev_notifier_call(struct notifier_block *nb, * remove and clean it up. */ if (!list_empty(&wdev->list)) { - nl80211_notify_iface(rdev, wdev, - NL80211_CMD_DEL_INTERFACE); + __cfg80211_unregister_wdev(wdev, false); sysfs_remove_link(&dev->dev.kobj, "phy80211"); - list_del_rcu(&wdev->list); - rdev->devlist_generation++; - cfg80211_mlme_purge_registrations(wdev); -#ifdef CONFIG_CFG80211_WEXT - kzfree(wdev->wext.keys); -#endif - flush_work(&wdev->disconnect_wk); - cfg80211_cqm_config_free(wdev); } /* * synchronise (so that we won't find this netdev diff --git a/net/wireless/core.h b/net/wireless/core.h index 7f52ef569320..c61dbba8bf47 100644 --- a/net/wireless/core.h +++ b/net/wireless/core.h @@ -66,6 +66,7 @@ struct cfg80211_registered_device { /* protected by RTNL only */ int num_running_ifaces; int num_running_monitor_ifaces; + u64 cookie_counter; /* BSSes/scanning */ spinlock_t bss_lock; @@ -133,6 +134,16 @@ cfg80211_rdev_free_wowlan(struct cfg80211_registered_device *rdev) #endif } +static inline u64 cfg80211_assign_cookie(struct cfg80211_registered_device *rdev) +{ + u64 r = ++rdev->cookie_counter; + + if (WARN_ON(r == 0)) + r = ++rdev->cookie_counter; + + return r; +} + extern struct workqueue_struct *cfg80211_wq; extern struct list_head cfg80211_rdev_list; extern int cfg80211_rdev_list_generation; @@ -187,6 +198,9 @@ struct wiphy *wiphy_idx_to_wiphy(int wiphy_idx); int cfg80211_switch_netns(struct cfg80211_registered_device *rdev, struct net *net); +void cfg80211_init_wdev(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev); + static inline void wdev_lock(struct wireless_dev *wdev) __acquires(wdev) { diff --git a/net/wireless/lib80211_crypt_tkip.c b/net/wireless/lib80211_crypt_tkip.c index e6bce1f130c9..b5e235573c8a 100644 --- a/net/wireless/lib80211_crypt_tkip.c +++ b/net/wireless/lib80211_crypt_tkip.c @@ -30,7 +30,7 @@ #include <net/iw_handler.h> #include <crypto/hash.h> -#include <crypto/skcipher.h> +#include <linux/crypto.h> #include <linux/crc32.h> #include <net/lib80211.h> @@ -64,9 +64,9 @@ struct lib80211_tkip_data { int key_idx; - struct crypto_skcipher *rx_tfm_arc4; + struct crypto_cipher *rx_tfm_arc4; struct crypto_shash *rx_tfm_michael; - struct crypto_skcipher *tx_tfm_arc4; + struct crypto_cipher *tx_tfm_arc4; struct crypto_shash *tx_tfm_michael; /* scratch buffers for virt_to_page() (crypto API) */ @@ -99,8 +99,7 @@ static void *lib80211_tkip_init(int key_idx) priv->key_idx = key_idx; - priv->tx_tfm_arc4 = crypto_alloc_skcipher("ecb(arc4)", 0, - CRYPTO_ALG_ASYNC); + priv->tx_tfm_arc4 = crypto_alloc_cipher("arc4", 0, CRYPTO_ALG_ASYNC); if (IS_ERR(priv->tx_tfm_arc4)) { priv->tx_tfm_arc4 = NULL; goto fail; @@ -112,8 +111,7 @@ static void *lib80211_tkip_init(int key_idx) goto fail; } - priv->rx_tfm_arc4 = crypto_alloc_skcipher("ecb(arc4)", 0, - CRYPTO_ALG_ASYNC); + priv->rx_tfm_arc4 = crypto_alloc_cipher("arc4", 0, CRYPTO_ALG_ASYNC); if (IS_ERR(priv->rx_tfm_arc4)) { priv->rx_tfm_arc4 = NULL; goto fail; @@ -130,9 +128,9 @@ static void *lib80211_tkip_init(int key_idx) fail: if (priv) { crypto_free_shash(priv->tx_tfm_michael); - crypto_free_skcipher(priv->tx_tfm_arc4); + crypto_free_cipher(priv->tx_tfm_arc4); crypto_free_shash(priv->rx_tfm_michael); - crypto_free_skcipher(priv->rx_tfm_arc4); + crypto_free_cipher(priv->rx_tfm_arc4); kfree(priv); } @@ -144,9 +142,9 @@ static void lib80211_tkip_deinit(void *priv) struct lib80211_tkip_data *_priv = priv; if (_priv) { crypto_free_shash(_priv->tx_tfm_michael); - crypto_free_skcipher(_priv->tx_tfm_arc4); + crypto_free_cipher(_priv->tx_tfm_arc4); crypto_free_shash(_priv->rx_tfm_michael); - crypto_free_skcipher(_priv->rx_tfm_arc4); + crypto_free_cipher(_priv->rx_tfm_arc4); } kfree(priv); } @@ -344,12 +342,10 @@ static int lib80211_tkip_hdr(struct sk_buff *skb, int hdr_len, static int lib80211_tkip_encrypt(struct sk_buff *skb, int hdr_len, void *priv) { struct lib80211_tkip_data *tkey = priv; - SKCIPHER_REQUEST_ON_STACK(req, tkey->tx_tfm_arc4); int len; u8 rc4key[16], *pos, *icv; u32 crc; - struct scatterlist sg; - int err; + int i; if (tkey->flags & IEEE80211_CRYPTO_TKIP_COUNTERMEASURES) { struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data; @@ -374,14 +370,10 @@ static int lib80211_tkip_encrypt(struct sk_buff *skb, int hdr_len, void *priv) icv[2] = crc >> 16; icv[3] = crc >> 24; - crypto_skcipher_setkey(tkey->tx_tfm_arc4, rc4key, 16); - sg_init_one(&sg, pos, len + 4); - skcipher_request_set_tfm(req, tkey->tx_tfm_arc4); - skcipher_request_set_callback(req, 0, NULL, NULL); - skcipher_request_set_crypt(req, &sg, &sg, len + 4, NULL); - err = crypto_skcipher_encrypt(req); - skcipher_request_zero(req); - return err; + crypto_cipher_setkey(tkey->tx_tfm_arc4, rc4key, 16); + for (i = 0; i < len + 4; i++) + crypto_cipher_encrypt_one(tkey->tx_tfm_arc4, pos + i, pos + i); + return 0; } /* @@ -400,7 +392,6 @@ static inline int tkip_replay_check(u32 iv32_n, u16 iv16_n, static int lib80211_tkip_decrypt(struct sk_buff *skb, int hdr_len, void *priv) { struct lib80211_tkip_data *tkey = priv; - SKCIPHER_REQUEST_ON_STACK(req, tkey->rx_tfm_arc4); u8 rc4key[16]; u8 keyidx, *pos; u32 iv32; @@ -408,9 +399,8 @@ static int lib80211_tkip_decrypt(struct sk_buff *skb, int hdr_len, void *priv) struct ieee80211_hdr *hdr; u8 icv[4]; u32 crc; - struct scatterlist sg; int plen; - int err; + int i; hdr = (struct ieee80211_hdr *)skb->data; @@ -463,18 +453,9 @@ static int lib80211_tkip_decrypt(struct sk_buff *skb, int hdr_len, void *priv) plen = skb->len - hdr_len - 12; - crypto_skcipher_setkey(tkey->rx_tfm_arc4, rc4key, 16); - sg_init_one(&sg, pos, plen + 4); - skcipher_request_set_tfm(req, tkey->rx_tfm_arc4); - skcipher_request_set_callback(req, 0, NULL, NULL); - skcipher_request_set_crypt(req, &sg, &sg, plen + 4, NULL); - err = crypto_skcipher_decrypt(req); - skcipher_request_zero(req); - if (err) { - net_dbg_ratelimited("TKIP: failed to decrypt received packet from %pM\n", - hdr->addr2); - return -7; - } + crypto_cipher_setkey(tkey->rx_tfm_arc4, rc4key, 16); + for (i = 0; i < plen + 4; i++) + crypto_cipher_decrypt_one(tkey->rx_tfm_arc4, pos + i, pos + i); crc = ~crc32_le(~0, pos, plen); icv[0] = crc; @@ -660,9 +641,9 @@ static int lib80211_tkip_set_key(void *key, int len, u8 * seq, void *priv) struct lib80211_tkip_data *tkey = priv; int keyidx; struct crypto_shash *tfm = tkey->tx_tfm_michael; - struct crypto_skcipher *tfm2 = tkey->tx_tfm_arc4; + struct crypto_cipher *tfm2 = tkey->tx_tfm_arc4; struct crypto_shash *tfm3 = tkey->rx_tfm_michael; - struct crypto_skcipher *tfm4 = tkey->rx_tfm_arc4; + struct crypto_cipher *tfm4 = tkey->rx_tfm_arc4; keyidx = tkey->key_idx; memset(tkey, 0, sizeof(*tkey)); diff --git a/net/wireless/lib80211_crypt_wep.c b/net/wireless/lib80211_crypt_wep.c index d05f58b0fd04..6015f6b542a6 100644 --- a/net/wireless/lib80211_crypt_wep.c +++ b/net/wireless/lib80211_crypt_wep.c @@ -22,7 +22,7 @@ #include <net/lib80211.h> -#include <crypto/skcipher.h> +#include <linux/crypto.h> #include <linux/crc32.h> MODULE_AUTHOR("Jouni Malinen"); @@ -35,8 +35,8 @@ struct lib80211_wep_data { u8 key[WEP_KEY_LEN + 1]; u8 key_len; u8 key_idx; - struct crypto_skcipher *tx_tfm; - struct crypto_skcipher *rx_tfm; + struct crypto_cipher *tx_tfm; + struct crypto_cipher *rx_tfm; }; static void *lib80211_wep_init(int keyidx) @@ -48,13 +48,13 @@ static void *lib80211_wep_init(int keyidx) goto fail; priv->key_idx = keyidx; - priv->tx_tfm = crypto_alloc_skcipher("ecb(arc4)", 0, CRYPTO_ALG_ASYNC); + priv->tx_tfm = crypto_alloc_cipher("arc4", 0, CRYPTO_ALG_ASYNC); if (IS_ERR(priv->tx_tfm)) { priv->tx_tfm = NULL; goto fail; } - priv->rx_tfm = crypto_alloc_skcipher("ecb(arc4)", 0, CRYPTO_ALG_ASYNC); + priv->rx_tfm = crypto_alloc_cipher("arc4", 0, CRYPTO_ALG_ASYNC); if (IS_ERR(priv->rx_tfm)) { priv->rx_tfm = NULL; goto fail; @@ -66,8 +66,8 @@ static void *lib80211_wep_init(int keyidx) fail: if (priv) { - crypto_free_skcipher(priv->tx_tfm); - crypto_free_skcipher(priv->rx_tfm); + crypto_free_cipher(priv->tx_tfm); + crypto_free_cipher(priv->rx_tfm); kfree(priv); } return NULL; @@ -77,8 +77,8 @@ static void lib80211_wep_deinit(void *priv) { struct lib80211_wep_data *_priv = priv; if (_priv) { - crypto_free_skcipher(_priv->tx_tfm); - crypto_free_skcipher(_priv->rx_tfm); + crypto_free_cipher(_priv->tx_tfm); + crypto_free_cipher(_priv->rx_tfm); } kfree(priv); } @@ -129,12 +129,10 @@ static int lib80211_wep_build_iv(struct sk_buff *skb, int hdr_len, static int lib80211_wep_encrypt(struct sk_buff *skb, int hdr_len, void *priv) { struct lib80211_wep_data *wep = priv; - SKCIPHER_REQUEST_ON_STACK(req, wep->tx_tfm); u32 crc, klen, len; u8 *pos, *icv; - struct scatterlist sg; u8 key[WEP_KEY_LEN + 3]; - int err; + int i; /* other checks are in lib80211_wep_build_iv */ if (skb_tailroom(skb) < 4) @@ -162,14 +160,12 @@ static int lib80211_wep_encrypt(struct sk_buff *skb, int hdr_len, void *priv) icv[2] = crc >> 16; icv[3] = crc >> 24; - crypto_skcipher_setkey(wep->tx_tfm, key, klen); - sg_init_one(&sg, pos, len + 4); - skcipher_request_set_tfm(req, wep->tx_tfm); - skcipher_request_set_callback(req, 0, NULL, NULL); - skcipher_request_set_crypt(req, &sg, &sg, len + 4, NULL); - err = crypto_skcipher_encrypt(req); - skcipher_request_zero(req); - return err; + crypto_cipher_setkey(wep->tx_tfm, key, klen); + + for (i = 0; i < len + 4; i++) + crypto_cipher_encrypt_one(wep->tx_tfm, pos + i, pos + i); + + return 0; } /* Perform WEP decryption on given buffer. Buffer includes whole WEP part of @@ -182,12 +178,10 @@ static int lib80211_wep_encrypt(struct sk_buff *skb, int hdr_len, void *priv) static int lib80211_wep_decrypt(struct sk_buff *skb, int hdr_len, void *priv) { struct lib80211_wep_data *wep = priv; - SKCIPHER_REQUEST_ON_STACK(req, wep->rx_tfm); u32 crc, klen, plen; u8 key[WEP_KEY_LEN + 3]; u8 keyidx, *pos, icv[4]; - struct scatterlist sg; - int err; + int i; if (skb->len < hdr_len + 8) return -1; @@ -208,15 +202,9 @@ static int lib80211_wep_decrypt(struct sk_buff *skb, int hdr_len, void *priv) /* Apply RC4 to data and compute CRC32 over decrypted data */ plen = skb->len - hdr_len - 8; - crypto_skcipher_setkey(wep->rx_tfm, key, klen); - sg_init_one(&sg, pos, plen + 4); - skcipher_request_set_tfm(req, wep->rx_tfm); - skcipher_request_set_callback(req, 0, NULL, NULL); - skcipher_request_set_crypt(req, &sg, &sg, plen + 4, NULL); - err = crypto_skcipher_decrypt(req); - skcipher_request_zero(req); - if (err) - return -7; + crypto_cipher_setkey(wep->rx_tfm, key, klen); + for (i = 0; i < plen + 4; i++) + crypto_cipher_decrypt_one(wep->rx_tfm, pos + i, pos + i); crc = ~crc32_le(~0, pos, plen); icv[0] = crc; diff --git a/net/wireless/nl80211.c b/net/wireless/nl80211.c index d5f9b5235cdd..744b5851bbf9 100644 --- a/net/wireless/nl80211.c +++ b/net/wireless/nl80211.c @@ -200,7 +200,46 @@ cfg80211_get_dev_from_info(struct net *netns, struct genl_info *info) return __cfg80211_rdev_from_attrs(netns, info->attrs); } +static int validate_ie_attr(const struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + const u8 *pos; + int len; + + pos = nla_data(attr); + len = nla_len(attr); + + while (len) { + u8 elemlen; + + if (len < 2) + goto error; + len -= 2; + + elemlen = pos[1]; + if (elemlen > len) + goto error; + + len -= elemlen; + pos += 2 + elemlen; + } + + return 0; +error: + NL_SET_ERR_MSG_ATTR(extack, attr, "malformed information elements"); + return -EINVAL; +} + /* policy for the attributes */ +static const struct nla_policy +nl80211_ftm_responder_policy[NL80211_FTM_RESP_ATTR_MAX + 1] = { + [NL80211_FTM_RESP_ATTR_ENABLED] = { .type = NLA_FLAG, }, + [NL80211_FTM_RESP_ATTR_LCI] = { .type = NLA_BINARY, + .len = U8_MAX }, + [NL80211_FTM_RESP_ATTR_CIVICLOC] = { .type = NLA_BINARY, + .len = U8_MAX }, +}; + static const struct nla_policy nl80211_policy[NUM_NL80211_ATTR] = { [NL80211_ATTR_WIPHY] = { .type = NLA_U32 }, [NL80211_ATTR_WIPHY_NAME] = { .type = NLA_NUL_STRING, @@ -213,14 +252,14 @@ static const struct nla_policy nl80211_policy[NUM_NL80211_ATTR] = { [NL80211_ATTR_CENTER_FREQ1] = { .type = NLA_U32 }, [NL80211_ATTR_CENTER_FREQ2] = { .type = NLA_U32 }, - [NL80211_ATTR_WIPHY_RETRY_SHORT] = { .type = NLA_U8 }, - [NL80211_ATTR_WIPHY_RETRY_LONG] = { .type = NLA_U8 }, + [NL80211_ATTR_WIPHY_RETRY_SHORT] = NLA_POLICY_MIN(NLA_U8, 1), + [NL80211_ATTR_WIPHY_RETRY_LONG] = NLA_POLICY_MIN(NLA_U8, 1), [NL80211_ATTR_WIPHY_FRAG_THRESHOLD] = { .type = NLA_U32 }, [NL80211_ATTR_WIPHY_RTS_THRESHOLD] = { .type = NLA_U32 }, [NL80211_ATTR_WIPHY_COVERAGE_CLASS] = { .type = NLA_U8 }, [NL80211_ATTR_WIPHY_DYN_ACK] = { .type = NLA_FLAG }, - [NL80211_ATTR_IFTYPE] = { .type = NLA_U32 }, + [NL80211_ATTR_IFTYPE] = NLA_POLICY_MAX(NLA_U32, NL80211_IFTYPE_MAX), [NL80211_ATTR_IFINDEX] = { .type = NLA_U32 }, [NL80211_ATTR_IFNAME] = { .type = NLA_NUL_STRING, .len = IFNAMSIZ-1 }, @@ -230,24 +269,28 @@ static const struct nla_policy nl80211_policy[NUM_NL80211_ATTR] = { [NL80211_ATTR_KEY] = { .type = NLA_NESTED, }, [NL80211_ATTR_KEY_DATA] = { .type = NLA_BINARY, .len = WLAN_MAX_KEY_LEN }, - [NL80211_ATTR_KEY_IDX] = { .type = NLA_U8 }, + [NL80211_ATTR_KEY_IDX] = NLA_POLICY_MAX(NLA_U8, 5), [NL80211_ATTR_KEY_CIPHER] = { .type = NLA_U32 }, [NL80211_ATTR_KEY_DEFAULT] = { .type = NLA_FLAG }, [NL80211_ATTR_KEY_SEQ] = { .type = NLA_BINARY, .len = 16 }, - [NL80211_ATTR_KEY_TYPE] = { .type = NLA_U32 }, + [NL80211_ATTR_KEY_TYPE] = + NLA_POLICY_MAX(NLA_U32, NUM_NL80211_KEYTYPES), [NL80211_ATTR_BEACON_INTERVAL] = { .type = NLA_U32 }, [NL80211_ATTR_DTIM_PERIOD] = { .type = NLA_U32 }, [NL80211_ATTR_BEACON_HEAD] = { .type = NLA_BINARY, .len = IEEE80211_MAX_DATA_LEN }, - [NL80211_ATTR_BEACON_TAIL] = { .type = NLA_BINARY, - .len = IEEE80211_MAX_DATA_LEN }, - [NL80211_ATTR_STA_AID] = { .type = NLA_U16 }, + [NL80211_ATTR_BEACON_TAIL] = + NLA_POLICY_VALIDATE_FN(NLA_BINARY, validate_ie_attr, + IEEE80211_MAX_DATA_LEN), + [NL80211_ATTR_STA_AID] = + NLA_POLICY_RANGE(NLA_U16, 1, IEEE80211_MAX_AID), [NL80211_ATTR_STA_FLAGS] = { .type = NLA_NESTED }, [NL80211_ATTR_STA_LISTEN_INTERVAL] = { .type = NLA_U16 }, [NL80211_ATTR_STA_SUPPORTED_RATES] = { .type = NLA_BINARY, .len = NL80211_MAX_SUPP_RATES }, - [NL80211_ATTR_STA_PLINK_ACTION] = { .type = NLA_U8 }, + [NL80211_ATTR_STA_PLINK_ACTION] = + NLA_POLICY_MAX(NLA_U8, NUM_NL80211_PLINK_ACTIONS - 1), [NL80211_ATTR_STA_VLAN] = { .type = NLA_U32 }, [NL80211_ATTR_MNTR_FLAGS] = { /* NLA_NESTED can't be empty */ }, [NL80211_ATTR_MESH_ID] = { .type = NLA_BINARY, @@ -270,8 +313,9 @@ static const struct nla_policy nl80211_policy[NUM_NL80211_ATTR] = { [NL80211_ATTR_HT_CAPABILITY] = { .len = NL80211_HT_CAPABILITY_LEN }, [NL80211_ATTR_MGMT_SUBTYPE] = { .type = NLA_U8 }, - [NL80211_ATTR_IE] = { .type = NLA_BINARY, - .len = IEEE80211_MAX_DATA_LEN }, + [NL80211_ATTR_IE] = NLA_POLICY_VALIDATE_FN(NLA_BINARY, + validate_ie_attr, + IEEE80211_MAX_DATA_LEN), [NL80211_ATTR_SCAN_FREQUENCIES] = { .type = NLA_NESTED }, [NL80211_ATTR_SCAN_SSIDS] = { .type = NLA_NESTED }, @@ -281,7 +325,9 @@ static const struct nla_policy nl80211_policy[NUM_NL80211_ATTR] = { [NL80211_ATTR_REASON_CODE] = { .type = NLA_U16 }, [NL80211_ATTR_FREQ_FIXED] = { .type = NLA_FLAG }, [NL80211_ATTR_TIMED_OUT] = { .type = NLA_FLAG }, - [NL80211_ATTR_USE_MFP] = { .type = NLA_U32 }, + [NL80211_ATTR_USE_MFP] = NLA_POLICY_RANGE(NLA_U32, + NL80211_MFP_NO, + NL80211_MFP_OPTIONAL), [NL80211_ATTR_STA_FLAGS2] = { .len = sizeof(struct nl80211_sta_flag_update), }, @@ -301,7 +347,9 @@ static const struct nla_policy nl80211_policy[NUM_NL80211_ATTR] = { [NL80211_ATTR_FRAME] = { .type = NLA_BINARY, .len = IEEE80211_MAX_DATA_LEN }, [NL80211_ATTR_FRAME_MATCH] = { .type = NLA_BINARY, }, - [NL80211_ATTR_PS_STATE] = { .type = NLA_U32 }, + [NL80211_ATTR_PS_STATE] = NLA_POLICY_RANGE(NLA_U32, + NL80211_PS_DISABLED, + NL80211_PS_ENABLED), [NL80211_ATTR_CQM] = { .type = NLA_NESTED, }, [NL80211_ATTR_LOCAL_STATE_CHANGE] = { .type = NLA_FLAG }, [NL80211_ATTR_AP_ISOLATE] = { .type = NLA_U8 }, @@ -314,15 +362,23 @@ static const struct nla_policy nl80211_policy[NUM_NL80211_ATTR] = { [NL80211_ATTR_OFFCHANNEL_TX_OK] = { .type = NLA_FLAG }, [NL80211_ATTR_KEY_DEFAULT_TYPES] = { .type = NLA_NESTED }, [NL80211_ATTR_WOWLAN_TRIGGERS] = { .type = NLA_NESTED }, - [NL80211_ATTR_STA_PLINK_STATE] = { .type = NLA_U8 }, + [NL80211_ATTR_STA_PLINK_STATE] = + NLA_POLICY_MAX(NLA_U8, NUM_NL80211_PLINK_STATES - 1), + [NL80211_ATTR_MESH_PEER_AID] = + NLA_POLICY_RANGE(NLA_U16, 1, IEEE80211_MAX_AID), [NL80211_ATTR_SCHED_SCAN_INTERVAL] = { .type = NLA_U32 }, [NL80211_ATTR_REKEY_DATA] = { .type = NLA_NESTED }, [NL80211_ATTR_SCAN_SUPP_RATES] = { .type = NLA_NESTED }, - [NL80211_ATTR_HIDDEN_SSID] = { .type = NLA_U32 }, - [NL80211_ATTR_IE_PROBE_RESP] = { .type = NLA_BINARY, - .len = IEEE80211_MAX_DATA_LEN }, - [NL80211_ATTR_IE_ASSOC_RESP] = { .type = NLA_BINARY, - .len = IEEE80211_MAX_DATA_LEN }, + [NL80211_ATTR_HIDDEN_SSID] = + NLA_POLICY_RANGE(NLA_U32, + NL80211_HIDDEN_SSID_NOT_IN_USE, + NL80211_HIDDEN_SSID_ZERO_CONTENTS), + [NL80211_ATTR_IE_PROBE_RESP] = + NLA_POLICY_VALIDATE_FN(NLA_BINARY, validate_ie_attr, + IEEE80211_MAX_DATA_LEN), + [NL80211_ATTR_IE_ASSOC_RESP] = + NLA_POLICY_VALIDATE_FN(NLA_BINARY, validate_ie_attr, + IEEE80211_MAX_DATA_LEN), [NL80211_ATTR_ROAM_SUPPORT] = { .type = NLA_FLAG }, [NL80211_ATTR_SCHED_SCAN_MATCH] = { .type = NLA_NESTED }, [NL80211_ATTR_TX_NO_CCK_RATE] = { .type = NLA_FLAG }, @@ -348,9 +404,12 @@ static const struct nla_policy nl80211_policy[NUM_NL80211_ATTR] = { [NL80211_ATTR_AUTH_DATA] = { .type = NLA_BINARY, }, [NL80211_ATTR_VHT_CAPABILITY] = { .len = NL80211_VHT_CAPABILITY_LEN }, [NL80211_ATTR_SCAN_FLAGS] = { .type = NLA_U32 }, - [NL80211_ATTR_P2P_CTWINDOW] = { .type = NLA_U8 }, - [NL80211_ATTR_P2P_OPPPS] = { .type = NLA_U8 }, - [NL80211_ATTR_LOCAL_MESH_POWER_MODE] = {. type = NLA_U32 }, + [NL80211_ATTR_P2P_CTWINDOW] = NLA_POLICY_MAX(NLA_U8, 127), + [NL80211_ATTR_P2P_OPPPS] = NLA_POLICY_MAX(NLA_U8, 1), + [NL80211_ATTR_LOCAL_MESH_POWER_MODE] = + NLA_POLICY_RANGE(NLA_U32, + NL80211_MESH_POWER_UNKNOWN + 1, + NL80211_MESH_POWER_MAX), [NL80211_ATTR_ACL_POLICY] = {. type = NLA_U32 }, [NL80211_ATTR_MAC_ADDRS] = { .type = NLA_NESTED }, [NL80211_ATTR_STA_CAPABILITY] = { .type = NLA_U16 }, @@ -363,7 +422,8 @@ static const struct nla_policy nl80211_policy[NUM_NL80211_ATTR] = { [NL80211_ATTR_MDID] = { .type = NLA_U16 }, [NL80211_ATTR_IE_RIC] = { .type = NLA_BINARY, .len = IEEE80211_MAX_DATA_LEN }, - [NL80211_ATTR_PEER_AID] = { .type = NLA_U16 }, + [NL80211_ATTR_PEER_AID] = + NLA_POLICY_RANGE(NLA_U16, 1, IEEE80211_MAX_AID), [NL80211_ATTR_CH_SWITCH_COUNT] = { .type = NLA_U32 }, [NL80211_ATTR_CH_SWITCH_BLOCK_TX] = { .type = NLA_FLAG }, [NL80211_ATTR_CSA_IES] = { .type = NLA_NESTED }, @@ -384,8 +444,9 @@ static const struct nla_policy nl80211_policy[NUM_NL80211_ATTR] = { [NL80211_ATTR_SOCKET_OWNER] = { .type = NLA_FLAG }, [NL80211_ATTR_CSA_C_OFFSETS_TX] = { .type = NLA_BINARY }, [NL80211_ATTR_USE_RRM] = { .type = NLA_FLAG }, - [NL80211_ATTR_TSID] = { .type = NLA_U8 }, - [NL80211_ATTR_USER_PRIO] = { .type = NLA_U8 }, + [NL80211_ATTR_TSID] = NLA_POLICY_MAX(NLA_U8, IEEE80211_NUM_TIDS - 1), + [NL80211_ATTR_USER_PRIO] = + NLA_POLICY_MAX(NLA_U8, IEEE80211_NUM_UPS - 1), [NL80211_ATTR_ADMITTED_TIME] = { .type = NLA_U16 }, [NL80211_ATTR_SMPS_MODE] = { .type = NLA_U8 }, [NL80211_ATTR_MAC_MASK] = { .len = ETH_ALEN }, @@ -395,12 +456,13 @@ static const struct nla_policy nl80211_policy[NUM_NL80211_ATTR] = { [NL80211_ATTR_REG_INDOOR] = { .type = NLA_FLAG }, [NL80211_ATTR_PBSS] = { .type = NLA_FLAG }, [NL80211_ATTR_BSS_SELECT] = { .type = NLA_NESTED }, - [NL80211_ATTR_STA_SUPPORT_P2P_PS] = { .type = NLA_U8 }, + [NL80211_ATTR_STA_SUPPORT_P2P_PS] = + NLA_POLICY_MAX(NLA_U8, NUM_NL80211_P2P_PS_STATUS - 1), [NL80211_ATTR_MU_MIMO_GROUP_DATA] = { .len = VHT_MUMIMO_GROUPS_DATA_LEN }, [NL80211_ATTR_MU_MIMO_FOLLOW_MAC_ADDR] = { .len = ETH_ALEN }, - [NL80211_ATTR_NAN_MASTER_PREF] = { .type = NLA_U8 }, + [NL80211_ATTR_NAN_MASTER_PREF] = NLA_POLICY_MIN(NLA_U8, 1), [NL80211_ATTR_BANDS] = { .type = NLA_U32 }, [NL80211_ATTR_NAN_FUNC] = { .type = NLA_NESTED }, [NL80211_ATTR_FILS_KEK] = { .type = NLA_BINARY, @@ -430,6 +492,11 @@ static const struct nla_policy nl80211_policy[NUM_NL80211_ATTR] = { [NL80211_ATTR_TXQ_QUANTUM] = { .type = NLA_U32 }, [NL80211_ATTR_HE_CAPABILITY] = { .type = NLA_BINARY, .len = NL80211_HE_MAX_CAPABILITY_LEN }, + + [NL80211_ATTR_FTM_RESPONDER] = { + .type = NLA_NESTED, + .validation_data = nl80211_ftm_responder_policy, + }, }; /* policy for the key attributes */ @@ -440,7 +507,7 @@ static const struct nla_policy nl80211_key_policy[NL80211_KEY_MAX + 1] = { [NL80211_KEY_SEQ] = { .type = NLA_BINARY, .len = 16 }, [NL80211_KEY_DEFAULT] = { .type = NLA_FLAG }, [NL80211_KEY_DEFAULT_MGMT] = { .type = NLA_FLAG }, - [NL80211_KEY_TYPE] = { .type = NLA_U32 }, + [NL80211_KEY_TYPE] = NLA_POLICY_MAX(NLA_U32, NUM_NL80211_KEYTYPES - 1), [NL80211_KEY_DEFAULT_TYPES] = { .type = NLA_NESTED }, }; @@ -491,7 +558,10 @@ nl80211_wowlan_tcp_policy[NUM_NL80211_WOWLAN_TCP] = { static const struct nla_policy nl80211_coalesce_policy[NUM_NL80211_ATTR_COALESCE_RULE] = { [NL80211_ATTR_COALESCE_RULE_DELAY] = { .type = NLA_U32 }, - [NL80211_ATTR_COALESCE_RULE_CONDITION] = { .type = NLA_U32 }, + [NL80211_ATTR_COALESCE_RULE_CONDITION] = + NLA_POLICY_RANGE(NLA_U32, + NL80211_COALESCE_CONDITION_MATCH, + NL80211_COALESCE_CONDITION_NO_MATCH), [NL80211_ATTR_COALESCE_RULE_PKT_PATTERN] = { .type = NLA_NESTED }, }; @@ -567,8 +637,7 @@ nl80211_packet_pattern_policy[MAX_NL80211_PKTPAT + 1] = { [NL80211_PKTPAT_OFFSET] = { .type = NLA_U32 }, }; -static int nl80211_prepare_wdev_dump(struct sk_buff *skb, - struct netlink_callback *cb, +static int nl80211_prepare_wdev_dump(struct netlink_callback *cb, struct cfg80211_registered_device **rdev, struct wireless_dev **wdev) { @@ -582,7 +651,7 @@ static int nl80211_prepare_wdev_dump(struct sk_buff *skb, return err; *wdev = __cfg80211_wdev_from_attrs( - sock_net(skb->sk), + sock_net(cb->skb->sk), genl_family_attrbuf(&nl80211_fam)); if (IS_ERR(*wdev)) return PTR_ERR(*wdev); @@ -614,36 +683,6 @@ static int nl80211_prepare_wdev_dump(struct sk_buff *skb, return 0; } -/* IE validation */ -static bool is_valid_ie_attr(const struct nlattr *attr) -{ - const u8 *pos; - int len; - - if (!attr) - return true; - - pos = nla_data(attr); - len = nla_len(attr); - - while (len) { - u8 elemlen; - - if (len < 2) - return false; - len -= 2; - - elemlen = pos[1]; - if (elemlen > len) - return false; - - len -= elemlen; - pos += 2 + elemlen; - } - - return true; -} - /* message building helper */ static inline void *nl80211hdr_put(struct sk_buff *skb, u32 portid, u32 seq, int flags, u8 cmd) @@ -858,12 +897,8 @@ static int nl80211_parse_key_new(struct genl_info *info, struct nlattr *key, if (tb[NL80211_KEY_CIPHER]) k->p.cipher = nla_get_u32(tb[NL80211_KEY_CIPHER]); - if (tb[NL80211_KEY_TYPE]) { + if (tb[NL80211_KEY_TYPE]) k->type = nla_get_u32(tb[NL80211_KEY_TYPE]); - if (k->type < 0 || k->type >= NUM_NL80211_KEYTYPES) - return genl_err_attr(info, -EINVAL, - tb[NL80211_KEY_TYPE]); - } if (tb[NL80211_KEY_DEFAULT_TYPES]) { struct nlattr *kdt[NUM_NL80211_KEY_DEFAULT_TYPES]; @@ -910,13 +945,8 @@ static int nl80211_parse_key_old(struct genl_info *info, struct key_parse *k) if (k->defmgmt) k->def_multi = true; - if (info->attrs[NL80211_ATTR_KEY_TYPE]) { + if (info->attrs[NL80211_ATTR_KEY_TYPE]) k->type = nla_get_u32(info->attrs[NL80211_ATTR_KEY_TYPE]); - if (k->type < 0 || k->type >= NUM_NL80211_KEYTYPES) { - GENL_SET_ERR_MSG(info, "key type out of range"); - return -EINVAL; - } - } if (info->attrs[NL80211_ATTR_KEY_DEFAULT_TYPES]) { struct nlattr *kdt[NUM_NL80211_KEY_DEFAULT_TYPES]; @@ -2292,12 +2322,14 @@ static int nl80211_parse_chandef(struct cfg80211_registered_device *rdev, struct genl_info *info, struct cfg80211_chan_def *chandef) { + struct netlink_ext_ack *extack = info->extack; + struct nlattr **attrs = info->attrs; u32 control_freq; - if (!info->attrs[NL80211_ATTR_WIPHY_FREQ]) + if (!attrs[NL80211_ATTR_WIPHY_FREQ]) return -EINVAL; - control_freq = nla_get_u32(info->attrs[NL80211_ATTR_WIPHY_FREQ]); + control_freq = nla_get_u32(attrs[NL80211_ATTR_WIPHY_FREQ]); chandef->chan = ieee80211_get_channel(&rdev->wiphy, control_freq); chandef->width = NL80211_CHAN_WIDTH_20_NOHT; @@ -2305,14 +2337,16 @@ static int nl80211_parse_chandef(struct cfg80211_registered_device *rdev, chandef->center_freq2 = 0; /* Primary channel not allowed */ - if (!chandef->chan || chandef->chan->flags & IEEE80211_CHAN_DISABLED) + if (!chandef->chan || chandef->chan->flags & IEEE80211_CHAN_DISABLED) { + NL_SET_ERR_MSG_ATTR(extack, attrs[NL80211_ATTR_WIPHY_FREQ], + "Channel is disabled"); return -EINVAL; + } - if (info->attrs[NL80211_ATTR_WIPHY_CHANNEL_TYPE]) { + if (attrs[NL80211_ATTR_WIPHY_CHANNEL_TYPE]) { enum nl80211_channel_type chantype; - chantype = nla_get_u32( - info->attrs[NL80211_ATTR_WIPHY_CHANNEL_TYPE]); + chantype = nla_get_u32(attrs[NL80211_ATTR_WIPHY_CHANNEL_TYPE]); switch (chantype) { case NL80211_CHAN_NO_HT: @@ -2322,42 +2356,56 @@ static int nl80211_parse_chandef(struct cfg80211_registered_device *rdev, cfg80211_chandef_create(chandef, chandef->chan, chantype); /* user input for center_freq is incorrect */ - if (info->attrs[NL80211_ATTR_CENTER_FREQ1] && - chandef->center_freq1 != nla_get_u32( - info->attrs[NL80211_ATTR_CENTER_FREQ1])) + if (attrs[NL80211_ATTR_CENTER_FREQ1] && + chandef->center_freq1 != nla_get_u32(attrs[NL80211_ATTR_CENTER_FREQ1])) { + NL_SET_ERR_MSG_ATTR(extack, + attrs[NL80211_ATTR_CENTER_FREQ1], + "bad center frequency 1"); return -EINVAL; + } /* center_freq2 must be zero */ - if (info->attrs[NL80211_ATTR_CENTER_FREQ2] && - nla_get_u32(info->attrs[NL80211_ATTR_CENTER_FREQ2])) + if (attrs[NL80211_ATTR_CENTER_FREQ2] && + nla_get_u32(attrs[NL80211_ATTR_CENTER_FREQ2])) { + NL_SET_ERR_MSG_ATTR(extack, + attrs[NL80211_ATTR_CENTER_FREQ2], + "center frequency 2 can't be used"); return -EINVAL; + } break; default: + NL_SET_ERR_MSG_ATTR(extack, + attrs[NL80211_ATTR_WIPHY_CHANNEL_TYPE], + "invalid channel type"); return -EINVAL; } - } else if (info->attrs[NL80211_ATTR_CHANNEL_WIDTH]) { + } else if (attrs[NL80211_ATTR_CHANNEL_WIDTH]) { chandef->width = - nla_get_u32(info->attrs[NL80211_ATTR_CHANNEL_WIDTH]); - if (info->attrs[NL80211_ATTR_CENTER_FREQ1]) + nla_get_u32(attrs[NL80211_ATTR_CHANNEL_WIDTH]); + if (attrs[NL80211_ATTR_CENTER_FREQ1]) chandef->center_freq1 = - nla_get_u32( - info->attrs[NL80211_ATTR_CENTER_FREQ1]); - if (info->attrs[NL80211_ATTR_CENTER_FREQ2]) + nla_get_u32(attrs[NL80211_ATTR_CENTER_FREQ1]); + if (attrs[NL80211_ATTR_CENTER_FREQ2]) chandef->center_freq2 = - nla_get_u32( - info->attrs[NL80211_ATTR_CENTER_FREQ2]); + nla_get_u32(attrs[NL80211_ATTR_CENTER_FREQ2]); } - if (!cfg80211_chandef_valid(chandef)) + if (!cfg80211_chandef_valid(chandef)) { + NL_SET_ERR_MSG(extack, "invalid channel definition"); return -EINVAL; + } if (!cfg80211_chandef_usable(&rdev->wiphy, chandef, - IEEE80211_CHAN_DISABLED)) + IEEE80211_CHAN_DISABLED)) { + NL_SET_ERR_MSG(extack, "(extension) channel is disabled"); return -EINVAL; + } if ((chandef->width == NL80211_CHAN_WIDTH_5 || chandef->width == NL80211_CHAN_WIDTH_10) && - !(rdev->wiphy.flags & WIPHY_FLAG_SUPPORTS_5_10_MHZ)) + !(rdev->wiphy.flags & WIPHY_FLAG_SUPPORTS_5_10_MHZ)) { + NL_SET_ERR_MSG(extack, "5/10 MHz not supported"); return -EINVAL; + } return 0; } @@ -2617,8 +2665,6 @@ static int nl80211_set_wiphy(struct sk_buff *skb, struct genl_info *info) if (info->attrs[NL80211_ATTR_WIPHY_RETRY_SHORT]) { retry_short = nla_get_u8( info->attrs[NL80211_ATTR_WIPHY_RETRY_SHORT]); - if (retry_short == 0) - return -EINVAL; changed |= WIPHY_PARAM_RETRY_SHORT; } @@ -2626,8 +2672,6 @@ static int nl80211_set_wiphy(struct sk_buff *skb, struct genl_info *info) if (info->attrs[NL80211_ATTR_WIPHY_RETRY_LONG]) { retry_long = nla_get_u8( info->attrs[NL80211_ATTR_WIPHY_RETRY_LONG]); - if (retry_long == 0) - return -EINVAL; changed |= WIPHY_PARAM_RETRY_LONG; } @@ -3119,8 +3163,6 @@ static int nl80211_set_interface(struct sk_buff *skb, struct genl_info *info) ntype = nla_get_u32(info->attrs[NL80211_ATTR_IFTYPE]); if (otype != ntype) change = true; - if (ntype > NL80211_IFTYPE_MAX) - return -EINVAL; } if (info->attrs[NL80211_ATTR_MESH_ID]) { @@ -3185,11 +3227,8 @@ static int nl80211_new_interface(struct sk_buff *skb, struct genl_info *info) if (!info->attrs[NL80211_ATTR_IFNAME]) return -EINVAL; - if (info->attrs[NL80211_ATTR_IFTYPE]) { + if (info->attrs[NL80211_ATTR_IFTYPE]) type = nla_get_u32(info->attrs[NL80211_ATTR_IFTYPE]); - if (type > NL80211_IFTYPE_MAX) - return -EINVAL; - } if (!rdev->ops->add_virtual_intf || !(rdev->wiphy.interface_modes & (1 << type))) @@ -3252,15 +3291,7 @@ static int nl80211_new_interface(struct sk_buff *skb, struct genl_info *info) * P2P Device and NAN do not have a netdev, so don't go * through the netdev notifier and must be added here */ - mutex_init(&wdev->mtx); - INIT_LIST_HEAD(&wdev->event_list); - spin_lock_init(&wdev->event_lock); - INIT_LIST_HEAD(&wdev->mgmt_registrations); - spin_lock_init(&wdev->mgmt_registrations_lock); - - wdev->identifier = ++rdev->wdev_id; - list_add_rcu(&wdev->list, &rdev->wiphy.wdev_list); - rdev->devlist_generation++; + cfg80211_init_wdev(rdev, wdev); break; default: break; @@ -3272,15 +3303,6 @@ static int nl80211_new_interface(struct sk_buff *skb, struct genl_info *info) return -ENOBUFS; } - /* - * For wdevs which have no associated netdev object (e.g. of type - * NL80211_IFTYPE_P2P_DEVICE), emit the NEW_INTERFACE event here. - * For all other types, the event will be generated from the - * netdev notifier - */ - if (!wdev->netdev) - nl80211_notify_iface(rdev, wdev, NL80211_CMD_NEW_INTERFACE); - return genlmsg_reply(msg, info); } @@ -3359,7 +3381,7 @@ static void get_key_callback(void *c, struct key_params *params) params->cipher))) goto nla_put_failure; - if (nla_put_u8(cookie->msg, NL80211_ATTR_KEY_IDX, cookie->idx)) + if (nla_put_u8(cookie->msg, NL80211_KEY_IDX, cookie->idx)) goto nla_put_failure; nla_nest_end(cookie->msg, key); @@ -3386,9 +3408,6 @@ static int nl80211_get_key(struct sk_buff *skb, struct genl_info *info) if (info->attrs[NL80211_ATTR_KEY_IDX]) key_idx = nla_get_u8(info->attrs[NL80211_ATTR_KEY_IDX]); - if (key_idx > 5) - return -EINVAL; - if (info->attrs[NL80211_ATTR_MAC]) mac_addr = nla_data(info->attrs[NL80211_ATTR_MAC]); @@ -3396,8 +3415,6 @@ static int nl80211_get_key(struct sk_buff *skb, struct genl_info *info) if (info->attrs[NL80211_ATTR_KEY_TYPE]) { u32 kt = nla_get_u32(info->attrs[NL80211_ATTR_KEY_TYPE]); - if (kt >= NUM_NL80211_KEYTYPES) - return -EINVAL; if (kt != NL80211_KEYTYPE_GROUP && kt != NL80211_KEYTYPE_PAIRWISE) return -EINVAL; @@ -3756,6 +3773,7 @@ static bool ht_rateset_to_mask(struct ieee80211_supported_band *sband, return false; /* check availability */ + ridx = array_index_nospec(ridx, IEEE80211_HT_MCS_MASK_LEN); if (sband->ht_cap.mcs.rx_mask[ridx] & rbit) mcs[ridx] |= rbit; else @@ -3997,16 +4015,12 @@ static int validate_beacon_tx_rate(struct cfg80211_registered_device *rdev, return 0; } -static int nl80211_parse_beacon(struct nlattr *attrs[], +static int nl80211_parse_beacon(struct cfg80211_registered_device *rdev, + struct nlattr *attrs[], struct cfg80211_beacon_data *bcn) { bool haveinfo = false; - - if (!is_valid_ie_attr(attrs[NL80211_ATTR_BEACON_TAIL]) || - !is_valid_ie_attr(attrs[NL80211_ATTR_IE]) || - !is_valid_ie_attr(attrs[NL80211_ATTR_IE_PROBE_RESP]) || - !is_valid_ie_attr(attrs[NL80211_ATTR_IE_ASSOC_RESP])) - return -EINVAL; + int err; memset(bcn, 0, sizeof(*bcn)); @@ -4051,6 +4065,35 @@ static int nl80211_parse_beacon(struct nlattr *attrs[], bcn->probe_resp_len = nla_len(attrs[NL80211_ATTR_PROBE_RESP]); } + if (attrs[NL80211_ATTR_FTM_RESPONDER]) { + struct nlattr *tb[NL80211_FTM_RESP_ATTR_MAX + 1]; + + err = nla_parse_nested(tb, NL80211_FTM_RESP_ATTR_MAX, + attrs[NL80211_ATTR_FTM_RESPONDER], + NULL, NULL); + if (err) + return err; + + if (tb[NL80211_FTM_RESP_ATTR_ENABLED] && + wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_ENABLE_FTM_RESPONDER)) + bcn->ftm_responder = 1; + else + return -EOPNOTSUPP; + + if (tb[NL80211_FTM_RESP_ATTR_LCI]) { + bcn->lci = nla_data(tb[NL80211_FTM_RESP_ATTR_LCI]); + bcn->lci_len = nla_len(tb[NL80211_FTM_RESP_ATTR_LCI]); + } + + if (tb[NL80211_FTM_RESP_ATTR_CIVICLOC]) { + bcn->civicloc = nla_data(tb[NL80211_FTM_RESP_ATTR_CIVICLOC]); + bcn->civicloc_len = nla_len(tb[NL80211_FTM_RESP_ATTR_CIVICLOC]); + } + } else { + bcn->ftm_responder = -1; + } + return 0; } @@ -4197,7 +4240,7 @@ static int nl80211_start_ap(struct sk_buff *skb, struct genl_info *info) !info->attrs[NL80211_ATTR_BEACON_HEAD]) return -EINVAL; - err = nl80211_parse_beacon(info->attrs, ¶ms.beacon); + err = nl80211_parse_beacon(rdev, info->attrs, ¶ms.beacon); if (err) return err; @@ -4227,14 +4270,9 @@ static int nl80211_start_ap(struct sk_buff *skb, struct genl_info *info) return -EINVAL; } - if (info->attrs[NL80211_ATTR_HIDDEN_SSID]) { + if (info->attrs[NL80211_ATTR_HIDDEN_SSID]) params.hidden_ssid = nla_get_u32( info->attrs[NL80211_ATTR_HIDDEN_SSID]); - if (params.hidden_ssid != NL80211_HIDDEN_SSID_NOT_IN_USE && - params.hidden_ssid != NL80211_HIDDEN_SSID_ZERO_LEN && - params.hidden_ssid != NL80211_HIDDEN_SSID_ZERO_CONTENTS) - return -EINVAL; - } params.privacy = !!info->attrs[NL80211_ATTR_PRIVACY]; @@ -4264,8 +4302,6 @@ static int nl80211_start_ap(struct sk_buff *skb, struct genl_info *info) return -EINVAL; params.p2p_ctwindow = nla_get_u8(info->attrs[NL80211_ATTR_P2P_CTWINDOW]); - if (params.p2p_ctwindow > 127) - return -EINVAL; if (params.p2p_ctwindow != 0 && !(rdev->wiphy.features & NL80211_FEATURE_P2P_GO_CTWIN)) return -EINVAL; @@ -4277,8 +4313,6 @@ static int nl80211_start_ap(struct sk_buff *skb, struct genl_info *info) if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_GO) return -EINVAL; tmp = nla_get_u8(info->attrs[NL80211_ATTR_P2P_OPPPS]); - if (tmp > 1) - return -EINVAL; params.p2p_opp_ps = tmp; if (params.p2p_opp_ps != 0 && !(rdev->wiphy.features & NL80211_FEATURE_P2P_GO_OPPPS)) @@ -4381,7 +4415,7 @@ static int nl80211_set_beacon(struct sk_buff *skb, struct genl_info *info) if (!wdev->beacon_interval) return -EINVAL; - err = nl80211_parse_beacon(info->attrs, ¶ms); + err = nl80211_parse_beacon(rdev, info->attrs, ¶ms); if (err) return err; @@ -4727,6 +4761,8 @@ static int nl80211_send_station(struct sk_buff *msg, u32 cmd, u32 portid, PUT_SINFO_U64(RX_DROP_MISC, rx_dropped_misc); PUT_SINFO_U64(BEACON_RX, rx_beacon); PUT_SINFO(BEACON_SIGNAL_AVG, rx_beacon_signal_avg, u8); + PUT_SINFO(RX_MPDUS, rx_mpdu_count, u32); + PUT_SINFO(FCS_ERROR_COUNT, fcs_err_count, u32); if (wiphy_ext_feature_isset(&rdev->wiphy, NL80211_EXT_FEATURE_ACK_SIGNAL_SUPPORT)) { PUT_SINFO(ACK_SIGNAL, ack_signal, u8); @@ -4810,7 +4846,7 @@ static int nl80211_dump_station(struct sk_buff *skb, int err; rtnl_lock(); - err = nl80211_prepare_wdev_dump(skb, cb, &rdev, &wdev); + err = nl80211_prepare_wdev_dump(cb, &rdev, &wdev); if (err) goto out_err; @@ -5215,17 +5251,11 @@ static int nl80211_set_station(struct sk_buff *skb, struct genl_info *info) else params.listen_interval = -1; - if (info->attrs[NL80211_ATTR_STA_SUPPORT_P2P_PS]) { - u8 tmp; - - tmp = nla_get_u8(info->attrs[NL80211_ATTR_STA_SUPPORT_P2P_PS]); - if (tmp >= NUM_NL80211_P2P_PS_STATUS) - return -EINVAL; - - params.support_p2p_ps = tmp; - } else { + if (info->attrs[NL80211_ATTR_STA_SUPPORT_P2P_PS]) + params.support_p2p_ps = + nla_get_u8(info->attrs[NL80211_ATTR_STA_SUPPORT_P2P_PS]); + else params.support_p2p_ps = -1; - } if (!info->attrs[NL80211_ATTR_MAC]) return -EINVAL; @@ -5255,38 +5285,23 @@ static int nl80211_set_station(struct sk_buff *skb, struct genl_info *info) if (parse_station_flags(info, dev->ieee80211_ptr->iftype, ¶ms)) return -EINVAL; - if (info->attrs[NL80211_ATTR_STA_PLINK_ACTION]) { + if (info->attrs[NL80211_ATTR_STA_PLINK_ACTION]) params.plink_action = nla_get_u8(info->attrs[NL80211_ATTR_STA_PLINK_ACTION]); - if (params.plink_action >= NUM_NL80211_PLINK_ACTIONS) - return -EINVAL; - } if (info->attrs[NL80211_ATTR_STA_PLINK_STATE]) { params.plink_state = nla_get_u8(info->attrs[NL80211_ATTR_STA_PLINK_STATE]); - if (params.plink_state >= NUM_NL80211_PLINK_STATES) - return -EINVAL; - if (info->attrs[NL80211_ATTR_MESH_PEER_AID]) { + if (info->attrs[NL80211_ATTR_MESH_PEER_AID]) params.peer_aid = nla_get_u16( info->attrs[NL80211_ATTR_MESH_PEER_AID]); - if (params.peer_aid > IEEE80211_MAX_AID) - return -EINVAL; - } params.sta_modify_mask |= STATION_PARAM_APPLY_PLINK_STATE; } - if (info->attrs[NL80211_ATTR_LOCAL_MESH_POWER_MODE]) { - enum nl80211_mesh_power_mode pm = nla_get_u32( + if (info->attrs[NL80211_ATTR_LOCAL_MESH_POWER_MODE]) + params.local_pm = nla_get_u32( info->attrs[NL80211_ATTR_LOCAL_MESH_POWER_MODE]); - if (pm <= NL80211_MESH_POWER_UNKNOWN || - pm > NL80211_MESH_POWER_MAX) - return -EINVAL; - - params.local_pm = pm; - } - if (info->attrs[NL80211_ATTR_OPMODE_NOTIF]) { params.opmode_notif_used = true; params.opmode_notif = @@ -5363,13 +5378,8 @@ static int nl80211_new_station(struct sk_buff *skb, struct genl_info *info) nla_get_u16(info->attrs[NL80211_ATTR_STA_LISTEN_INTERVAL]); if (info->attrs[NL80211_ATTR_STA_SUPPORT_P2P_PS]) { - u8 tmp; - - tmp = nla_get_u8(info->attrs[NL80211_ATTR_STA_SUPPORT_P2P_PS]); - if (tmp >= NUM_NL80211_P2P_PS_STATUS) - return -EINVAL; - - params.support_p2p_ps = tmp; + params.support_p2p_ps = + nla_get_u8(info->attrs[NL80211_ATTR_STA_SUPPORT_P2P_PS]); } else { /* * if not specified, assume it's supported for P2P GO interface, @@ -5383,8 +5393,6 @@ static int nl80211_new_station(struct sk_buff *skb, struct genl_info *info) params.aid = nla_get_u16(info->attrs[NL80211_ATTR_PEER_AID]); else params.aid = nla_get_u16(info->attrs[NL80211_ATTR_STA_AID]); - if (!params.aid || params.aid > IEEE80211_MAX_AID) - return -EINVAL; if (info->attrs[NL80211_ATTR_STA_CAPABILITY]) { params.capability = @@ -5424,12 +5432,9 @@ static int nl80211_new_station(struct sk_buff *skb, struct genl_info *info) nla_get_u8(info->attrs[NL80211_ATTR_OPMODE_NOTIF]); } - if (info->attrs[NL80211_ATTR_STA_PLINK_ACTION]) { + if (info->attrs[NL80211_ATTR_STA_PLINK_ACTION]) params.plink_action = nla_get_u8(info->attrs[NL80211_ATTR_STA_PLINK_ACTION]); - if (params.plink_action >= NUM_NL80211_PLINK_ACTIONS) - return -EINVAL; - } err = nl80211_parse_sta_channel_info(info, ¶ms); if (err) @@ -5661,7 +5666,7 @@ static int nl80211_dump_mpath(struct sk_buff *skb, int err; rtnl_lock(); - err = nl80211_prepare_wdev_dump(skb, cb, &rdev, &wdev); + err = nl80211_prepare_wdev_dump(cb, &rdev, &wdev); if (err) goto out_err; @@ -5857,7 +5862,7 @@ static int nl80211_dump_mpp(struct sk_buff *skb, int err; rtnl_lock(); - err = nl80211_prepare_wdev_dump(skb, cb, &rdev, &wdev); + err = nl80211_prepare_wdev_dump(cb, &rdev, &wdev); if (err) goto out_err; @@ -5939,9 +5944,7 @@ static int nl80211_set_bss(struct sk_buff *skb, struct genl_info *info) if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_GO) return -EINVAL; params.p2p_ctwindow = - nla_get_s8(info->attrs[NL80211_ATTR_P2P_CTWINDOW]); - if (params.p2p_ctwindow < 0) - return -EINVAL; + nla_get_u8(info->attrs[NL80211_ATTR_P2P_CTWINDOW]); if (params.p2p_ctwindow != 0 && !(rdev->wiphy.features & NL80211_FEATURE_P2P_GO_CTWIN)) return -EINVAL; @@ -5953,8 +5956,6 @@ static int nl80211_set_bss(struct sk_buff *skb, struct genl_info *info) if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_GO) return -EINVAL; tmp = nla_get_u8(info->attrs[NL80211_ATTR_P2P_OPPPS]); - if (tmp > 1) - return -EINVAL; params.p2p_opp_ps = tmp; if (params.p2p_opp_ps && !(rdev->wiphy.features & NL80211_FEATURE_P2P_GO_OPPPS)) @@ -6133,33 +6134,49 @@ static int nl80211_get_mesh_config(struct sk_buff *skb, return -ENOBUFS; } -static const struct nla_policy nl80211_meshconf_params_policy[NL80211_MESHCONF_ATTR_MAX+1] = { - [NL80211_MESHCONF_RETRY_TIMEOUT] = { .type = NLA_U16 }, - [NL80211_MESHCONF_CONFIRM_TIMEOUT] = { .type = NLA_U16 }, - [NL80211_MESHCONF_HOLDING_TIMEOUT] = { .type = NLA_U16 }, - [NL80211_MESHCONF_MAX_PEER_LINKS] = { .type = NLA_U16 }, - [NL80211_MESHCONF_MAX_RETRIES] = { .type = NLA_U8 }, - [NL80211_MESHCONF_TTL] = { .type = NLA_U8 }, - [NL80211_MESHCONF_ELEMENT_TTL] = { .type = NLA_U8 }, - [NL80211_MESHCONF_AUTO_OPEN_PLINKS] = { .type = NLA_U8 }, - [NL80211_MESHCONF_SYNC_OFFSET_MAX_NEIGHBOR] = { .type = NLA_U32 }, +static const struct nla_policy +nl80211_meshconf_params_policy[NL80211_MESHCONF_ATTR_MAX+1] = { + [NL80211_MESHCONF_RETRY_TIMEOUT] = + NLA_POLICY_RANGE(NLA_U16, 1, 255), + [NL80211_MESHCONF_CONFIRM_TIMEOUT] = + NLA_POLICY_RANGE(NLA_U16, 1, 255), + [NL80211_MESHCONF_HOLDING_TIMEOUT] = + NLA_POLICY_RANGE(NLA_U16, 1, 255), + [NL80211_MESHCONF_MAX_PEER_LINKS] = + NLA_POLICY_RANGE(NLA_U16, 0, 255), + [NL80211_MESHCONF_MAX_RETRIES] = NLA_POLICY_MAX(NLA_U8, 16), + [NL80211_MESHCONF_TTL] = NLA_POLICY_MIN(NLA_U8, 1), + [NL80211_MESHCONF_ELEMENT_TTL] = NLA_POLICY_MIN(NLA_U8, 1), + [NL80211_MESHCONF_AUTO_OPEN_PLINKS] = NLA_POLICY_MAX(NLA_U8, 1), + [NL80211_MESHCONF_SYNC_OFFSET_MAX_NEIGHBOR] = + NLA_POLICY_RANGE(NLA_U32, 1, 255), [NL80211_MESHCONF_HWMP_MAX_PREQ_RETRIES] = { .type = NLA_U8 }, [NL80211_MESHCONF_PATH_REFRESH_TIME] = { .type = NLA_U32 }, - [NL80211_MESHCONF_MIN_DISCOVERY_TIMEOUT] = { .type = NLA_U16 }, + [NL80211_MESHCONF_MIN_DISCOVERY_TIMEOUT] = NLA_POLICY_MIN(NLA_U16, 1), [NL80211_MESHCONF_HWMP_ACTIVE_PATH_TIMEOUT] = { .type = NLA_U32 }, - [NL80211_MESHCONF_HWMP_PREQ_MIN_INTERVAL] = { .type = NLA_U16 }, - [NL80211_MESHCONF_HWMP_PERR_MIN_INTERVAL] = { .type = NLA_U16 }, - [NL80211_MESHCONF_HWMP_NET_DIAM_TRVS_TIME] = { .type = NLA_U16 }, - [NL80211_MESHCONF_HWMP_ROOTMODE] = { .type = NLA_U8 }, - [NL80211_MESHCONF_HWMP_RANN_INTERVAL] = { .type = NLA_U16 }, - [NL80211_MESHCONF_GATE_ANNOUNCEMENTS] = { .type = NLA_U8 }, - [NL80211_MESHCONF_FORWARDING] = { .type = NLA_U8 }, - [NL80211_MESHCONF_RSSI_THRESHOLD] = { .type = NLA_U32 }, + [NL80211_MESHCONF_HWMP_PREQ_MIN_INTERVAL] = + NLA_POLICY_MIN(NLA_U16, 1), + [NL80211_MESHCONF_HWMP_PERR_MIN_INTERVAL] = + NLA_POLICY_MIN(NLA_U16, 1), + [NL80211_MESHCONF_HWMP_NET_DIAM_TRVS_TIME] = + NLA_POLICY_MIN(NLA_U16, 1), + [NL80211_MESHCONF_HWMP_ROOTMODE] = NLA_POLICY_MAX(NLA_U8, 4), + [NL80211_MESHCONF_HWMP_RANN_INTERVAL] = + NLA_POLICY_MIN(NLA_U16, 1), + [NL80211_MESHCONF_GATE_ANNOUNCEMENTS] = NLA_POLICY_MAX(NLA_U8, 1), + [NL80211_MESHCONF_FORWARDING] = NLA_POLICY_MAX(NLA_U8, 1), + [NL80211_MESHCONF_RSSI_THRESHOLD] = + NLA_POLICY_RANGE(NLA_S32, -255, 0), [NL80211_MESHCONF_HT_OPMODE] = { .type = NLA_U16 }, [NL80211_MESHCONF_HWMP_PATH_TO_ROOT_TIMEOUT] = { .type = NLA_U32 }, - [NL80211_MESHCONF_HWMP_ROOT_INTERVAL] = { .type = NLA_U16 }, - [NL80211_MESHCONF_HWMP_CONFIRMATION_INTERVAL] = { .type = NLA_U16 }, - [NL80211_MESHCONF_POWER_MODE] = { .type = NLA_U32 }, + [NL80211_MESHCONF_HWMP_ROOT_INTERVAL] = + NLA_POLICY_MIN(NLA_U16, 1), + [NL80211_MESHCONF_HWMP_CONFIRMATION_INTERVAL] = + NLA_POLICY_MIN(NLA_U16, 1), + [NL80211_MESHCONF_POWER_MODE] = + NLA_POLICY_RANGE(NLA_U32, + NL80211_MESH_POWER_ACTIVE, + NL80211_MESH_POWER_MAX), [NL80211_MESHCONF_AWAKE_WINDOW] = { .type = NLA_U16 }, [NL80211_MESHCONF_PLINK_TIMEOUT] = { .type = NLA_U32 }, }; @@ -6172,68 +6189,12 @@ static const struct nla_policy [NL80211_MESH_SETUP_USERSPACE_AUTH] = { .type = NLA_FLAG }, [NL80211_MESH_SETUP_AUTH_PROTOCOL] = { .type = NLA_U8 }, [NL80211_MESH_SETUP_USERSPACE_MPM] = { .type = NLA_FLAG }, - [NL80211_MESH_SETUP_IE] = { .type = NLA_BINARY, - .len = IEEE80211_MAX_DATA_LEN }, + [NL80211_MESH_SETUP_IE] = + NLA_POLICY_VALIDATE_FN(NLA_BINARY, validate_ie_attr, + IEEE80211_MAX_DATA_LEN), [NL80211_MESH_SETUP_USERSPACE_AMPE] = { .type = NLA_FLAG }, }; -static int nl80211_check_bool(const struct nlattr *nla, u8 min, u8 max, bool *out) -{ - u8 val = nla_get_u8(nla); - if (val < min || val > max) - return -EINVAL; - *out = val; - return 0; -} - -static int nl80211_check_u8(const struct nlattr *nla, u8 min, u8 max, u8 *out) -{ - u8 val = nla_get_u8(nla); - if (val < min || val > max) - return -EINVAL; - *out = val; - return 0; -} - -static int nl80211_check_u16(const struct nlattr *nla, u16 min, u16 max, u16 *out) -{ - u16 val = nla_get_u16(nla); - if (val < min || val > max) - return -EINVAL; - *out = val; - return 0; -} - -static int nl80211_check_u32(const struct nlattr *nla, u32 min, u32 max, u32 *out) -{ - u32 val = nla_get_u32(nla); - if (val < min || val > max) - return -EINVAL; - *out = val; - return 0; -} - -static int nl80211_check_s32(const struct nlattr *nla, s32 min, s32 max, s32 *out) -{ - s32 val = nla_get_s32(nla); - if (val < min || val > max) - return -EINVAL; - *out = val; - return 0; -} - -static int nl80211_check_power_mode(const struct nlattr *nla, - enum nl80211_mesh_power_mode min, - enum nl80211_mesh_power_mode max, - enum nl80211_mesh_power_mode *out) -{ - u32 val = nla_get_u32(nla); - if (val < min || val > max) - return -EINVAL; - *out = val; - return 0; -} - static int nl80211_parse_mesh_config(struct genl_info *info, struct mesh_config *cfg, u32 *mask_out) @@ -6242,13 +6203,12 @@ static int nl80211_parse_mesh_config(struct genl_info *info, u32 mask = 0; u16 ht_opmode; -#define FILL_IN_MESH_PARAM_IF_SET(tb, cfg, param, min, max, mask, attr, fn) \ -do { \ - if (tb[attr]) { \ - if (fn(tb[attr], min, max, &cfg->param)) \ - return -EINVAL; \ - mask |= (1 << (attr - 1)); \ - } \ +#define FILL_IN_MESH_PARAM_IF_SET(tb, cfg, param, mask, attr, fn) \ +do { \ + if (tb[attr]) { \ + cfg->param = fn(tb[attr]); \ + mask |= BIT((attr) - 1); \ + } \ } while (0) if (!info->attrs[NL80211_ATTR_MESH_CONFIG]) @@ -6263,75 +6223,73 @@ do { \ BUILD_BUG_ON(NL80211_MESHCONF_ATTR_MAX > 32); /* Fill in the params struct */ - FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshRetryTimeout, 1, 255, - mask, NL80211_MESHCONF_RETRY_TIMEOUT, - nl80211_check_u16); - FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshConfirmTimeout, 1, 255, - mask, NL80211_MESHCONF_CONFIRM_TIMEOUT, - nl80211_check_u16); - FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHoldingTimeout, 1, 255, - mask, NL80211_MESHCONF_HOLDING_TIMEOUT, - nl80211_check_u16); - FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshMaxPeerLinks, 0, 255, - mask, NL80211_MESHCONF_MAX_PEER_LINKS, - nl80211_check_u16); - FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshMaxRetries, 0, 16, - mask, NL80211_MESHCONF_MAX_RETRIES, - nl80211_check_u8); - FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshTTL, 1, 255, - mask, NL80211_MESHCONF_TTL, nl80211_check_u8); - FILL_IN_MESH_PARAM_IF_SET(tb, cfg, element_ttl, 1, 255, - mask, NL80211_MESHCONF_ELEMENT_TTL, - nl80211_check_u8); - FILL_IN_MESH_PARAM_IF_SET(tb, cfg, auto_open_plinks, 0, 1, - mask, NL80211_MESHCONF_AUTO_OPEN_PLINKS, - nl80211_check_bool); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshRetryTimeout, mask, + NL80211_MESHCONF_RETRY_TIMEOUT, nla_get_u16); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshConfirmTimeout, mask, + NL80211_MESHCONF_CONFIRM_TIMEOUT, + nla_get_u16); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHoldingTimeout, mask, + NL80211_MESHCONF_HOLDING_TIMEOUT, + nla_get_u16); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshMaxPeerLinks, mask, + NL80211_MESHCONF_MAX_PEER_LINKS, + nla_get_u16); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshMaxRetries, mask, + NL80211_MESHCONF_MAX_RETRIES, nla_get_u8); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshTTL, mask, + NL80211_MESHCONF_TTL, nla_get_u8); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, element_ttl, mask, + NL80211_MESHCONF_ELEMENT_TTL, nla_get_u8); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, auto_open_plinks, mask, + NL80211_MESHCONF_AUTO_OPEN_PLINKS, + nla_get_u8); FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshNbrOffsetMaxNeighbor, - 1, 255, mask, + mask, NL80211_MESHCONF_SYNC_OFFSET_MAX_NEIGHBOR, - nl80211_check_u32); - FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHWMPmaxPREQretries, 0, 255, - mask, NL80211_MESHCONF_HWMP_MAX_PREQ_RETRIES, - nl80211_check_u8); - FILL_IN_MESH_PARAM_IF_SET(tb, cfg, path_refresh_time, 1, 65535, - mask, NL80211_MESHCONF_PATH_REFRESH_TIME, - nl80211_check_u32); - FILL_IN_MESH_PARAM_IF_SET(tb, cfg, min_discovery_timeout, 1, 65535, - mask, NL80211_MESHCONF_MIN_DISCOVERY_TIMEOUT, - nl80211_check_u16); + nla_get_u32); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHWMPmaxPREQretries, mask, + NL80211_MESHCONF_HWMP_MAX_PREQ_RETRIES, + nla_get_u8); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, path_refresh_time, mask, + NL80211_MESHCONF_PATH_REFRESH_TIME, + nla_get_u32); + if (mask & BIT(NL80211_MESHCONF_PATH_REFRESH_TIME) && + (cfg->path_refresh_time < 1 || cfg->path_refresh_time > 65535)) + return -EINVAL; + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, min_discovery_timeout, mask, + NL80211_MESHCONF_MIN_DISCOVERY_TIMEOUT, + nla_get_u16); FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHWMPactivePathTimeout, - 1, 65535, mask, + mask, NL80211_MESHCONF_HWMP_ACTIVE_PATH_TIMEOUT, - nl80211_check_u32); - FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHWMPpreqMinInterval, - 1, 65535, mask, + nla_get_u32); + if (mask & BIT(NL80211_MESHCONF_HWMP_ACTIVE_PATH_TIMEOUT) && + (cfg->dot11MeshHWMPactivePathTimeout < 1 || + cfg->dot11MeshHWMPactivePathTimeout > 65535)) + return -EINVAL; + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHWMPpreqMinInterval, mask, NL80211_MESHCONF_HWMP_PREQ_MIN_INTERVAL, - nl80211_check_u16); - FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHWMPperrMinInterval, - 1, 65535, mask, + nla_get_u16); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHWMPperrMinInterval, mask, NL80211_MESHCONF_HWMP_PERR_MIN_INTERVAL, - nl80211_check_u16); + nla_get_u16); FILL_IN_MESH_PARAM_IF_SET(tb, cfg, - dot11MeshHWMPnetDiameterTraversalTime, - 1, 65535, mask, + dot11MeshHWMPnetDiameterTraversalTime, mask, NL80211_MESHCONF_HWMP_NET_DIAM_TRVS_TIME, - nl80211_check_u16); - FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHWMPRootMode, 0, 4, - mask, NL80211_MESHCONF_HWMP_ROOTMODE, - nl80211_check_u8); - FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHWMPRannInterval, 1, 65535, - mask, NL80211_MESHCONF_HWMP_RANN_INTERVAL, - nl80211_check_u16); - FILL_IN_MESH_PARAM_IF_SET(tb, cfg, - dot11MeshGateAnnouncementProtocol, 0, 1, + nla_get_u16); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHWMPRootMode, mask, + NL80211_MESHCONF_HWMP_ROOTMODE, nla_get_u8); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHWMPRannInterval, mask, + NL80211_MESHCONF_HWMP_RANN_INTERVAL, + nla_get_u16); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshGateAnnouncementProtocol, mask, NL80211_MESHCONF_GATE_ANNOUNCEMENTS, - nl80211_check_bool); - FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshForwarding, 0, 1, - mask, NL80211_MESHCONF_FORWARDING, - nl80211_check_bool); - FILL_IN_MESH_PARAM_IF_SET(tb, cfg, rssi_threshold, -255, 0, - mask, NL80211_MESHCONF_RSSI_THRESHOLD, - nl80211_check_s32); + nla_get_u8); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshForwarding, mask, + NL80211_MESHCONF_FORWARDING, nla_get_u8); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, rssi_threshold, mask, + NL80211_MESHCONF_RSSI_THRESHOLD, + nla_get_s32); /* * Check HT operation mode based on * IEEE 802.11-2016 9.4.2.57 HT Operation element. @@ -6350,29 +6308,27 @@ do { \ cfg->ht_opmode = ht_opmode; mask |= (1 << (NL80211_MESHCONF_HT_OPMODE - 1)); } - FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHWMPactivePathToRootTimeout, - 1, 65535, mask, - NL80211_MESHCONF_HWMP_PATH_TO_ROOT_TIMEOUT, - nl80211_check_u32); - FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHWMProotInterval, 1, 65535, - mask, NL80211_MESHCONF_HWMP_ROOT_INTERVAL, - nl80211_check_u16); FILL_IN_MESH_PARAM_IF_SET(tb, cfg, - dot11MeshHWMPconfirmationInterval, - 1, 65535, mask, + dot11MeshHWMPactivePathToRootTimeout, mask, + NL80211_MESHCONF_HWMP_PATH_TO_ROOT_TIMEOUT, + nla_get_u32); + if (mask & BIT(NL80211_MESHCONF_HWMP_PATH_TO_ROOT_TIMEOUT) && + (cfg->dot11MeshHWMPactivePathToRootTimeout < 1 || + cfg->dot11MeshHWMPactivePathToRootTimeout > 65535)) + return -EINVAL; + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHWMProotInterval, mask, + NL80211_MESHCONF_HWMP_ROOT_INTERVAL, + nla_get_u16); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHWMPconfirmationInterval, + mask, NL80211_MESHCONF_HWMP_CONFIRMATION_INTERVAL, - nl80211_check_u16); - FILL_IN_MESH_PARAM_IF_SET(tb, cfg, power_mode, - NL80211_MESH_POWER_ACTIVE, - NL80211_MESH_POWER_MAX, - mask, NL80211_MESHCONF_POWER_MODE, - nl80211_check_power_mode); - FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshAwakeWindowDuration, - 0, 65535, mask, - NL80211_MESHCONF_AWAKE_WINDOW, nl80211_check_u16); - FILL_IN_MESH_PARAM_IF_SET(tb, cfg, plink_timeout, 0, 0xffffffff, - mask, NL80211_MESHCONF_PLINK_TIMEOUT, - nl80211_check_u32); + nla_get_u16); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, power_mode, mask, + NL80211_MESHCONF_POWER_MODE, nla_get_u32); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshAwakeWindowDuration, mask, + NL80211_MESHCONF_AWAKE_WINDOW, nla_get_u16); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, plink_timeout, mask, + NL80211_MESHCONF_PLINK_TIMEOUT, nla_get_u32); if (mask_out) *mask_out = mask; @@ -6415,8 +6371,6 @@ static int nl80211_parse_mesh_setup(struct genl_info *info, if (tb[NL80211_MESH_SETUP_IE]) { struct nlattr *ieattr = tb[NL80211_MESH_SETUP_IE]; - if (!is_valid_ie_attr(ieattr)) - return -EINVAL; setup->ie = nla_data(ieattr); setup->ie_len = nla_len(ieattr); } @@ -7049,9 +7003,6 @@ static int nl80211_trigger_scan(struct sk_buff *skb, struct genl_info *info) int err, tmp, n_ssids = 0, n_channels, i; size_t ie_len; - if (!is_valid_ie_attr(info->attrs[NL80211_ATTR_IE])) - return -EINVAL; - wiphy = &rdev->wiphy; if (wdev->iftype == NL80211_IFTYPE_NAN) @@ -7405,9 +7356,6 @@ nl80211_parse_sched_scan(struct wiphy *wiphy, struct wireless_dev *wdev, struct nlattr *tb[NL80211_SCHED_SCAN_MATCH_ATTR_MAX + 1]; s32 default_match_rssi = NL80211_SCAN_RSSI_THOLD_OFF; - if (!is_valid_ie_attr(attrs[NL80211_ATTR_IE])) - return ERR_PTR(-EINVAL); - if (attrs[NL80211_ATTR_SCAN_FREQUENCIES]) { n_channels = validate_scan_freqs( attrs[NL80211_ATTR_SCAN_FREQUENCIES]); @@ -7767,7 +7715,7 @@ static int nl80211_start_sched_scan(struct sk_buff *skb, */ if (want_multi && rdev->wiphy.max_sched_scan_reqs > 1) { while (!sched_scan_req->reqid) - sched_scan_req->reqid = rdev->wiphy.cookie_counter++; + sched_scan_req->reqid = cfg80211_assign_cookie(rdev); } err = rdev_sched_scan_start(rdev, dev, sched_scan_req); @@ -7943,7 +7891,7 @@ static int nl80211_channel_switch(struct sk_buff *skb, struct genl_info *info) if (!need_new_beacon) goto skip_beacons; - err = nl80211_parse_beacon(info->attrs, ¶ms.beacon_after); + err = nl80211_parse_beacon(rdev, info->attrs, ¶ms.beacon_after); if (err) return err; @@ -7953,7 +7901,7 @@ static int nl80211_channel_switch(struct sk_buff *skb, struct genl_info *info) if (err) return err; - err = nl80211_parse_beacon(csa_attrs, ¶ms.beacon_csa); + err = nl80211_parse_beacon(rdev, csa_attrs, ¶ms.beacon_csa); if (err) return err; @@ -8190,7 +8138,7 @@ static int nl80211_dump_scan(struct sk_buff *skb, struct netlink_callback *cb) int err; rtnl_lock(); - err = nl80211_prepare_wdev_dump(skb, cb, &rdev, &wdev); + err = nl80211_prepare_wdev_dump(cb, &rdev, &wdev); if (err) { rtnl_unlock(); return err; @@ -8311,7 +8259,7 @@ static int nl80211_dump_survey(struct sk_buff *skb, struct netlink_callback *cb) bool radio_stats; rtnl_lock(); - res = nl80211_prepare_wdev_dump(skb, cb, &rdev, &wdev); + res = nl80211_prepare_wdev_dump(cb, &rdev, &wdev); if (res) goto out_err; @@ -8375,9 +8323,6 @@ static int nl80211_authenticate(struct sk_buff *skb, struct genl_info *info) struct key_parse key; bool local_state_change; - if (!is_valid_ie_attr(info->attrs[NL80211_ATTR_IE])) - return -EINVAL; - if (!info->attrs[NL80211_ATTR_MAC]) return -EINVAL; @@ -8616,9 +8561,6 @@ static int nl80211_associate(struct sk_buff *skb, struct genl_info *info) dev->ieee80211_ptr->conn_owner_nlportid != info->snd_portid) return -EPERM; - if (!is_valid_ie_attr(info->attrs[NL80211_ATTR_IE])) - return -EINVAL; - if (!info->attrs[NL80211_ATTR_MAC] || !info->attrs[NL80211_ATTR_SSID] || !info->attrs[NL80211_ATTR_WIPHY_FREQ]) @@ -8742,9 +8684,6 @@ static int nl80211_deauthenticate(struct sk_buff *skb, struct genl_info *info) dev->ieee80211_ptr->conn_owner_nlportid != info->snd_portid) return -EPERM; - if (!is_valid_ie_attr(info->attrs[NL80211_ATTR_IE])) - return -EINVAL; - if (!info->attrs[NL80211_ATTR_MAC]) return -EINVAL; @@ -8793,9 +8732,6 @@ static int nl80211_disassociate(struct sk_buff *skb, struct genl_info *info) dev->ieee80211_ptr->conn_owner_nlportid != info->snd_portid) return -EPERM; - if (!is_valid_ie_attr(info->attrs[NL80211_ATTR_IE])) - return -EINVAL; - if (!info->attrs[NL80211_ATTR_MAC]) return -EINVAL; @@ -8870,9 +8806,6 @@ static int nl80211_join_ibss(struct sk_buff *skb, struct genl_info *info) memset(&ibss, 0, sizeof(ibss)); - if (!is_valid_ie_attr(info->attrs[NL80211_ATTR_IE])) - return -EINVAL; - if (!info->attrs[NL80211_ATTR_SSID] || !nla_len(info->attrs[NL80211_ATTR_SSID])) return -EINVAL; @@ -9310,9 +9243,6 @@ static int nl80211_connect(struct sk_buff *skb, struct genl_info *info) memset(&connect, 0, sizeof(connect)); - if (!is_valid_ie_attr(info->attrs[NL80211_ATTR_IE])) - return -EINVAL; - if (!info->attrs[NL80211_ATTR_SSID] || !nla_len(info->attrs[NL80211_ATTR_SSID])) return -EINVAL; @@ -9371,11 +9301,6 @@ static int nl80211_connect(struct sk_buff *skb, struct genl_info *info) !wiphy_ext_feature_isset(&rdev->wiphy, NL80211_EXT_FEATURE_MFP_OPTIONAL)) return -EOPNOTSUPP; - - if (connect.mfp != NL80211_MFP_REQUIRED && - connect.mfp != NL80211_MFP_NO && - connect.mfp != NL80211_MFP_OPTIONAL) - return -EINVAL; } else { connect.mfp = NL80211_MFP_NO; } @@ -9548,8 +9473,6 @@ static int nl80211_update_connect_params(struct sk_buff *skb, return -EOPNOTSUPP; if (info->attrs[NL80211_ATTR_IE]) { - if (!is_valid_ie_attr(info->attrs[NL80211_ATTR_IE])) - return -EINVAL; connect.ie = nla_data(info->attrs[NL80211_ATTR_IE]); connect.ie_len = nla_len(info->attrs[NL80211_ATTR_IE]); changed |= UPDATE_ASSOC_IES; @@ -10134,9 +10057,6 @@ static int nl80211_set_power_save(struct sk_buff *skb, struct genl_info *info) ps_state = nla_get_u32(info->attrs[NL80211_ATTR_PS_STATE]); - if (ps_state != NL80211_PS_DISABLED && ps_state != NL80211_PS_ENABLED) - return -EINVAL; - wdev = dev->ieee80211_ptr; if (!rdev->ops->set_power_mgmt) @@ -10234,7 +10154,7 @@ static int cfg80211_cqm_rssi_update(struct cfg80211_registered_device *rdev, struct wireless_dev *wdev = dev->ieee80211_ptr; s32 last, low, high; u32 hyst; - int i, n; + int i, n, low_index; int err; /* RSSI reporting disabled? */ @@ -10271,10 +10191,19 @@ static int cfg80211_cqm_rssi_update(struct cfg80211_registered_device *rdev, if (last < wdev->cqm_config->rssi_thresholds[i]) break; - low = i > 0 ? - (wdev->cqm_config->rssi_thresholds[i - 1] - hyst) : S32_MIN; - high = i < n ? - (wdev->cqm_config->rssi_thresholds[i] + hyst - 1) : S32_MAX; + low_index = i - 1; + if (low_index >= 0) { + low_index = array_index_nospec(low_index, n); + low = wdev->cqm_config->rssi_thresholds[low_index] - hyst; + } else { + low = S32_MIN; + } + if (i < n) { + i = array_index_nospec(i, n); + high = wdev->cqm_config->rssi_thresholds[i] + hyst - 1; + } else { + high = S32_MAX; + } return rdev_set_cqm_rssi_range_config(rdev, dev, low, high); } @@ -10690,8 +10619,7 @@ static int nl80211_send_wowlan_nd(struct sk_buff *msg, if (!scan_plan) return -ENOBUFS; - if (!scan_plan || - nla_put_u32(msg, NL80211_SCHED_SCAN_PLAN_INTERVAL, + if (nla_put_u32(msg, NL80211_SCHED_SCAN_PLAN_INTERVAL, req->scan_plans[i].interval) || (req->scan_plans[i].iterations && nla_put_u32(msg, NL80211_SCHED_SCAN_PLAN_ITERATIONS, @@ -11289,9 +11217,6 @@ static int nl80211_parse_coalesce_rule(struct cfg80211_registered_device *rdev, if (tb[NL80211_ATTR_COALESCE_RULE_CONDITION]) new_rule->condition = nla_get_u32(tb[NL80211_ATTR_COALESCE_RULE_CONDITION]); - if (new_rule->condition != NL80211_COALESCE_CONDITION_MATCH && - new_rule->condition != NL80211_COALESCE_CONDITION_NO_MATCH) - return -EINVAL; if (!tb[NL80211_ATTR_COALESCE_RULE_PKT_PATTERN]) return -EINVAL; @@ -11644,8 +11569,6 @@ static int nl80211_start_nan(struct sk_buff *skb, struct genl_info *info) conf.master_pref = nla_get_u8(info->attrs[NL80211_ATTR_NAN_MASTER_PREF]); - if (!conf.master_pref) - return -EINVAL; if (info->attrs[NL80211_ATTR_BANDS]) { u32 bands = nla_get_u32(info->attrs[NL80211_ATTR_BANDS]); @@ -11763,7 +11686,7 @@ static int nl80211_nan_add_func(struct sk_buff *skb, if (!func) return -ENOMEM; - func->cookie = wdev->wiphy->cookie_counter++; + func->cookie = cfg80211_assign_cookie(rdev); if (!tb[NL80211_NAN_FUNC_TYPE] || nla_get_u8(tb[NL80211_NAN_FUNC_TYPE]) > NL80211_NAN_FUNC_MAX_TYPE) { @@ -12209,8 +12132,7 @@ static int nl80211_update_ft_ies(struct sk_buff *skb, struct genl_info *info) return -EOPNOTSUPP; if (!info->attrs[NL80211_ATTR_MDID] || - !info->attrs[NL80211_ATTR_IE] || - !is_valid_ie_attr(info->attrs[NL80211_ATTR_IE])) + !info->attrs[NL80211_ATTR_IE]) return -EINVAL; memset(&ft_params, 0, sizeof(ft_params)); @@ -12630,12 +12552,7 @@ static int nl80211_add_tx_ts(struct sk_buff *skb, struct genl_info *info) return -EINVAL; tsid = nla_get_u8(info->attrs[NL80211_ATTR_TSID]); - if (tsid >= IEEE80211_NUM_TIDS) - return -EINVAL; - up = nla_get_u8(info->attrs[NL80211_ATTR_USER_PRIO]); - if (up >= IEEE80211_NUM_UPS) - return -EINVAL; /* WMM uses TIDs 0-7 even for TSPEC */ if (tsid >= IEEE80211_FIRST_TSPEC_TSID) { @@ -12993,6 +12910,76 @@ static int nl80211_tx_control_port(struct sk_buff *skb, struct genl_info *info) return err; } +static int nl80211_get_ftm_responder_stats(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_ftm_responder_stats ftm_stats = {}; + struct sk_buff *msg; + void *hdr; + struct nlattr *ftm_stats_attr; + int err; + + if (wdev->iftype != NL80211_IFTYPE_AP || !wdev->beacon_interval) + return -EOPNOTSUPP; + + err = rdev_get_ftm_responder_stats(rdev, dev, &ftm_stats); + if (err) + return err; + + if (!ftm_stats.filled) + return -ENODATA; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + hdr = nl80211hdr_put(msg, info->snd_portid, info->snd_seq, 0, + NL80211_CMD_GET_FTM_RESPONDER_STATS); + if (!hdr) + return -ENOBUFS; + + if (nla_put_u32(msg, NL80211_ATTR_IFINDEX, dev->ifindex)) + goto nla_put_failure; + + ftm_stats_attr = nla_nest_start(msg, NL80211_ATTR_FTM_RESPONDER_STATS); + if (!ftm_stats_attr) + goto nla_put_failure; + +#define SET_FTM(field, name, type) \ + do { if ((ftm_stats.filled & BIT(NL80211_FTM_STATS_ ## name)) && \ + nla_put_ ## type(msg, NL80211_FTM_STATS_ ## name, \ + ftm_stats.field)) \ + goto nla_put_failure; } while (0) +#define SET_FTM_U64(field, name) \ + do { if ((ftm_stats.filled & BIT(NL80211_FTM_STATS_ ## name)) && \ + nla_put_u64_64bit(msg, NL80211_FTM_STATS_ ## name, \ + ftm_stats.field, NL80211_FTM_STATS_PAD)) \ + goto nla_put_failure; } while (0) + + SET_FTM(success_num, SUCCESS_NUM, u32); + SET_FTM(partial_num, PARTIAL_NUM, u32); + SET_FTM(failed_num, FAILED_NUM, u32); + SET_FTM(asap_num, ASAP_NUM, u32); + SET_FTM(non_asap_num, NON_ASAP_NUM, u32); + SET_FTM_U64(total_duration_ms, TOTAL_DURATION_MSEC); + SET_FTM(unknown_triggers_num, UNKNOWN_TRIGGERS_NUM, u32); + SET_FTM(reschedule_requests_num, RESCHEDULE_REQUESTS_NUM, u32); + SET_FTM(out_of_window_triggers_num, OUT_OF_WINDOW_TRIGGERS_NUM, u32); +#undef SET_FTM + + nla_nest_end(msg, ftm_stats_attr); + + genlmsg_end(msg, hdr); + return genlmsg_reply(msg, info); + +nla_put_failure: + nlmsg_free(msg); + return -ENOBUFS; +} + #define NL80211_FLAG_NEED_WIPHY 0x01 #define NL80211_FLAG_NEED_NETDEV 0x02 #define NL80211_FLAG_NEED_RTNL 0x04 @@ -13904,6 +13891,13 @@ static const struct genl_ops nl80211_ops[] = { .internal_flags = NL80211_FLAG_NEED_NETDEV_UP | NL80211_FLAG_NEED_RTNL, }, + { + .cmd = NL80211_CMD_GET_FTM_RESPONDER_STATS, + .doit = nl80211_get_ftm_responder_stats, + .policy = nl80211_policy, + .internal_flags = NL80211_FLAG_NEED_NETDEV | + NL80211_FLAG_NEED_RTNL, + }, }; static struct genl_family nl80211_fam __ro_after_init = { diff --git a/net/wireless/rdev-ops.h b/net/wireless/rdev-ops.h index 364f5d67f05b..51380b5c32f2 100644 --- a/net/wireless/rdev-ops.h +++ b/net/wireless/rdev-ops.h @@ -1232,4 +1232,19 @@ rdev_external_auth(struct cfg80211_registered_device *rdev, return ret; } +static inline int +rdev_get_ftm_responder_stats(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_ftm_responder_stats *ftm_stats) +{ + int ret = -EOPNOTSUPP; + + trace_rdev_get_ftm_responder_stats(&rdev->wiphy, dev, ftm_stats); + if (rdev->ops->get_ftm_responder_stats) + ret = rdev->ops->get_ftm_responder_stats(&rdev->wiphy, dev, + ftm_stats); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + #endif /* __CFG80211_RDEV_OPS */ diff --git a/net/wireless/reg.c b/net/wireless/reg.c index 56be68a27bb9..ecfb1a06dbb2 100644 --- a/net/wireless/reg.c +++ b/net/wireless/reg.c @@ -2667,11 +2667,12 @@ static void reg_process_hint(struct regulatory_request *reg_request) { struct wiphy *wiphy = NULL; enum reg_request_treatment treatment; + enum nl80211_reg_initiator initiator = reg_request->initiator; if (reg_request->wiphy_idx != WIPHY_IDX_INVALID) wiphy = wiphy_idx_to_wiphy(reg_request->wiphy_idx); - switch (reg_request->initiator) { + switch (initiator) { case NL80211_REGDOM_SET_BY_CORE: treatment = reg_process_hint_core(reg_request); break; @@ -2689,7 +2690,7 @@ static void reg_process_hint(struct regulatory_request *reg_request) treatment = reg_process_hint_country_ie(wiphy, reg_request); break; default: - WARN(1, "invalid initiator %d\n", reg_request->initiator); + WARN(1, "invalid initiator %d\n", initiator); goto out_free; } @@ -2704,7 +2705,7 @@ static void reg_process_hint(struct regulatory_request *reg_request) */ if (treatment == REG_REQ_ALREADY_SET && wiphy && wiphy->regulatory_flags & REGULATORY_STRICT_REG) { - wiphy_update_regulatory(wiphy, reg_request->initiator); + wiphy_update_regulatory(wiphy, initiator); wiphy_all_share_dfs_chan_state(wiphy); reg_check_channels(); } @@ -2873,6 +2874,7 @@ static int regulatory_hint_core(const char *alpha2) request->alpha2[0] = alpha2[0]; request->alpha2[1] = alpha2[1]; request->initiator = NL80211_REGDOM_SET_BY_CORE; + request->wiphy_idx = WIPHY_IDX_INVALID; queue_regulatory_request(request); @@ -3829,6 +3831,15 @@ static int __init regulatory_init_db(void) { int err; + /* + * It's possible that - due to other bugs/issues - cfg80211 + * never called regulatory_init() below, or that it failed; + * in that case, don't try to do any further work here as + * it's doomed to lead to crashes. + */ + if (IS_ERR_OR_NULL(reg_pdev)) + return -EINVAL; + err = load_builtin_regdb_keys(); if (err) return err; diff --git a/net/wireless/scan.c b/net/wireless/scan.c index d36c3eb7b931..d0e7472dd9fd 100644 --- a/net/wireless/scan.c +++ b/net/wireless/scan.c @@ -1058,13 +1058,23 @@ cfg80211_bss_update(struct cfg80211_registered_device *rdev, return NULL; } +/* + * Update RX channel information based on the available frame payload + * information. This is mainly for the 2.4 GHz band where frames can be received + * from neighboring channels and the Beacon frames use the DSSS Parameter Set + * element to indicate the current (transmitting) channel, but this might also + * be needed on other bands if RX frequency does not match with the actual + * operating channel of a BSS. + */ static struct ieee80211_channel * cfg80211_get_bss_channel(struct wiphy *wiphy, const u8 *ie, size_t ielen, - struct ieee80211_channel *channel) + struct ieee80211_channel *channel, + enum nl80211_bss_scan_width scan_width) { const u8 *tmp; u32 freq; int channel_number = -1; + struct ieee80211_channel *alt_channel; tmp = cfg80211_find_ie(WLAN_EID_DS_PARAMS, ie, ielen); if (tmp && tmp[1] == 1) { @@ -1078,16 +1088,45 @@ cfg80211_get_bss_channel(struct wiphy *wiphy, const u8 *ie, size_t ielen, } } - if (channel_number < 0) + if (channel_number < 0) { + /* No channel information in frame payload */ return channel; + } freq = ieee80211_channel_to_frequency(channel_number, channel->band); - channel = ieee80211_get_channel(wiphy, freq); - if (!channel) - return NULL; - if (channel->flags & IEEE80211_CHAN_DISABLED) + alt_channel = ieee80211_get_channel(wiphy, freq); + if (!alt_channel) { + if (channel->band == NL80211_BAND_2GHZ) { + /* + * Better not allow unexpected channels when that could + * be going beyond the 1-11 range (e.g., discovering + * BSS on channel 12 when radio is configured for + * channel 11. + */ + return NULL; + } + + /* No match for the payload channel number - ignore it */ + return channel; + } + + if (scan_width == NL80211_BSS_CHAN_WIDTH_10 || + scan_width == NL80211_BSS_CHAN_WIDTH_5) { + /* + * Ignore channel number in 5 and 10 MHz channels where there + * may not be an n:1 or 1:n mapping between frequencies and + * channel numbers. + */ + return channel; + } + + /* + * Use the channel determined through the payload channel number + * instead of the RX channel reported by the driver. + */ + if (alt_channel->flags & IEEE80211_CHAN_DISABLED) return NULL; - return channel; + return alt_channel; } /* Returned bss is reference counted and must be cleaned up appropriately. */ @@ -1112,7 +1151,8 @@ cfg80211_inform_bss_data(struct wiphy *wiphy, (data->signal < 0 || data->signal > 100))) return NULL; - channel = cfg80211_get_bss_channel(wiphy, ie, ielen, data->chan); + channel = cfg80211_get_bss_channel(wiphy, ie, ielen, data->chan, + data->scan_width); if (!channel) return NULL; @@ -1210,7 +1250,7 @@ cfg80211_inform_bss_frame_data(struct wiphy *wiphy, return NULL; channel = cfg80211_get_bss_channel(wiphy, mgmt->u.beacon.variable, - ielen, data->chan); + ielen, data->chan, data->scan_width); if (!channel) return NULL; diff --git a/net/wireless/trace.h b/net/wireless/trace.h index 5e7eec849200..c6a9446b4e6b 100644 --- a/net/wireless/trace.h +++ b/net/wireless/trace.h @@ -2368,6 +2368,140 @@ TRACE_EVENT(rdev_external_auth, __entry->bssid, __entry->ssid, __entry->status) ); +TRACE_EVENT(rdev_start_radar_detection, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_chan_def *chandef, + u32 cac_time_ms), + TP_ARGS(wiphy, netdev, chandef, cac_time_ms), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + CHAN_DEF_ENTRY + __field(u32, cac_time_ms) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + CHAN_DEF_ASSIGN(chandef); + __entry->cac_time_ms = cac_time_ms; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", " CHAN_DEF_PR_FMT + ", cac_time_ms=%u", + WIPHY_PR_ARG, NETDEV_PR_ARG, CHAN_DEF_PR_ARG, + __entry->cac_time_ms) +); + +TRACE_EVENT(rdev_set_mcast_rate, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + int *mcast_rate), + TP_ARGS(wiphy, netdev, mcast_rate), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __array(int, mcast_rate, NUM_NL80211_BANDS) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + memcpy(__entry->mcast_rate, mcast_rate, + sizeof(int) * NUM_NL80211_BANDS); + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", " + "mcast_rates [2.4GHz=0x%x, 5.2GHz=0x%x, 60GHz=0x%x]", + WIPHY_PR_ARG, NETDEV_PR_ARG, + __entry->mcast_rate[NL80211_BAND_2GHZ], + __entry->mcast_rate[NL80211_BAND_5GHZ], + __entry->mcast_rate[NL80211_BAND_60GHZ]) +); + +TRACE_EVENT(rdev_set_coalesce, + TP_PROTO(struct wiphy *wiphy, struct cfg80211_coalesce *coalesce), + TP_ARGS(wiphy, coalesce), + TP_STRUCT__entry( + WIPHY_ENTRY + __field(int, n_rules) + ), + TP_fast_assign( + WIPHY_ASSIGN; + __entry->n_rules = coalesce ? coalesce->n_rules : 0; + ), + TP_printk(WIPHY_PR_FMT ", n_rules=%d", + WIPHY_PR_ARG, __entry->n_rules) +); + +DEFINE_EVENT(wiphy_wdev_evt, rdev_abort_scan, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev), + TP_ARGS(wiphy, wdev) +); + +TRACE_EVENT(rdev_set_multicast_to_unicast, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + const bool enabled), + TP_ARGS(wiphy, netdev, enabled), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(bool, enabled) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + __entry->enabled = enabled; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", unicast: %s", + WIPHY_PR_ARG, NETDEV_PR_ARG, + BOOL_TO_STR(__entry->enabled)) +); + +DEFINE_EVENT(wiphy_wdev_evt, rdev_get_txq_stats, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev), + TP_ARGS(wiphy, wdev) +); + +TRACE_EVENT(rdev_get_ftm_responder_stats, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_ftm_responder_stats *ftm_stats), + + TP_ARGS(wiphy, netdev, ftm_stats), + + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(u64, timestamp) + __field(u32, success_num) + __field(u32, partial_num) + __field(u32, failed_num) + __field(u32, asap_num) + __field(u32, non_asap_num) + __field(u64, duration) + __field(u32, unknown_triggers) + __field(u32, reschedule) + __field(u32, out_of_window) + ), + + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + __entry->success_num = ftm_stats->success_num; + __entry->partial_num = ftm_stats->partial_num; + __entry->failed_num = ftm_stats->failed_num; + __entry->asap_num = ftm_stats->asap_num; + __entry->non_asap_num = ftm_stats->non_asap_num; + __entry->duration = ftm_stats->total_duration_ms; + __entry->unknown_triggers = ftm_stats->unknown_triggers_num; + __entry->reschedule = ftm_stats->reschedule_requests_num; + __entry->out_of_window = ftm_stats->out_of_window_triggers_num; + ), + + TP_printk(WIPHY_PR_FMT "Ftm responder stats: success %u, partial %u, " + "failed %u, asap %u, non asap %u, total duration %llu, unknown " + "triggers %u, rescheduled %u, out of window %u", WIPHY_PR_ARG, + __entry->success_num, __entry->partial_num, __entry->failed_num, + __entry->asap_num, __entry->non_asap_num, __entry->duration, + __entry->unknown_triggers, __entry->reschedule, + __entry->out_of_window) +); + /************************************************************* * cfg80211 exported functions traces * *************************************************************/ @@ -3160,105 +3294,6 @@ TRACE_EVENT(cfg80211_stop_iface, TP_printk(WIPHY_PR_FMT ", " WDEV_PR_FMT, WIPHY_PR_ARG, WDEV_PR_ARG) ); - -TRACE_EVENT(rdev_start_radar_detection, - TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, - struct cfg80211_chan_def *chandef, - u32 cac_time_ms), - TP_ARGS(wiphy, netdev, chandef, cac_time_ms), - TP_STRUCT__entry( - WIPHY_ENTRY - NETDEV_ENTRY - CHAN_DEF_ENTRY - __field(u32, cac_time_ms) - ), - TP_fast_assign( - WIPHY_ASSIGN; - NETDEV_ASSIGN; - CHAN_DEF_ASSIGN(chandef); - __entry->cac_time_ms = cac_time_ms; - ), - TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", " CHAN_DEF_PR_FMT - ", cac_time_ms=%u", - WIPHY_PR_ARG, NETDEV_PR_ARG, CHAN_DEF_PR_ARG, - __entry->cac_time_ms) -); - -TRACE_EVENT(rdev_set_mcast_rate, - TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, - int *mcast_rate), - TP_ARGS(wiphy, netdev, mcast_rate), - TP_STRUCT__entry( - WIPHY_ENTRY - NETDEV_ENTRY - __array(int, mcast_rate, NUM_NL80211_BANDS) - ), - TP_fast_assign( - WIPHY_ASSIGN; - NETDEV_ASSIGN; - memcpy(__entry->mcast_rate, mcast_rate, - sizeof(int) * NUM_NL80211_BANDS); - ), - TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", " - "mcast_rates [2.4GHz=0x%x, 5.2GHz=0x%x, 60GHz=0x%x]", - WIPHY_PR_ARG, NETDEV_PR_ARG, - __entry->mcast_rate[NL80211_BAND_2GHZ], - __entry->mcast_rate[NL80211_BAND_5GHZ], - __entry->mcast_rate[NL80211_BAND_60GHZ]) -); - -TRACE_EVENT(rdev_set_coalesce, - TP_PROTO(struct wiphy *wiphy, struct cfg80211_coalesce *coalesce), - TP_ARGS(wiphy, coalesce), - TP_STRUCT__entry( - WIPHY_ENTRY - __field(int, n_rules) - ), - TP_fast_assign( - WIPHY_ASSIGN; - __entry->n_rules = coalesce ? coalesce->n_rules : 0; - ), - TP_printk(WIPHY_PR_FMT ", n_rules=%d", - WIPHY_PR_ARG, __entry->n_rules) -); - -DEFINE_EVENT(wiphy_wdev_evt, rdev_abort_scan, - TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev), - TP_ARGS(wiphy, wdev) -); - -TRACE_EVENT(rdev_set_multicast_to_unicast, - TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, - const bool enabled), - TP_ARGS(wiphy, netdev, enabled), - TP_STRUCT__entry( - WIPHY_ENTRY - NETDEV_ENTRY - __field(bool, enabled) - ), - TP_fast_assign( - WIPHY_ASSIGN; - NETDEV_ASSIGN; - __entry->enabled = enabled; - ), - TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", unicast: %s", - WIPHY_PR_ARG, NETDEV_PR_ARG, - BOOL_TO_STR(__entry->enabled)) -); - -TRACE_EVENT(rdev_get_txq_stats, - TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev), - TP_ARGS(wiphy, wdev), - TP_STRUCT__entry( - WIPHY_ENTRY - WDEV_ENTRY - ), - TP_fast_assign( - WIPHY_ASSIGN; - WDEV_ASSIGN; - ), - TP_printk(WIPHY_PR_FMT ", " WDEV_PR_FMT, WIPHY_PR_ARG, WDEV_PR_ARG) -); #endif /* !__RDEV_OPS_TRACE || TRACE_HEADER_MULTI_READ */ #undef TRACE_INCLUDE_PATH diff --git a/net/wireless/wext-compat.c b/net/wireless/wext-compat.c index 167f7025ac98..06943d9c9835 100644 --- a/net/wireless/wext-compat.c +++ b/net/wireless/wext-compat.c @@ -1278,12 +1278,16 @@ static int cfg80211_wext_giwrate(struct net_device *dev, if (err) return err; - if (!(sinfo.filled & BIT_ULL(NL80211_STA_INFO_TX_BITRATE))) - return -EOPNOTSUPP; + if (!(sinfo.filled & BIT_ULL(NL80211_STA_INFO_TX_BITRATE))) { + err = -EOPNOTSUPP; + goto free; + } rate->value = 100000 * cfg80211_calculate_bitrate(&sinfo.txrate); - return 0; +free: + cfg80211_sinfo_release_content(&sinfo); + return err; } /* Get wireless statistics. Called by /proc/net/wireless and by SIOCGIWSTATS */ @@ -1293,7 +1297,7 @@ static struct iw_statistics *cfg80211_wireless_stats(struct net_device *dev) struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); /* we are under RTNL - globally locked - so can use static structs */ static struct iw_statistics wstats; - static struct station_info sinfo; + static struct station_info sinfo = {}; u8 bssid[ETH_ALEN]; if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_STATION) @@ -1352,6 +1356,8 @@ static struct iw_statistics *cfg80211_wireless_stats(struct net_device *dev) if (sinfo.filled & BIT_ULL(NL80211_STA_INFO_TX_FAILED)) wstats.discard.retries = sinfo.tx_failed; + cfg80211_sinfo_release_content(&sinfo); + return &wstats; } diff --git a/net/xdp/xdp_umem.c b/net/xdp/xdp_umem.c index 555427b3e0fe..a264cf2accd0 100644 --- a/net/xdp/xdp_umem.c +++ b/net/xdp/xdp_umem.c @@ -32,37 +32,49 @@ void xdp_del_sk_umem(struct xdp_umem *umem, struct xdp_sock *xs) { unsigned long flags; - if (xs->dev) { - spin_lock_irqsave(&umem->xsk_list_lock, flags); - list_del_rcu(&xs->list); - spin_unlock_irqrestore(&umem->xsk_list_lock, flags); - - if (umem->zc) - synchronize_net(); - } + spin_lock_irqsave(&umem->xsk_list_lock, flags); + list_del_rcu(&xs->list); + spin_unlock_irqrestore(&umem->xsk_list_lock, flags); } -int xdp_umem_query(struct net_device *dev, u16 queue_id) +/* The umem is stored both in the _rx struct and the _tx struct as we do + * not know if the device has more tx queues than rx, or the opposite. + * This might also change during run time. + */ +static void xdp_reg_umem_at_qid(struct net_device *dev, struct xdp_umem *umem, + u16 queue_id) { - struct netdev_bpf bpf; + if (queue_id < dev->real_num_rx_queues) + dev->_rx[queue_id].umem = umem; + if (queue_id < dev->real_num_tx_queues) + dev->_tx[queue_id].umem = umem; +} - ASSERT_RTNL(); +struct xdp_umem *xdp_get_umem_from_qid(struct net_device *dev, + u16 queue_id) +{ + if (queue_id < dev->real_num_rx_queues) + return dev->_rx[queue_id].umem; + if (queue_id < dev->real_num_tx_queues) + return dev->_tx[queue_id].umem; - memset(&bpf, 0, sizeof(bpf)); - bpf.command = XDP_QUERY_XSK_UMEM; - bpf.xsk.queue_id = queue_id; + return NULL; +} - if (!dev->netdev_ops->ndo_bpf) - return 0; - return dev->netdev_ops->ndo_bpf(dev, &bpf) ?: !!bpf.xsk.umem; +static void xdp_clear_umem_at_qid(struct net_device *dev, u16 queue_id) +{ + if (queue_id < dev->real_num_rx_queues) + dev->_rx[queue_id].umem = NULL; + if (queue_id < dev->real_num_tx_queues) + dev->_tx[queue_id].umem = NULL; } int xdp_umem_assign_dev(struct xdp_umem *umem, struct net_device *dev, - u32 queue_id, u16 flags) + u16 queue_id, u16 flags) { bool force_zc, force_copy; struct netdev_bpf bpf; - int err; + int err = 0; force_zc = flags & XDP_ZEROCOPY; force_copy = flags & XDP_COPY; @@ -70,17 +82,23 @@ int xdp_umem_assign_dev(struct xdp_umem *umem, struct net_device *dev, if (force_zc && force_copy) return -EINVAL; - if (force_copy) - return 0; + rtnl_lock(); + if (xdp_get_umem_from_qid(dev, queue_id)) { + err = -EBUSY; + goto out_rtnl_unlock; + } - if (!dev->netdev_ops->ndo_bpf || !dev->netdev_ops->ndo_xsk_async_xmit) - return force_zc ? -EOPNOTSUPP : 0; /* fail or fallback */ + xdp_reg_umem_at_qid(dev, umem, queue_id); + umem->dev = dev; + umem->queue_id = queue_id; + if (force_copy) + /* For copy-mode, we are done. */ + goto out_rtnl_unlock; - rtnl_lock(); - err = xdp_umem_query(dev, queue_id); - if (err) { - err = err < 0 ? -EOPNOTSUPP : -EBUSY; - goto err_rtnl_unlock; + if (!dev->netdev_ops->ndo_bpf || + !dev->netdev_ops->ndo_xsk_async_xmit) { + err = -EOPNOTSUPP; + goto err_unreg_umem; } bpf.command = XDP_SETUP_XSK_UMEM; @@ -89,18 +107,20 @@ int xdp_umem_assign_dev(struct xdp_umem *umem, struct net_device *dev, err = dev->netdev_ops->ndo_bpf(dev, &bpf); if (err) - goto err_rtnl_unlock; + goto err_unreg_umem; rtnl_unlock(); dev_hold(dev); - umem->dev = dev; - umem->queue_id = queue_id; umem->zc = true; return 0; -err_rtnl_unlock: +err_unreg_umem: + xdp_clear_umem_at_qid(dev, queue_id); + if (!force_zc) + err = 0; /* fallback to copy mode */ +out_rtnl_unlock: rtnl_unlock(); - return force_zc ? err : 0; /* fail or fallback */ + return err; } static void xdp_umem_clear_dev(struct xdp_umem *umem) @@ -108,7 +128,7 @@ static void xdp_umem_clear_dev(struct xdp_umem *umem) struct netdev_bpf bpf; int err; - if (umem->dev) { + if (umem->zc) { bpf.command = XDP_SETUP_XSK_UMEM; bpf.xsk.umem = NULL; bpf.xsk.queue_id = umem->queue_id; @@ -119,9 +139,17 @@ static void xdp_umem_clear_dev(struct xdp_umem *umem) if (err) WARN(1, "failed to disable umem!\n"); + } + + if (umem->dev) { + rtnl_lock(); + xdp_clear_umem_at_qid(umem->dev, umem->queue_id); + rtnl_unlock(); + } + if (umem->zc) { dev_put(umem->dev); - umem->dev = NULL; + umem->zc = false; } } diff --git a/net/xdp/xdp_umem.h b/net/xdp/xdp_umem.h index c8be1ad3eb88..27603227601b 100644 --- a/net/xdp/xdp_umem.h +++ b/net/xdp/xdp_umem.h @@ -9,7 +9,7 @@ #include <net/xdp_sock.h> int xdp_umem_assign_dev(struct xdp_umem *umem, struct net_device *dev, - u32 queue_id, u16 flags); + u16 queue_id, u16 flags); bool xdp_umem_validate_queues(struct xdp_umem *umem); void xdp_get_umem(struct xdp_umem *umem); void xdp_put_umem(struct xdp_umem *umem); diff --git a/net/xdp/xsk.c b/net/xdp/xsk.c index 5a432dfee4ee..07156f43d295 100644 --- a/net/xdp/xsk.c +++ b/net/xdp/xsk.c @@ -355,12 +355,18 @@ static int xsk_release(struct socket *sock) local_bh_enable(); if (xs->dev) { + struct net_device *dev = xs->dev; + /* Wait for driver to stop using the xdp socket. */ - synchronize_net(); - dev_put(xs->dev); + xdp_del_sk_umem(xs->umem, xs); xs->dev = NULL; + synchronize_net(); + dev_put(dev); } + xskq_destroy(xs->rx); + xskq_destroy(xs->tx); + sock_orphan(sk); sock->sk = NULL; @@ -419,13 +425,6 @@ static int xsk_bind(struct socket *sock, struct sockaddr *addr, int addr_len) } qid = sxdp->sxdp_queue_id; - - if ((xs->rx && qid >= dev->real_num_rx_queues) || - (xs->tx && qid >= dev->real_num_tx_queues)) { - err = -EINVAL; - goto out_unlock; - } - flags = sxdp->sxdp_flags; if (flags & XDP_SHARED_UMEM) { @@ -721,9 +720,6 @@ static void xsk_destruct(struct sock *sk) if (!sock_flag(sk, SOCK_DEAD)) return; - xskq_destroy(xs->rx); - xskq_destroy(xs->tx); - xdp_del_sk_umem(xs->umem, xs); xdp_put_umem(xs->umem); sk_refcnt_debug_dec(sk); @@ -758,6 +754,8 @@ static int xsk_create(struct net *net, struct socket *sock, int protocol, sk->sk_destruct = xsk_destruct; sk_refcnt_debug_inc(sk); + sock_set_flag(sk, SOCK_RCU_FREE); + xs = xdp_sk(sk); mutex_init(&xs->mutex); spin_lock_init(&xs->tx_completion_lock); diff --git a/net/xfrm/Kconfig b/net/xfrm/Kconfig index 4a9ee2d83158..140270a13d54 100644 --- a/net/xfrm/Kconfig +++ b/net/xfrm/Kconfig @@ -8,7 +8,6 @@ config XFRM config XFRM_OFFLOAD bool - depends on XFRM config XFRM_ALGO tristate diff --git a/net/xfrm/xfrm_hash.c b/net/xfrm/xfrm_hash.c index 2ad33ce1ea17..eca8d84d99bf 100644 --- a/net/xfrm/xfrm_hash.c +++ b/net/xfrm/xfrm_hash.c @@ -6,7 +6,7 @@ #include <linux/kernel.h> #include <linux/mm.h> -#include <linux/bootmem.h> +#include <linux/memblock.h> #include <linux/vmalloc.h> #include <linux/slab.h> #include <linux/xfrm.h> diff --git a/net/xfrm/xfrm_hash.h b/net/xfrm/xfrm_hash.h index 61be810389d8..ce66323102f9 100644 --- a/net/xfrm/xfrm_hash.h +++ b/net/xfrm/xfrm_hash.h @@ -13,7 +13,7 @@ static inline unsigned int __xfrm4_addr_hash(const xfrm_address_t *addr) static inline unsigned int __xfrm6_addr_hash(const xfrm_address_t *addr) { - return ntohl(addr->a6[2] ^ addr->a6[3]); + return jhash2((__force u32 *)addr->a6, 4, 0); } static inline unsigned int __xfrm4_daddr_saddr_hash(const xfrm_address_t *daddr, @@ -26,8 +26,7 @@ static inline unsigned int __xfrm4_daddr_saddr_hash(const xfrm_address_t *daddr, static inline unsigned int __xfrm6_daddr_saddr_hash(const xfrm_address_t *daddr, const xfrm_address_t *saddr) { - return ntohl(daddr->a6[2] ^ daddr->a6[3] ^ - saddr->a6[2] ^ saddr->a6[3]); + return __xfrm6_addr_hash(daddr) ^ __xfrm6_addr_hash(saddr); } static inline u32 __bits2mask32(__u8 bits) diff --git a/net/xfrm/xfrm_input.c b/net/xfrm/xfrm_input.c index b89c9c7f8c5c..684c0bc01e2c 100644 --- a/net/xfrm/xfrm_input.c +++ b/net/xfrm/xfrm_input.c @@ -131,7 +131,7 @@ struct sec_path *secpath_dup(struct sec_path *src) sp->len = 0; sp->olen = 0; - memset(sp->ovec, 0, sizeof(sp->ovec[XFRM_MAX_OFFLOAD_DEPTH])); + memset(sp->ovec, 0, sizeof(sp->ovec)); if (src) { int i; @@ -458,6 +458,7 @@ resume: XFRM_INC_STATS(net, LINUX_MIB_XFRMINHDRERROR); goto drop; } + crypto_done = false; } while (!err); err = xfrm_rcv_cb(skb, family, x->type->proto, 0); diff --git a/net/xfrm/xfrm_interface.c b/net/xfrm/xfrm_interface.c index dc5b20bf29cf..d679fa0f44b3 100644 --- a/net/xfrm/xfrm_interface.c +++ b/net/xfrm/xfrm_interface.c @@ -116,6 +116,9 @@ static void xfrmi_unlink(struct xfrmi_net *xfrmn, struct xfrm_if *xi) static void xfrmi_dev_free(struct net_device *dev) { + struct xfrm_if *xi = netdev_priv(dev); + + gro_cells_destroy(&xi->gro_cells); free_percpu(dev->tstats); } @@ -561,9 +564,6 @@ static void xfrmi_get_stats64(struct net_device *dev, { int cpu; - if (!dev->tstats) - return; - for_each_possible_cpu(cpu) { struct pcpu_sw_netstats *stats; struct pcpu_sw_netstats tmp; diff --git a/net/xfrm/xfrm_output.c b/net/xfrm/xfrm_output.c index 2d42cb0c94b8..4ae87c5ce2e3 100644 --- a/net/xfrm/xfrm_output.c +++ b/net/xfrm/xfrm_output.c @@ -100,6 +100,10 @@ static int xfrm_output_one(struct sk_buff *skb, int err) spin_unlock_bh(&x->lock); skb_dst_force(skb); + if (!skb_dst(skb)) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTERROR); + goto error_nolock; + } if (xfrm_offload(skb)) { x->type_offload->encap(x, skb); diff --git a/net/xfrm/xfrm_policy.c b/net/xfrm/xfrm_policy.c index 3110c3fbee20..119a427d9b2b 100644 --- a/net/xfrm/xfrm_policy.c +++ b/net/xfrm/xfrm_policy.c @@ -632,9 +632,9 @@ static void xfrm_hash_rebuild(struct work_struct *work) break; } if (newpos) - hlist_add_behind(&policy->bydst, newpos); + hlist_add_behind_rcu(&policy->bydst, newpos); else - hlist_add_head(&policy->bydst, chain); + hlist_add_head_rcu(&policy->bydst, chain); } spin_unlock_bh(&net->xfrm.xfrm_policy_lock); @@ -774,9 +774,9 @@ int xfrm_policy_insert(int dir, struct xfrm_policy *policy, int excl) break; } if (newpos) - hlist_add_behind(&policy->bydst, newpos); + hlist_add_behind_rcu(&policy->bydst, newpos); else - hlist_add_head(&policy->bydst, chain); + hlist_add_head_rcu(&policy->bydst, chain); __xfrm_policy_link(policy, dir); /* After previous checking, family can either be AF_INET or AF_INET6 */ @@ -2491,6 +2491,10 @@ int __xfrm_route_forward(struct sk_buff *skb, unsigned short family) } skb_dst_force(skb); + if (!skb_dst(skb)) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMFWDHDRERROR); + return 0; + } dst = xfrm_lookup(net, skb_dst(skb), &fl, NULL, XFRM_LOOKUP_QUEUE); if (IS_ERR(dst)) { diff --git a/net/xfrm/xfrm_state.c b/net/xfrm/xfrm_state.c index b669262682c9..dc4a9f1fb941 100644 --- a/net/xfrm/xfrm_state.c +++ b/net/xfrm/xfrm_state.c @@ -2077,10 +2077,8 @@ int xfrm_user_policy(struct sock *sk, int optname, u8 __user *optval, int optlen struct xfrm_mgr *km; struct xfrm_policy *pol = NULL; -#ifdef CONFIG_COMPAT if (in_compat_syscall()) return -EOPNOTSUPP; -#endif if (!optval && !optlen) { xfrm_sk_policy_insert(sk, XFRM_POLICY_IN, NULL); diff --git a/net/xfrm/xfrm_user.c b/net/xfrm/xfrm_user.c index 4791aa8b8185..c9a84e22f5d5 100644 --- a/net/xfrm/xfrm_user.c +++ b/net/xfrm/xfrm_user.c @@ -151,10 +151,16 @@ static int verify_newsa_info(struct xfrm_usersa_info *p, err = -EINVAL; switch (p->family) { case AF_INET: + if (p->sel.prefixlen_d > 32 || p->sel.prefixlen_s > 32) + goto out; + break; case AF_INET6: #if IS_ENABLED(CONFIG_IPV6) + if (p->sel.prefixlen_d > 128 || p->sel.prefixlen_s > 128) + goto out; + break; #else err = -EAFNOSUPPORT; @@ -1001,7 +1007,7 @@ static int xfrm_dump_sa(struct sk_buff *skb, struct netlink_callback *cb) int err; err = nlmsg_parse(cb->nlh, 0, attrs, XFRMA_MAX, xfrma_policy, - NULL); + cb->extack); if (err < 0) return err; @@ -1396,10 +1402,16 @@ static int verify_newpolicy_info(struct xfrm_userpolicy_info *p) switch (p->sel.family) { case AF_INET: + if (p->sel.prefixlen_d > 32 || p->sel.prefixlen_s > 32) + return -EINVAL; + break; case AF_INET6: #if IS_ENABLED(CONFIG_IPV6) + if (p->sel.prefixlen_d > 128 || p->sel.prefixlen_s > 128) + return -EINVAL; + break; #else return -EAFNOSUPPORT; @@ -1480,6 +1492,9 @@ static int validate_tmpl(int nr, struct xfrm_user_tmpl *ut, u16 family) (ut[i].family != prev_family)) return -EINVAL; + if (ut[i].mode >= XFRM_MODE_MAX) + return -EINVAL; + prev_family = ut[i].family; switch (ut[i].family) { @@ -2606,10 +2621,8 @@ static int xfrm_user_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh, const struct xfrm_link *link; int type, err; -#ifdef CONFIG_COMPAT if (in_compat_syscall()) return -EOPNOTSUPP; -#endif type = nlh->nlmsg_type; if (type > XFRM_MSG_MAX) |
