From FreeBSD:
[dragonfly.git] / lib / libcaps / caps_misc.c
1 /*
2  * CAPS_MISC.C
3  *
4  * Copyright (c) 2003 Matthew Dillon <dillon@backplane.com>
5  * All rights reserved.
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  * 1. Redistributions of source code must retain the above copyright
11  *    notice, this list of conditions and the following disclaimer.
12  * 2. Redistributions in binary form must reproduce the above copyright
13  *    notice, this list of conditions and the following disclaimer in the
14  *    documentation and/or other materials provided with the distribution.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
20  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
22  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
23  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
24  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
25  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
26  * SUCH DAMAGE.
27  *
28  * $DragonFly: src/lib/libcaps/Attic/caps_misc.c,v 1.1 2003/11/24 21:15:58 dillon Exp $
29  */
30
31 #include "defs.h"
32
33 caps_port_t
34 caps_mkport(
35     enum caps_type type,
36     int (*cs_putport)(lwkt_port_t port, lwkt_msg_t msg),
37     void *(*cs_waitport)(lwkt_port_t port, lwkt_msg_t msg),
38     void (*cs_replyport)(lwkt_port_t port, lwkt_msg_t msg)
39 ) {
40     caps_port_t port;
41
42     port = malloc(sizeof(*port));
43     bzero(port, sizeof(*port));
44
45     lwkt_initport(&port->lport, curthread);
46     port->lport.mp_putport = cs_putport;
47     port->lport.mp_waitport = cs_waitport;
48     port->lport.mp_replyport = cs_replyport;
49     port->lport.mp_refs = 1;
50     port->type = type;
51     port->kqfd = -1;            /* kqueue descriptor */
52     port->lfd = -1;             /* listen socket descriptor */
53     port->cfd = -1;             /* client socket descriptor */
54     TAILQ_INIT(&port->clist);   /* server connections */
55     TAILQ_INIT(&port->wlist);   /* writes in progress */
56     TAILQ_INIT(&port->mlist);   /* written messages waiting for reply */
57     port->cred.pid = (pid_t)-1;
58     port->cred.uid = (uid_t)-1;
59     port->cred.gid = (gid_t)-1;
60
61     port->rmsg = &port->rmsg_static;
62
63     return(port);
64 }
65
66 void
67 caps_shutdown(caps_port_t port)
68 {
69     caps_port_t scan;
70     lwkt_msg_t msg;
71
72     port->flags |= CAPPF_SHUTDOWN;
73     if (port->flags & CAPPF_ONLIST) {
74         --port->lport.mp_refs;
75         port->flags &= ~CAPPF_ONLIST;
76         TAILQ_REMOVE(&port->server->clist, port, centry);
77     }
78     if (port->kqfd >= 0) {
79         close(port->kqfd);
80         port->kqfd = -1;
81     }
82     if (port->lfd >= 0) {
83         close(port->lfd);
84         port->lfd = -1;
85     }
86     if (port->cfd >= 0) {
87         close(port->cfd);
88         port->cfd = -1;
89     }
90     port->rbytes = 0;
91     port->wbytes = 0;
92     while ((msg = TAILQ_FIRST(&port->wlist)) != NULL) {
93         TAILQ_REMOVE(&port->wlist, msg, ms_node);
94         msg->ms_flags &= ~MSGF_QUEUED;
95         if (port->type == CAPT_CLIENT)
96             lwkt_replymsg(msg, EIO);
97         else
98             free(msg);
99     }
100     while ((msg = TAILQ_FIRST(&port->mlist)) != NULL) {
101         TAILQ_REMOVE(&port->mlist, msg, ms_node);
102         msg->ms_flags &= ~MSGF_QUEUED;
103         lwkt_replymsg(msg, EIO);
104     }
105     if ((msg = port->rmsg) != NULL) {
106         port->rmsg = &port->rmsg_static;
107         if (msg != &port->rmsg_static)
108             free(msg);
109     }
110     while ((scan = TAILQ_FIRST(&port->clist)) != NULL) {
111         caps_shutdown(scan);
112     }
113     assert(port->lport.mp_refs >= 0);
114     if (port->lport.mp_refs == 0)
115         free(port);
116 }
117
118 void
119 caps_close(caps_port_t port)
120 {
121     --port->lport.mp_refs;
122     assert(port->lport.mp_refs >= 0);
123     caps_shutdown(port);
124 }
125
126 /*
127  * Start writing a new message to the socket and/or continue writing
128  * previously queued messages to the socket.
129  */
130 void
131 caps_kev_write(caps_port_t port, lwkt_msg_t msg)
132 {
133     struct kevent kev;
134     int n;
135
136     /*
137      * Add new messages to the queue
138      */
139     if (msg) {
140         msg->ms_flags |= MSGF_QUEUED;
141         TAILQ_INSERT_TAIL(&port->wlist, msg, ms_node);
142     }
143
144     /*
145      * Continue writing out the existing queue.  The message in
146      * progress is msg->ms_msgsize bytes long.  The opaque field in
147      * the over-the-wire version of the message contains a pointer
148      * to the message so we can match up replies.
149      */
150     ++port->lport.mp_refs;
151     while ((msg = TAILQ_FIRST(&port->wlist)) != NULL) {
152         lwkt_msg_t save;
153
154         /*
155          * Kinda messy
156          */
157         save = msg->opaque.ms_umsg;
158         if ((msg->ms_flags & MSGF_REPLY) == 0)
159             msg->opaque.ms_umsg = msg;
160         n = write(port->cfd, (char *)msg + port->wbytes, 
161                     msg->ms_msgsize - port->wbytes);
162         msg->opaque.ms_umsg = save;
163
164         DBPRINTF(("write %d/%d bytes\n" , n, msg->ms_msgsize - port->wbytes));
165         /* XXX handle failures.  Let the read side deal with it */
166         if (n <= 0)
167             break;
168         port->wbytes += n;
169         if (port->wbytes != msg->ms_msgsize)
170             break;
171         port->wbytes = 0;
172         TAILQ_REMOVE(&port->wlist, msg, ms_node);
173         msg->ms_flags &= ~MSGF_QUEUED;
174         if (msg->ms_flags & MSGF_REPLY) {
175             /*
176              * Finished writing reply, throw the message away.
177              */
178             free(msg);
179         } else {
180             /*
181              * Finished sending request, place message on mlist.
182              */
183             msg->ms_flags |= MSGF_QUEUED;
184             TAILQ_INSERT_TAIL(&port->mlist, msg, ms_node);
185         }
186     }
187
188     /*
189      * Do we need to wait for a write-availability event?   Note that
190      * the kevent calls can fail if the descriptor is no longer valid.
191      */
192     msg = TAILQ_FIRST(&port->wlist);
193     if (msg && (port->flags & CAPPF_WREQUESTED) == 0) {
194         port->flags |= CAPPF_WREQUESTED;
195         EV_SET(&kev, port->cfd, EVFILT_WRITE, EV_ADD|EV_ENABLE, 0, 0, port);
196         kevent(port->kqfd, &kev, 1, NULL, 0, NULL);
197     } else if (port->flags & CAPPF_WREQUESTED) {
198         port->flags &= ~CAPPF_WREQUESTED;
199         EV_SET(&kev, port->cfd, EVFILT_WRITE, EV_ADD|EV_DISABLE, 0, 0, port);
200         kevent(port->kqfd, &kev, 1, NULL, 0, NULL);
201     }
202     --port->lport.mp_refs;
203 }
204
205 /*
206  * Read a new message from the socket or continue reading messages from the
207  * socket.  If the message represents a reply it must be matched up against
208  * messages on the mlist, copied, and the mlist message returned instead.
209  */
210 lwkt_msg_t
211 caps_kev_read(caps_port_t port)
212 {
213     lwkt_msg_t msg;
214     int n;
215
216     /*
217      * If we are waiting for a cred the only permissable message is a
218      * creds message.
219      */
220     if (port->flags & CAPPF_WAITCRED) {
221         struct msghdr msghdr;
222         struct caps_creds_cmsg cmsg;
223
224         bzero(&msghdr, sizeof(msghdr));
225         bzero(&cmsg, sizeof(cmsg));
226         msghdr.msg_control = &cmsg;
227         msghdr.msg_controllen = sizeof(cmsg);
228         cmsg.cmsg.cmsg_len = sizeof(cmsg);
229         cmsg.cmsg.cmsg_type = 0;
230         if ((n = recvmsg(port->cfd, &msghdr, MSG_EOR)) < 0) {
231             if (errno == EINTR)
232                 return(NULL);
233         }
234         if (cmsg.cmsg.cmsg_type != SCM_CREDS) {
235             DBPRINTF(("server: expected SCM_CREDS\n"));
236             goto failed;
237         }
238         DBPRINTF(("server: connect from pid %d uid %d\n",
239                 (int)cmsg.cred.cmcred_pid, (int)cmsg.cred.cmcred_uid));
240         port->cred.pid = cmsg.cred.cmcred_pid;
241         port->cred.uid = cmsg.cred.cmcred_uid;
242         port->cred.euid = cmsg.cred.cmcred_euid;
243         port->cred.gid = cmsg.cred.cmcred_gid;
244         if ((port->cred.ngroups = cmsg.cred.cmcred_ngroups) > CAPS_MAXGROUPS)
245             port->cred.ngroups = CAPS_MAXGROUPS;
246         if (port->cred.ngroups < 0)
247             port->cred.ngroups = 0;
248         bcopy(cmsg.cred.cmcred_groups, port->cred.groups, 
249                 sizeof(gid_t) * port->cred.ngroups);
250         port->flags &= ~CAPPF_WAITCRED;
251         return(NULL);
252     }
253
254     /*
255      * Read or continue reading the next packet.  Use the static message
256      * while we are pulling in the header.
257      */
258     if (port->rmsg == &port->rmsg_static) {
259         n = read(port->cfd, (char *)port->rmsg + port->rbytes,
260                 sizeof(port->rmsg_static) - port->rbytes);
261         DBPRINTF(("read %d bytes\n" , n));
262         if (n <= 0) {
263                 if (errno == EINTR || errno == EAGAIN)
264                     return(NULL);
265                 goto failed;
266         }
267         port->rbytes += n;
268         if (port->rbytes != sizeof(port->rmsg_static))
269             return(NULL);
270         if (port->rmsg_static.ms_msgsize > port->rmsg_static.ms_maxsize ||
271             port->rmsg_static.ms_msgsize < sizeof(struct lwkt_msg) ||
272             port->rmsg_static.ms_maxsize > CAPMSG_MAXSIZE
273         ) {
274             goto failed;
275         }
276         port->rmsg = malloc(port->rmsg_static.ms_maxsize);
277         bcopy(&port->rmsg_static, port->rmsg, port->rbytes);
278     }
279     if (port->rbytes != port->rmsg->ms_msgsize) {
280         n = read(port->cfd, (char *)port->rmsg + port->rbytes,
281                 port->rmsg->ms_msgsize - port->rbytes);
282         if (n <= 0) {
283             if (errno == EINTR || errno == EAGAIN)
284                 return(NULL);
285             goto failed;
286         }
287         port->rbytes += n;
288         if (port->rbytes != port->rmsg->ms_msgsize)
289             return(NULL);
290     }
291     msg = port->rmsg;
292     port->rmsg = &port->rmsg_static;
293     port->rbytes = 0;
294
295     /*
296      * Setup the target port and the reply port
297      */
298     msg->ms_reply_port = &port->lport;
299     if (port->type == CAPT_REMOTE && port->server)
300         msg->ms_target_port = &port->server->lport;
301     else
302         msg->ms_target_port = &port->lport;
303
304     if (msg->ms_flags & MSGF_REPLY) {
305         /*
306          * If the message represents a reply we have to match it up against
307          * the original.
308          */
309         lwkt_msg_t scan;
310         lwkt_msg_t save_msg;
311         lwkt_port_t save_port1;
312         lwkt_port_t save_port2;
313
314         TAILQ_FOREACH(scan, &port->mlist, ms_node) {
315             if (msg->opaque.ms_umsg == scan)
316                 break;
317         }
318         DBPRINTF(("matchup: %p against %p\n", msg->opaque.ms_umsg, scan));
319         if (scan == NULL)
320             goto failed;
321         if (msg->ms_msgsize > scan->ms_maxsize)
322             goto failed;
323         TAILQ_REMOVE(&port->mlist, scan, ms_node);
324         save_msg = scan->opaque.ms_umsg;
325         save_port1 = scan->ms_target_port;
326         save_port2 = scan->ms_reply_port;
327         bcopy(msg, scan, msg->ms_msgsize);
328         scan->opaque.ms_umsg = save_msg;
329         scan->ms_target_port = save_port1;
330         scan->ms_reply_port = save_port2;
331         free(msg);
332         msg = scan;
333     } else {
334         /*
335          * New messages ref the port so it cannot go away until the last
336          * message has been replied.
337          */
338         ++port->lport.mp_refs;
339     }
340     return(msg);
341 failed:
342     caps_shutdown(port);
343     return(NULL);
344 }
345