kernel - Use pool tokens to protect unix domain PCBs
authorMatthew Dillon <dillon@apollo.backplane.com>
Fri, 14 Sep 2012 16:10:06 +0000 (09:10 -0700)
committerMatthew Dillon <dillon@apollo.backplane.com>
Fri, 14 Sep 2012 16:10:06 +0000 (09:10 -0700)
* The read, status, and write paths now use per-pcb pool tokens
  instead of the global unp_token.  The global token is still used
  for accept, connect, disconnect, etc.

* General semantics for making this SMP safe is to obtain a pointer
  to the unp from so->so_pcb, then obtain the related pool token,
  then re-check that so->so_pcb still equals unp.

* Pool token protects the peer pointer, unp->unp_conn.  Any change
  to unp->unp_conn requires both the pool token and the global token.

* This should improve concurrent reading and writing w/unix domain
  sockets.

sys/kern/uipc_usrreq.c

index e8edb31..d780bb9 100644 (file)
@@ -111,10 +111,17 @@ static int     unp_listen (struct unpcb *, struct thread *);
 static void    unp_fp_externalize(struct lwp *lp, struct file *fp, int fd);
 
 /*
- * NOTE:
- * Since unp_token will be automaticly released upon execution of
- * blocking code, we need to reference unp_conn before any possible
- * blocking code to prevent it from being ripped behind our back.
+ * SMP Considerations:
+ *
+ *     Since unp_token will be automaticly released upon execution of
+ *     blocking code, we need to reference unp_conn before any possible
+ *     blocking code to prevent it from being ripped behind our back.
+ *
+ *     Any adjustment to unp->unp_conn requires both the global unp_token
+ *     AND the per-unp token (lwkt_token_pool_lookup(unp)) to be held.
+ *
+ *     Any access to so_pcb to obtain unp requires the pool token for
+ *     unp to be held.
  */
 
 /* NOTE: unp_token MUST be held */
@@ -227,7 +234,6 @@ uipc_connect(netmsg_t msg)
        struct unpcb *unp;
        int error;
 
-       lwkt_gettoken(&unp_token);
        unp = msg->base.nm_so->so_pcb;
        if (unp) {
                error = unp_connect(msg->base.nm_so,
@@ -236,7 +242,6 @@ uipc_connect(netmsg_t msg)
        } else {
                error = EINVAL;
        }
-       lwkt_reltoken(&unp_token);
        lwkt_replymsg(&msg->lmsg, error);
 }
 
@@ -246,7 +251,6 @@ uipc_connect2(netmsg_t msg)
        struct unpcb *unp;
        int error;
 
-       lwkt_gettoken(&unp_token);
        unp = msg->connect2.nm_so1->so_pcb;
        if (unp) {
                error = unp_connect2(msg->connect2.nm_so1,
@@ -254,7 +258,6 @@ uipc_connect2(netmsg_t msg)
        } else {
                error = EINVAL;
        }
-       lwkt_reltoken(&unp_token);
        lwkt_replymsg(&msg->lmsg, error);
 }
 
@@ -351,13 +354,23 @@ uipc_rcvd(netmsg_t msg)
        struct socket *so2;
        int error;
 
-       lwkt_gettoken(&unp_token);
+       /*
+        * so_pcb is only modified with both the global and the unp
+        * pool token held.  The unp pointer is invalid until we verify
+        * that it is good by re-checking so_pcb AFTER obtaining the token.
+        */
        so = msg->base.nm_so;
-       unp = so->so_pcb;
+       while ((unp = so->so_pcb) != NULL) {
+               lwkt_getpooltoken(unp);
+               if (unp == so->so_pcb)
+                       break;
+               lwkt_relpooltoken(unp);
+       }
        if (unp == NULL) {
                error = EINVAL;
                goto done;
        }
+       /* pool token held */
 
        switch (so->so_type) {
        case SOCK_DGRAM:
@@ -367,31 +380,37 @@ uipc_rcvd(netmsg_t msg)
        case SOCK_SEQPACKET:
                if (unp->unp_conn == NULL)
                        break;
-               unp2 = unp->unp_conn;
+               unp2 = unp->unp_conn;   /* protected by pool token */
 
                /*
                 * Because we are transfering mbufs directly to the
                 * peer socket we have to use SSB_STOP on the sender
                 * to prevent it from building up infinite mbufs.
+                *
+                * As in several places in this module w ehave to ref unp2
+                * to ensure that it does not get ripped out from under us
+                * if we block on the so2 token or in sowwakeup().
                 */
                so2 = unp2->unp_socket;
+               unp_reference(unp2);
+               lwkt_gettoken(&so2->so_rcv.ssb_token);
                if (so->so_rcv.ssb_cc < so2->so_snd.ssb_hiwat &&
                    so->so_rcv.ssb_mbcnt < so2->so_snd.ssb_mbmax
                ) {
                        atomic_clear_int(&so2->so_snd.ssb_flags, SSB_STOP);
 
-                       unp_reference(unp2);
                        sowwakeup(so2);
-                       unp_free(unp2);
                }
+               lwkt_reltoken(&so2->so_rcv.ssb_token);
+               unp_free(unp2);
                break;
        default:
                panic("uipc_rcvd unknown socktype");
                /*NOTREACHED*/
        }
        error = 0;
+       lwkt_relpooltoken(unp);
 done:
-       lwkt_reltoken(&unp_token);
        lwkt_replymsg(&msg->lmsg, error);
 }
 
@@ -407,16 +426,28 @@ uipc_send(netmsg_t msg)
        struct mbuf *m;
        int error = 0;
 
-       lwkt_gettoken(&unp_token);
        so = msg->base.nm_so;
        control = msg->send.nm_control;
        m = msg->send.nm_m;
-       unp = so->so_pcb;
 
+       /*
+        * so_pcb is only modified with both the global and the unp
+        * pool token held.  The unp pointer is invalid until we verify
+        * that it is good by re-checking so_pcb AFTER obtaining the token.
+        */
+       so = msg->base.nm_so;
+       while ((unp = so->so_pcb) != NULL) {
+               lwkt_getpooltoken(unp);
+               if (unp == so->so_pcb)
+                       break;
+               lwkt_relpooltoken(unp);
+       }
        if (unp == NULL) {
                error = EINVAL;
-               goto release;
+               goto done;
        }
+       /* pool token held */
+
        if (msg->send.nm_flags & PRUS_OOB) {
                error = EOPNOTSUPP;
                goto release;
@@ -553,7 +584,8 @@ uipc_send(netmsg_t msg)
                unp_dispose(control);
 
 release:
-       lwkt_reltoken(&unp_token);
+       lwkt_relpooltoken(unp);
+done:
 
        if (control)
                m_freem(control);
@@ -573,14 +605,26 @@ uipc_sense(netmsg_t msg)
        struct stat *sb;
        int error;
 
-       lwkt_gettoken(&unp_token);
        so = msg->base.nm_so;
        sb = msg->sense.nm_stat;
-       unp = so->so_pcb;
+
+       /*
+        * so_pcb is only modified with both the global and the unp
+        * pool token held.  The unp pointer is invalid until we verify
+        * that it is good by re-checking so_pcb AFTER obtaining the token.
+        */
+       while ((unp = so->so_pcb) != NULL) {
+               lwkt_getpooltoken(unp);
+               if (unp == so->so_pcb)
+                       break;
+               lwkt_relpooltoken(unp);
+       }
        if (unp == NULL) {
                error = EINVAL;
                goto done;
        }
+       /* pool token held */
+
        sb->st_blksize = so->so_snd.ssb_hiwat;
        sb->st_dev = NOUDEV;
        if (unp->unp_ino == 0) {        /* make up a non-zero inode number */
@@ -590,8 +634,8 @@ uipc_sense(netmsg_t msg)
        }
        sb->st_ino = unp->unp_ino;
        error = 0;
+       lwkt_relpooltoken(unp);
 done:
-       lwkt_reltoken(&unp_token);
        lwkt_replymsg(&msg->lmsg, error);
 }
 
@@ -602,38 +646,60 @@ uipc_shutdown(netmsg_t msg)
        struct unpcb *unp;
        int error;
 
-       lwkt_gettoken(&unp_token);
+       /*
+        * so_pcb is only modified with both the global and the unp
+        * pool token held.  The unp pointer is invalid until we verify
+        * that it is good by re-checking so_pcb AFTER obtaining the token.
+        */
        so = msg->base.nm_so;
-       unp = so->so_pcb;
+       while ((unp = so->so_pcb) != NULL) {
+               lwkt_getpooltoken(unp);
+               if (unp == so->so_pcb)
+                       break;
+               lwkt_relpooltoken(unp);
+       }
        if (unp) {
+               /* pool token held */
                socantsendmore(so);
                unp_shutdown(unp);
+               lwkt_relpooltoken(unp);
                error = 0;
        } else {
                error = EINVAL;
        }
-       lwkt_reltoken(&unp_token);
        lwkt_replymsg(&msg->lmsg, error);
 }
 
 static void
 uipc_sockaddr(netmsg_t msg)
 {
+       struct socket *so;
        struct unpcb *unp;
        int error;
 
-       lwkt_gettoken(&unp_token);
-       unp = msg->base.nm_so->so_pcb;
+       /*
+        * so_pcb is only modified with both the global and the unp
+        * pool token held.  The unp pointer is invalid until we verify
+        * that it is good by re-checking so_pcb AFTER obtaining the token.
+        */
+       so = msg->base.nm_so;
+       while ((unp = so->so_pcb) != NULL) {
+               lwkt_getpooltoken(unp);
+               if (unp == so->so_pcb)
+                       break;
+               lwkt_relpooltoken(unp);
+       }
        if (unp) {
+               /* pool token held */
                if (unp->unp_addr) {
                        *msg->sockaddr.nm_nam =
                                dup_sockaddr((struct sockaddr *)unp->unp_addr);
                }
+               lwkt_relpooltoken(unp);
                error = 0;
        } else {
                error = EINVAL;
        }
-       lwkt_reltoken(&unp_token);
        lwkt_replymsg(&msg->lmsg, error);
 }
 
@@ -797,8 +863,9 @@ unp_detach(struct unpcb *unp)
        struct socket *so;
 
        lwkt_gettoken(&unp_token);
+       lwkt_getpooltoken(unp);
 
-       LIST_REMOVE(unp, unp_link);
+       LIST_REMOVE(unp, unp_link);     /* both tokens required */
        unp->unp_gencnt = ++unp_gencnt;
        --unp_count;
        if (unp->unp_vnode) {
@@ -812,8 +879,9 @@ unp_detach(struct unpcb *unp)
                unp_drop(LIST_FIRST(&unp->unp_refs), ECONNRESET);
        soisdisconnected(unp->unp_socket);
        so = unp->unp_socket;
-       soreference(so);        /* for delayed sorflush */
-       so->so_pcb = NULL;
+       soreference(so);                /* for delayed sorflush */
+       KKASSERT(so->so_pcb == unp);
+       so->so_pcb = NULL;              /* both tokens required */
        unp->unp_socket = NULL;
        sofree(so);             /* remove pcb ref */
 
@@ -829,6 +897,7 @@ unp_detach(struct unpcb *unp)
                unp_gc();
        }
        sofree(so);
+       lwkt_relpooltoken(unp);
        lwkt_reltoken(&unp_token);
 
        if (unp->unp_addr)
@@ -873,10 +942,15 @@ unp_bind(struct unpcb *unp, struct sockaddr *nam, struct thread *td)
        vattr.va_mode = (ACCESSPERMS & ~p->p_fd->fd_cmask);
        error = VOP_NCREATE(&nd.nl_nch, nd.nl_dvp, &vp, nd.nl_cred, &vattr);
        if (error == 0) {
-               vp->v_socket = unp->unp_socket;
-               unp->unp_vnode = vp;
-               unp->unp_addr = (struct sockaddr_un *)dup_sockaddr(nam);
-               vn_unlock(vp);
+               if (unp->unp_vnode == NULL) {
+                       vp->v_socket = unp->unp_socket;
+                       unp->unp_vnode = vp;
+                       unp->unp_addr = (struct sockaddr_un *)dup_sockaddr(nam);
+                       vn_unlock(vp);
+               } else {
+                       vput(vp);               /* late race */
+                       error = EINVAL;
+               }
        }
 done:
        nlookup_done(&nd);
@@ -940,6 +1014,11 @@ unp_connect(struct socket *so, struct sockaddr *nam, struct thread *td)
                        goto bad;
                }
                unp = so->so_pcb;
+               if (unp->unp_conn) {    /* race, already connected! */
+                       error = EISCONN;
+                       sofree(so3);
+                       goto bad;
+               }
                unp2 = so2->so_pcb;
                unp3 = so3->so_pcb;
                if (unp2->unp_addr)
@@ -978,6 +1057,12 @@ failed:
        return (error);
 }
 
+/*
+ * Connect two unix domain sockets together.
+ *
+ * NOTE: Semantics for any change to unp_conn requires that the per-unp
+ *      pool token also be held.
+ */
 int
 unp_connect2(struct socket *so, struct socket *so2)
 {
@@ -991,6 +1076,9 @@ unp_connect2(struct socket *so, struct socket *so2)
                return (EPROTOTYPE);
        }
        unp2 = so2->so_pcb;
+       lwkt_getpooltoken(unp);
+       lwkt_getpooltoken(unp2);
+
        unp->unp_conn = unp2;
 
        switch (so->so_type) {
@@ -1009,22 +1097,34 @@ unp_connect2(struct socket *so, struct socket *so2)
        default:
                panic("unp_connect2");
        }
+       lwkt_relpooltoken(unp2);
+       lwkt_relpooltoken(unp);
        lwkt_reltoken(&unp_token);
        return (0);
 }
 
+/*
+ * Disconnect a unix domain socket pair.
+ *
+ * NOTE: Semantics for any change to unp_conn requires that the per-unp
+ *      pool token also be held.
+ */
 static void
 unp_disconnect(struct unpcb *unp)
 {
        struct unpcb *unp2;
 
        lwkt_gettoken(&unp_token);
+       lwkt_getpooltoken(unp);
 
-       unp2 = unp->unp_conn;
-       if (unp2 == NULL) {
-               lwkt_reltoken(&unp_token);
-               return;
+       while ((unp2 = unp->unp_conn) != NULL) {
+               lwkt_getpooltoken(unp2);
+               if (unp2 == unp->unp_conn)
+                       break;
+               lwkt_relpooltoken(unp2);
        }
+       if (unp2 == NULL)
+               goto done;
 
        unp->unp_conn = NULL;
 
@@ -1045,6 +1145,9 @@ unp_disconnect(struct unpcb *unp)
                unp_free(unp2);
                break;
        }
+       lwkt_relpooltoken(unp2);
+       lwkt_relpooltoken(unp);
+done:
        lwkt_reltoken(&unp_token);
 }