Commit 5a9e20ef authored by Trond Myklebust's avatar Trond Myklebust

RPC: Initialize the GSS context upon RPC credential creation.

Signed-off-by: default avatarTrond Myklebust <Trond.Myklebust@netapp.com>
parent 4d08b43c
......@@ -365,11 +365,14 @@ rpcauth_refreshcred(struct rpc_task *task)
{
struct rpc_auth *auth = task->tk_auth;
struct rpc_cred *cred = task->tk_msg.rpc_cred;
int err;
dprintk("RPC: %4d refreshing %s cred %p\n",
task->tk_pid, auth->au_ops->au_name, cred);
task->tk_status = cred->cr_ops->crrefresh(task);
return task->tk_status;
err = cred->cr_ops->crrefresh(task);
if (err < 0)
task->tk_status = err;
return err;
}
void
......
......@@ -87,6 +87,7 @@ struct gss_auth {
struct gss_api_mech *mech;
enum rpc_gss_svc service;
struct list_head upcalls;
struct rpc_clnt *client;
struct dentry *dentry;
char path[48];
spinlock_t lock;
......@@ -294,7 +295,8 @@ struct gss_upcall_msg {
struct rpc_pipe_msg msg;
struct list_head list;
struct gss_auth *auth;
struct rpc_wait_queue waitq;
struct rpc_wait_queue rpc_waitqueue;
wait_queue_head_t waitqueue;
struct gss_cl_ctx *ctx;
};
......@@ -324,16 +326,34 @@ __gss_find_upcall(struct gss_auth *gss_auth, uid_t uid)
return NULL;
}
/* Try to add a upcall to the pipefs queue.
* If an upcall owned by our uid already exists, then we return a reference
* to that upcall instead of adding the new upcall.
*/
static inline struct gss_upcall_msg *
gss_add_msg(struct gss_auth *gss_auth, struct gss_upcall_msg *gss_msg)
{
struct gss_upcall_msg *old;
spin_lock(&gss_auth->lock);
old = __gss_find_upcall(gss_auth, gss_msg->uid);
if (old == NULL) {
atomic_inc(&gss_msg->count);
list_add(&gss_msg->list, &gss_auth->upcalls);
} else
gss_msg = old;
spin_unlock(&gss_auth->lock);
return gss_msg;
}
static void
__gss_unhash_msg(struct gss_upcall_msg *gss_msg)
{
if (list_empty(&gss_msg->list))
return;
list_del_init(&gss_msg->list);
if (gss_msg->msg.errno < 0)
rpc_wake_up_status(&gss_msg->waitq, gss_msg->msg.errno);
else
rpc_wake_up(&gss_msg->waitq);
rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno);
wake_up_all(&gss_msg->waitqueue);
atomic_dec(&gss_msg->count);
}
......@@ -359,81 +379,127 @@ gss_upcall_callback(struct rpc_task *task)
gss_cred_set_ctx(task->tk_msg.rpc_cred, gss_get_ctx(gss_msg->ctx));
else
task->tk_status = gss_msg->msg.errno;
spin_lock(&gss_msg->auth->lock);
gss_cred->gc_upcall = NULL;
rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno);
spin_unlock(&gss_msg->auth->lock);
gss_release_msg(gss_msg);
}
static int
gss_upcall(struct rpc_clnt *clnt, struct rpc_task *task, struct rpc_cred *cred)
static inline struct gss_upcall_msg *
gss_alloc_msg(struct gss_auth *gss_auth, uid_t uid)
{
struct gss_upcall_msg *gss_msg;
gss_msg = kmalloc(sizeof(*gss_msg), GFP_KERNEL);
if (gss_msg != NULL) {
memset(gss_msg, 0, sizeof(*gss_msg));
INIT_LIST_HEAD(&gss_msg->list);
rpc_init_wait_queue(&gss_msg->rpc_waitqueue, "RPCSEC_GSS upcall waitq");
init_waitqueue_head(&gss_msg->waitqueue);
atomic_set(&gss_msg->count, 1);
gss_msg->msg.data = &gss_msg->uid;
gss_msg->msg.len = sizeof(gss_msg->uid);
gss_msg->uid = uid;
gss_msg->auth = gss_auth;
}
return gss_msg;
}
static struct gss_upcall_msg *
gss_setup_upcall(struct rpc_clnt *clnt, struct gss_auth *gss_auth, struct rpc_cred *cred)
{
struct gss_auth *gss_auth = container_of(clnt->cl_auth,
struct gss_upcall_msg *gss_new, *gss_msg;
gss_new = gss_alloc_msg(gss_auth, cred->cr_uid);
if (gss_new == NULL)
return ERR_PTR(-ENOMEM);
gss_msg = gss_add_msg(gss_auth, gss_new);
if (gss_msg == gss_new) {
int res = rpc_queue_upcall(gss_auth->dentry->d_inode, &gss_new->msg);
if (res) {
gss_unhash_msg(gss_new);
gss_msg = ERR_PTR(res);
}
} else
gss_release_msg(gss_new);
return gss_msg;
}
static inline int
gss_refresh_upcall(struct rpc_task *task)
{
struct rpc_cred *cred = task->tk_msg.rpc_cred;
struct gss_auth *gss_auth = container_of(task->tk_client->cl_auth,
struct gss_auth, rpc_auth);
struct gss_cred *gss_cred = container_of(cred,
struct gss_cred, gc_base);
struct gss_upcall_msg *gss_msg, *gss_new = NULL;
struct rpc_pipe_msg *msg;
struct dentry *dentry = gss_auth->dentry;
uid_t uid = cred->cr_uid;
int res = 0;
dprintk("RPC: %4u gss_upcall for uid %u\n", task->tk_pid, uid);
struct gss_upcall_msg *gss_msg;
int err = 0;
retry:
spin_lock(&gss_auth->lock);
gss_msg = __gss_find_upcall(gss_auth, uid);
if (gss_msg)
goto out_sleep;
if (gss_new == NULL) {
spin_unlock(&gss_auth->lock);
gss_new = kmalloc(sizeof(*gss_new), GFP_KERNEL);
if (!gss_new) {
dprintk("RPC: %4u gss_upcall -ENOMEM\n", task->tk_pid);
return -ENOMEM;
}
goto retry;
dprintk("RPC: %4u gss_refresh_upcall for uid %u\n", task->tk_pid, cred->cr_uid);
gss_msg = gss_setup_upcall(task->tk_client, gss_auth, cred);
if (IS_ERR(gss_msg)) {
err = PTR_ERR(gss_msg);
goto out;
}
gss_msg = gss_new;
memset(gss_new, 0, sizeof(*gss_new));
INIT_LIST_HEAD(&gss_new->list);
rpc_init_wait_queue(&gss_new->waitq, "RPCSEC_GSS upcall waitq");
atomic_set(&gss_new->count, 2);
msg = &gss_new->msg;
msg->data = &gss_new->uid;
msg->len = sizeof(gss_new->uid);
gss_new->uid = uid;
gss_new->auth = gss_auth;
list_add(&gss_new->list, &gss_auth->upcalls);
gss_new = NULL;
/* Has someone updated the credential behind our back? */
if (!gss_cred_is_uptodate_ctx(cred)) {
/* No, so do upcall and sleep */
spin_lock(&gss_auth->lock);
if (gss_cred->gc_upcall != NULL)
rpc_sleep_on(&gss_cred->gc_upcall->rpc_waitqueue, task, NULL, NULL);
else if (gss_msg->ctx == NULL && gss_msg->msg.errno >= 0) {
task->tk_timeout = 0;
/* gss_upcall_callback will release the reference to gss_msg */
gss_cred->gc_upcall = gss_msg;
rpc_sleep_on(&gss_msg->waitq, task, gss_upcall_callback, NULL);
spin_unlock(&gss_auth->lock);
res = rpc_queue_upcall(dentry->d_inode, msg);
if (res)
gss_unhash_msg(gss_msg);
} else {
/* Yes, so cancel upcall */
__gss_unhash_msg(gss_msg);
/* gss_upcall_callback will release the reference to gss_upcall_msg */
atomic_inc(&gss_msg->count);
rpc_sleep_on(&gss_msg->rpc_waitqueue, task, gss_upcall_callback, NULL);
} else
err = gss_msg->msg.errno;
spin_unlock(&gss_auth->lock);
gss_release_msg(gss_msg);
out:
dprintk("RPC: %4u gss_refresh_upcall for uid %u result %d\n", task->tk_pid,
cred->cr_uid, err);
return err;
}
static inline int
gss_create_upcall(struct gss_auth *gss_auth, struct gss_cred *gss_cred)
{
struct rpc_cred *cred = &gss_cred->gc_base;
struct gss_upcall_msg *gss_msg;
DEFINE_WAIT(wait);
int err = 0;
dprintk("RPC: gss_upcall for uid %u\n", cred->cr_uid);
gss_msg = gss_setup_upcall(gss_auth->client, gss_auth, cred);
if (IS_ERR(gss_msg)) {
err = PTR_ERR(gss_msg);
goto out;
}
for (;;) {
prepare_to_wait(&gss_msg->waitqueue, &wait, TASK_INTERRUPTIBLE);
spin_lock(&gss_auth->lock);
if (gss_msg->ctx != NULL || gss_msg->msg.errno < 0) {
spin_unlock(&gss_auth->lock);
break;
}
spin_unlock(&gss_auth->lock);
gss_release_msg(gss_msg);
if (signalled()) {
err = -ERESTARTSYS;
goto out_intr;
}
schedule();
}
dprintk("RPC: %4u gss_upcall for uid %u result %d\n", task->tk_pid,
uid, res);
return res;
out_sleep:
task->tk_timeout = 0;
/* gss_upcall_callback will release the reference to gss_msg */
gss_cred->gc_upcall = gss_msg;
rpc_sleep_on(&gss_msg->waitq, task, gss_upcall_callback, NULL);
spin_unlock(&gss_auth->lock);
dprintk("RPC: %4u gss_upcall sleeping\n", task->tk_pid);
if (gss_new)
kfree(gss_new);
return 0;
if (gss_msg->ctx)
gss_cred_set_ctx(cred, gss_get_ctx(gss_msg->ctx));
else
err = gss_msg->msg.errno;
out_intr:
finish_wait(&gss_msg->waitqueue, &wait);
gss_release_msg(gss_msg);
out:
dprintk("RPC: gss_create_upcall for uid %u result %d\n", cred->cr_uid, err);
return err;
}
static ssize_t
......@@ -600,6 +666,7 @@ gss_create(struct rpc_clnt *clnt, rpc_authflavor_t flavor)
return NULL;
if (!(gss_auth = kmalloc(sizeof(*gss_auth), GFP_KERNEL)))
goto out_dec;
gss_auth->client = clnt;
gss_auth->mech = gss_mech_get_by_pseudoflavor(flavor);
if (!gss_auth->mech) {
printk(KERN_WARNING "%s: Pseudoflavor %d not found!",
......@@ -697,6 +764,7 @@ gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int taskflags)
{
struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth);
struct gss_cred *cred = NULL;
int err = -ENOMEM;
dprintk("RPC: gss_create_cred for uid %d, flavor %d\n",
acred->uid, auth->au_flavor);
......@@ -714,11 +782,14 @@ gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int taskflags)
cred->gc_flags = 0;
cred->gc_base.cr_ops = &gss_credops;
cred->gc_service = gss_auth->service;
err = gss_create_upcall(gss_auth, cred);
if (err < 0)
goto out_err;
return &cred->gc_base;
out_err:
dprintk("RPC: gss_create_cred failed\n");
dprintk("RPC: gss_create_cred failed with error %d\n", err);
if (cred) gss_destroy_cred(&cred->gc_base);
return NULL;
}
......@@ -804,11 +875,9 @@ gss_marshal(struct rpc_task *task, u32 *p, int ruid)
static int
gss_refresh(struct rpc_task *task)
{
struct rpc_clnt *clnt = task->tk_client;
struct rpc_cred *cred = task->tk_msg.rpc_cred;
if (!gss_cred_is_uptodate_ctx(cred))
return gss_upcall(clnt, task, cred);
if (!gss_cred_is_uptodate_ctx(task->tk_msg.rpc_cred))
return gss_refresh_upcall(task);
return 0;
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment