Commit 4ff09db1 authored by Martin KaFai Lau's avatar Martin KaFai Lau Committed by Alexei Starovoitov

bpf: net: Change sk_getsockopt() to take the sockptr_t argument

This patch changes sk_getsockopt() to take the sockptr_t argument
such that it can be used by bpf_getsockopt(SOL_SOCKET) in a
latter patch.

security_socket_getpeersec_stream() is not changed.  It stays
with the __user ptr (optval.user and optlen.user) to avoid changes
to other security hooks.  bpf_getsockopt(SOL_SOCKET) also does not
support SO_PEERSEC.
Signed-off-by: default avatarMartin KaFai Lau <martin.lau@kernel.org>
Link: https://lore.kernel.org/r/20220902002802.2888419-1-kafai@fb.comSigned-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parent ba74a760
...@@ -900,8 +900,7 @@ int sk_reuseport_attach_filter(struct sock_fprog *fprog, struct sock *sk); ...@@ -900,8 +900,7 @@ int sk_reuseport_attach_filter(struct sock_fprog *fprog, struct sock *sk);
int sk_reuseport_attach_bpf(u32 ufd, struct sock *sk); int sk_reuseport_attach_bpf(u32 ufd, struct sock *sk);
void sk_reuseport_prog_free(struct bpf_prog *prog); void sk_reuseport_prog_free(struct bpf_prog *prog);
int sk_detach_filter(struct sock *sk); int sk_detach_filter(struct sock *sk);
int sk_get_filter(struct sock *sk, struct sock_filter __user *filter, int sk_get_filter(struct sock *sk, sockptr_t optval, unsigned int len);
unsigned int len);
bool sk_filter_charge(struct sock *sk, struct sk_filter *fp); bool sk_filter_charge(struct sock *sk, struct sk_filter *fp);
void sk_filter_uncharge(struct sock *sk, struct sk_filter *fp); void sk_filter_uncharge(struct sock *sk, struct sk_filter *fp);
......
...@@ -64,6 +64,11 @@ static inline int copy_to_sockptr_offset(sockptr_t dst, size_t offset, ...@@ -64,6 +64,11 @@ static inline int copy_to_sockptr_offset(sockptr_t dst, size_t offset,
return 0; return 0;
} }
static inline int copy_to_sockptr(sockptr_t dst, const void *src, size_t size)
{
return copy_to_sockptr_offset(dst, 0, src, size);
}
static inline void *memdup_sockptr(sockptr_t src, size_t len) static inline void *memdup_sockptr(sockptr_t src, size_t len)
{ {
void *p = kmalloc_track_caller(len, GFP_USER | __GFP_NOWARN); void *p = kmalloc_track_caller(len, GFP_USER | __GFP_NOWARN);
......
...@@ -10716,8 +10716,7 @@ int sk_detach_filter(struct sock *sk) ...@@ -10716,8 +10716,7 @@ int sk_detach_filter(struct sock *sk)
} }
EXPORT_SYMBOL_GPL(sk_detach_filter); EXPORT_SYMBOL_GPL(sk_detach_filter);
int sk_get_filter(struct sock *sk, struct sock_filter __user *ubuf, int sk_get_filter(struct sock *sk, sockptr_t optval, unsigned int len)
unsigned int len)
{ {
struct sock_fprog_kern *fprog; struct sock_fprog_kern *fprog;
struct sk_filter *filter; struct sk_filter *filter;
...@@ -10748,7 +10747,7 @@ int sk_get_filter(struct sock *sk, struct sock_filter __user *ubuf, ...@@ -10748,7 +10747,7 @@ int sk_get_filter(struct sock *sk, struct sock_filter __user *ubuf,
goto out; goto out;
ret = -EFAULT; ret = -EFAULT;
if (copy_to_user(ubuf, fprog->filter, bpf_classic_proglen(fprog))) if (copy_to_sockptr(optval, fprog->filter, bpf_classic_proglen(fprog)))
goto out; goto out;
/* Instead of bytes, the API requests to return the number /* Instead of bytes, the API requests to return the number
......
...@@ -712,8 +712,8 @@ static int sock_setbindtodevice(struct sock *sk, sockptr_t optval, int optlen) ...@@ -712,8 +712,8 @@ static int sock_setbindtodevice(struct sock *sk, sockptr_t optval, int optlen)
return ret; return ret;
} }
static int sock_getbindtodevice(struct sock *sk, char __user *optval, static int sock_getbindtodevice(struct sock *sk, sockptr_t optval,
int __user *optlen, int len) sockptr_t optlen, int len)
{ {
int ret = -ENOPROTOOPT; int ret = -ENOPROTOOPT;
#ifdef CONFIG_NETDEVICES #ifdef CONFIG_NETDEVICES
...@@ -737,12 +737,12 @@ static int sock_getbindtodevice(struct sock *sk, char __user *optval, ...@@ -737,12 +737,12 @@ static int sock_getbindtodevice(struct sock *sk, char __user *optval,
len = strlen(devname) + 1; len = strlen(devname) + 1;
ret = -EFAULT; ret = -EFAULT;
if (copy_to_user(optval, devname, len)) if (copy_to_sockptr(optval, devname, len))
goto out; goto out;
zero: zero:
ret = -EFAULT; ret = -EFAULT;
if (put_user(len, optlen)) if (copy_to_sockptr(optlen, &len, sizeof(int)))
goto out; goto out;
ret = 0; ret = 0;
...@@ -1568,20 +1568,23 @@ static void cred_to_ucred(struct pid *pid, const struct cred *cred, ...@@ -1568,20 +1568,23 @@ static void cred_to_ucred(struct pid *pid, const struct cred *cred,
} }
} }
static int groups_to_user(gid_t __user *dst, const struct group_info *src) static int groups_to_user(sockptr_t dst, const struct group_info *src)
{ {
struct user_namespace *user_ns = current_user_ns(); struct user_namespace *user_ns = current_user_ns();
int i; int i;
for (i = 0; i < src->ngroups; i++) for (i = 0; i < src->ngroups; i++) {
if (put_user(from_kgid_munged(user_ns, src->gid[i]), dst + i)) gid_t gid = from_kgid_munged(user_ns, src->gid[i]);
if (copy_to_sockptr_offset(dst, i * sizeof(gid), &gid, sizeof(gid)))
return -EFAULT; return -EFAULT;
}
return 0; return 0;
} }
static int sk_getsockopt(struct sock *sk, int level, int optname, static int sk_getsockopt(struct sock *sk, int level, int optname,
char __user *optval, int __user *optlen) sockptr_t optval, sockptr_t optlen)
{ {
struct socket *sock = sk->sk_socket; struct socket *sock = sk->sk_socket;
...@@ -1600,7 +1603,7 @@ static int sk_getsockopt(struct sock *sk, int level, int optname, ...@@ -1600,7 +1603,7 @@ static int sk_getsockopt(struct sock *sk, int level, int optname,
int lv = sizeof(int); int lv = sizeof(int);
int len; int len;
if (get_user(len, optlen)) if (copy_from_sockptr(&len, optlen, sizeof(int)))
return -EFAULT; return -EFAULT;
if (len < 0) if (len < 0)
return -EINVAL; return -EINVAL;
...@@ -1735,7 +1738,7 @@ static int sk_getsockopt(struct sock *sk, int level, int optname, ...@@ -1735,7 +1738,7 @@ static int sk_getsockopt(struct sock *sk, int level, int optname,
cred_to_ucred(sk->sk_peer_pid, sk->sk_peer_cred, &peercred); cred_to_ucred(sk->sk_peer_pid, sk->sk_peer_cred, &peercred);
spin_unlock(&sk->sk_peer_lock); spin_unlock(&sk->sk_peer_lock);
if (copy_to_user(optval, &peercred, len)) if (copy_to_sockptr(optval, &peercred, len))
return -EFAULT; return -EFAULT;
goto lenout; goto lenout;
} }
...@@ -1753,11 +1756,11 @@ static int sk_getsockopt(struct sock *sk, int level, int optname, ...@@ -1753,11 +1756,11 @@ static int sk_getsockopt(struct sock *sk, int level, int optname,
if (len < n * sizeof(gid_t)) { if (len < n * sizeof(gid_t)) {
len = n * sizeof(gid_t); len = n * sizeof(gid_t);
put_cred(cred); put_cred(cred);
return put_user(len, optlen) ? -EFAULT : -ERANGE; return copy_to_sockptr(optlen, &len, sizeof(int)) ? -EFAULT : -ERANGE;
} }
len = n * sizeof(gid_t); len = n * sizeof(gid_t);
ret = groups_to_user((gid_t __user *)optval, cred->group_info); ret = groups_to_user(optval, cred->group_info);
put_cred(cred); put_cred(cred);
if (ret) if (ret)
return ret; return ret;
...@@ -1773,7 +1776,7 @@ static int sk_getsockopt(struct sock *sk, int level, int optname, ...@@ -1773,7 +1776,7 @@ static int sk_getsockopt(struct sock *sk, int level, int optname,
return -ENOTCONN; return -ENOTCONN;
if (lv < len) if (lv < len)
return -EINVAL; return -EINVAL;
if (copy_to_user(optval, address, len)) if (copy_to_sockptr(optval, address, len))
return -EFAULT; return -EFAULT;
goto lenout; goto lenout;
} }
...@@ -1790,7 +1793,7 @@ static int sk_getsockopt(struct sock *sk, int level, int optname, ...@@ -1790,7 +1793,7 @@ static int sk_getsockopt(struct sock *sk, int level, int optname,
break; break;
case SO_PEERSEC: case SO_PEERSEC:
return security_socket_getpeersec_stream(sock, optval, optlen, len); return security_socket_getpeersec_stream(sock, optval.user, optlen.user, len);
case SO_MARK: case SO_MARK:
v.val = sk->sk_mark; v.val = sk->sk_mark;
...@@ -1822,7 +1825,7 @@ static int sk_getsockopt(struct sock *sk, int level, int optname, ...@@ -1822,7 +1825,7 @@ static int sk_getsockopt(struct sock *sk, int level, int optname,
return sock_getbindtodevice(sk, optval, optlen, len); return sock_getbindtodevice(sk, optval, optlen, len);
case SO_GET_FILTER: case SO_GET_FILTER:
len = sk_get_filter(sk, (struct sock_filter __user *)optval, len); len = sk_get_filter(sk, optval, len);
if (len < 0) if (len < 0)
return len; return len;
...@@ -1870,7 +1873,7 @@ static int sk_getsockopt(struct sock *sk, int level, int optname, ...@@ -1870,7 +1873,7 @@ static int sk_getsockopt(struct sock *sk, int level, int optname,
sk_get_meminfo(sk, meminfo); sk_get_meminfo(sk, meminfo);
len = min_t(unsigned int, len, sizeof(meminfo)); len = min_t(unsigned int, len, sizeof(meminfo));
if (copy_to_user(optval, &meminfo, len)) if (copy_to_sockptr(optval, &meminfo, len))
return -EFAULT; return -EFAULT;
goto lenout; goto lenout;
...@@ -1939,10 +1942,10 @@ static int sk_getsockopt(struct sock *sk, int level, int optname, ...@@ -1939,10 +1942,10 @@ static int sk_getsockopt(struct sock *sk, int level, int optname,
if (len > lv) if (len > lv)
len = lv; len = lv;
if (copy_to_user(optval, &v, len)) if (copy_to_sockptr(optval, &v, len))
return -EFAULT; return -EFAULT;
lenout: lenout:
if (put_user(len, optlen)) if (copy_to_sockptr(optlen, &len, sizeof(int)))
return -EFAULT; return -EFAULT;
return 0; return 0;
} }
...@@ -1950,7 +1953,9 @@ static int sk_getsockopt(struct sock *sk, int level, int optname, ...@@ -1950,7 +1953,9 @@ static int sk_getsockopt(struct sock *sk, int level, int optname,
int sock_getsockopt(struct socket *sock, int level, int optname, int sock_getsockopt(struct socket *sock, int level, int optname,
char __user *optval, int __user *optlen) char __user *optval, int __user *optlen)
{ {
return sk_getsockopt(sock->sk, level, optname, optval, optlen); return sk_getsockopt(sock->sk, level, optname,
USER_SOCKPTR(optval),
USER_SOCKPTR(optlen));
} }
/* /*
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment