diff options
Diffstat (limited to 'net/xfrm/xfrm_state.c')
| -rw-r--r-- | net/xfrm/xfrm_state.c | 116 |
1 files changed, 64 insertions, 52 deletions
diff --git a/net/xfrm/xfrm_state.c b/net/xfrm/xfrm_state.c index 98b362d51836..1748d374abca 100644 --- a/net/xfrm/xfrm_state.c +++ b/net/xfrm/xfrm_state.c @@ -53,7 +53,7 @@ static DECLARE_WORK(xfrm_state_gc_work, xfrm_state_gc_task); static HLIST_HEAD(xfrm_state_gc_list); static HLIST_HEAD(xfrm_state_dev_gc_list); -static inline bool xfrm_state_hold_rcu(struct xfrm_state __rcu *x) +static inline bool xfrm_state_hold_rcu(struct xfrm_state *x) { return refcount_inc_not_zero(&x->refcnt); } @@ -870,7 +870,7 @@ xfrm_state_flush_secctx_check(struct net *net, u8 proto, bool task_valid) for (i = 0; i <= net->xfrm.state_hmask; i++) { struct xfrm_state *x; - hlist_for_each_entry(x, net->xfrm.state_bydst+i, bydst) { + hlist_for_each_entry(x, xfrm_state_deref_prot(net->xfrm.state_bydst, net) + i, bydst) { if (xfrm_id_proto_match(x->id.proto, proto) && (err = security_xfrm_state_delete(x)) != 0) { xfrm_audit_state_delete(x, 0, task_valid); @@ -891,7 +891,7 @@ xfrm_dev_state_flush_secctx_check(struct net *net, struct net_device *dev, bool struct xfrm_state *x; struct xfrm_dev_offload *xso; - hlist_for_each_entry(x, net->xfrm.state_bydst+i, bydst) { + hlist_for_each_entry(x, xfrm_state_deref_prot(net->xfrm.state_bydst, net) + i, bydst) { xso = &x->xso; if (xso->dev == dev && @@ -931,7 +931,7 @@ int xfrm_state_flush(struct net *net, u8 proto, bool task_valid) for (i = 0; i <= net->xfrm.state_hmask; i++) { struct xfrm_state *x; restart: - hlist_for_each_entry(x, net->xfrm.state_bydst+i, bydst) { + hlist_for_each_entry(x, xfrm_state_deref_prot(net->xfrm.state_bydst, net) + i, bydst) { if (!xfrm_state_kern(x) && xfrm_id_proto_match(x->id.proto, proto)) { xfrm_state_hold(x); @@ -973,7 +973,7 @@ int xfrm_dev_state_flush(struct net *net, struct net_device *dev, bool task_vali err = -ESRCH; for (i = 0; i <= net->xfrm.state_hmask; i++) { restart: - hlist_for_each_entry(x, net->xfrm.state_bydst+i, bydst) { + hlist_for_each_entry(x, xfrm_state_deref_prot(net->xfrm.state_bydst, net) + i, bydst) { xso = &x->xso; if (!xfrm_state_kern(x) && xso->dev == dev) { @@ -1563,23 +1563,23 @@ found: list_add(&x->km.all, &net->xfrm.state_all); h = xfrm_dst_hash(net, daddr, saddr, tmpl->reqid, encap_family); XFRM_STATE_INSERT(bydst, &x->bydst, - net->xfrm.state_bydst + h, + xfrm_state_deref_prot(net->xfrm.state_bydst, net) + h, x->xso.type); h = xfrm_src_hash(net, daddr, saddr, encap_family); XFRM_STATE_INSERT(bysrc, &x->bysrc, - net->xfrm.state_bysrc + h, + xfrm_state_deref_prot(net->xfrm.state_bysrc, net) + h, x->xso.type); INIT_HLIST_NODE(&x->state_cache); if (x->id.spi) { h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto, encap_family); XFRM_STATE_INSERT(byspi, &x->byspi, - net->xfrm.state_byspi + h, + xfrm_state_deref_prot(net->xfrm.state_byspi, net) + h, x->xso.type); } if (x->km.seq) { h = xfrm_seq_hash(net, x->km.seq); XFRM_STATE_INSERT(byseq, &x->byseq, - net->xfrm.state_byseq + h, + xfrm_state_deref_prot(net->xfrm.state_byseq, net) + h, x->xso.type); } x->lft.hard_add_expires_seconds = net->xfrm.sysctl_acq_expires; @@ -1652,7 +1652,7 @@ xfrm_stateonly_find(struct net *net, u32 mark, u32 if_id, spin_lock_bh(&net->xfrm.xfrm_state_lock); h = xfrm_dst_hash(net, daddr, saddr, reqid, family); - hlist_for_each_entry(x, net->xfrm.state_bydst+h, bydst) { + hlist_for_each_entry(x, xfrm_state_deref_prot(net->xfrm.state_bydst, net) + h, bydst) { if (x->props.family == family && x->props.reqid == reqid && (mark & x->mark.m) == x->mark.v && @@ -1703,18 +1703,12 @@ static struct xfrm_state *xfrm_state_lookup_spi_proto(struct net *net, __be32 sp struct xfrm_state *x; unsigned int i; - rcu_read_lock(); for (i = 0; i <= net->xfrm.state_hmask; i++) { - hlist_for_each_entry_rcu(x, &net->xfrm.state_byspi[i], byspi) { - if (x->id.spi == spi && x->id.proto == proto) { - if (!xfrm_state_hold_rcu(x)) - continue; - rcu_read_unlock(); + hlist_for_each_entry(x, xfrm_state_deref_prot(net->xfrm.state_byspi, net) + i, byspi) { + if (x->id.spi == spi && x->id.proto == proto) return x; - } } } - rcu_read_unlock(); return NULL; } @@ -1730,25 +1724,29 @@ static void __xfrm_state_insert(struct xfrm_state *x) h = xfrm_dst_hash(net, &x->id.daddr, &x->props.saddr, x->props.reqid, x->props.family); - XFRM_STATE_INSERT(bydst, &x->bydst, net->xfrm.state_bydst + h, + XFRM_STATE_INSERT(bydst, &x->bydst, + xfrm_state_deref_prot(net->xfrm.state_bydst, net) + h, x->xso.type); h = xfrm_src_hash(net, &x->id.daddr, &x->props.saddr, x->props.family); - XFRM_STATE_INSERT(bysrc, &x->bysrc, net->xfrm.state_bysrc + h, + XFRM_STATE_INSERT(bysrc, &x->bysrc, + xfrm_state_deref_prot(net->xfrm.state_bysrc, net) + h, x->xso.type); if (x->id.spi) { h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto, x->props.family); - XFRM_STATE_INSERT(byspi, &x->byspi, net->xfrm.state_byspi + h, + XFRM_STATE_INSERT(byspi, &x->byspi, + xfrm_state_deref_prot(net->xfrm.state_byspi, net) + h, x->xso.type); } if (x->km.seq) { h = xfrm_seq_hash(net, x->km.seq); - XFRM_STATE_INSERT(byseq, &x->byseq, net->xfrm.state_byseq + h, + XFRM_STATE_INSERT(byseq, &x->byseq, + xfrm_state_deref_prot(net->xfrm.state_byseq, net) + h, x->xso.type); } @@ -1775,7 +1773,7 @@ static void __xfrm_state_bump_genids(struct xfrm_state *xnew) u32 cpu_id = xnew->pcpu_num; h = xfrm_dst_hash(net, &xnew->id.daddr, &xnew->props.saddr, reqid, family); - hlist_for_each_entry(x, net->xfrm.state_bydst+h, bydst) { + hlist_for_each_entry(x, xfrm_state_deref_prot(net->xfrm.state_bydst, net) + h, bydst) { if (x->props.family == family && x->props.reqid == reqid && x->if_id == if_id && @@ -1811,7 +1809,7 @@ static struct xfrm_state *__find_acq_core(struct net *net, struct xfrm_state *x; u32 mark = m->v & m->m; - hlist_for_each_entry(x, net->xfrm.state_bydst+h, bydst) { + hlist_for_each_entry(x, xfrm_state_deref_prot(net->xfrm.state_bydst, net) + h, bydst) { if (x->props.reqid != reqid || x->props.mode != mode || x->props.family != family || @@ -1868,10 +1866,12 @@ static struct xfrm_state *__find_acq_core(struct net *net, ktime_set(net->xfrm.sysctl_acq_expires, 0), HRTIMER_MODE_REL_SOFT); list_add(&x->km.all, &net->xfrm.state_all); - XFRM_STATE_INSERT(bydst, &x->bydst, net->xfrm.state_bydst + h, + XFRM_STATE_INSERT(bydst, &x->bydst, + xfrm_state_deref_prot(net->xfrm.state_bydst, net) + h, x->xso.type); h = xfrm_src_hash(net, daddr, saddr, family); - XFRM_STATE_INSERT(bysrc, &x->bysrc, net->xfrm.state_bysrc + h, + XFRM_STATE_INSERT(bysrc, &x->bysrc, + xfrm_state_deref_prot(net->xfrm.state_bysrc, net) + h, x->xso.type); net->xfrm.state_num++; @@ -2091,7 +2091,7 @@ struct xfrm_state *xfrm_migrate_state_find(struct xfrm_migrate *m, struct net *n if (m->reqid) { h = xfrm_dst_hash(net, &m->old_daddr, &m->old_saddr, m->reqid, m->old_family); - hlist_for_each_entry(x, net->xfrm.state_bydst+h, bydst) { + hlist_for_each_entry(x, xfrm_state_deref_prot(net->xfrm.state_bydst, net) + h, bydst) { if (x->props.mode != m->mode || x->id.proto != m->proto) continue; @@ -2110,7 +2110,7 @@ struct xfrm_state *xfrm_migrate_state_find(struct xfrm_migrate *m, struct net *n } else { h = xfrm_src_hash(net, &m->old_daddr, &m->old_saddr, m->old_family); - hlist_for_each_entry(x, net->xfrm.state_bysrc+h, bysrc) { + hlist_for_each_entry(x, xfrm_state_deref_prot(net->xfrm.state_bysrc, net) + h, bysrc) { if (x->props.mode != m->mode || x->id.proto != m->proto) continue; @@ -2264,6 +2264,7 @@ out: err = 0; x->km.state = XFRM_STATE_DEAD; + xfrm_dev_state_delete(x); __xfrm_state_put(x); } @@ -2312,7 +2313,7 @@ void xfrm_state_update_stats(struct net *net) spin_lock_bh(&net->xfrm.xfrm_state_lock); for (i = 0; i <= net->xfrm.state_hmask; i++) { - hlist_for_each_entry(x, net->xfrm.state_bydst + i, bydst) + hlist_for_each_entry(x, xfrm_state_deref_prot(net->xfrm.state_bydst, net) + i, bydst) xfrm_dev_state_update_stats(x); } spin_unlock_bh(&net->xfrm.xfrm_state_lock); @@ -2503,7 +2504,7 @@ static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 s unsigned int h = xfrm_seq_hash(net, seq); struct xfrm_state *x; - hlist_for_each_entry_rcu(x, net->xfrm.state_byseq + h, byseq) { + hlist_for_each_entry(x, xfrm_state_deref_prot(net->xfrm.state_byseq, net) + h, byseq) { if (x->km.seq == seq && (mark & x->mark.m) == x->mark.v && x->pcpu_num == pcpu_num && @@ -2602,12 +2603,13 @@ int xfrm_alloc_spi(struct xfrm_state *x, u32 low, u32 high, if (!x0) { x->id.spi = newspi; h = xfrm_spi_hash(net, &x->id.daddr, newspi, x->id.proto, x->props.family); - XFRM_STATE_INSERT(byspi, &x->byspi, net->xfrm.state_byspi + h, x->xso.type); + XFRM_STATE_INSERT(byspi, &x->byspi, + xfrm_state_deref_prot(net->xfrm.state_byspi, net) + h, + x->xso.type); spin_unlock_bh(&net->xfrm.xfrm_state_lock); err = 0; goto unlock; } - xfrm_state_put(x0); spin_unlock_bh(&net->xfrm.xfrm_state_lock); next: @@ -3258,6 +3260,7 @@ EXPORT_SYMBOL(xfrm_init_state); int __net_init xfrm_state_init(struct net *net) { + struct hlist_head *ndst, *nsrc, *nspi, *nseq; unsigned int sz; if (net_eq(net, &init_net)) @@ -3268,18 +3271,25 @@ int __net_init xfrm_state_init(struct net *net) sz = sizeof(struct hlist_head) * 8; - net->xfrm.state_bydst = xfrm_hash_alloc(sz); - if (!net->xfrm.state_bydst) + ndst = xfrm_hash_alloc(sz); + if (!ndst) goto out_bydst; - net->xfrm.state_bysrc = xfrm_hash_alloc(sz); - if (!net->xfrm.state_bysrc) + rcu_assign_pointer(net->xfrm.state_bydst, ndst); + + nsrc = xfrm_hash_alloc(sz); + if (!nsrc) goto out_bysrc; - net->xfrm.state_byspi = xfrm_hash_alloc(sz); - if (!net->xfrm.state_byspi) + rcu_assign_pointer(net->xfrm.state_bysrc, nsrc); + + nspi = xfrm_hash_alloc(sz); + if (!nspi) goto out_byspi; - net->xfrm.state_byseq = xfrm_hash_alloc(sz); - if (!net->xfrm.state_byseq) + rcu_assign_pointer(net->xfrm.state_byspi, nspi); + + nseq = xfrm_hash_alloc(sz); + if (!nseq) goto out_byseq; + rcu_assign_pointer(net->xfrm.state_byseq, nseq); net->xfrm.state_cache_input = alloc_percpu(struct hlist_head); if (!net->xfrm.state_cache_input) @@ -3295,17 +3305,19 @@ int __net_init xfrm_state_init(struct net *net) return 0; out_state_cache_input: - xfrm_hash_free(net->xfrm.state_byseq, sz); + xfrm_hash_free(nseq, sz); out_byseq: - xfrm_hash_free(net->xfrm.state_byspi, sz); + xfrm_hash_free(nspi, sz); out_byspi: - xfrm_hash_free(net->xfrm.state_bysrc, sz); + xfrm_hash_free(nsrc, sz); out_bysrc: - xfrm_hash_free(net->xfrm.state_bydst, sz); + xfrm_hash_free(ndst, sz); out_bydst: return -ENOMEM; } +#define xfrm_state_deref_netexit(table) \ + rcu_dereference_protected((table), true /* netns is going away */) void xfrm_state_fini(struct net *net) { unsigned int sz; @@ -3318,17 +3330,17 @@ void xfrm_state_fini(struct net *net) WARN_ON(!list_empty(&net->xfrm.state_all)); for (i = 0; i <= net->xfrm.state_hmask; i++) { - WARN_ON(!hlist_empty(net->xfrm.state_byseq + i)); - WARN_ON(!hlist_empty(net->xfrm.state_byspi + i)); - WARN_ON(!hlist_empty(net->xfrm.state_bysrc + i)); - WARN_ON(!hlist_empty(net->xfrm.state_bydst + i)); + WARN_ON(!hlist_empty(xfrm_state_deref_netexit(net->xfrm.state_byseq) + i)); + WARN_ON(!hlist_empty(xfrm_state_deref_netexit(net->xfrm.state_byspi) + i)); + WARN_ON(!hlist_empty(xfrm_state_deref_netexit(net->xfrm.state_bysrc) + i)); + WARN_ON(!hlist_empty(xfrm_state_deref_netexit(net->xfrm.state_bydst) + i)); } sz = (net->xfrm.state_hmask + 1) * sizeof(struct hlist_head); - xfrm_hash_free(net->xfrm.state_byseq, sz); - xfrm_hash_free(net->xfrm.state_byspi, sz); - xfrm_hash_free(net->xfrm.state_bysrc, sz); - xfrm_hash_free(net->xfrm.state_bydst, sz); + xfrm_hash_free(xfrm_state_deref_netexit(net->xfrm.state_byseq), sz); + xfrm_hash_free(xfrm_state_deref_netexit(net->xfrm.state_byspi), sz); + xfrm_hash_free(xfrm_state_deref_netexit(net->xfrm.state_bysrc), sz); + xfrm_hash_free(xfrm_state_deref_netexit(net->xfrm.state_bydst), sz); free_percpu(net->xfrm.state_cache_input); } |
