summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--net/ipv4/ipmr.c56
-rw-r--r--net/ipv6/ip6mr.c52
2 files changed, 84 insertions, 24 deletions
diff --git a/net/ipv4/ipmr.c b/net/ipv4/ipmr.c
index c58dd78509a2..383ea8b91cc7 100644
--- a/net/ipv4/ipmr.c
+++ b/net/ipv4/ipmr.c
@@ -120,6 +120,11 @@ static void ipmr_expire_process(struct timer_list *t);
lockdep_rtnl_is_held() || \
list_empty(&net->ipv4.mr_tables))
+static bool ipmr_can_free_table(struct net *net)
+{
+ return !check_net(net) || !net->ipv4.mr_rules_ops;
+}
+
static struct mr_table *ipmr_mr_table_iter(struct net *net,
struct mr_table *mrt)
{
@@ -137,7 +142,7 @@ static struct mr_table *ipmr_mr_table_iter(struct net *net,
return ret;
}
-static struct mr_table *ipmr_get_table(struct net *net, u32 id)
+static struct mr_table *__ipmr_get_table(struct net *net, u32 id)
{
struct mr_table *mrt;
@@ -148,6 +153,16 @@ static struct mr_table *ipmr_get_table(struct net *net, u32 id)
return NULL;
}
+static struct mr_table *ipmr_get_table(struct net *net, u32 id)
+{
+ struct mr_table *mrt;
+
+ rcu_read_lock();
+ mrt = __ipmr_get_table(net, id);
+ rcu_read_unlock();
+ return mrt;
+}
+
static int ipmr_fib_lookup(struct net *net, struct flowi4 *flp4,
struct mr_table **mrt)
{
@@ -189,7 +204,7 @@ static int ipmr_rule_action(struct fib_rule *rule, struct flowi *flp,
arg->table = fib_rule_get_table(rule, arg);
- mrt = ipmr_get_table(rule->fr_net, arg->table);
+ mrt = __ipmr_get_table(rule->fr_net, arg->table);
if (!mrt)
return -EAGAIN;
res->mrt = mrt;
@@ -302,6 +317,11 @@ EXPORT_SYMBOL(ipmr_rule_default);
#define ipmr_for_each_table(mrt, net) \
for (mrt = net->ipv4.mrt; mrt; mrt = NULL)
+static bool ipmr_can_free_table(struct net *net)
+{
+ return !check_net(net);
+}
+
static struct mr_table *ipmr_mr_table_iter(struct net *net,
struct mr_table *mrt)
{
@@ -315,6 +335,8 @@ static struct mr_table *ipmr_get_table(struct net *net, u32 id)
return net->ipv4.mrt;
}
+#define __ipmr_get_table ipmr_get_table
+
static int ipmr_fib_lookup(struct net *net, struct flowi4 *flp4,
struct mr_table **mrt)
{
@@ -403,7 +425,7 @@ static struct mr_table *ipmr_new_table(struct net *net, u32 id)
if (id != RT_TABLE_DEFAULT && id >= 1000000000)
return ERR_PTR(-EINVAL);
- mrt = ipmr_get_table(net, id);
+ mrt = __ipmr_get_table(net, id);
if (mrt)
return mrt;
@@ -413,6 +435,10 @@ static struct mr_table *ipmr_new_table(struct net *net, u32 id)
static void ipmr_free_table(struct mr_table *mrt)
{
+ struct net *net = read_pnet(&mrt->net);
+
+ DEBUG_NET_WARN_ON_ONCE(!ipmr_can_free_table(net));
+
timer_shutdown_sync(&mrt->ipmr_expire_timer);
mroute_clean_tables(mrt, MRT_FLUSH_VIFS | MRT_FLUSH_VIFS_STATIC |
MRT_FLUSH_MFC | MRT_FLUSH_MFC_STATIC);
@@ -1374,7 +1400,7 @@ int ip_mroute_setsockopt(struct sock *sk, int optname, sockptr_t optval,
goto out_unlock;
}
- mrt = ipmr_get_table(net, raw_sk(sk)->ipmr_table ? : RT_TABLE_DEFAULT);
+ mrt = __ipmr_get_table(net, raw_sk(sk)->ipmr_table ? : RT_TABLE_DEFAULT);
if (!mrt) {
ret = -ENOENT;
goto out_unlock;
@@ -2262,11 +2288,13 @@ int ipmr_get_route(struct net *net, struct sk_buff *skb,
struct mr_table *mrt;
int err;
- mrt = ipmr_get_table(net, RT_TABLE_DEFAULT);
- if (!mrt)
+ rcu_read_lock();
+ mrt = __ipmr_get_table(net, RT_TABLE_DEFAULT);
+ if (!mrt) {
+ rcu_read_unlock();
return -ENOENT;
+ }
- rcu_read_lock();
cache = ipmr_cache_find(mrt, saddr, daddr);
if (!cache && skb->dev) {
int vif = ipmr_find_vif(mrt, skb->dev);
@@ -2550,7 +2578,7 @@ static int ipmr_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh,
grp = nla_get_in_addr_default(tb[RTA_DST], 0);
tableid = nla_get_u32_default(tb[RTA_TABLE], 0);
- mrt = ipmr_get_table(net, tableid ? tableid : RT_TABLE_DEFAULT);
+ mrt = __ipmr_get_table(net, tableid ? tableid : RT_TABLE_DEFAULT);
if (!mrt) {
err = -ENOENT;
goto errout_free;
@@ -2604,7 +2632,7 @@ static int ipmr_rtm_dumproute(struct sk_buff *skb, struct netlink_callback *cb)
if (filter.table_id) {
struct mr_table *mrt;
- mrt = ipmr_get_table(sock_net(skb->sk), filter.table_id);
+ mrt = __ipmr_get_table(sock_net(skb->sk), filter.table_id);
if (!mrt) {
if (rtnl_msg_family(cb->nlh) != RTNL_FAMILY_IPMR)
return skb->len;
@@ -2712,7 +2740,7 @@ static int rtm_to_ipmr_mfcc(struct net *net, struct nlmsghdr *nlh,
break;
}
}
- mrt = ipmr_get_table(net, tblid);
+ mrt = __ipmr_get_table(net, tblid);
if (!mrt) {
ret = -ENOENT;
goto out;
@@ -2920,13 +2948,15 @@ static void *ipmr_vif_seq_start(struct seq_file *seq, loff_t *pos)
struct net *net = seq_file_net(seq);
struct mr_table *mrt;
- mrt = ipmr_get_table(net, RT_TABLE_DEFAULT);
- if (!mrt)
+ rcu_read_lock();
+ mrt = __ipmr_get_table(net, RT_TABLE_DEFAULT);
+ if (!mrt) {
+ rcu_read_unlock();
return ERR_PTR(-ENOENT);
+ }
iter->mrt = mrt;
- rcu_read_lock();
return mr_vif_seq_start(seq, pos);
}
diff --git a/net/ipv6/ip6mr.c b/net/ipv6/ip6mr.c
index d66f58932a79..4147890fe98f 100644
--- a/net/ipv6/ip6mr.c
+++ b/net/ipv6/ip6mr.c
@@ -108,6 +108,11 @@ static void ipmr_expire_process(struct timer_list *t);
lockdep_rtnl_is_held() || \
list_empty(&net->ipv6.mr6_tables))
+static bool ip6mr_can_free_table(struct net *net)
+{
+ return !check_net(net) || !net->ipv6.mr6_rules_ops;
+}
+
static struct mr_table *ip6mr_mr_table_iter(struct net *net,
struct mr_table *mrt)
{
@@ -125,7 +130,7 @@ static struct mr_table *ip6mr_mr_table_iter(struct net *net,
return ret;
}
-static struct mr_table *ip6mr_get_table(struct net *net, u32 id)
+static struct mr_table *__ip6mr_get_table(struct net *net, u32 id)
{
struct mr_table *mrt;
@@ -136,6 +141,16 @@ static struct mr_table *ip6mr_get_table(struct net *net, u32 id)
return NULL;
}
+static struct mr_table *ip6mr_get_table(struct net *net, u32 id)
+{
+ struct mr_table *mrt;
+
+ rcu_read_lock();
+ mrt = __ip6mr_get_table(net, id);
+ rcu_read_unlock();
+ return mrt;
+}
+
static int ip6mr_fib_lookup(struct net *net, struct flowi6 *flp6,
struct mr_table **mrt)
{
@@ -177,7 +192,7 @@ static int ip6mr_rule_action(struct fib_rule *rule, struct flowi *flp,
arg->table = fib_rule_get_table(rule, arg);
- mrt = ip6mr_get_table(rule->fr_net, arg->table);
+ mrt = __ip6mr_get_table(rule->fr_net, arg->table);
if (!mrt)
return -EAGAIN;
res->mrt = mrt;
@@ -291,6 +306,11 @@ EXPORT_SYMBOL(ip6mr_rule_default);
#define ip6mr_for_each_table(mrt, net) \
for (mrt = net->ipv6.mrt6; mrt; mrt = NULL)
+static bool ip6mr_can_free_table(struct net *net)
+{
+ return !check_net(net);
+}
+
static struct mr_table *ip6mr_mr_table_iter(struct net *net,
struct mr_table *mrt)
{
@@ -304,6 +324,8 @@ static struct mr_table *ip6mr_get_table(struct net *net, u32 id)
return net->ipv6.mrt6;
}
+#define __ip6mr_get_table ip6mr_get_table
+
static int ip6mr_fib_lookup(struct net *net, struct flowi6 *flp6,
struct mr_table **mrt)
{
@@ -382,7 +404,7 @@ static struct mr_table *ip6mr_new_table(struct net *net, u32 id)
{
struct mr_table *mrt;
- mrt = ip6mr_get_table(net, id);
+ mrt = __ip6mr_get_table(net, id);
if (mrt)
return mrt;
@@ -392,6 +414,10 @@ static struct mr_table *ip6mr_new_table(struct net *net, u32 id)
static void ip6mr_free_table(struct mr_table *mrt)
{
+ struct net *net = read_pnet(&mrt->net);
+
+ DEBUG_NET_WARN_ON_ONCE(!ip6mr_can_free_table(net));
+
timer_shutdown_sync(&mrt->ipmr_expire_timer);
mroute_clean_tables(mrt, MRT6_FLUSH_MIFS | MRT6_FLUSH_MIFS_STATIC |
MRT6_FLUSH_MFC | MRT6_FLUSH_MFC_STATIC);
@@ -411,13 +437,15 @@ static void *ip6mr_vif_seq_start(struct seq_file *seq, loff_t *pos)
struct net *net = seq_file_net(seq);
struct mr_table *mrt;
- mrt = ip6mr_get_table(net, RT6_TABLE_DFLT);
- if (!mrt)
+ rcu_read_lock();
+ mrt = __ip6mr_get_table(net, RT6_TABLE_DFLT);
+ if (!mrt) {
+ rcu_read_unlock();
return ERR_PTR(-ENOENT);
+ }
iter->mrt = mrt;
- rcu_read_lock();
return mr_vif_seq_start(seq, pos);
}
@@ -2278,11 +2306,13 @@ int ip6mr_get_route(struct net *net, struct sk_buff *skb, struct rtmsg *rtm,
struct mfc6_cache *cache;
struct rt6_info *rt = dst_rt6_info(skb_dst(skb));
- mrt = ip6mr_get_table(net, RT6_TABLE_DFLT);
- if (!mrt)
+ rcu_read_lock();
+ mrt = __ip6mr_get_table(net, RT6_TABLE_DFLT);
+ if (!mrt) {
+ rcu_read_unlock();
return -ENOENT;
+ }
- rcu_read_lock();
cache = ip6mr_cache_find(mrt, &rt->rt6i_src.addr, &rt->rt6i_dst.addr);
if (!cache && skb->dev) {
int vif = ip6mr_find_vif(mrt, skb->dev);
@@ -2562,7 +2592,7 @@ static int ip6mr_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh,
grp = nla_get_in6_addr(tb[RTA_DST]);
tableid = nla_get_u32_default(tb[RTA_TABLE], 0);
- mrt = ip6mr_get_table(net, tableid ?: RT_TABLE_DEFAULT);
+ mrt = __ip6mr_get_table(net, tableid ?: RT_TABLE_DEFAULT);
if (!mrt) {
NL_SET_ERR_MSG_MOD(extack, "MR table does not exist");
return -ENOENT;
@@ -2609,7 +2639,7 @@ static int ip6mr_rtm_dumproute(struct sk_buff *skb, struct netlink_callback *cb)
if (filter.table_id) {
struct mr_table *mrt;
- mrt = ip6mr_get_table(sock_net(skb->sk), filter.table_id);
+ mrt = __ip6mr_get_table(sock_net(skb->sk), filter.table_id);
if (!mrt) {
if (rtnl_msg_family(cb->nlh) != RTNL_FAMILY_IP6MR)
return skb->len;