AF_UNIX: Hold a reference of the unp_conn before executing blocking code
authorSepherosa Ziehau <sephe@dragonflybsd.org>
Wed, 27 Apr 2011 07:10:03 +0000 (15:10 +0800)
committerSepherosa Ziehau <sephe@dragonflybsd.org>
Wed, 27 Apr 2011 07:10:03 +0000 (15:10 +0800)
Since unp_token will be automaticly released upon execution of blocking
code, close of unp_conn could race any code paths that references unp_conn
after executing blocking code.  To fix these races, we simply increment
the reference count of the unp_conn before executig any possibly blocking
code and release the reference count of unp_conn, which may release unp_conn
itself.  This _currently_ does not suffer 0-ref race, since unp_token is
always being held.

sys/kern/uipc_usrreq.c
sys/sys/unpcb.h

index fad2f23..19a526b 100644 (file)
@@ -111,6 +111,29 @@ static int     unp_internalize (struct mbuf *, struct thread *);
 static int     unp_listen (struct unpcb *, struct thread *);
 static void    unp_fp_externalize(struct lwp *lp, struct file *fp, int fd);
 
+/*
+ * NOTE:
+ * Since unp_token will be automaticly released upon execution of
+ * blocking code, we need to reference unp_conn before any possible
+ * blocking code to prevent it from being ripped behind our back.
+ */
+
+/* NOTE: unp_token MUST be held */
+static __inline void
+unp_reference(struct unpcb *unp)
+{
+       atomic_add_int(&unp->unp_refcnt, 1);
+}
+
+/* NOTE: unp_token MUST be held */
+static __inline void
+unp_free(struct unpcb *unp)
+{
+       KKASSERT(unp->unp_refcnt > 0);
+       if (atomic_fetchadd_int(&unp->unp_refcnt, -1) == 1)
+               unp_detach(unp);
+}
+
 /*
  * NOTE: (so) is referenced from soabort*() and netmsg_pru_abort()
  *      will sofree() it when we return.
@@ -125,7 +148,7 @@ uipc_abort(netmsg_t msg)
        unp = msg->base.nm_so->so_pcb;
        if (unp) {
                unp_drop(unp, ECONNABORTED);
-               unp_detach(unp);
+               unp_free(unp);
                error = 0;
        } else {
                error = EINVAL;
@@ -146,14 +169,18 @@ uipc_accept(netmsg_t msg)
        if (unp == NULL) {
                error = EINVAL;
        } else {
+               struct unpcb *unp2 = unp->unp_conn;
+
                /*
                 * Pass back name of connected socket,
                 * if it was bound and we are still connected
                 * (our peer may have closed already!).
                 */
-               if (unp->unp_conn && unp->unp_conn->unp_addr) {
+               if (unp2 && unp2->unp_addr) {
+                       unp_reference(unp2);
                        *msg->accept.nm_nam = dup_sockaddr(
-                               (struct sockaddr *)unp->unp_conn->unp_addr);
+                               (struct sockaddr *)unp2->unp_addr);
+                       unp_free(unp2);
                } else {
                        *msg->accept.nm_nam = dup_sockaddr(
                                (struct sockaddr *)&sun_noname);
@@ -244,7 +271,7 @@ uipc_detach(netmsg_t msg)
        lwkt_gettoken(&unp_token);
        unp = msg->base.nm_so->so_pcb;
        if (unp) {
-               unp_detach(unp);
+               unp_free(unp);
                error = 0;
        } else {
                error = EINVAL;
@@ -298,8 +325,12 @@ uipc_peeraddr(netmsg_t msg)
        if (unp == NULL) {
                error = EINVAL;
        } else if (unp->unp_conn && unp->unp_conn->unp_addr) {
+               struct unpcb *unp2 = unp->unp_conn;
+
+               unp_reference(unp2);
                *msg->peeraddr.nm_nam = dup_sockaddr(
-                               (struct sockaddr *)unp->unp_conn->unp_addr);
+                               (struct sockaddr *)unp2->unp_addr);
+               unp_free(unp2);
                error = 0;
        } else {
                /*
@@ -318,7 +349,7 @@ uipc_peeraddr(netmsg_t msg)
 static void
 uipc_rcvd(netmsg_t msg)
 {
-       struct unpcb *unp;
+       struct unpcb *unp, *unp2;
        struct socket *so;
        struct socket *so2;
        int error;
@@ -339,17 +370,22 @@ uipc_rcvd(netmsg_t msg)
        case SOCK_SEQPACKET:
                if (unp->unp_conn == NULL)
                        break;
+               unp2 = unp->unp_conn;
+
                /*
                 * Because we are transfering mbufs directly to the
                 * peer socket we have to use SSB_STOP on the sender
                 * to prevent it from building up infinite mbufs.
                 */
-               so2 = unp->unp_conn->unp_socket;
+               so2 = unp2->unp_socket;
                if (so->so_rcv.ssb_cc < so2->so_snd.ssb_hiwat &&
                    so->so_rcv.ssb_mbcnt < so2->so_snd.ssb_mbmax
                ) {
                        atomic_clear_int(&so2->so_snd.ssb_flags, SSB_STOP);
+
+                       unp_reference(unp2);
                        sowwakeup(so2);
+                       unp_free(unp2);
                }
                break;
        default:
@@ -367,7 +403,7 @@ done:
 static void
 uipc_send(netmsg_t msg)
 {
-       struct unpcb *unp;
+       struct unpcb *unp, *unp2;
        struct socket *so;
        struct socket *so2;
        struct mbuf *control;
@@ -413,12 +449,15 @@ uipc_send(netmsg_t msg)
                                break;
                        }
                }
-               so2 = unp->unp_conn->unp_socket;
+               unp2 = unp->unp_conn;
+               so2 = unp2->unp_socket;
                if (unp->unp_addr)
                        from = (struct sockaddr *)unp->unp_addr;
                else
                        from = &sun_noname;
 
+               unp_reference(unp2);
+
                lwkt_gettoken(&so2->so_rcv.ssb_token);
                if (ssb_appendaddr(&so2->so_rcv, from, m, control)) {
                        sorwakeup(so2);
@@ -430,6 +469,8 @@ uipc_send(netmsg_t msg)
                if (msg->send.nm_addr)
                        unp_disconnect(unp);
                lwkt_reltoken(&so2->so_rcv.ssb_token);
+
+               unp_free(unp2);
                break;
        }
 
@@ -459,7 +500,11 @@ uipc_send(netmsg_t msg)
                }
                if (unp->unp_conn == NULL)
                        panic("uipc_send connected but no connection?");
-               so2 = unp->unp_conn->unp_socket;
+               unp2 = unp->unp_conn;
+               so2 = unp2->unp_socket;
+
+               unp_reference(unp2);
+
                /*
                 * Send to paired receive port, and then reduce
                 * send buffer hiwater marks to maintain backpressure.
@@ -491,6 +536,8 @@ uipc_send(netmsg_t msg)
                }
                lwkt_reltoken(&so2->so_rcv.ssb_token);
                sorwakeup(so2);
+
+               unp_free(unp2);
                break;
 
        default:
@@ -731,6 +778,7 @@ unp_attach(struct socket *so, struct pru_attach_info *ai)
                error = ENOBUFS;
                goto failed;
        }
+       unp->unp_refcnt = 1;
        unp->unp_gencnt = ++unp_gencnt;
        unp_count++;
        LIST_INIT(&unp->unp_refs);
@@ -988,11 +1036,16 @@ unp_disconnect(struct unpcb *unp)
                LIST_REMOVE(unp, unp_reflink);
                soclrstate(unp->unp_socket, SS_ISCONNECTED);
                break;
+
        case SOCK_STREAM:
        case SOCK_SEQPACKET:
-               soisdisconnected(unp->unp_socket);
+               unp_reference(unp2);
                unp2->unp_conn = NULL;
+
+               soisdisconnected(unp->unp_socket);
                soisdisconnected(unp2->unp_socket);
+
+               unp_free(unp2);
                break;
        }
        lwkt_reltoken(&unp_token);
@@ -1003,7 +1056,7 @@ void
 unp_abort(struct unpcb *unp)
 {
        lwkt_gettoken(&unp_token);
-       unp_detach(unp);
+       unp_free(unp);
        lwkt_reltoken(&unp_token);
 }
 #endif
index bb39f0e..6bd1681 100644 (file)
@@ -79,7 +79,7 @@ struct        unpcb {
        struct  unp_head unp_refs;      /* referencing socket linked list */
        LIST_ENTRY(unpcb) unp_reflink;  /* link in unp_refs list */
        struct  sockaddr_un *unp_addr;  /* bound address of socket */
-       int     unused01;
+       int     unp_refcnt;             /* referece count */
        int     unused02;
        unp_gen_t unp_gencnt;           /* generation count of this instance */
        int     unp_flags;              /* flags */