select: Don't allow unwanted/leftover fds being returned.
authorSepherosa Ziehau <sephe@dragonflybsd.org>
Mon, 28 Aug 2017 13:49:00 +0000 (21:49 +0800)
committerSepherosa Ziehau <sephe@dragonflybsd.org>
Mon, 28 Aug 2017 13:55:28 +0000 (21:55 +0800)
The root cause is that the lwp_kqueue_serial will wrap pretty quickly,
6 seconds on my laptop, if the select(2) is polling, either due to heavy
workload or 0 timeout.  The POC test:
https://leaf.dragonflybsd.org/~sephe/select_wrap.c

Fixing this issue by saving the original fd_sets and do additional
kevent filtering before return the fd to userland.

poll(2) suffers the similar issue and will be fixed in later commit.

Reported-by: many
sys/kern/sys_generic.c

index 5cc2fd1..f10b06a 100644 (file)
@@ -87,6 +87,9 @@ struct select_kevent_copyin_args {
        kfd_set         *read_set;
        kfd_set         *write_set;
        kfd_set         *except_set;
+       kfd_set         *oread_set;     /* orig set, points into read_set */
+       kfd_set         *owrite_set;    /* orig set, points into write_set */
+       kfd_set         *oexcept_set;   /* orig set, points into except_set */
        int             active_set;     /* One of select_copyin_states */
        struct lwp      *lwp;           /* Pointer to our lwp */
        int             num_fds;        /* Number of file descriptors (syscall arg) */
@@ -990,18 +993,48 @@ select_copyout(void *arg, struct kevent *kevp, int count, int *res)
        skap = (struct select_kevent_copyin_args *)arg;
 
        for (i = 0; i < count; ++i) {
+               kfd_set *fd_set = NULL;
+
                /*
                 * Filter out and delete spurious events
                 */
                if ((u_int)(uintptr_t)kevp[i].udata !=
                    skap->lwp->lwp_kqueue_serial) {
-                       kev = kevp[i];
-                       kev.flags = EV_DISABLE|EV_DELETE;
-                       kqueue_register(&skap->lwp->lwp_kqueue, &kev);
-                       if (nseldebug)
+                       if (nseldebug) {
                                kprintf("select fd %ju mismatched serial %d\n",
                                        (uintmax_t)kevp[i].ident,
                                        skap->lwp->lwp_kqueue_serial);
+                       }
+               } else {
+                       switch (kevp[i].filter) {
+                       case EVFILT_READ:
+                               if (__predict_true(skap->oread_set != NULL &&
+                                   FD_ISSET(kevp[i].ident, skap->oread_set)))
+                                       fd_set = skap->read_set;
+                               break;
+
+                       case EVFILT_WRITE:
+                               if (__predict_true(skap->owrite_set != NULL &&
+                                   FD_ISSET(kevp[i].ident, skap->owrite_set)))
+                                       fd_set = skap->write_set;
+                               break;
+
+                       case EVFILT_EXCEPT:
+                               if (__predict_true(skap->oexcept_set != NULL &&
+                                   FD_ISSET(kevp[i].ident, skap->oexcept_set)))
+                                       fd_set = skap->except_set;
+                               break;
+                       }
+                       if (__predict_false(fd_set == NULL) && nseldebug) {
+                               kprintf("select leftover fd %ju, "
+                                   "serial wrapped\n",
+                                   (uintmax_t)kevp[i].ident);
+                       }
+               }
+               if (fd_set == NULL) {
+                       kev = kevp[i];
+                       kev.flags = EV_DISABLE|EV_DELETE;
+                       kqueue_register(&skap->lwp->lwp_kqueue, &kev);
                        continue;
                }
 
@@ -1052,18 +1085,7 @@ select_copyout(void *arg, struct kevent *kevp, int count, int *res)
                                        kevp[i].filter, error);
                        continue;
                }
-
-               switch (kevp[i].filter) {
-               case EVFILT_READ:
-                       FD_SET(kevp[i].ident, skap->read_set);
-                       break;
-               case EVFILT_WRITE:
-                       FD_SET(kevp[i].ident, skap->write_set);
-                       break;
-               case EVFILT_EXCEPT:
-                       FD_SET(kevp[i].ident, skap->except_set);
-                       break;
-               }
+               FD_SET(kevp[i].ident, fd_set);
 
                ++*res;
        }
@@ -1076,20 +1098,26 @@ select_copyout(void *arg, struct kevent *kevp, int count, int *res)
  * set is large.
  */
 static int
-getbits(int bytes, fd_set *in_set, kfd_set **out_set, kfd_set *tmp_set)
+getbits(int bytes, fd_set *in_set, kfd_set **out_set0, kfd_set **orig_set0,
+    kfd_set *tmp_set)
 {
-       int error;
+       kfd_set *out_set = NULL, *orig_set = NULL;
+       int error = 0;
 
        if (in_set) {
-               if (bytes < sizeof(*tmp_set))
-                       *out_set = tmp_set;
-               else
-                       *out_set = kmalloc(bytes, M_SELECT, M_WAITOK);
-               error = copyin(in_set, *out_set, bytes);
-       } else {
-               *out_set = NULL;
-               error = 0;
+               if (bytes < sizeof(kfd_set)) {
+                       out_set = tmp_set;
+                       orig_set = tmp_set + 1;
+               } else {
+                       out_set = kmalloc(bytes * 2, M_SELECT, M_WAITOK);
+                       orig_set = (kfd_set *)(((uint8_t *)out_set) + bytes);
+               }
+               error = copyin(in_set, out_set, bytes);
+               if (!error)
+                       memcpy(orig_set, out_set, bytes);
        }
+       *out_set0 = out_set;
+       *orig_set0 = orig_set;
        return (error);
 }
 
@@ -1129,9 +1157,9 @@ doselect(int nd, fd_set *read, fd_set *write, fd_set *except,
        struct proc *p = curproc;
        struct select_kevent_copyin_args *kap, ka;
        int bytes, error;
-       kfd_set read_tmp;
-       kfd_set write_tmp;
-       kfd_set except_tmp;
+       kfd_set read_tmp[2];
+       kfd_set write_tmp[2];
+       kfd_set except_tmp[2];
 
        *res = 0;
        if (nd < 0)
@@ -1159,11 +1187,15 @@ doselect(int nd, fd_set *read, fd_set *write, fd_set *except,
        kap->write_set = NULL;
        kap->except_set = NULL;
 
-       error = getbits(bytes, read, &kap->read_set, &read_tmp);
-       if (error == 0)
-               error = getbits(bytes, write, &kap->write_set, &write_tmp);
-       if (error == 0)
-               error = getbits(bytes, except, &kap->except_set, &except_tmp);
+       error = getbits(bytes, read, &kap->read_set, &kap->oread_set, read_tmp);
+       if (error == 0) {
+               error = getbits(bytes, write, &kap->write_set, &kap->owrite_set,
+                   write_tmp);
+       }
+       if (error == 0) {
+               error = getbits(bytes, except, &kap->except_set,
+                   &kap->oexcept_set, except_tmp);
+       }
        if (error)
                goto done;
 
@@ -1194,11 +1226,11 @@ doselect(int nd, fd_set *read, fd_set *write, fd_set *except,
         * Clean up.
         */
 done:
-       if (kap->read_set && kap->read_set != &read_tmp)
+       if (kap->read_set && kap->read_set != read_tmp)
                kfree(kap->read_set, M_SELECT);
-       if (kap->write_set && kap->write_set != &write_tmp)
+       if (kap->write_set && kap->write_set != write_tmp)
                kfree(kap->write_set, M_SELECT);
-       if (kap->except_set && kap->except_set != &except_tmp)
+       if (kap->except_set && kap->except_set != except_tmp)
                kfree(kap->except_set, M_SELECT);
 
        kap->lwp->lwp_kqueue_serial += kap->num_fds;