Commit eb555cb5 authored by Linus Torvalds's avatar Linus Torvalds

Merge tag '5.20-rc-ksmbd-server-fixes' of git://git.samba.org/ksmbd

Pull ksmbd updates from Steve French:

 - fixes for memory access bugs (out of bounds access, oops, leak)

 - multichannel fixes

 - session disconnect performance improvement, and session register
   improvement

 - cleanup

* tag '5.20-rc-ksmbd-server-fixes' of git://git.samba.org/ksmbd:
  ksmbd: fix heap-based overflow in set_ntacl_dacl()
  ksmbd: prevent out of bound read for SMB2_TREE_CONNNECT
  ksmbd: prevent out of bound read for SMB2_WRITE
  ksmbd: fix use-after-free bug in smb2_tree_disconect
  ksmbd: fix memory leak in smb2_handle_negotiate
  ksmbd: fix racy issue while destroying session on multichannel
  ksmbd: use wait_event instead of schedule_timeout()
  ksmbd: fix kernel oops from idr_remove()
  ksmbd: add channel rwlock
  ksmbd: replace sessions list in connection with xarray
  MAINTAINERS: ksmbd: add entry for documentation
  ksmbd: remove unused ksmbd_share_configs_cleanup function
parents f30adc0d 8f054118
...@@ -11063,6 +11063,7 @@ R: Sergey Senozhatsky <senozhatsky@chromium.org> ...@@ -11063,6 +11063,7 @@ R: Sergey Senozhatsky <senozhatsky@chromium.org>
L: linux-cifs@vger.kernel.org L: linux-cifs@vger.kernel.org
S: Maintained S: Maintained
T: git git://git.samba.org/ksmbd.git T: git git://git.samba.org/ksmbd.git
F: Documentation/filesystems/cifs/ksmbd.rst
F: fs/ksmbd/ F: fs/ksmbd/
F: fs/smbfs_common/ F: fs/smbfs_common/
......
...@@ -121,8 +121,8 @@ static int ksmbd_gen_sess_key(struct ksmbd_session *sess, char *hash, ...@@ -121,8 +121,8 @@ static int ksmbd_gen_sess_key(struct ksmbd_session *sess, char *hash,
return rc; return rc;
} }
static int calc_ntlmv2_hash(struct ksmbd_session *sess, char *ntlmv2_hash, static int calc_ntlmv2_hash(struct ksmbd_conn *conn, struct ksmbd_session *sess,
char *dname) char *ntlmv2_hash, char *dname)
{ {
int ret, len, conv_len; int ret, len, conv_len;
wchar_t *domain = NULL; wchar_t *domain = NULL;
...@@ -158,7 +158,7 @@ static int calc_ntlmv2_hash(struct ksmbd_session *sess, char *ntlmv2_hash, ...@@ -158,7 +158,7 @@ static int calc_ntlmv2_hash(struct ksmbd_session *sess, char *ntlmv2_hash,
} }
conv_len = smb_strtoUTF16(uniname, user_name(sess->user), len, conv_len = smb_strtoUTF16(uniname, user_name(sess->user), len,
sess->conn->local_nls); conn->local_nls);
if (conv_len < 0 || conv_len > len) { if (conv_len < 0 || conv_len > len) {
ret = -EINVAL; ret = -EINVAL;
goto out; goto out;
...@@ -182,7 +182,7 @@ static int calc_ntlmv2_hash(struct ksmbd_session *sess, char *ntlmv2_hash, ...@@ -182,7 +182,7 @@ static int calc_ntlmv2_hash(struct ksmbd_session *sess, char *ntlmv2_hash,
} }
conv_len = smb_strtoUTF16((__le16 *)domain, dname, len, conv_len = smb_strtoUTF16((__le16 *)domain, dname, len,
sess->conn->local_nls); conn->local_nls);
if (conv_len < 0 || conv_len > len) { if (conv_len < 0 || conv_len > len) {
ret = -EINVAL; ret = -EINVAL;
goto out; goto out;
...@@ -215,8 +215,9 @@ static int calc_ntlmv2_hash(struct ksmbd_session *sess, char *ntlmv2_hash, ...@@ -215,8 +215,9 @@ static int calc_ntlmv2_hash(struct ksmbd_session *sess, char *ntlmv2_hash,
* *
* Return: 0 on success, error number on error * Return: 0 on success, error number on error
*/ */
int ksmbd_auth_ntlmv2(struct ksmbd_session *sess, struct ntlmv2_resp *ntlmv2, int ksmbd_auth_ntlmv2(struct ksmbd_conn *conn, struct ksmbd_session *sess,
int blen, char *domain_name, char *cryptkey) struct ntlmv2_resp *ntlmv2, int blen, char *domain_name,
char *cryptkey)
{ {
char ntlmv2_hash[CIFS_ENCPWD_SIZE]; char ntlmv2_hash[CIFS_ENCPWD_SIZE];
char ntlmv2_rsp[CIFS_HMAC_MD5_HASH_SIZE]; char ntlmv2_rsp[CIFS_HMAC_MD5_HASH_SIZE];
...@@ -230,7 +231,7 @@ int ksmbd_auth_ntlmv2(struct ksmbd_session *sess, struct ntlmv2_resp *ntlmv2, ...@@ -230,7 +231,7 @@ int ksmbd_auth_ntlmv2(struct ksmbd_session *sess, struct ntlmv2_resp *ntlmv2,
return -ENOMEM; return -ENOMEM;
} }
rc = calc_ntlmv2_hash(sess, ntlmv2_hash, domain_name); rc = calc_ntlmv2_hash(conn, sess, ntlmv2_hash, domain_name);
if (rc) { if (rc) {
ksmbd_debug(AUTH, "could not get v2 hash rc %d\n", rc); ksmbd_debug(AUTH, "could not get v2 hash rc %d\n", rc);
goto out; goto out;
...@@ -333,7 +334,8 @@ int ksmbd_decode_ntlmssp_auth_blob(struct authenticate_message *authblob, ...@@ -333,7 +334,8 @@ int ksmbd_decode_ntlmssp_auth_blob(struct authenticate_message *authblob,
/* process NTLMv2 authentication */ /* process NTLMv2 authentication */
ksmbd_debug(AUTH, "decode_ntlmssp_authenticate_blob dname%s\n", ksmbd_debug(AUTH, "decode_ntlmssp_authenticate_blob dname%s\n",
domain_name); domain_name);
ret = ksmbd_auth_ntlmv2(sess, (struct ntlmv2_resp *)((char *)authblob + nt_off), ret = ksmbd_auth_ntlmv2(conn, sess,
(struct ntlmv2_resp *)((char *)authblob + nt_off),
nt_len - CIFS_ENCPWD_SIZE, nt_len - CIFS_ENCPWD_SIZE,
domain_name, conn->ntlmssp.cryptkey); domain_name, conn->ntlmssp.cryptkey);
kfree(domain_name); kfree(domain_name);
...@@ -659,8 +661,9 @@ struct derivation { ...@@ -659,8 +661,9 @@ struct derivation {
bool binding; bool binding;
}; };
static int generate_key(struct ksmbd_session *sess, struct kvec label, static int generate_key(struct ksmbd_conn *conn, struct ksmbd_session *sess,
struct kvec context, __u8 *key, unsigned int key_size) struct kvec label, struct kvec context, __u8 *key,
unsigned int key_size)
{ {
unsigned char zero = 0x0; unsigned char zero = 0x0;
__u8 i[4] = {0, 0, 0, 1}; __u8 i[4] = {0, 0, 0, 1};
...@@ -720,8 +723,8 @@ static int generate_key(struct ksmbd_session *sess, struct kvec label, ...@@ -720,8 +723,8 @@ static int generate_key(struct ksmbd_session *sess, struct kvec label,
goto smb3signkey_ret; goto smb3signkey_ret;
} }
if (sess->conn->cipher_type == SMB2_ENCRYPTION_AES256_CCM || if (conn->cipher_type == SMB2_ENCRYPTION_AES256_CCM ||
sess->conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM) conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM)
rc = crypto_shash_update(CRYPTO_HMACSHA256(ctx), L256, 4); rc = crypto_shash_update(CRYPTO_HMACSHA256(ctx), L256, 4);
else else
rc = crypto_shash_update(CRYPTO_HMACSHA256(ctx), L128, 4); rc = crypto_shash_update(CRYPTO_HMACSHA256(ctx), L128, 4);
...@@ -756,17 +759,17 @@ static int generate_smb3signingkey(struct ksmbd_session *sess, ...@@ -756,17 +759,17 @@ static int generate_smb3signingkey(struct ksmbd_session *sess,
if (!chann) if (!chann)
return 0; return 0;
if (sess->conn->dialect >= SMB30_PROT_ID && signing->binding) if (conn->dialect >= SMB30_PROT_ID && signing->binding)
key = chann->smb3signingkey; key = chann->smb3signingkey;
else else
key = sess->smb3signingkey; key = sess->smb3signingkey;
rc = generate_key(sess, signing->label, signing->context, key, rc = generate_key(conn, sess, signing->label, signing->context, key,
SMB3_SIGN_KEY_SIZE); SMB3_SIGN_KEY_SIZE);
if (rc) if (rc)
return rc; return rc;
if (!(sess->conn->dialect >= SMB30_PROT_ID && signing->binding)) if (!(conn->dialect >= SMB30_PROT_ID && signing->binding))
memcpy(chann->smb3signingkey, key, SMB3_SIGN_KEY_SIZE); memcpy(chann->smb3signingkey, key, SMB3_SIGN_KEY_SIZE);
ksmbd_debug(AUTH, "dumping generated AES signing keys\n"); ksmbd_debug(AUTH, "dumping generated AES signing keys\n");
...@@ -820,30 +823,31 @@ struct derivation_twin { ...@@ -820,30 +823,31 @@ struct derivation_twin {
struct derivation decryption; struct derivation decryption;
}; };
static int generate_smb3encryptionkey(struct ksmbd_session *sess, static int generate_smb3encryptionkey(struct ksmbd_conn *conn,
struct ksmbd_session *sess,
const struct derivation_twin *ptwin) const struct derivation_twin *ptwin)
{ {
int rc; int rc;
rc = generate_key(sess, ptwin->encryption.label, rc = generate_key(conn, sess, ptwin->encryption.label,
ptwin->encryption.context, sess->smb3encryptionkey, ptwin->encryption.context, sess->smb3encryptionkey,
SMB3_ENC_DEC_KEY_SIZE); SMB3_ENC_DEC_KEY_SIZE);
if (rc) if (rc)
return rc; return rc;
rc = generate_key(sess, ptwin->decryption.label, rc = generate_key(conn, sess, ptwin->decryption.label,
ptwin->decryption.context, ptwin->decryption.context,
sess->smb3decryptionkey, SMB3_ENC_DEC_KEY_SIZE); sess->smb3decryptionkey, SMB3_ENC_DEC_KEY_SIZE);
if (rc) if (rc)
return rc; return rc;
ksmbd_debug(AUTH, "dumping generated AES encryption keys\n"); ksmbd_debug(AUTH, "dumping generated AES encryption keys\n");
ksmbd_debug(AUTH, "Cipher type %d\n", sess->conn->cipher_type); ksmbd_debug(AUTH, "Cipher type %d\n", conn->cipher_type);
ksmbd_debug(AUTH, "Session Id %llu\n", sess->id); ksmbd_debug(AUTH, "Session Id %llu\n", sess->id);
ksmbd_debug(AUTH, "Session Key %*ph\n", ksmbd_debug(AUTH, "Session Key %*ph\n",
SMB2_NTLMV2_SESSKEY_SIZE, sess->sess_key); SMB2_NTLMV2_SESSKEY_SIZE, sess->sess_key);
if (sess->conn->cipher_type == SMB2_ENCRYPTION_AES256_CCM || if (conn->cipher_type == SMB2_ENCRYPTION_AES256_CCM ||
sess->conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM) { conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM) {
ksmbd_debug(AUTH, "ServerIn Key %*ph\n", ksmbd_debug(AUTH, "ServerIn Key %*ph\n",
SMB3_GCM256_CRYPTKEY_SIZE, sess->smb3encryptionkey); SMB3_GCM256_CRYPTKEY_SIZE, sess->smb3encryptionkey);
ksmbd_debug(AUTH, "ServerOut Key %*ph\n", ksmbd_debug(AUTH, "ServerOut Key %*ph\n",
...@@ -857,7 +861,8 @@ static int generate_smb3encryptionkey(struct ksmbd_session *sess, ...@@ -857,7 +861,8 @@ static int generate_smb3encryptionkey(struct ksmbd_session *sess,
return 0; return 0;
} }
int ksmbd_gen_smb30_encryptionkey(struct ksmbd_session *sess) int ksmbd_gen_smb30_encryptionkey(struct ksmbd_conn *conn,
struct ksmbd_session *sess)
{ {
struct derivation_twin twin; struct derivation_twin twin;
struct derivation *d; struct derivation *d;
...@@ -874,10 +879,11 @@ int ksmbd_gen_smb30_encryptionkey(struct ksmbd_session *sess) ...@@ -874,10 +879,11 @@ int ksmbd_gen_smb30_encryptionkey(struct ksmbd_session *sess)
d->context.iov_base = "ServerIn "; d->context.iov_base = "ServerIn ";
d->context.iov_len = 10; d->context.iov_len = 10;
return generate_smb3encryptionkey(sess, &twin); return generate_smb3encryptionkey(conn, sess, &twin);
} }
int ksmbd_gen_smb311_encryptionkey(struct ksmbd_session *sess) int ksmbd_gen_smb311_encryptionkey(struct ksmbd_conn *conn,
struct ksmbd_session *sess)
{ {
struct derivation_twin twin; struct derivation_twin twin;
struct derivation *d; struct derivation *d;
...@@ -894,7 +900,7 @@ int ksmbd_gen_smb311_encryptionkey(struct ksmbd_session *sess) ...@@ -894,7 +900,7 @@ int ksmbd_gen_smb311_encryptionkey(struct ksmbd_session *sess)
d->context.iov_base = sess->Preauth_HashValue; d->context.iov_base = sess->Preauth_HashValue;
d->context.iov_len = 64; d->context.iov_len = 64;
return generate_smb3encryptionkey(sess, &twin); return generate_smb3encryptionkey(conn, sess, &twin);
} }
int ksmbd_gen_preauth_integrity_hash(struct ksmbd_conn *conn, char *buf, int ksmbd_gen_preauth_integrity_hash(struct ksmbd_conn *conn, char *buf,
......
...@@ -38,8 +38,9 @@ struct kvec; ...@@ -38,8 +38,9 @@ struct kvec;
int ksmbd_crypt_message(struct ksmbd_conn *conn, struct kvec *iov, int ksmbd_crypt_message(struct ksmbd_conn *conn, struct kvec *iov,
unsigned int nvec, int enc); unsigned int nvec, int enc);
void ksmbd_copy_gss_neg_header(void *buf); void ksmbd_copy_gss_neg_header(void *buf);
int ksmbd_auth_ntlmv2(struct ksmbd_session *sess, struct ntlmv2_resp *ntlmv2, int ksmbd_auth_ntlmv2(struct ksmbd_conn *conn, struct ksmbd_session *sess,
int blen, char *domain_name, char *cryptkey); struct ntlmv2_resp *ntlmv2, int blen, char *domain_name,
char *cryptkey);
int ksmbd_decode_ntlmssp_auth_blob(struct authenticate_message *authblob, int ksmbd_decode_ntlmssp_auth_blob(struct authenticate_message *authblob,
int blob_len, struct ksmbd_conn *conn, int blob_len, struct ksmbd_conn *conn,
struct ksmbd_session *sess); struct ksmbd_session *sess);
...@@ -58,8 +59,10 @@ int ksmbd_gen_smb30_signingkey(struct ksmbd_session *sess, ...@@ -58,8 +59,10 @@ int ksmbd_gen_smb30_signingkey(struct ksmbd_session *sess,
struct ksmbd_conn *conn); struct ksmbd_conn *conn);
int ksmbd_gen_smb311_signingkey(struct ksmbd_session *sess, int ksmbd_gen_smb311_signingkey(struct ksmbd_session *sess,
struct ksmbd_conn *conn); struct ksmbd_conn *conn);
int ksmbd_gen_smb30_encryptionkey(struct ksmbd_session *sess); int ksmbd_gen_smb30_encryptionkey(struct ksmbd_conn *conn,
int ksmbd_gen_smb311_encryptionkey(struct ksmbd_session *sess); struct ksmbd_session *sess);
int ksmbd_gen_smb311_encryptionkey(struct ksmbd_conn *conn,
struct ksmbd_session *sess);
int ksmbd_gen_preauth_integrity_hash(struct ksmbd_conn *conn, char *buf, int ksmbd_gen_preauth_integrity_hash(struct ksmbd_conn *conn, char *buf,
__u8 *pi_hash); __u8 *pi_hash);
int ksmbd_gen_sd_hash(struct ksmbd_conn *conn, char *sd_buf, int len, int ksmbd_gen_sd_hash(struct ksmbd_conn *conn, char *sd_buf, int len,
......
...@@ -36,6 +36,7 @@ void ksmbd_conn_free(struct ksmbd_conn *conn) ...@@ -36,6 +36,7 @@ void ksmbd_conn_free(struct ksmbd_conn *conn)
list_del(&conn->conns_list); list_del(&conn->conns_list);
write_unlock(&conn_list_lock); write_unlock(&conn_list_lock);
xa_destroy(&conn->sessions);
kvfree(conn->request_buf); kvfree(conn->request_buf);
kfree(conn->preauth_info); kfree(conn->preauth_info);
kfree(conn); kfree(conn);
...@@ -65,13 +66,14 @@ struct ksmbd_conn *ksmbd_conn_alloc(void) ...@@ -65,13 +66,14 @@ struct ksmbd_conn *ksmbd_conn_alloc(void)
conn->outstanding_credits = 0; conn->outstanding_credits = 0;
init_waitqueue_head(&conn->req_running_q); init_waitqueue_head(&conn->req_running_q);
init_waitqueue_head(&conn->r_count_q);
INIT_LIST_HEAD(&conn->conns_list); INIT_LIST_HEAD(&conn->conns_list);
INIT_LIST_HEAD(&conn->sessions);
INIT_LIST_HEAD(&conn->requests); INIT_LIST_HEAD(&conn->requests);
INIT_LIST_HEAD(&conn->async_requests); INIT_LIST_HEAD(&conn->async_requests);
spin_lock_init(&conn->request_lock); spin_lock_init(&conn->request_lock);
spin_lock_init(&conn->credits_lock); spin_lock_init(&conn->credits_lock);
ida_init(&conn->async_ida); ida_init(&conn->async_ida);
xa_init(&conn->sessions);
spin_lock_init(&conn->llist_lock); spin_lock_init(&conn->llist_lock);
INIT_LIST_HEAD(&conn->lock_list); INIT_LIST_HEAD(&conn->lock_list);
...@@ -164,7 +166,6 @@ int ksmbd_conn_write(struct ksmbd_work *work) ...@@ -164,7 +166,6 @@ int ksmbd_conn_write(struct ksmbd_work *work)
struct kvec iov[3]; struct kvec iov[3];
int iov_idx = 0; int iov_idx = 0;
ksmbd_conn_try_dequeue_request(work);
if (!work->response_buf) { if (!work->response_buf) {
pr_err("NULL response header\n"); pr_err("NULL response header\n");
return -EINVAL; return -EINVAL;
...@@ -346,8 +347,8 @@ int ksmbd_conn_handler_loop(void *p) ...@@ -346,8 +347,8 @@ int ksmbd_conn_handler_loop(void *p)
out: out:
/* Wait till all reference dropped to the Server object*/ /* Wait till all reference dropped to the Server object*/
while (atomic_read(&conn->r_count) > 0) wait_event(conn->r_count_q, atomic_read(&conn->r_count) == 0);
schedule_timeout(HZ);
unload_nls(conn->local_nls); unload_nls(conn->local_nls);
if (default_conn_ops.terminate_fn) if (default_conn_ops.terminate_fn)
......
...@@ -20,13 +20,6 @@ ...@@ -20,13 +20,6 @@
#define KSMBD_SOCKET_BACKLOG 16 #define KSMBD_SOCKET_BACKLOG 16
/*
* WARNING
*
* This is nothing but a HACK. Session status should move to channel
* or to session. As of now we have 1 tcp_conn : 1 ksmbd_session, but
* we need to change it to 1 tcp_conn : N ksmbd_sessions.
*/
enum { enum {
KSMBD_SESS_NEW = 0, KSMBD_SESS_NEW = 0,
KSMBD_SESS_GOOD, KSMBD_SESS_GOOD,
...@@ -55,7 +48,7 @@ struct ksmbd_conn { ...@@ -55,7 +48,7 @@ struct ksmbd_conn {
struct nls_table *local_nls; struct nls_table *local_nls;
struct list_head conns_list; struct list_head conns_list;
/* smb session 1 per user */ /* smb session 1 per user */
struct list_head sessions; struct xarray sessions;
unsigned long last_active; unsigned long last_active;
/* How many request are running currently */ /* How many request are running currently */
atomic_t req_running; atomic_t req_running;
...@@ -65,6 +58,7 @@ struct ksmbd_conn { ...@@ -65,6 +58,7 @@ struct ksmbd_conn {
unsigned int outstanding_credits; unsigned int outstanding_credits;
spinlock_t credits_lock; spinlock_t credits_lock;
wait_queue_head_t req_running_q; wait_queue_head_t req_running_q;
wait_queue_head_t r_count_q;
/* Lock to protect requests list*/ /* Lock to protect requests list*/
spinlock_t request_lock; spinlock_t request_lock;
struct list_head requests; struct list_head requests;
......
...@@ -222,17 +222,3 @@ bool ksmbd_share_veto_filename(struct ksmbd_share_config *share, ...@@ -222,17 +222,3 @@ bool ksmbd_share_veto_filename(struct ksmbd_share_config *share,
} }
return false; return false;
} }
void ksmbd_share_configs_cleanup(void)
{
struct ksmbd_share_config *share;
struct hlist_node *tmp;
int i;
down_write(&shares_table_lock);
hash_for_each_safe(shares_table, i, tmp, share, hlist) {
hash_del(&share->hlist);
kill_share(share);
}
up_write(&shares_table_lock);
}
...@@ -76,6 +76,4 @@ static inline void ksmbd_share_config_put(struct ksmbd_share_config *share) ...@@ -76,6 +76,4 @@ static inline void ksmbd_share_config_put(struct ksmbd_share_config *share)
struct ksmbd_share_config *ksmbd_share_config_get(char *name); struct ksmbd_share_config *ksmbd_share_config_get(char *name);
bool ksmbd_share_veto_filename(struct ksmbd_share_config *share, bool ksmbd_share_veto_filename(struct ksmbd_share_config *share,
const char *filename); const char *filename);
void ksmbd_share_configs_cleanup(void);
#endif /* __SHARE_CONFIG_MANAGEMENT_H__ */ #endif /* __SHARE_CONFIG_MANAGEMENT_H__ */
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
#include "user_session.h" #include "user_session.h"
struct ksmbd_tree_conn_status struct ksmbd_tree_conn_status
ksmbd_tree_conn_connect(struct ksmbd_session *sess, char *share_name) ksmbd_tree_conn_connect(struct ksmbd_conn *conn, struct ksmbd_session *sess,
char *share_name)
{ {
struct ksmbd_tree_conn_status status = {-EINVAL, NULL}; struct ksmbd_tree_conn_status status = {-EINVAL, NULL};
struct ksmbd_tree_connect_response *resp = NULL; struct ksmbd_tree_connect_response *resp = NULL;
...@@ -41,7 +42,7 @@ ksmbd_tree_conn_connect(struct ksmbd_session *sess, char *share_name) ...@@ -41,7 +42,7 @@ ksmbd_tree_conn_connect(struct ksmbd_session *sess, char *share_name)
goto out_error; goto out_error;
} }
peer_addr = KSMBD_TCP_PEER_SOCKADDR(sess->conn); peer_addr = KSMBD_TCP_PEER_SOCKADDR(conn);
resp = ksmbd_ipc_tree_connect_request(sess, resp = ksmbd_ipc_tree_connect_request(sess,
sc, sc,
tree_conn, tree_conn,
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
struct ksmbd_share_config; struct ksmbd_share_config;
struct ksmbd_user; struct ksmbd_user;
struct ksmbd_conn;
struct ksmbd_tree_connect { struct ksmbd_tree_connect {
int id; int id;
...@@ -40,7 +41,8 @@ static inline int test_tree_conn_flag(struct ksmbd_tree_connect *tree_conn, ...@@ -40,7 +41,8 @@ static inline int test_tree_conn_flag(struct ksmbd_tree_connect *tree_conn,
struct ksmbd_session; struct ksmbd_session;
struct ksmbd_tree_conn_status struct ksmbd_tree_conn_status
ksmbd_tree_conn_connect(struct ksmbd_session *sess, char *share_name); ksmbd_tree_conn_connect(struct ksmbd_conn *conn, struct ksmbd_session *sess,
char *share_name);
int ksmbd_tree_conn_disconnect(struct ksmbd_session *sess, int ksmbd_tree_conn_disconnect(struct ksmbd_session *sess,
struct ksmbd_tree_connect *tree_conn); struct ksmbd_tree_connect *tree_conn);
......
...@@ -32,11 +32,13 @@ static void free_channel_list(struct ksmbd_session *sess) ...@@ -32,11 +32,13 @@ static void free_channel_list(struct ksmbd_session *sess)
{ {
struct channel *chann, *tmp; struct channel *chann, *tmp;
write_lock(&sess->chann_lock);
list_for_each_entry_safe(chann, tmp, &sess->ksmbd_chann_list, list_for_each_entry_safe(chann, tmp, &sess->ksmbd_chann_list,
chann_list) { chann_list) {
list_del(&chann->chann_list); list_del(&chann->chann_list);
kfree(chann); kfree(chann);
} }
write_unlock(&sess->chann_lock);
} }
static void __session_rpc_close(struct ksmbd_session *sess, static void __session_rpc_close(struct ksmbd_session *sess,
...@@ -149,11 +151,6 @@ void ksmbd_session_destroy(struct ksmbd_session *sess) ...@@ -149,11 +151,6 @@ void ksmbd_session_destroy(struct ksmbd_session *sess)
if (!sess) if (!sess)
return; return;
if (!atomic_dec_and_test(&sess->refcnt))
return;
list_del(&sess->sessions_entry);
down_write(&sessions_table_lock); down_write(&sessions_table_lock);
hash_del(&sess->hlist); hash_del(&sess->hlist);
up_write(&sessions_table_lock); up_write(&sessions_table_lock);
...@@ -181,53 +178,70 @@ static struct ksmbd_session *__session_lookup(unsigned long long id) ...@@ -181,53 +178,70 @@ static struct ksmbd_session *__session_lookup(unsigned long long id)
return NULL; return NULL;
} }
void ksmbd_session_register(struct ksmbd_conn *conn, int ksmbd_session_register(struct ksmbd_conn *conn,
struct ksmbd_session *sess) struct ksmbd_session *sess)
{ {
sess->conn = conn; sess->dialect = conn->dialect;
list_add(&sess->sessions_entry, &conn->sessions); memcpy(sess->ClientGUID, conn->ClientGUID, SMB2_CLIENT_GUID_SIZE);
return xa_err(xa_store(&conn->sessions, sess->id, sess, GFP_KERNEL));
} }
void ksmbd_sessions_deregister(struct ksmbd_conn *conn) static int ksmbd_chann_del(struct ksmbd_conn *conn, struct ksmbd_session *sess)
{ {
struct ksmbd_session *sess; struct channel *chann, *tmp;
while (!list_empty(&conn->sessions)) {
sess = list_entry(conn->sessions.next,
struct ksmbd_session,
sessions_entry);
ksmbd_session_destroy(sess); write_lock(&sess->chann_lock);
list_for_each_entry_safe(chann, tmp, &sess->ksmbd_chann_list,
chann_list) {
if (chann->conn == conn) {
list_del(&chann->chann_list);
kfree(chann);
write_unlock(&sess->chann_lock);
return 0;
}
} }
} write_unlock(&sess->chann_lock);
static bool ksmbd_session_id_match(struct ksmbd_session *sess, return -ENOENT;
unsigned long long id)
{
return sess->id == id;
} }
struct ksmbd_session *ksmbd_session_lookup(struct ksmbd_conn *conn, void ksmbd_sessions_deregister(struct ksmbd_conn *conn)
unsigned long long id)
{ {
struct ksmbd_session *sess = NULL; struct ksmbd_session *sess;
list_for_each_entry(sess, &conn->sessions, sessions_entry) { if (conn->binding) {
if (ksmbd_session_id_match(sess, id)) int bkt;
return sess;
down_write(&sessions_table_lock);
hash_for_each(sessions_table, bkt, sess, hlist) {
if (!ksmbd_chann_del(conn, sess)) {
up_write(&sessions_table_lock);
goto sess_destroy;
}
}
up_write(&sessions_table_lock);
} else {
unsigned long id;
xa_for_each(&conn->sessions, id, sess) {
if (!ksmbd_chann_del(conn, sess))
goto sess_destroy;
}
} }
return NULL;
}
int get_session(struct ksmbd_session *sess) return;
{
return atomic_inc_not_zero(&sess->refcnt); sess_destroy:
if (list_empty(&sess->ksmbd_chann_list)) {
xa_erase(&conn->sessions, sess->id);
ksmbd_session_destroy(sess);
}
} }
void put_session(struct ksmbd_session *sess) struct ksmbd_session *ksmbd_session_lookup(struct ksmbd_conn *conn,
unsigned long long id)
{ {
if (atomic_dec_and_test(&sess->refcnt)) return xa_load(&conn->sessions, id);
pr_err("get/%s seems to be mismatched.", __func__);
} }
struct ksmbd_session *ksmbd_session_lookup_slowpath(unsigned long long id) struct ksmbd_session *ksmbd_session_lookup_slowpath(unsigned long long id)
...@@ -236,10 +250,6 @@ struct ksmbd_session *ksmbd_session_lookup_slowpath(unsigned long long id) ...@@ -236,10 +250,6 @@ struct ksmbd_session *ksmbd_session_lookup_slowpath(unsigned long long id)
down_read(&sessions_table_lock); down_read(&sessions_table_lock);
sess = __session_lookup(id); sess = __session_lookup(id);
if (sess) {
if (!get_session(sess))
sess = NULL;
}
up_read(&sessions_table_lock); up_read(&sessions_table_lock);
return sess; return sess;
...@@ -253,6 +263,8 @@ struct ksmbd_session *ksmbd_session_lookup_all(struct ksmbd_conn *conn, ...@@ -253,6 +263,8 @@ struct ksmbd_session *ksmbd_session_lookup_all(struct ksmbd_conn *conn,
sess = ksmbd_session_lookup(conn, id); sess = ksmbd_session_lookup(conn, id);
if (!sess && conn->binding) if (!sess && conn->binding)
sess = ksmbd_session_lookup_slowpath(id); sess = ksmbd_session_lookup_slowpath(id);
if (sess && sess->state != SMB2_SESSION_VALID)
sess = NULL;
return sess; return sess;
} }
...@@ -314,12 +326,11 @@ static struct ksmbd_session *__session_create(int protocol) ...@@ -314,12 +326,11 @@ static struct ksmbd_session *__session_create(int protocol)
goto error; goto error;
set_session_flag(sess, protocol); set_session_flag(sess, protocol);
INIT_LIST_HEAD(&sess->sessions_entry);
xa_init(&sess->tree_conns); xa_init(&sess->tree_conns);
INIT_LIST_HEAD(&sess->ksmbd_chann_list); INIT_LIST_HEAD(&sess->ksmbd_chann_list);
INIT_LIST_HEAD(&sess->rpc_handle_list); INIT_LIST_HEAD(&sess->rpc_handle_list);
sess->sequence_number = 1; sess->sequence_number = 1;
atomic_set(&sess->refcnt, 1); rwlock_init(&sess->chann_lock);
switch (protocol) { switch (protocol) {
case CIFDS_SESSION_FLAG_SMB2: case CIFDS_SESSION_FLAG_SMB2:
......
...@@ -33,8 +33,10 @@ struct preauth_session { ...@@ -33,8 +33,10 @@ struct preauth_session {
struct ksmbd_session { struct ksmbd_session {
u64 id; u64 id;
__u16 dialect;
char ClientGUID[SMB2_CLIENT_GUID_SIZE];
struct ksmbd_user *user; struct ksmbd_user *user;
struct ksmbd_conn *conn;
unsigned int sequence_number; unsigned int sequence_number;
unsigned int flags; unsigned int flags;
...@@ -48,6 +50,7 @@ struct ksmbd_session { ...@@ -48,6 +50,7 @@ struct ksmbd_session {
char sess_key[CIFS_KEY_SIZE]; char sess_key[CIFS_KEY_SIZE];
struct hlist_node hlist; struct hlist_node hlist;
rwlock_t chann_lock;
struct list_head ksmbd_chann_list; struct list_head ksmbd_chann_list;
struct xarray tree_conns; struct xarray tree_conns;
struct ida tree_conn_ida; struct ida tree_conn_ida;
...@@ -57,9 +60,7 @@ struct ksmbd_session { ...@@ -57,9 +60,7 @@ struct ksmbd_session {
__u8 smb3decryptionkey[SMB3_ENC_DEC_KEY_SIZE]; __u8 smb3decryptionkey[SMB3_ENC_DEC_KEY_SIZE];
__u8 smb3signingkey[SMB3_SIGN_KEY_SIZE]; __u8 smb3signingkey[SMB3_SIGN_KEY_SIZE];
struct list_head sessions_entry;
struct ksmbd_file_table file_table; struct ksmbd_file_table file_table;
atomic_t refcnt;
}; };
static inline int test_session_flag(struct ksmbd_session *sess, int bit) static inline int test_session_flag(struct ksmbd_session *sess, int bit)
...@@ -84,8 +85,8 @@ void ksmbd_session_destroy(struct ksmbd_session *sess); ...@@ -84,8 +85,8 @@ void ksmbd_session_destroy(struct ksmbd_session *sess);
struct ksmbd_session *ksmbd_session_lookup_slowpath(unsigned long long id); struct ksmbd_session *ksmbd_session_lookup_slowpath(unsigned long long id);
struct ksmbd_session *ksmbd_session_lookup(struct ksmbd_conn *conn, struct ksmbd_session *ksmbd_session_lookup(struct ksmbd_conn *conn,
unsigned long long id); unsigned long long id);
void ksmbd_session_register(struct ksmbd_conn *conn, int ksmbd_session_register(struct ksmbd_conn *conn,
struct ksmbd_session *sess); struct ksmbd_session *sess);
void ksmbd_sessions_deregister(struct ksmbd_conn *conn); void ksmbd_sessions_deregister(struct ksmbd_conn *conn);
struct ksmbd_session *ksmbd_session_lookup_all(struct ksmbd_conn *conn, struct ksmbd_session *ksmbd_session_lookup_all(struct ksmbd_conn *conn,
unsigned long long id); unsigned long long id);
...@@ -100,6 +101,4 @@ void ksmbd_release_tree_conn_id(struct ksmbd_session *sess, int id); ...@@ -100,6 +101,4 @@ void ksmbd_release_tree_conn_id(struct ksmbd_session *sess, int id);
int ksmbd_session_rpc_open(struct ksmbd_session *sess, char *rpc_name); int ksmbd_session_rpc_open(struct ksmbd_session *sess, char *rpc_name);
void ksmbd_session_rpc_close(struct ksmbd_session *sess, int id); void ksmbd_session_rpc_close(struct ksmbd_session *sess, int id);
int ksmbd_session_rpc_method(struct ksmbd_session *sess, int id); int ksmbd_session_rpc_method(struct ksmbd_session *sess, int id);
int get_session(struct ksmbd_session *sess);
void put_session(struct ksmbd_session *sess);
#endif /* __USER_SESSION_MANAGEMENT_H__ */ #endif /* __USER_SESSION_MANAGEMENT_H__ */
...@@ -30,6 +30,7 @@ static DEFINE_RWLOCK(lease_list_lock); ...@@ -30,6 +30,7 @@ static DEFINE_RWLOCK(lease_list_lock);
static struct oplock_info *alloc_opinfo(struct ksmbd_work *work, static struct oplock_info *alloc_opinfo(struct ksmbd_work *work,
u64 id, __u16 Tid) u64 id, __u16 Tid)
{ {
struct ksmbd_conn *conn = work->conn;
struct ksmbd_session *sess = work->sess; struct ksmbd_session *sess = work->sess;
struct oplock_info *opinfo; struct oplock_info *opinfo;
...@@ -38,7 +39,7 @@ static struct oplock_info *alloc_opinfo(struct ksmbd_work *work, ...@@ -38,7 +39,7 @@ static struct oplock_info *alloc_opinfo(struct ksmbd_work *work,
return NULL; return NULL;
opinfo->sess = sess; opinfo->sess = sess;
opinfo->conn = sess->conn; opinfo->conn = conn;
opinfo->level = SMB2_OPLOCK_LEVEL_NONE; opinfo->level = SMB2_OPLOCK_LEVEL_NONE;
opinfo->op_state = OPLOCK_STATE_NONE; opinfo->op_state = OPLOCK_STATE_NONE;
opinfo->pending_break = 0; opinfo->pending_break = 0;
...@@ -615,18 +616,13 @@ static void __smb2_oplock_break_noti(struct work_struct *wk) ...@@ -615,18 +616,13 @@ static void __smb2_oplock_break_noti(struct work_struct *wk)
struct ksmbd_file *fp; struct ksmbd_file *fp;
fp = ksmbd_lookup_durable_fd(br_info->fid); fp = ksmbd_lookup_durable_fd(br_info->fid);
if (!fp) { if (!fp)
atomic_dec(&conn->r_count); goto out;
ksmbd_free_work_struct(work);
return;
}
if (allocate_oplock_break_buf(work)) { if (allocate_oplock_break_buf(work)) {
pr_err("smb2_allocate_rsp_buf failed! "); pr_err("smb2_allocate_rsp_buf failed! ");
atomic_dec(&conn->r_count);
ksmbd_fd_put(work, fp); ksmbd_fd_put(work, fp);
ksmbd_free_work_struct(work); goto out;
return;
} }
rsp_hdr = smb2_get_msg(work->response_buf); rsp_hdr = smb2_get_msg(work->response_buf);
...@@ -667,8 +663,16 @@ static void __smb2_oplock_break_noti(struct work_struct *wk) ...@@ -667,8 +663,16 @@ static void __smb2_oplock_break_noti(struct work_struct *wk)
ksmbd_fd_put(work, fp); ksmbd_fd_put(work, fp);
ksmbd_conn_write(work); ksmbd_conn_write(work);
out:
ksmbd_free_work_struct(work); ksmbd_free_work_struct(work);
atomic_dec(&conn->r_count); /*
* Checking waitqueue to dropping pending requests on
* disconnection. waitqueue_active is safe because it
* uses atomic operation for condition.
*/
if (!atomic_dec_return(&conn->r_count) && waitqueue_active(&conn->r_count_q))
wake_up(&conn->r_count_q);
} }
/** /**
...@@ -731,9 +735,7 @@ static void __smb2_lease_break_noti(struct work_struct *wk) ...@@ -731,9 +735,7 @@ static void __smb2_lease_break_noti(struct work_struct *wk)
if (allocate_oplock_break_buf(work)) { if (allocate_oplock_break_buf(work)) {
ksmbd_debug(OPLOCK, "smb2_allocate_rsp_buf failed! "); ksmbd_debug(OPLOCK, "smb2_allocate_rsp_buf failed! ");
ksmbd_free_work_struct(work); goto out;
atomic_dec(&conn->r_count);
return;
} }
rsp_hdr = smb2_get_msg(work->response_buf); rsp_hdr = smb2_get_msg(work->response_buf);
...@@ -771,8 +773,16 @@ static void __smb2_lease_break_noti(struct work_struct *wk) ...@@ -771,8 +773,16 @@ static void __smb2_lease_break_noti(struct work_struct *wk)
inc_rfc1001_len(work->response_buf, 44); inc_rfc1001_len(work->response_buf, 44);
ksmbd_conn_write(work); ksmbd_conn_write(work);
out:
ksmbd_free_work_struct(work); ksmbd_free_work_struct(work);
atomic_dec(&conn->r_count); /*
* Checking waitqueue to dropping pending requests on
* disconnection. waitqueue_active is safe because it
* uses atomic operation for condition.
*/
if (!atomic_dec_return(&conn->r_count) && waitqueue_active(&conn->r_count_q))
wake_up(&conn->r_count_q);
} }
/** /**
...@@ -972,7 +982,7 @@ int find_same_lease_key(struct ksmbd_session *sess, struct ksmbd_inode *ci, ...@@ -972,7 +982,7 @@ int find_same_lease_key(struct ksmbd_session *sess, struct ksmbd_inode *ci,
} }
list_for_each_entry(lb, &lease_table_list, l_entry) { list_for_each_entry(lb, &lease_table_list, l_entry) {
if (!memcmp(lb->client_guid, sess->conn->ClientGUID, if (!memcmp(lb->client_guid, sess->ClientGUID,
SMB2_CLIENT_GUID_SIZE)) SMB2_CLIENT_GUID_SIZE))
goto found; goto found;
} }
...@@ -988,7 +998,7 @@ int find_same_lease_key(struct ksmbd_session *sess, struct ksmbd_inode *ci, ...@@ -988,7 +998,7 @@ int find_same_lease_key(struct ksmbd_session *sess, struct ksmbd_inode *ci,
rcu_read_unlock(); rcu_read_unlock();
if (opinfo->o_fp->f_ci == ci) if (opinfo->o_fp->f_ci == ci)
goto op_next; goto op_next;
err = compare_guid_key(opinfo, sess->conn->ClientGUID, err = compare_guid_key(opinfo, sess->ClientGUID,
lctx->lease_key); lctx->lease_key);
if (err) { if (err) {
err = -EINVAL; err = -EINVAL;
...@@ -1122,7 +1132,7 @@ int smb_grant_oplock(struct ksmbd_work *work, int req_op_level, u64 pid, ...@@ -1122,7 +1132,7 @@ int smb_grant_oplock(struct ksmbd_work *work, int req_op_level, u64 pid,
struct oplock_info *m_opinfo; struct oplock_info *m_opinfo;
/* is lease already granted ? */ /* is lease already granted ? */
m_opinfo = same_client_has_lease(ci, sess->conn->ClientGUID, m_opinfo = same_client_has_lease(ci, sess->ClientGUID,
lctx); lctx);
if (m_opinfo) { if (m_opinfo) {
copy_lease(m_opinfo, opinfo); copy_lease(m_opinfo, opinfo);
...@@ -1240,7 +1250,7 @@ void smb_break_all_levII_oplock(struct ksmbd_work *work, struct ksmbd_file *fp, ...@@ -1240,7 +1250,7 @@ void smb_break_all_levII_oplock(struct ksmbd_work *work, struct ksmbd_file *fp,
{ {
struct oplock_info *op, *brk_op; struct oplock_info *op, *brk_op;
struct ksmbd_inode *ci; struct ksmbd_inode *ci;
struct ksmbd_conn *conn = work->sess->conn; struct ksmbd_conn *conn = work->conn;
if (!test_share_config_flag(work->tcon->share_conf, if (!test_share_config_flag(work->tcon->share_conf,
KSMBD_SHARE_FLAG_OPLOCKS)) KSMBD_SHARE_FLAG_OPLOCKS))
......
...@@ -261,7 +261,13 @@ static void handle_ksmbd_work(struct work_struct *wk) ...@@ -261,7 +261,13 @@ static void handle_ksmbd_work(struct work_struct *wk)
ksmbd_conn_try_dequeue_request(work); ksmbd_conn_try_dequeue_request(work);
ksmbd_free_work_struct(work); ksmbd_free_work_struct(work);
atomic_dec(&conn->r_count); /*
* Checking waitqueue to dropping pending requests on
* disconnection. waitqueue_active is safe because it
* uses atomic operation for condition.
*/
if (!atomic_dec_return(&conn->r_count) && waitqueue_active(&conn->r_count_q))
wake_up(&conn->r_count_q);
} }
/** /**
......
...@@ -90,11 +90,6 @@ static int smb2_get_data_area_len(unsigned int *off, unsigned int *len, ...@@ -90,11 +90,6 @@ static int smb2_get_data_area_len(unsigned int *off, unsigned int *len,
*off = 0; *off = 0;
*len = 0; *len = 0;
/* error reqeusts do not have data area */
if (hdr->Status && hdr->Status != STATUS_MORE_PROCESSING_REQUIRED &&
(((struct smb2_err_rsp *)hdr)->StructureSize) == SMB2_ERROR_STRUCTURE_SIZE2_LE)
return ret;
/* /*
* Following commands have data areas so we have to get the location * Following commands have data areas so we have to get the location
* of the data buffer offset and data buffer length for the particular * of the data buffer offset and data buffer length for the particular
...@@ -136,8 +131,11 @@ static int smb2_get_data_area_len(unsigned int *off, unsigned int *len, ...@@ -136,8 +131,11 @@ static int smb2_get_data_area_len(unsigned int *off, unsigned int *len,
*len = le16_to_cpu(((struct smb2_read_req *)hdr)->ReadChannelInfoLength); *len = le16_to_cpu(((struct smb2_read_req *)hdr)->ReadChannelInfoLength);
break; break;
case SMB2_WRITE: case SMB2_WRITE:
if (((struct smb2_write_req *)hdr)->DataOffset) { if (((struct smb2_write_req *)hdr)->DataOffset ||
*off = le16_to_cpu(((struct smb2_write_req *)hdr)->DataOffset); ((struct smb2_write_req *)hdr)->Length) {
*off = max_t(unsigned int,
le16_to_cpu(((struct smb2_write_req *)hdr)->DataOffset),
offsetof(struct smb2_write_req, Buffer));
*len = le32_to_cpu(((struct smb2_write_req *)hdr)->Length); *len = le32_to_cpu(((struct smb2_write_req *)hdr)->Length);
break; break;
} }
......
This diff is collapsed.
...@@ -421,7 +421,7 @@ struct smb_version_ops { ...@@ -421,7 +421,7 @@ struct smb_version_ops {
int (*check_sign_req)(struct ksmbd_work *work); int (*check_sign_req)(struct ksmbd_work *work);
void (*set_sign_rsp)(struct ksmbd_work *work); void (*set_sign_rsp)(struct ksmbd_work *work);
int (*generate_signingkey)(struct ksmbd_session *sess, struct ksmbd_conn *conn); int (*generate_signingkey)(struct ksmbd_session *sess, struct ksmbd_conn *conn);
int (*generate_encryptionkey)(struct ksmbd_session *sess); int (*generate_encryptionkey)(struct ksmbd_conn *conn, struct ksmbd_session *sess);
bool (*is_transform_hdr)(void *buf); bool (*is_transform_hdr)(void *buf);
int (*decrypt_req)(struct ksmbd_work *work); int (*decrypt_req)(struct ksmbd_work *work);
int (*encrypt_resp)(struct ksmbd_work *work); int (*encrypt_resp)(struct ksmbd_work *work);
......
...@@ -690,6 +690,7 @@ static void set_posix_acl_entries_dacl(struct user_namespace *user_ns, ...@@ -690,6 +690,7 @@ static void set_posix_acl_entries_dacl(struct user_namespace *user_ns,
static void set_ntacl_dacl(struct user_namespace *user_ns, static void set_ntacl_dacl(struct user_namespace *user_ns,
struct smb_acl *pndacl, struct smb_acl *pndacl,
struct smb_acl *nt_dacl, struct smb_acl *nt_dacl,
unsigned int aces_size,
const struct smb_sid *pownersid, const struct smb_sid *pownersid,
const struct smb_sid *pgrpsid, const struct smb_sid *pgrpsid,
struct smb_fattr *fattr) struct smb_fattr *fattr)
...@@ -703,9 +704,19 @@ static void set_ntacl_dacl(struct user_namespace *user_ns, ...@@ -703,9 +704,19 @@ static void set_ntacl_dacl(struct user_namespace *user_ns,
if (nt_num_aces) { if (nt_num_aces) {
ntace = (struct smb_ace *)((char *)nt_dacl + sizeof(struct smb_acl)); ntace = (struct smb_ace *)((char *)nt_dacl + sizeof(struct smb_acl));
for (i = 0; i < nt_num_aces; i++) { for (i = 0; i < nt_num_aces; i++) {
memcpy((char *)pndace + size, ntace, le16_to_cpu(ntace->size)); unsigned short nt_ace_size;
size += le16_to_cpu(ntace->size);
ntace = (struct smb_ace *)((char *)ntace + le16_to_cpu(ntace->size)); if (offsetof(struct smb_ace, access_req) > aces_size)
break;
nt_ace_size = le16_to_cpu(ntace->size);
if (nt_ace_size > aces_size)
break;
memcpy((char *)pndace + size, ntace, nt_ace_size);
size += nt_ace_size;
aces_size -= nt_ace_size;
ntace = (struct smb_ace *)((char *)ntace + nt_ace_size);
num_aces++; num_aces++;
} }
} }
...@@ -878,7 +889,7 @@ int parse_sec_desc(struct user_namespace *user_ns, struct smb_ntsd *pntsd, ...@@ -878,7 +889,7 @@ int parse_sec_desc(struct user_namespace *user_ns, struct smb_ntsd *pntsd,
/* Convert permission bits from mode to equivalent CIFS ACL */ /* Convert permission bits from mode to equivalent CIFS ACL */
int build_sec_desc(struct user_namespace *user_ns, int build_sec_desc(struct user_namespace *user_ns,
struct smb_ntsd *pntsd, struct smb_ntsd *ppntsd, struct smb_ntsd *pntsd, struct smb_ntsd *ppntsd,
int addition_info, __u32 *secdesclen, int ppntsd_size, int addition_info, __u32 *secdesclen,
struct smb_fattr *fattr) struct smb_fattr *fattr)
{ {
int rc = 0; int rc = 0;
...@@ -938,15 +949,25 @@ int build_sec_desc(struct user_namespace *user_ns, ...@@ -938,15 +949,25 @@ int build_sec_desc(struct user_namespace *user_ns,
if (!ppntsd) { if (!ppntsd) {
set_mode_dacl(user_ns, dacl_ptr, fattr); set_mode_dacl(user_ns, dacl_ptr, fattr);
} else if (!ppntsd->dacloffset) {
goto out;
} else { } else {
struct smb_acl *ppdacl_ptr; struct smb_acl *ppdacl_ptr;
unsigned int dacl_offset = le32_to_cpu(ppntsd->dacloffset);
int ppdacl_size, ntacl_size = ppntsd_size - dacl_offset;
if (!dacl_offset ||
(dacl_offset + sizeof(struct smb_acl) > ppntsd_size))
goto out;
ppdacl_ptr = (struct smb_acl *)((char *)ppntsd + dacl_offset);
ppdacl_size = le16_to_cpu(ppdacl_ptr->size);
if (ppdacl_size > ntacl_size ||
ppdacl_size < sizeof(struct smb_acl))
goto out;
ppdacl_ptr = (struct smb_acl *)((char *)ppntsd +
le32_to_cpu(ppntsd->dacloffset));
set_ntacl_dacl(user_ns, dacl_ptr, ppdacl_ptr, set_ntacl_dacl(user_ns, dacl_ptr, ppdacl_ptr,
nowner_sid_ptr, ngroup_sid_ptr, fattr); ntacl_size - sizeof(struct smb_acl),
nowner_sid_ptr, ngroup_sid_ptr,
fattr);
} }
pntsd->dacloffset = cpu_to_le32(offset); pntsd->dacloffset = cpu_to_le32(offset);
offset += le16_to_cpu(dacl_ptr->size); offset += le16_to_cpu(dacl_ptr->size);
...@@ -980,24 +1001,31 @@ int smb_inherit_dacl(struct ksmbd_conn *conn, ...@@ -980,24 +1001,31 @@ int smb_inherit_dacl(struct ksmbd_conn *conn,
struct smb_sid owner_sid, group_sid; struct smb_sid owner_sid, group_sid;
struct dentry *parent = path->dentry->d_parent; struct dentry *parent = path->dentry->d_parent;
struct user_namespace *user_ns = mnt_user_ns(path->mnt); struct user_namespace *user_ns = mnt_user_ns(path->mnt);
int inherited_flags = 0, flags = 0, i, ace_cnt = 0, nt_size = 0; int inherited_flags = 0, flags = 0, i, ace_cnt = 0, nt_size = 0, pdacl_size;
int rc = 0, num_aces, dacloffset, pntsd_type, acl_len; int rc = 0, num_aces, dacloffset, pntsd_type, pntsd_size, acl_len, aces_size;
char *aces_base; char *aces_base;
bool is_dir = S_ISDIR(d_inode(path->dentry)->i_mode); bool is_dir = S_ISDIR(d_inode(path->dentry)->i_mode);
acl_len = ksmbd_vfs_get_sd_xattr(conn, user_ns, pntsd_size = ksmbd_vfs_get_sd_xattr(conn, user_ns,
parent, &parent_pntsd); parent, &parent_pntsd);
if (acl_len <= 0) if (pntsd_size <= 0)
return -ENOENT; return -ENOENT;
dacloffset = le32_to_cpu(parent_pntsd->dacloffset); dacloffset = le32_to_cpu(parent_pntsd->dacloffset);
if (!dacloffset) { if (!dacloffset || (dacloffset + sizeof(struct smb_acl) > pntsd_size)) {
rc = -EINVAL; rc = -EINVAL;
goto free_parent_pntsd; goto free_parent_pntsd;
} }
parent_pdacl = (struct smb_acl *)((char *)parent_pntsd + dacloffset); parent_pdacl = (struct smb_acl *)((char *)parent_pntsd + dacloffset);
acl_len = pntsd_size - dacloffset;
num_aces = le32_to_cpu(parent_pdacl->num_aces); num_aces = le32_to_cpu(parent_pdacl->num_aces);
pntsd_type = le16_to_cpu(parent_pntsd->type); pntsd_type = le16_to_cpu(parent_pntsd->type);
pdacl_size = le16_to_cpu(parent_pdacl->size);
if (pdacl_size > acl_len || pdacl_size < sizeof(struct smb_acl)) {
rc = -EINVAL;
goto free_parent_pntsd;
}
aces_base = kmalloc(sizeof(struct smb_ace) * num_aces * 2, GFP_KERNEL); aces_base = kmalloc(sizeof(struct smb_ace) * num_aces * 2, GFP_KERNEL);
if (!aces_base) { if (!aces_base) {
...@@ -1008,11 +1036,23 @@ int smb_inherit_dacl(struct ksmbd_conn *conn, ...@@ -1008,11 +1036,23 @@ int smb_inherit_dacl(struct ksmbd_conn *conn,
aces = (struct smb_ace *)aces_base; aces = (struct smb_ace *)aces_base;
parent_aces = (struct smb_ace *)((char *)parent_pdacl + parent_aces = (struct smb_ace *)((char *)parent_pdacl +
sizeof(struct smb_acl)); sizeof(struct smb_acl));
aces_size = acl_len - sizeof(struct smb_acl);
if (pntsd_type & DACL_AUTO_INHERITED) if (pntsd_type & DACL_AUTO_INHERITED)
inherited_flags = INHERITED_ACE; inherited_flags = INHERITED_ACE;
for (i = 0; i < num_aces; i++) { for (i = 0; i < num_aces; i++) {
int pace_size;
if (offsetof(struct smb_ace, access_req) > aces_size)
break;
pace_size = le16_to_cpu(parent_aces->size);
if (pace_size > aces_size)
break;
aces_size -= pace_size;
flags = parent_aces->flags; flags = parent_aces->flags;
if (!smb_inherit_flags(flags, is_dir)) if (!smb_inherit_flags(flags, is_dir))
goto pass; goto pass;
...@@ -1057,8 +1097,7 @@ int smb_inherit_dacl(struct ksmbd_conn *conn, ...@@ -1057,8 +1097,7 @@ int smb_inherit_dacl(struct ksmbd_conn *conn,
aces = (struct smb_ace *)((char *)aces + le16_to_cpu(aces->size)); aces = (struct smb_ace *)((char *)aces + le16_to_cpu(aces->size));
ace_cnt++; ace_cnt++;
pass: pass:
parent_aces = parent_aces = (struct smb_ace *)((char *)parent_aces + pace_size);
(struct smb_ace *)((char *)parent_aces + le16_to_cpu(parent_aces->size));
} }
if (nt_size > 0) { if (nt_size > 0) {
...@@ -1153,7 +1192,7 @@ int smb_check_perm_dacl(struct ksmbd_conn *conn, struct path *path, ...@@ -1153,7 +1192,7 @@ int smb_check_perm_dacl(struct ksmbd_conn *conn, struct path *path,
struct smb_ntsd *pntsd = NULL; struct smb_ntsd *pntsd = NULL;
struct smb_acl *pdacl; struct smb_acl *pdacl;
struct posix_acl *posix_acls; struct posix_acl *posix_acls;
int rc = 0, acl_size; int rc = 0, pntsd_size, acl_size, aces_size, pdacl_size, dacl_offset;
struct smb_sid sid; struct smb_sid sid;
int granted = le32_to_cpu(*pdaccess & ~FILE_MAXIMAL_ACCESS_LE); int granted = le32_to_cpu(*pdaccess & ~FILE_MAXIMAL_ACCESS_LE);
struct smb_ace *ace; struct smb_ace *ace;
...@@ -1162,37 +1201,33 @@ int smb_check_perm_dacl(struct ksmbd_conn *conn, struct path *path, ...@@ -1162,37 +1201,33 @@ int smb_check_perm_dacl(struct ksmbd_conn *conn, struct path *path,
struct smb_ace *others_ace = NULL; struct smb_ace *others_ace = NULL;
struct posix_acl_entry *pa_entry; struct posix_acl_entry *pa_entry;
unsigned int sid_type = SIDOWNER; unsigned int sid_type = SIDOWNER;
char *end_of_acl; unsigned short ace_size;
ksmbd_debug(SMB, "check permission using windows acl\n"); ksmbd_debug(SMB, "check permission using windows acl\n");
acl_size = ksmbd_vfs_get_sd_xattr(conn, user_ns, pntsd_size = ksmbd_vfs_get_sd_xattr(conn, user_ns,
path->dentry, &pntsd); path->dentry, &pntsd);
if (acl_size <= 0 || !pntsd || !pntsd->dacloffset) { if (pntsd_size <= 0 || !pntsd)
kfree(pntsd); goto err_out;
return 0;
} dacl_offset = le32_to_cpu(pntsd->dacloffset);
if (!dacl_offset ||
(dacl_offset + sizeof(struct smb_acl) > pntsd_size))
goto err_out;
pdacl = (struct smb_acl *)((char *)pntsd + le32_to_cpu(pntsd->dacloffset)); pdacl = (struct smb_acl *)((char *)pntsd + le32_to_cpu(pntsd->dacloffset));
end_of_acl = ((char *)pntsd) + acl_size; acl_size = pntsd_size - dacl_offset;
if (end_of_acl <= (char *)pdacl) { pdacl_size = le16_to_cpu(pdacl->size);
kfree(pntsd);
return 0;
}
if (end_of_acl < (char *)pdacl + le16_to_cpu(pdacl->size) || if (pdacl_size > acl_size || pdacl_size < sizeof(struct smb_acl))
le16_to_cpu(pdacl->size) < sizeof(struct smb_acl)) { goto err_out;
kfree(pntsd);
return 0;
}
if (!pdacl->num_aces) { if (!pdacl->num_aces) {
if (!(le16_to_cpu(pdacl->size) - sizeof(struct smb_acl)) && if (!(pdacl_size - sizeof(struct smb_acl)) &&
*pdaccess & ~(FILE_READ_CONTROL_LE | FILE_WRITE_DAC_LE)) { *pdaccess & ~(FILE_READ_CONTROL_LE | FILE_WRITE_DAC_LE)) {
rc = -EACCES; rc = -EACCES;
goto err_out; goto err_out;
} }
kfree(pntsd); goto err_out;
return 0;
} }
if (*pdaccess & FILE_MAXIMAL_ACCESS_LE) { if (*pdaccess & FILE_MAXIMAL_ACCESS_LE) {
...@@ -1200,11 +1235,16 @@ int smb_check_perm_dacl(struct ksmbd_conn *conn, struct path *path, ...@@ -1200,11 +1235,16 @@ int smb_check_perm_dacl(struct ksmbd_conn *conn, struct path *path,
DELETE; DELETE;
ace = (struct smb_ace *)((char *)pdacl + sizeof(struct smb_acl)); ace = (struct smb_ace *)((char *)pdacl + sizeof(struct smb_acl));
aces_size = acl_size - sizeof(struct smb_acl);
for (i = 0; i < le32_to_cpu(pdacl->num_aces); i++) { for (i = 0; i < le32_to_cpu(pdacl->num_aces); i++) {
if (offsetof(struct smb_ace, access_req) > aces_size)
break;
ace_size = le16_to_cpu(ace->size);
if (ace_size > aces_size)
break;
aces_size -= ace_size;
granted |= le32_to_cpu(ace->access_req); granted |= le32_to_cpu(ace->access_req);
ace = (struct smb_ace *)((char *)ace + le16_to_cpu(ace->size)); ace = (struct smb_ace *)((char *)ace + le16_to_cpu(ace->size));
if (end_of_acl < (char *)ace)
goto err_out;
} }
if (!pdacl->num_aces) if (!pdacl->num_aces)
...@@ -1216,7 +1256,15 @@ int smb_check_perm_dacl(struct ksmbd_conn *conn, struct path *path, ...@@ -1216,7 +1256,15 @@ int smb_check_perm_dacl(struct ksmbd_conn *conn, struct path *path,
id_to_sid(uid, sid_type, &sid); id_to_sid(uid, sid_type, &sid);
ace = (struct smb_ace *)((char *)pdacl + sizeof(struct smb_acl)); ace = (struct smb_ace *)((char *)pdacl + sizeof(struct smb_acl));
aces_size = acl_size - sizeof(struct smb_acl);
for (i = 0; i < le32_to_cpu(pdacl->num_aces); i++) { for (i = 0; i < le32_to_cpu(pdacl->num_aces); i++) {
if (offsetof(struct smb_ace, access_req) > aces_size)
break;
ace_size = le16_to_cpu(ace->size);
if (ace_size > aces_size)
break;
aces_size -= ace_size;
if (!compare_sids(&sid, &ace->sid) || if (!compare_sids(&sid, &ace->sid) ||
!compare_sids(&sid_unix_NFS_mode, &ace->sid)) { !compare_sids(&sid_unix_NFS_mode, &ace->sid)) {
found = 1; found = 1;
...@@ -1226,8 +1274,6 @@ int smb_check_perm_dacl(struct ksmbd_conn *conn, struct path *path, ...@@ -1226,8 +1274,6 @@ int smb_check_perm_dacl(struct ksmbd_conn *conn, struct path *path,
others_ace = ace; others_ace = ace;
ace = (struct smb_ace *)((char *)ace + le16_to_cpu(ace->size)); ace = (struct smb_ace *)((char *)ace + le16_to_cpu(ace->size));
if (end_of_acl < (char *)ace)
goto err_out;
} }
if (*pdaccess & FILE_MAXIMAL_ACCESS_LE && found) { if (*pdaccess & FILE_MAXIMAL_ACCESS_LE && found) {
......
...@@ -193,7 +193,7 @@ struct posix_acl_state { ...@@ -193,7 +193,7 @@ struct posix_acl_state {
int parse_sec_desc(struct user_namespace *user_ns, struct smb_ntsd *pntsd, int parse_sec_desc(struct user_namespace *user_ns, struct smb_ntsd *pntsd,
int acl_len, struct smb_fattr *fattr); int acl_len, struct smb_fattr *fattr);
int build_sec_desc(struct user_namespace *user_ns, struct smb_ntsd *pntsd, int build_sec_desc(struct user_namespace *user_ns, struct smb_ntsd *pntsd,
struct smb_ntsd *ppntsd, int addition_info, struct smb_ntsd *ppntsd, int ppntsd_size, int addition_info,
__u32 *secdesclen, struct smb_fattr *fattr); __u32 *secdesclen, struct smb_fattr *fattr);
int init_acl_state(struct posix_acl_state *state, int cnt); int init_acl_state(struct posix_acl_state *state, int cnt);
void free_acl_state(struct posix_acl_state *state); void free_acl_state(struct posix_acl_state *state);
......
...@@ -481,12 +481,11 @@ int ksmbd_vfs_write(struct ksmbd_work *work, struct ksmbd_file *fp, ...@@ -481,12 +481,11 @@ int ksmbd_vfs_write(struct ksmbd_work *work, struct ksmbd_file *fp,
char *buf, size_t count, loff_t *pos, bool sync, char *buf, size_t count, loff_t *pos, bool sync,
ssize_t *written) ssize_t *written)
{ {
struct ksmbd_session *sess = work->sess;
struct file *filp; struct file *filp;
loff_t offset = *pos; loff_t offset = *pos;
int err = 0; int err = 0;
if (sess->conn->connection_type) { if (work->conn->connection_type) {
if (!(fp->daccess & FILE_WRITE_DATA_LE)) { if (!(fp->daccess & FILE_WRITE_DATA_LE)) {
pr_err("no right to write(%pd)\n", pr_err("no right to write(%pd)\n",
fp->filp->f_path.dentry); fp->filp->f_path.dentry);
...@@ -1540,6 +1539,11 @@ int ksmbd_vfs_get_sd_xattr(struct ksmbd_conn *conn, ...@@ -1540,6 +1539,11 @@ int ksmbd_vfs_get_sd_xattr(struct ksmbd_conn *conn,
} }
*pntsd = acl.sd_buf; *pntsd = acl.sd_buf;
if (acl.sd_size < sizeof(struct smb_ntsd)) {
pr_err("sd size is invalid\n");
goto out_free;
}
(*pntsd)->osidoffset = cpu_to_le32(le32_to_cpu((*pntsd)->osidoffset) - (*pntsd)->osidoffset = cpu_to_le32(le32_to_cpu((*pntsd)->osidoffset) -
NDR_NTSD_OFFSETOF); NDR_NTSD_OFFSETOF);
(*pntsd)->gsidoffset = cpu_to_le32(le32_to_cpu((*pntsd)->gsidoffset) - (*pntsd)->gsidoffset = cpu_to_le32(le32_to_cpu((*pntsd)->gsidoffset) -
......
...@@ -569,7 +569,7 @@ struct ksmbd_file *ksmbd_open_fd(struct ksmbd_work *work, struct file *filp) ...@@ -569,7 +569,7 @@ struct ksmbd_file *ksmbd_open_fd(struct ksmbd_work *work, struct file *filp)
atomic_set(&fp->refcount, 1); atomic_set(&fp->refcount, 1);
fp->filp = filp; fp->filp = filp;
fp->conn = work->sess->conn; fp->conn = work->conn;
fp->tcon = work->tcon; fp->tcon = work->tcon;
fp->volatile_id = KSMBD_NO_FID; fp->volatile_id = KSMBD_NO_FID;
fp->persistent_id = KSMBD_NO_FID; fp->persistent_id = KSMBD_NO_FID;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment