Commit 49573ff7 authored by Jakub Kicinski's avatar Jakub Kicinski

Merge branch 'tls-splice_read-fixes'

Jakub Kicinski says:

====================
tls: splice_read fixes

As I work my way to unlocked and zero-copy TLS Rx the obvious bugs
in the splice_read implementation get harder and harder to ignore.
This is to say the fixes here are discovered by code inspection,
I'm not aware of anyone actually using splice_read.
====================

Link: https://lore.kernel.org/r/20211124232557.2039757-1-kuba@kernel.orgSigned-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parents 9dbe33cf f884a342
...@@ -61,7 +61,7 @@ static DEFINE_MUTEX(tcpv6_prot_mutex); ...@@ -61,7 +61,7 @@ static DEFINE_MUTEX(tcpv6_prot_mutex);
static const struct proto *saved_tcpv4_prot; static const struct proto *saved_tcpv4_prot;
static DEFINE_MUTEX(tcpv4_prot_mutex); static DEFINE_MUTEX(tcpv4_prot_mutex);
static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
static struct proto_ops tls_sw_proto_ops; static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
const struct proto *base); const struct proto *base);
...@@ -71,6 +71,8 @@ void update_sk_prot(struct sock *sk, struct tls_context *ctx) ...@@ -71,6 +71,8 @@ void update_sk_prot(struct sock *sk, struct tls_context *ctx)
WRITE_ONCE(sk->sk_prot, WRITE_ONCE(sk->sk_prot,
&tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]); &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]);
WRITE_ONCE(sk->sk_socket->ops,
&tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]);
} }
int wait_on_pending_writer(struct sock *sk, long *timeo) int wait_on_pending_writer(struct sock *sk, long *timeo)
...@@ -669,8 +671,6 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, ...@@ -669,8 +671,6 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
if (tx) { if (tx) {
ctx->sk_write_space = sk->sk_write_space; ctx->sk_write_space = sk->sk_write_space;
sk->sk_write_space = tls_write_space; sk->sk_write_space = tls_write_space;
} else {
sk->sk_socket->ops = &tls_sw_proto_ops;
} }
goto out; goto out;
...@@ -728,6 +728,39 @@ struct tls_context *tls_ctx_create(struct sock *sk) ...@@ -728,6 +728,39 @@ struct tls_context *tls_ctx_create(struct sock *sk)
return ctx; return ctx;
} }
static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
const struct proto_ops *base)
{
ops[TLS_BASE][TLS_BASE] = *base;
ops[TLS_SW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
ops[TLS_SW ][TLS_BASE].sendpage_locked = tls_sw_sendpage_locked;
ops[TLS_BASE][TLS_SW ] = ops[TLS_BASE][TLS_BASE];
ops[TLS_BASE][TLS_SW ].splice_read = tls_sw_splice_read;
ops[TLS_SW ][TLS_SW ] = ops[TLS_SW ][TLS_BASE];
ops[TLS_SW ][TLS_SW ].splice_read = tls_sw_splice_read;
#ifdef CONFIG_TLS_DEVICE
ops[TLS_HW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
ops[TLS_HW ][TLS_BASE].sendpage_locked = NULL;
ops[TLS_HW ][TLS_SW ] = ops[TLS_BASE][TLS_SW ];
ops[TLS_HW ][TLS_SW ].sendpage_locked = NULL;
ops[TLS_BASE][TLS_HW ] = ops[TLS_BASE][TLS_SW ];
ops[TLS_SW ][TLS_HW ] = ops[TLS_SW ][TLS_SW ];
ops[TLS_HW ][TLS_HW ] = ops[TLS_HW ][TLS_SW ];
ops[TLS_HW ][TLS_HW ].sendpage_locked = NULL;
#endif
#ifdef CONFIG_TLS_TOE
ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
#endif
}
static void tls_build_proto(struct sock *sk) static void tls_build_proto(struct sock *sk)
{ {
int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
...@@ -739,6 +772,8 @@ static void tls_build_proto(struct sock *sk) ...@@ -739,6 +772,8 @@ static void tls_build_proto(struct sock *sk)
mutex_lock(&tcpv6_prot_mutex); mutex_lock(&tcpv6_prot_mutex);
if (likely(prot != saved_tcpv6_prot)) { if (likely(prot != saved_tcpv6_prot)) {
build_protos(tls_prots[TLSV6], prot); build_protos(tls_prots[TLSV6], prot);
build_proto_ops(tls_proto_ops[TLSV6],
sk->sk_socket->ops);
smp_store_release(&saved_tcpv6_prot, prot); smp_store_release(&saved_tcpv6_prot, prot);
} }
mutex_unlock(&tcpv6_prot_mutex); mutex_unlock(&tcpv6_prot_mutex);
...@@ -749,6 +784,8 @@ static void tls_build_proto(struct sock *sk) ...@@ -749,6 +784,8 @@ static void tls_build_proto(struct sock *sk)
mutex_lock(&tcpv4_prot_mutex); mutex_lock(&tcpv4_prot_mutex);
if (likely(prot != saved_tcpv4_prot)) { if (likely(prot != saved_tcpv4_prot)) {
build_protos(tls_prots[TLSV4], prot); build_protos(tls_prots[TLSV4], prot);
build_proto_ops(tls_proto_ops[TLSV4],
sk->sk_socket->ops);
smp_store_release(&saved_tcpv4_prot, prot); smp_store_release(&saved_tcpv4_prot, prot);
} }
mutex_unlock(&tcpv4_prot_mutex); mutex_unlock(&tcpv4_prot_mutex);
...@@ -959,10 +996,6 @@ static int __init tls_register(void) ...@@ -959,10 +996,6 @@ static int __init tls_register(void)
if (err) if (err)
return err; return err;
tls_sw_proto_ops = inet_stream_ops;
tls_sw_proto_ops.splice_read = tls_sw_splice_read;
tls_sw_proto_ops.sendpage_locked = tls_sw_sendpage_locked;
tls_device_init(); tls_device_init();
tcp_register_ulp(&tcp_tls_ulp_ops); tcp_register_ulp(&tcp_tls_ulp_ops);
......
...@@ -2005,6 +2005,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, ...@@ -2005,6 +2005,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
struct sock *sk = sock->sk; struct sock *sk = sock->sk;
struct sk_buff *skb; struct sk_buff *skb;
ssize_t copied = 0; ssize_t copied = 0;
bool from_queue;
int err = 0; int err = 0;
long timeo; long timeo;
int chunk; int chunk;
...@@ -2014,25 +2015,28 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, ...@@ -2014,25 +2015,28 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
timeo = sock_rcvtimeo(sk, flags & SPLICE_F_NONBLOCK); timeo = sock_rcvtimeo(sk, flags & SPLICE_F_NONBLOCK);
skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo, &err); from_queue = !skb_queue_empty(&ctx->rx_list);
if (!skb) if (from_queue) {
goto splice_read_end; skb = __skb_dequeue(&ctx->rx_list);
} else {
if (!ctx->decrypted) { skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo,
err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false); &err);
if (!skb)
/* splice does not support reading control messages */
if (ctx->control != TLS_RECORD_TYPE_DATA) {
err = -EINVAL;
goto splice_read_end; goto splice_read_end;
}
err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
if (err < 0) { if (err < 0) {
tls_err_abort(sk, -EBADMSG); tls_err_abort(sk, -EBADMSG);
goto splice_read_end; goto splice_read_end;
} }
ctx->decrypted = 1;
} }
/* splice does not support reading control messages */
if (ctx->control != TLS_RECORD_TYPE_DATA) {
err = -EINVAL;
goto splice_read_end;
}
rxm = strp_msg(skb); rxm = strp_msg(skb);
chunk = min_t(unsigned int, rxm->full_len, len); chunk = min_t(unsigned int, rxm->full_len, len);
...@@ -2040,7 +2044,17 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, ...@@ -2040,7 +2044,17 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
if (copied < 0) if (copied < 0)
goto splice_read_end; goto splice_read_end;
tls_sw_advance_skb(sk, skb, copied); if (!from_queue) {
ctx->recv_pkt = NULL;
__strp_unpause(&ctx->strp);
}
if (chunk < rxm->full_len) {
__skb_queue_head(&ctx->rx_list, skb);
rxm->offset += len;
rxm->full_len -= len;
} else {
consume_skb(skb);
}
splice_read_end: splice_read_end:
release_sock(sk); release_sock(sk);
......
...@@ -78,26 +78,21 @@ static void memrnd(void *s, size_t n) ...@@ -78,26 +78,21 @@ static void memrnd(void *s, size_t n)
*byte++ = rand(); *byte++ = rand();
} }
FIXTURE(tls_basic) static void ulp_sock_pair(struct __test_metadata *_metadata,
{ int *fd, int *cfd, bool *notls)
int fd, cfd;
bool notls;
};
FIXTURE_SETUP(tls_basic)
{ {
struct sockaddr_in addr; struct sockaddr_in addr;
socklen_t len; socklen_t len;
int sfd, ret; int sfd, ret;
self->notls = false; *notls = false;
len = sizeof(addr); len = sizeof(addr);
addr.sin_family = AF_INET; addr.sin_family = AF_INET;
addr.sin_addr.s_addr = htonl(INADDR_ANY); addr.sin_addr.s_addr = htonl(INADDR_ANY);
addr.sin_port = 0; addr.sin_port = 0;
self->fd = socket(AF_INET, SOCK_STREAM, 0); *fd = socket(AF_INET, SOCK_STREAM, 0);
sfd = socket(AF_INET, SOCK_STREAM, 0); sfd = socket(AF_INET, SOCK_STREAM, 0);
ret = bind(sfd, &addr, sizeof(addr)); ret = bind(sfd, &addr, sizeof(addr));
...@@ -108,26 +103,96 @@ FIXTURE_SETUP(tls_basic) ...@@ -108,26 +103,96 @@ FIXTURE_SETUP(tls_basic)
ret = getsockname(sfd, &addr, &len); ret = getsockname(sfd, &addr, &len);
ASSERT_EQ(ret, 0); ASSERT_EQ(ret, 0);
ret = connect(self->fd, &addr, sizeof(addr)); ret = connect(*fd, &addr, sizeof(addr));
ASSERT_EQ(ret, 0); ASSERT_EQ(ret, 0);
self->cfd = accept(sfd, &addr, &len); *cfd = accept(sfd, &addr, &len);
ASSERT_GE(self->cfd, 0); ASSERT_GE(*cfd, 0);
close(sfd); close(sfd);
ret = setsockopt(self->fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls")); ret = setsockopt(*fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
if (ret != 0) { if (ret != 0) {
ASSERT_EQ(errno, ENOENT); ASSERT_EQ(errno, ENOENT);
self->notls = true; *notls = true;
printf("Failure setting TCP_ULP, testing without tls\n"); printf("Failure setting TCP_ULP, testing without tls\n");
return; return;
} }
ret = setsockopt(self->cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls")); ret = setsockopt(*cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
ASSERT_EQ(ret, 0); ASSERT_EQ(ret, 0);
} }
/* Produce a basic cmsg */
static int tls_send_cmsg(int fd, unsigned char record_type,
void *data, size_t len, int flags)
{
char cbuf[CMSG_SPACE(sizeof(char))];
int cmsg_len = sizeof(char);
struct cmsghdr *cmsg;
struct msghdr msg;
struct iovec vec;
vec.iov_base = data;
vec.iov_len = len;
memset(&msg, 0, sizeof(struct msghdr));
msg.msg_iov = &vec;
msg.msg_iovlen = 1;
msg.msg_control = cbuf;
msg.msg_controllen = sizeof(cbuf);
cmsg = CMSG_FIRSTHDR(&msg);
cmsg->cmsg_level = SOL_TLS;
/* test sending non-record types. */
cmsg->cmsg_type = TLS_SET_RECORD_TYPE;
cmsg->cmsg_len = CMSG_LEN(cmsg_len);
*CMSG_DATA(cmsg) = record_type;
msg.msg_controllen = cmsg->cmsg_len;
return sendmsg(fd, &msg, flags);
}
static int tls_recv_cmsg(struct __test_metadata *_metadata,
int fd, unsigned char record_type,
void *data, size_t len, int flags)
{
char cbuf[CMSG_SPACE(sizeof(char))];
struct cmsghdr *cmsg;
unsigned char ctype;
struct msghdr msg;
struct iovec vec;
int n;
vec.iov_base = data;
vec.iov_len = len;
memset(&msg, 0, sizeof(struct msghdr));
msg.msg_iov = &vec;
msg.msg_iovlen = 1;
msg.msg_control = cbuf;
msg.msg_controllen = sizeof(cbuf);
n = recvmsg(fd, &msg, flags);
cmsg = CMSG_FIRSTHDR(&msg);
EXPECT_NE(cmsg, NULL);
EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
ctype = *((unsigned char *)CMSG_DATA(cmsg));
EXPECT_EQ(ctype, record_type);
return n;
}
FIXTURE(tls_basic)
{
int fd, cfd;
bool notls;
};
FIXTURE_SETUP(tls_basic)
{
ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
}
FIXTURE_TEARDOWN(tls_basic) FIXTURE_TEARDOWN(tls_basic)
{ {
close(self->fd); close(self->fd);
...@@ -199,60 +264,21 @@ FIXTURE_VARIANT_ADD(tls, 13_sm4_ccm) ...@@ -199,60 +264,21 @@ FIXTURE_VARIANT_ADD(tls, 13_sm4_ccm)
FIXTURE_SETUP(tls) FIXTURE_SETUP(tls)
{ {
struct tls_crypto_info_keys tls12; struct tls_crypto_info_keys tls12;
struct sockaddr_in addr; int ret;
socklen_t len;
int sfd, ret;
self->notls = false;
len = sizeof(addr);
tls_crypto_info_init(variant->tls_version, variant->cipher_type, tls_crypto_info_init(variant->tls_version, variant->cipher_type,
&tls12); &tls12);
addr.sin_family = AF_INET; ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
addr.sin_addr.s_addr = htonl(INADDR_ANY);
addr.sin_port = 0;
self->fd = socket(AF_INET, SOCK_STREAM, 0); if (self->notls)
sfd = socket(AF_INET, SOCK_STREAM, 0); return;
ret = bind(sfd, &addr, sizeof(addr));
ASSERT_EQ(ret, 0);
ret = listen(sfd, 10);
ASSERT_EQ(ret, 0);
ret = getsockname(sfd, &addr, &len); ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
ASSERT_EQ(ret, 0); ASSERT_EQ(ret, 0);
ret = connect(self->fd, &addr, sizeof(addr)); ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len);
ASSERT_EQ(ret, 0); ASSERT_EQ(ret, 0);
ret = setsockopt(self->fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
if (ret != 0) {
self->notls = true;
printf("Failure setting TCP_ULP, testing without tls\n");
}
if (!self->notls) {
ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12,
tls12.len);
ASSERT_EQ(ret, 0);
}
self->cfd = accept(sfd, &addr, &len);
ASSERT_GE(self->cfd, 0);
if (!self->notls) {
ret = setsockopt(self->cfd, IPPROTO_TCP, TCP_ULP, "tls",
sizeof("tls"));
ASSERT_EQ(ret, 0);
ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12,
tls12.len);
ASSERT_EQ(ret, 0);
}
close(sfd);
} }
FIXTURE_TEARDOWN(tls) FIXTURE_TEARDOWN(tls)
...@@ -613,6 +639,95 @@ TEST_F(tls, splice_to_pipe) ...@@ -613,6 +639,95 @@ TEST_F(tls, splice_to_pipe)
EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0); EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
} }
TEST_F(tls, splice_cmsg_to_pipe)
{
char *test_str = "test_read";
char record_type = 100;
int send_len = 10;
char buf[10];
int p[2];
ASSERT_GE(pipe(p), 0);
EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), -1);
EXPECT_EQ(errno, EINVAL);
EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
EXPECT_EQ(errno, EIO);
EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
buf, sizeof(buf), MSG_WAITALL),
send_len);
EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
}
TEST_F(tls, splice_dec_cmsg_to_pipe)
{
char *test_str = "test_read";
char record_type = 100;
int send_len = 10;
char buf[10];
int p[2];
ASSERT_GE(pipe(p), 0);
EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
EXPECT_EQ(errno, EIO);
EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), -1);
EXPECT_EQ(errno, EINVAL);
EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
buf, sizeof(buf), MSG_WAITALL),
send_len);
EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
}
TEST_F(tls, recv_and_splice)
{
int send_len = TLS_PAYLOAD_MAX_LEN;
char mem_send[TLS_PAYLOAD_MAX_LEN];
char mem_recv[TLS_PAYLOAD_MAX_LEN];
int half = send_len / 2;
int p[2];
ASSERT_GE(pipe(p), 0);
EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
/* Recv hald of the record, splice the other half */
EXPECT_EQ(recv(self->cfd, mem_recv, half, MSG_WAITALL), half);
EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, half, SPLICE_F_NONBLOCK),
half);
EXPECT_EQ(read(p[0], &mem_recv[half], half), half);
EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
}
TEST_F(tls, peek_and_splice)
{
int send_len = TLS_PAYLOAD_MAX_LEN;
char mem_send[TLS_PAYLOAD_MAX_LEN];
char mem_recv[TLS_PAYLOAD_MAX_LEN];
int chunk = TLS_PAYLOAD_MAX_LEN / 4;
int n, i, p[2];
memrnd(mem_send, sizeof(mem_send));
ASSERT_GE(pipe(p), 0);
for (i = 0; i < 4; i++)
EXPECT_EQ(send(self->fd, &mem_send[chunk * i], chunk, 0),
chunk);
EXPECT_EQ(recv(self->cfd, mem_recv, chunk * 5 / 2,
MSG_WAITALL | MSG_PEEK),
chunk * 5 / 2);
EXPECT_EQ(memcmp(mem_send, mem_recv, chunk * 5 / 2), 0);
n = 0;
while (n < send_len) {
i = splice(self->cfd, NULL, p[1], NULL, send_len - n, 0);
EXPECT_GT(i, 0);
n += i;
}
EXPECT_EQ(n, send_len);
EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
}
TEST_F(tls, recvmsg_single) TEST_F(tls, recvmsg_single)
{ {
char const *test_str = "test_recvmsg_single"; char const *test_str = "test_recvmsg_single";
...@@ -1193,60 +1308,30 @@ TEST_F(tls, mutliproc_sendpage_writers) ...@@ -1193,60 +1308,30 @@ TEST_F(tls, mutliproc_sendpage_writers)
TEST_F(tls, control_msg) TEST_F(tls, control_msg)
{ {
if (self->notls) char *test_str = "test_read";
return;
char cbuf[CMSG_SPACE(sizeof(char))];
char const *test_str = "test_read";
int cmsg_len = sizeof(char);
char record_type = 100; char record_type = 100;
struct cmsghdr *cmsg;
struct msghdr msg;
int send_len = 10; int send_len = 10;
struct iovec vec;
char buf[10]; char buf[10];
vec.iov_base = (char *)test_str; if (self->notls)
vec.iov_len = 10; SKIP(return, "no TLS support");
memset(&msg, 0, sizeof(struct msghdr));
msg.msg_iov = &vec;
msg.msg_iovlen = 1;
msg.msg_control = cbuf;
msg.msg_controllen = sizeof(cbuf);
cmsg = CMSG_FIRSTHDR(&msg);
cmsg->cmsg_level = SOL_TLS;
/* test sending non-record types. */
cmsg->cmsg_type = TLS_SET_RECORD_TYPE;
cmsg->cmsg_len = CMSG_LEN(cmsg_len);
*CMSG_DATA(cmsg) = record_type;
msg.msg_controllen = cmsg->cmsg_len;
EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len); EXPECT_EQ(tls_send_cmsg(self->fd, record_type, test_str, send_len, 0),
send_len);
/* Should fail because we didn't provide a control message */ /* Should fail because we didn't provide a control message */
EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1); EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
vec.iov_base = buf; EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
EXPECT_EQ(recvmsg(self->cfd, &msg, MSG_WAITALL | MSG_PEEK), send_len); buf, sizeof(buf), MSG_WAITALL | MSG_PEEK),
send_len);
cmsg = CMSG_FIRSTHDR(&msg);
EXPECT_NE(cmsg, NULL);
EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
record_type = *((unsigned char *)CMSG_DATA(cmsg));
EXPECT_EQ(record_type, 100);
EXPECT_EQ(memcmp(buf, test_str, send_len), 0); EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
/* Recv the message again without MSG_PEEK */ /* Recv the message again without MSG_PEEK */
record_type = 0;
memset(buf, 0, sizeof(buf)); memset(buf, 0, sizeof(buf));
EXPECT_EQ(recvmsg(self->cfd, &msg, MSG_WAITALL), send_len); EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
cmsg = CMSG_FIRSTHDR(&msg); buf, sizeof(buf), MSG_WAITALL),
EXPECT_NE(cmsg, NULL); send_len);
EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
record_type = *((unsigned char *)CMSG_DATA(cmsg));
EXPECT_EQ(record_type, 100);
EXPECT_EQ(memcmp(buf, test_str, send_len), 0); EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
} }
...@@ -1301,6 +1386,160 @@ TEST_F(tls, shutdown_reuse) ...@@ -1301,6 +1386,160 @@ TEST_F(tls, shutdown_reuse)
EXPECT_EQ(errno, EISCONN); EXPECT_EQ(errno, EISCONN);
} }
FIXTURE(tls_err)
{
int fd, cfd;
int fd2, cfd2;
bool notls;
};
FIXTURE_VARIANT(tls_err)
{
uint16_t tls_version;
};
FIXTURE_VARIANT_ADD(tls_err, 12_aes_gcm)
{
.tls_version = TLS_1_2_VERSION,
};
FIXTURE_VARIANT_ADD(tls_err, 13_aes_gcm)
{
.tls_version = TLS_1_3_VERSION,
};
FIXTURE_SETUP(tls_err)
{
struct tls_crypto_info_keys tls12;
int ret;
tls_crypto_info_init(variant->tls_version, TLS_CIPHER_AES_GCM_128,
&tls12);
ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
ulp_sock_pair(_metadata, &self->fd2, &self->cfd2, &self->notls);
if (self->notls)
return;
ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
ASSERT_EQ(ret, 0);
ret = setsockopt(self->cfd2, SOL_TLS, TLS_RX, &tls12, tls12.len);
ASSERT_EQ(ret, 0);
}
FIXTURE_TEARDOWN(tls_err)
{
close(self->fd);
close(self->cfd);
close(self->fd2);
close(self->cfd2);
}
TEST_F(tls_err, bad_rec)
{
char buf[64];
if (self->notls)
SKIP(return, "no TLS support");
memset(buf, 0x55, sizeof(buf));
EXPECT_EQ(send(self->fd2, buf, sizeof(buf), 0), sizeof(buf));
EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
EXPECT_EQ(errno, EMSGSIZE);
EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), MSG_DONTWAIT), -1);
EXPECT_EQ(errno, EAGAIN);
}
TEST_F(tls_err, bad_auth)
{
char buf[128];
int n;
if (self->notls)
SKIP(return, "no TLS support");
memrnd(buf, sizeof(buf) / 2);
EXPECT_EQ(send(self->fd, buf, sizeof(buf) / 2, 0), sizeof(buf) / 2);
n = recv(self->cfd, buf, sizeof(buf), 0);
EXPECT_GT(n, sizeof(buf) / 2);
buf[n - 1]++;
EXPECT_EQ(send(self->fd2, buf, n, 0), n);
EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
EXPECT_EQ(errno, EBADMSG);
EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
EXPECT_EQ(errno, EBADMSG);
}
TEST_F(tls_err, bad_in_large_read)
{
char txt[3][64];
char cip[3][128];
char buf[3 * 128];
int i, n;
if (self->notls)
SKIP(return, "no TLS support");
/* Put 3 records in the sockets */
for (i = 0; i < 3; i++) {
memrnd(txt[i], sizeof(txt[i]));
EXPECT_EQ(send(self->fd, txt[i], sizeof(txt[i]), 0),
sizeof(txt[i]));
n = recv(self->cfd, cip[i], sizeof(cip[i]), 0);
EXPECT_GT(n, sizeof(txt[i]));
/* Break the third message */
if (i == 2)
cip[2][n - 1]++;
EXPECT_EQ(send(self->fd2, cip[i], n, 0), n);
}
/* We should be able to receive the first two messages */
EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), sizeof(txt[0]) * 2);
EXPECT_EQ(memcmp(buf, txt[0], sizeof(txt[0])), 0);
EXPECT_EQ(memcmp(buf + sizeof(txt[0]), txt[1], sizeof(txt[1])), 0);
/* Third mesasge is bad */
EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
EXPECT_EQ(errno, EBADMSG);
EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
EXPECT_EQ(errno, EBADMSG);
}
TEST_F(tls_err, bad_cmsg)
{
char *test_str = "test_read";
int send_len = 10;
char cip[128];
char buf[128];
char txt[64];
int n;
if (self->notls)
SKIP(return, "no TLS support");
/* Queue up one data record */
memrnd(txt, sizeof(txt));
EXPECT_EQ(send(self->fd, txt, sizeof(txt), 0), sizeof(txt));
n = recv(self->cfd, cip, sizeof(cip), 0);
EXPECT_GT(n, sizeof(txt));
EXPECT_EQ(send(self->fd2, cip, n, 0), n);
EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
n = recv(self->cfd, cip, sizeof(cip), 0);
cip[n - 1]++; /* Break it */
EXPECT_GT(n, send_len);
EXPECT_EQ(send(self->fd2, cip, n, 0), n);
EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), sizeof(txt));
EXPECT_EQ(memcmp(buf, txt, sizeof(txt)), 0);
EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
EXPECT_EQ(errno, EBADMSG);
EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
EXPECT_EQ(errno, EBADMSG);
}
TEST(non_established) { TEST(non_established) {
struct tls12_crypto_info_aes_gcm_256 tls12; struct tls12_crypto_info_aes_gcm_256 tls12;
struct sockaddr_in addr; struct sockaddr_in addr;
...@@ -1355,64 +1594,82 @@ TEST(non_established) { ...@@ -1355,64 +1594,82 @@ TEST(non_established) {
TEST(keysizes) { TEST(keysizes) {
struct tls12_crypto_info_aes_gcm_256 tls12; struct tls12_crypto_info_aes_gcm_256 tls12;
struct sockaddr_in addr; int ret, fd, cfd;
int sfd, ret, fd, cfd;
socklen_t len;
bool notls; bool notls;
notls = false;
len = sizeof(addr);
memset(&tls12, 0, sizeof(tls12)); memset(&tls12, 0, sizeof(tls12));
tls12.info.version = TLS_1_2_VERSION; tls12.info.version = TLS_1_2_VERSION;
tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256; tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
addr.sin_family = AF_INET; ulp_sock_pair(_metadata, &fd, &cfd, &notls);
addr.sin_addr.s_addr = htonl(INADDR_ANY);
addr.sin_port = 0;
fd = socket(AF_INET, SOCK_STREAM, 0); if (!notls) {
sfd = socket(AF_INET, SOCK_STREAM, 0); ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12,
sizeof(tls12));
EXPECT_EQ(ret, 0);
ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12,
sizeof(tls12));
EXPECT_EQ(ret, 0);
}
close(fd);
close(cfd);
}
TEST(tls_v6ops) {
struct tls_crypto_info_keys tls12;
struct sockaddr_in6 addr, addr2;
int sfd, ret, fd;
socklen_t len, len2;
tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_128, &tls12);
addr.sin6_family = AF_INET6;
addr.sin6_addr = in6addr_any;
addr.sin6_port = 0;
fd = socket(AF_INET6, SOCK_STREAM, 0);
sfd = socket(AF_INET6, SOCK_STREAM, 0);
ret = bind(sfd, &addr, sizeof(addr)); ret = bind(sfd, &addr, sizeof(addr));
ASSERT_EQ(ret, 0); ASSERT_EQ(ret, 0);
ret = listen(sfd, 10); ret = listen(sfd, 10);
ASSERT_EQ(ret, 0); ASSERT_EQ(ret, 0);
len = sizeof(addr);
ret = getsockname(sfd, &addr, &len); ret = getsockname(sfd, &addr, &len);
ASSERT_EQ(ret, 0); ASSERT_EQ(ret, 0);
ret = connect(fd, &addr, sizeof(addr)); ret = connect(fd, &addr, sizeof(addr));
ASSERT_EQ(ret, 0); ASSERT_EQ(ret, 0);
len = sizeof(addr);
ret = getsockname(fd, &addr, &len);
ASSERT_EQ(ret, 0);
ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls")); ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
if (ret != 0) { if (ret) {
notls = true; ASSERT_EQ(errno, ENOENT);
printf("Failure setting TCP_ULP, testing without tls\n"); SKIP(return, "no TLS support");
} }
ASSERT_EQ(ret, 0);
if (!notls) { ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12, ASSERT_EQ(ret, 0);
sizeof(tls12));
EXPECT_EQ(ret, 0);
}
cfd = accept(sfd, &addr, &len); ret = setsockopt(fd, SOL_TLS, TLS_RX, &tls12, tls12.len);
ASSERT_GE(cfd, 0); ASSERT_EQ(ret, 0);
if (!notls) { len2 = sizeof(addr2);
ret = setsockopt(cfd, IPPROTO_TCP, TCP_ULP, "tls", ret = getsockname(fd, &addr2, &len2);
sizeof("tls")); ASSERT_EQ(ret, 0);
EXPECT_EQ(ret, 0);
ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12, EXPECT_EQ(len2, len);
sizeof(tls12)); EXPECT_EQ(memcmp(&addr, &addr2, len), 0);
EXPECT_EQ(ret, 0);
}
close(sfd);
close(fd); close(fd);
close(cfd); close(sfd);
} }
TEST_HARNESS_MAIN TEST_HARNESS_MAIN
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment