Merge branch 'vendor/GREP'
[dragonfly.git] / contrib / tnftp / ssl.c
1 /*      $NetBSD: ssl.c,v 1.2 2012/12/24 22:12:28 christos Exp $ */
2
3 /*-
4  * Copyright (c) 1998-2004 Dag-Erling Coïdan Smørgrav
5  * Copyright (c) 2008, 2010 Joerg Sonnenberger <joerg@NetBSD.org>
6  * All rights reserved.
7  *
8  * Redistribution and use in source and binary forms, with or without
9  * modification, are permitted provided that the following conditions
10  * are met:
11  * 1. Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer
13  *    in this position and unchanged.
14  * 2. Redistributions in binary form must reproduce the above copyright
15  *    notice, this list of conditions and the following disclaimer in the
16  *    documentation and/or other materials provided with the distribution.
17  * 3. The name of the author may not be used to endorse or promote products
18  *    derived from this software without specific prior written permission
19  *
20  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
21  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
22  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
23  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
24  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
25  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
29  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30  *
31  * $FreeBSD: common.c,v 1.53 2007/12/19 00:26:36 des Exp $
32  */
33
34 #include <sys/cdefs.h>
35 #ifndef lint
36 __RCSID("$NetBSD: ssl.c,v 1.2 2012/12/24 22:12:28 christos Exp $");
37 #endif
38
39 #include <time.h>
40 #include <unistd.h>
41 #include <fcntl.h>
42
43 #include <sys/param.h>
44 #include <sys/select.h>
45 #include <sys/socket.h>
46 #include <sys/uio.h>
47
48 #include <netinet/tcp.h>
49 #include <netinet/in.h>
50 #include <openssl/crypto.h>
51 #include <openssl/x509.h>
52 #include <openssl/pem.h>
53 #include <openssl/ssl.h>
54 #include <openssl/err.h>
55
56 #include "ssl.h"
57
58 extern int quit_time, verbose, ftp_debug;
59 extern FILE *ttyout;
60
61 struct fetch_connect {
62         int                      sd;            /* file/socket descriptor */
63         char                    *buf;           /* buffer */
64         size_t                   bufsize;       /* buffer size */
65         size_t                   bufpos;        /* position of buffer */
66         size_t                   buflen;        /* length of buffer contents */
67         struct {                                /* data cached after an
68                                                    interrupted read */
69                 char    *buf;
70                 size_t   size;
71                 size_t   pos;
72                 size_t   len;
73         } cache;
74         int                      issock;
75         int                      iserr;
76         int                      iseof;
77         SSL                     *ssl;           /* SSL handle */
78 };
79
80 /*
81  * Write a vector to a connection w/ timeout
82  * Note: can modify the iovec.
83  */
84 static ssize_t
85 fetch_writev(struct fetch_connect *conn, struct iovec *iov, int iovcnt)
86 {
87         struct timeval now, timeout, delta;
88         fd_set writefds;
89         ssize_t len, total;
90         int r;
91
92         if (quit_time > 0) {
93                 FD_ZERO(&writefds);
94                 gettimeofday(&timeout, NULL);
95                 timeout.tv_sec += quit_time;
96         }
97
98         total = 0;
99         while (iovcnt > 0) {
100                 while (quit_time > 0 && !FD_ISSET(conn->sd, &writefds)) {
101                         FD_SET(conn->sd, &writefds);
102                         gettimeofday(&now, NULL);
103                         delta.tv_sec = timeout.tv_sec - now.tv_sec;
104                         delta.tv_usec = timeout.tv_usec - now.tv_usec;
105                         if (delta.tv_usec < 0) {
106                                 delta.tv_usec += 1000000;
107                                 delta.tv_sec--;
108                         }
109                         if (delta.tv_sec < 0) {
110                                 errno = ETIMEDOUT;
111                                 return -1;
112                         }
113                         errno = 0;
114                         r = select(conn->sd + 1, NULL, &writefds, NULL, &delta);
115                         if (r == -1) {
116                                 if (errno == EINTR)
117                                         continue;
118                                 return -1;
119                         }
120                 }
121                 errno = 0;
122                 if (conn->ssl != NULL)
123                         len = SSL_write(conn->ssl, iov->iov_base, iov->iov_len);
124                 else
125                         len = writev(conn->sd, iov, iovcnt);
126                 if (len == 0) {
127                         /* we consider a short write a failure */
128                         /* XXX perhaps we shouldn't in the SSL case */
129                         errno = EPIPE;
130                         return -1;
131                 }
132                 if (len < 0) {
133                         if (errno == EINTR)
134                                 continue;
135                         return -1;
136                 }
137                 total += len;
138                 while (iovcnt > 0 && len >= (ssize_t)iov->iov_len) {
139                         len -= iov->iov_len;
140                         iov++;
141                         iovcnt--;
142                 }
143                 if (iovcnt > 0) {
144                         iov->iov_len -= len;
145                         iov->iov_base = (char *)iov->iov_base + len;
146                 }
147         }
148         return total;
149 }
150
151 /*
152  * Write to a connection w/ timeout
153  */
154 static int
155 fetch_write(struct fetch_connect *conn, const char *str, size_t len)
156 {
157         struct iovec iov[1];
158
159         iov[0].iov_base = __DECONST(char *, str);
160         iov[0].iov_len = len;
161         return fetch_writev(conn, iov, 1);
162 }
163
164 /*
165  * Send a formatted line; optionally echo to terminal
166  */
167 int
168 fetch_printf(struct fetch_connect *conn, const char *fmt, ...)
169 {
170         va_list ap;
171         size_t len;
172         char *msg;
173         int r;
174
175         va_start(ap, fmt);
176         len = vasprintf(&msg, fmt, ap);
177         va_end(ap);
178
179         if (msg == NULL) {
180                 errno = ENOMEM;
181                 return -1;
182         }
183
184         r = fetch_write(conn, msg, len);
185         free(msg);
186         return r;
187 }
188
189 int
190 fetch_fileno(struct fetch_connect *conn)
191 {
192
193         return conn->sd;
194 }
195
196 int
197 fetch_error(struct fetch_connect *conn)
198 {
199
200         return conn->iserr;
201 }
202
203 static void
204 fetch_clearerr(struct fetch_connect *conn)
205 {
206
207         conn->iserr = 0;
208 }
209
210 int
211 fetch_flush(struct fetch_connect *conn)
212 {
213         int v;
214
215         if (conn->issock) {
216 #ifdef TCP_NOPUSH
217                 v = 0;
218                 setsockopt(conn->sd, IPPROTO_TCP, TCP_NOPUSH, &v, sizeof(v));
219 #endif
220                 v = 1;
221                 setsockopt(conn->sd, IPPROTO_TCP, TCP_NODELAY, &v, sizeof(v));
222         }
223         return 0;
224 }
225
226 /*ARGSUSED*/
227 struct fetch_connect *
228 fetch_open(const char *fname, const char *fmode)
229 {
230         struct fetch_connect *conn;
231         int fd;
232
233         fd = open(fname, O_RDONLY); /* XXX: fmode */
234         if (fd < 0)
235                 return NULL;
236
237         if ((conn = calloc(1, sizeof(*conn))) == NULL) {
238                 close(fd);
239                 return NULL;
240         }
241
242         conn->sd = fd;
243         conn->issock = 0;
244         return conn;
245 }
246
247 /*ARGSUSED*/
248 struct fetch_connect *
249 fetch_fdopen(int sd, const char *fmode)
250 {
251         struct fetch_connect *conn;
252 #if defined(SO_NOSIGPIPE) || defined(TCP_NOPUSH)
253         int opt = 1;
254 #endif
255
256         if ((conn = calloc(1, sizeof(*conn))) == NULL)
257                 return NULL;
258
259         conn->sd = sd;
260         conn->issock = 1;
261         fcntl(sd, F_SETFD, FD_CLOEXEC);
262 #ifdef SO_NOSIGPIPE
263         setsockopt(sd, SOL_SOCKET, SO_NOSIGPIPE, &opt, sizeof(opt));
264 #endif
265 #ifdef TCP_NOPUSH
266         setsockopt(sd, IPPROTO_TCP, TCP_NOPUSH, &opt, sizeof(opt));
267 #endif
268         return conn;
269 }
270
271 int
272 fetch_close(struct fetch_connect *conn)
273 {
274         int rv = 0;
275
276         if (conn != NULL) {
277                 fetch_flush(conn);
278                 SSL_free(conn->ssl);
279                 rv = close(conn->sd);
280                 if (rv < 0) {
281                         errno = rv;
282                         rv = EOF;
283                 }
284                 free(conn->cache.buf);
285                 free(conn->buf);
286                 free(conn);
287         }
288         return rv;
289 }
290
291 #define FETCH_READ_WAIT         -2
292 #define FETCH_READ_ERROR        -1
293
294 static ssize_t
295 fetch_ssl_read(SSL *ssl, void *buf, size_t len)
296 {
297         ssize_t rlen;
298         int ssl_err;
299
300         rlen = SSL_read(ssl, buf, len);
301         if (rlen < 0) {
302                 ssl_err = SSL_get_error(ssl, rlen);
303                 if (ssl_err == SSL_ERROR_WANT_READ ||
304                     ssl_err == SSL_ERROR_WANT_WRITE) {
305                         return FETCH_READ_WAIT;
306                 }
307                 ERR_print_errors_fp(ttyout);
308                 return FETCH_READ_ERROR;
309         }
310         return rlen;
311 }
312
313 static ssize_t
314 fetch_nonssl_read(int sd, void *buf, size_t len)
315 {
316         ssize_t rlen;
317
318         rlen = read(sd, buf, len);
319         if (rlen < 0) {
320                 if (errno == EAGAIN || errno == EINTR)
321                         return FETCH_READ_WAIT;
322                 return FETCH_READ_ERROR;
323         }
324         return rlen;
325 }
326
327 /*
328  * Cache some data that was read from a socket but cannot be immediately
329  * returned because of an interrupted system call.
330  */
331 static int
332 fetch_cache_data(struct fetch_connect *conn, char *src, size_t nbytes)
333 {
334
335         if (conn->cache.size < nbytes) {
336                 char *tmp = realloc(conn->cache.buf, nbytes);
337                 if (tmp == NULL)
338                         return -1;
339
340                 conn->cache.buf = tmp;
341                 conn->cache.size = nbytes;
342         }
343
344         memcpy(conn->cache.buf, src, nbytes);
345         conn->cache.len = nbytes;
346         conn->cache.pos = 0;
347         return 0;
348 }
349
350 ssize_t
351 fetch_read(void *ptr, size_t size, size_t nmemb, struct fetch_connect *conn)
352 {
353         struct timeval now, timeout, delta;
354         fd_set readfds;
355         ssize_t rlen, total;
356         size_t len;
357         char *start, *buf;
358
359         if (quit_time > 0) {
360                 gettimeofday(&timeout, NULL);
361                 timeout.tv_sec += quit_time;
362         }
363
364         total = 0;
365         start = buf = ptr;
366         len = size * nmemb;
367
368         if (conn->cache.len > 0) {
369                 /*
370                  * The last invocation of fetch_read was interrupted by a
371                  * signal after some data had been read from the socket. Copy
372                  * the cached data into the supplied buffer before trying to
373                  * read from the socket again.
374                  */
375                 total = (conn->cache.len < len) ? conn->cache.len : len;
376                 memcpy(buf, conn->cache.buf, total);
377
378                 conn->cache.len -= total;
379                 conn->cache.pos += total;
380                 len -= total;
381                 buf += total;
382         }
383
384         while (len > 0) {
385                 /*
386                  * The socket is non-blocking.  Instead of the canonical
387                  * select() -> read(), we do the following:
388                  *
389                  * 1) call read() or SSL_read().
390                  * 2) if an error occurred, return -1.
391                  * 3) if we received data but we still expect more,
392                  *    update our counters and loop.
393                  * 4) if read() or SSL_read() signaled EOF, return.
394                  * 5) if we did not receive any data but we're not at EOF,
395                  *    call select().
396                  *
397                  * In the SSL case, this is necessary because if we
398                  * receive a close notification, we have to call
399                  * SSL_read() one additional time after we've read
400                  * everything we received.
401                  *
402                  * In the non-SSL case, it may improve performance (very
403                  * slightly) when reading small amounts of data.
404                  */
405                 if (conn->ssl != NULL)
406                         rlen = fetch_ssl_read(conn->ssl, buf, len);
407                 else
408                         rlen = fetch_nonssl_read(conn->sd, buf, len);
409                 if (rlen == 0) {
410                         break;
411                 } else if (rlen > 0) {
412                         len -= rlen;
413                         buf += rlen;
414                         total += rlen;
415                         continue;
416                 } else if (rlen == FETCH_READ_ERROR) {
417                         if (errno == EINTR)
418                                 fetch_cache_data(conn, start, total);
419                         return -1;
420                 }
421                 FD_ZERO(&readfds);
422                 while (!FD_ISSET(conn->sd, &readfds)) {
423                         FD_SET(conn->sd, &readfds);
424                         if (quit_time > 0) {
425                                 gettimeofday(&now, NULL);
426                                 if (!timercmp(&timeout, &now, >)) {
427                                         errno = ETIMEDOUT;
428                                         return -1;
429                                 }
430                                 timersub(&timeout, &now, &delta);
431                         }
432                         errno = 0;
433                         if (select(conn->sd + 1, &readfds, NULL, NULL,
434                                 quit_time > 0 ? &delta : NULL) < 0) {
435                                 if (errno == EINTR)
436                                         continue;
437                                 return -1;
438                         }
439                 }
440         }
441         return total;
442 }
443
444 #define MIN_BUF_SIZE 1024
445
446 /*
447  * Read a line of text from a connection w/ timeout
448  */
449 char *
450 fetch_getln(char *str, int size, struct fetch_connect *conn)
451 {
452         size_t tmpsize;
453         ssize_t len;
454         char c;
455
456         if (conn->buf == NULL) {
457                 if ((conn->buf = malloc(MIN_BUF_SIZE)) == NULL) {
458                         errno = ENOMEM;
459                         conn->iserr = 1;
460                         return NULL;
461                 }
462                 conn->bufsize = MIN_BUF_SIZE;
463         }
464
465         if (conn->iserr || conn->iseof)
466                 return NULL;
467
468         if (conn->buflen - conn->bufpos > 0)
469                 goto done;
470
471         conn->buf[0] = '\0';
472         conn->bufpos = 0;
473         conn->buflen = 0;
474         do {
475                 len = fetch_read(&c, sizeof(c), 1, conn);
476                 if (len == -1) {
477                         conn->iserr = 1;
478                         return NULL;
479                 }
480                 if (len == 0) {
481                         conn->iseof = 1;
482                         break;
483                 }
484                 conn->buf[conn->buflen++] = c;
485                 if (conn->buflen == conn->bufsize) {
486                         char *tmp = conn->buf;
487                         tmpsize = conn->bufsize * 2 + 1;
488                         if ((tmp = realloc(tmp, tmpsize)) == NULL) {
489                                 errno = ENOMEM;
490                                 conn->iserr = 1;
491                                 return NULL;
492                         }
493                         conn->buf = tmp;
494                         conn->bufsize = tmpsize;
495                 }
496         } while (c != '\n');
497
498         if (conn->buflen == 0)
499                 return NULL;
500  done:
501         tmpsize = MIN(size - 1, (int)(conn->buflen - conn->bufpos));
502         memcpy(str, conn->buf + conn->bufpos, tmpsize);
503         str[tmpsize] = '\0';
504         conn->bufpos += tmpsize;
505         return str;
506 }
507
508 int
509 fetch_getline(struct fetch_connect *conn, char *buf, size_t buflen,
510     const char **errormsg)
511 {
512         size_t len;
513         int rv;
514
515         if (fetch_getln(buf, buflen, conn) == NULL) {
516                 if (conn->iseof) {      /* EOF */
517                         rv = -2;
518                         if (errormsg)
519                                 *errormsg = "\nEOF received";
520                 } else {                /* error */
521                         rv = -1;
522                         if (errormsg)
523                                 *errormsg = "Error encountered";
524                 }
525                 fetch_clearerr(conn);
526                 return rv;
527         }
528         len = strlen(buf);
529         if (buf[len - 1] == '\n') {     /* clear any trailing newline */
530                 buf[--len] = '\0';
531         } else if (len == buflen - 1) { /* line too long */
532                 while (1) {
533                         char c;
534                         ssize_t rlen = fetch_read(&c, sizeof(c), 1, conn);
535                         if (rlen <= 0 || c == '\n')
536                                 break;
537                 }
538                 if (errormsg)
539                         *errormsg = "Input line is too long";
540                 fetch_clearerr(conn);
541                 return -3;
542         }
543         if (errormsg)
544                 *errormsg = NULL;
545         return len;
546 }
547
548 void *
549 fetch_start_ssl(int sock)
550 {
551         SSL *ssl;
552         SSL_CTX *ctx;
553         int ret, ssl_err;
554
555         /* Init the SSL library and context */
556         if (!SSL_library_init()){
557                 fprintf(ttyout, "SSL library init failed\n");
558                 return NULL;
559         }
560
561         SSL_load_error_strings();
562
563         ctx = SSL_CTX_new(SSLv23_client_method());
564         SSL_CTX_set_mode(ctx, SSL_MODE_AUTO_RETRY);
565
566         ssl = SSL_new(ctx);
567         if (ssl == NULL){
568                 fprintf(ttyout, "SSL context creation failed\n");
569                 SSL_CTX_free(ctx);
570                 return NULL;
571         }
572         SSL_set_fd(ssl, sock);
573         while ((ret = SSL_connect(ssl)) == -1) {
574                 ssl_err = SSL_get_error(ssl, ret);
575                 if (ssl_err != SSL_ERROR_WANT_READ &&
576                     ssl_err != SSL_ERROR_WANT_WRITE) {
577                         ERR_print_errors_fp(ttyout);
578                         SSL_free(ssl);
579                         return NULL;
580                 }
581         }
582
583         if (ftp_debug && verbose) {
584                 X509 *cert;
585                 X509_NAME *name;
586                 char *str;
587
588                 fprintf(ttyout, "SSL connection established using %s\n",
589                     SSL_get_cipher(ssl));
590                 cert = SSL_get_peer_certificate(ssl);
591                 name = X509_get_subject_name(cert);
592                 str = X509_NAME_oneline(name, 0, 0);
593                 fprintf(ttyout, "Certificate subject: %s\n", str);
594                 free(str);
595                 name = X509_get_issuer_name(cert);
596                 str = X509_NAME_oneline(name, 0, 0);
597                 fprintf(ttyout, "Certificate issuer: %s\n", str);
598                 free(str);
599         }
600
601         return ssl;
602 }
603
604
605 void
606 fetch_set_ssl(struct fetch_connect *conn, void *ssl)
607 {
608         conn->ssl = ssl;
609 }