tcp: Create seperate local port spaces for laddr/faddr/fport triple
authorSepherosa Ziehau <sephe@dragonflybsd.org>
Tue, 12 Apr 2011 08:27:46 +0000 (16:27 +0800)
committerSepherosa Ziehau <sephe@dragonflybsd.org>
Tue, 12 Apr 2011 08:27:46 +0000 (16:27 +0800)
TCP local ports could easily get depleted under stress due to
TIME_WAIT state.  Since TIME_WAIT state is used to protect
laddr/lport/faddr/fport, we actually just need to make sure that
lport is not duplicated for each laddr/faddr/fport triple intead of
making sure that lport is not duplicated globally.

Add sysctl net.inet.tcp.lportext to enable seperate local port spaces
for laddr/faddr/fport triple; it is disabled by default.

sys/netinet/in_pcb.c
sys/netinet/in_pcb.h
sys/netinet/tcp_usrreq.c

index 7d4d269..4986979 100644 (file)
@@ -515,6 +515,193 @@ done:
        return error;
 }
 
+static struct inpcb *
+in_pcblookup_addrport(struct inpcbinfo *pcbinfo, struct in_addr laddr,
+    u_short lport, struct in_addr faddr, u_short fport, struct ucred *cred)
+{
+       struct inpcb *inp;
+       struct inpcbporthead *porthash;
+       struct inpcbport *phd;
+       struct inpcb *match = NULL;
+
+       /*
+        * If the porthashbase is shared across several cpus we need
+        * to lock.
+        */
+       if (pcbinfo->porttoken)
+               lwkt_gettoken(pcbinfo->porttoken);
+
+       /*
+        * Best fit PCB lookup.
+        *
+        * First see if this local port is in use by looking on the
+        * port hash list.
+        */
+       porthash = &pcbinfo->porthashbase[
+                       INP_PCBPORTHASH(lport, pcbinfo->porthashmask)];
+       LIST_FOREACH(phd, porthash, phd_hash) {
+               if (phd->phd_port == lport)
+                       break;
+       }
+       if (phd != NULL) {
+               LIST_FOREACH(inp, &phd->phd_pcblist, inp_portlist) {
+#ifdef INET6
+                       if ((inp->inp_vflag & INP_IPV4) == 0)
+                               continue;
+#endif
+                       if (inp->inp_laddr.s_addr != INADDR_ANY &&
+                           inp->inp_laddr.s_addr != laddr.s_addr)
+                               continue;
+
+                       if (inp->inp_faddr.s_addr != INADDR_ANY &&
+                           inp->inp_faddr.s_addr != faddr.s_addr)
+                               continue;
+
+                       if (inp->inp_fport != 0 && inp->inp_fport != fport)
+                               continue;
+
+                       if (cred == NULL ||
+                           cred->cr_prison ==
+                           inp->inp_socket->so_cred->cr_prison) {
+                               match = inp;
+                               break;
+                       }
+               }
+       }
+       if (pcbinfo->porttoken)
+               lwkt_reltoken(pcbinfo->porttoken);
+       return (match);
+}
+
+int
+in_pcbconn_bind(struct inpcb *inp, const struct sockaddr *nam,
+    struct thread *td)
+{
+       struct proc *p = td->td_proc;
+       unsigned short *lastport;
+       const struct sockaddr_in *sin = (const struct sockaddr_in *)nam;
+       struct sockaddr_in jsin;
+       struct inpcbinfo *pcbinfo = inp->inp_pcbinfo;
+       struct ucred *cred = NULL;
+       u_short lport = 0;
+       ushort first, last;
+       int count, error;
+
+       if (TAILQ_EMPTY(&in_ifaddrheads[mycpuid])) /* XXX broken! */
+               return (EADDRNOTAVAIL);
+
+       KKASSERT(inp->inp_laddr.s_addr != INADDR_ANY);
+       if (inp->inp_lport != 0)
+               return (EINVAL);        /* already bound */
+
+       KKASSERT(p);
+       cred = p->p_ucred;
+
+       /*
+        * This has to be atomic.  If the porthash is shared across multiple
+        * protocol threads (aka tcp) then the token will be non-NULL.
+        */
+       if (pcbinfo->porttoken)
+               lwkt_gettoken(pcbinfo->porttoken);
+
+       jsin.sin_family = AF_INET;
+       jsin.sin_addr.s_addr = inp->inp_laddr.s_addr;
+       if (!prison_replace_wildcards(td, (struct sockaddr *)&jsin)) {
+               inp->inp_laddr.s_addr = INADDR_ANY;
+               error = EINVAL;
+               goto done;
+       }
+       inp->inp_laddr.s_addr = jsin.sin_addr.s_addr;
+
+       inp->inp_flags |= INP_ANONPORT;
+
+       if (inp->inp_flags & INP_HIGHPORT) {
+               first = ipport_hifirstauto;     /* sysctl */
+               last  = ipport_hilastauto;
+               lastport = &pcbinfo->lasthi;
+       } else if (inp->inp_flags & INP_LOWPORT) {
+               if (cred &&
+                   (error = priv_check_cred(cred, PRIV_NETINET_RESERVEDPORT, 0))) {
+                       inp->inp_laddr.s_addr = INADDR_ANY;
+                       goto done;
+               }
+               first = ipport_lowfirstauto;    /* 1023 */
+               last  = ipport_lowlastauto;     /* 600 */
+               lastport = &pcbinfo->lastlow;
+       } else {
+               first = ipport_firstauto;       /* sysctl */
+               last  = ipport_lastauto;
+               lastport = &pcbinfo->lastport;
+       }
+       /*
+        * Simple check to ensure all ports are not used up causing
+        * a deadlock here.
+        *
+        * We split the two cases (up and down) so that the direction
+        * is not being tested on each round of the loop.
+        */
+       if (first > last) {
+               /*
+                * counting down
+                */
+               count = first - last;
+
+               do {
+                       if (count-- < 0) {      /* completely used? */
+                               inp->inp_laddr.s_addr = INADDR_ANY;
+                               error = EADDRNOTAVAIL;
+                               goto done;
+                       }
+                       --*lastport;
+                       if (*lastport > first || *lastport < last)
+                               *lastport = first;
+                       lport = htons(*lastport);
+               } while (in_pcblookup_addrport(pcbinfo, inp->inp_laddr, lport,
+                               sin->sin_addr, sin->sin_port, cred));
+       } else {
+               /*
+                * counting up
+                */
+               count = last - first;
+
+               do {
+                       if (count-- < 0) {      /* completely used? */
+                               inp->inp_laddr.s_addr = INADDR_ANY;
+                               error = EADDRNOTAVAIL;
+                               goto done;
+                       }
+                       ++*lastport;
+                       if (*lastport < first || *lastport > last)
+                               *lastport = first;
+                       lport = htons(*lastport);
+               } while (in_pcblookup_addrport(pcbinfo, inp->inp_laddr, lport,
+                               sin->sin_addr, sin->sin_port, cred));
+       }
+       inp->inp_lport = lport;
+
+       jsin.sin_family = AF_INET;
+       jsin.sin_addr.s_addr = inp->inp_laddr.s_addr;
+       if (!prison_replace_wildcards(td, (struct sockaddr*)&jsin)) {
+               inp->inp_laddr.s_addr = INADDR_ANY;
+               inp->inp_lport = 0;
+               error = EINVAL;
+               goto done;
+       }
+       inp->inp_laddr.s_addr = jsin.sin_addr.s_addr;
+
+       if (in_pcbinsporthash(inp) != 0) {
+               inp->inp_laddr.s_addr = INADDR_ANY;
+               inp->inp_lport = 0;
+               error = EAGAIN;
+               goto done;
+       }
+       error = 0;
+done:
+       if (pcbinfo->porttoken)
+               lwkt_reltoken(pcbinfo->porttoken);
+       return error;
+}
+
 /*
  *   Transform old in_pcbconnect() into an inner subroutine for new
  *   in_pcbconnect(): Do some validity-checking on the remote
index 43e2bf2..26715bb 100644 (file)
@@ -412,6 +412,8 @@ int in_pcballoc (struct socket *, struct inpcbinfo *);
 void   in_pcbunlink (struct inpcb *, struct inpcbinfo *);
 void   in_pcblink (struct inpcb *, struct inpcbinfo *);
 int    in_pcbbind (struct inpcb *, struct sockaddr *, struct thread *);
+int    in_pcbconn_bind(struct inpcb *, const struct sockaddr *,
+           struct thread *);
 int    in_pcbconnect (struct inpcb *, struct sockaddr *, struct thread *);
 void   in_pcbdetach (struct inpcb *);
 void   in_pcbdisconnect (struct inpcb *);
index ca07f26..41bdd71 100644 (file)
@@ -159,6 +159,11 @@ static struct tcpcb *
 #define        TCPDEBUG2(req)
 #endif
 
+static int     tcp_lport_extension = 0;
+
+SYSCTL_INT(_net_inet_tcp, OID_AUTO, lportext, CTLFLAG_RW,
+    &tcp_lport_extension, 0, "");
+
 /*
  * TCP attaches to socket via pru_attach(), reserving space,
  * and an internet control block.  This is likely occuring on
@@ -1055,24 +1060,40 @@ tcp_connect(netmsg_t msg)
                in_pcblink(so->so_pcb, &tcbinfo[mycpu->gd_cpuid]);
        }
 
-       /*
-        * Bind if we have to
-        */
-       if (inp->inp_lport == 0) {
-               error = in_pcbbind(inp, NULL, td);
+       if (tcp_lport_extension) {
+               if (inp->inp_lport == 0) {
+                       KKASSERT(inp->inp_laddr.s_addr == INADDR_ANY);
+
+                       error = in_pcbladdr(inp, nam, &if_sin, td);
+                       if (error)
+                               goto out;
+                       inp->inp_laddr.s_addr = if_sin->sin_addr.s_addr;
+
+                       error = in_pcbconn_bind(inp, nam, td);
+                       if (error)
+                               goto out;
+               }
+       } else {
+               /*
+                * Bind if we have to
+                */
+               if (inp->inp_lport == 0) {
+                       error = in_pcbbind(inp, NULL, td);
+                       if (error)
+                               goto out;
+               }
+
+               /*
+                * Calculate the correct protocol processing thread.  The
+                * connect operation must run there.  Set the forwarding
+                * port before we forward the message or it will get bounced
+                * right back to us.
+                */
+               error = in_pcbladdr(inp, nam, &if_sin, td);
                if (error)
                        goto out;
        }
 
-       /*
-        * Calculate the correct protocol processing thread.  The connect
-        * operation must run there.  Set the forwarding port before we
-        * forward the message or it will get bounced right back to us.
-        */
-       error = in_pcbladdr(inp, nam, &if_sin, td);
-       if (error)
-               goto out;
-
        KKASSERT(inp->inp_socket == so);
 
 #ifdef SMP