From 203bf8e27fe3863327c957b658ab30c9e392ec33 Mon Sep 17 00:00:00 2001 From: Sepherosa Ziehau Date: Tue, 29 Aug 2017 13:51:59 +0800 Subject: [PATCH] Revert "select: Don't allow unwanted/leftover fds being returned." This reverts commit ce4975442fa0524017fb3c1aef93bbe6880ae770. --- sys/kern/sys_generic.c | 106 ++++++++++++++--------------------------- 1 file changed, 37 insertions(+), 69 deletions(-) diff --git a/sys/kern/sys_generic.c b/sys/kern/sys_generic.c index 950424e506..e1fe938811 100644 --- a/sys/kern/sys_generic.c +++ b/sys/kern/sys_generic.c @@ -87,9 +87,6 @@ struct select_kevent_copyin_args { kfd_set *read_set; kfd_set *write_set; kfd_set *except_set; - kfd_set *oread_set; /* orig set, points into read_set */ - kfd_set *owrite_set; /* orig set, points into write_set */ - kfd_set *oexcept_set; /* orig set, points into except_set */ int active_set; /* One of select_copyin_states */ struct lwp *lwp; /* Pointer to our lwp */ int num_fds; /* Number of file descriptors (syscall arg) */ @@ -993,48 +990,18 @@ select_copyout(void *arg, struct kevent *kevp, int count, int *res) skap = (struct select_kevent_copyin_args *)arg; for (i = 0; i < count; ++i) { - kfd_set *fd_set = NULL; - /* * Filter out and delete spurious events */ if ((u_int)(uintptr_t)kevp[i].udata != skap->lwp->lwp_kqueue_serial) { - if (nseldebug) { - kprintf("select fd %ju mismatched serial %d\n", - (uintmax_t)kevp[i].ident, - skap->lwp->lwp_kqueue_serial); - } - } else { - switch (kevp[i].filter) { - case EVFILT_READ: - if (__predict_true(skap->oread_set != NULL && - FD_ISSET(kevp[i].ident, skap->oread_set))) - fd_set = skap->read_set; - break; - - case EVFILT_WRITE: - if (__predict_true(skap->owrite_set != NULL && - FD_ISSET(kevp[i].ident, skap->owrite_set))) - fd_set = skap->write_set; - break; - - case EVFILT_EXCEPT: - if (__predict_true(skap->oexcept_set != NULL && - FD_ISSET(kevp[i].ident, skap->oexcept_set))) - fd_set = skap->except_set; - break; - } - if (__predict_false(fd_set == NULL) && nseldebug) { - kprintf("select leftover fd %ju, " - "serial wrapped\n", - (uintmax_t)kevp[i].ident); - } - } - if (fd_set == NULL) { kev = kevp[i]; kev.flags = EV_DISABLE|EV_DELETE; kqueue_register(&skap->lwp->lwp_kqueue, &kev); + if (nseldebug) + kprintf("select fd %ju mismatched serial %d\n", + (uintmax_t)kevp[i].ident, + skap->lwp->lwp_kqueue_serial); continue; } @@ -1085,7 +1052,18 @@ select_copyout(void *arg, struct kevent *kevp, int count, int *res) kevp[i].filter, error); continue; } - FD_SET(kevp[i].ident, fd_set); + + switch (kevp[i].filter) { + case EVFILT_READ: + FD_SET(kevp[i].ident, skap->read_set); + break; + case EVFILT_WRITE: + FD_SET(kevp[i].ident, skap->write_set); + break; + case EVFILT_EXCEPT: + FD_SET(kevp[i].ident, skap->except_set); + break; + } ++*res; } @@ -1098,26 +1076,20 @@ select_copyout(void *arg, struct kevent *kevp, int count, int *res) * set is large. */ static int -getbits(int bytes, fd_set *in_set, kfd_set **out_set0, kfd_set **orig_set0, - kfd_set *tmp_set) +getbits(int bytes, fd_set *in_set, kfd_set **out_set, kfd_set *tmp_set) { - kfd_set *out_set = NULL, *orig_set = NULL; - int error = 0; + int error; if (in_set) { - if (bytes < sizeof(kfd_set)) { - out_set = tmp_set; - orig_set = tmp_set + 1; - } else { - out_set = kmalloc(bytes * 2, M_SELECT, M_WAITOK); - orig_set = (kfd_set *)(((uint8_t *)out_set) + bytes); - } - error = copyin(in_set, out_set, bytes); - if (!error) - memcpy(orig_set, out_set, bytes); + if (bytes < sizeof(*tmp_set)) + *out_set = tmp_set; + else + *out_set = kmalloc(bytes, M_SELECT, M_WAITOK); + error = copyin(in_set, *out_set, bytes); + } else { + *out_set = NULL; + error = 0; } - *out_set0 = out_set; - *orig_set0 = orig_set; return (error); } @@ -1157,9 +1129,9 @@ doselect(int nd, fd_set *read, fd_set *write, fd_set *except, struct proc *p = curproc; struct select_kevent_copyin_args *kap, ka; int bytes, error; - kfd_set read_tmp[2]; - kfd_set write_tmp[2]; - kfd_set except_tmp[2]; + kfd_set read_tmp; + kfd_set write_tmp; + kfd_set except_tmp; *res = 0; if (nd < 0) @@ -1187,15 +1159,11 @@ doselect(int nd, fd_set *read, fd_set *write, fd_set *except, kap->write_set = NULL; kap->except_set = NULL; - error = getbits(bytes, read, &kap->read_set, &kap->oread_set, read_tmp); - if (error == 0) { - error = getbits(bytes, write, &kap->write_set, &kap->owrite_set, - write_tmp); - } - if (error == 0) { - error = getbits(bytes, except, &kap->except_set, - &kap->oexcept_set, except_tmp); - } + error = getbits(bytes, read, &kap->read_set, &read_tmp); + if (error == 0) + error = getbits(bytes, write, &kap->write_set, &write_tmp); + if (error == 0) + error = getbits(bytes, except, &kap->except_set, &except_tmp); if (error) goto done; @@ -1226,11 +1194,11 @@ doselect(int nd, fd_set *read, fd_set *write, fd_set *except, * Clean up. */ done: - if (kap->read_set && kap->read_set != read_tmp) + if (kap->read_set && kap->read_set != &read_tmp) kfree(kap->read_set, M_SELECT); - if (kap->write_set && kap->write_set != write_tmp) + if (kap->write_set && kap->write_set != &write_tmp) kfree(kap->write_set, M_SELECT); - if (kap->except_set && kap->except_set != except_tmp) + if (kap->except_set && kap->except_set != &except_tmp) kfree(kap->except_set, M_SELECT); kap->lwp->lwp_kqueue_serial += kap->num_fds; -- 2.41.0