Commit 5a9e20ef authored by Trond Myklebust's avatar Trond Myklebust

RPC: Initialize the GSS context upon RPC credential creation.

Signed-off-by: default avatarTrond Myklebust <Trond.Myklebust@netapp.com>
parent 4d08b43c
...@@ -365,11 +365,14 @@ rpcauth_refreshcred(struct rpc_task *task) ...@@ -365,11 +365,14 @@ rpcauth_refreshcred(struct rpc_task *task)
{ {
struct rpc_auth *auth = task->tk_auth; struct rpc_auth *auth = task->tk_auth;
struct rpc_cred *cred = task->tk_msg.rpc_cred; struct rpc_cred *cred = task->tk_msg.rpc_cred;
int err;
dprintk("RPC: %4d refreshing %s cred %p\n", dprintk("RPC: %4d refreshing %s cred %p\n",
task->tk_pid, auth->au_ops->au_name, cred); task->tk_pid, auth->au_ops->au_name, cred);
task->tk_status = cred->cr_ops->crrefresh(task); err = cred->cr_ops->crrefresh(task);
return task->tk_status; if (err < 0)
task->tk_status = err;
return err;
} }
void void
......
...@@ -87,6 +87,7 @@ struct gss_auth { ...@@ -87,6 +87,7 @@ struct gss_auth {
struct gss_api_mech *mech; struct gss_api_mech *mech;
enum rpc_gss_svc service; enum rpc_gss_svc service;
struct list_head upcalls; struct list_head upcalls;
struct rpc_clnt *client;
struct dentry *dentry; struct dentry *dentry;
char path[48]; char path[48];
spinlock_t lock; spinlock_t lock;
...@@ -294,7 +295,8 @@ struct gss_upcall_msg { ...@@ -294,7 +295,8 @@ struct gss_upcall_msg {
struct rpc_pipe_msg msg; struct rpc_pipe_msg msg;
struct list_head list; struct list_head list;
struct gss_auth *auth; struct gss_auth *auth;
struct rpc_wait_queue waitq; struct rpc_wait_queue rpc_waitqueue;
wait_queue_head_t waitqueue;
struct gss_cl_ctx *ctx; struct gss_cl_ctx *ctx;
}; };
...@@ -324,16 +326,34 @@ __gss_find_upcall(struct gss_auth *gss_auth, uid_t uid) ...@@ -324,16 +326,34 @@ __gss_find_upcall(struct gss_auth *gss_auth, uid_t uid)
return NULL; return NULL;
} }
/* Try to add a upcall to the pipefs queue.
* If an upcall owned by our uid already exists, then we return a reference
* to that upcall instead of adding the new upcall.
*/
static inline struct gss_upcall_msg *
gss_add_msg(struct gss_auth *gss_auth, struct gss_upcall_msg *gss_msg)
{
struct gss_upcall_msg *old;
spin_lock(&gss_auth->lock);
old = __gss_find_upcall(gss_auth, gss_msg->uid);
if (old == NULL) {
atomic_inc(&gss_msg->count);
list_add(&gss_msg->list, &gss_auth->upcalls);
} else
gss_msg = old;
spin_unlock(&gss_auth->lock);
return gss_msg;
}
static void static void
__gss_unhash_msg(struct gss_upcall_msg *gss_msg) __gss_unhash_msg(struct gss_upcall_msg *gss_msg)
{ {
if (list_empty(&gss_msg->list)) if (list_empty(&gss_msg->list))
return; return;
list_del_init(&gss_msg->list); list_del_init(&gss_msg->list);
if (gss_msg->msg.errno < 0) rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno);
rpc_wake_up_status(&gss_msg->waitq, gss_msg->msg.errno); wake_up_all(&gss_msg->waitqueue);
else
rpc_wake_up(&gss_msg->waitq);
atomic_dec(&gss_msg->count); atomic_dec(&gss_msg->count);
} }
...@@ -359,81 +379,127 @@ gss_upcall_callback(struct rpc_task *task) ...@@ -359,81 +379,127 @@ gss_upcall_callback(struct rpc_task *task)
gss_cred_set_ctx(task->tk_msg.rpc_cred, gss_get_ctx(gss_msg->ctx)); gss_cred_set_ctx(task->tk_msg.rpc_cred, gss_get_ctx(gss_msg->ctx));
else else
task->tk_status = gss_msg->msg.errno; task->tk_status = gss_msg->msg.errno;
spin_lock(&gss_msg->auth->lock);
gss_cred->gc_upcall = NULL; gss_cred->gc_upcall = NULL;
rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno);
spin_unlock(&gss_msg->auth->lock);
gss_release_msg(gss_msg); gss_release_msg(gss_msg);
} }
static int static inline struct gss_upcall_msg *
gss_upcall(struct rpc_clnt *clnt, struct rpc_task *task, struct rpc_cred *cred) gss_alloc_msg(struct gss_auth *gss_auth, uid_t uid)
{
struct gss_upcall_msg *gss_msg;
gss_msg = kmalloc(sizeof(*gss_msg), GFP_KERNEL);
if (gss_msg != NULL) {
memset(gss_msg, 0, sizeof(*gss_msg));
INIT_LIST_HEAD(&gss_msg->list);
rpc_init_wait_queue(&gss_msg->rpc_waitqueue, "RPCSEC_GSS upcall waitq");
init_waitqueue_head(&gss_msg->waitqueue);
atomic_set(&gss_msg->count, 1);
gss_msg->msg.data = &gss_msg->uid;
gss_msg->msg.len = sizeof(gss_msg->uid);
gss_msg->uid = uid;
gss_msg->auth = gss_auth;
}
return gss_msg;
}
static struct gss_upcall_msg *
gss_setup_upcall(struct rpc_clnt *clnt, struct gss_auth *gss_auth, struct rpc_cred *cred)
{ {
struct gss_auth *gss_auth = container_of(clnt->cl_auth, struct gss_upcall_msg *gss_new, *gss_msg;
gss_new = gss_alloc_msg(gss_auth, cred->cr_uid);
if (gss_new == NULL)
return ERR_PTR(-ENOMEM);
gss_msg = gss_add_msg(gss_auth, gss_new);
if (gss_msg == gss_new) {
int res = rpc_queue_upcall(gss_auth->dentry->d_inode, &gss_new->msg);
if (res) {
gss_unhash_msg(gss_new);
gss_msg = ERR_PTR(res);
}
} else
gss_release_msg(gss_new);
return gss_msg;
}
static inline int
gss_refresh_upcall(struct rpc_task *task)
{
struct rpc_cred *cred = task->tk_msg.rpc_cred;
struct gss_auth *gss_auth = container_of(task->tk_client->cl_auth,
struct gss_auth, rpc_auth); struct gss_auth, rpc_auth);
struct gss_cred *gss_cred = container_of(cred, struct gss_cred *gss_cred = container_of(cred,
struct gss_cred, gc_base); struct gss_cred, gc_base);
struct gss_upcall_msg *gss_msg, *gss_new = NULL; struct gss_upcall_msg *gss_msg;
struct rpc_pipe_msg *msg; int err = 0;
struct dentry *dentry = gss_auth->dentry;
uid_t uid = cred->cr_uid;
int res = 0;
dprintk("RPC: %4u gss_upcall for uid %u\n", task->tk_pid, uid);
retry: dprintk("RPC: %4u gss_refresh_upcall for uid %u\n", task->tk_pid, cred->cr_uid);
spin_lock(&gss_auth->lock); gss_msg = gss_setup_upcall(task->tk_client, gss_auth, cred);
gss_msg = __gss_find_upcall(gss_auth, uid); if (IS_ERR(gss_msg)) {
if (gss_msg) err = PTR_ERR(gss_msg);
goto out_sleep; goto out;
if (gss_new == NULL) {
spin_unlock(&gss_auth->lock);
gss_new = kmalloc(sizeof(*gss_new), GFP_KERNEL);
if (!gss_new) {
dprintk("RPC: %4u gss_upcall -ENOMEM\n", task->tk_pid);
return -ENOMEM;
}
goto retry;
} }
gss_msg = gss_new; spin_lock(&gss_auth->lock);
memset(gss_new, 0, sizeof(*gss_new)); if (gss_cred->gc_upcall != NULL)
INIT_LIST_HEAD(&gss_new->list); rpc_sleep_on(&gss_cred->gc_upcall->rpc_waitqueue, task, NULL, NULL);
rpc_init_wait_queue(&gss_new->waitq, "RPCSEC_GSS upcall waitq"); else if (gss_msg->ctx == NULL && gss_msg->msg.errno >= 0) {
atomic_set(&gss_new->count, 2);
msg = &gss_new->msg;
msg->data = &gss_new->uid;
msg->len = sizeof(gss_new->uid);
gss_new->uid = uid;
gss_new->auth = gss_auth;
list_add(&gss_new->list, &gss_auth->upcalls);
gss_new = NULL;
/* Has someone updated the credential behind our back? */
if (!gss_cred_is_uptodate_ctx(cred)) {
/* No, so do upcall and sleep */
task->tk_timeout = 0; task->tk_timeout = 0;
/* gss_upcall_callback will release the reference to gss_msg */
gss_cred->gc_upcall = gss_msg; gss_cred->gc_upcall = gss_msg;
rpc_sleep_on(&gss_msg->waitq, task, gss_upcall_callback, NULL); /* gss_upcall_callback will release the reference to gss_upcall_msg */
spin_unlock(&gss_auth->lock); atomic_inc(&gss_msg->count);
res = rpc_queue_upcall(dentry->d_inode, msg); rpc_sleep_on(&gss_msg->rpc_waitqueue, task, gss_upcall_callback, NULL);
if (res) } else
gss_unhash_msg(gss_msg); err = gss_msg->msg.errno;
} else { spin_unlock(&gss_auth->lock);
/* Yes, so cancel upcall */ gss_release_msg(gss_msg);
__gss_unhash_msg(gss_msg); out:
dprintk("RPC: %4u gss_refresh_upcall for uid %u result %d\n", task->tk_pid,
cred->cr_uid, err);
return err;
}
static inline int
gss_create_upcall(struct gss_auth *gss_auth, struct gss_cred *gss_cred)
{
struct rpc_cred *cred = &gss_cred->gc_base;
struct gss_upcall_msg *gss_msg;
DEFINE_WAIT(wait);
int err = 0;
dprintk("RPC: gss_upcall for uid %u\n", cred->cr_uid);
gss_msg = gss_setup_upcall(gss_auth->client, gss_auth, cred);
if (IS_ERR(gss_msg)) {
err = PTR_ERR(gss_msg);
goto out;
}
for (;;) {
prepare_to_wait(&gss_msg->waitqueue, &wait, TASK_INTERRUPTIBLE);
spin_lock(&gss_auth->lock);
if (gss_msg->ctx != NULL || gss_msg->msg.errno < 0) {
spin_unlock(&gss_auth->lock);
break;
}
spin_unlock(&gss_auth->lock); spin_unlock(&gss_auth->lock);
gss_release_msg(gss_msg); if (signalled()) {
err = -ERESTARTSYS;
goto out_intr;
}
schedule();
} }
dprintk("RPC: %4u gss_upcall for uid %u result %d\n", task->tk_pid, if (gss_msg->ctx)
uid, res); gss_cred_set_ctx(cred, gss_get_ctx(gss_msg->ctx));
return res; else
out_sleep: err = gss_msg->msg.errno;
task->tk_timeout = 0; out_intr:
/* gss_upcall_callback will release the reference to gss_msg */ finish_wait(&gss_msg->waitqueue, &wait);
gss_cred->gc_upcall = gss_msg; gss_release_msg(gss_msg);
rpc_sleep_on(&gss_msg->waitq, task, gss_upcall_callback, NULL); out:
spin_unlock(&gss_auth->lock); dprintk("RPC: gss_create_upcall for uid %u result %d\n", cred->cr_uid, err);
dprintk("RPC: %4u gss_upcall sleeping\n", task->tk_pid); return err;
if (gss_new)
kfree(gss_new);
return 0;
} }
static ssize_t static ssize_t
...@@ -600,6 +666,7 @@ gss_create(struct rpc_clnt *clnt, rpc_authflavor_t flavor) ...@@ -600,6 +666,7 @@ gss_create(struct rpc_clnt *clnt, rpc_authflavor_t flavor)
return NULL; return NULL;
if (!(gss_auth = kmalloc(sizeof(*gss_auth), GFP_KERNEL))) if (!(gss_auth = kmalloc(sizeof(*gss_auth), GFP_KERNEL)))
goto out_dec; goto out_dec;
gss_auth->client = clnt;
gss_auth->mech = gss_mech_get_by_pseudoflavor(flavor); gss_auth->mech = gss_mech_get_by_pseudoflavor(flavor);
if (!gss_auth->mech) { if (!gss_auth->mech) {
printk(KERN_WARNING "%s: Pseudoflavor %d not found!", printk(KERN_WARNING "%s: Pseudoflavor %d not found!",
...@@ -697,6 +764,7 @@ gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int taskflags) ...@@ -697,6 +764,7 @@ gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int taskflags)
{ {
struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth); struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth);
struct gss_cred *cred = NULL; struct gss_cred *cred = NULL;
int err = -ENOMEM;
dprintk("RPC: gss_create_cred for uid %d, flavor %d\n", dprintk("RPC: gss_create_cred for uid %d, flavor %d\n",
acred->uid, auth->au_flavor); acred->uid, auth->au_flavor);
...@@ -714,11 +782,14 @@ gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int taskflags) ...@@ -714,11 +782,14 @@ gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int taskflags)
cred->gc_flags = 0; cred->gc_flags = 0;
cred->gc_base.cr_ops = &gss_credops; cred->gc_base.cr_ops = &gss_credops;
cred->gc_service = gss_auth->service; cred->gc_service = gss_auth->service;
err = gss_create_upcall(gss_auth, cred);
if (err < 0)
goto out_err;
return &cred->gc_base; return &cred->gc_base;
out_err: out_err:
dprintk("RPC: gss_create_cred failed\n"); dprintk("RPC: gss_create_cred failed with error %d\n", err);
if (cred) gss_destroy_cred(&cred->gc_base); if (cred) gss_destroy_cred(&cred->gc_base);
return NULL; return NULL;
} }
...@@ -804,11 +875,9 @@ gss_marshal(struct rpc_task *task, u32 *p, int ruid) ...@@ -804,11 +875,9 @@ gss_marshal(struct rpc_task *task, u32 *p, int ruid)
static int static int
gss_refresh(struct rpc_task *task) gss_refresh(struct rpc_task *task)
{ {
struct rpc_clnt *clnt = task->tk_client;
struct rpc_cred *cred = task->tk_msg.rpc_cred;
if (!gss_cred_is_uptodate_ctx(cred)) if (!gss_cred_is_uptodate_ctx(task->tk_msg.rpc_cred))
return gss_upcall(clnt, task, cred); return gss_refresh_upcall(task);
return 0; return 0;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment