diff options
Diffstat (limited to 'net')
33 files changed, 1436 insertions, 352 deletions
diff --git a/net/9p/Kconfig b/net/9p/Kconfig index bcdab9c23b40..63f988f0c9e8 100644 --- a/net/9p/Kconfig +++ b/net/9p/Kconfig @@ -40,6 +40,12 @@ config NET_9P_XEN This builds support for a transport for 9pfs between two Xen domains. +config NET_9P_USBG + bool "9P USB Gadget Transport" + depends on USB_GADGET=y || USB_GADGET=NET_9P + help + This builds support for a transport for 9pfs over + usb gadget. config NET_9P_RDMA depends on INET && INFINIBAND && INFINIBAND_ADDR_TRANS diff --git a/net/9p/Makefile b/net/9p/Makefile index 1df9b344c30b..22794a451c3f 100644 --- a/net/9p/Makefile +++ b/net/9p/Makefile @@ -4,6 +4,7 @@ obj-$(CONFIG_NET_9P_FD) += 9pnet_fd.o obj-$(CONFIG_NET_9P_XEN) += 9pnet_xen.o obj-$(CONFIG_NET_9P_VIRTIO) += 9pnet_virtio.o obj-$(CONFIG_NET_9P_RDMA) += 9pnet_rdma.o +obj-$(CONFIG_NET_9P_USBG) += 9pnet_usbg.o 9pnet-objs := \ mod.o \ @@ -23,3 +24,6 @@ obj-$(CONFIG_NET_9P_RDMA) += 9pnet_rdma.o 9pnet_rdma-objs := \ trans_rdma.o \ + +9pnet_usbg-objs := \ + trans_usbg.o \ diff --git a/net/9p/trans_usbg.c b/net/9p/trans_usbg.c new file mode 100644 index 000000000000..975b76839dca --- /dev/null +++ b/net/9p/trans_usbg.c @@ -0,0 +1,956 @@ +// SPDX-License-Identifier: GPL-2.0+ +/* + * trans_usbg.c - USB peripheral usb9pfs configuration driver and transport. + * + * Copyright (C) 2024 Michael Grzeschik <m.grzeschik@pengutronix.de> + */ + +/* Gadget usb9pfs only needs two bulk endpoints, and will use the usb9pfs + * transport to mount host exported filesystem via usb gadget. + */ + +/* +--------------------------+ | +--------------------------+ + * | 9PFS mounting client | | | 9PFS exporting server | + * SW | | | | | + * | (this:trans_usbg) | | |(e.g. diod or nfs-ganesha)| + * +-------------^------------+ | +-------------^------------+ + * | | | + * ------------------|------------------------------------|------------- + * | | | + * +-------------v------------+ | +-------------v------------+ + * | | | | | + * HW | USB Device Controller <---------> USB Host Controller | + * | | | | | + * +--------------------------+ | +--------------------------+ + */ + +#include <linux/cleanup.h> +#include <linux/kernel.h> +#include <linux/module.h> +#include <linux/usb/composite.h> +#include <linux/usb/func_utils.h> + +#include <net/9p/9p.h> +#include <net/9p/client.h> +#include <net/9p/transport.h> + +#define DEFAULT_BUFLEN 16384 + +struct f_usb9pfs { + struct p9_client *client; + + /* 9p request lock for en/dequeue */ + spinlock_t lock; + + struct usb_request *in_req; + struct usb_request *out_req; + + struct usb_ep *in_ep; + struct usb_ep *out_ep; + + struct completion send; + struct completion received; + + unsigned int buflen; + + struct usb_function function; +}; + +static inline struct f_usb9pfs *func_to_usb9pfs(struct usb_function *f) +{ + return container_of(f, struct f_usb9pfs, function); +} + +struct f_usb9pfs_opts { + struct usb_function_instance func_inst; + unsigned int buflen; + + struct f_usb9pfs_dev *dev; + + /* Read/write access to configfs attributes is handled by configfs. + * + * This is to protect the data from concurrent access by read/write + * and create symlink/remove symlink. + */ + struct mutex lock; + int refcnt; +}; + +struct f_usb9pfs_dev { + struct f_usb9pfs *usb9pfs; + struct f_usb9pfs_opts *opts; + char tag[41]; + bool inuse; + + struct list_head usb9pfs_instance; +}; + +static DEFINE_MUTEX(usb9pfs_lock); +static struct list_head usbg_instance_list; + +static int usb9pfs_queue_tx(struct f_usb9pfs *usb9pfs, struct p9_req_t *p9_tx_req, + gfp_t gfp_flags) +{ + struct usb_composite_dev *cdev = usb9pfs->function.config->cdev; + struct usb_request *req = usb9pfs->in_req; + int ret; + + if (!(p9_tx_req->tc.size % usb9pfs->in_ep->maxpacket)) + req->zero = 1; + + req->buf = p9_tx_req->tc.sdata; + req->length = p9_tx_req->tc.size; + req->context = p9_tx_req; + + dev_dbg(&cdev->gadget->dev, "%s usb9pfs send --> %d/%d, zero: %d\n", + usb9pfs->in_ep->name, req->actual, req->length, req->zero); + + ret = usb_ep_queue(usb9pfs->in_ep, req, gfp_flags); + if (ret) + req->context = NULL; + + dev_dbg(&cdev->gadget->dev, "tx submit --> %d\n", ret); + + return ret; +} + +static int usb9pfs_queue_rx(struct f_usb9pfs *usb9pfs, struct usb_request *req, + gfp_t gfp_flags) +{ + struct usb_composite_dev *cdev = usb9pfs->function.config->cdev; + int ret; + + ret = usb_ep_queue(usb9pfs->out_ep, req, gfp_flags); + + dev_dbg(&cdev->gadget->dev, "rx submit --> %d\n", ret); + + return ret; +} + +static int usb9pfs_transmit(struct f_usb9pfs *usb9pfs, struct p9_req_t *p9_req) +{ + int ret = 0; + + guard(spinlock_irqsave)(&usb9pfs->lock); + + ret = usb9pfs_queue_tx(usb9pfs, p9_req, GFP_ATOMIC); + if (ret) + return ret; + + list_del(&p9_req->req_list); + + p9_req_get(p9_req); + + return ret; +} + +static void usb9pfs_tx_complete(struct usb_ep *ep, struct usb_request *req) +{ + struct f_usb9pfs *usb9pfs = ep->driver_data; + struct usb_composite_dev *cdev = usb9pfs->function.config->cdev; + struct p9_req_t *p9_tx_req = req->context; + unsigned long flags; + + /* reset zero packages */ + req->zero = 0; + + if (req->status) { + dev_err(&cdev->gadget->dev, "%s usb9pfs complete --> %d, %d/%d\n", + ep->name, req->status, req->actual, req->length); + return; + } + + dev_dbg(&cdev->gadget->dev, "%s usb9pfs complete --> %d, %d/%d\n", + ep->name, req->status, req->actual, req->length); + + spin_lock_irqsave(&usb9pfs->lock, flags); + WRITE_ONCE(p9_tx_req->status, REQ_STATUS_SENT); + + p9_req_put(usb9pfs->client, p9_tx_req); + + req->context = NULL; + + spin_unlock_irqrestore(&usb9pfs->lock, flags); + + complete(&usb9pfs->send); +} + +static struct p9_req_t *usb9pfs_rx_header(struct f_usb9pfs *usb9pfs, void *buf) +{ + struct p9_req_t *p9_rx_req; + struct p9_fcall rc; + int ret; + + /* start by reading header */ + rc.sdata = buf; + rc.offset = 0; + rc.capacity = P9_HDRSZ; + rc.size = P9_HDRSZ; + + p9_debug(P9_DEBUG_TRANS, "mux %p got %zu bytes\n", usb9pfs, + rc.capacity - rc.offset); + + ret = p9_parse_header(&rc, &rc.size, NULL, NULL, 0); + if (ret) { + p9_debug(P9_DEBUG_ERROR, + "error parsing header: %d\n", ret); + return NULL; + } + + p9_debug(P9_DEBUG_TRANS, + "mux %p pkt: size: %d bytes tag: %d\n", + usb9pfs, rc.size, rc.tag); + + p9_rx_req = p9_tag_lookup(usb9pfs->client, rc.tag); + if (!p9_rx_req || p9_rx_req->status != REQ_STATUS_SENT) { + p9_debug(P9_DEBUG_ERROR, "Unexpected packet tag %d\n", rc.tag); + return NULL; + } + + if (rc.size > p9_rx_req->rc.capacity) { + p9_debug(P9_DEBUG_ERROR, + "requested packet size too big: %d for tag %d with capacity %zd\n", + rc.size, rc.tag, p9_rx_req->rc.capacity); + p9_req_put(usb9pfs->client, p9_rx_req); + return NULL; + } + + if (!p9_rx_req->rc.sdata) { + p9_debug(P9_DEBUG_ERROR, + "No recv fcall for tag %d (req %p), disconnecting!\n", + rc.tag, p9_rx_req); + p9_req_put(usb9pfs->client, p9_rx_req); + return NULL; + } + + return p9_rx_req; +} + +static void usb9pfs_rx_complete(struct usb_ep *ep, struct usb_request *req) +{ + struct f_usb9pfs *usb9pfs = ep->driver_data; + struct usb_composite_dev *cdev = usb9pfs->function.config->cdev; + struct p9_req_t *p9_rx_req; + + if (req->status) { + dev_err(&cdev->gadget->dev, "%s usb9pfs complete --> %d, %d/%d\n", + ep->name, req->status, req->actual, req->length); + return; + } + + p9_rx_req = usb9pfs_rx_header(usb9pfs, req->buf); + if (!p9_rx_req) + return; + + memcpy(p9_rx_req->rc.sdata, req->buf, req->actual); + + p9_rx_req->rc.size = req->actual; + + p9_client_cb(usb9pfs->client, p9_rx_req, REQ_STATUS_RCVD); + p9_req_put(usb9pfs->client, p9_rx_req); + + complete(&usb9pfs->received); +} + +static void disable_ep(struct usb_composite_dev *cdev, struct usb_ep *ep) +{ + int value; + + value = usb_ep_disable(ep); + if (value < 0) + dev_info(&cdev->gadget->dev, + "disable %s --> %d\n", ep->name, value); +} + +static void disable_usb9pfs(struct f_usb9pfs *usb9pfs) +{ + struct usb_composite_dev *cdev = + usb9pfs->function.config->cdev; + + if (usb9pfs->in_req) { + usb_ep_free_request(usb9pfs->in_ep, usb9pfs->in_req); + usb9pfs->in_req = NULL; + } + + if (usb9pfs->out_req) { + usb_ep_free_request(usb9pfs->out_ep, usb9pfs->out_req); + usb9pfs->out_req = NULL; + } + + disable_ep(cdev, usb9pfs->in_ep); + disable_ep(cdev, usb9pfs->out_ep); + dev_dbg(&cdev->gadget->dev, "%s disabled\n", + usb9pfs->function.name); +} + +static int alloc_requests(struct usb_composite_dev *cdev, + struct f_usb9pfs *usb9pfs) +{ + int ret; + + usb9pfs->in_req = usb_ep_alloc_request(usb9pfs->in_ep, GFP_ATOMIC); + if (!usb9pfs->in_req) { + ret = -ENOENT; + goto fail; + } + + usb9pfs->out_req = alloc_ep_req(usb9pfs->out_ep, usb9pfs->buflen); + if (!usb9pfs->out_req) { + ret = -ENOENT; + goto fail_in; + } + + usb9pfs->in_req->complete = usb9pfs_tx_complete; + usb9pfs->out_req->complete = usb9pfs_rx_complete; + + /* length will be set in complete routine */ + usb9pfs->in_req->context = usb9pfs; + usb9pfs->out_req->context = usb9pfs; + + return 0; + +fail_in: + usb_ep_free_request(usb9pfs->in_ep, usb9pfs->in_req); +fail: + return ret; +} + +static int enable_endpoint(struct usb_composite_dev *cdev, + struct f_usb9pfs *usb9pfs, struct usb_ep *ep) +{ + int ret; + + ret = config_ep_by_speed(cdev->gadget, &usb9pfs->function, ep); + if (ret) + return ret; + + ret = usb_ep_enable(ep); + if (ret < 0) + return ret; + + ep->driver_data = usb9pfs; + + return 0; +} + +static int +enable_usb9pfs(struct usb_composite_dev *cdev, struct f_usb9pfs *usb9pfs) +{ + struct p9_client *client; + int ret = 0; + + ret = enable_endpoint(cdev, usb9pfs, usb9pfs->in_ep); + if (ret) + goto out; + + ret = enable_endpoint(cdev, usb9pfs, usb9pfs->out_ep); + if (ret) + goto disable_in; + + ret = alloc_requests(cdev, usb9pfs); + if (ret) + goto disable_out; + + client = usb9pfs->client; + if (client) + client->status = Connected; + + dev_dbg(&cdev->gadget->dev, "%s enabled\n", usb9pfs->function.name); + return 0; + +disable_out: + usb_ep_disable(usb9pfs->out_ep); +disable_in: + usb_ep_disable(usb9pfs->in_ep); +out: + return ret; +} + +static int p9_usbg_create(struct p9_client *client, const char *devname, char *args) +{ + struct f_usb9pfs_dev *dev; + struct f_usb9pfs *usb9pfs; + int ret = -ENOENT; + int found = 0; + + if (!devname) + return -EINVAL; + + guard(mutex)(&usb9pfs_lock); + + list_for_each_entry(dev, &usbg_instance_list, usb9pfs_instance) { + if (!strncmp(devname, dev->tag, strlen(devname))) { + if (!dev->inuse) { + dev->inuse = true; + found = 1; + break; + } + ret = -EBUSY; + break; + } + } + + if (!found) { + pr_err("no channels available for device %s\n", devname); + return ret; + } + + usb9pfs = dev->usb9pfs; + if (!usb9pfs) + return -EINVAL; + + client->trans = (void *)usb9pfs; + if (!usb9pfs->in_req) + client->status = Disconnected; + else + client->status = Connected; + usb9pfs->client = client; + + client->trans_mod->maxsize = usb9pfs->buflen; + + complete(&usb9pfs->received); + + return 0; +} + +static void usb9pfs_clear_tx(struct f_usb9pfs *usb9pfs) +{ + struct p9_req_t *req; + + guard(spinlock_irqsave)(&usb9pfs->lock); + + req = usb9pfs->in_req->context; + if (!req) + return; + + if (!req->t_err) + req->t_err = -ECONNRESET; + + p9_client_cb(usb9pfs->client, req, REQ_STATUS_ERROR); +} + +static void p9_usbg_close(struct p9_client *client) +{ + struct f_usb9pfs *usb9pfs; + struct f_usb9pfs_dev *dev; + struct f_usb9pfs_opts *opts; + + if (!client) + return; + + usb9pfs = client->trans; + if (!usb9pfs) + return; + + client->status = Disconnected; + + usb9pfs_clear_tx(usb9pfs); + + opts = container_of(usb9pfs->function.fi, + struct f_usb9pfs_opts, func_inst); + + dev = opts->dev; + + mutex_lock(&usb9pfs_lock); + dev->inuse = false; + mutex_unlock(&usb9pfs_lock); +} + +static int p9_usbg_request(struct p9_client *client, struct p9_req_t *p9_req) +{ + struct f_usb9pfs *usb9pfs = client->trans; + int ret; + + if (client->status != Connected) + return -EBUSY; + + ret = wait_for_completion_killable(&usb9pfs->received); + if (ret) + return ret; + + ret = usb9pfs_transmit(usb9pfs, p9_req); + if (ret) + return ret; + + ret = wait_for_completion_killable(&usb9pfs->send); + if (ret) + return ret; + + return usb9pfs_queue_rx(usb9pfs, usb9pfs->out_req, GFP_ATOMIC); +} + +static int p9_usbg_cancel(struct p9_client *client, struct p9_req_t *req) +{ + struct f_usb9pfs *usb9pfs = client->trans; + int ret = 1; + + p9_debug(P9_DEBUG_TRANS, "client %p req %p\n", client, req); + + guard(spinlock_irqsave)(&usb9pfs->lock); + + if (req->status == REQ_STATUS_UNSENT) { + list_del(&req->req_list); + WRITE_ONCE(req->status, REQ_STATUS_FLSHD); + p9_req_put(client, req); + ret = 0; + } + + return ret; +} + +static struct p9_trans_module p9_usbg_trans = { + .name = "usbg", + .create = p9_usbg_create, + .close = p9_usbg_close, + .request = p9_usbg_request, + .cancel = p9_usbg_cancel, + .owner = THIS_MODULE, +}; + +/*-------------------------------------------------------------------------*/ + +#define USB_PROTOCOL_9PFS 0x09 + +static struct usb_interface_descriptor usb9pfs_intf = { + .bLength = sizeof(usb9pfs_intf), + .bDescriptorType = USB_DT_INTERFACE, + + .bNumEndpoints = 2, + .bInterfaceClass = USB_CLASS_VENDOR_SPEC, + .bInterfaceSubClass = USB_SUBCLASS_VENDOR_SPEC, + .bInterfaceProtocol = USB_PROTOCOL_9PFS, + + /* .iInterface = DYNAMIC */ +}; + +/* full speed support: */ + +static struct usb_endpoint_descriptor fs_usb9pfs_source_desc = { + .bLength = USB_DT_ENDPOINT_SIZE, + .bDescriptorType = USB_DT_ENDPOINT, + + .bEndpointAddress = USB_DIR_IN, + .bmAttributes = USB_ENDPOINT_XFER_BULK, +}; + +static struct usb_endpoint_descriptor fs_usb9pfs_sink_desc = { + .bLength = USB_DT_ENDPOINT_SIZE, + .bDescriptorType = USB_DT_ENDPOINT, + + .bEndpointAddress = USB_DIR_OUT, + .bmAttributes = USB_ENDPOINT_XFER_BULK, +}; + +static struct usb_descriptor_header *fs_usb9pfs_descs[] = { + (struct usb_descriptor_header *)&usb9pfs_intf, + (struct usb_descriptor_header *)&fs_usb9pfs_sink_desc, + (struct usb_descriptor_header *)&fs_usb9pfs_source_desc, + NULL, +}; + +/* high speed support: */ + +static struct usb_endpoint_descriptor hs_usb9pfs_source_desc = { + .bLength = USB_DT_ENDPOINT_SIZE, + .bDescriptorType = USB_DT_ENDPOINT, + + .bmAttributes = USB_ENDPOINT_XFER_BULK, + .wMaxPacketSize = cpu_to_le16(512), +}; + +static struct usb_endpoint_descriptor hs_usb9pfs_sink_desc = { + .bLength = USB_DT_ENDPOINT_SIZE, + .bDescriptorType = USB_DT_ENDPOINT, + + .bmAttributes = USB_ENDPOINT_XFER_BULK, + .wMaxPacketSize = cpu_to_le16(512), +}; + +static struct usb_descriptor_header *hs_usb9pfs_descs[] = { + (struct usb_descriptor_header *)&usb9pfs_intf, + (struct usb_descriptor_header *)&hs_usb9pfs_source_desc, + (struct usb_descriptor_header *)&hs_usb9pfs_sink_desc, + NULL, +}; + +/* super speed support: */ + +static struct usb_endpoint_descriptor ss_usb9pfs_source_desc = { + .bLength = USB_DT_ENDPOINT_SIZE, + .bDescriptorType = USB_DT_ENDPOINT, + + .bmAttributes = USB_ENDPOINT_XFER_BULK, + .wMaxPacketSize = cpu_to_le16(1024), +}; + +static struct usb_ss_ep_comp_descriptor ss_usb9pfs_source_comp_desc = { + .bLength = USB_DT_SS_EP_COMP_SIZE, + .bDescriptorType = USB_DT_SS_ENDPOINT_COMP, + .bMaxBurst = 0, + .bmAttributes = 0, + .wBytesPerInterval = 0, +}; + +static struct usb_endpoint_descriptor ss_usb9pfs_sink_desc = { + .bLength = USB_DT_ENDPOINT_SIZE, + .bDescriptorType = USB_DT_ENDPOINT, + + .bmAttributes = USB_ENDPOINT_XFER_BULK, + .wMaxPacketSize = cpu_to_le16(1024), +}; + +static struct usb_ss_ep_comp_descriptor ss_usb9pfs_sink_comp_desc = { + .bLength = USB_DT_SS_EP_COMP_SIZE, + .bDescriptorType = USB_DT_SS_ENDPOINT_COMP, + .bMaxBurst = 0, + .bmAttributes = 0, + .wBytesPerInterval = 0, +}; + +static struct usb_descriptor_header *ss_usb9pfs_descs[] = { + (struct usb_descriptor_header *)&usb9pfs_intf, + (struct usb_descriptor_header *)&ss_usb9pfs_source_desc, + (struct usb_descriptor_header *)&ss_usb9pfs_source_comp_desc, + (struct usb_descriptor_header *)&ss_usb9pfs_sink_desc, + (struct usb_descriptor_header *)&ss_usb9pfs_sink_comp_desc, + NULL, +}; + +/* function-specific strings: */ +static struct usb_string strings_usb9pfs[] = { + [0].s = "usb9pfs input to output", + { } /* end of list */ +}; + +static struct usb_gadget_strings stringtab_usb9pfs = { + .language = 0x0409, /* en-us */ + .strings = strings_usb9pfs, +}; + +static struct usb_gadget_strings *usb9pfs_strings[] = { + &stringtab_usb9pfs, + NULL, +}; + +/*-------------------------------------------------------------------------*/ + +static int usb9pfs_func_bind(struct usb_configuration *c, + struct usb_function *f) +{ + struct f_usb9pfs *usb9pfs = func_to_usb9pfs(f); + struct f_usb9pfs_opts *opts; + struct usb_composite_dev *cdev = c->cdev; + int ret; + int id; + + /* allocate interface ID(s) */ + id = usb_interface_id(c, f); + if (id < 0) + return id; + usb9pfs_intf.bInterfaceNumber = id; + + id = usb_string_id(cdev); + if (id < 0) + return id; + strings_usb9pfs[0].id = id; + usb9pfs_intf.iInterface = id; + + /* allocate endpoints */ + usb9pfs->in_ep = usb_ep_autoconfig(cdev->gadget, + &fs_usb9pfs_source_desc); + if (!usb9pfs->in_ep) + goto autoconf_fail; + + usb9pfs->out_ep = usb_ep_autoconfig(cdev->gadget, + &fs_usb9pfs_sink_desc); + if (!usb9pfs->out_ep) + goto autoconf_fail; + + /* support high speed hardware */ + hs_usb9pfs_source_desc.bEndpointAddress = + fs_usb9pfs_source_desc.bEndpointAddress; + hs_usb9pfs_sink_desc.bEndpointAddress = + fs_usb9pfs_sink_desc.bEndpointAddress; + + /* support super speed hardware */ + ss_usb9pfs_source_desc.bEndpointAddress = + fs_usb9pfs_source_desc.bEndpointAddress; + ss_usb9pfs_sink_desc.bEndpointAddress = + fs_usb9pfs_sink_desc.bEndpointAddress; + + ret = usb_assign_descriptors(f, fs_usb9pfs_descs, hs_usb9pfs_descs, + ss_usb9pfs_descs, ss_usb9pfs_descs); + if (ret) + return ret; + + opts = container_of(f->fi, struct f_usb9pfs_opts, func_inst); + opts->dev->usb9pfs = usb9pfs; + + dev_dbg(&cdev->gadget->dev, "%s speed %s: IN/%s, OUT/%s\n", + (gadget_is_superspeed(c->cdev->gadget) ? "super" : + (gadget_is_dualspeed(c->cdev->gadget) ? "dual" : "full")), + f->name, usb9pfs->in_ep->name, usb9pfs->out_ep->name); + + return 0; + +autoconf_fail: + ERROR(cdev, "%s: can't autoconfigure on %s\n", + f->name, cdev->gadget->name); + return -ENODEV; +} + +static void usb9pfs_func_unbind(struct usb_configuration *c, + struct usb_function *f) +{ + struct f_usb9pfs *usb9pfs = func_to_usb9pfs(f); + + disable_usb9pfs(usb9pfs); +} + +static void usb9pfs_free_func(struct usb_function *f) +{ + struct f_usb9pfs *usb9pfs = func_to_usb9pfs(f); + struct f_usb9pfs_opts *opts; + + kfree(usb9pfs); + + opts = container_of(f->fi, struct f_usb9pfs_opts, func_inst); + + mutex_lock(&opts->lock); + opts->refcnt--; + mutex_unlock(&opts->lock); + + usb_free_all_descriptors(f); +} + +static int usb9pfs_set_alt(struct usb_function *f, + unsigned int intf, unsigned int alt) +{ + struct f_usb9pfs *usb9pfs = func_to_usb9pfs(f); + struct usb_composite_dev *cdev = f->config->cdev; + + return enable_usb9pfs(cdev, usb9pfs); +} + +static void usb9pfs_disable(struct usb_function *f) +{ + struct f_usb9pfs *usb9pfs = func_to_usb9pfs(f); + + usb9pfs_clear_tx(usb9pfs); +} + +static struct usb_function *usb9pfs_alloc(struct usb_function_instance *fi) +{ + struct f_usb9pfs_opts *usb9pfs_opts; + struct f_usb9pfs *usb9pfs; + + usb9pfs = kzalloc(sizeof(*usb9pfs), GFP_KERNEL); + if (!usb9pfs) + return ERR_PTR(-ENOMEM); + + spin_lock_init(&usb9pfs->lock); + + init_completion(&usb9pfs->send); + init_completion(&usb9pfs->received); + + usb9pfs_opts = container_of(fi, struct f_usb9pfs_opts, func_inst); + + mutex_lock(&usb9pfs_opts->lock); + usb9pfs_opts->refcnt++; + mutex_unlock(&usb9pfs_opts->lock); + + usb9pfs->buflen = usb9pfs_opts->buflen; + + usb9pfs->function.name = "usb9pfs"; + usb9pfs->function.bind = usb9pfs_func_bind; + usb9pfs->function.unbind = usb9pfs_func_unbind; + usb9pfs->function.set_alt = usb9pfs_set_alt; + usb9pfs->function.disable = usb9pfs_disable; + usb9pfs->function.strings = usb9pfs_strings; + + usb9pfs->function.free_func = usb9pfs_free_func; + + return &usb9pfs->function; +} + +static inline struct f_usb9pfs_opts *to_f_usb9pfs_opts(struct config_item *item) +{ + return container_of(to_config_group(item), struct f_usb9pfs_opts, + func_inst.group); +} + +static inline struct f_usb9pfs_opts *fi_to_f_usb9pfs_opts(struct usb_function_instance *fi) +{ + return container_of(fi, struct f_usb9pfs_opts, func_inst); +} + +static void usb9pfs_attr_release(struct config_item *item) +{ + struct f_usb9pfs_opts *usb9pfs_opts = to_f_usb9pfs_opts(item); + + usb_put_function_instance(&usb9pfs_opts->func_inst); +} + +static struct configfs_item_operations usb9pfs_item_ops = { + .release = usb9pfs_attr_release, +}; + +static ssize_t f_usb9pfs_opts_buflen_show(struct config_item *item, char *page) +{ + struct f_usb9pfs_opts *opts = to_f_usb9pfs_opts(item); + int ret; + + mutex_lock(&opts->lock); + ret = sysfs_emit(page, "%d\n", opts->buflen); + mutex_unlock(&opts->lock); + + return ret; +} + +static ssize_t f_usb9pfs_opts_buflen_store(struct config_item *item, + const char *page, size_t len) +{ + struct f_usb9pfs_opts *opts = to_f_usb9pfs_opts(item); + int ret; + u32 num; + + guard(mutex)(&opts->lock); + + if (opts->refcnt) + return -EBUSY; + + ret = kstrtou32(page, 0, &num); + if (ret) + return ret; + + opts->buflen = num; + + return len; +} + +CONFIGFS_ATTR(f_usb9pfs_opts_, buflen); + +static struct configfs_attribute *usb9pfs_attrs[] = { + &f_usb9pfs_opts_attr_buflen, + NULL, +}; + +static const struct config_item_type usb9pfs_func_type = { + .ct_item_ops = &usb9pfs_item_ops, + .ct_attrs = usb9pfs_attrs, + .ct_owner = THIS_MODULE, +}; + +static struct f_usb9pfs_dev *_usb9pfs_do_find_dev(const char *tag) +{ + struct f_usb9pfs_dev *usb9pfs_dev; + + if (!tag) + return NULL; + + list_for_each_entry(usb9pfs_dev, &usbg_instance_list, usb9pfs_instance) { + if (strcmp(usb9pfs_dev->tag, tag) == 0) + return usb9pfs_dev; + } + + return NULL; +} + +static int usb9pfs_tag_instance(struct f_usb9pfs_dev *dev, const char *tag) +{ + struct f_usb9pfs_dev *existing; + int ret = 0; + + guard(mutex)(&usb9pfs_lock); + + existing = _usb9pfs_do_find_dev(tag); + if (!existing) + strscpy(dev->tag, tag, ARRAY_SIZE(dev->tag)); + else if (existing != dev) + ret = -EBUSY; + + return ret; +} + +static int usb9pfs_set_inst_tag(struct usb_function_instance *fi, const char *tag) +{ + if (strlen(tag) >= sizeof_field(struct f_usb9pfs_dev, tag)) + return -ENAMETOOLONG; + return usb9pfs_tag_instance(fi_to_f_usb9pfs_opts(fi)->dev, tag); +} + +static void usb9pfs_free_instance(struct usb_function_instance *fi) +{ + struct f_usb9pfs_opts *usb9pfs_opts = + container_of(fi, struct f_usb9pfs_opts, func_inst); + struct f_usb9pfs_dev *dev = usb9pfs_opts->dev; + + mutex_lock(&usb9pfs_lock); + list_del(&dev->usb9pfs_instance); + mutex_unlock(&usb9pfs_lock); + + kfree(usb9pfs_opts); +} + +static struct usb_function_instance *usb9pfs_alloc_instance(void) +{ + struct f_usb9pfs_opts *usb9pfs_opts; + struct f_usb9pfs_dev *dev; + + usb9pfs_opts = kzalloc(sizeof(*usb9pfs_opts), GFP_KERNEL); + if (!usb9pfs_opts) + return ERR_PTR(-ENOMEM); + + mutex_init(&usb9pfs_opts->lock); + + usb9pfs_opts->func_inst.set_inst_name = usb9pfs_set_inst_tag; + usb9pfs_opts->func_inst.free_func_inst = usb9pfs_free_instance; + + usb9pfs_opts->buflen = DEFAULT_BUFLEN; + + dev = kzalloc(sizeof(*dev), GFP_KERNEL); + if (IS_ERR(dev)) { + kfree(usb9pfs_opts); + return ERR_CAST(dev); + } + + usb9pfs_opts->dev = dev; + dev->opts = usb9pfs_opts; + + config_group_init_type_name(&usb9pfs_opts->func_inst.group, "", + &usb9pfs_func_type); + + mutex_lock(&usb9pfs_lock); + list_add_tail(&dev->usb9pfs_instance, &usbg_instance_list); + mutex_unlock(&usb9pfs_lock); + + return &usb9pfs_opts->func_inst; +} +DECLARE_USB_FUNCTION(usb9pfs, usb9pfs_alloc_instance, usb9pfs_alloc); + +static int __init usb9pfs_modinit(void) +{ + int ret; + + INIT_LIST_HEAD(&usbg_instance_list); + + ret = usb_function_register(&usb9pfsusb_func); + if (!ret) + v9fs_register_trans(&p9_usbg_trans); + + return ret; +} + +static void __exit usb9pfs_modexit(void) +{ + usb_function_unregister(&usb9pfsusb_func); + v9fs_unregister_trans(&p9_usbg_trans); +} + +module_init(usb9pfs_modinit); +module_exit(usb9pfs_modexit); + +MODULE_ALIAS_9P("usbg"); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("USB gadget 9pfs transport"); +MODULE_AUTHOR("Michael Grzeschik"); diff --git a/net/core/net_namespace.c b/net/core/net_namespace.c index 11e4dd4f09ed..e39479f1c9a4 100644 --- a/net/core/net_namespace.c +++ b/net/core/net_namespace.c @@ -697,11 +697,11 @@ struct net *get_net_ns_by_fd(int fd) struct fd f = fdget(fd); struct net *net = ERR_PTR(-EINVAL); - if (!f.file) + if (!fd_file(f)) return ERR_PTR(-EBADF); - if (proc_ns_file(f.file)) { - struct ns_common *ns = get_proc_ns(file_inode(f.file)); + if (proc_ns_file(fd_file(f))) { + struct ns_common *ns = get_proc_ns(file_inode(fd_file(f))); if (ns->ops == &netns_operations) net = get_net(container_of(ns, struct net, ns)); } diff --git a/net/core/sock_map.c b/net/core/sock_map.c index 724b6856fcc3..242c91a6e3d3 100644 --- a/net/core/sock_map.c +++ b/net/core/sock_map.c @@ -67,46 +67,39 @@ static struct bpf_map *sock_map_alloc(union bpf_attr *attr) int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog) { - u32 ufd = attr->target_fd; struct bpf_map *map; - struct fd f; int ret; if (attr->attach_flags || attr->replace_bpf_fd) return -EINVAL; - f = fdget(ufd); + CLASS(fd, f)(attr->target_fd); map = __bpf_map_get(f); if (IS_ERR(map)) return PTR_ERR(map); mutex_lock(&sockmap_mutex); ret = sock_map_prog_update(map, prog, NULL, NULL, attr->attach_type); mutex_unlock(&sockmap_mutex); - fdput(f); return ret; } int sock_map_prog_detach(const union bpf_attr *attr, enum bpf_prog_type ptype) { - u32 ufd = attr->target_fd; struct bpf_prog *prog; struct bpf_map *map; - struct fd f; int ret; if (attr->attach_flags || attr->replace_bpf_fd) return -EINVAL; - f = fdget(ufd); + CLASS(fd, f)(attr->target_fd); map = __bpf_map_get(f); if (IS_ERR(map)) return PTR_ERR(map); prog = bpf_prog_get(attr->attach_bpf_fd); - if (IS_ERR(prog)) { - ret = PTR_ERR(prog); - goto put_map; - } + if (IS_ERR(prog)) + return PTR_ERR(prog); if (prog->type != ptype) { ret = -EINVAL; @@ -118,8 +111,6 @@ int sock_map_prog_detach(const union bpf_attr *attr, enum bpf_prog_type ptype) mutex_unlock(&sockmap_mutex); put_prog: bpf_prog_put(prog); -put_map: - fdput(f); return ret; } @@ -1551,18 +1542,17 @@ int sock_map_bpf_prog_query(const union bpf_attr *attr, union bpf_attr __user *uattr) { __u32 __user *prog_ids = u64_to_user_ptr(attr->query.prog_ids); - u32 prog_cnt = 0, flags = 0, ufd = attr->target_fd; + u32 prog_cnt = 0, flags = 0; struct bpf_prog **pprog; struct bpf_prog *prog; struct bpf_map *map; - struct fd f; u32 id = 0; int ret; if (attr->query.query_flags) return -EINVAL; - f = fdget(ufd); + CLASS(fd, f)(attr->target_fd); map = __bpf_map_get(f); if (IS_ERR(map)) return PTR_ERR(map); @@ -1594,7 +1584,6 @@ end: copy_to_user(&uattr->query.prog_cnt, &prog_cnt, sizeof(prog_cnt))) ret = -EFAULT; - fdput(f); return ret; } diff --git a/net/ipv4/netfilter/nf_reject_ipv4.c b/net/ipv4/netfilter/nf_reject_ipv4.c index 04504b2b51df..87fd945a0d27 100644 --- a/net/ipv4/netfilter/nf_reject_ipv4.c +++ b/net/ipv4/netfilter/nf_reject_ipv4.c @@ -239,9 +239,8 @@ static int nf_reject_fill_skb_dst(struct sk_buff *skb_in) void nf_send_reset(struct net *net, struct sock *sk, struct sk_buff *oldskb, int hook) { - struct sk_buff *nskb; - struct iphdr *niph; const struct tcphdr *oth; + struct sk_buff *nskb; struct tcphdr _oth; oth = nf_reject_ip_tcphdr_get(oldskb, &_oth, hook); @@ -266,14 +265,12 @@ void nf_send_reset(struct net *net, struct sock *sk, struct sk_buff *oldskb, nskb->mark = IP4_REPLY_MARK(net, oldskb->mark); skb_reserve(nskb, LL_MAX_HEADER); - niph = nf_reject_iphdr_put(nskb, oldskb, IPPROTO_TCP, - ip4_dst_hoplimit(skb_dst(nskb))); + nf_reject_iphdr_put(nskb, oldskb, IPPROTO_TCP, + ip4_dst_hoplimit(skb_dst(nskb))); nf_reject_ip_tcphdr_put(nskb, oldskb, oth); if (ip_route_me_harder(net, sk, nskb, RTN_UNSPEC)) goto free_nskb; - niph = ip_hdr(nskb); - /* "Never happens" */ if (nskb->len > dst_mtu(skb_dst(nskb))) goto free_nskb; @@ -290,6 +287,7 @@ void nf_send_reset(struct net *net, struct sock *sk, struct sk_buff *oldskb, */ if (nf_bridge_info_exists(oldskb)) { struct ethhdr *oeth = eth_hdr(oldskb); + struct iphdr *niph = ip_hdr(nskb); struct net_device *br_indev; br_indev = nf_bridge_get_physindev(oldskb, net); diff --git a/net/ipv6/Kconfig b/net/ipv6/Kconfig index 08d4b7132d4c..1c9c686d9522 100644 --- a/net/ipv6/Kconfig +++ b/net/ipv6/Kconfig @@ -323,6 +323,7 @@ config IPV6_RPL_LWTUNNEL bool "IPv6: RPL Source Routing Header support" depends on IPV6 select LWTUNNEL + select DST_CACHE help Support for RFC6554 RPL Source Routing Header using the lightweight tunnels mechanism. diff --git a/net/ipv6/netfilter/nf_reject_ipv6.c b/net/ipv6/netfilter/nf_reject_ipv6.c index dedee264b8f6..7db0437140bf 100644 --- a/net/ipv6/netfilter/nf_reject_ipv6.c +++ b/net/ipv6/netfilter/nf_reject_ipv6.c @@ -223,33 +223,23 @@ void nf_reject_ip6_tcphdr_put(struct sk_buff *nskb, const struct tcphdr *oth, unsigned int otcplen) { struct tcphdr *tcph; - int needs_ack; skb_reset_transport_header(nskb); - tcph = skb_put(nskb, sizeof(struct tcphdr)); + tcph = skb_put_zero(nskb, sizeof(struct tcphdr)); /* Truncate to length (no data) */ tcph->doff = sizeof(struct tcphdr)/4; tcph->source = oth->dest; tcph->dest = oth->source; if (oth->ack) { - needs_ack = 0; tcph->seq = oth->ack_seq; - tcph->ack_seq = 0; } else { - needs_ack = 1; tcph->ack_seq = htonl(ntohl(oth->seq) + oth->syn + oth->fin + otcplen - (oth->doff<<2)); - tcph->seq = 0; + tcph->ack = 1; } - /* Reset flags */ - ((u_int8_t *)tcph)[13] = 0; tcph->rst = 1; - tcph->ack = needs_ack; - tcph->window = 0; - tcph->urg_ptr = 0; - tcph->check = 0; /* Adjust TCP checksum */ tcph->check = csum_ipv6_magic(&ipv6_hdr(nskb)->saddr, @@ -283,7 +273,6 @@ void nf_send_reset6(struct net *net, struct sock *sk, struct sk_buff *oldskb, const struct tcphdr *otcph; unsigned int otcplen, hh_len; const struct ipv6hdr *oip6h = ipv6_hdr(oldskb); - struct ipv6hdr *ip6h; struct dst_entry *dst = NULL; struct flowi6 fl6; @@ -339,8 +328,7 @@ void nf_send_reset6(struct net *net, struct sock *sk, struct sk_buff *oldskb, nskb->mark = fl6.flowi6_mark; skb_reserve(nskb, hh_len + dst->header_len); - ip6h = nf_reject_ip6hdr_put(nskb, oldskb, IPPROTO_TCP, - ip6_dst_hoplimit(dst)); + nf_reject_ip6hdr_put(nskb, oldskb, IPPROTO_TCP, ip6_dst_hoplimit(dst)); nf_reject_ip6_tcphdr_put(nskb, oldskb, otcph, otcplen); nf_ct_attach(nskb, oldskb); @@ -355,6 +343,7 @@ void nf_send_reset6(struct net *net, struct sock *sk, struct sk_buff *oldskb, */ if (nf_bridge_info_exists(oldskb)) { struct ethhdr *oeth = eth_hdr(oldskb); + struct ipv6hdr *ip6h = ipv6_hdr(nskb); struct net_device *br_indev; br_indev = nf_bridge_get_physindev(oldskb, net); diff --git a/net/mac80211/rc80211_minstrel_ht_debugfs.c b/net/mac80211/rc80211_minstrel_ht_debugfs.c index 25b8a67a63a4..85149c774505 100644 --- a/net/mac80211/rc80211_minstrel_ht_debugfs.c +++ b/net/mac80211/rc80211_minstrel_ht_debugfs.c @@ -187,7 +187,6 @@ static const struct file_operations minstrel_ht_stat_fops = { .open = minstrel_ht_stats_open, .read = minstrel_stats_read, .release = minstrel_stats_release, - .llseek = no_llseek, }; static char * @@ -323,7 +322,6 @@ static const struct file_operations minstrel_ht_stat_csv_fops = { .open = minstrel_ht_stats_csv_open, .read = minstrel_stats_read, .release = minstrel_stats_release, - .llseek = no_llseek, }; void diff --git a/net/netfilter/nf_conntrack_core.c b/net/netfilter/nf_conntrack_core.c index d3cb53b008f5..9db3e2b0b1c3 100644 --- a/net/netfilter/nf_conntrack_core.c +++ b/net/netfilter/nf_conntrack_core.c @@ -988,6 +988,56 @@ static void __nf_conntrack_insert_prepare(struct nf_conn *ct) tstamp->start = ktime_get_real_ns(); } +/** + * nf_ct_match_reverse - check if ct1 and ct2 refer to identical flow + * @ct1: conntrack in hash table to check against + * @ct2: merge candidate + * + * returns true if ct1 and ct2 happen to refer to the same flow, but + * in opposing directions, i.e. + * ct1: a:b -> c:d + * ct2: c:d -> a:b + * for both directions. If so, @ct2 should not have been created + * as the skb should have been picked up as ESTABLISHED flow. + * But ct1 was not yet committed to hash table before skb that created + * ct2 had arrived. + * + * Note we don't compare netns because ct entries in different net + * namespace cannot clash to begin with. + * + * @return: true if ct1 and ct2 are identical when swapping origin/reply. + */ +static bool +nf_ct_match_reverse(const struct nf_conn *ct1, const struct nf_conn *ct2) +{ + u16 id1, id2; + + if (!nf_ct_tuple_equal(&ct1->tuplehash[IP_CT_DIR_ORIGINAL].tuple, + &ct2->tuplehash[IP_CT_DIR_REPLY].tuple)) + return false; + + if (!nf_ct_tuple_equal(&ct1->tuplehash[IP_CT_DIR_REPLY].tuple, + &ct2->tuplehash[IP_CT_DIR_ORIGINAL].tuple)) + return false; + + id1 = nf_ct_zone_id(nf_ct_zone(ct1), IP_CT_DIR_ORIGINAL); + id2 = nf_ct_zone_id(nf_ct_zone(ct2), IP_CT_DIR_REPLY); + if (id1 != id2) + return false; + + id1 = nf_ct_zone_id(nf_ct_zone(ct1), IP_CT_DIR_REPLY); + id2 = nf_ct_zone_id(nf_ct_zone(ct2), IP_CT_DIR_ORIGINAL); + + return id1 == id2; +} + +static int nf_ct_can_merge(const struct nf_conn *ct, + const struct nf_conn *loser_ct) +{ + return nf_ct_match(ct, loser_ct) || + nf_ct_match_reverse(ct, loser_ct); +} + /* caller must hold locks to prevent concurrent changes */ static int __nf_ct_resolve_clash(struct sk_buff *skb, struct nf_conntrack_tuple_hash *h) @@ -999,11 +1049,7 @@ static int __nf_ct_resolve_clash(struct sk_buff *skb, loser_ct = nf_ct_get(skb, &ctinfo); - if (nf_ct_is_dying(ct)) - return NF_DROP; - - if (((ct->status & IPS_NAT_DONE_MASK) == 0) || - nf_ct_match(ct, loser_ct)) { + if (nf_ct_can_merge(ct, loser_ct)) { struct net *net = nf_ct_net(ct); nf_conntrack_get(&ct->ct_general); @@ -2151,80 +2197,6 @@ static void nf_conntrack_attach(struct sk_buff *nskb, const struct sk_buff *skb) nf_conntrack_get(skb_nfct(nskb)); } -static int __nf_conntrack_update(struct net *net, struct sk_buff *skb, - struct nf_conn *ct, - enum ip_conntrack_info ctinfo) -{ - const struct nf_nat_hook *nat_hook; - struct nf_conntrack_tuple_hash *h; - struct nf_conntrack_tuple tuple; - unsigned int status; - int dataoff; - u16 l3num; - u8 l4num; - - l3num = nf_ct_l3num(ct); - - dataoff = get_l4proto(skb, skb_network_offset(skb), l3num, &l4num); - if (dataoff <= 0) - return NF_DROP; - - if (!nf_ct_get_tuple(skb, skb_network_offset(skb), dataoff, l3num, - l4num, net, &tuple)) - return NF_DROP; - - if (ct->status & IPS_SRC_NAT) { - memcpy(tuple.src.u3.all, - ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.src.u3.all, - sizeof(tuple.src.u3.all)); - tuple.src.u.all = - ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.src.u.all; - } - - if (ct->status & IPS_DST_NAT) { - memcpy(tuple.dst.u3.all, - ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.dst.u3.all, - sizeof(tuple.dst.u3.all)); - tuple.dst.u.all = - ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.dst.u.all; - } - - h = nf_conntrack_find_get(net, nf_ct_zone(ct), &tuple); - if (!h) - return NF_ACCEPT; - - /* Store status bits of the conntrack that is clashing to re-do NAT - * mangling according to what it has been done already to this packet. - */ - status = ct->status; - - nf_ct_put(ct); - ct = nf_ct_tuplehash_to_ctrack(h); - nf_ct_set(skb, ct, ctinfo); - - nat_hook = rcu_dereference(nf_nat_hook); - if (!nat_hook) - return NF_ACCEPT; - - if (status & IPS_SRC_NAT) { - unsigned int verdict = nat_hook->manip_pkt(skb, ct, - NF_NAT_MANIP_SRC, - IP_CT_DIR_ORIGINAL); - if (verdict != NF_ACCEPT) - return verdict; - } - - if (status & IPS_DST_NAT) { - unsigned int verdict = nat_hook->manip_pkt(skb, ct, - NF_NAT_MANIP_DST, - IP_CT_DIR_ORIGINAL); - if (verdict != NF_ACCEPT) - return verdict; - } - - return NF_ACCEPT; -} - /* This packet is coming from userspace via nf_queue, complete the packet * processing after the helper invocation in nf_confirm(). */ @@ -2288,17 +2260,6 @@ static int nf_conntrack_update(struct net *net, struct sk_buff *skb) if (!ct) return NF_ACCEPT; - if (!nf_ct_is_confirmed(ct)) { - int ret = __nf_conntrack_update(net, skb, ct, ctinfo); - - if (ret != NF_ACCEPT) - return ret; - - ct = nf_ct_get(skb, &ctinfo); - if (!ct) - return NF_ACCEPT; - } - return nf_confirm_cthelper(skb, ct, ctinfo); } diff --git a/net/netfilter/nf_conntrack_netlink.c b/net/netfilter/nf_conntrack_netlink.c index 123e2e933e9b..6a1239433830 100644 --- a/net/netfilter/nf_conntrack_netlink.c +++ b/net/netfilter/nf_conntrack_netlink.c @@ -382,7 +382,7 @@ nla_put_failure: #define ctnetlink_dump_secctx(a, b) (0) #endif -#ifdef CONFIG_NF_CONNTRACK_LABELS +#ifdef CONFIG_NF_CONNTRACK_EVENTS static inline int ctnetlink_label_size(const struct nf_conn *ct) { struct nf_conn_labels *labels = nf_ct_labels_find(ct); @@ -391,6 +391,7 @@ static inline int ctnetlink_label_size(const struct nf_conn *ct) return 0; return nla_total_size(sizeof(labels->bits)); } +#endif static int ctnetlink_dump_labels(struct sk_buff *skb, const struct nf_conn *ct) @@ -411,10 +412,6 @@ ctnetlink_dump_labels(struct sk_buff *skb, const struct nf_conn *ct) return 0; } -#else -#define ctnetlink_dump_labels(a, b) (0) -#define ctnetlink_label_size(a) (0) -#endif #define master_tuple(ct) &(ct->master->tuplehash[IP_CT_DIR_ORIGINAL].tuple) @@ -652,7 +649,6 @@ static size_t ctnetlink_proto_size(const struct nf_conn *ct) return len + len4; } -#endif static inline size_t ctnetlink_acct_size(const struct nf_conn *ct) { @@ -690,6 +686,7 @@ static inline size_t ctnetlink_timestamp_size(const struct nf_conn *ct) return 0; #endif } +#endif #ifdef CONFIG_NF_CONNTRACK_EVENTS static size_t ctnetlink_nlmsg_size(const struct nf_conn *ct) diff --git a/net/netfilter/nf_nat_core.c b/net/netfilter/nf_nat_core.c index 6d8da6dddf99..4085c436e306 100644 --- a/net/netfilter/nf_nat_core.c +++ b/net/netfilter/nf_nat_core.c @@ -183,7 +183,35 @@ hash_by_src(const struct net *net, return reciprocal_scale(hash, nf_nat_htable_size); } -/* Is this tuple already taken? (not by us) */ +/** + * nf_nat_used_tuple - check if proposed nat tuple clashes with existing entry + * @tuple: proposed NAT binding + * @ignored_conntrack: our (unconfirmed) conntrack entry + * + * A conntrack entry can be inserted to the connection tracking table + * if there is no existing entry with an identical tuple in either direction. + * + * Example: + * INITIATOR -> NAT/PAT -> RESPONDER + * + * INITIATOR passes through NAT/PAT ("us") and SNAT is done (saddr rewrite). + * Then, later, NAT/PAT itself also connects to RESPONDER. + * + * This will not work if the SNAT done earlier has same IP:PORT source pair. + * + * Conntrack table has: + * ORIGINAL: $IP_INITIATOR:$SPORT -> $IP_RESPONDER:$DPORT + * REPLY: $IP_RESPONDER:$DPORT -> $IP_NAT:$SPORT + * + * and new locally originating connection wants: + * ORIGINAL: $IP_NAT:$SPORT -> $IP_RESPONDER:$DPORT + * REPLY: $IP_RESPONDER:$DPORT -> $IP_NAT:$SPORT + * + * ... which would mean incoming packets cannot be distinguished between + * the existing and the newly added entry (identical IP_CT_DIR_REPLY tuple). + * + * @return: true if the proposed NAT mapping collides with an existing entry. + */ static int nf_nat_used_tuple(const struct nf_conntrack_tuple *tuple, const struct nf_conn *ignored_conntrack) @@ -200,6 +228,94 @@ nf_nat_used_tuple(const struct nf_conntrack_tuple *tuple, return nf_conntrack_tuple_taken(&reply, ignored_conntrack); } +static bool nf_nat_allow_clash(const struct nf_conn *ct) +{ + return nf_ct_l4proto_find(nf_ct_protonum(ct))->allow_clash; +} + +/** + * nf_nat_used_tuple_new - check if to-be-inserted conntrack collides with existing entry + * @tuple: proposed NAT binding + * @ignored_ct: our (unconfirmed) conntrack entry + * + * Same as nf_nat_used_tuple, but also check for rare clash in reverse + * direction. Should be called only when @tuple has not been altered, i.e. + * @ignored_conntrack will not be subject to NAT. + * + * @return: true if the proposed NAT mapping collides with existing entry. + */ +static noinline bool +nf_nat_used_tuple_new(const struct nf_conntrack_tuple *tuple, + const struct nf_conn *ignored_ct) +{ + static const unsigned long uses_nat = IPS_NAT_MASK | IPS_SEQ_ADJUST_BIT; + const struct nf_conntrack_tuple_hash *thash; + const struct nf_conntrack_zone *zone; + struct nf_conn *ct; + bool taken = true; + struct net *net; + + if (!nf_nat_used_tuple(tuple, ignored_ct)) + return false; + + if (!nf_nat_allow_clash(ignored_ct)) + return true; + + /* Initial choice clashes with existing conntrack. + * Check for (rare) reverse collision. + * + * This can happen when new packets are received in both directions + * at the exact same time on different CPUs. + * + * Without SMP, first packet creates new conntrack entry and second + * packet is resolved as established reply packet. + * + * With parallel processing, both packets could be picked up as + * new and both get their own ct entry allocated. + * + * If ignored_conntrack and colliding ct are not subject to NAT then + * pretend the tuple is available and let later clash resolution + * handle this at insertion time. + * + * Without it, the 'reply' packet has its source port rewritten + * by nat engine. + */ + if (READ_ONCE(ignored_ct->status) & uses_nat) + return true; + + net = nf_ct_net(ignored_ct); + zone = nf_ct_zone(ignored_ct); + + thash = nf_conntrack_find_get(net, zone, tuple); + if (unlikely(!thash)) /* clashing entry went away */ + return false; + + ct = nf_ct_tuplehash_to_ctrack(thash); + + /* NB: IP_CT_DIR_ORIGINAL should be impossible because + * nf_nat_used_tuple() handles origin collisions. + * + * Handle remote chance other CPU confirmed its ct right after. + */ + if (thash->tuple.dst.dir != IP_CT_DIR_REPLY) + goto out; + + /* clashing connection subject to NAT? Retry with new tuple. */ + if (READ_ONCE(ct->status) & uses_nat) + goto out; + + if (nf_ct_tuple_equal(&ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple, + &ignored_ct->tuplehash[IP_CT_DIR_REPLY].tuple) && + nf_ct_tuple_equal(&ct->tuplehash[IP_CT_DIR_REPLY].tuple, + &ignored_ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple)) { + taken = false; + goto out; + } +out: + nf_ct_put(ct); + return taken; +} + static bool nf_nat_may_kill(struct nf_conn *ct, unsigned long flags) { static const unsigned long flags_refuse = IPS_FIXED_TIMEOUT | @@ -611,7 +727,7 @@ get_unique_tuple(struct nf_conntrack_tuple *tuple, !(range->flags & NF_NAT_RANGE_PROTO_RANDOM_ALL)) { /* try the original tuple first */ if (nf_in_range(orig_tuple, range)) { - if (!nf_nat_used_tuple(orig_tuple, ct)) { + if (!nf_nat_used_tuple_new(orig_tuple, ct)) { *tuple = *orig_tuple; return; } @@ -1208,7 +1324,6 @@ static const struct nf_nat_hook nat_hook = { #ifdef CONFIG_XFRM .decode_session = __nf_nat_decode_session, #endif - .manip_pkt = nf_nat_manip_pkt, .remove_nat_bysrc = nf_nat_cleanup_conntrack, }; diff --git a/net/netfilter/nf_tables_api.c b/net/netfilter/nf_tables_api.c index 57259b5f3ef5..a24fe62650a7 100644 --- a/net/netfilter/nf_tables_api.c +++ b/net/netfilter/nf_tables_api.c @@ -1849,7 +1849,7 @@ static int nft_dump_basechain_hook(struct sk_buff *skb, int family, if (!hook_list) hook_list = &basechain->hook_list; - list_for_each_entry(hook, hook_list, list) { + list_for_each_entry_rcu(hook, hook_list, list) { if (!first) first = hook; @@ -6684,7 +6684,7 @@ static int nft_setelem_catchall_insert(const struct net *net, } } - catchall = kmalloc(sizeof(*catchall), GFP_KERNEL); + catchall = kmalloc(sizeof(*catchall), GFP_KERNEL_ACCOUNT); if (!catchall) return -ENOMEM; @@ -9207,7 +9207,7 @@ static void nf_tables_flowtable_destroy(struct nft_flowtable *flowtable) flowtable->data.type->setup(&flowtable->data, hook->ops.dev, FLOW_BLOCK_UNBIND); list_del_rcu(&hook->list); - kfree(hook); + kfree_rcu(hook, rcu); } kfree(flowtable->name); module_put(flowtable->data.type->owner); diff --git a/net/netfilter/nft_compat.c b/net/netfilter/nft_compat.c index 52cdfee17f73..7ca4f0d21fe2 100644 --- a/net/netfilter/nft_compat.c +++ b/net/netfilter/nft_compat.c @@ -535,7 +535,7 @@ nft_match_large_init(const struct nft_ctx *ctx, const struct nft_expr *expr, struct xt_match *m = expr->ops->data; int ret; - priv->info = kmalloc(XT_ALIGN(m->matchsize), GFP_KERNEL); + priv->info = kmalloc(XT_ALIGN(m->matchsize), GFP_KERNEL_ACCOUNT); if (!priv->info) return -ENOMEM; @@ -808,7 +808,7 @@ nft_match_select_ops(const struct nft_ctx *ctx, goto err; } - ops = kzalloc(sizeof(struct nft_expr_ops), GFP_KERNEL); + ops = kzalloc(sizeof(struct nft_expr_ops), GFP_KERNEL_ACCOUNT); if (!ops) { err = -ENOMEM; goto err; @@ -898,7 +898,7 @@ nft_target_select_ops(const struct nft_ctx *ctx, goto err; } - ops = kzalloc(sizeof(struct nft_expr_ops), GFP_KERNEL); + ops = kzalloc(sizeof(struct nft_expr_ops), GFP_KERNEL_ACCOUNT); if (!ops) { err = -ENOMEM; goto err; diff --git a/net/netfilter/nft_log.c b/net/netfilter/nft_log.c index 5defe6e4fd98..e35588137995 100644 --- a/net/netfilter/nft_log.c +++ b/net/netfilter/nft_log.c @@ -163,7 +163,7 @@ static int nft_log_init(const struct nft_ctx *ctx, nla = tb[NFTA_LOG_PREFIX]; if (nla != NULL) { - priv->prefix = kmalloc(nla_len(nla) + 1, GFP_KERNEL); + priv->prefix = kmalloc(nla_len(nla) + 1, GFP_KERNEL_ACCOUNT); if (priv->prefix == NULL) return -ENOMEM; nla_strscpy(priv->prefix, nla, nla_len(nla) + 1); diff --git a/net/netfilter/nft_meta.c b/net/netfilter/nft_meta.c index 8c8eb14d647b..05cd1e6e6a2f 100644 --- a/net/netfilter/nft_meta.c +++ b/net/netfilter/nft_meta.c @@ -952,7 +952,7 @@ static int nft_secmark_obj_init(const struct nft_ctx *ctx, if (tb[NFTA_SECMARK_CTX] == NULL) return -EINVAL; - priv->ctx = nla_strdup(tb[NFTA_SECMARK_CTX], GFP_KERNEL); + priv->ctx = nla_strdup(tb[NFTA_SECMARK_CTX], GFP_KERNEL_ACCOUNT); if (!priv->ctx) return -ENOMEM; diff --git a/net/netfilter/nft_numgen.c b/net/netfilter/nft_numgen.c index 7d29db7c2ac0..bd058babfc82 100644 --- a/net/netfilter/nft_numgen.c +++ b/net/netfilter/nft_numgen.c @@ -66,7 +66,7 @@ static int nft_ng_inc_init(const struct nft_ctx *ctx, if (priv->offset + priv->modulus - 1 < priv->offset) return -EOVERFLOW; - priv->counter = kmalloc(sizeof(*priv->counter), GFP_KERNEL); + priv->counter = kmalloc(sizeof(*priv->counter), GFP_KERNEL_ACCOUNT); if (!priv->counter) return -ENOMEM; diff --git a/net/netfilter/nft_set_pipapo.c b/net/netfilter/nft_set_pipapo.c index eb4c4a4ac7ac..7be342b495f5 100644 --- a/net/netfilter/nft_set_pipapo.c +++ b/net/netfilter/nft_set_pipapo.c @@ -663,7 +663,7 @@ static int pipapo_realloc_mt(struct nft_pipapo_field *f, check_add_overflow(rules, extra, &rules_alloc)) return -EOVERFLOW; - new_mt = kvmalloc_array(rules_alloc, sizeof(*new_mt), GFP_KERNEL); + new_mt = kvmalloc_array(rules_alloc, sizeof(*new_mt), GFP_KERNEL_ACCOUNT); if (!new_mt) return -ENOMEM; @@ -936,7 +936,7 @@ static void pipapo_lt_bits_adjust(struct nft_pipapo_field *f) return; } - new_lt = kvzalloc(lt_size + NFT_PIPAPO_ALIGN_HEADROOM, GFP_KERNEL); + new_lt = kvzalloc(lt_size + NFT_PIPAPO_ALIGN_HEADROOM, GFP_KERNEL_ACCOUNT); if (!new_lt) return; @@ -1212,7 +1212,7 @@ static int pipapo_realloc_scratch(struct nft_pipapo_match *clone, scratch = kzalloc_node(struct_size(scratch, map, bsize_max * 2) + NFT_PIPAPO_ALIGN_HEADROOM, - GFP_KERNEL, cpu_to_node(i)); + GFP_KERNEL_ACCOUNT, cpu_to_node(i)); if (!scratch) { /* On failure, there's no need to undo previous * allocations: this means that some scratch maps have @@ -1427,7 +1427,7 @@ static struct nft_pipapo_match *pipapo_clone(struct nft_pipapo_match *old) struct nft_pipapo_match *new; int i; - new = kmalloc(struct_size(new, f, old->field_count), GFP_KERNEL); + new = kmalloc(struct_size(new, f, old->field_count), GFP_KERNEL_ACCOUNT); if (!new) return NULL; @@ -1457,7 +1457,7 @@ static struct nft_pipapo_match *pipapo_clone(struct nft_pipapo_match *old) new_lt = kvzalloc(src->groups * NFT_PIPAPO_BUCKETS(src->bb) * src->bsize * sizeof(*dst->lt) + NFT_PIPAPO_ALIGN_HEADROOM, - GFP_KERNEL); + GFP_KERNEL_ACCOUNT); if (!new_lt) goto out_lt; @@ -1470,7 +1470,8 @@ static struct nft_pipapo_match *pipapo_clone(struct nft_pipapo_match *old) if (src->rules > 0) { dst->mt = kvmalloc_array(src->rules_alloc, - sizeof(*src->mt), GFP_KERNEL); + sizeof(*src->mt), + GFP_KERNEL_ACCOUNT); if (!dst->mt) goto out_mt; diff --git a/net/netfilter/nft_tunnel.c b/net/netfilter/nft_tunnel.c index 60a76e6e348e..5c6ed68cc6e0 100644 --- a/net/netfilter/nft_tunnel.c +++ b/net/netfilter/nft_tunnel.c @@ -509,13 +509,14 @@ static int nft_tunnel_obj_init(const struct nft_ctx *ctx, return err; } - md = metadata_dst_alloc(priv->opts.len, METADATA_IP_TUNNEL, GFP_KERNEL); + md = metadata_dst_alloc(priv->opts.len, METADATA_IP_TUNNEL, + GFP_KERNEL_ACCOUNT); if (!md) return -ENOMEM; memcpy(&md->u.tun_info, &info, sizeof(info)); #ifdef CONFIG_DST_CACHE - err = dst_cache_init(&md->u.tun_info.dst_cache, GFP_KERNEL); + err = dst_cache_init(&md->u.tun_info.dst_cache, GFP_KERNEL_ACCOUNT); if (err < 0) { metadata_dst_free(md); return err; diff --git a/net/qrtr/af_qrtr.c b/net/qrtr/af_qrtr.c index 41ece61eb57a..00c51cf693f3 100644 --- a/net/qrtr/af_qrtr.c +++ b/net/qrtr/af_qrtr.c @@ -884,7 +884,7 @@ static int qrtr_bcast_enqueue(struct qrtr_node *node, struct sk_buff *skb, mutex_lock(&qrtr_node_lock); list_for_each_entry(node, &qrtr_all_nodes, item) { - skbn = skb_clone(skb, GFP_KERNEL); + skbn = pskb_copy(skb, GFP_KERNEL); if (!skbn) break; skb_set_owner_w(skbn, skb->sk); diff --git a/net/rfkill/core.c b/net/rfkill/core.c index 13a5126bc36e..7d3e82e4c2fc 100644 --- a/net/rfkill/core.c +++ b/net/rfkill/core.c @@ -1394,7 +1394,6 @@ static const struct file_operations rfkill_fops = { .release = rfkill_fop_release, .unlocked_ioctl = rfkill_fop_ioctl, .compat_ioctl = compat_ptr_ioctl, - .llseek = no_llseek, }; #define RFKILL_NAME "rfkill" diff --git a/net/socket.c b/net/socket.c index 8d8b84fa404a..601ad74930ef 100644 --- a/net/socket.c +++ b/net/socket.c @@ -153,7 +153,6 @@ static void sock_show_fdinfo(struct seq_file *m, struct file *f) static const struct file_operations socket_file_ops = { .owner = THIS_MODULE, - .llseek = no_llseek, .read_iter = sock_read_iter, .write_iter = sock_write_iter, .poll = sock_poll, @@ -556,10 +555,10 @@ static struct socket *sockfd_lookup_light(int fd, int *err, int *fput_needed) struct socket *sock; *err = -EBADF; - if (f.file) { - sock = sock_from_file(f.file); + if (fd_file(f)) { + sock = sock_from_file(fd_file(f)); if (likely(sock)) { - *fput_needed = f.flags & FDPUT_FPUT; + *fput_needed = f.word & FDPUT_FPUT; return sock; } *err = -ENOTSOCK; @@ -2014,8 +2013,8 @@ int __sys_accept4(int fd, struct sockaddr __user *upeer_sockaddr, struct fd f; f = fdget(fd); - if (f.file) { - ret = __sys_accept4_file(f.file, upeer_sockaddr, + if (fd_file(f)) { + ret = __sys_accept4_file(fd_file(f), upeer_sockaddr, upeer_addrlen, flags); fdput(f); } @@ -2076,12 +2075,12 @@ int __sys_connect(int fd, struct sockaddr __user *uservaddr, int addrlen) struct fd f; f = fdget(fd); - if (f.file) { + if (fd_file(f)) { struct sockaddr_storage address; ret = move_addr_to_kernel(uservaddr, addrlen, &address); if (!ret) - ret = __sys_connect_file(f.file, &address, addrlen, 0); + ret = __sys_connect_file(fd_file(f), &address, addrlen, 0); fdput(f); } diff --git a/net/sunrpc/cache.c b/net/sunrpc/cache.c index 95ff74706104..1bd3e531b0e0 100644 --- a/net/sunrpc/cache.c +++ b/net/sunrpc/cache.c @@ -731,11 +731,10 @@ static bool cache_defer_req(struct cache_req *req, struct cache_head *item) static void cache_revisit_request(struct cache_head *item) { struct cache_deferred_req *dreq; - struct list_head pending; struct hlist_node *tmp; int hash = DFR_HASH(item); + LIST_HEAD(pending); - INIT_LIST_HEAD(&pending); spin_lock(&cache_defer_lock); hlist_for_each_entry_safe(dreq, tmp, &cache_defer_hash[hash], hash) @@ -756,10 +755,8 @@ static void cache_revisit_request(struct cache_head *item) void cache_clean_deferred(void *owner) { struct cache_deferred_req *dreq, *tmp; - struct list_head pending; + LIST_HEAD(pending); - - INIT_LIST_HEAD(&pending); spin_lock(&cache_defer_lock); list_for_each_entry_safe(dreq, tmp, &cache_defer_list, recent) { @@ -1085,9 +1082,8 @@ static void cache_dequeue(struct cache_detail *detail, struct cache_head *ch) { struct cache_queue *cq, *tmp; struct cache_request *cr; - struct list_head dequeued; + LIST_HEAD(dequeued); - INIT_LIST_HEAD(&dequeued); spin_lock(&queue_lock); list_for_each_entry_safe(cq, tmp, &detail->queue, list) if (!cq->reader) { @@ -1596,7 +1592,6 @@ static int cache_release_procfs(struct inode *inode, struct file *filp) } static const struct proc_ops cache_channel_proc_ops = { - .proc_lseek = no_llseek, .proc_read = cache_read_procfs, .proc_write = cache_write_procfs, .proc_poll = cache_poll_procfs, @@ -1662,7 +1657,6 @@ static const struct proc_ops cache_flush_proc_ops = { .proc_read = read_flush_procfs, .proc_write = write_flush_procfs, .proc_release = release_flush_procfs, - .proc_lseek = no_llseek, }; static void remove_cache_proc_entries(struct cache_detail *cd) @@ -1815,7 +1809,6 @@ static int cache_release_pipefs(struct inode *inode, struct file *filp) const struct file_operations cache_file_operations_pipefs = { .owner = THIS_MODULE, - .llseek = no_llseek, .read = cache_read_pipefs, .write = cache_write_pipefs, .poll = cache_poll_pipefs, @@ -1881,7 +1874,6 @@ const struct file_operations cache_flush_operations_pipefs = { .read = read_flush_pipefs, .write = write_flush_pipefs, .release = release_flush_pipefs, - .llseek = no_llseek, }; int sunrpc_cache_register_pipefs(struct dentry *parent, diff --git a/net/sunrpc/clnt.c b/net/sunrpc/clnt.c index 09f29a95f2bc..0090162ee8c3 100644 --- a/net/sunrpc/clnt.c +++ b/net/sunrpc/clnt.c @@ -48,13 +48,8 @@ # define RPCDBG_FACILITY RPCDBG_CALL #endif -/* - * All RPC clients are linked into this list - */ - static DECLARE_WAIT_QUEUE_HEAD(destroy_wait); - static void call_start(struct rpc_task *task); static void call_reserve(struct rpc_task *task); static void call_reserveresult(struct rpc_task *task); @@ -546,7 +541,7 @@ struct rpc_clnt *rpc_create(struct rpc_create_args *args) .connect_timeout = args->connect_timeout, .reconnect_timeout = args->reconnect_timeout, }; - char servername[48]; + char servername[RPC_MAXNETNAMELEN]; struct rpc_clnt *clnt; int i; @@ -1893,12 +1888,6 @@ call_allocate(struct rpc_task *task) if (req->rq_buffer) return; - if (proc->p_proc != 0) { - BUG_ON(proc->p_arglen == 0); - if (proc->p_decode != NULL) - BUG_ON(proc->p_replen == 0); - } - /* * Calculate the size (in quads) of the RPC call * and reply headers, and convert both values diff --git a/net/sunrpc/rpc_pipe.c b/net/sunrpc/rpc_pipe.c index 910a5d850d04..7ce3721c06ca 100644 --- a/net/sunrpc/rpc_pipe.c +++ b/net/sunrpc/rpc_pipe.c @@ -385,7 +385,6 @@ rpc_pipe_ioctl(struct file *filp, unsigned int cmd, unsigned long arg) static const struct file_operations rpc_pipe_fops = { .owner = THIS_MODULE, - .llseek = no_llseek, .read = rpc_pipe_read, .write = rpc_pipe_write, .poll = rpc_pipe_poll, diff --git a/net/sunrpc/sunrpc.h b/net/sunrpc/sunrpc.h index d4a362c9e4b3..e3c6e3b63f0b 100644 --- a/net/sunrpc/sunrpc.h +++ b/net/sunrpc/sunrpc.h @@ -36,7 +36,11 @@ static inline int sock_is_loopback(struct sock *sk) return loopback; } +struct svc_serv; +struct svc_rqst; int rpc_clients_notifier_register(void); void rpc_clients_notifier_unregister(void); void auth_domain_cleanup(void); +void svc_sock_update_bufs(struct svc_serv *serv); +enum svc_auth_status svc_authenticate(struct svc_rqst *rqstp); #endif /* _NET_SUNRPC_SUNRPC_H */ diff --git a/net/sunrpc/svc.c b/net/sunrpc/svc.c index 88a59cfa5583..7e7f4e0390c7 100644 --- a/net/sunrpc/svc.c +++ b/net/sunrpc/svc.c @@ -32,6 +32,7 @@ #include <trace/events/sunrpc.h> #include "fail.h" +#include "sunrpc.h" #define RPCDBG_FACILITY RPCDBG_SVCDSP @@ -417,7 +418,7 @@ struct svc_pool *svc_pool_for_cpu(struct svc_serv *serv) return &serv->sv_pools[pidx % serv->sv_nrpools]; } -int svc_rpcb_setup(struct svc_serv *serv, struct net *net) +static int svc_rpcb_setup(struct svc_serv *serv, struct net *net) { int err; @@ -429,7 +430,6 @@ int svc_rpcb_setup(struct svc_serv *serv, struct net *net) svc_unregister(serv, net); return 0; } -EXPORT_SYMBOL_GPL(svc_rpcb_setup); void svc_rpcb_cleanup(struct svc_serv *serv, struct net *net) { @@ -440,10 +440,11 @@ EXPORT_SYMBOL_GPL(svc_rpcb_cleanup); static int svc_uses_rpcbind(struct svc_serv *serv) { - struct svc_program *progp; - unsigned int i; + unsigned int p, i; + + for (p = 0; p < serv->sv_nprogs; p++) { + struct svc_program *progp = &serv->sv_programs[p]; - for (progp = serv->sv_program; progp; progp = progp->pg_next) { for (i = 0; i < progp->pg_nvers; i++) { if (progp->pg_vers[i] == NULL) continue; @@ -480,7 +481,7 @@ __svc_init_bc(struct svc_serv *serv) * Create an RPC service */ static struct svc_serv * -__svc_create(struct svc_program *prog, struct svc_stat *stats, +__svc_create(struct svc_program *prog, int nprogs, struct svc_stat *stats, unsigned int bufsize, int npools, int (*threadfn)(void *data)) { struct svc_serv *serv; @@ -491,7 +492,8 @@ __svc_create(struct svc_program *prog, struct svc_stat *stats, if (!(serv = kzalloc(sizeof(*serv), GFP_KERNEL))) return NULL; serv->sv_name = prog->pg_name; - serv->sv_program = prog; + serv->sv_programs = prog; + serv->sv_nprogs = nprogs; serv->sv_stats = stats; if (bufsize > RPCSVC_MAXPAYLOAD) bufsize = RPCSVC_MAXPAYLOAD; @@ -499,17 +501,18 @@ __svc_create(struct svc_program *prog, struct svc_stat *stats, serv->sv_max_mesg = roundup(serv->sv_max_payload + PAGE_SIZE, PAGE_SIZE); serv->sv_threadfn = threadfn; xdrsize = 0; - while (prog) { - prog->pg_lovers = prog->pg_nvers-1; - for (vers=0; vers<prog->pg_nvers ; vers++) - if (prog->pg_vers[vers]) { - prog->pg_hivers = vers; - if (prog->pg_lovers > vers) - prog->pg_lovers = vers; - if (prog->pg_vers[vers]->vs_xdrsize > xdrsize) - xdrsize = prog->pg_vers[vers]->vs_xdrsize; + for (i = 0; i < nprogs; i++) { + struct svc_program *progp = &prog[i]; + + progp->pg_lovers = progp->pg_nvers-1; + for (vers = 0; vers < progp->pg_nvers ; vers++) + if (progp->pg_vers[vers]) { + progp->pg_hivers = vers; + if (progp->pg_lovers > vers) + progp->pg_lovers = vers; + if (progp->pg_vers[vers]->vs_xdrsize > xdrsize) + xdrsize = progp->pg_vers[vers]->vs_xdrsize; } - prog = prog->pg_next; } serv->sv_xdrsize = xdrsize; INIT_LIST_HEAD(&serv->sv_tempsocks); @@ -558,13 +561,14 @@ __svc_create(struct svc_program *prog, struct svc_stat *stats, struct svc_serv *svc_create(struct svc_program *prog, unsigned int bufsize, int (*threadfn)(void *data)) { - return __svc_create(prog, NULL, bufsize, 1, threadfn); + return __svc_create(prog, 1, NULL, bufsize, 1, threadfn); } EXPORT_SYMBOL_GPL(svc_create); /** * svc_create_pooled - Create an RPC service with pooled threads - * @prog: the RPC program the new service will handle + * @prog: Array of RPC programs the new service will handle + * @nprogs: Number of programs in the array * @stats: the stats struct if desired * @bufsize: maximum message size for @prog * @threadfn: a function to service RPC requests for @prog @@ -572,6 +576,7 @@ EXPORT_SYMBOL_GPL(svc_create); * Returns an instantiated struct svc_serv object or NULL. */ struct svc_serv *svc_create_pooled(struct svc_program *prog, + unsigned int nprogs, struct svc_stat *stats, unsigned int bufsize, int (*threadfn)(void *data)) @@ -579,7 +584,7 @@ struct svc_serv *svc_create_pooled(struct svc_program *prog, struct svc_serv *serv; unsigned int npools = svc_pool_map_get(); - serv = __svc_create(prog, stats, bufsize, npools, threadfn); + serv = __svc_create(prog, nprogs, stats, bufsize, npools, threadfn); if (!serv) goto out_err; serv->sv_is_pooled = true; @@ -602,16 +607,16 @@ svc_destroy(struct svc_serv **servp) *servp = NULL; - dprintk("svc: svc_destroy(%s)\n", serv->sv_program->pg_name); + dprintk("svc: svc_destroy(%s)\n", serv->sv_programs->pg_name); timer_shutdown_sync(&serv->sv_temptimer); /* * Remaining transports at this point are not expected. */ WARN_ONCE(!list_empty(&serv->sv_permsocks), - "SVC: permsocks remain for %s\n", serv->sv_program->pg_name); + "SVC: permsocks remain for %s\n", serv->sv_programs->pg_name); WARN_ONCE(!list_empty(&serv->sv_tempsocks), - "SVC: tempsocks remain for %s\n", serv->sv_program->pg_name); + "SVC: tempsocks remain for %s\n", serv->sv_programs->pg_name); cache_clean_deferred(serv); @@ -664,8 +669,21 @@ svc_release_buffer(struct svc_rqst *rqstp) put_page(rqstp->rq_pages[i]); } -struct svc_rqst * -svc_rqst_alloc(struct svc_serv *serv, struct svc_pool *pool, int node) +static void +svc_rqst_free(struct svc_rqst *rqstp) +{ + folio_batch_release(&rqstp->rq_fbatch); + svc_release_buffer(rqstp); + if (rqstp->rq_scratch_page) + put_page(rqstp->rq_scratch_page); + kfree(rqstp->rq_resp); + kfree(rqstp->rq_argp); + kfree(rqstp->rq_auth_data); + kfree_rcu(rqstp, rq_rcu_head); +} + +static struct svc_rqst * +svc_prepare_thread(struct svc_serv *serv, struct svc_pool *pool, int node) { struct svc_rqst *rqstp; @@ -693,27 +711,10 @@ svc_rqst_alloc(struct svc_serv *serv, struct svc_pool *pool, int node) if (!svc_init_buffer(rqstp, serv->sv_max_mesg, node)) goto out_enomem; - return rqstp; -out_enomem: - svc_rqst_free(rqstp); - return NULL; -} -EXPORT_SYMBOL_GPL(svc_rqst_alloc); - -static struct svc_rqst * -svc_prepare_thread(struct svc_serv *serv, struct svc_pool *pool, int node) -{ - struct svc_rqst *rqstp; + rqstp->rq_err = -EAGAIN; /* No error yet */ - rqstp = svc_rqst_alloc(serv, pool, node); - if (!rqstp) - return ERR_PTR(-ENOMEM); - - spin_lock_bh(&serv->sv_lock); serv->sv_nrthreads += 1; - spin_unlock_bh(&serv->sv_lock); - - atomic_inc(&pool->sp_nrthreads); + pool->sp_nrthreads += 1; /* Protected by whatever lock the service uses when calling * svc_set_num_threads() @@ -721,6 +722,10 @@ svc_prepare_thread(struct svc_serv *serv, struct svc_pool *pool, int node) list_add_rcu(&rqstp->rq_all, &pool->sp_all_threads); return rqstp; + +out_enomem: + svc_rqst_free(rqstp); + return NULL; } /** @@ -768,31 +773,22 @@ svc_pool_victim(struct svc_serv *serv, struct svc_pool *target_pool, struct svc_pool *pool; unsigned int i; -retry: pool = target_pool; - if (pool != NULL) { - if (atomic_inc_not_zero(&pool->sp_nrthreads)) - goto found_pool; - return NULL; - } else { + if (!pool) { for (i = 0; i < serv->sv_nrpools; i++) { pool = &serv->sv_pools[--(*state) % serv->sv_nrpools]; - if (atomic_inc_not_zero(&pool->sp_nrthreads)) - goto found_pool; + if (pool->sp_nrthreads) + break; } - return NULL; } -found_pool: - set_bit(SP_VICTIM_REMAINS, &pool->sp_flags); - set_bit(SP_NEED_VICTIM, &pool->sp_flags); - if (!atomic_dec_and_test(&pool->sp_nrthreads)) + if (pool && pool->sp_nrthreads) { + set_bit(SP_VICTIM_REMAINS, &pool->sp_flags); + set_bit(SP_NEED_VICTIM, &pool->sp_flags); return pool; - /* Nothing left in this pool any more */ - clear_bit(SP_NEED_VICTIM, &pool->sp_flags); - clear_bit(SP_VICTIM_REMAINS, &pool->sp_flags); - goto retry; + } + return NULL; } static int @@ -803,6 +799,7 @@ svc_start_kthreads(struct svc_serv *serv, struct svc_pool *pool, int nrservs) struct svc_pool *chosen_pool; unsigned int state = serv->sv_nrthreads-1; int node; + int err; do { nrservs--; @@ -810,8 +807,8 @@ svc_start_kthreads(struct svc_serv *serv, struct svc_pool *pool, int nrservs) node = svc_pool_map_get_node(chosen_pool->sp_id); rqstp = svc_prepare_thread(serv, chosen_pool, node); - if (IS_ERR(rqstp)) - return PTR_ERR(rqstp); + if (!rqstp) + return -ENOMEM; task = kthread_create_on_node(serv->sv_threadfn, rqstp, node, "%s", serv->sv_name); if (IS_ERR(task)) { @@ -825,6 +822,13 @@ svc_start_kthreads(struct svc_serv *serv, struct svc_pool *pool, int nrservs) svc_sock_update_bufs(serv); wake_up_process(task); + + wait_var_event(&rqstp->rq_err, rqstp->rq_err != -EAGAIN); + err = rqstp->rq_err; + if (err) { + svc_exit_thread(rqstp); + return err; + } } while (nrservs > 0); return 0; @@ -871,7 +875,7 @@ svc_set_num_threads(struct svc_serv *serv, struct svc_pool *pool, int nrservs) if (!pool) nrservs -= serv->sv_nrthreads; else - nrservs -= atomic_read(&pool->sp_nrthreads); + nrservs -= pool->sp_nrthreads; if (nrservs > 0) return svc_start_kthreads(serv, pool, nrservs); @@ -933,25 +937,21 @@ void svc_rqst_release_pages(struct svc_rqst *rqstp) } } -/* - * Called from a server thread as it's exiting. Caller must hold the "service - * mutex" for the service. +/** + * svc_exit_thread - finalise the termination of a sunrpc server thread + * @rqstp: the svc_rqst which represents the thread. + * + * When a thread started with svc_new_thread() exits it must call + * svc_exit_thread() as its last act. This must be done with the + * service mutex held. Normally this is held by a DIFFERENT thread, the + * one that is calling svc_set_num_threads() and which will wait for + * SP_VICTIM_REMAINS to be cleared before dropping the mutex. If the + * thread exits for any reason other than svc_thread_should_stop() + * returning %true (which indicated that svc_set_num_threads() is + * waiting for it to exit), then it must take the service mutex itself, + * which can only safely be done using mutex_try_lock(). */ void -svc_rqst_free(struct svc_rqst *rqstp) -{ - folio_batch_release(&rqstp->rq_fbatch); - svc_release_buffer(rqstp); - if (rqstp->rq_scratch_page) - put_page(rqstp->rq_scratch_page); - kfree(rqstp->rq_resp); - kfree(rqstp->rq_argp); - kfree(rqstp->rq_auth_data); - kfree_rcu(rqstp, rq_rcu_head); -} -EXPORT_SYMBOL_GPL(svc_rqst_free); - -void svc_exit_thread(struct svc_rqst *rqstp) { struct svc_serv *serv = rqstp->rq_server; @@ -959,11 +959,8 @@ svc_exit_thread(struct svc_rqst *rqstp) list_del_rcu(&rqstp->rq_all); - atomic_dec(&pool->sp_nrthreads); - - spin_lock_bh(&serv->sv_lock); + pool->sp_nrthreads -= 1; serv->sv_nrthreads -= 1; - spin_unlock_bh(&serv->sv_lock); svc_sock_update_bufs(serv); svc_rqst_free(rqstp); @@ -1098,6 +1095,7 @@ static int __svc_register(struct net *net, const char *progname, return error; } +static int svc_rpcbind_set_version(struct net *net, const struct svc_program *progp, u32 version, int family, @@ -1108,7 +1106,6 @@ int svc_rpcbind_set_version(struct net *net, version, family, proto, port); } -EXPORT_SYMBOL_GPL(svc_rpcbind_set_version); int svc_generic_rpcbind_set(struct net *net, const struct svc_program *progp, @@ -1156,15 +1153,16 @@ int svc_register(const struct svc_serv *serv, struct net *net, const int family, const unsigned short proto, const unsigned short port) { - struct svc_program *progp; - unsigned int i; + unsigned int p, i; int error = 0; WARN_ON_ONCE(proto == 0 && port == 0); if (proto == 0 && port == 0) return -EINVAL; - for (progp = serv->sv_program; progp; progp = progp->pg_next) { + for (p = 0; p < serv->sv_nprogs; p++) { + struct svc_program *progp = &serv->sv_programs[p]; + for (i = 0; i < progp->pg_nvers; i++) { error = progp->pg_rpcbind_set(net, progp, i, @@ -1216,13 +1214,14 @@ static void __svc_unregister(struct net *net, const u32 program, const u32 versi static void svc_unregister(const struct svc_serv *serv, struct net *net) { struct sighand_struct *sighand; - struct svc_program *progp; unsigned long flags; - unsigned int i; + unsigned int p, i; clear_thread_flag(TIF_SIGPENDING); - for (progp = serv->sv_program; progp; progp = progp->pg_next) { + for (p = 0; p < serv->sv_nprogs; p++) { + struct svc_program *progp = &serv->sv_programs[p]; + for (i = 0; i < progp->pg_nvers; i++) { if (progp->pg_vers[i] == NULL) continue; @@ -1328,7 +1327,7 @@ svc_process_common(struct svc_rqst *rqstp) struct svc_process_info process; enum svc_auth_status auth_res; unsigned int aoffset; - int rc; + int pr, rc; __be32 *p; /* Will be turned off only when NFSv4 Sessions are used */ @@ -1352,9 +1351,12 @@ svc_process_common(struct svc_rqst *rqstp) rqstp->rq_vers = be32_to_cpup(p++); rqstp->rq_proc = be32_to_cpup(p); - for (progp = serv->sv_program; progp; progp = progp->pg_next) + for (pr = 0; pr < serv->sv_nprogs; pr++) { + progp = &serv->sv_programs[pr]; + if (rqstp->rq_prog == progp->pg_prog) break; + } /* * Decode auth data, and add verifier to reply buffer. @@ -1526,6 +1528,14 @@ err_system_err: goto sendit; } +/* + * Drop request + */ +static void svc_drop(struct svc_rqst *rqstp) +{ + trace_svc_drop(rqstp); +} + /** * svc_process - Execute one RPC transaction * @rqstp: RPC transaction context diff --git a/net/sunrpc/svc_xprt.c b/net/sunrpc/svc_xprt.c index d3735ab3e6d1..43c57124de52 100644 --- a/net/sunrpc/svc_xprt.c +++ b/net/sunrpc/svc_xprt.c @@ -268,7 +268,7 @@ static int _svc_xprt_create(struct svc_serv *serv, const char *xprt_name, spin_unlock(&svc_xprt_class_lock); newxprt = xcl->xcl_ops->xpo_create(serv, net, sap, len, flags); if (IS_ERR(newxprt)) { - trace_svc_xprt_create_err(serv->sv_program->pg_name, + trace_svc_xprt_create_err(serv->sv_programs->pg_name, xcl->xcl_name, sap, len, newxprt); module_put(xcl->xcl_owner); @@ -905,15 +905,6 @@ void svc_recv(struct svc_rqst *rqstp) } EXPORT_SYMBOL_GPL(svc_recv); -/* - * Drop request - */ -void svc_drop(struct svc_rqst *rqstp) -{ - trace_svc_drop(rqstp); -} -EXPORT_SYMBOL_GPL(svc_drop); - /** * svc_send - Return reply to client * @rqstp: RPC transaction context diff --git a/net/sunrpc/svcauth.c b/net/sunrpc/svcauth.c index 1619211f0960..55b4d2874188 100644 --- a/net/sunrpc/svcauth.c +++ b/net/sunrpc/svcauth.c @@ -18,6 +18,7 @@ #include <linux/sunrpc/svcauth.h> #include <linux/err.h> #include <linux/hash.h> +#include <linux/user_namespace.h> #include <trace/events/sunrpc.h> @@ -98,7 +99,6 @@ enum svc_auth_status svc_authenticate(struct svc_rqst *rqstp) rqstp->rq_authop = aops; return aops->accept(rqstp); } -EXPORT_SYMBOL_GPL(svc_authenticate); /** * svc_set_client - Assign an appropriate 'auth_domain' as the client @@ -176,6 +176,33 @@ rpc_authflavor_t svc_auth_flavor(struct svc_rqst *rqstp) } EXPORT_SYMBOL_GPL(svc_auth_flavor); +/** + * svcauth_map_clnt_to_svc_cred_local - maps a generic cred + * to a svc_cred suitable for use in nfsd. + * @clnt: rpc_clnt associated with nfs client + * @cred: generic cred associated with nfs client + * @svc: returned svc_cred that is suitable for use in nfsd + */ +void svcauth_map_clnt_to_svc_cred_local(struct rpc_clnt *clnt, + const struct cred *cred, + struct svc_cred *svc) +{ + struct user_namespace *userns = clnt->cl_cred ? + clnt->cl_cred->user_ns : &init_user_ns; + + memset(svc, 0, sizeof(struct svc_cred)); + + svc->cr_uid = KUIDT_INIT(from_kuid_munged(userns, cred->fsuid)); + svc->cr_gid = KGIDT_INIT(from_kgid_munged(userns, cred->fsgid)); + svc->cr_flavor = clnt->cl_auth->au_flavor; + if (cred->group_info) + svc->cr_group_info = get_group_info(cred->group_info); + /* These aren't relevant for local (network is bypassed) */ + svc->cr_principal = NULL; + svc->cr_gss_mech = NULL; +} +EXPORT_SYMBOL_GPL(svcauth_map_clnt_to_svc_cred_local); + /************************************************** * 'auth_domains' are stored in a hash table indexed by name. * When the last reference to an 'auth_domain' is dropped, diff --git a/net/sunrpc/svcauth_unix.c b/net/sunrpc/svcauth_unix.c index 04b45588ae6f..8ca98b146ec8 100644 --- a/net/sunrpc/svcauth_unix.c +++ b/net/sunrpc/svcauth_unix.c @@ -697,7 +697,8 @@ svcauth_unix_set_client(struct svc_rqst *rqstp) rqstp->rq_auth_stat = rpc_autherr_badcred; ipm = ip_map_cached_get(xprt); if (ipm == NULL) - ipm = __ip_map_lookup(sn->ip_map_cache, rqstp->rq_server->sv_program->pg_class, + ipm = __ip_map_lookup(sn->ip_map_cache, + rqstp->rq_server->sv_programs->pg_class, &sin6->sin6_addr); if (ipm == NULL) diff --git a/net/sunrpc/svcsock.c b/net/sunrpc/svcsock.c index 6b3f01beb294..825ec5357691 100644 --- a/net/sunrpc/svcsock.c +++ b/net/sunrpc/svcsock.c @@ -1378,7 +1378,6 @@ void svc_sock_update_bufs(struct svc_serv *serv) set_bit(XPT_CHNGBUF, &svsk->sk_xprt.xpt_flags); spin_unlock_bh(&serv->sv_lock); } -EXPORT_SYMBOL_GPL(svc_sock_update_bufs); /* * Initialize socket for RPC use and create svc_sock struct diff --git a/net/sunrpc/xprtrdma/svc_rdma_transport.c b/net/sunrpc/xprtrdma/svc_rdma_transport.c index f15750cacacf..c3fbf0779d4a 100644 --- a/net/sunrpc/xprtrdma/svc_rdma_transport.c +++ b/net/sunrpc/xprtrdma/svc_rdma_transport.c @@ -339,7 +339,6 @@ static int svc_rdma_cma_handler(struct rdma_cm_id *cma_id, svc_xprt_enqueue(xprt); break; case RDMA_CM_EVENT_DISCONNECTED: - case RDMA_CM_EVENT_DEVICE_REMOVAL: svc_xprt_deferred_close(xprt); break; default: @@ -370,7 +369,7 @@ static struct svc_xprt *svc_rdma_create(struct svc_serv *serv, listen_id = svc_rdma_create_listen_id(net, sa, cma_xprt); if (IS_ERR(listen_id)) { kfree(cma_xprt); - return (struct svc_xprt *)listen_id; + return ERR_CAST(listen_id); } cma_xprt->sc_cm_id = listen_id; @@ -384,6 +383,16 @@ static struct svc_xprt *svc_rdma_create(struct svc_serv *serv, return &cma_xprt->sc_xprt; } +static void svc_rdma_xprt_done(struct rpcrdma_notification *rn) +{ + struct svcxprt_rdma *rdma = container_of(rn, struct svcxprt_rdma, + sc_rn); + struct rdma_cm_id *id = rdma->sc_cm_id; + + trace_svcrdma_device_removal(id); + svc_xprt_close(&rdma->sc_xprt); +} + /* * This is the xpo_recvfrom function for listening endpoints. Its * purpose is to accept incoming connections. The CMA callback handler @@ -425,6 +434,9 @@ static struct svc_xprt *svc_rdma_accept(struct svc_xprt *xprt) dev = newxprt->sc_cm_id->device; newxprt->sc_port_num = newxprt->sc_cm_id->port_num; + if (rpcrdma_rn_register(dev, &newxprt->sc_rn, svc_rdma_xprt_done)) + goto errout; + newxprt->sc_max_req_size = svcrdma_max_req_size; newxprt->sc_max_requests = svcrdma_max_requests; newxprt->sc_max_bc_requests = svcrdma_max_bc_requests; @@ -580,6 +592,7 @@ static void __svc_rdma_free(struct work_struct *work) { struct svcxprt_rdma *rdma = container_of(work, struct svcxprt_rdma, sc_work); + struct ib_device *device = rdma->sc_cm_id->device; /* This blocks until the Completion Queues are empty */ if (rdma->sc_qp && !IS_ERR(rdma->sc_qp)) @@ -608,6 +621,7 @@ static void __svc_rdma_free(struct work_struct *work) /* Destroy the CM ID */ rdma_destroy_id(rdma->sc_cm_id); + rpcrdma_rn_unregister(device, &rdma->sc_rn); kfree(rdma); } diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c index e0160da4ef43..85e423921734 100644 --- a/net/vmw_vsock/virtio_transport.c +++ b/net/vmw_vsock/virtio_transport.c @@ -94,6 +94,63 @@ out_rcu: return ret; } +/* Caller need to hold vsock->tx_lock on vq */ +static int virtio_transport_send_skb(struct sk_buff *skb, struct virtqueue *vq, + struct virtio_vsock *vsock) +{ + int ret, in_sg = 0, out_sg = 0; + struct scatterlist **sgs; + + sgs = vsock->out_sgs; + sg_init_one(sgs[out_sg], virtio_vsock_hdr(skb), + sizeof(*virtio_vsock_hdr(skb))); + out_sg++; + + if (!skb_is_nonlinear(skb)) { + if (skb->len > 0) { + sg_init_one(sgs[out_sg], skb->data, skb->len); + out_sg++; + } + } else { + struct skb_shared_info *si; + int i; + + /* If skb is nonlinear, then its buffer must contain + * only header and nothing more. Data is stored in + * the fragged part. + */ + WARN_ON_ONCE(skb_headroom(skb) != sizeof(*virtio_vsock_hdr(skb))); + + si = skb_shinfo(skb); + + for (i = 0; i < si->nr_frags; i++) { + skb_frag_t *skb_frag = &si->frags[i]; + void *va; + + /* We will use 'page_to_virt()' for the userspace page + * here, because virtio or dma-mapping layers will call + * 'virt_to_phys()' later to fill the buffer descriptor. + * We don't touch memory at "virtual" address of this page. + */ + va = page_to_virt(skb_frag_page(skb_frag)); + sg_init_one(sgs[out_sg], + va + skb_frag_off(skb_frag), + skb_frag_size(skb_frag)); + out_sg++; + } + } + + ret = virtqueue_add_sgs(vq, sgs, out_sg, in_sg, skb, GFP_KERNEL); + /* Usually this means that there is no more space available in + * the vq + */ + if (ret < 0) + return ret; + + virtio_transport_deliver_tap_pkt(skb); + return 0; +} + static void virtio_transport_send_pkt_work(struct work_struct *work) { @@ -111,66 +168,22 @@ virtio_transport_send_pkt_work(struct work_struct *work) vq = vsock->vqs[VSOCK_VQ_TX]; for (;;) { - int ret, in_sg = 0, out_sg = 0; - struct scatterlist **sgs; struct sk_buff *skb; bool reply; + int ret; skb = virtio_vsock_skb_dequeue(&vsock->send_pkt_queue); if (!skb) break; reply = virtio_vsock_skb_reply(skb); - sgs = vsock->out_sgs; - sg_init_one(sgs[out_sg], virtio_vsock_hdr(skb), - sizeof(*virtio_vsock_hdr(skb))); - out_sg++; - - if (!skb_is_nonlinear(skb)) { - if (skb->len > 0) { - sg_init_one(sgs[out_sg], skb->data, skb->len); - out_sg++; - } - } else { - struct skb_shared_info *si; - int i; - - /* If skb is nonlinear, then its buffer must contain - * only header and nothing more. Data is stored in - * the fragged part. - */ - WARN_ON_ONCE(skb_headroom(skb) != sizeof(*virtio_vsock_hdr(skb))); - - si = skb_shinfo(skb); - - for (i = 0; i < si->nr_frags; i++) { - skb_frag_t *skb_frag = &si->frags[i]; - void *va; - /* We will use 'page_to_virt()' for the userspace page - * here, because virtio or dma-mapping layers will call - * 'virt_to_phys()' later to fill the buffer descriptor. - * We don't touch memory at "virtual" address of this page. - */ - va = page_to_virt(skb_frag_page(skb_frag)); - sg_init_one(sgs[out_sg], - va + skb_frag_off(skb_frag), - skb_frag_size(skb_frag)); - out_sg++; - } - } - - ret = virtqueue_add_sgs(vq, sgs, out_sg, in_sg, skb, GFP_KERNEL); - /* Usually this means that there is no more space available in - * the vq - */ + ret = virtio_transport_send_skb(skb, vq, vsock); if (ret < 0) { virtio_vsock_skb_queue_head(&vsock->send_pkt_queue, skb); break; } - virtio_transport_deliver_tap_pkt(skb); - if (reply) { struct virtqueue *rx_vq = vsock->vqs[VSOCK_VQ_RX]; int val; @@ -195,6 +208,28 @@ out: queue_work(virtio_vsock_workqueue, &vsock->rx_work); } +/* Caller need to hold RCU for vsock. + * Returns 0 if the packet is successfully put on the vq. + */ +static int virtio_transport_send_skb_fast_path(struct virtio_vsock *vsock, struct sk_buff *skb) +{ + struct virtqueue *vq = vsock->vqs[VSOCK_VQ_TX]; + int ret; + + /* Inside RCU, can't sleep! */ + ret = mutex_trylock(&vsock->tx_lock); + if (unlikely(ret == 0)) + return -EBUSY; + + ret = virtio_transport_send_skb(skb, vq, vsock); + if (ret == 0) + virtqueue_kick(vq); + + mutex_unlock(&vsock->tx_lock); + + return ret; +} + static int virtio_transport_send_pkt(struct sk_buff *skb) { @@ -218,11 +253,20 @@ virtio_transport_send_pkt(struct sk_buff *skb) goto out_rcu; } - if (virtio_vsock_skb_reply(skb)) - atomic_inc(&vsock->queued_replies); + /* If send_pkt_queue is empty, we can safely bypass this queue + * because packet order is maintained and (try) to put the packet + * on the virtqueue using virtio_transport_send_skb_fast_path. + * If this fails we simply put the packet on the intermediate + * queue and schedule the worker. + */ + if (!skb_queue_empty_lockless(&vsock->send_pkt_queue) || + virtio_transport_send_skb_fast_path(vsock, skb)) { + if (virtio_vsock_skb_reply(skb)) + atomic_inc(&vsock->queued_replies); - virtio_vsock_skb_queue_tail(&vsock->send_pkt_queue, skb); - queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work); + virtio_vsock_skb_queue_tail(&vsock->send_pkt_queue, skb); + queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work); + } out_rcu: rcu_read_unlock(); |