Merge tag 'nfs-for-6.5-1' of git://git.linux-nfs.org/projects/trondmy/linux-nfs

Pull NFS client updates from Trond Myklebust:
"Stable fixes and other bugfixes:

   - nfs: don't report STATX_BTIME in ->getattr

   - Revert 'NFSv4: Retry LOCK on OLD_STATEID during delegation return'
     since it breaks NFSv4 state recovery.

   - NFSv4.1: freeze the session table upon receiving NFS4ERR_BADSESSION

   - Fix the NFSv4.2 xattr cache shrinker_id

   - Force a ctime update after a NFSv4.2 SETXATTR call

  Features and cleanups:

   - NFS and RPC over TLS client code from Chuck Lever

   - Support for use of abstract unix socket addresses with the rpcbind
     daemon

   - Sysfs API to allow shutdown of the kernel RPC client and prevent
     umount() hangs if the server is known to be permanently down

   - XDR cleanups from Anna"

* tag 'nfs-for-6.5-1' of git://git.linux-nfs.org/projects/trondmy/linux-nfs: (33 commits)
  Revert "NFSv4: Retry LOCK on OLD_STATEID during delegation return"
  NFS: Don't cleanup sysfs superblock entry if uninitialized
  nfs: don't report STATX_BTIME in ->getattr
  NFSv4.1: freeze the session table upon receiving NFS4ERR_BADSESSION
  NFSv4.2: fix wrong shrinker_id
  NFSv4: Clean up some shutdown loops
  NFS: Cancel all existing RPC tasks when shutdown
  NFS: add sysfs shutdown knob
  NFS: add a sysfs link to the acl rpc_client
  NFS: add a sysfs link to the lockd rpc_client
  NFS: Add sysfs links to sunrpc clients for nfs_clients
  NFS: add superblock sysfs entries
  NFS: Make all of /sys/fs/nfs network-namespace unique
  NFS: Open-code the nfs_kset kset_create_and_add()
  NFS: rename nfs_client_kobj to nfs_net_kobj
  NFS: rename nfs_client_kset to nfs_kset
  NFS: Add an "xprtsec=" NFS mount option
  NFS: Have struct nfs_client carry a TLS policy field
  SUNRPC: Add a TCP-with-TLS RPC transport class
  SUNRPC: Capture CMSG metadata on client-side receive
  ...
This commit is contained in:
Linus Torvalds
2023-07-01 14:38:25 -07:00
31 changed files with 1562 additions and 431 deletions

View File

@@ -9,7 +9,7 @@ obj-$(CONFIG_SUNRPC_GSS) += auth_gss/
obj-$(CONFIG_SUNRPC_XPRT_RDMA) += xprtrdma/
sunrpc-y := clnt.o xprt.o socklib.o xprtsock.o sched.o \
auth.o auth_null.o auth_unix.o \
auth.o auth_null.o auth_tls.o auth_unix.o \
svc.o svcsock.o svcauth.o svcauth_unix.o \
addr.o rpcb_clnt.o timer.o xdr.o \
sunrpc_syms.o cache.o rpc_pipe.o sysfs.o \

View File

@@ -32,7 +32,7 @@ static unsigned int auth_hashbits = RPC_CREDCACHE_DEFAULT_HASHBITS;
static const struct rpc_authops __rcu *auth_flavors[RPC_AUTH_MAXFLAVOR] = {
[RPC_AUTH_NULL] = (const struct rpc_authops __force __rcu *)&authnull_ops,
[RPC_AUTH_UNIX] = (const struct rpc_authops __force __rcu *)&authunix_ops,
NULL, /* others can be loadable modules */
[RPC_AUTH_TLS] = (const struct rpc_authops __force __rcu *)&authtls_ops,
};
static LIST_HEAD(cred_unused);

175
net/sunrpc/auth_tls.c Normal file
View File

@@ -0,0 +1,175 @@
// SPDX-License-Identifier: GPL-2.0-only
/*
* Copyright (c) 2021, 2022 Oracle. All rights reserved.
*
* The AUTH_TLS credential is used only to probe a remote peer
* for RPC-over-TLS support.
*/
#include <linux/types.h>
#include <linux/module.h>
#include <linux/sunrpc/clnt.h>
static const char *starttls_token = "STARTTLS";
static const size_t starttls_len = 8;
static struct rpc_auth tls_auth;
static struct rpc_cred tls_cred;
static void tls_encode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr,
const void *obj)
{
}
static int tls_decode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr,
void *obj)
{
return 0;
}
static const struct rpc_procinfo rpcproc_tls_probe = {
.p_encode = tls_encode_probe,
.p_decode = tls_decode_probe,
};
static void rpc_tls_probe_call_prepare(struct rpc_task *task, void *data)
{
task->tk_flags &= ~RPC_TASK_NO_RETRANS_TIMEOUT;
rpc_call_start(task);
}
static void rpc_tls_probe_call_done(struct rpc_task *task, void *data)
{
}
static const struct rpc_call_ops rpc_tls_probe_ops = {
.rpc_call_prepare = rpc_tls_probe_call_prepare,
.rpc_call_done = rpc_tls_probe_call_done,
};
static int tls_probe(struct rpc_clnt *clnt)
{
struct rpc_message msg = {
.rpc_proc = &rpcproc_tls_probe,
};
struct rpc_task_setup task_setup_data = {
.rpc_client = clnt,
.rpc_message = &msg,
.rpc_op_cred = &tls_cred,
.callback_ops = &rpc_tls_probe_ops,
.flags = RPC_TASK_SOFT | RPC_TASK_SOFTCONN,
};
struct rpc_task *task;
int status;
task = rpc_run_task(&task_setup_data);
if (IS_ERR(task))
return PTR_ERR(task);
status = task->tk_status;
rpc_put_task(task);
return status;
}
static struct rpc_auth *tls_create(const struct rpc_auth_create_args *args,
struct rpc_clnt *clnt)
{
refcount_inc(&tls_auth.au_count);
return &tls_auth;
}
static void tls_destroy(struct rpc_auth *auth)
{
}
static struct rpc_cred *tls_lookup_cred(struct rpc_auth *auth,
struct auth_cred *acred, int flags)
{
return get_rpccred(&tls_cred);
}
static void tls_destroy_cred(struct rpc_cred *cred)
{
}
static int tls_match(struct auth_cred *acred, struct rpc_cred *cred, int taskflags)
{
return 1;
}
static int tls_marshal(struct rpc_task *task, struct xdr_stream *xdr)
{
__be32 *p;
p = xdr_reserve_space(xdr, 4 * XDR_UNIT);
if (!p)
return -EMSGSIZE;
/* Credential */
*p++ = rpc_auth_tls;
*p++ = xdr_zero;
/* Verifier */
*p++ = rpc_auth_null;
*p = xdr_zero;
return 0;
}
static int tls_refresh(struct rpc_task *task)
{
set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags);
return 0;
}
static int tls_validate(struct rpc_task *task, struct xdr_stream *xdr)
{
__be32 *p;
void *str;
p = xdr_inline_decode(xdr, XDR_UNIT);
if (!p)
return -EIO;
if (*p != rpc_auth_null)
return -EIO;
if (xdr_stream_decode_opaque_inline(xdr, &str, starttls_len) != starttls_len)
return -EIO;
if (memcmp(str, starttls_token, starttls_len))
return -EIO;
return 0;
}
const struct rpc_authops authtls_ops = {
.owner = THIS_MODULE,
.au_flavor = RPC_AUTH_TLS,
.au_name = "NULL",
.create = tls_create,
.destroy = tls_destroy,
.lookup_cred = tls_lookup_cred,
.ping = tls_probe,
};
static struct rpc_auth tls_auth = {
.au_cslack = NUL_CALLSLACK,
.au_rslack = NUL_REPLYSLACK,
.au_verfsize = NUL_REPLYSLACK,
.au_ralign = NUL_REPLYSLACK,
.au_ops = &authtls_ops,
.au_flavor = RPC_AUTH_TLS,
.au_count = REFCOUNT_INIT(1),
};
static const struct rpc_credops tls_credops = {
.cr_name = "AUTH_TLS",
.crdestroy = tls_destroy_cred,
.crmatch = tls_match,
.crmarshal = tls_marshal,
.crwrap_req = rpcauth_wrap_req_encode,
.crrefresh = tls_refresh,
.crvalidate = tls_validate,
.crunwrap_resp = rpcauth_unwrap_resp_decode,
};
static struct rpc_cred tls_cred = {
.cr_lru = LIST_HEAD_INIT(tls_cred.cr_lru),
.cr_auth = &tls_auth,
.cr_ops = &tls_credops,
.cr_count = REFCOUNT_INIT(2),
.cr_flags = 1UL << RPCAUTH_CRED_UPTODATE,
};

View File

@@ -385,6 +385,7 @@ static struct rpc_clnt * rpc_new_client(const struct rpc_create_args *args,
if (!clnt)
goto out_err;
clnt->cl_parent = parent ? : clnt;
clnt->cl_xprtsec = args->xprtsec;
err = rpc_alloc_clid(clnt);
if (err)
@@ -434,7 +435,7 @@ static struct rpc_clnt * rpc_new_client(const struct rpc_create_args *args,
if (parent)
refcount_inc(&parent->cl_count);
trace_rpc_clnt_new(clnt, xprt, program->name, args->servername);
trace_rpc_clnt_new(clnt, xprt, args);
return clnt;
out_no_path:
@@ -532,6 +533,7 @@ struct rpc_clnt *rpc_create(struct rpc_create_args *args)
.addrlen = args->addrsize,
.servername = args->servername,
.bc_xprt = args->bc_xprt,
.xprtsec = args->xprtsec,
};
char servername[48];
struct rpc_clnt *clnt;
@@ -565,8 +567,12 @@ struct rpc_clnt *rpc_create(struct rpc_create_args *args)
servername[0] = '\0';
switch (args->address->sa_family) {
case AF_LOCAL:
snprintf(servername, sizeof(servername), "%s",
sun->sun_path);
if (sun->sun_path[0])
snprintf(servername, sizeof(servername), "%s",
sun->sun_path);
else
snprintf(servername, sizeof(servername), "@%s",
sun->sun_path+1);
break;
case AF_INET:
snprintf(servername, sizeof(servername), "%pI4",
@@ -727,6 +733,7 @@ int rpc_switch_client_transport(struct rpc_clnt *clnt,
struct rpc_clnt *parent;
int err;
args->xprtsec = clnt->cl_xprtsec;
xprt = xprt_create_transport(args);
if (IS_ERR(xprt))
return PTR_ERR(xprt);
@@ -1717,6 +1724,11 @@ call_start(struct rpc_task *task)
trace_rpc_request(task);
if (task->tk_client->cl_shutdown) {
rpc_call_rpcerror(task, -EIO);
return;
}
/* Increment call count (version might not be valid for ping) */
if (clnt->cl_program->version[clnt->cl_vers])
clnt->cl_program->version[clnt->cl_vers]->counts[idx]++;
@@ -2826,6 +2838,9 @@ static int rpc_ping(struct rpc_clnt *clnt)
struct rpc_task *task;
int status;
if (clnt->cl_auth->au_ops->ping)
return clnt->cl_auth->au_ops->ping(clnt);
task = rpc_call_null_helper(clnt, NULL, NULL, 0, NULL, NULL);
if (IS_ERR(task))
return PTR_ERR(task);
@@ -3046,6 +3061,7 @@ int rpc_clnt_add_xprt(struct rpc_clnt *clnt,
if (!xprtargs->ident)
xprtargs->ident = ident;
xprtargs->xprtsec = clnt->cl_xprtsec;
xprt = xprt_create_transport(xprtargs);
if (IS_ERR(xprt)) {
ret = PTR_ERR(xprt);

View File

@@ -36,6 +36,7 @@
#include "netns.h"
#define RPCBIND_SOCK_PATHNAME "/var/run/rpcbind.sock"
#define RPCBIND_SOCK_ABSTRACT_NAME "\0/run/rpcbind.sock"
#define RPCBIND_PROGRAM (100000u)
#define RPCBIND_PORT (111u)
@@ -216,21 +217,22 @@ static void rpcb_set_local(struct net *net, struct rpc_clnt *clnt,
sn->rpcb_users = 1;
}
/* Evaluate to actual length of the `sockaddr_un' structure. */
# define SUN_LEN(ptr) (offsetof(struct sockaddr_un, sun_path) \
+ 1 + strlen((ptr)->sun_path + 1))
/*
* Returns zero on success, otherwise a negative errno value
* is returned.
*/
static int rpcb_create_local_unix(struct net *net)
static int rpcb_create_af_local(struct net *net,
const struct sockaddr_un *addr)
{
static const struct sockaddr_un rpcb_localaddr_rpcbind = {
.sun_family = AF_LOCAL,
.sun_path = RPCBIND_SOCK_PATHNAME,
};
struct rpc_create_args args = {
.net = net,
.protocol = XPRT_TRANSPORT_LOCAL,
.address = (struct sockaddr *)&rpcb_localaddr_rpcbind,
.addrsize = sizeof(rpcb_localaddr_rpcbind),
.address = (struct sockaddr *)addr,
.addrsize = SUN_LEN(addr),
.servername = "localhost",
.program = &rpcb_program,
.version = RPCBVERS_2,
@@ -269,6 +271,26 @@ out:
return result;
}
static int rpcb_create_local_abstract(struct net *net)
{
static const struct sockaddr_un rpcb_localaddr_abstract = {
.sun_family = AF_LOCAL,
.sun_path = RPCBIND_SOCK_ABSTRACT_NAME,
};
return rpcb_create_af_local(net, &rpcb_localaddr_abstract);
}
static int rpcb_create_local_unix(struct net *net)
{
static const struct sockaddr_un rpcb_localaddr_unix = {
.sun_family = AF_LOCAL,
.sun_path = RPCBIND_SOCK_PATHNAME,
};
return rpcb_create_af_local(net, &rpcb_localaddr_unix);
}
/*
* Returns zero on success, otherwise a negative errno value
* is returned.
@@ -332,7 +354,8 @@ int rpcb_create_local(struct net *net)
if (rpcb_get_local(net))
goto out;
if (rpcb_create_local_unix(net) != 0)
if (rpcb_create_local_abstract(net) != 0 &&
rpcb_create_local_unix(net) != 0)
result = rpcb_create_local_net(net);
out:

View File

@@ -239,6 +239,7 @@ static ssize_t rpc_sysfs_xprt_dstaddr_store(struct kobject *kobj,
if (!xprt)
return 0;
if (!(xprt->xprt_class->ident == XPRT_TRANSPORT_TCP ||
xprt->xprt_class->ident == XPRT_TRANSPORT_TCP_TLS ||
xprt->xprt_class->ident == XPRT_TRANSPORT_RDMA)) {
xprt_put(xprt);
return -EOPNOTSUPP;

View File

@@ -5,13 +5,6 @@
#ifndef __SUNRPC_SYSFS_H
#define __SUNRPC_SYSFS_H
struct rpc_sysfs_client {
struct kobject kobject;
struct net *net;
struct rpc_clnt *clnt;
struct rpc_xprt_switch *xprt_switch;
};
struct rpc_sysfs_xprt_switch {
struct kobject kobject;
struct net *net;

View File

@@ -47,6 +47,9 @@
#include <net/checksum.h>
#include <net/udp.h>
#include <net/tcp.h>
#include <net/tls.h>
#include <net/handshake.h>
#include <linux/bvec.h>
#include <linux/highmem.h>
#include <linux/uio.h>
@@ -96,6 +99,7 @@ static struct ctl_table_header *sunrpc_table_header;
static struct xprt_class xs_local_transport;
static struct xprt_class xs_udp_transport;
static struct xprt_class xs_tcp_transport;
static struct xprt_class xs_tcp_tls_transport;
static struct xprt_class xs_bc_tcp_transport;
/*
@@ -187,6 +191,11 @@ static struct ctl_table xs_tunables_table[] = {
*/
#define XS_IDLE_DISC_TO (5U * 60 * HZ)
/*
* TLS handshake timeout.
*/
#define XS_TLS_HANDSHAKE_TO (10U * HZ)
#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
# undef RPC_DEBUG_DATA
# define RPCDBG_FACILITY RPCDBG_TRANS
@@ -253,7 +262,12 @@ static void xs_format_common_peer_addresses(struct rpc_xprt *xprt)
switch (sap->sa_family) {
case AF_LOCAL:
sun = xs_addr_un(xprt);
strscpy(buf, sun->sun_path, sizeof(buf));
if (sun->sun_path[0]) {
strscpy(buf, sun->sun_path, sizeof(buf));
} else {
buf[0] = '@';
strscpy(buf+1, sun->sun_path+1, sizeof(buf)-1);
}
xprt->address_strings[RPC_DISPLAY_ADDR] =
kstrdup(buf, GFP_KERNEL);
break;
@@ -342,13 +356,56 @@ xs_alloc_sparse_pages(struct xdr_buf *buf, size_t want, gfp_t gfp)
return want;
}
static int
xs_sock_process_cmsg(struct socket *sock, struct msghdr *msg,
struct cmsghdr *cmsg, int ret)
{
if (cmsg->cmsg_level == SOL_TLS &&
cmsg->cmsg_type == TLS_GET_RECORD_TYPE) {
u8 content_type = *((u8 *)CMSG_DATA(cmsg));
switch (content_type) {
case TLS_RECORD_TYPE_DATA:
/* TLS sets EOR at the end of each application data
* record, even though there might be more frames
* waiting to be decrypted.
*/
msg->msg_flags &= ~MSG_EOR;
break;
case TLS_RECORD_TYPE_ALERT:
ret = -ENOTCONN;
break;
default:
ret = -EAGAIN;
}
}
return ret;
}
static int
xs_sock_recv_cmsg(struct socket *sock, struct msghdr *msg, int flags)
{
union {
struct cmsghdr cmsg;
u8 buf[CMSG_SPACE(sizeof(u8))];
} u;
int ret;
msg->msg_control = &u;
msg->msg_controllen = sizeof(u);
ret = sock_recvmsg(sock, msg, flags);
if (msg->msg_controllen != sizeof(u))
ret = xs_sock_process_cmsg(sock, msg, &u.cmsg, ret);
return ret;
}
static ssize_t
xs_sock_recvmsg(struct socket *sock, struct msghdr *msg, int flags, size_t seek)
{
ssize_t ret;
if (seek != 0)
iov_iter_advance(&msg->msg_iter, seek);
ret = sock_recvmsg(sock, msg, flags);
ret = xs_sock_recv_cmsg(sock, msg, flags);
return ret > 0 ? ret + seek : ret;
}
@@ -374,7 +431,7 @@ xs_read_discard(struct socket *sock, struct msghdr *msg, int flags,
size_t count)
{
iov_iter_discard(&msg->msg_iter, ITER_DEST, count);
return sock_recvmsg(sock, msg, flags);
return xs_sock_recv_cmsg(sock, msg, flags);
}
#if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE
@@ -695,6 +752,8 @@ static void xs_poll_check_readable(struct sock_xprt *transport)
{
clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state);
if (test_bit(XPRT_SOCK_IGNORE_RECV, &transport->sock_state))
return;
if (!xs_poll_socket_readable(transport))
return;
if (!test_and_set_bit(XPRT_SOCK_DATA_READY, &transport->sock_state))
@@ -1191,6 +1250,8 @@ static void xs_reset_transport(struct sock_xprt *transport)
if (atomic_read(&transport->xprt.swapper))
sk_clear_memalloc(sk);
tls_handshake_cancel(sk);
kernel_sock_shutdown(sock, SHUT_RDWR);
mutex_lock(&transport->recv_mutex);
@@ -1380,6 +1441,10 @@ static void xs_data_ready(struct sock *sk)
trace_xs_data_ready(xprt);
transport->old_data_ready(sk);
if (test_bit(XPRT_SOCK_IGNORE_RECV, &transport->sock_state))
return;
/* Any data means we had a useful conversation, so
* then we don't need to delay the next reconnect
*/
@@ -2360,6 +2425,267 @@ out_unlock:
current_restore_flags(pflags, PF_MEMALLOC);
}
/*
* Transfer the connected socket to @upper_transport, then mark that
* xprt CONNECTED.
*/
static int xs_tcp_tls_finish_connecting(struct rpc_xprt *lower_xprt,
struct sock_xprt *upper_transport)
{
struct sock_xprt *lower_transport =
container_of(lower_xprt, struct sock_xprt, xprt);
struct rpc_xprt *upper_xprt = &upper_transport->xprt;
if (!upper_transport->inet) {
struct socket *sock = lower_transport->sock;
struct sock *sk = sock->sk;
/* Avoid temporary address, they are bad for long-lived
* connections such as NFS mounts.
* RFC4941, section 3.6 suggests that:
* Individual applications, which have specific
* knowledge about the normal duration of connections,
* MAY override this as appropriate.
*/
if (xs_addr(upper_xprt)->sa_family == PF_INET6)
ip6_sock_set_addr_preferences(sk, IPV6_PREFER_SRC_PUBLIC);
xs_tcp_set_socket_timeouts(upper_xprt, sock);
tcp_sock_set_nodelay(sk);
lock_sock(sk);
/* @sk is already connected, so it now has the RPC callbacks.
* Reach into @lower_transport to save the original ones.
*/
upper_transport->old_data_ready = lower_transport->old_data_ready;
upper_transport->old_state_change = lower_transport->old_state_change;
upper_transport->old_write_space = lower_transport->old_write_space;
upper_transport->old_error_report = lower_transport->old_error_report;
sk->sk_user_data = upper_xprt;
/* socket options */
sock_reset_flag(sk, SOCK_LINGER);
xprt_clear_connected(upper_xprt);
upper_transport->sock = sock;
upper_transport->inet = sk;
upper_transport->file = lower_transport->file;
release_sock(sk);
/* Reset lower_transport before shutting down its clnt */
mutex_lock(&lower_transport->recv_mutex);
lower_transport->inet = NULL;
lower_transport->sock = NULL;
lower_transport->file = NULL;
xprt_clear_connected(lower_xprt);
xs_sock_reset_connection_flags(lower_xprt);
xs_stream_reset_connect(lower_transport);
mutex_unlock(&lower_transport->recv_mutex);
}
if (!xprt_bound(upper_xprt))
return -ENOTCONN;
xs_set_memalloc(upper_xprt);
if (!xprt_test_and_set_connected(upper_xprt)) {
upper_xprt->connect_cookie++;
clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state);
xprt_clear_connecting(upper_xprt);
upper_xprt->stat.connect_count++;
upper_xprt->stat.connect_time += (long)jiffies -
upper_xprt->stat.connect_start;
xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING);
}
return 0;
}
/**
* xs_tls_handshake_done - TLS handshake completion handler
* @data: address of xprt to wake
* @status: status of handshake
* @peerid: serial number of key containing the remote's identity
*
*/
static void xs_tls_handshake_done(void *data, int status, key_serial_t peerid)
{
struct rpc_xprt *lower_xprt = data;
struct sock_xprt *lower_transport =
container_of(lower_xprt, struct sock_xprt, xprt);
lower_transport->xprt_err = status ? -EACCES : 0;
complete(&lower_transport->handshake_done);
xprt_put(lower_xprt);
}
static int xs_tls_handshake_sync(struct rpc_xprt *lower_xprt, struct xprtsec_parms *xprtsec)
{
struct sock_xprt *lower_transport =
container_of(lower_xprt, struct sock_xprt, xprt);
struct tls_handshake_args args = {
.ta_sock = lower_transport->sock,
.ta_done = xs_tls_handshake_done,
.ta_data = xprt_get(lower_xprt),
.ta_peername = lower_xprt->servername,
};
struct sock *sk = lower_transport->inet;
int rc;
init_completion(&lower_transport->handshake_done);
set_bit(XPRT_SOCK_IGNORE_RECV, &lower_transport->sock_state);
lower_transport->xprt_err = -ETIMEDOUT;
switch (xprtsec->policy) {
case RPC_XPRTSEC_TLS_ANON:
rc = tls_client_hello_anon(&args, GFP_KERNEL);
if (rc)
goto out_put_xprt;
break;
case RPC_XPRTSEC_TLS_X509:
args.ta_my_cert = xprtsec->cert_serial;
args.ta_my_privkey = xprtsec->privkey_serial;
rc = tls_client_hello_x509(&args, GFP_KERNEL);
if (rc)
goto out_put_xprt;
break;
default:
rc = -EACCES;
goto out_put_xprt;
}
rc = wait_for_completion_interruptible_timeout(&lower_transport->handshake_done,
XS_TLS_HANDSHAKE_TO);
if (rc <= 0) {
if (!tls_handshake_cancel(sk)) {
if (rc == 0)
rc = -ETIMEDOUT;
goto out_put_xprt;
}
}
rc = lower_transport->xprt_err;
out:
xs_stream_reset_connect(lower_transport);
clear_bit(XPRT_SOCK_IGNORE_RECV, &lower_transport->sock_state);
return rc;
out_put_xprt:
xprt_put(lower_xprt);
goto out;
}
/**
* xs_tcp_tls_setup_socket - establish a TLS session on a TCP socket
* @work: queued work item
*
* Invoked by a work queue tasklet.
*
* For RPC-with-TLS, there is a two-stage connection process.
*
* The "upper-layer xprt" is visible to the RPC consumer. Once it has
* been marked connected, the consumer knows that a TCP connection and
* a TLS session have been established.
*
* A "lower-layer xprt", created in this function, handles the mechanics
* of connecting the TCP socket, performing the RPC_AUTH_TLS probe, and
* then driving the TLS handshake. Once all that is complete, the upper
* layer xprt is marked connected.
*/
static void xs_tcp_tls_setup_socket(struct work_struct *work)
{
struct sock_xprt *upper_transport =
container_of(work, struct sock_xprt, connect_worker.work);
struct rpc_clnt *upper_clnt = upper_transport->clnt;
struct rpc_xprt *upper_xprt = &upper_transport->xprt;
struct rpc_create_args args = {
.net = upper_xprt->xprt_net,
.protocol = upper_xprt->prot,
.address = (struct sockaddr *)&upper_xprt->addr,
.addrsize = upper_xprt->addrlen,
.timeout = upper_clnt->cl_timeout,
.servername = upper_xprt->servername,
.program = upper_clnt->cl_program,
.prognumber = upper_clnt->cl_prog,
.version = upper_clnt->cl_vers,
.authflavor = RPC_AUTH_TLS,
.cred = upper_clnt->cl_cred,
.xprtsec = {
.policy = RPC_XPRTSEC_NONE,
},
};
unsigned int pflags = current->flags;
struct rpc_clnt *lower_clnt;
struct rpc_xprt *lower_xprt;
int status;
if (atomic_read(&upper_xprt->swapper))
current->flags |= PF_MEMALLOC;
xs_stream_start_connect(upper_transport);
/* This implicitly sends an RPC_AUTH_TLS probe */
lower_clnt = rpc_create(&args);
if (IS_ERR(lower_clnt)) {
trace_rpc_tls_unavailable(upper_clnt, upper_xprt);
clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state);
xprt_clear_connecting(upper_xprt);
xprt_wake_pending_tasks(upper_xprt, PTR_ERR(lower_clnt));
xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING);
goto out_unlock;
}
/* RPC_AUTH_TLS probe was successful. Try a TLS handshake on
* the lower xprt.
*/
rcu_read_lock();
lower_xprt = rcu_dereference(lower_clnt->cl_xprt);
rcu_read_unlock();
status = xs_tls_handshake_sync(lower_xprt, &upper_xprt->xprtsec);
if (status) {
trace_rpc_tls_not_started(upper_clnt, upper_xprt);
goto out_close;
}
status = xs_tcp_tls_finish_connecting(lower_xprt, upper_transport);
if (status)
goto out_close;
trace_rpc_socket_connect(upper_xprt, upper_transport->sock, 0);
if (!xprt_test_and_set_connected(upper_xprt)) {
upper_xprt->connect_cookie++;
clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state);
xprt_clear_connecting(upper_xprt);
upper_xprt->stat.connect_count++;
upper_xprt->stat.connect_time += (long)jiffies -
upper_xprt->stat.connect_start;
xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING);
}
rpc_shutdown_client(lower_clnt);
out_unlock:
current_restore_flags(pflags, PF_MEMALLOC);
upper_transport->clnt = NULL;
xprt_unlock_connect(upper_xprt, upper_transport);
return;
out_close:
rpc_shutdown_client(lower_clnt);
/* xprt_force_disconnect() wakes tasks with a fixed tk_status code.
* Wake them first here to ensure they get our tk_status code.
*/
xprt_wake_pending_tasks(upper_xprt, status);
xs_tcp_force_close(upper_xprt);
xprt_clear_connecting(upper_xprt);
goto out_unlock;
}
/**
* xs_connect - connect a socket to a remote endpoint
* @xprt: pointer to transport structure
@@ -2391,6 +2717,7 @@ static void xs_connect(struct rpc_xprt *xprt, struct rpc_task *task)
} else
dprintk("RPC: xs_connect scheduled xprt %p\n", xprt);
transport->clnt = task->tk_client;
queue_delayed_work(xprtiod_workqueue,
&transport->connect_worker,
delay);
@@ -2858,7 +3185,7 @@ static struct rpc_xprt *xs_setup_local(struct xprt_create *args)
switch (sun->sun_family) {
case AF_LOCAL:
if (sun->sun_path[0] != '/') {
if (sun->sun_path[0] != '/' && sun->sun_path[0] != '\0') {
dprintk("RPC: bad AF_LOCAL address: %s\n",
sun->sun_path);
ret = ERR_PTR(-EINVAL);
@@ -3044,6 +3371,94 @@ out_err:
return ret;
}
/**
* xs_setup_tcp_tls - Set up transport to use a TCP with TLS
* @args: rpc transport creation arguments
*
*/
static struct rpc_xprt *xs_setup_tcp_tls(struct xprt_create *args)
{
struct sockaddr *addr = args->dstaddr;
struct rpc_xprt *xprt;
struct sock_xprt *transport;
struct rpc_xprt *ret;
unsigned int max_slot_table_size = xprt_max_tcp_slot_table_entries;
if (args->flags & XPRT_CREATE_INFINITE_SLOTS)
max_slot_table_size = RPC_MAX_SLOT_TABLE_LIMIT;
xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries,
max_slot_table_size);
if (IS_ERR(xprt))
return xprt;
transport = container_of(xprt, struct sock_xprt, xprt);
xprt->prot = IPPROTO_TCP;
xprt->xprt_class = &xs_tcp_transport;
xprt->max_payload = RPC_MAX_FRAGMENT_SIZE;
xprt->bind_timeout = XS_BIND_TO;
xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO;
xprt->idle_timeout = XS_IDLE_DISC_TO;
xprt->ops = &xs_tcp_ops;
xprt->timeout = &xs_tcp_default_timeout;
xprt->max_reconnect_timeout = xprt->timeout->to_maxval;
xprt->connect_timeout = xprt->timeout->to_initval *
(xprt->timeout->to_retries + 1);
INIT_WORK(&transport->recv_worker, xs_stream_data_receive_workfn);
INIT_WORK(&transport->error_worker, xs_error_handle);
switch (args->xprtsec.policy) {
case RPC_XPRTSEC_TLS_ANON:
case RPC_XPRTSEC_TLS_X509:
xprt->xprtsec = args->xprtsec;
INIT_DELAYED_WORK(&transport->connect_worker,
xs_tcp_tls_setup_socket);
break;
default:
ret = ERR_PTR(-EACCES);
goto out_err;
}
switch (addr->sa_family) {
case AF_INET:
if (((struct sockaddr_in *)addr)->sin_port != htons(0))
xprt_set_bound(xprt);
xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP);
break;
case AF_INET6:
if (((struct sockaddr_in6 *)addr)->sin6_port != htons(0))
xprt_set_bound(xprt);
xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP6);
break;
default:
ret = ERR_PTR(-EAFNOSUPPORT);
goto out_err;
}
if (xprt_bound(xprt))
dprintk("RPC: set up xprt to %s (port %s) via %s\n",
xprt->address_strings[RPC_DISPLAY_ADDR],
xprt->address_strings[RPC_DISPLAY_PORT],
xprt->address_strings[RPC_DISPLAY_PROTO]);
else
dprintk("RPC: set up xprt to %s (autobind) via %s\n",
xprt->address_strings[RPC_DISPLAY_ADDR],
xprt->address_strings[RPC_DISPLAY_PROTO]);
if (try_module_get(THIS_MODULE))
return xprt;
ret = ERR_PTR(-EINVAL);
out_err:
xs_xprt_free(xprt);
return ret;
}
/**
* xs_setup_bc_tcp - Set up transport to use a TCP backchannel socket
* @args: rpc transport creation arguments
@@ -3153,6 +3568,15 @@ static struct xprt_class xs_tcp_transport = {
.netid = { "tcp", "tcp6", "" },
};
static struct xprt_class xs_tcp_tls_transport = {
.list = LIST_HEAD_INIT(xs_tcp_tls_transport.list),
.name = "tcp-with-tls",
.owner = THIS_MODULE,
.ident = XPRT_TRANSPORT_TCP_TLS,
.setup = xs_setup_tcp_tls,
.netid = { "tcp", "tcp6", "" },
};
static struct xprt_class xs_bc_tcp_transport = {
.list = LIST_HEAD_INIT(xs_bc_tcp_transport.list),
.name = "tcp NFSv4.1 backchannel",
@@ -3174,6 +3598,7 @@ int init_socket_xprt(void)
xprt_register_transport(&xs_local_transport);
xprt_register_transport(&xs_udp_transport);
xprt_register_transport(&xs_tcp_transport);
xprt_register_transport(&xs_tcp_tls_transport);
xprt_register_transport(&xs_bc_tcp_transport);
return 0;
@@ -3193,6 +3618,7 @@ void cleanup_socket_xprt(void)
xprt_unregister_transport(&xs_local_transport);
xprt_unregister_transport(&xs_udp_transport);
xprt_unregister_transport(&xs_tcp_transport);
xprt_unregister_transport(&xs_tcp_tls_transport);
xprt_unregister_transport(&xs_bc_tcp_transport);
}