libfetch: Fix hang due to SSL server closing before read completes
authorJohn Marino <draco@marino.st>
Thu, 1 Nov 2012 19:31:08 +0000 (20:31 +0100)
committerJohn Marino <draco@marino.st>
Thu, 1 Nov 2012 21:56:18 +0000 (22:56 +0100)
If the server sends a close notification before before a SSL read
operation is complete, fetch will hang.  Fix this by reworking
fetch_read() to use non-blocking sockets.

Taken-From: FreeBSD SVN 210568 (28 JUL 2010)
Taken-From: FreeBSD SVN 214256 (23 OCT 2010)

lib/libfetch/common.c

index d6fe226..bf2000e 100644 (file)
@@ -37,6 +37,7 @@
 
 #include <ctype.h>
 #include <errno.h>
+#include <fcntl.h>
 #include <netdb.h>
 #include <pwd.h>
 #include <stdarg.h>
@@ -294,7 +295,8 @@ fetch_connect(const char *host, int port, int af, int verbose)
                        close(sd);
                        continue;
                }
-               if (connect(sd, res->ai_addr, res->ai_addrlen) == 0)
+               if (connect(sd, res->ai_addr, res->ai_addrlen) == 0 &&
+                    fcntl(sd, F_SETFL, O_NONBLOCK) == 0)
                        break;
                close(sd);
        }
@@ -318,6 +320,7 @@ fetch_connect(const char *host, int port, int af, int verbose)
 int
 fetch_ssl(conn_t *conn, int verbose)
 {
+        int ret, ssl_err;
 
 #ifdef WITH_SSL
        /* Init the SSL library and context */
@@ -338,9 +341,13 @@ fetch_ssl(conn_t *conn, int verbose)
                return (-1);
        }
        SSL_set_fd(conn->ssl, conn->sd);
-       if (SSL_connect(conn->ssl) == -1){
-               ERR_print_errors_fp(stderr);
-               return (-1);
+       while ((ret = SSL_connect(conn->ssl)) == -1) {
+               ssl_err = SSL_get_error(conn->ssl, ret);
+               if (ssl_err != SSL_ERROR_WANT_READ &&
+                   ssl_err != SSL_ERROR_WANT_WRITE) {
+                       ERR_print_errors_fp(stderr);
+                       return (-1);
+               }
        }
 
        if (verbose) {
@@ -369,6 +376,46 @@ fetch_ssl(conn_t *conn, int verbose)
 #endif
 }
 
+#define FETCH_READ_WAIT                -2
+#define FETCH_READ_ERROR       -1
+#define FETCH_READ_DONE                 0
+
+#ifdef WITH_SSL
+static ssize_t
+fetch_ssl_read(SSL *ssl, char *buf, size_t len)
+{
+       ssize_t rlen;
+       int ssl_err;
+
+       rlen = SSL_read(ssl, buf, len);
+       if (rlen < 0) {
+               ssl_err = SSL_get_error(ssl, rlen);
+               if (ssl_err == SSL_ERROR_WANT_READ ||
+                   ssl_err == SSL_ERROR_WANT_WRITE) {
+                       return (FETCH_READ_WAIT);
+               } else {
+                       ERR_print_errors_fp(stderr);
+                       return (FETCH_READ_ERROR);
+               }
+       }
+       return (rlen);
+}
+#endif
+
+static ssize_t
+fetch_socket_read(int sd, char *buf, size_t len)
+{
+       ssize_t rlen;
+
+       rlen = read(sd, buf, len);
+       if (rlen < 0) {
+               if (errno == EAGAIN || (errno == EINTR && fetchRestartCalls))
+                       return (FETCH_READ_WAIT);
+               else
+                       return (FETCH_READ_ERROR);
+       }
+       return (rlen);
+}
 
 /*
  * Read a character from a connection w/ timeout
@@ -389,6 +436,43 @@ fetch_read(conn_t *conn, char *buf, size_t len)
 
        total = 0;
        while (len > 0) {
+               /*
+                * The socket is non-blocking.  Instead of the canonical
+                * select() -> read(), we do the following:
+                *
+                * 1) call read() or SSL_read().
+                * 2) if an error occurred, return -1.
+                * 3) if we received data but we still expect more,
+                *    update our counters and loop.
+                * 4) if read() or SSL_read() signaled EOF, return.
+                * 5) if we did not receive any data but we're not at EOF,
+                *    call select().
+                *
+                * In the SSL case, this is necessary because if we
+                * receive a close notification, we have to call
+                * SSL_read() one additional time after we've read
+                * everything we received.
+                *
+                * In the non-SSL case, it may improve performance (very
+                * slightly) when reading small amounts of data.
+                */
+#ifdef WITH_SSL
+               if (conn->ssl != NULL)
+                       rlen = fetch_ssl_read(conn->ssl, buf, len);
+               else
+#endif
+                       rlen = fetch_socket_read(conn->sd, buf, len);
+               if (rlen == 0) {
+                       break;
+               } else if (rlen > 0) {
+                       len -= rlen;
+                       buf += rlen;
+                       total += rlen;
+                       continue;
+               } else if (rlen == FETCH_READ_ERROR) {
+                       return (-1);
+               }
+               // assert(rlen == FETCH_READ_WAIT);
                while (fetchTimeout && !FD_ISSET(conn->sd, &readfds)) {
                        FD_SET(conn->sd, &readfds);
                        gettimeofday(&now, NULL);
@@ -412,22 +496,6 @@ fetch_read(conn_t *conn, char *buf, size_t len)
                                return (-1);
                        }
                }
-#ifdef WITH_SSL
-               if (conn->ssl != NULL)
-                       rlen = SSL_read(conn->ssl, buf, len);
-               else
-#endif
-                       rlen = read(conn->sd, buf, len);
-               if (rlen == 0)
-                       break;
-               if (rlen < 0) {
-                       if (errno == EINTR && fetchRestartCalls)
-                               continue;
-                       return (-1);
-               }
-               len -= rlen;
-               buf += rlen;
-               total += rlen;
        }
        return (total);
 }
@@ -547,6 +615,7 @@ fetch_writev(conn_t *conn, struct iovec *iov, int iovcnt)
                        wlen = writev(conn->sd, iov, iovcnt);
                if (wlen == 0) {
                        /* we consider a short write a failure */
+                       /* XXX perhaps we shouldn't in the SSL case */
                        errno = EPIPE;
                        fetch_syserr();
                        return (-1);