Commit 41e3d179 authored by Trond Myklebust's avatar Trond Myklebust

RPC: clean up the RPCSEC_GSS kerberos and spkm3 context import functions

Signed-off-by: default avatarTrond Myklebust <Trond.Myklebust@netapp.com>
parent 25326faa
......@@ -33,8 +33,9 @@ struct gss_ctx {
/* gss-api prototypes; note that these are somewhat simplified versions of
* the prototypes specified in RFC 2744. */
u32 gss_import_sec_context(
struct xdr_netobj *input_token,
int gss_import_sec_context(
const void* input_token,
size_t bufsize,
struct gss_api_mech *mech,
struct gss_ctx **ctx_id);
u32 gss_get_mic(
......@@ -80,8 +81,9 @@ struct gss_api_mech {
/* and must provide the following operations: */
struct gss_api_ops {
u32 (*gss_import_sec_context)(
struct xdr_netobj *input_token,
int (*gss_import_sec_context)(
const void *input_token,
size_t bufsize,
struct gss_ctx *ctx_id);
u32 (*gss_get_mic)(
struct gss_ctx *ctx_id,
......
......@@ -272,7 +272,7 @@ gss_parse_init_downcall(struct gss_api_mech *gm, struct xdr_netobj *buf,
goto err_free_wire_ctx;
if (p != end)
goto err_free_wire_ctx;
if (gss_import_sec_context(&tmp_buf, gm, &ctx->gc_gss_ctx))
if (gss_import_sec_context(tmp_buf.data, tmp_buf.len, gm, &ctx->gc_gss_ctx))
goto err_free_wire_ctx;
*gc = ctx;
return 0;
......
......@@ -48,46 +48,48 @@
# define RPCDBG_FACILITY RPCDBG_AUTH
#endif
static inline int
get_bytes(char **ptr, const char *end, void *res, int len)
static const void *
simple_get_bytes(const void *p, const void *end, void *res, int len)
{
char *p, *q;
p = *ptr;
q = p + len;
if (q > end || q < p)
return -1;
const void *q = (const void *)((const char *)p + len);
if (unlikely(q > end || q < p))
return ERR_PTR(-EFAULT);
memcpy(res, p, len);
*ptr = q;
return 0;
return q;
}
static inline int
get_netobj(char **ptr, const char *end, struct xdr_netobj *res)
static const void *
simple_get_netobj(const void *p, const void *end, struct xdr_netobj *res)
{
char *p, *q;
p = *ptr;
if (get_bytes(&p, end, &res->len, sizeof(res->len)))
return -1;
q = p + res->len;
if (q > end || q < p)
return -1;
if (!(res->data = kmalloc(res->len, GFP_KERNEL)))
return -1;
memcpy(res->data, p, res->len);
*ptr = q;
return 0;
const void *q;
unsigned int len;
p = simple_get_bytes(p, end, &len, sizeof(len));
if (IS_ERR(p))
return p;
q = (const void *)((const char *)p + len);
if (unlikely(q > end || q < p))
return ERR_PTR(-EFAULT);
res->data = kmalloc(len, GFP_KERNEL);
if (unlikely(res->data == NULL))
return ERR_PTR(-ENOMEM);
memcpy(res->data, p, len);
res->len = len;
return q;
}
static inline int
get_key(char **p, char *end, struct crypto_tfm **res)
static inline const void *
get_key(const void *p, const void *end, struct crypto_tfm **res)
{
struct xdr_netobj key;
int alg, alg_mode;
char *alg_name;
if (get_bytes(p, end, &alg, sizeof(alg)))
p = simple_get_bytes(p, end, &alg, sizeof(alg));
if (IS_ERR(p))
goto out_err;
if ((get_netobj(p, end, &key)))
p = simple_get_netobj(p, end, &key);
if (IS_ERR(p))
goto out_err;
switch (alg) {
......@@ -105,50 +107,63 @@ get_key(char **p, char *end, struct crypto_tfm **res)
goto out_err_free_tfm;
kfree(key.data);
return 0;
return p;
out_err_free_tfm:
crypto_free_tfm(*res);
out_err_free_key:
kfree(key.data);
p = ERR_PTR(-EINVAL);
out_err:
return -1;
return p;
}
static u32
gss_import_sec_context_kerberos(struct xdr_netobj *inbuf,
static int
gss_import_sec_context_kerberos(const void *p,
size_t len,
struct gss_ctx *ctx_id)
{
char *p = inbuf->data;
char *end = inbuf->data + inbuf->len;
const void *end = (const void *)((const char *)p + len);
struct krb5_ctx *ctx;
if (!(ctx = kmalloc(sizeof(*ctx), GFP_KERNEL)))
goto out_err;
memset(ctx, 0, sizeof(*ctx));
if (get_bytes(&p, end, &ctx->initiate, sizeof(ctx->initiate)))
p = simple_get_bytes(p, end, &ctx->initiate, sizeof(ctx->initiate));
if (IS_ERR(p))
goto out_err_free_ctx;
if (get_bytes(&p, end, &ctx->seed_init, sizeof(ctx->seed_init)))
p = simple_get_bytes(p, end, &ctx->seed_init, sizeof(ctx->seed_init));
if (IS_ERR(p))
goto out_err_free_ctx;
if (get_bytes(&p, end, ctx->seed, sizeof(ctx->seed)))
p = simple_get_bytes(p, end, ctx->seed, sizeof(ctx->seed));
if (IS_ERR(p))
goto out_err_free_ctx;
if (get_bytes(&p, end, &ctx->signalg, sizeof(ctx->signalg)))
p = simple_get_bytes(p, end, &ctx->signalg, sizeof(ctx->signalg));
if (IS_ERR(p))
goto out_err_free_ctx;
if (get_bytes(&p, end, &ctx->sealalg, sizeof(ctx->sealalg)))
p = simple_get_bytes(p, end, &ctx->sealalg, sizeof(ctx->sealalg));
if (IS_ERR(p))
goto out_err_free_ctx;
if (get_bytes(&p, end, &ctx->endtime, sizeof(ctx->endtime)))
p = simple_get_bytes(p, end, &ctx->endtime, sizeof(ctx->endtime));
if (IS_ERR(p))
goto out_err_free_ctx;
if (get_bytes(&p, end, &ctx->seq_send, sizeof(ctx->seq_send)))
p = simple_get_bytes(p, end, &ctx->seq_send, sizeof(ctx->seq_send));
if (IS_ERR(p))
goto out_err_free_ctx;
if (get_netobj(&p, end, &ctx->mech_used))
p = simple_get_netobj(p, end, &ctx->mech_used);
if (IS_ERR(p))
goto out_err_free_ctx;
if (get_key(&p, end, &ctx->enc))
p = get_key(p, end, &ctx->enc);
if (IS_ERR(p))
goto out_err_free_mech;
if (get_key(&p, end, &ctx->seq))
p = get_key(p, end, &ctx->seq);
if (IS_ERR(p))
goto out_err_free_key1;
if (p != end)
if (p != end) {
p = ERR_PTR(-EFAULT);
goto out_err_free_key2;
}
ctx_id->internal_ctx_id = ctx;
dprintk("RPC: Succesfully imported new context.\n");
......@@ -163,7 +178,7 @@ gss_import_sec_context_kerberos(struct xdr_netobj *inbuf,
out_err_free_ctx:
kfree(ctx);
out_err:
return GSS_S_FAILURE;
return PTR_ERR(p);
}
static void
......
......@@ -233,8 +233,8 @@ EXPORT_SYMBOL(gss_mech_put);
/* The mech could probably be determined from the token instead, but it's just
* as easy for now to pass it in. */
u32
gss_import_sec_context(struct xdr_netobj *input_token,
int
gss_import_sec_context(const void *input_token, size_t bufsize,
struct gss_api_mech *mech,
struct gss_ctx **ctx_id)
{
......@@ -244,7 +244,7 @@ gss_import_sec_context(struct xdr_netobj *input_token,
(*ctx_id)->mech_type = gss_mech_get(mech);
return mech->gm_ops
->gss_import_sec_context(input_token, *ctx_id);
->gss_import_sec_context(input_token, bufsize, *ctx_id);
}
/* gss_get_mic: compute a mic over message and return mic_token. */
......
......@@ -49,52 +49,51 @@
# define RPCDBG_FACILITY RPCDBG_AUTH
#endif
static inline int
get_bytes(char **ptr, const char *end, void *res, int len)
static const void *
simple_get_bytes(const void *p, const void *end, void *res, int len)
{
char *p, *q;
p = *ptr;
q = p + len;
if (q > end || q < p)
return -1;
const void *q = (const void *)((const char *)p + len);
if (unlikely(q > end || q < p))
return ERR_PTR(-EFAULT);
memcpy(res, p, len);
*ptr = q;
return 0;
return q;
}
static inline int
get_netobj(char **ptr, const char *end, struct xdr_netobj *res)
static const void *
simple_get_netobj(const void *p, const void *end, struct xdr_netobj *res)
{
char *p, *q;
p = *ptr;
if (get_bytes(&p, end, &res->len, sizeof(res->len)))
return -1;
q = p + res->len;
if(res->len == 0)
goto out_nocopy;
if (q > end || q < p)
return -1;
if (!(res->data = kmalloc(res->len, GFP_KERNEL)))
return -1;
memcpy(res->data, p, res->len);
out_nocopy:
*ptr = q;
return 0;
const void *q;
unsigned int len;
p = simple_get_bytes(p, end, &len, sizeof(len));
if (IS_ERR(p))
return p;
res->len = len;
if (len == 0) {
res->data = NULL;
return p;
}
q = (const void *)((const char *)p + len);
if (unlikely(q > end || q < p))
return ERR_PTR(-EFAULT);
res->data = kmalloc(len, GFP_KERNEL);
if (unlikely(res->data == NULL))
return ERR_PTR(-ENOMEM);
memcpy(res->data, p, len);
return q;
}
static inline int
get_key(char **p, char *end, struct crypto_tfm **res, int *resalg)
static inline const void *
get_key(const void *p, const void *end, struct crypto_tfm **res, int *resalg)
{
struct xdr_netobj key = {
.len = 0,
.data = NULL,
};
struct xdr_netobj key = { 0 };
int alg_mode,setkey = 0;
char *alg_name;
if (get_bytes(p, end, resalg, sizeof(int)))
p = simple_get_bytes(p, end, resalg, sizeof(*resalg));
if (IS_ERR(p))
goto out_err;
if ((get_netobj(p, end, &key)))
p = simple_get_netobj(p, end, &key);
if (IS_ERR(p))
goto out_err;
switch (*resalg) {
......@@ -111,10 +110,6 @@ get_key(char **p, char *end, struct crypto_tfm **res, int *resalg)
alg_mode = 0;
setkey = 0;
break;
case NID_cast5_cbc:
dprintk("RPC: SPKM3 get_key: case cast5_cbc, UNSUPPORTED \n");
goto out_err;
break;
default:
dprintk("RPC: SPKM3 get_key: unsupported algorithm %d", *resalg);
goto out_err_free_key;
......@@ -128,69 +123,81 @@ get_key(char **p, char *end, struct crypto_tfm **res, int *resalg)
if(key.len > 0)
kfree(key.data);
return 0;
return p;
out_err_free_tfm:
crypto_free_tfm(*res);
out_err_free_key:
if(key.len > 0)
kfree(key.data);
p = ERR_PTR(-EINVAL);
out_err:
return -1;
return p;
}
static u32
gss_import_sec_context_spkm3(struct xdr_netobj *inbuf,
static int
gss_import_sec_context_spkm3(const void *p, size_t len,
struct gss_ctx *ctx_id)
{
char *p = inbuf->data;
char *end = inbuf->data + inbuf->len;
const void *end = (const void *)((const char *)p + len);
struct spkm3_ctx *ctx;
if (!(ctx = kmalloc(sizeof(*ctx), GFP_KERNEL)))
goto out_err;
memset(ctx, 0, sizeof(*ctx));
if (get_netobj(&p, end, &ctx->ctx_id))
p = simple_get_netobj(p, end, &ctx->ctx_id);
if (IS_ERR(p))
goto out_err_free_ctx;
if (get_bytes(&p, end, &ctx->qop, sizeof(ctx->qop)))
p = simple_get_bytes(p, end, &ctx->qop, sizeof(ctx->qop));
if (IS_ERR(p))
goto out_err_free_ctx_id;
if (get_netobj(&p, end, &ctx->mech_used))
p = simple_get_netobj(p, end, &ctx->mech_used);
if (IS_ERR(p))
goto out_err_free_mech;
if (get_bytes(&p, end, &ctx->ret_flags, sizeof(ctx->ret_flags)))
p = simple_get_bytes(p, end, &ctx->ret_flags, sizeof(ctx->ret_flags));
if (IS_ERR(p))
goto out_err_free_mech;
if (get_bytes(&p, end, &ctx->req_flags, sizeof(ctx->req_flags)))
p = simple_get_bytes(p, end, &ctx->req_flags, sizeof(ctx->req_flags));
if (IS_ERR(p))
goto out_err_free_mech;
if (get_netobj(&p, end, &ctx->share_key))
p = simple_get_netobj(p, end, &ctx->share_key);
if (IS_ERR(p))
goto out_err_free_s_key;
if (get_key(&p, end, &ctx->derived_conf_key, &ctx->conf_alg)) {
dprintk("RPC: SPKM3 confidentiality key will be NULL\n");
}
p = get_key(p, end, &ctx->derived_conf_key, &ctx->conf_alg);
if (IS_ERR(p))
goto out_err_free_s_key;
if (get_key(&p, end, &ctx->derived_integ_key, &ctx->intg_alg)) {
dprintk("RPC: SPKM3 integrity key will be NULL\n");
}
p = get_key(p, end, &ctx->derived_integ_key, &ctx->intg_alg);
if (IS_ERR(p))
goto out_err_free_key1;
if (get_bytes(&p, end, &ctx->owf_alg, sizeof(ctx->owf_alg)))
goto out_err_free_s_key;
p = simple_get_bytes(p, end, &ctx->keyestb_alg, sizeof(ctx->keyestb_alg));
if (IS_ERR(p))
goto out_err_free_key2;
if (get_bytes(&p, end, &ctx->owf_alg, sizeof(ctx->owf_alg)))
goto out_err_free_s_key;
p = simple_get_bytes(p, end, &ctx->owf_alg, sizeof(ctx->owf_alg));
if (IS_ERR(p))
goto out_err_free_key2;
if (p != end)
goto out_err_free_s_key;
goto out_err_free_key2;
ctx_id->internal_ctx_id = ctx;
dprintk("Succesfully imported new spkm context.\n");
return 0;
out_err_free_key2:
crypto_free_tfm(ctx->derived_integ_key);
out_err_free_key1:
crypto_free_tfm(ctx->derived_conf_key);
out_err_free_s_key:
kfree(ctx->share_key.data);
out_err_free_mech:
......@@ -200,7 +207,7 @@ gss_import_sec_context_spkm3(struct xdr_netobj *inbuf,
out_err_free_ctx:
kfree(ctx);
out_err:
return GSS_S_FAILURE;
return PTR_ERR(p);
}
static void
......
......@@ -381,7 +381,6 @@ static int rsc_parse(struct cache_detail *cd,
else {
int N, i;
struct gss_api_mech *gm;
struct xdr_netobj tmp_buf;
/* gid */
if (get_int(&mesg, &rsci.cred.cr_gid))
......@@ -420,9 +419,7 @@ static int rsc_parse(struct cache_detail *cd,
gss_mech_put(gm);
goto out;
}
tmp_buf.len = len;
tmp_buf.data = buf;
if (gss_import_sec_context(&tmp_buf, gm, &rsci.mechctx)) {
if (gss_import_sec_context(buf, len, gm, &rsci.mechctx)) {
gss_mech_put(gm);
goto out;
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment