Commit ecb2421b authored by Florian Westphal's avatar Florian Westphal Committed by Pablo Neira Ayuso

netfilter: add and use nf_ct_netns_get/put

currently aliased to try_module_get/_put.
Will be changed in next patch when we add functions to make use of ->net
argument to store usercount per l3proto tracker.

This is needed to avoid registering the conntrack hooks in all netns and
later only enable connection tracking in those that need conntrack.
Signed-off-by: default avatarFlorian Westphal <fw@strlen.de>
Signed-off-by: default avatarPablo Neira Ayuso <pablo@netfilter.org>
parent a379854d
...@@ -181,6 +181,10 @@ static inline void nf_ct_put(struct nf_conn *ct) ...@@ -181,6 +181,10 @@ static inline void nf_ct_put(struct nf_conn *ct)
int nf_ct_l3proto_try_module_get(unsigned short l3proto); int nf_ct_l3proto_try_module_get(unsigned short l3proto);
void nf_ct_l3proto_module_put(unsigned short l3proto); void nf_ct_l3proto_module_put(unsigned short l3proto);
/* load module; enable/disable conntrack in this namespace */
int nf_ct_netns_get(struct net *net, u8 nfproto);
void nf_ct_netns_put(struct net *net, u8 nfproto);
/* /*
* Allocate a hashtable of hlist_head (if nulls == 0), * Allocate a hashtable of hlist_head (if nulls == 0),
* or hlist_nulls_head (if nulls == 1) * or hlist_nulls_head (if nulls == 1)
......
...@@ -419,7 +419,7 @@ static int clusterip_tg_check(const struct xt_tgchk_param *par) ...@@ -419,7 +419,7 @@ static int clusterip_tg_check(const struct xt_tgchk_param *par)
} }
cipinfo->config = config; cipinfo->config = config;
ret = nf_ct_l3proto_try_module_get(par->family); ret = nf_ct_netns_get(par->net, par->family);
if (ret < 0) if (ret < 0)
pr_info("cannot load conntrack support for proto=%u\n", pr_info("cannot load conntrack support for proto=%u\n",
par->family); par->family);
...@@ -444,7 +444,7 @@ static void clusterip_tg_destroy(const struct xt_tgdtor_param *par) ...@@ -444,7 +444,7 @@ static void clusterip_tg_destroy(const struct xt_tgdtor_param *par)
clusterip_config_put(cipinfo->config); clusterip_config_put(cipinfo->config);
nf_ct_l3proto_module_put(par->family); nf_ct_netns_get(par->net, par->family);
} }
#ifdef CONFIG_COMPAT #ifdef CONFIG_COMPAT
......
...@@ -418,12 +418,12 @@ static int synproxy_tg4_check(const struct xt_tgchk_param *par) ...@@ -418,12 +418,12 @@ static int synproxy_tg4_check(const struct xt_tgchk_param *par)
e->ip.invflags & XT_INV_PROTO) e->ip.invflags & XT_INV_PROTO)
return -EINVAL; return -EINVAL;
return nf_ct_l3proto_try_module_get(par->family); return nf_ct_netns_get(par->net, par->family);
} }
static void synproxy_tg4_destroy(const struct xt_tgdtor_param *par) static void synproxy_tg4_destroy(const struct xt_tgdtor_param *par)
{ {
nf_ct_l3proto_module_put(par->family); nf_ct_netns_put(par->net, par->family);
} }
static struct xt_target synproxy_tg4_reg __read_mostly = { static struct xt_target synproxy_tg4_reg __read_mostly = {
......
...@@ -440,12 +440,12 @@ static int synproxy_tg6_check(const struct xt_tgchk_param *par) ...@@ -440,12 +440,12 @@ static int synproxy_tg6_check(const struct xt_tgchk_param *par)
e->ipv6.invflags & XT_INV_PROTO) e->ipv6.invflags & XT_INV_PROTO)
return -EINVAL; return -EINVAL;
return nf_ct_l3proto_try_module_get(par->family); return nf_ct_netns_get(par->net, par->family);
} }
static void synproxy_tg6_destroy(const struct xt_tgdtor_param *par) static void synproxy_tg6_destroy(const struct xt_tgdtor_param *par)
{ {
nf_ct_l3proto_module_put(par->family); nf_ct_netns_put(par->net, par->family);
} }
static struct xt_target synproxy_tg6_reg __read_mostly = { static struct xt_target synproxy_tg6_reg __read_mostly = {
......
...@@ -125,6 +125,18 @@ void nf_ct_l3proto_module_put(unsigned short l3proto) ...@@ -125,6 +125,18 @@ void nf_ct_l3proto_module_put(unsigned short l3proto)
} }
EXPORT_SYMBOL_GPL(nf_ct_l3proto_module_put); EXPORT_SYMBOL_GPL(nf_ct_l3proto_module_put);
int nf_ct_netns_get(struct net *net, u8 nfproto)
{
return nf_ct_l3proto_try_module_get(nfproto);
}
EXPORT_SYMBOL_GPL(nf_ct_netns_get);
void nf_ct_netns_put(struct net *net, u8 nfproto)
{
nf_ct_l3proto_module_put(nfproto);
}
EXPORT_SYMBOL_GPL(nf_ct_netns_put);
struct nf_conntrack_l4proto * struct nf_conntrack_l4proto *
nf_ct_l4proto_find_get(u_int16_t l3num, u_int8_t l4num) nf_ct_l4proto_find_get(u_int16_t l3num, u_int8_t l4num)
{ {
......
...@@ -208,37 +208,37 @@ static const struct nla_policy nft_ct_policy[NFTA_CT_MAX + 1] = { ...@@ -208,37 +208,37 @@ static const struct nla_policy nft_ct_policy[NFTA_CT_MAX + 1] = {
[NFTA_CT_SREG] = { .type = NLA_U32 }, [NFTA_CT_SREG] = { .type = NLA_U32 },
}; };
static int nft_ct_l3proto_try_module_get(uint8_t family) static int nft_ct_netns_get(struct net *net, uint8_t family)
{ {
int err; int err;
if (family == NFPROTO_INET) { if (family == NFPROTO_INET) {
err = nf_ct_l3proto_try_module_get(NFPROTO_IPV4); err = nf_ct_netns_get(net, NFPROTO_IPV4);
if (err < 0) if (err < 0)
goto err1; goto err1;
err = nf_ct_l3proto_try_module_get(NFPROTO_IPV6); err = nf_ct_netns_get(net, NFPROTO_IPV6);
if (err < 0) if (err < 0)
goto err2; goto err2;
} else { } else {
err = nf_ct_l3proto_try_module_get(family); err = nf_ct_netns_get(net, family);
if (err < 0) if (err < 0)
goto err1; goto err1;
} }
return 0; return 0;
err2: err2:
nf_ct_l3proto_module_put(NFPROTO_IPV4); nf_ct_netns_put(net, NFPROTO_IPV4);
err1: err1:
return err; return err;
} }
static void nft_ct_l3proto_module_put(uint8_t family) static void nft_ct_netns_put(struct net *net, uint8_t family)
{ {
if (family == NFPROTO_INET) { if (family == NFPROTO_INET) {
nf_ct_l3proto_module_put(NFPROTO_IPV4); nf_ct_netns_put(net, NFPROTO_IPV4);
nf_ct_l3proto_module_put(NFPROTO_IPV6); nf_ct_netns_put(net, NFPROTO_IPV6);
} else } else
nf_ct_l3proto_module_put(family); nf_ct_netns_put(net, family);
} }
static int nft_ct_get_init(const struct nft_ctx *ctx, static int nft_ct_get_init(const struct nft_ctx *ctx,
...@@ -342,7 +342,7 @@ static int nft_ct_get_init(const struct nft_ctx *ctx, ...@@ -342,7 +342,7 @@ static int nft_ct_get_init(const struct nft_ctx *ctx,
if (err < 0) if (err < 0)
return err; return err;
err = nft_ct_l3proto_try_module_get(ctx->afi->family); err = nft_ct_netns_get(ctx->net, ctx->afi->family);
if (err < 0) if (err < 0)
return err; return err;
...@@ -390,7 +390,7 @@ static int nft_ct_set_init(const struct nft_ctx *ctx, ...@@ -390,7 +390,7 @@ static int nft_ct_set_init(const struct nft_ctx *ctx,
if (err < 0) if (err < 0)
goto err1; goto err1;
err = nft_ct_l3proto_try_module_get(ctx->afi->family); err = nft_ct_netns_get(ctx->net, ctx->afi->family);
if (err < 0) if (err < 0)
goto err1; goto err1;
...@@ -405,7 +405,7 @@ static int nft_ct_set_init(const struct nft_ctx *ctx, ...@@ -405,7 +405,7 @@ static int nft_ct_set_init(const struct nft_ctx *ctx,
static void nft_ct_get_destroy(const struct nft_ctx *ctx, static void nft_ct_get_destroy(const struct nft_ctx *ctx,
const struct nft_expr *expr) const struct nft_expr *expr)
{ {
nft_ct_l3proto_module_put(ctx->afi->family); nf_ct_netns_put(ctx->net, ctx->afi->family);
} }
static void nft_ct_set_destroy(const struct nft_ctx *ctx, static void nft_ct_set_destroy(const struct nft_ctx *ctx,
...@@ -423,7 +423,7 @@ static void nft_ct_set_destroy(const struct nft_ctx *ctx, ...@@ -423,7 +423,7 @@ static void nft_ct_set_destroy(const struct nft_ctx *ctx,
break; break;
} }
nft_ct_l3proto_module_put(ctx->afi->family); nft_ct_netns_put(ctx->net, ctx->afi->family);
} }
static int nft_ct_get_dump(struct sk_buff *skb, const struct nft_expr *expr) static int nft_ct_get_dump(struct sk_buff *skb, const struct nft_expr *expr)
......
...@@ -106,7 +106,7 @@ static int connsecmark_tg_check(const struct xt_tgchk_param *par) ...@@ -106,7 +106,7 @@ static int connsecmark_tg_check(const struct xt_tgchk_param *par)
return -EINVAL; return -EINVAL;
} }
ret = nf_ct_l3proto_try_module_get(par->family); ret = nf_ct_netns_get(par->net, par->family);
if (ret < 0) if (ret < 0)
pr_info("cannot load conntrack support for proto=%u\n", pr_info("cannot load conntrack support for proto=%u\n",
par->family); par->family);
...@@ -115,7 +115,7 @@ static int connsecmark_tg_check(const struct xt_tgchk_param *par) ...@@ -115,7 +115,7 @@ static int connsecmark_tg_check(const struct xt_tgchk_param *par)
static void connsecmark_tg_destroy(const struct xt_tgdtor_param *par) static void connsecmark_tg_destroy(const struct xt_tgdtor_param *par)
{ {
nf_ct_l3proto_module_put(par->family); nf_ct_netns_put(par->net, par->family);
} }
static struct xt_target connsecmark_tg_reg __read_mostly = { static struct xt_target connsecmark_tg_reg __read_mostly = {
......
...@@ -216,7 +216,7 @@ static int xt_ct_tg_check(const struct xt_tgchk_param *par, ...@@ -216,7 +216,7 @@ static int xt_ct_tg_check(const struct xt_tgchk_param *par,
goto err1; goto err1;
#endif #endif
ret = nf_ct_l3proto_try_module_get(par->family); ret = nf_ct_netns_get(par->net, par->family);
if (ret < 0) if (ret < 0)
goto err1; goto err1;
...@@ -260,7 +260,7 @@ static int xt_ct_tg_check(const struct xt_tgchk_param *par, ...@@ -260,7 +260,7 @@ static int xt_ct_tg_check(const struct xt_tgchk_param *par,
err3: err3:
nf_ct_tmpl_free(ct); nf_ct_tmpl_free(ct);
err2: err2:
nf_ct_l3proto_module_put(par->family); nf_ct_netns_put(par->net, par->family);
err1: err1:
return ret; return ret;
} }
...@@ -341,7 +341,7 @@ static void xt_ct_tg_destroy(const struct xt_tgdtor_param *par, ...@@ -341,7 +341,7 @@ static void xt_ct_tg_destroy(const struct xt_tgdtor_param *par,
if (help) if (help)
module_put(help->helper->me); module_put(help->helper->me);
nf_ct_l3proto_module_put(par->family); nf_ct_netns_put(par->net, par->family);
xt_ct_destroy_timeout(ct); xt_ct_destroy_timeout(ct);
nf_ct_put(info->ct); nf_ct_put(info->ct);
......
...@@ -110,7 +110,7 @@ static int connbytes_mt_check(const struct xt_mtchk_param *par) ...@@ -110,7 +110,7 @@ static int connbytes_mt_check(const struct xt_mtchk_param *par)
sinfo->direction != XT_CONNBYTES_DIR_BOTH) sinfo->direction != XT_CONNBYTES_DIR_BOTH)
return -EINVAL; return -EINVAL;
ret = nf_ct_l3proto_try_module_get(par->family); ret = nf_ct_netns_get(par->net, par->family);
if (ret < 0) if (ret < 0)
pr_info("cannot load conntrack support for proto=%u\n", pr_info("cannot load conntrack support for proto=%u\n",
par->family); par->family);
...@@ -129,7 +129,7 @@ static int connbytes_mt_check(const struct xt_mtchk_param *par) ...@@ -129,7 +129,7 @@ static int connbytes_mt_check(const struct xt_mtchk_param *par)
static void connbytes_mt_destroy(const struct xt_mtdtor_param *par) static void connbytes_mt_destroy(const struct xt_mtdtor_param *par)
{ {
nf_ct_l3proto_module_put(par->family); nf_ct_netns_put(par->net, par->family);
} }
static struct xt_match connbytes_mt_reg __read_mostly = { static struct xt_match connbytes_mt_reg __read_mostly = {
......
...@@ -61,7 +61,7 @@ static int connlabel_mt_check(const struct xt_mtchk_param *par) ...@@ -61,7 +61,7 @@ static int connlabel_mt_check(const struct xt_mtchk_param *par)
return -EINVAL; return -EINVAL;
} }
ret = nf_ct_l3proto_try_module_get(par->family); ret = nf_ct_netns_get(par->net, par->family);
if (ret < 0) { if (ret < 0) {
pr_info("cannot load conntrack support for proto=%u\n", pr_info("cannot load conntrack support for proto=%u\n",
par->family); par->family);
...@@ -70,14 +70,14 @@ static int connlabel_mt_check(const struct xt_mtchk_param *par) ...@@ -70,14 +70,14 @@ static int connlabel_mt_check(const struct xt_mtchk_param *par)
ret = nf_connlabels_get(par->net, info->bit); ret = nf_connlabels_get(par->net, info->bit);
if (ret < 0) if (ret < 0)
nf_ct_l3proto_module_put(par->family); nf_ct_netns_put(par->net, par->family);
return ret; return ret;
} }
static void connlabel_mt_destroy(const struct xt_mtdtor_param *par) static void connlabel_mt_destroy(const struct xt_mtdtor_param *par)
{ {
nf_connlabels_put(par->net); nf_connlabels_put(par->net);
nf_ct_l3proto_module_put(par->family); nf_ct_netns_put(par->net, par->family);
} }
static struct xt_match connlabels_mt_reg __read_mostly = { static struct xt_match connlabels_mt_reg __read_mostly = {
......
...@@ -368,7 +368,7 @@ static int connlimit_mt_check(const struct xt_mtchk_param *par) ...@@ -368,7 +368,7 @@ static int connlimit_mt_check(const struct xt_mtchk_param *par)
net_get_random_once(&connlimit_rnd, sizeof(connlimit_rnd)); net_get_random_once(&connlimit_rnd, sizeof(connlimit_rnd));
ret = nf_ct_l3proto_try_module_get(par->family); ret = nf_ct_netns_get(par->net, par->family);
if (ret < 0) { if (ret < 0) {
pr_info("cannot load conntrack support for " pr_info("cannot load conntrack support for "
"address family %u\n", par->family); "address family %u\n", par->family);
...@@ -378,7 +378,7 @@ static int connlimit_mt_check(const struct xt_mtchk_param *par) ...@@ -378,7 +378,7 @@ static int connlimit_mt_check(const struct xt_mtchk_param *par)
/* init private data */ /* init private data */
info->data = kmalloc(sizeof(struct xt_connlimit_data), GFP_KERNEL); info->data = kmalloc(sizeof(struct xt_connlimit_data), GFP_KERNEL);
if (info->data == NULL) { if (info->data == NULL) {
nf_ct_l3proto_module_put(par->family); nf_ct_netns_put(par->net, par->family);
return -ENOMEM; return -ENOMEM;
} }
...@@ -414,7 +414,7 @@ static void connlimit_mt_destroy(const struct xt_mtdtor_param *par) ...@@ -414,7 +414,7 @@ static void connlimit_mt_destroy(const struct xt_mtdtor_param *par)
const struct xt_connlimit_info *info = par->matchinfo; const struct xt_connlimit_info *info = par->matchinfo;
unsigned int i; unsigned int i;
nf_ct_l3proto_module_put(par->family); nf_ct_netns_put(par->net, par->family);
for (i = 0; i < ARRAY_SIZE(info->data->climit_root4); ++i) for (i = 0; i < ARRAY_SIZE(info->data->climit_root4); ++i)
destroy_tree(&info->data->climit_root4[i]); destroy_tree(&info->data->climit_root4[i]);
......
...@@ -77,7 +77,7 @@ static int connmark_tg_check(const struct xt_tgchk_param *par) ...@@ -77,7 +77,7 @@ static int connmark_tg_check(const struct xt_tgchk_param *par)
{ {
int ret; int ret;
ret = nf_ct_l3proto_try_module_get(par->family); ret = nf_ct_netns_get(par->net, par->family);
if (ret < 0) if (ret < 0)
pr_info("cannot load conntrack support for proto=%u\n", pr_info("cannot load conntrack support for proto=%u\n",
par->family); par->family);
...@@ -86,7 +86,7 @@ static int connmark_tg_check(const struct xt_tgchk_param *par) ...@@ -86,7 +86,7 @@ static int connmark_tg_check(const struct xt_tgchk_param *par)
static void connmark_tg_destroy(const struct xt_tgdtor_param *par) static void connmark_tg_destroy(const struct xt_tgdtor_param *par)
{ {
nf_ct_l3proto_module_put(par->family); nf_ct_netns_put(par->net, par->family);
} }
static bool static bool
...@@ -107,7 +107,7 @@ static int connmark_mt_check(const struct xt_mtchk_param *par) ...@@ -107,7 +107,7 @@ static int connmark_mt_check(const struct xt_mtchk_param *par)
{ {
int ret; int ret;
ret = nf_ct_l3proto_try_module_get(par->family); ret = nf_ct_netns_get(par->net, par->family);
if (ret < 0) if (ret < 0)
pr_info("cannot load conntrack support for proto=%u\n", pr_info("cannot load conntrack support for proto=%u\n",
par->family); par->family);
...@@ -116,7 +116,7 @@ static int connmark_mt_check(const struct xt_mtchk_param *par) ...@@ -116,7 +116,7 @@ static int connmark_mt_check(const struct xt_mtchk_param *par)
static void connmark_mt_destroy(const struct xt_mtdtor_param *par) static void connmark_mt_destroy(const struct xt_mtdtor_param *par)
{ {
nf_ct_l3proto_module_put(par->family); nf_ct_netns_put(par->net, par->family);
} }
static struct xt_target connmark_tg_reg __read_mostly = { static struct xt_target connmark_tg_reg __read_mostly = {
......
...@@ -271,7 +271,7 @@ static int conntrack_mt_check(const struct xt_mtchk_param *par) ...@@ -271,7 +271,7 @@ static int conntrack_mt_check(const struct xt_mtchk_param *par)
{ {
int ret; int ret;
ret = nf_ct_l3proto_try_module_get(par->family); ret = nf_ct_netns_get(par->net, par->family);
if (ret < 0) if (ret < 0)
pr_info("cannot load conntrack support for proto=%u\n", pr_info("cannot load conntrack support for proto=%u\n",
par->family); par->family);
...@@ -280,7 +280,7 @@ static int conntrack_mt_check(const struct xt_mtchk_param *par) ...@@ -280,7 +280,7 @@ static int conntrack_mt_check(const struct xt_mtchk_param *par)
static void conntrack_mt_destroy(const struct xt_mtdtor_param *par) static void conntrack_mt_destroy(const struct xt_mtdtor_param *par)
{ {
nf_ct_l3proto_module_put(par->family); nf_ct_netns_put(par->net, par->family);
} }
static struct xt_match conntrack_mt_reg[] __read_mostly = { static struct xt_match conntrack_mt_reg[] __read_mostly = {
......
...@@ -59,7 +59,7 @@ static int helper_mt_check(const struct xt_mtchk_param *par) ...@@ -59,7 +59,7 @@ static int helper_mt_check(const struct xt_mtchk_param *par)
struct xt_helper_info *info = par->matchinfo; struct xt_helper_info *info = par->matchinfo;
int ret; int ret;
ret = nf_ct_l3proto_try_module_get(par->family); ret = nf_ct_netns_get(par->net, par->family);
if (ret < 0) { if (ret < 0) {
pr_info("cannot load conntrack support for proto=%u\n", pr_info("cannot load conntrack support for proto=%u\n",
par->family); par->family);
...@@ -71,7 +71,7 @@ static int helper_mt_check(const struct xt_mtchk_param *par) ...@@ -71,7 +71,7 @@ static int helper_mt_check(const struct xt_mtchk_param *par)
static void helper_mt_destroy(const struct xt_mtdtor_param *par) static void helper_mt_destroy(const struct xt_mtdtor_param *par)
{ {
nf_ct_l3proto_module_put(par->family); nf_ct_netns_put(par->net, par->family);
} }
static struct xt_match helper_mt_reg __read_mostly = { static struct xt_match helper_mt_reg __read_mostly = {
......
...@@ -43,7 +43,7 @@ static int state_mt_check(const struct xt_mtchk_param *par) ...@@ -43,7 +43,7 @@ static int state_mt_check(const struct xt_mtchk_param *par)
{ {
int ret; int ret;
ret = nf_ct_l3proto_try_module_get(par->family); ret = nf_ct_netns_get(par->net, par->family);
if (ret < 0) if (ret < 0)
pr_info("cannot load conntrack support for proto=%u\n", pr_info("cannot load conntrack support for proto=%u\n",
par->family); par->family);
...@@ -52,7 +52,7 @@ static int state_mt_check(const struct xt_mtchk_param *par) ...@@ -52,7 +52,7 @@ static int state_mt_check(const struct xt_mtchk_param *par)
static void state_mt_destroy(const struct xt_mtdtor_param *par) static void state_mt_destroy(const struct xt_mtdtor_param *par)
{ {
nf_ct_l3proto_module_put(par->family); nf_ct_netns_put(par->net, par->family);
} }
static struct xt_match state_mt_reg __read_mostly = { static struct xt_match state_mt_reg __read_mostly = {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment