b34c864693f1dce71e8165677140485f8053243e
[dragonfly.git] / contrib / tnftp / ssl.c
1 /*      $NetBSD: ssl.c,v 1.2 2012/12/24 22:12:28 christos Exp $ */
2
3 /*-
4  * Copyright (c) 1998-2004 Dag-Erling Coïdan Smørgrav
5  * Copyright (c) 2008, 2010 Joerg Sonnenberger <joerg@NetBSD.org>
6  * All rights reserved.
7  *
8  * Redistribution and use in source and binary forms, with or without
9  * modification, are permitted provided that the following conditions
10  * are met:
11  * 1. Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer
13  *    in this position and unchanged.
14  * 2. Redistributions in binary form must reproduce the above copyright
15  *    notice, this list of conditions and the following disclaimer in the
16  *    documentation and/or other materials provided with the distribution.
17  * 3. The name of the author may not be used to endorse or promote products
18  *    derived from this software without specific prior written permission
19  *
20  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
21  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
22  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
23  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
24  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
25  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
29  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30  *
31  * $FreeBSD: common.c,v 1.53 2007/12/19 00:26:36 des Exp $
32  */
33
34 #include <sys/cdefs.h>
35 #ifndef lint
36 __RCSID("$NetBSD: ssl.c,v 1.2 2012/12/24 22:12:28 christos Exp $");
37 #endif
38
39 #include <time.h>
40 #include <unistd.h>
41 #include <fcntl.h>
42
43 #include <sys/param.h>
44 #include <sys/select.h>
45 #include <sys/uio.h>
46
47 #include <netinet/tcp.h>
48 #include <netinet/in.h>
49 #include <openssl/crypto.h>
50 #include <openssl/x509.h>
51 #include <openssl/pem.h>
52 #include <openssl/ssl.h>
53 #include <openssl/err.h>
54
55 #include "ssl.h"
56
57 extern int quit_time, verbose, ftp_debug;
58 extern FILE *ttyout;
59
60 struct fetch_connect {
61         int                      sd;            /* file/socket descriptor */
62         char                    *buf;           /* buffer */
63         size_t                   bufsize;       /* buffer size */
64         size_t                   bufpos;        /* position of buffer */
65         size_t                   buflen;        /* length of buffer contents */
66         struct {                                /* data cached after an
67                                                    interrupted read */
68                 char    *buf;
69                 size_t   size;
70                 size_t   pos;
71                 size_t   len;
72         } cache;
73         int                      issock;
74         int                      iserr;
75         int                      iseof;
76         SSL                     *ssl;           /* SSL handle */
77 };
78
79 /*
80  * Write a vector to a connection w/ timeout
81  * Note: can modify the iovec.
82  */
83 static ssize_t
84 fetch_writev(struct fetch_connect *conn, struct iovec *iov, int iovcnt)
85 {
86         struct timeval now, timeout, delta;
87         fd_set writefds;
88         ssize_t len, total;
89         int r;
90
91         if (quit_time > 0) {
92                 FD_ZERO(&writefds);
93                 gettimeofday(&timeout, NULL);
94                 timeout.tv_sec += quit_time;
95         }
96
97         total = 0;
98         while (iovcnt > 0) {
99                 while (quit_time > 0 && !FD_ISSET(conn->sd, &writefds)) {
100                         FD_SET(conn->sd, &writefds);
101                         gettimeofday(&now, NULL);
102                         delta.tv_sec = timeout.tv_sec - now.tv_sec;
103                         delta.tv_usec = timeout.tv_usec - now.tv_usec;
104                         if (delta.tv_usec < 0) {
105                                 delta.tv_usec += 1000000;
106                                 delta.tv_sec--;
107                         }
108                         if (delta.tv_sec < 0) {
109                                 errno = ETIMEDOUT;
110                                 return -1;
111                         }
112                         errno = 0;
113                         r = select(conn->sd + 1, NULL, &writefds, NULL, &delta);
114                         if (r == -1) {
115                                 if (errno == EINTR)
116                                         continue;
117                                 return -1;
118                         }
119                 }
120                 errno = 0;
121                 if (conn->ssl != NULL)
122                         len = SSL_write(conn->ssl, iov->iov_base, iov->iov_len);
123                 else
124                         len = writev(conn->sd, iov, iovcnt);
125                 if (len == 0) {
126                         /* we consider a short write a failure */
127                         /* XXX perhaps we shouldn't in the SSL case */
128                         errno = EPIPE;
129                         return -1;
130                 }
131                 if (len < 0) {
132                         if (errno == EINTR)
133                                 continue;
134                         return -1;
135                 }
136                 total += len;
137                 while (iovcnt > 0 && len >= (ssize_t)iov->iov_len) {
138                         len -= iov->iov_len;
139                         iov++;
140                         iovcnt--;
141                 }
142                 if (iovcnt > 0) {
143                         iov->iov_len -= len;
144                         iov->iov_base = (char *)iov->iov_base + len;
145                 }
146         }
147         return total;
148 }
149
150 /*
151  * Write to a connection w/ timeout
152  */
153 static int
154 fetch_write(struct fetch_connect *conn, const char *str, size_t len)
155 {
156         struct iovec iov[1];
157
158         iov[0].iov_base = (char *)__UNCONST(str);
159         iov[0].iov_len = len;
160         return fetch_writev(conn, iov, 1);
161 }
162
163 /*
164  * Send a formatted line; optionally echo to terminal
165  */
166 int
167 fetch_printf(struct fetch_connect *conn, const char *fmt, ...)
168 {
169         va_list ap;
170         size_t len;
171         char *msg;
172         int r;
173
174         va_start(ap, fmt);
175         len = vasprintf(&msg, fmt, ap);
176         va_end(ap);
177
178         if (msg == NULL) {
179                 errno = ENOMEM;
180                 return -1;
181         }
182
183         r = fetch_write(conn, msg, len);
184         free(msg);
185         return r;
186 }
187
188 int
189 fetch_fileno(struct fetch_connect *conn)
190 {
191
192         return conn->sd;
193 }
194
195 int
196 fetch_error(struct fetch_connect *conn)
197 {
198
199         return conn->iserr;
200 }
201
202 static void
203 fetch_clearerr(struct fetch_connect *conn)
204 {
205
206         conn->iserr = 0;
207 }
208
209 int
210 fetch_flush(struct fetch_connect *conn)
211 {
212         int v;
213
214         if (conn->issock) {
215 #ifdef TCP_NOPUSH
216                 v = 0;
217                 setsockopt(conn->sd, IPPROTO_TCP, TCP_NOPUSH, &v, sizeof(v));
218 #endif
219                 v = 1;
220                 setsockopt(conn->sd, IPPROTO_TCP, TCP_NODELAY, &v, sizeof(v));
221         }
222         return 0;
223 }
224
225 /*ARGSUSED*/
226 struct fetch_connect *
227 fetch_open(const char *fname, const char *fmode)
228 {
229         struct fetch_connect *conn;
230         int fd;
231
232         fd = open(fname, O_RDONLY); /* XXX: fmode */
233         if (fd < 0)
234                 return NULL;
235
236         if ((conn = calloc(1, sizeof(*conn))) == NULL) {
237                 close(fd);
238                 return NULL;
239         }
240
241         conn->sd = fd;
242         conn->issock = 0;
243         return conn;
244 }
245
246 /*ARGSUSED*/
247 struct fetch_connect *
248 fetch_fdopen(int sd, const char *fmode)
249 {
250         struct fetch_connect *conn;
251 #if defined(SO_NOSIGPIPE) || defined(TCP_NOPUSH)
252         int opt = 1;
253 #endif
254
255         if ((conn = calloc(1, sizeof(*conn))) == NULL)
256                 return NULL;
257
258         conn->sd = sd;
259         conn->issock = 1;
260         fcntl(sd, F_SETFD, FD_CLOEXEC);
261 #ifdef SO_NOSIGPIPE
262         setsockopt(sd, SOL_SOCKET, SO_NOSIGPIPE, &opt, sizeof(opt));
263 #endif
264 #ifdef TCP_NOPUSH
265         setsockopt(sd, IPPROTO_TCP, TCP_NOPUSH, &opt, sizeof(opt));
266 #endif
267         return conn;
268 }
269
270 int
271 fetch_close(struct fetch_connect *conn)
272 {
273         int rv = 0;
274
275         if (conn != NULL) {
276                 fetch_flush(conn);
277                 SSL_free(conn->ssl);
278                 rv = close(conn->sd);
279                 if (rv < 0) {
280                         errno = rv;
281                         rv = EOF;
282                 }
283                 free(conn->cache.buf);
284                 free(conn->buf);
285                 free(conn);
286         }
287         return rv;
288 }
289
290 #define FETCH_READ_WAIT         -2
291 #define FETCH_READ_ERROR        -1
292
293 static ssize_t
294 fetch_ssl_read(SSL *ssl, void *buf, size_t len)
295 {
296         ssize_t rlen;
297         int ssl_err;
298
299         rlen = SSL_read(ssl, buf, len);
300         if (rlen < 0) {
301                 ssl_err = SSL_get_error(ssl, rlen);
302                 if (ssl_err == SSL_ERROR_WANT_READ ||
303                     ssl_err == SSL_ERROR_WANT_WRITE) {
304                         return FETCH_READ_WAIT;
305                 }
306                 ERR_print_errors_fp(ttyout);
307                 return FETCH_READ_ERROR;
308         }
309         return rlen;
310 }
311
312 static ssize_t
313 fetch_nonssl_read(int sd, void *buf, size_t len)
314 {
315         ssize_t rlen;
316
317         rlen = read(sd, buf, len);
318         if (rlen < 0) {
319                 if (errno == EAGAIN || errno == EINTR)
320                         return FETCH_READ_WAIT;
321                 return FETCH_READ_ERROR;
322         }
323         return rlen;
324 }
325
326 /*
327  * Cache some data that was read from a socket but cannot be immediately
328  * returned because of an interrupted system call.
329  */
330 static int
331 fetch_cache_data(struct fetch_connect *conn, char *src, size_t nbytes)
332 {
333
334         if (conn->cache.size < nbytes) {
335                 char *tmp = realloc(conn->cache.buf, nbytes);
336                 if (tmp == NULL)
337                         return -1;
338
339                 conn->cache.buf = tmp;
340                 conn->cache.size = nbytes;
341         }
342
343         memcpy(conn->cache.buf, src, nbytes);
344         conn->cache.len = nbytes;
345         conn->cache.pos = 0;
346         return 0;
347 }
348
349 ssize_t
350 fetch_read(void *ptr, size_t size, size_t nmemb, struct fetch_connect *conn)
351 {
352         struct timeval now, timeout, delta;
353         fd_set readfds;
354         ssize_t rlen, total;
355         size_t len;
356         char *start, *buf;
357
358         if (quit_time > 0) {
359                 gettimeofday(&timeout, NULL);
360                 timeout.tv_sec += quit_time;
361         }
362
363         total = 0;
364         start = buf = ptr;
365         len = size * nmemb;
366
367         if (conn->cache.len > 0) {
368                 /*
369                  * The last invocation of fetch_read was interrupted by a
370                  * signal after some data had been read from the socket. Copy
371                  * the cached data into the supplied buffer before trying to
372                  * read from the socket again.
373                  */
374                 total = (conn->cache.len < len) ? conn->cache.len : len;
375                 memcpy(buf, conn->cache.buf, total);
376
377                 conn->cache.len -= total;
378                 conn->cache.pos += total;
379                 len -= total;
380                 buf += total;
381         }
382
383         while (len > 0) {
384                 /*
385                  * The socket is non-blocking.  Instead of the canonical
386                  * select() -> read(), we do the following:
387                  *
388                  * 1) call read() or SSL_read().
389                  * 2) if an error occurred, return -1.
390                  * 3) if we received data but we still expect more,
391                  *    update our counters and loop.
392                  * 4) if read() or SSL_read() signaled EOF, return.
393                  * 5) if we did not receive any data but we're not at EOF,
394                  *    call select().
395                  *
396                  * In the SSL case, this is necessary because if we
397                  * receive a close notification, we have to call
398                  * SSL_read() one additional time after we've read
399                  * everything we received.
400                  *
401                  * In the non-SSL case, it may improve performance (very
402                  * slightly) when reading small amounts of data.
403                  */
404                 if (conn->ssl != NULL)
405                         rlen = fetch_ssl_read(conn->ssl, buf, len);
406                 else
407                         rlen = fetch_nonssl_read(conn->sd, buf, len);
408                 if (rlen == 0) {
409                         break;
410                 } else if (rlen > 0) {
411                         len -= rlen;
412                         buf += rlen;
413                         total += rlen;
414                         continue;
415                 } else if (rlen == FETCH_READ_ERROR) {
416                         if (errno == EINTR)
417                                 fetch_cache_data(conn, start, total);
418                         return -1;
419                 }
420                 FD_ZERO(&readfds);
421                 while (!FD_ISSET(conn->sd, &readfds)) {
422                         FD_SET(conn->sd, &readfds);
423                         if (quit_time > 0) {
424                                 gettimeofday(&now, NULL);
425                                 if (!timercmp(&timeout, &now, >)) {
426                                         errno = ETIMEDOUT;
427                                         return -1;
428                                 }
429                                 timersub(&timeout, &now, &delta);
430                         }
431                         errno = 0;
432                         if (select(conn->sd + 1, &readfds, NULL, NULL,
433                                 quit_time > 0 ? &delta : NULL) < 0) {
434                                 if (errno == EINTR)
435                                         continue;
436                                 return -1;
437                         }
438                 }
439         }
440         return total;
441 }
442
443 #define MIN_BUF_SIZE 1024
444
445 /*
446  * Read a line of text from a connection w/ timeout
447  */
448 char *
449 fetch_getln(char *str, int size, struct fetch_connect *conn)
450 {
451         size_t tmpsize;
452         ssize_t len;
453         char c;
454
455         if (conn->buf == NULL) {
456                 if ((conn->buf = malloc(MIN_BUF_SIZE)) == NULL) {
457                         errno = ENOMEM;
458                         conn->iserr = 1;
459                         return NULL;
460                 }
461                 conn->bufsize = MIN_BUF_SIZE;
462         }
463
464         if (conn->iserr || conn->iseof)
465                 return NULL;
466
467         if (conn->buflen - conn->bufpos > 0)
468                 goto done;
469
470         conn->buf[0] = '\0';
471         conn->bufpos = 0;
472         conn->buflen = 0;
473         do {
474                 len = fetch_read(&c, sizeof(c), 1, conn);
475                 if (len == -1) {
476                         conn->iserr = 1;
477                         return NULL;
478                 }
479                 if (len == 0) {
480                         conn->iseof = 1;
481                         break;
482                 }
483                 conn->buf[conn->buflen++] = c;
484                 if (conn->buflen == conn->bufsize) {
485                         char *tmp = conn->buf;
486                         tmpsize = conn->bufsize * 2 + 1;
487                         if ((tmp = realloc(tmp, tmpsize)) == NULL) {
488                                 errno = ENOMEM;
489                                 conn->iserr = 1;
490                                 return NULL;
491                         }
492                         conn->buf = tmp;
493                         conn->bufsize = tmpsize;
494                 }
495         } while (c != '\n');
496
497         if (conn->buflen == 0)
498                 return NULL;
499  done:
500         tmpsize = MIN(size - 1, (int)(conn->buflen - conn->bufpos));
501         memcpy(str, conn->buf + conn->bufpos, tmpsize);
502         str[tmpsize] = '\0';
503         conn->bufpos += tmpsize;
504         return str;
505 }
506
507 int
508 fetch_getline(struct fetch_connect *conn, char *buf, size_t buflen,
509     const char **errormsg)
510 {
511         size_t len;
512         int rv;
513
514         if (fetch_getln(buf, buflen, conn) == NULL) {
515                 if (conn->iseof) {      /* EOF */
516                         rv = -2;
517                         if (errormsg)
518                                 *errormsg = "\nEOF received";
519                 } else {                /* error */
520                         rv = -1;
521                         if (errormsg)
522                                 *errormsg = "Error encountered";
523                 }
524                 fetch_clearerr(conn);
525                 return rv;
526         }
527         len = strlen(buf);
528         if (buf[len - 1] == '\n') {     /* clear any trailing newline */
529                 buf[--len] = '\0';
530         } else if (len == buflen - 1) { /* line too long */
531                 while (1) {
532                         char c;
533                         ssize_t rlen = fetch_read(&c, sizeof(c), 1, conn);
534                         if (rlen <= 0 || c == '\n')
535                                 break;
536                 }
537                 if (errormsg)
538                         *errormsg = "Input line is too long";
539                 fetch_clearerr(conn);
540                 return -3;
541         }
542         if (errormsg)
543                 *errormsg = NULL;
544         return len;
545 }
546
547 void *
548 fetch_start_ssl(int sock)
549 {
550         SSL *ssl;
551         SSL_CTX *ctx;
552         int ret, ssl_err;
553
554         /* Init the SSL library and context */
555         if (!SSL_library_init()){
556                 fprintf(ttyout, "SSL library init failed\n");
557                 return NULL;
558         }
559
560         SSL_load_error_strings();
561
562         ctx = SSL_CTX_new(SSLv23_client_method());
563         SSL_CTX_set_mode(ctx, SSL_MODE_AUTO_RETRY);
564
565         ssl = SSL_new(ctx);
566         if (ssl == NULL){
567                 fprintf(ttyout, "SSL context creation failed\n");
568                 SSL_CTX_free(ctx);
569                 return NULL;
570         }
571         SSL_set_fd(ssl, sock);
572         while ((ret = SSL_connect(ssl)) == -1) {
573                 ssl_err = SSL_get_error(ssl, ret);
574                 if (ssl_err != SSL_ERROR_WANT_READ &&
575                     ssl_err != SSL_ERROR_WANT_WRITE) {
576                         ERR_print_errors_fp(ttyout);
577                         SSL_free(ssl);
578                         return NULL;
579                 }
580         }
581
582         if (ftp_debug && verbose) {
583                 X509 *cert;
584                 X509_NAME *name;
585                 char *str;
586
587                 fprintf(ttyout, "SSL connection established using %s\n",
588                     SSL_get_cipher(ssl));
589                 cert = SSL_get_peer_certificate(ssl);
590                 name = X509_get_subject_name(cert);
591                 str = X509_NAME_oneline(name, 0, 0);
592                 fprintf(ttyout, "Certificate subject: %s\n", str);
593                 free(str);
594                 name = X509_get_issuer_name(cert);
595                 str = X509_NAME_oneline(name, 0, 0);
596                 fprintf(ttyout, "Certificate issuer: %s\n", str);
597                 free(str);
598         }
599
600         return ssl;
601 }
602
603
604 void
605 fetch_set_ssl(struct fetch_connect *conn, void *ssl)
606 {
607         conn->ssl = ssl;
608 }