Commit 1df30f82 authored by Paul Moore's avatar Paul Moore Committed by Greg Kroah-Hartman

audit: fix the RCU locking for the auditd_connection structure

commit 48d0e023 upstream.

Cong Wang correctly pointed out that the RCU read locking of the
auditd_connection struct was wrong, this patch correct this by
adopting a more traditional, and correct RCU locking model.

This patch is heavily based on an earlier prototype by Cong Wang.
Reported-by: default avatarCong Wang <xiyou.wangcong@gmail.com>
Signed-off-by: default avatarCong Wang <xiyou.wangcong@gmail.com>
Signed-off-by: default avatarPaul Moore <paul@paul-moore.com>
Signed-off-by: default avatarGreg Kroah-Hartman <gregkh@linuxfoundation.org>
parent f17b15c8
...@@ -110,18 +110,19 @@ struct audit_net { ...@@ -110,18 +110,19 @@ struct audit_net {
* @pid: auditd PID * @pid: auditd PID
* @portid: netlink portid * @portid: netlink portid
* @net: the associated network namespace * @net: the associated network namespace
* @lock: spinlock to protect write access * @rcu: RCU head
* *
* Description: * Description:
* This struct is RCU protected; you must either hold the RCU lock for reading * This struct is RCU protected; you must either hold the RCU lock for reading
* or the included spinlock for writing. * or the associated spinlock for writing.
*/ */
static struct auditd_connection { static struct auditd_connection {
int pid; int pid;
u32 portid; u32 portid;
struct net *net; struct net *net;
spinlock_t lock; struct rcu_head rcu;
} auditd_conn; } *auditd_conn = NULL;
static DEFINE_SPINLOCK(auditd_conn_lock);
/* If audit_rate_limit is non-zero, limit the rate of sending audit records /* If audit_rate_limit is non-zero, limit the rate of sending audit records
* to that number per second. This prevents DoS attacks, but results in * to that number per second. This prevents DoS attacks, but results in
...@@ -223,14 +224,38 @@ struct audit_reply { ...@@ -223,14 +224,38 @@ struct audit_reply {
int auditd_test_task(const struct task_struct *task) int auditd_test_task(const struct task_struct *task)
{ {
int rc; int rc;
struct auditd_connection *ac;
rcu_read_lock(); rcu_read_lock();
rc = (auditd_conn.pid && task->tgid == auditd_conn.pid ? 1 : 0); ac = rcu_dereference(auditd_conn);
rc = (ac && ac->pid == task->tgid ? 1 : 0);
rcu_read_unlock(); rcu_read_unlock();
return rc; return rc;
} }
/**
* auditd_pid_vnr - Return the auditd PID relative to the namespace
*
* Description:
* Returns the PID in relation to the namespace, 0 on failure.
*/
static pid_t auditd_pid_vnr(void)
{
pid_t pid;
const struct auditd_connection *ac;
rcu_read_lock();
ac = rcu_dereference(auditd_conn);
if (!ac)
pid = 0;
else
pid = ac->pid;
rcu_read_unlock();
return pid;
}
/** /**
* audit_get_sk - Return the audit socket for the given network namespace * audit_get_sk - Return the audit socket for the given network namespace
* @net: the destination network namespace * @net: the destination network namespace
...@@ -426,6 +451,23 @@ static int audit_set_failure(u32 state) ...@@ -426,6 +451,23 @@ static int audit_set_failure(u32 state)
return audit_do_config_change("audit_failure", &audit_failure, state); return audit_do_config_change("audit_failure", &audit_failure, state);
} }
/**
* auditd_conn_free - RCU helper to release an auditd connection struct
* @rcu: RCU head
*
* Description:
* Drop any references inside the auditd connection tracking struct and free
* the memory.
*/
static void auditd_conn_free(struct rcu_head *rcu)
{
struct auditd_connection *ac;
ac = container_of(rcu, struct auditd_connection, rcu);
put_net(ac->net);
kfree(ac);
}
/** /**
* auditd_set - Set/Reset the auditd connection state * auditd_set - Set/Reset the auditd connection state
* @pid: auditd PID * @pid: auditd PID
...@@ -434,22 +476,33 @@ static int audit_set_failure(u32 state) ...@@ -434,22 +476,33 @@ static int audit_set_failure(u32 state)
* *
* Description: * Description:
* This function will obtain and drop network namespace references as * This function will obtain and drop network namespace references as
* necessary. * necessary. Returns zero on success, negative values on failure.
*/ */
static void auditd_set(int pid, u32 portid, struct net *net) static int auditd_set(int pid, u32 portid, struct net *net)
{ {
unsigned long flags; unsigned long flags;
struct auditd_connection *ac_old, *ac_new;
spin_lock_irqsave(&auditd_conn.lock, flags); if (!pid || !net)
auditd_conn.pid = pid; return -EINVAL;
auditd_conn.portid = portid;
if (auditd_conn.net) ac_new = kzalloc(sizeof(*ac_new), GFP_KERNEL);
put_net(auditd_conn.net); if (!ac_new)
if (net) return -ENOMEM;
auditd_conn.net = get_net(net); ac_new->pid = pid;
else ac_new->portid = portid;
auditd_conn.net = NULL; ac_new->net = get_net(net);
spin_unlock_irqrestore(&auditd_conn.lock, flags);
spin_lock_irqsave(&auditd_conn_lock, flags);
ac_old = rcu_dereference_protected(auditd_conn,
lockdep_is_held(&auditd_conn_lock));
rcu_assign_pointer(auditd_conn, ac_new);
spin_unlock_irqrestore(&auditd_conn_lock, flags);
if (ac_old)
call_rcu(&ac_old->rcu, auditd_conn_free);
return 0;
} }
/** /**
...@@ -544,13 +597,19 @@ static void kauditd_retry_skb(struct sk_buff *skb) ...@@ -544,13 +597,19 @@ static void kauditd_retry_skb(struct sk_buff *skb)
*/ */
static void auditd_reset(void) static void auditd_reset(void)
{ {
unsigned long flags;
struct sk_buff *skb; struct sk_buff *skb;
struct auditd_connection *ac_old;
/* if it isn't already broken, break the connection */ /* if it isn't already broken, break the connection */
rcu_read_lock(); spin_lock_irqsave(&auditd_conn_lock, flags);
if (auditd_conn.pid) ac_old = rcu_dereference_protected(auditd_conn,
auditd_set(0, 0, NULL); lockdep_is_held(&auditd_conn_lock));
rcu_read_unlock(); rcu_assign_pointer(auditd_conn, NULL);
spin_unlock_irqrestore(&auditd_conn_lock, flags);
if (ac_old)
call_rcu(&ac_old->rcu, auditd_conn_free);
/* flush all of the main and retry queues to the hold queue */ /* flush all of the main and retry queues to the hold queue */
while ((skb = skb_dequeue(&audit_retry_queue))) while ((skb = skb_dequeue(&audit_retry_queue)))
...@@ -576,6 +635,7 @@ static int auditd_send_unicast_skb(struct sk_buff *skb) ...@@ -576,6 +635,7 @@ static int auditd_send_unicast_skb(struct sk_buff *skb)
u32 portid; u32 portid;
struct net *net; struct net *net;
struct sock *sk; struct sock *sk;
struct auditd_connection *ac;
/* NOTE: we can't call netlink_unicast while in the RCU section so /* NOTE: we can't call netlink_unicast while in the RCU section so
* take a reference to the network namespace and grab local * take a reference to the network namespace and grab local
...@@ -585,15 +645,15 @@ static int auditd_send_unicast_skb(struct sk_buff *skb) ...@@ -585,15 +645,15 @@ static int auditd_send_unicast_skb(struct sk_buff *skb)
* section netlink_unicast() should safely return an error */ * section netlink_unicast() should safely return an error */
rcu_read_lock(); rcu_read_lock();
if (!auditd_conn.pid) { ac = rcu_dereference(auditd_conn);
if (!ac) {
rcu_read_unlock(); rcu_read_unlock();
rc = -ECONNREFUSED; rc = -ECONNREFUSED;
goto err; goto err;
} }
net = auditd_conn.net; net = get_net(ac->net);
get_net(net);
sk = audit_get_sk(net); sk = audit_get_sk(net);
portid = auditd_conn.portid; portid = ac->portid;
rcu_read_unlock(); rcu_read_unlock();
rc = netlink_unicast(sk, skb, portid, 0); rc = netlink_unicast(sk, skb, portid, 0);
...@@ -728,6 +788,7 @@ static int kauditd_thread(void *dummy) ...@@ -728,6 +788,7 @@ static int kauditd_thread(void *dummy)
u32 portid = 0; u32 portid = 0;
struct net *net = NULL; struct net *net = NULL;
struct sock *sk = NULL; struct sock *sk = NULL;
struct auditd_connection *ac;
#define UNICAST_RETRIES 5 #define UNICAST_RETRIES 5
...@@ -735,14 +796,14 @@ static int kauditd_thread(void *dummy) ...@@ -735,14 +796,14 @@ static int kauditd_thread(void *dummy)
while (!kthread_should_stop()) { while (!kthread_should_stop()) {
/* NOTE: see the lock comments in auditd_send_unicast_skb() */ /* NOTE: see the lock comments in auditd_send_unicast_skb() */
rcu_read_lock(); rcu_read_lock();
if (!auditd_conn.pid) { ac = rcu_dereference(auditd_conn);
if (!ac) {
rcu_read_unlock(); rcu_read_unlock();
goto main_queue; goto main_queue;
} }
net = auditd_conn.net; net = get_net(ac->net);
get_net(net);
sk = audit_get_sk(net); sk = audit_get_sk(net);
portid = auditd_conn.portid; portid = ac->portid;
rcu_read_unlock(); rcu_read_unlock();
/* attempt to flush the hold queue */ /* attempt to flush the hold queue */
...@@ -1102,9 +1163,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh) ...@@ -1102,9 +1163,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
memset(&s, 0, sizeof(s)); memset(&s, 0, sizeof(s));
s.enabled = audit_enabled; s.enabled = audit_enabled;
s.failure = audit_failure; s.failure = audit_failure;
rcu_read_lock(); s.pid = auditd_pid_vnr();
s.pid = auditd_conn.pid;
rcu_read_unlock();
s.rate_limit = audit_rate_limit; s.rate_limit = audit_rate_limit;
s.backlog_limit = audit_backlog_limit; s.backlog_limit = audit_backlog_limit;
s.lost = atomic_read(&audit_lost); s.lost = atomic_read(&audit_lost);
...@@ -1143,38 +1202,44 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh) ...@@ -1143,38 +1202,44 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
/* test the auditd connection */ /* test the auditd connection */
audit_replace(requesting_pid); audit_replace(requesting_pid);
rcu_read_lock(); auditd_pid = auditd_pid_vnr();
auditd_pid = auditd_conn.pid;
/* only the current auditd can unregister itself */ /* only the current auditd can unregister itself */
if ((!new_pid) && (requesting_pid != auditd_pid)) { if ((!new_pid) && (requesting_pid != auditd_pid)) {
rcu_read_unlock();
audit_log_config_change("audit_pid", new_pid, audit_log_config_change("audit_pid", new_pid,
auditd_pid, 0); auditd_pid, 0);
return -EACCES; return -EACCES;
} }
/* replacing a healthy auditd is not allowed */ /* replacing a healthy auditd is not allowed */
if (auditd_pid && new_pid) { if (auditd_pid && new_pid) {
rcu_read_unlock();
audit_log_config_change("audit_pid", new_pid, audit_log_config_change("audit_pid", new_pid,
auditd_pid, 0); auditd_pid, 0);
return -EEXIST; return -EEXIST;
} }
rcu_read_unlock();
if (audit_enabled != AUDIT_OFF)
audit_log_config_change("audit_pid", new_pid,
auditd_pid, 1);
if (new_pid) { if (new_pid) {
/* register a new auditd connection */ /* register a new auditd connection */
auditd_set(new_pid, err = auditd_set(new_pid,
NETLINK_CB(skb).portid, NETLINK_CB(skb).portid,
sock_net(NETLINK_CB(skb).sk)); sock_net(NETLINK_CB(skb).sk));
if (audit_enabled != AUDIT_OFF)
audit_log_config_change("audit_pid",
new_pid,
auditd_pid,
err ? 0 : 1);
if (err)
return err;
/* try to process any backlog */ /* try to process any backlog */
wake_up_interruptible(&kauditd_wait); wake_up_interruptible(&kauditd_wait);
} else } else {
if (audit_enabled != AUDIT_OFF)
audit_log_config_change("audit_pid",
new_pid,
auditd_pid, 1);
/* unregister the auditd connection */ /* unregister the auditd connection */
auditd_reset(); auditd_reset();
}
} }
if (s.mask & AUDIT_STATUS_RATE_LIMIT) { if (s.mask & AUDIT_STATUS_RATE_LIMIT) {
err = audit_set_rate_limit(s.rate_limit); err = audit_set_rate_limit(s.rate_limit);
...@@ -1447,10 +1512,11 @@ static void __net_exit audit_net_exit(struct net *net) ...@@ -1447,10 +1512,11 @@ static void __net_exit audit_net_exit(struct net *net)
{ {
struct audit_net *aunet = net_generic(net, audit_net_id); struct audit_net *aunet = net_generic(net, audit_net_id);
rcu_read_lock(); /* NOTE: you would think that we would want to check the auditd
if (net == auditd_conn.net) * connection and potentially reset it here if it lives in this
auditd_reset(); * namespace, but since the auditd connection tracking struct holds a
rcu_read_unlock(); * reference to this namespace (see auditd_set()) we are only ever
* going to get here after that connection has been released */
netlink_kernel_release(aunet->sk); netlink_kernel_release(aunet->sk);
} }
...@@ -1470,9 +1536,6 @@ static int __init audit_init(void) ...@@ -1470,9 +1536,6 @@ static int __init audit_init(void)
if (audit_initialized == AUDIT_DISABLED) if (audit_initialized == AUDIT_DISABLED)
return 0; return 0;
memset(&auditd_conn, 0, sizeof(auditd_conn));
spin_lock_init(&auditd_conn.lock);
skb_queue_head_init(&audit_queue); skb_queue_head_init(&audit_queue);
skb_queue_head_init(&audit_retry_queue); skb_queue_head_init(&audit_retry_queue);
skb_queue_head_init(&audit_hold_queue); skb_queue_head_init(&audit_hold_queue);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment