Commit 41e3d179 authored by Trond Myklebust's avatar Trond Myklebust

RPC: clean up the RPCSEC_GSS kerberos and spkm3 context import functions

Signed-off-by: default avatarTrond Myklebust <Trond.Myklebust@netapp.com>
parent 25326faa
...@@ -33,8 +33,9 @@ struct gss_ctx { ...@@ -33,8 +33,9 @@ struct gss_ctx {
/* gss-api prototypes; note that these are somewhat simplified versions of /* gss-api prototypes; note that these are somewhat simplified versions of
* the prototypes specified in RFC 2744. */ * the prototypes specified in RFC 2744. */
u32 gss_import_sec_context( int gss_import_sec_context(
struct xdr_netobj *input_token, const void* input_token,
size_t bufsize,
struct gss_api_mech *mech, struct gss_api_mech *mech,
struct gss_ctx **ctx_id); struct gss_ctx **ctx_id);
u32 gss_get_mic( u32 gss_get_mic(
...@@ -80,8 +81,9 @@ struct gss_api_mech { ...@@ -80,8 +81,9 @@ struct gss_api_mech {
/* and must provide the following operations: */ /* and must provide the following operations: */
struct gss_api_ops { struct gss_api_ops {
u32 (*gss_import_sec_context)( int (*gss_import_sec_context)(
struct xdr_netobj *input_token, const void *input_token,
size_t bufsize,
struct gss_ctx *ctx_id); struct gss_ctx *ctx_id);
u32 (*gss_get_mic)( u32 (*gss_get_mic)(
struct gss_ctx *ctx_id, struct gss_ctx *ctx_id,
......
...@@ -272,7 +272,7 @@ gss_parse_init_downcall(struct gss_api_mech *gm, struct xdr_netobj *buf, ...@@ -272,7 +272,7 @@ gss_parse_init_downcall(struct gss_api_mech *gm, struct xdr_netobj *buf,
goto err_free_wire_ctx; goto err_free_wire_ctx;
if (p != end) if (p != end)
goto err_free_wire_ctx; goto err_free_wire_ctx;
if (gss_import_sec_context(&tmp_buf, gm, &ctx->gc_gss_ctx)) if (gss_import_sec_context(tmp_buf.data, tmp_buf.len, gm, &ctx->gc_gss_ctx))
goto err_free_wire_ctx; goto err_free_wire_ctx;
*gc = ctx; *gc = ctx;
return 0; return 0;
......
...@@ -48,46 +48,48 @@ ...@@ -48,46 +48,48 @@
# define RPCDBG_FACILITY RPCDBG_AUTH # define RPCDBG_FACILITY RPCDBG_AUTH
#endif #endif
static inline int static const void *
get_bytes(char **ptr, const char *end, void *res, int len) simple_get_bytes(const void *p, const void *end, void *res, int len)
{ {
char *p, *q; const void *q = (const void *)((const char *)p + len);
p = *ptr; if (unlikely(q > end || q < p))
q = p + len; return ERR_PTR(-EFAULT);
if (q > end || q < p)
return -1;
memcpy(res, p, len); memcpy(res, p, len);
*ptr = q; return q;
return 0;
} }
static inline int static const void *
get_netobj(char **ptr, const char *end, struct xdr_netobj *res) simple_get_netobj(const void *p, const void *end, struct xdr_netobj *res)
{ {
char *p, *q; const void *q;
p = *ptr; unsigned int len;
if (get_bytes(&p, end, &res->len, sizeof(res->len)))
return -1; p = simple_get_bytes(p, end, &len, sizeof(len));
q = p + res->len; if (IS_ERR(p))
if (q > end || q < p) return p;
return -1; q = (const void *)((const char *)p + len);
if (!(res->data = kmalloc(res->len, GFP_KERNEL))) if (unlikely(q > end || q < p))
return -1; return ERR_PTR(-EFAULT);
memcpy(res->data, p, res->len); res->data = kmalloc(len, GFP_KERNEL);
*ptr = q; if (unlikely(res->data == NULL))
return 0; return ERR_PTR(-ENOMEM);
memcpy(res->data, p, len);
res->len = len;
return q;
} }
static inline int static inline const void *
get_key(char **p, char *end, struct crypto_tfm **res) get_key(const void *p, const void *end, struct crypto_tfm **res)
{ {
struct xdr_netobj key; struct xdr_netobj key;
int alg, alg_mode; int alg, alg_mode;
char *alg_name; char *alg_name;
if (get_bytes(p, end, &alg, sizeof(alg))) p = simple_get_bytes(p, end, &alg, sizeof(alg));
if (IS_ERR(p))
goto out_err; goto out_err;
if ((get_netobj(p, end, &key))) p = simple_get_netobj(p, end, &key);
if (IS_ERR(p))
goto out_err; goto out_err;
switch (alg) { switch (alg) {
...@@ -105,50 +107,63 @@ get_key(char **p, char *end, struct crypto_tfm **res) ...@@ -105,50 +107,63 @@ get_key(char **p, char *end, struct crypto_tfm **res)
goto out_err_free_tfm; goto out_err_free_tfm;
kfree(key.data); kfree(key.data);
return 0; return p;
out_err_free_tfm: out_err_free_tfm:
crypto_free_tfm(*res); crypto_free_tfm(*res);
out_err_free_key: out_err_free_key:
kfree(key.data); kfree(key.data);
p = ERR_PTR(-EINVAL);
out_err: out_err:
return -1; return p;
} }
static u32 static int
gss_import_sec_context_kerberos(struct xdr_netobj *inbuf, gss_import_sec_context_kerberos(const void *p,
size_t len,
struct gss_ctx *ctx_id) struct gss_ctx *ctx_id)
{ {
char *p = inbuf->data; const void *end = (const void *)((const char *)p + len);
char *end = inbuf->data + inbuf->len;
struct krb5_ctx *ctx; struct krb5_ctx *ctx;
if (!(ctx = kmalloc(sizeof(*ctx), GFP_KERNEL))) if (!(ctx = kmalloc(sizeof(*ctx), GFP_KERNEL)))
goto out_err; goto out_err;
memset(ctx, 0, sizeof(*ctx)); memset(ctx, 0, sizeof(*ctx));
if (get_bytes(&p, end, &ctx->initiate, sizeof(ctx->initiate))) p = simple_get_bytes(p, end, &ctx->initiate, sizeof(ctx->initiate));
if (IS_ERR(p))
goto out_err_free_ctx; goto out_err_free_ctx;
if (get_bytes(&p, end, &ctx->seed_init, sizeof(ctx->seed_init))) p = simple_get_bytes(p, end, &ctx->seed_init, sizeof(ctx->seed_init));
if (IS_ERR(p))
goto out_err_free_ctx; goto out_err_free_ctx;
if (get_bytes(&p, end, ctx->seed, sizeof(ctx->seed))) p = simple_get_bytes(p, end, ctx->seed, sizeof(ctx->seed));
if (IS_ERR(p))
goto out_err_free_ctx; goto out_err_free_ctx;
if (get_bytes(&p, end, &ctx->signalg, sizeof(ctx->signalg))) p = simple_get_bytes(p, end, &ctx->signalg, sizeof(ctx->signalg));
if (IS_ERR(p))
goto out_err_free_ctx; goto out_err_free_ctx;
if (get_bytes(&p, end, &ctx->sealalg, sizeof(ctx->sealalg))) p = simple_get_bytes(p, end, &ctx->sealalg, sizeof(ctx->sealalg));
if (IS_ERR(p))
goto out_err_free_ctx; goto out_err_free_ctx;
if (get_bytes(&p, end, &ctx->endtime, sizeof(ctx->endtime))) p = simple_get_bytes(p, end, &ctx->endtime, sizeof(ctx->endtime));
if (IS_ERR(p))
goto out_err_free_ctx; goto out_err_free_ctx;
if (get_bytes(&p, end, &ctx->seq_send, sizeof(ctx->seq_send))) p = simple_get_bytes(p, end, &ctx->seq_send, sizeof(ctx->seq_send));
if (IS_ERR(p))
goto out_err_free_ctx; goto out_err_free_ctx;
if (get_netobj(&p, end, &ctx->mech_used)) p = simple_get_netobj(p, end, &ctx->mech_used);
if (IS_ERR(p))
goto out_err_free_ctx; goto out_err_free_ctx;
if (get_key(&p, end, &ctx->enc)) p = get_key(p, end, &ctx->enc);
if (IS_ERR(p))
goto out_err_free_mech; goto out_err_free_mech;
if (get_key(&p, end, &ctx->seq)) p = get_key(p, end, &ctx->seq);
if (IS_ERR(p))
goto out_err_free_key1; goto out_err_free_key1;
if (p != end) if (p != end) {
p = ERR_PTR(-EFAULT);
goto out_err_free_key2; goto out_err_free_key2;
}
ctx_id->internal_ctx_id = ctx; ctx_id->internal_ctx_id = ctx;
dprintk("RPC: Succesfully imported new context.\n"); dprintk("RPC: Succesfully imported new context.\n");
...@@ -163,7 +178,7 @@ gss_import_sec_context_kerberos(struct xdr_netobj *inbuf, ...@@ -163,7 +178,7 @@ gss_import_sec_context_kerberos(struct xdr_netobj *inbuf,
out_err_free_ctx: out_err_free_ctx:
kfree(ctx); kfree(ctx);
out_err: out_err:
return GSS_S_FAILURE; return PTR_ERR(p);
} }
static void static void
......
...@@ -233,8 +233,8 @@ EXPORT_SYMBOL(gss_mech_put); ...@@ -233,8 +233,8 @@ EXPORT_SYMBOL(gss_mech_put);
/* The mech could probably be determined from the token instead, but it's just /* The mech could probably be determined from the token instead, but it's just
* as easy for now to pass it in. */ * as easy for now to pass it in. */
u32 int
gss_import_sec_context(struct xdr_netobj *input_token, gss_import_sec_context(const void *input_token, size_t bufsize,
struct gss_api_mech *mech, struct gss_api_mech *mech,
struct gss_ctx **ctx_id) struct gss_ctx **ctx_id)
{ {
...@@ -244,7 +244,7 @@ gss_import_sec_context(struct xdr_netobj *input_token, ...@@ -244,7 +244,7 @@ gss_import_sec_context(struct xdr_netobj *input_token,
(*ctx_id)->mech_type = gss_mech_get(mech); (*ctx_id)->mech_type = gss_mech_get(mech);
return mech->gm_ops return mech->gm_ops
->gss_import_sec_context(input_token, *ctx_id); ->gss_import_sec_context(input_token, bufsize, *ctx_id);
} }
/* gss_get_mic: compute a mic over message and return mic_token. */ /* gss_get_mic: compute a mic over message and return mic_token. */
......
...@@ -49,52 +49,51 @@ ...@@ -49,52 +49,51 @@
# define RPCDBG_FACILITY RPCDBG_AUTH # define RPCDBG_FACILITY RPCDBG_AUTH
#endif #endif
static inline int static const void *
get_bytes(char **ptr, const char *end, void *res, int len) simple_get_bytes(const void *p, const void *end, void *res, int len)
{ {
char *p, *q; const void *q = (const void *)((const char *)p + len);
p = *ptr; if (unlikely(q > end || q < p))
q = p + len; return ERR_PTR(-EFAULT);
if (q > end || q < p)
return -1;
memcpy(res, p, len); memcpy(res, p, len);
*ptr = q; return q;
return 0;
} }
static inline int static const void *
get_netobj(char **ptr, const char *end, struct xdr_netobj *res) simple_get_netobj(const void *p, const void *end, struct xdr_netobj *res)
{ {
char *p, *q; const void *q;
p = *ptr; unsigned int len;
if (get_bytes(&p, end, &res->len, sizeof(res->len))) p = simple_get_bytes(p, end, &len, sizeof(len));
return -1; if (IS_ERR(p))
q = p + res->len; return p;
if(res->len == 0) res->len = len;
goto out_nocopy; if (len == 0) {
if (q > end || q < p) res->data = NULL;
return -1; return p;
if (!(res->data = kmalloc(res->len, GFP_KERNEL))) }
return -1; q = (const void *)((const char *)p + len);
memcpy(res->data, p, res->len); if (unlikely(q > end || q < p))
out_nocopy: return ERR_PTR(-EFAULT);
*ptr = q; res->data = kmalloc(len, GFP_KERNEL);
return 0; if (unlikely(res->data == NULL))
return ERR_PTR(-ENOMEM);
memcpy(res->data, p, len);
return q;
} }
static inline int static inline const void *
get_key(char **p, char *end, struct crypto_tfm **res, int *resalg) get_key(const void *p, const void *end, struct crypto_tfm **res, int *resalg)
{ {
struct xdr_netobj key = { struct xdr_netobj key = { 0 };
.len = 0,
.data = NULL,
};
int alg_mode,setkey = 0; int alg_mode,setkey = 0;
char *alg_name; char *alg_name;
if (get_bytes(p, end, resalg, sizeof(int))) p = simple_get_bytes(p, end, resalg, sizeof(*resalg));
if (IS_ERR(p))
goto out_err; goto out_err;
if ((get_netobj(p, end, &key))) p = simple_get_netobj(p, end, &key);
if (IS_ERR(p))
goto out_err; goto out_err;
switch (*resalg) { switch (*resalg) {
...@@ -111,10 +110,6 @@ get_key(char **p, char *end, struct crypto_tfm **res, int *resalg) ...@@ -111,10 +110,6 @@ get_key(char **p, char *end, struct crypto_tfm **res, int *resalg)
alg_mode = 0; alg_mode = 0;
setkey = 0; setkey = 0;
break; break;
case NID_cast5_cbc:
dprintk("RPC: SPKM3 get_key: case cast5_cbc, UNSUPPORTED \n");
goto out_err;
break;
default: default:
dprintk("RPC: SPKM3 get_key: unsupported algorithm %d", *resalg); dprintk("RPC: SPKM3 get_key: unsupported algorithm %d", *resalg);
goto out_err_free_key; goto out_err_free_key;
...@@ -128,69 +123,81 @@ get_key(char **p, char *end, struct crypto_tfm **res, int *resalg) ...@@ -128,69 +123,81 @@ get_key(char **p, char *end, struct crypto_tfm **res, int *resalg)
if(key.len > 0) if(key.len > 0)
kfree(key.data); kfree(key.data);
return 0; return p;
out_err_free_tfm: out_err_free_tfm:
crypto_free_tfm(*res); crypto_free_tfm(*res);
out_err_free_key: out_err_free_key:
if(key.len > 0) if(key.len > 0)
kfree(key.data); kfree(key.data);
p = ERR_PTR(-EINVAL);
out_err: out_err:
return -1; return p;
} }
static u32 static int
gss_import_sec_context_spkm3(struct xdr_netobj *inbuf, gss_import_sec_context_spkm3(const void *p, size_t len,
struct gss_ctx *ctx_id) struct gss_ctx *ctx_id)
{ {
char *p = inbuf->data; const void *end = (const void *)((const char *)p + len);
char *end = inbuf->data + inbuf->len;
struct spkm3_ctx *ctx; struct spkm3_ctx *ctx;
if (!(ctx = kmalloc(sizeof(*ctx), GFP_KERNEL))) if (!(ctx = kmalloc(sizeof(*ctx), GFP_KERNEL)))
goto out_err; goto out_err;
memset(ctx, 0, sizeof(*ctx)); memset(ctx, 0, sizeof(*ctx));
if (get_netobj(&p, end, &ctx->ctx_id)) p = simple_get_netobj(p, end, &ctx->ctx_id);
if (IS_ERR(p))
goto out_err_free_ctx; goto out_err_free_ctx;
if (get_bytes(&p, end, &ctx->qop, sizeof(ctx->qop))) p = simple_get_bytes(p, end, &ctx->qop, sizeof(ctx->qop));
if (IS_ERR(p))
goto out_err_free_ctx_id; goto out_err_free_ctx_id;
if (get_netobj(&p, end, &ctx->mech_used)) p = simple_get_netobj(p, end, &ctx->mech_used);
if (IS_ERR(p))
goto out_err_free_mech; goto out_err_free_mech;
if (get_bytes(&p, end, &ctx->ret_flags, sizeof(ctx->ret_flags))) p = simple_get_bytes(p, end, &ctx->ret_flags, sizeof(ctx->ret_flags));
if (IS_ERR(p))
goto out_err_free_mech; goto out_err_free_mech;
if (get_bytes(&p, end, &ctx->req_flags, sizeof(ctx->req_flags))) p = simple_get_bytes(p, end, &ctx->req_flags, sizeof(ctx->req_flags));
if (IS_ERR(p))
goto out_err_free_mech; goto out_err_free_mech;
if (get_netobj(&p, end, &ctx->share_key)) p = simple_get_netobj(p, end, &ctx->share_key);
if (IS_ERR(p))
goto out_err_free_s_key; goto out_err_free_s_key;
if (get_key(&p, end, &ctx->derived_conf_key, &ctx->conf_alg)) { p = get_key(p, end, &ctx->derived_conf_key, &ctx->conf_alg);
dprintk("RPC: SPKM3 confidentiality key will be NULL\n"); if (IS_ERR(p))
} goto out_err_free_s_key;
if (get_key(&p, end, &ctx->derived_integ_key, &ctx->intg_alg)) { p = get_key(p, end, &ctx->derived_integ_key, &ctx->intg_alg);
dprintk("RPC: SPKM3 integrity key will be NULL\n"); if (IS_ERR(p))
} goto out_err_free_key1;
if (get_bytes(&p, end, &ctx->owf_alg, sizeof(ctx->owf_alg))) p = simple_get_bytes(p, end, &ctx->keyestb_alg, sizeof(ctx->keyestb_alg));
goto out_err_free_s_key; if (IS_ERR(p))
goto out_err_free_key2;
if (get_bytes(&p, end, &ctx->owf_alg, sizeof(ctx->owf_alg))) p = simple_get_bytes(p, end, &ctx->owf_alg, sizeof(ctx->owf_alg));
goto out_err_free_s_key; if (IS_ERR(p))
goto out_err_free_key2;
if (p != end) if (p != end)
goto out_err_free_s_key; goto out_err_free_key2;
ctx_id->internal_ctx_id = ctx; ctx_id->internal_ctx_id = ctx;
dprintk("Succesfully imported new spkm context.\n"); dprintk("Succesfully imported new spkm context.\n");
return 0; return 0;
out_err_free_key2:
crypto_free_tfm(ctx->derived_integ_key);
out_err_free_key1:
crypto_free_tfm(ctx->derived_conf_key);
out_err_free_s_key: out_err_free_s_key:
kfree(ctx->share_key.data); kfree(ctx->share_key.data);
out_err_free_mech: out_err_free_mech:
...@@ -200,7 +207,7 @@ gss_import_sec_context_spkm3(struct xdr_netobj *inbuf, ...@@ -200,7 +207,7 @@ gss_import_sec_context_spkm3(struct xdr_netobj *inbuf,
out_err_free_ctx: out_err_free_ctx:
kfree(ctx); kfree(ctx);
out_err: out_err:
return GSS_S_FAILURE; return PTR_ERR(p);
} }
static void static void
......
...@@ -381,7 +381,6 @@ static int rsc_parse(struct cache_detail *cd, ...@@ -381,7 +381,6 @@ static int rsc_parse(struct cache_detail *cd,
else { else {
int N, i; int N, i;
struct gss_api_mech *gm; struct gss_api_mech *gm;
struct xdr_netobj tmp_buf;
/* gid */ /* gid */
if (get_int(&mesg, &rsci.cred.cr_gid)) if (get_int(&mesg, &rsci.cred.cr_gid))
...@@ -420,9 +419,7 @@ static int rsc_parse(struct cache_detail *cd, ...@@ -420,9 +419,7 @@ static int rsc_parse(struct cache_detail *cd,
gss_mech_put(gm); gss_mech_put(gm);
goto out; goto out;
} }
tmp_buf.len = len; if (gss_import_sec_context(buf, len, gm, &rsci.mechctx)) {
tmp_buf.data = buf;
if (gss_import_sec_context(&tmp_buf, gm, &rsci.mechctx)) {
gss_mech_put(gm); gss_mech_put(gm);
goto out; goto out;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment