libssh: Fix a remaining warning and raise WARNS to 2.
[dragonfly.git] / crypto / openssh / cipher-ctr-mt.c
1 /*
2  * OpenSSH Multi-threaded AES-CTR Cipher
3  *
4  * Author: Benjamin Bennett <ben@psc.edu>
5  * Author: Mike Tasota <tasota@gmail.com>
6  * Author: Chris Rapier <rapier@psc.edu>
7  * Copyright (c) 2008-2013 Pittsburgh Supercomputing Center. All rights reserved.
8  *
9  * Based on original OpenSSH AES-CTR cipher. Small portions remain unchanged,
10  * Copyright (c) 2003 Markus Friedl <markus@openbsd.org>
11  *
12  * Permission to use, copy, modify, and distribute this software for any
13  * purpose with or without fee is hereby granted, provided that the above
14  * copyright notice and this permission notice appear in all copies.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
17  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
18  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
19  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
20  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
21  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
22  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
23  */
24 #include "includes.h"
25
26 #include <sys/types.h>
27
28 #include <stdarg.h>
29 #include <string.h>
30
31 #include <openssl/evp.h>
32
33 #include "xmalloc.h"
34 #include "log.h"
35
36 /* compatibility with old or broken OpenSSL versions */
37 #include "openbsd-compat/openssl-compat.h"
38
39 #ifndef USE_BUILTIN_RIJNDAEL
40 #include <openssl/aes.h>
41 #endif
42
43 #include <pthread.h>
44
45 /*-------------------- TUNABLES --------------------*/
46 /* Number of pregen threads to use */
47 #define CIPHER_THREADS  2
48
49 /* Number of keystream queues */
50 #define NUMKQ           (CIPHER_THREADS + 2)
51
52 /* Length of a keystream queue */
53 #define KQLEN           4096
54
55 /* Processor cacheline length */
56 #define CACHELINE_LEN   64
57
58 /* Collect thread stats and print at cancellation when in debug mode */
59 /* #define CIPHER_THREAD_STATS */
60
61 /* Use single-byte XOR instead of 8-byte XOR */
62 /* #define CIPHER_BYTE_XOR */
63 /*-------------------- END TUNABLES --------------------*/
64
65
66 const EVP_CIPHER *evp_aes_ctr_mt(void);
67
68 #ifdef CIPHER_THREAD_STATS
69 /*
70  * Struct to collect thread stats
71  */
72 struct thread_stats {
73         u_int   fills;
74         u_int   skips;
75         u_int   waits;
76         u_int   drains;
77 };
78
79 /*
80  * Debug print the thread stats
81  * Use with pthread_cleanup_push for displaying at thread cancellation
82  */
83 static void
84 thread_loop_stats(void *x)
85 {
86         struct thread_stats *s = x;
87
88         debug("tid %lu - %u fills, %u skips, %u waits", pthread_self(),
89                         s->fills, s->skips, s->waits);
90 }
91
92  #define STATS_STRUCT(s)        struct thread_stats s
93  #define STATS_INIT(s)          { memset(&s, 0, sizeof(s)); }
94  #define STATS_FILL(s)          { s.fills++; }
95  #define STATS_SKIP(s)          { s.skips++; }
96  #define STATS_WAIT(s)          { s.waits++; }
97  #define STATS_DRAIN(s)         { s.drains++; }
98 #else
99  #define STATS_STRUCT(s)
100  #define STATS_INIT(s)
101  #define STATS_FILL(s)
102  #define STATS_SKIP(s)
103  #define STATS_WAIT(s)
104  #define STATS_DRAIN(s)
105 #endif
106
107 /* Keystream Queue state */
108 enum {
109         KQINIT,
110         KQEMPTY,
111         KQFILLING,
112         KQFULL,
113         KQDRAINING
114 };
115
116 /* Keystream Queue struct */
117 struct kq {
118         u_char          keys[KQLEN][AES_BLOCK_SIZE];
119         u_char          ctr[AES_BLOCK_SIZE];
120         u_char          pad0[CACHELINE_LEN];
121         volatile int    qstate;
122         pthread_mutex_t lock;
123         pthread_cond_t  cond;
124         u_char          pad1[CACHELINE_LEN];
125 };
126
127 /* Context struct */
128 struct ssh_aes_ctr_ctx
129 {
130         struct kq       q[NUMKQ];
131         AES_KEY         aes_ctx;
132         STATS_STRUCT(stats);
133         u_char          aes_counter[AES_BLOCK_SIZE];
134         pthread_t       tid[CIPHER_THREADS];
135         int             state;
136         int             qidx;
137         int             ridx;
138 };
139
140 /* <friedl>
141  * increment counter 'ctr',
142  * the counter is of size 'len' bytes and stored in network-byte-order.
143  * (LSB at ctr[len-1], MSB at ctr[0])
144  */
145 static void
146 ssh_ctr_inc(u_char *ctr, u_int len)
147 {
148         int i;
149
150         for (i = len - 1; i >= 0; i--)
151                 if (++ctr[i])   /* continue on overflow */
152                         return;
153 }
154
155 /*
156  * Add num to counter 'ctr'
157  */
158 static void
159 ssh_ctr_add(u_char *ctr, uint32_t num, u_int len)
160 {
161         int i;
162         uint16_t n;
163
164         for (n = 0, i = len - 1; i >= 0 && (num || n); i--) {
165                 n = ctr[i] + (num & 0xff) + n;
166                 num >>= 8;
167                 ctr[i] = n & 0xff;
168                 n >>= 8;
169         }
170 }
171
172 /*
173  * Threads may be cancelled in a pthread_cond_wait, we must free the mutex
174  */
175 static void
176 thread_loop_cleanup(void *x)
177 {
178         pthread_mutex_unlock((pthread_mutex_t *)x);
179 }
180
181 /*
182  * The life of a pregen thread:
183  *    Find empty keystream queues and fill them using their counter.
184  *    When done, update counter for the next fill.
185  */
186 static void *
187 thread_loop(void *x)
188 {
189         AES_KEY key;
190         STATS_STRUCT(stats);
191         struct ssh_aes_ctr_ctx *c = x;
192         struct kq *q;
193         int i;
194         int qidx;
195
196         /* Threads stats on cancellation */
197         STATS_INIT(stats);
198 #ifdef CIPHER_THREAD_STATS
199         pthread_cleanup_push(thread_loop_stats, &stats);
200 #endif
201
202         /* Thread local copy of AES key */
203         memcpy(&key, &c->aes_ctx, sizeof(key));
204
205         /*
206          * Handle the special case of startup, one thread must fill
207          * the first KQ then mark it as draining. Lock held throughout.
208          */
209         if (pthread_equal(pthread_self(), c->tid[0])) {
210                 q = &c->q[0];
211                 pthread_mutex_lock(&q->lock);
212                 if (q->qstate == KQINIT) {
213                         for (i = 0; i < KQLEN; i++) {
214                                 AES_encrypt(q->ctr, q->keys[i], &key);
215                                 ssh_ctr_inc(q->ctr, AES_BLOCK_SIZE);
216                         }
217                         ssh_ctr_add(q->ctr, KQLEN * (NUMKQ - 1), AES_BLOCK_SIZE);
218                         q->qstate = KQDRAINING;
219                         STATS_FILL(stats);
220                         pthread_cond_broadcast(&q->cond);
221                 }
222                 pthread_mutex_unlock(&q->lock);
223         }
224         else 
225                 STATS_SKIP(stats);
226
227         /*
228          * Normal case is to find empty queues and fill them, skipping over
229          * queues already filled by other threads and stopping to wait for
230          * a draining queue to become empty.
231          *
232          * Multiple threads may be waiting on a draining queue and awoken
233          * when empty.  The first thread to wake will mark it as filling,
234          * others will move on to fill, skip, or wait on the next queue.
235          */
236         for (qidx = 1;; qidx = (qidx + 1) % NUMKQ) {
237                 /* Check if I was cancelled, also checked in cond_wait */
238                 pthread_testcancel();
239
240                 /* Lock queue and block if its draining */
241                 q = &c->q[qidx];
242                 pthread_mutex_lock(&q->lock);
243                 pthread_cleanup_push(thread_loop_cleanup, &q->lock);
244                 while (q->qstate == KQDRAINING || q->qstate == KQINIT) {
245                         STATS_WAIT(stats);
246                         pthread_cond_wait(&q->cond, &q->lock);
247                 }
248                 pthread_cleanup_pop(0);
249
250                 /* If filling or full, somebody else got it, skip */
251                 if (q->qstate != KQEMPTY) {
252                         pthread_mutex_unlock(&q->lock);
253                         STATS_SKIP(stats);
254                         continue;
255                 }
256
257                 /*
258                  * Empty, let's fill it.
259                  * Queue lock is relinquished while we do this so others
260                  * can see that it's being filled.
261                  */
262                 q->qstate = KQFILLING;
263                 pthread_mutex_unlock(&q->lock);
264                 for (i = 0; i < KQLEN; i++) {
265                         AES_encrypt(q->ctr, q->keys[i], &key);
266                         ssh_ctr_inc(q->ctr, AES_BLOCK_SIZE);
267                 }
268
269                 /* Re-lock, mark full and signal consumer */
270                 pthread_mutex_lock(&q->lock);
271                 ssh_ctr_add(q->ctr, KQLEN * (NUMKQ - 1), AES_BLOCK_SIZE);
272                 q->qstate = KQFULL;
273                 STATS_FILL(stats);
274                 pthread_cond_signal(&q->cond);
275                 pthread_mutex_unlock(&q->lock);
276         }
277
278 #ifdef CIPHER_THREAD_STATS
279         /* Stats */
280         pthread_cleanup_pop(1);
281 #endif
282
283         return NULL;
284 }
285
286 static int
287 ssh_aes_ctr(EVP_CIPHER_CTX *ctx, u_char *dest, const u_char *src,
288     LIBCRYPTO_EVP_INL_TYPE len)
289 {
290         struct ssh_aes_ctr_ctx *c;
291         struct kq *q, *oldq;
292         int ridx;
293         u_char *buf;
294
295         if (len == 0)
296                 return (1);
297         if ((c = EVP_CIPHER_CTX_get_app_data(ctx)) == NULL)
298                 return (0);
299
300         q = &c->q[c->qidx];
301         ridx = c->ridx;
302
303         /* src already padded to block multiple */
304         while (len > 0) {
305                 buf = q->keys[ridx];
306
307 #ifdef CIPHER_BYTE_XOR
308                 dest[0] = src[0] ^ buf[0];
309                 dest[1] = src[1] ^ buf[1];
310                 dest[2] = src[2] ^ buf[2];
311                 dest[3] = src[3] ^ buf[3];
312                 dest[4] = src[4] ^ buf[4];
313                 dest[5] = src[5] ^ buf[5];
314                 dest[6] = src[6] ^ buf[6];
315                 dest[7] = src[7] ^ buf[7];
316                 dest[8] = src[8] ^ buf[8];
317                 dest[9] = src[9] ^ buf[9];
318                 dest[10] = src[10] ^ buf[10];
319                 dest[11] = src[11] ^ buf[11];
320                 dest[12] = src[12] ^ buf[12];
321                 dest[13] = src[13] ^ buf[13];
322                 dest[14] = src[14] ^ buf[14];
323                 dest[15] = src[15] ^ buf[15];
324 #else
325                 *(uint64_t *)dest = *(uint64_t *)src ^ *(uint64_t *)buf;
326                 *(uint64_t *)(dest + 8) = *(uint64_t *)(src + 8) ^
327                                                 *(uint64_t *)(buf + 8);
328 #endif
329
330                 dest += 16;
331                 src += 16;
332                 len -= 16;
333                 ssh_ctr_inc(ctx->iv, AES_BLOCK_SIZE);
334
335                 /* Increment read index, switch queues on rollover */
336                 if ((ridx = (ridx + 1) % KQLEN) == 0) {
337                         oldq = q;
338
339                         /* Mark next queue draining, may need to wait */
340                         c->qidx = (c->qidx + 1) % NUMKQ;
341                         q = &c->q[c->qidx];
342                         pthread_mutex_lock(&q->lock);
343                         while (q->qstate != KQFULL) {
344                                 STATS_WAIT(c->stats);
345                                 pthread_cond_wait(&q->cond, &q->lock);
346                         }
347                         q->qstate = KQDRAINING;
348                         pthread_mutex_unlock(&q->lock);
349
350                         /* Mark consumed queue empty and signal producers */
351                         pthread_mutex_lock(&oldq->lock);
352                         oldq->qstate = KQEMPTY;
353                         STATS_DRAIN(c->stats);
354                         pthread_cond_broadcast(&oldq->cond);
355                         pthread_mutex_unlock(&oldq->lock);
356                 }
357         }
358         c->ridx = ridx;
359         return (1);
360 }
361
362 #define HAVE_NONE       0
363 #define HAVE_KEY        1
364 #define HAVE_IV         2
365 static int
366 ssh_aes_ctr_init(EVP_CIPHER_CTX *ctx, const u_char *key, const u_char *iv,
367     int enc)
368 {
369         struct ssh_aes_ctr_ctx *c;
370         int i;
371
372         if ((c = EVP_CIPHER_CTX_get_app_data(ctx)) == NULL) {
373                 c = xmalloc(sizeof(*c));
374
375                 c->state = HAVE_NONE;
376                 for (i = 0; i < NUMKQ; i++) {
377                         pthread_mutex_init(&c->q[i].lock, NULL);
378                         pthread_cond_init(&c->q[i].cond, NULL);
379                 }
380
381                 STATS_INIT(c->stats);
382                 
383                 EVP_CIPHER_CTX_set_app_data(ctx, c);
384         }
385
386         if (c->state == (HAVE_KEY | HAVE_IV)) {
387                 /* Cancel pregen threads */
388                 for (i = 0; i < CIPHER_THREADS; i++)
389                         pthread_cancel(c->tid[i]);
390                 for (i = 0; i < CIPHER_THREADS; i++)
391                         pthread_join(c->tid[i], NULL);
392                 /* Start over getting key & iv */
393                 c->state = HAVE_NONE;
394         }
395
396         if (key != NULL) {
397                 AES_set_encrypt_key(key, EVP_CIPHER_CTX_key_length(ctx) * 8,
398                     &c->aes_ctx);
399                 c->state |= HAVE_KEY;
400         }
401
402         if (iv != NULL) {
403                 memcpy(ctx->iv, iv, AES_BLOCK_SIZE);
404                 c->state |= HAVE_IV;
405         }
406
407         if (c->state == (HAVE_KEY | HAVE_IV)) {
408                 /* Clear queues */
409                 memcpy(c->q[0].ctr, ctx->iv, AES_BLOCK_SIZE);
410                 c->q[0].qstate = KQINIT;
411                 for (i = 1; i < NUMKQ; i++) {
412                         memcpy(c->q[i].ctr, ctx->iv, AES_BLOCK_SIZE);
413                         ssh_ctr_add(c->q[i].ctr, i * KQLEN, AES_BLOCK_SIZE);
414                         c->q[i].qstate = KQEMPTY;
415                 }
416                 c->qidx = 0;
417                 c->ridx = 0;
418
419                 /* Start threads */
420                 for (i = 0; i < CIPHER_THREADS; i++) {
421                         debug("spawned a thread");
422                         pthread_create(&c->tid[i], NULL, thread_loop, c);
423                 }
424                 pthread_mutex_lock(&c->q[0].lock);
425                 while (c->q[0].qstate != KQDRAINING)
426                         pthread_cond_wait(&c->q[0].cond, &c->q[0].lock);
427                 pthread_mutex_unlock(&c->q[0].lock);
428                 
429         }
430         return (1);
431 }
432
433 /* this function is no longer used but might prove handy in the future
434  * this comment also applies to ssh_aes_ctr_thread_reconstruction
435  */
436 void
437 ssh_aes_ctr_thread_destroy(EVP_CIPHER_CTX *ctx)
438 {
439         struct ssh_aes_ctr_ctx *c;
440         int i;
441         c = EVP_CIPHER_CTX_get_app_data(ctx);
442         /* destroy threads */
443         for (i = 0; i < CIPHER_THREADS; i++) {
444                 pthread_cancel(c->tid[i]);
445         }
446         for (i = 0; i < CIPHER_THREADS; i++) {
447                 pthread_join(c->tid[i], NULL);
448         }
449 }
450
451 void
452 ssh_aes_ctr_thread_reconstruction(EVP_CIPHER_CTX *ctx)
453 {
454         struct ssh_aes_ctr_ctx *c;
455         int i;
456         c = EVP_CIPHER_CTX_get_app_data(ctx);
457         /* reconstruct threads */
458         for (i = 0; i < CIPHER_THREADS; i++) {
459                 debug("spawned a thread");
460                 pthread_create(&c->tid[i], NULL, thread_loop, c);
461         }
462 }
463
464 static int
465 ssh_aes_ctr_cleanup(EVP_CIPHER_CTX *ctx)
466 {
467         struct ssh_aes_ctr_ctx *c;
468         int i;
469
470         if ((c = EVP_CIPHER_CTX_get_app_data(ctx)) != NULL) {
471 #ifdef CIPHER_THREAD_STATS
472                 debug("main thread: %u drains, %u waits", c->stats.drains,
473                                 c->stats.waits);
474 #endif
475                 /* Cancel pregen threads */
476                 for (i = 0; i < CIPHER_THREADS; i++)
477                         pthread_cancel(c->tid[i]);
478                 for (i = 0; i < CIPHER_THREADS; i++)
479                         pthread_join(c->tid[i], NULL);
480
481                 memset(c, 0, sizeof(*c));
482                 free(c);
483                 EVP_CIPHER_CTX_set_app_data(ctx, NULL);
484         }
485         return (1);
486 }
487
488 /* <friedl> */
489 const EVP_CIPHER *
490 evp_aes_ctr_mt(void)
491 {
492         static EVP_CIPHER aes_ctr;
493
494         memset(&aes_ctr, 0, sizeof(EVP_CIPHER));
495         aes_ctr.nid = NID_undef;
496         aes_ctr.block_size = AES_BLOCK_SIZE;
497         aes_ctr.iv_len = AES_BLOCK_SIZE;
498         aes_ctr.key_len = 16;
499         aes_ctr.init = ssh_aes_ctr_init;
500         aes_ctr.cleanup = ssh_aes_ctr_cleanup;
501         aes_ctr.do_cipher = ssh_aes_ctr;
502 #ifndef SSH_OLD_EVP
503         aes_ctr.flags = EVP_CIPH_CBC_MODE | EVP_CIPH_VARIABLE_LENGTH |
504             EVP_CIPH_ALWAYS_CALL_INIT | EVP_CIPH_CUSTOM_IV;
505 #endif
506         return (&aes_ctr);
507 }