uipc: Fix various races on unp_connect() path.
authorSepherosa Ziehau <sephe@dragonflybsd.org>
Tue, 25 Aug 2015 14:05:14 +0000 (22:05 +0800)
committerSepherosa Ziehau <sephe@dragonflybsd.org>
Wed, 26 Aug 2015 02:06:56 +0000 (10:06 +0800)
And factor out unp_find_lockref(), which will be used to avoid
abusing unpcb.unp_conn on uipc_send() for unconnected DGRAM
unix sockets.

sys/kern/uipc_usrreq.c

index e909fbc..c65af71 100644 (file)
 #include <sys/msgport2.h>
 
 #define UNP_DETACHED           UNP_PRIVATE1
+#define UNP_CONNECTING         UNP_PRIVATE2
 
 #define UNP_ISATTACHED(unp)    \
     ((unp) != NULL && ((unp)->unp_flags & UNP_DETACHED) == 0)
 
+#ifdef INVARIANTS
+#define UNP_ASSERT_TOKEN_HELD(unp) \
+    ASSERT_LWKT_TOKEN_HELD(lwkt_token_pool_lookup((unp)))
+#else  /* !INVARIANTS */
+#define UNP_ASSERT_TOKEN_HELD(unp)
+#endif /* INVARIANTS */
+
 typedef struct unp_defdiscard {
        struct unp_defdiscard *next;
        struct file *fp;
@@ -110,6 +118,9 @@ static void    unp_discard (struct file *, void *);
 static int     unp_internalize (struct mbuf *, struct thread *);
 static int     unp_listen (struct unpcb *, struct thread *);
 static void    unp_fp_externalize(struct lwp *lp, struct file *fp, int fd);
+static int     unp_find_lockref(struct sockaddr *nam, struct thread *td,
+                  short type, struct unpcb **unp_ret);
+static void    unp_connect_pair(struct unpcb *unp, struct unpcb *unp2);
 
 /*
  * SMP Considerations:
@@ -271,33 +282,19 @@ uipc_bind(netmsg_t msg)
 static void
 uipc_connect(netmsg_t msg)
 {
-       struct unpcb *unp;
        int error;
 
-       unp = msg->base.nm_so->so_pcb;
-       if (UNP_ISATTACHED(unp)) {
-               error = unp_connect(msg->base.nm_so,
-                                   msg->connect.nm_nam,
-                                   msg->connect.nm_td);
-       } else {
-               error = EINVAL;
-       }
+       error = unp_connect(msg->base.nm_so, msg->connect.nm_nam,
+           msg->connect.nm_td);
        lwkt_replymsg(&msg->lmsg, error);
 }
 
 static void
 uipc_connect2(netmsg_t msg)
 {
-       struct unpcb *unp;
        int error;
 
-       unp = msg->connect2.nm_so1->so_pcb;
-       if (UNP_ISATTACHED(unp)) {
-               error = unp_connect2(msg->connect2.nm_so1,
-                                    msg->connect2.nm_so2);
-       } else {
-               error = EINVAL;
-       }
+       error = unp_connect2(msg->connect2.nm_so1, msg->connect2.nm_so2);
        lwkt_replymsg(&msg->lmsg, error);
 }
 
@@ -990,68 +987,74 @@ failed:
 static int
 unp_connect(struct socket *so, struct sockaddr *nam, struct thread *td)
 {
-       struct proc *p = td->td_proc;
-       struct sockaddr_un *soun = (struct sockaddr_un *)nam;
-       struct vnode *vp;
-       struct socket *so2, *so3;
-       struct unpcb *unp, *unp2, *unp3;
-       int error, len;
-       struct nlookupdata nd;
-       char buf[SOCK_MAXADDRLEN];
+       struct unpcb *unp, *unp2;
+       int error, flags = 0;
 
        lwkt_gettoken(&unp_token);
 
-       len = nam->sa_len - offsetof(struct sockaddr_un, sun_path);
-       if (len <= 0) {
+       unp = unp_getsocktoken(so);
+       if (!UNP_ISATTACHED(unp)) {
                error = EINVAL;
                goto failed;
        }
-       strncpy(buf, soun->sun_path, len);
-       buf[len] = 0;
 
-       vp = NULL;
-       error = nlookup_init(&nd, buf, UIO_SYSSPACE, NLC_FOLLOW);
-       if (error == 0)
-               error = nlookup(&nd);
-       if (error == 0)
-               error = cache_vget(&nd.nl_nch, nd.nl_cred, LK_EXCLUSIVE, &vp);
-       nlookup_done(&nd);
-       if (error)
+       if ((unp->unp_flags & UNP_CONNECTING) || unp->unp_conn != NULL) {
+               error = EISCONN;
                goto failed;
-
-       if (vp->v_type != VSOCK) {
-               error = ENOTSOCK;
-               goto bad;
        }
-       error = VOP_EACCESS(vp, VWRITE, p->p_ucred);
+
+       flags = UNP_CONNECTING;
+       unp_setflags(unp, flags);
+
+       error = unp_find_lockref(nam, td, so->so_type, &unp2);
        if (error)
-               goto bad;
-       so2 = vp->v_socket;
-       if (so2 == NULL) {
-               error = ECONNREFUSED;
-               goto bad;
-       }
-       if (so->so_type != so2->so_type) {
-               error = EPROTOTYPE;
-               goto bad;
-       }
+               goto failed;
+       /*
+        * NOTE:
+        * unp2 is locked and referenced.
+        */
+
        if (so->so_proto->pr_flags & PR_CONNREQUIRED) {
+               struct socket *so2, *so3;
+               struct unpcb *unp3;
+
+               so2 = unp2->unp_socket;
                if (!(so2->so_options & SO_ACCEPTCONN) ||
-                   (so3 = sonewconn(so2, 0)) == NULL) {
+                   (so3 = sonewconn_faddr(so2, 0, NULL,
+                    TRUE /* keep ref */)) == NULL) {
                        error = ECONNREFUSED;
-                       goto bad;
+                       goto done;
                }
-               unp = so->so_pcb;
-               if (unp->unp_conn) {    /* race, already connected! */
-                       error = EISCONN;
+               /* so3 has a socket reference. */
+
+               unp3 = unp_getsocktoken(so3);
+               if (!UNP_ISATTACHED(unp3)) {
+                       unp_reltoken(unp3);
+                       /*
+                        * Already aborted; we only need to drop the
+                        * socket reference held by sonewconn_faddr().
+                        */
                        sofree(so3);
-                       goto bad;
+                       error = ECONNREFUSED;
+                       goto done;
                }
-               unp2 = so2->so_pcb;
-               unp3 = so3->so_pcb;
-               if (unp2->unp_addr)
+               unp_reference(unp3);
+               /*
+                * NOTE:
+                * unp3 is locked and referenced.
+                */
+
+               /*
+                * Release so3 socket reference held by sonewconn_faddr().
+                * Since we have referenced unp3, neither unp3 nor so3 will
+                * be destroyed here.
+                */
+               sofree(so3);
+
+               if (unp2->unp_addr != NULL) {
                        unp3->unp_addr = (struct sockaddr_un *)
-                               dup_sockaddr((struct sockaddr *)unp2->unp_addr);
+                           dup_sockaddr((struct sockaddr *)unp2->unp_addr);
+               }
 
                /*
                 * unp_peercred management:
@@ -1060,7 +1063,7 @@ unp_connect(struct socket *so, struct sockaddr *nam, struct thread *td)
                 * from its process structure at the time of connect()
                 * (which is now).
                 */
-               cru2x(p->p_ucred, &unp3->unp_peercred);
+               cru2x(td->td_proc->p_ucred, &unp3->unp_peercred);
                unp_setflags(unp3, UNP_HAVEPC);
                /*
                 * The receiver's (server's) credentials are copied
@@ -1075,12 +1078,22 @@ unp_connect(struct socket *so, struct sockaddr *nam, struct thread *td)
                    sizeof(unp->unp_peercred));
                unp_setflags(unp, UNP_HAVEPC);
 
-               so2 = so3;
+               unp_connect_pair(unp, unp3);
+
+               /* Done with unp3 */
+               unp_free(unp3);
+               unp_reltoken(unp3);
+       } else {
+               unp_connect_pair(unp, unp2);
        }
-       error = unp_connect2(so, so2);
-bad:
-       vput(vp);
+done:
+       unp_free(unp2);
+       unp_reltoken(unp2);
 failed:
+       if (flags)
+               unp_clrflags(unp, flags);
+       unp_reltoken(unp);
+
        lwkt_reltoken(&unp_token);
        return (error);
 }
@@ -1094,8 +1107,8 @@ failed:
 int
 unp_connect2(struct socket *so, struct socket *so2)
 {
-       struct unpcb *unp;
-       struct unpcb *unp2;
+       struct unpcb *unp, *unp2;
+       int error;
 
        lwkt_gettoken(&unp_token);
        if (so2->so_type != so->so_type) {
@@ -1105,29 +1118,32 @@ unp_connect2(struct socket *so, struct socket *so2)
        unp = unp_getsocktoken(so);
        unp2 = unp_getsocktoken(so2);
 
-       unp->unp_conn = unp2;
-
-       switch (so->so_type) {
-       case SOCK_DGRAM:
-               LIST_INSERT_HEAD(&unp2->unp_refs, unp, unp_reflink);
-               soisconnected(so);
-               break;
-
-       case SOCK_STREAM:
-       case SOCK_SEQPACKET:
-               unp2->unp_conn = unp;
-               soisconnected(so);
-               soisconnected(so2);
-               break;
+       if (!UNP_ISATTACHED(unp)) {
+               error = EINVAL;
+               goto done;
+       }
+       if (!UNP_ISATTACHED(unp2)) {
+               error = ECONNREFUSED;
+               goto done;
+       }
 
-       default:
-               panic("unp_connect2");
+       if (unp->unp_conn != NULL) {
+               error = EISCONN;
+               goto done;
+       }
+       if ((so->so_type == SOCK_STREAM || so->so_type == SOCK_SEQPACKET) &&
+           unp2->unp_conn != NULL) {
+               error = EISCONN;
+               goto done;
        }
 
+       unp_connect_pair(unp, unp2);
+       error = 0;
+done:
        unp_reltoken(unp2);
        unp_reltoken(unp);
        lwkt_reltoken(&unp_token);
-       return (0);
+       return (error);
 }
 
 /*
@@ -2041,3 +2057,106 @@ unp_discard(struct file *fp, void *data __unused)
        }
 }
 
+static int
+unp_find_lockref(struct sockaddr *nam, struct thread *td, short type,
+    struct unpcb **unp_ret)
+{
+       struct proc *p = td->td_proc;
+       struct sockaddr_un *soun = (struct sockaddr_un *)nam;
+       struct vnode *vp = NULL;
+       struct socket *so;
+       struct unpcb *unp;
+       int error, len;
+       struct nlookupdata nd;
+       char buf[SOCK_MAXADDRLEN];
+
+       *unp_ret = NULL;
+
+       len = nam->sa_len - offsetof(struct sockaddr_un, sun_path);
+       if (len <= 0) {
+               error = EINVAL;
+               goto failed;
+       }
+       strncpy(buf, soun->sun_path, len);
+       buf[len] = 0;
+
+       error = nlookup_init(&nd, buf, UIO_SYSSPACE, NLC_FOLLOW);
+       if (error == 0)
+               error = nlookup(&nd);
+       if (error == 0)
+               error = cache_vget(&nd.nl_nch, nd.nl_cred, LK_EXCLUSIVE, &vp);
+       nlookup_done(&nd);
+       if (error) {
+               vp = NULL;
+               goto failed;
+       }
+
+       if (vp->v_type != VSOCK) {
+               error = ENOTSOCK;
+               goto failed;
+       }
+       error = VOP_EACCESS(vp, VWRITE, p->p_ucred);
+       if (error)
+               goto failed;
+       so = vp->v_socket;
+       if (so == NULL) {
+               error = ECONNREFUSED;
+               goto failed;
+       }
+       if (so->so_type != type) {
+               error = EPROTOTYPE;
+               goto failed;
+       }
+
+       /* Lock this unp. */
+       unp = unp_getsocktoken(so);
+       if (!UNP_ISATTACHED(unp)) {
+               unp_reltoken(unp);
+               error = ECONNREFUSED;
+               goto failed;
+       }
+       /* And keep this unp referenced. */
+       unp_reference(unp);
+
+       /* Done! */
+       *unp_ret = unp;
+       error = 0;
+failed:
+       if (vp != NULL)
+               vput(vp);
+       return error;
+}
+
+static void
+unp_connect_pair(struct unpcb *unp, struct unpcb *unp2)
+{
+       struct socket *so = unp->unp_socket;
+       struct socket *so2 = unp2->unp_socket;
+
+       UNP_ASSERT_TOKEN_HELD(unp);
+       UNP_ASSERT_TOKEN_HELD(unp2);
+
+       KASSERT(so->so_type == so2->so_type,
+           ("socket type mismatch, so %d, so2 %d", so->so_type, so2->so_type));
+
+       KASSERT(unp->unp_conn == NULL, ("unp is already connected"));
+       unp->unp_conn = unp2;
+
+       switch (so->so_type) {
+       case SOCK_DGRAM:
+               LIST_INSERT_HEAD(&unp2->unp_refs, unp, unp_reflink);
+               soisconnected(so);
+               break;
+
+       case SOCK_STREAM:
+       case SOCK_SEQPACKET:
+               KASSERT(unp2->unp_conn == NULL, ("unp2 is already connected"));
+               unp2->unp_conn = unp;
+               soisconnected(so);
+               soisconnected(so2);
+               break;
+
+       default:
+               panic("unp_connect_pair: unknown socket type %d", so->so_type);
+       }
+}