ec95e59116cbf360ddf8b7f718f9aadb05af7bd8
[dragonfly.git] / lib / libc / sysvipc / sockets.c
1 /**
2  * Copyright (c) 2013 Larisa Grigore.  All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  * 3. The name of the author may not be used to endorse or promote products
13  *    derived from this software without specific prior written permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
17  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
18  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
19  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
20  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
21  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
22  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
24  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25  */
26
27 #include <sys/param.h>
28 #include <sys/un.h>
29 #include <sys/uio.h>
30 #include <sys/types.h>
31 #include <sys/stat.h>
32 #include <err.h>
33 #include <errno.h>
34 #include <fcntl.h>
35 #include <signal.h>
36 #include <stdio.h>
37 #include <stdlib.h>
38 #include <string.h>
39 #include <unistd.h>
40
41 #include "sysvipc_utils.h"
42 #include "sysvipc_sockets.h"
43
44 #define MAX_CONN        10
45
46 int
47 init_socket(const char *sockfile)
48 {
49         struct sockaddr_un un_addr;
50         int sock;
51
52         /* create server socket */
53         if ( (sock = socket(PF_UNIX, SOCK_STREAM, 0)) < 0) {
54                 sysv_print_err("init socket");
55                 return (-1);
56         }
57
58         /* bind it */
59         memset(&un_addr, 0, sizeof(un_addr));
60         un_addr.sun_len = sizeof(un_addr);
61         un_addr.sun_family = AF_UNIX;
62         strcpy(un_addr.sun_path, sockfile);
63
64         unlink(un_addr.sun_path);
65
66         if (bind(sock, (struct sockaddr *)&un_addr, sizeof(un_addr)) < 0) {
67                 close(sock);
68                 sysv_print_err("bind");
69                 return (-1);
70         }
71
72         if (listen(sock, MAX_CONN) < 0) {
73                 close(sock);
74                 sysv_print_err("listen");
75                 return (-1);
76         }
77
78         /* turn on credentials passing */
79         return (sock);
80 }
81
82 int
83 handle_new_connection(int sock)
84 {
85         int fd, flags;
86
87         do {
88                 fd = accept(sock, NULL, NULL);
89         } while (fd < 0 && errno == EINTR);
90
91         if (fd < 0) {
92                 sysv_print_err("accept");
93                 return (-1);
94         }
95
96         flags = fcntl(fd, F_GETFL, 0);
97         fcntl(fd, F_SETFL, flags & ~O_NONBLOCK);
98
99         return (fd);
100 }
101
102 int
103 connect_to_daemon(const char *sockfile)
104 {
105         int sock, flags;
106         struct sockaddr_un serv_addr;
107
108         if ((sock = socket(AF_UNIX, SOCK_STREAM, 0)) < 0) {
109                 sysv_print_err("socket(%d)\n", sock);
110                 return (-1);
111         }
112
113         flags = fcntl(sock, F_GETFL, 0);
114         fcntl(sock, F_SETFL, flags & ~O_NONBLOCK);
115
116         memset(&serv_addr, 0, sizeof(serv_addr));
117         serv_addr.sun_family = AF_UNIX;
118         strcpy(serv_addr.sun_path, sockfile);
119
120         if (connect(sock, (struct sockaddr *)&serv_addr,
121                                 sizeof(serv_addr)) < 0) {
122                 close(sock);
123                 sysv_print_err("connect(%d)\n", sock);
124                 return (-1);
125         }
126
127         return (sock);
128 }
129
130 int
131 send_fd(int sock, int fd)
132 {
133         struct msghdr msg;
134         struct iovec vec;
135 #ifndef HAVE_ACCRIGHTS_IN_MSGHDR
136         union {
137                 struct cmsghdr hdr;
138                 char buf[CMSG_SPACE(sizeof(int))];
139         } cmsgbuf;
140         struct cmsghdr *cmsg;
141 #endif
142         int result = 0;
143         ssize_t n;
144
145         memset(&msg, 0, sizeof(msg));
146
147         if (fd < 0)
148                 result = errno;
149         else {
150 #ifdef HAVE_ACCRIGHTS_IN_MSGHDR
151                 msg.msg_accrights = (caddr_t)&fd;
152                 msg.msg_accrightslen = sizeof(fd);
153 #else
154                 msg.msg_control = (caddr_t)cmsgbuf.buf;
155                 msg.msg_controllen = sizeof(cmsgbuf.buf);
156                 cmsg = CMSG_FIRSTHDR(&msg);
157                 cmsg->cmsg_len = CMSG_LEN(sizeof(int));
158                 cmsg->cmsg_level = SOL_SOCKET;
159                 cmsg->cmsg_type = SCM_RIGHTS;
160                 *(int *)CMSG_DATA(cmsg) = fd;
161 #endif
162         }
163
164         vec.iov_base = (caddr_t)&result;
165         vec.iov_len = sizeof(int);
166         msg.msg_iov = &vec;
167         msg.msg_iovlen = 1;
168
169         if ((n = sendmsg(sock, &msg, 0)) == -1) {
170                 sysv_print_err("sendmsg(%d)\n",
171                                 sock, getpid());
172                 return (-1);
173         }
174         if (n != sizeof(int)) {
175                 sysv_print_err("sendmsg: expected sent 1 got %ld\n",
176                                 (long)n);
177                 return (-1);
178         }
179
180         return (0);
181 }
182
183 /**/
184 int
185 receive_fd(int sock)
186 {
187         struct msghdr msg;
188         struct iovec vec;
189 #ifndef HAVE_ACCRIGHTS_IN_MSGHDR
190         union {
191                 struct cmsghdr hdr;
192                 char buf[CMSG_SPACE(sizeof(int))];
193         } cmsgbuf;
194         struct cmsghdr *cmsg;
195 #endif
196         ssize_t n;
197         int result;
198         int fd;
199
200         memset(&msg, 0, sizeof(msg));
201         vec.iov_base = (caddr_t)&result;
202         vec.iov_len = sizeof(int);
203         msg.msg_iov = &vec;
204         msg.msg_iovlen = 1;
205
206 #ifdef HAVE_ACCRIGHTS_IN_MSGHDR
207         msg.msg_accrights = (caddr_t)&fd;
208         msg.msg_accrightslen = sizeof(fd);
209 #else
210         msg.msg_control = &cmsgbuf.buf;
211         msg.msg_controllen = sizeof(cmsgbuf.buf);
212 #endif
213
214         if ((n = recvmsg(sock, &msg, 0)) == -1)
215                 sysv_print_err("recvmsg\n");
216         if (n != sizeof(int)) {
217                 sysv_print_err("recvmsg: expected received 1 got %ld\n",
218                                 (long)n);
219         }
220         if (result == 0) {
221                 cmsg = CMSG_FIRSTHDR(&msg);
222                 if (cmsg == NULL) {
223                         sysv_print_err("no message header\n");
224                         return (-1);
225                 }
226                 if (cmsg->cmsg_type != SCM_RIGHTS)
227                         sysv_print_err("expected type %d got %d\n",
228                                         SCM_RIGHTS, cmsg->cmsg_type);
229
230                 fd = (*(int *)CMSG_DATA(cmsg));
231                 return (fd);
232         } else {
233                 errno = result;
234                 return (-1);
235         }
236 }
237
238 static void
239 close_fds(int *fds, int num_fds) {
240         int i;
241
242         for (i=0; i < num_fds; i++)
243                 close(fds[i]);
244 }
245
246 /* Send with the message, credentials too. */
247 int
248 send_msg_with_cred(int sock, char *buffer, size_t size) {
249         struct msghdr msg;
250         struct iovec vec;
251         ssize_t n;
252         
253         struct {
254                 struct cmsghdr hdr;
255                 char cred[CMSG_SPACE(sizeof(struct cmsgcred))];
256         } cmsg;
257
258         memset(&cmsg, 0, sizeof(cmsg));
259         cmsg.hdr.cmsg_len =  CMSG_LEN(sizeof(struct cmsgcred));
260         cmsg.hdr.cmsg_level = SOL_SOCKET;
261         cmsg.hdr.cmsg_type = SCM_CREDS;
262
263         memset(&msg, 0, sizeof(struct msghdr));
264         msg.msg_iov = &vec;
265         msg.msg_iovlen = 1;
266         msg.msg_control = (caddr_t)&cmsg;
267         msg.msg_controllen = CMSG_SPACE(sizeof(struct cmsgcred));
268
269         vec.iov_base = buffer;
270         vec.iov_len = size;
271
272         if ((n = sendmsg(sock, &msg, 0)) == -1) {
273                 sysv_print_err("sendmsg on fd %d\n", sock);
274                 return (-1);
275         }
276
277         return (0);
278 }
279
280 /* Receive a message and the credentials of the sender. */
281 int
282 receive_msg_with_cred(int sock, char *buffer, size_t size,
283                 struct cmsgcred *cred) {
284         struct msghdr msg = {0};
285         struct iovec vec;
286         ssize_t n;
287         int result;
288         struct cmsghdr *cmp;
289         struct {
290                 struct cmsghdr hdr;
291                 char cred[CMSG_SPACE(sizeof(struct cmsgcred))];
292         } cmsg;
293
294         memset(&msg, 0, sizeof(msg));
295         vec.iov_base = buffer;
296         vec.iov_len = size;
297         msg.msg_iov = &vec;
298         msg.msg_iovlen = 1;
299
300         msg.msg_control = &cmsg;
301         msg.msg_controllen = sizeof(cmsg);
302
303         do {
304                 n = recvmsg(sock, &msg, 0);
305         } while (n < 0 && errno == EINTR);
306
307         if (n < 0) {
308                 sysv_print_err("recvmsg on fd %d\n", sock);
309                 return (-1);
310         }
311
312         if (n == 0) {
313                 return (-1);
314         }
315
316         result = -1;
317         cmp = CMSG_FIRSTHDR(&msg);
318
319         while(cmp != NULL) {
320                 if (cmp->cmsg_level == SOL_SOCKET
321                                 && cmp->cmsg_type  == SCM_CREDS) {
322                         if (cred)
323                                 memcpy(cred, CMSG_DATA(cmp), sizeof(*cred));
324                         result = n;
325                 } else if (cmp->cmsg_level == SOL_SOCKET
326                                 && cmp->cmsg_type  == SCM_RIGHTS) {
327                         close_fds((int *) CMSG_DATA(cmp),
328                                         (cmp->cmsg_len - CMSG_LEN(0))
329                                         / sizeof(int));
330                 }
331                 cmp = CMSG_NXTHDR(&msg, cmp);
332         }
333
334         return (result);
335 }