Merge branch 'vendor/OPENSSH'
[dragonfly.git] / crypto / openssh / kex.c
1 /* $OpenBSD: kex.c,v 1.118 2016/05/02 10:26:04 djm Exp $ */
2 /*
3  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "includes.h"
27
28 #include <sys/param.h>  /* MAX roundup */
29
30 #include <signal.h>
31 #include <stdarg.h>
32 #include <stdio.h>
33 #include <stdlib.h>
34 #include <string.h>
35
36 #ifdef WITH_OPENSSL
37 #include <openssl/crypto.h>
38 #include <openssl/dh.h>
39 #endif
40
41 #include "ssh2.h"
42 #include "packet.h"
43 #include "compat.h"
44 #include "cipher.h"
45 #include "sshkey.h"
46 #include "kex.h"
47 #include "log.h"
48 #include "mac.h"
49 #include "match.h"
50 #include "misc.h"
51 #include "dispatch.h"
52 #include "monitor.h"
53
54 #include "ssherr.h"
55 #include "sshbuf.h"
56 #include "digest.h"
57
58 #if OPENSSL_VERSION_NUMBER >= 0x00907000L
59 # if defined(HAVE_EVP_SHA256)
60 # define evp_ssh_sha256 EVP_sha256
61 # else
62 extern const EVP_MD *evp_ssh_sha256(void);
63 # endif
64 #endif
65
66 /* prototype */
67 static int kex_choose_conf(struct ssh *);
68 static int kex_input_newkeys(int, u_int32_t, void *);
69
70 static const char *proposal_names[PROPOSAL_MAX] = {
71         "KEX algorithms",
72         "host key algorithms",
73         "ciphers ctos",
74         "ciphers stoc",
75         "MACs ctos",
76         "MACs stoc",
77         "compression ctos",
78         "compression stoc",
79         "languages ctos",
80         "languages stoc",
81 };
82
83 struct kexalg {
84         char *name;
85         u_int type;
86         int ec_nid;
87         int hash_alg;
88 };
89 static const struct kexalg kexalgs[] = {
90 #ifdef WITH_OPENSSL
91         { KEX_DH1, KEX_DH_GRP1_SHA1, 0, SSH_DIGEST_SHA1 },
92         { KEX_DH14_SHA1, KEX_DH_GRP14_SHA1, 0, SSH_DIGEST_SHA1 },
93         { KEX_DH14_SHA256, KEX_DH_GRP14_SHA256, 0, SSH_DIGEST_SHA256 },
94         { KEX_DH16_SHA512, KEX_DH_GRP16_SHA512, 0, SSH_DIGEST_SHA512 },
95         { KEX_DH18_SHA512, KEX_DH_GRP18_SHA512, 0, SSH_DIGEST_SHA512 },
96         { KEX_DHGEX_SHA1, KEX_DH_GEX_SHA1, 0, SSH_DIGEST_SHA1 },
97 #ifdef HAVE_EVP_SHA256
98         { KEX_DHGEX_SHA256, KEX_DH_GEX_SHA256, 0, SSH_DIGEST_SHA256 },
99 #endif /* HAVE_EVP_SHA256 */
100 #ifdef OPENSSL_HAS_ECC
101         { KEX_ECDH_SHA2_NISTP256, KEX_ECDH_SHA2,
102             NID_X9_62_prime256v1, SSH_DIGEST_SHA256 },
103         { KEX_ECDH_SHA2_NISTP384, KEX_ECDH_SHA2, NID_secp384r1,
104             SSH_DIGEST_SHA384 },
105 # ifdef OPENSSL_HAS_NISTP521
106         { KEX_ECDH_SHA2_NISTP521, KEX_ECDH_SHA2, NID_secp521r1,
107             SSH_DIGEST_SHA512 },
108 # endif /* OPENSSL_HAS_NISTP521 */
109 #endif /* OPENSSL_HAS_ECC */
110 #endif /* WITH_OPENSSL */
111 #if defined(HAVE_EVP_SHA256) || !defined(WITH_OPENSSL)
112         { KEX_CURVE25519_SHA256, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
113 #endif /* HAVE_EVP_SHA256 || !WITH_OPENSSL */
114         { NULL, -1, -1, -1},
115 };
116
117 char *
118 kex_alg_list(char sep)
119 {
120         char *ret = NULL, *tmp;
121         size_t nlen, rlen = 0;
122         const struct kexalg *k;
123
124         for (k = kexalgs; k->name != NULL; k++) {
125                 if (ret != NULL)
126                         ret[rlen++] = sep;
127                 nlen = strlen(k->name);
128                 if ((tmp = realloc(ret, rlen + nlen + 2)) == NULL) {
129                         free(ret);
130                         return NULL;
131                 }
132                 ret = tmp;
133                 memcpy(ret + rlen, k->name, nlen + 1);
134                 rlen += nlen;
135         }
136         return ret;
137 }
138
139 static const struct kexalg *
140 kex_alg_by_name(const char *name)
141 {
142         const struct kexalg *k;
143
144         for (k = kexalgs; k->name != NULL; k++) {
145                 if (strcmp(k->name, name) == 0)
146                         return k;
147         }
148         return NULL;
149 }
150
151 /* Validate KEX method name list */
152 int
153 kex_names_valid(const char *names)
154 {
155         char *s, *cp, *p;
156
157         if (names == NULL || strcmp(names, "") == 0)
158                 return 0;
159         if ((s = cp = strdup(names)) == NULL)
160                 return 0;
161         for ((p = strsep(&cp, ",")); p && *p != '\0';
162             (p = strsep(&cp, ","))) {
163                 if (kex_alg_by_name(p) == NULL) {
164                         error("Unsupported KEX algorithm \"%.100s\"", p);
165                         free(s);
166                         return 0;
167                 }
168         }
169         debug3("kex names ok: [%s]", names);
170         free(s);
171         return 1;
172 }
173
174 /*
175  * Concatenate algorithm names, avoiding duplicates in the process.
176  * Caller must free returned string.
177  */
178 char *
179 kex_names_cat(const char *a, const char *b)
180 {
181         char *ret = NULL, *tmp = NULL, *cp, *p;
182         size_t len;
183
184         if (a == NULL || *a == '\0')
185                 return NULL;
186         if (b == NULL || *b == '\0')
187                 return strdup(a);
188         if (strlen(b) > 1024*1024)
189                 return NULL;
190         len = strlen(a) + strlen(b) + 2;
191         if ((tmp = cp = strdup(b)) == NULL ||
192             (ret = calloc(1, len)) == NULL) {
193                 free(tmp);
194                 return NULL;
195         }
196         strlcpy(ret, a, len);
197         for ((p = strsep(&cp, ",")); p && *p != '\0'; (p = strsep(&cp, ","))) {
198                 if (match_list(ret, p, NULL) != NULL)
199                         continue; /* Algorithm already present */
200                 if (strlcat(ret, ",", len) >= len ||
201                     strlcat(ret, p, len) >= len) {
202                         free(tmp);
203                         free(ret);
204                         return NULL; /* Shouldn't happen */
205                 }
206         }
207         free(tmp);
208         return ret;
209 }
210
211 /*
212  * Assemble a list of algorithms from a default list and a string from a
213  * configuration file. The user-provided string may begin with '+' to
214  * indicate that it should be appended to the default.
215  */
216 int
217 kex_assemble_names(const char *def, char **list)
218 {
219         char *ret;
220
221         if (list == NULL || *list == NULL || **list == '\0') {
222                 *list = strdup(def);
223                 return 0;
224         }
225         if (**list != '+') {
226                 return 0;
227         }
228
229         if ((ret = kex_names_cat(def, *list + 1)) == NULL)
230                 return SSH_ERR_ALLOC_FAIL;
231         free(*list);
232         *list = ret;
233         return 0;
234 }
235
236 /* put algorithm proposal into buffer */
237 int
238 kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
239 {
240         u_int i;
241         int r;
242
243         sshbuf_reset(b);
244
245         /*
246          * add a dummy cookie, the cookie will be overwritten by
247          * kex_send_kexinit(), each time a kexinit is set
248          */
249         for (i = 0; i < KEX_COOKIE_LEN; i++) {
250                 if ((r = sshbuf_put_u8(b, 0)) != 0)
251                         return r;
252         }
253         for (i = 0; i < PROPOSAL_MAX; i++) {
254                 if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
255                         return r;
256         }
257         if ((r = sshbuf_put_u8(b, 0)) != 0 ||   /* first_kex_packet_follows */
258             (r = sshbuf_put_u32(b, 0)) != 0)    /* uint32 reserved */
259                 return r;
260         return 0;
261 }
262
263 /* parse buffer and return algorithm proposal */
264 int
265 kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
266 {
267         struct sshbuf *b = NULL;
268         u_char v;
269         u_int i;
270         char **proposal = NULL;
271         int r;
272
273         *propp = NULL;
274         if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
275                 return SSH_ERR_ALLOC_FAIL;
276         if ((b = sshbuf_fromb(raw)) == NULL) {
277                 r = SSH_ERR_ALLOC_FAIL;
278                 goto out;
279         }
280         if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) /* skip cookie */
281                 goto out;
282         /* extract kex init proposal strings */
283         for (i = 0; i < PROPOSAL_MAX; i++) {
284                 if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0)
285                         goto out;
286                 debug2("%s: %s", proposal_names[i], proposal[i]);
287         }
288         /* first kex follows / reserved */
289         if ((r = sshbuf_get_u8(b, &v)) != 0 ||  /* first_kex_follows */
290             (r = sshbuf_get_u32(b, &i)) != 0)   /* reserved */
291                 goto out;
292         if (first_kex_follows != NULL)
293                 *first_kex_follows = v;
294         debug2("first_kex_follows %d ", v);
295         debug2("reserved %u ", i);
296         r = 0;
297         *propp = proposal;
298  out:
299         if (r != 0 && proposal != NULL)
300                 kex_prop_free(proposal);
301         sshbuf_free(b);
302         return r;
303 }
304
305 void
306 kex_prop_free(char **proposal)
307 {
308         u_int i;
309
310         if (proposal == NULL)
311                 return;
312         for (i = 0; i < PROPOSAL_MAX; i++)
313                 free(proposal[i]);
314         free(proposal);
315 }
316
317 /* ARGSUSED */
318 static int
319 kex_protocol_error(int type, u_int32_t seq, void *ctxt)
320 {
321         struct ssh *ssh = active_state; /* XXX */
322         int r;
323
324         error("kex protocol error: type %d seq %u", type, seq);
325         if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
326             (r = sshpkt_put_u32(ssh, seq)) != 0 ||
327             (r = sshpkt_send(ssh)) != 0)
328                 return r;
329         return 0;
330 }
331
332 static void
333 kex_reset_dispatch(struct ssh *ssh)
334 {
335         ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
336             SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
337         ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
338 }
339
340 static int
341 kex_send_ext_info(struct ssh *ssh)
342 {
343         int r;
344
345         if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
346             (r = sshpkt_put_u32(ssh, 1)) != 0 ||
347             (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
348             (r = sshpkt_put_cstring(ssh, "rsa-sha2-256,rsa-sha2-512")) != 0 ||
349             (r = sshpkt_send(ssh)) != 0)
350                 return r;
351         return 0;
352 }
353
354 int
355 kex_send_newkeys(struct ssh *ssh)
356 {
357         int r;
358
359         kex_reset_dispatch(ssh);
360         if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
361             (r = sshpkt_send(ssh)) != 0)
362                 return r;
363         debug("SSH2_MSG_NEWKEYS sent");
364         debug("expecting SSH2_MSG_NEWKEYS");
365         ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
366         if (ssh->kex->ext_info_c)
367                 if ((r = kex_send_ext_info(ssh)) != 0)
368                         return r;
369         return 0;
370 }
371
372 int
373 kex_input_ext_info(int type, u_int32_t seq, void *ctxt)
374 {
375         struct ssh *ssh = ctxt;
376         struct kex *kex = ssh->kex;
377         u_int32_t i, ninfo;
378         char *name, *val, *found;
379         int r;
380
381         debug("SSH2_MSG_EXT_INFO received");
382         ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
383         if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
384                 return r;
385         for (i = 0; i < ninfo; i++) {
386                 if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
387                         return r;
388                 if ((r = sshpkt_get_cstring(ssh, &val, NULL)) != 0) {
389                         free(name);
390                         return r;
391                 }
392                 debug("%s: %s=<%s>", __func__, name, val);
393                 if (strcmp(name, "server-sig-algs") == 0) {
394                         found = match_list("rsa-sha2-256", val, NULL);
395                         if (found) {
396                                 kex->rsa_sha2 = 256;
397                                 free(found);
398                         }
399                         found = match_list("rsa-sha2-512", val, NULL);
400                         if (found) {
401                                 kex->rsa_sha2 = 512;
402                                 free(found);
403                         }
404                 }
405                 free(name);
406                 free(val);
407         }
408         return sshpkt_get_end(ssh);
409 }
410
411 static int
412 kex_input_newkeys(int type, u_int32_t seq, void *ctxt)
413 {
414         struct ssh *ssh = ctxt;
415         struct kex *kex = ssh->kex;
416         int r;
417
418         debug("SSH2_MSG_NEWKEYS received");
419         ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
420         if ((r = sshpkt_get_end(ssh)) != 0)
421                 return r;
422         kex->done = 1;
423         sshbuf_reset(kex->peer);
424         /* sshbuf_reset(kex->my); */
425         kex->flags &= ~KEX_INIT_SENT;
426         free(kex->name);
427         kex->name = NULL;
428         return 0;
429 }
430
431 int
432 kex_send_kexinit(struct ssh *ssh)
433 {
434         u_char *cookie;
435         struct kex *kex = ssh->kex;
436         int r;
437
438         if (kex == NULL)
439                 return SSH_ERR_INTERNAL_ERROR;
440         if (kex->flags & KEX_INIT_SENT)
441                 return 0;
442         kex->done = 0;
443
444         /* generate a random cookie */
445         if (sshbuf_len(kex->my) < KEX_COOKIE_LEN)
446                 return SSH_ERR_INVALID_FORMAT;
447         if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL)
448                 return SSH_ERR_INTERNAL_ERROR;
449         arc4random_buf(cookie, KEX_COOKIE_LEN);
450
451         if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
452             (r = sshpkt_putb(ssh, kex->my)) != 0 ||
453             (r = sshpkt_send(ssh)) != 0)
454                 return r;
455         debug("SSH2_MSG_KEXINIT sent");
456         kex->flags |= KEX_INIT_SENT;
457         return 0;
458 }
459
460 /* ARGSUSED */
461 int
462 kex_input_kexinit(int type, u_int32_t seq, void *ctxt)
463 {
464         struct ssh *ssh = ctxt;
465         struct kex *kex = ssh->kex;
466         const u_char *ptr;
467         u_int i;
468         size_t dlen;
469         int r;
470
471         debug("SSH2_MSG_KEXINIT received");
472         if (kex == NULL)
473                 return SSH_ERR_INVALID_ARGUMENT;
474
475         ptr = sshpkt_ptr(ssh, &dlen);
476         if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
477                 return r;
478
479         /* discard packet */
480         for (i = 0; i < KEX_COOKIE_LEN; i++)
481                 if ((r = sshpkt_get_u8(ssh, NULL)) != 0)
482                         return r;
483         for (i = 0; i < PROPOSAL_MAX; i++)
484                 if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0)
485                         return r;
486         /*
487          * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
488          * KEX method has the server move first, but a server might be using
489          * a custom method or one that we otherwise don't support. We should
490          * be prepared to remember first_kex_follows here so we can eat a
491          * packet later.
492          * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
493          * for cases where the server *doesn't* go first. I guess we should
494          * ignore it when it is set for these cases, which is what we do now.
495          */
496         if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||      /* first_kex_follows */
497             (r = sshpkt_get_u32(ssh, NULL)) != 0 ||     /* reserved */
498             (r = sshpkt_get_end(ssh)) != 0)
499                         return r;
500
501         if (!(kex->flags & KEX_INIT_SENT))
502                 if ((r = kex_send_kexinit(ssh)) != 0)
503                         return r;
504         if ((r = kex_choose_conf(ssh)) != 0)
505                 return r;
506
507         if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
508                 return (kex->kex[kex->kex_type])(ssh);
509
510         return SSH_ERR_INTERNAL_ERROR;
511 }
512
513 int
514 kex_new(struct ssh *ssh, char *proposal[PROPOSAL_MAX], struct kex **kexp)
515 {
516         struct kex *kex;
517         int r;
518
519         *kexp = NULL;
520         if ((kex = calloc(1, sizeof(*kex))) == NULL)
521                 return SSH_ERR_ALLOC_FAIL;
522         if ((kex->peer = sshbuf_new()) == NULL ||
523             (kex->my = sshbuf_new()) == NULL) {
524                 r = SSH_ERR_ALLOC_FAIL;
525                 goto out;
526         }
527         if ((r = kex_prop2buf(kex->my, proposal)) != 0)
528                 goto out;
529         kex->done = 0;
530         kex_reset_dispatch(ssh);
531         r = 0;
532         *kexp = kex;
533  out:
534         if (r != 0)
535                 kex_free(kex);
536         return r;
537 }
538
539 void
540 kex_free_newkeys(struct newkeys *newkeys)
541 {
542         if (newkeys == NULL)
543                 return;
544         if (newkeys->enc.key) {
545                 explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
546                 free(newkeys->enc.key);
547                 newkeys->enc.key = NULL;
548         }
549         if (newkeys->enc.iv) {
550                 explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
551                 free(newkeys->enc.iv);
552                 newkeys->enc.iv = NULL;
553         }
554         free(newkeys->enc.name);
555         explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
556         free(newkeys->comp.name);
557         explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
558         mac_clear(&newkeys->mac);
559         if (newkeys->mac.key) {
560                 explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
561                 free(newkeys->mac.key);
562                 newkeys->mac.key = NULL;
563         }
564         free(newkeys->mac.name);
565         explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
566         explicit_bzero(newkeys, sizeof(*newkeys));
567         free(newkeys);
568 }
569
570 void
571 kex_free(struct kex *kex)
572 {
573         u_int mode;
574
575 #ifdef WITH_OPENSSL
576         if (kex->dh)
577                 DH_free(kex->dh);
578 #ifdef OPENSSL_HAS_ECC
579         if (kex->ec_client_key)
580                 EC_KEY_free(kex->ec_client_key);
581 #endif /* OPENSSL_HAS_ECC */
582 #endif /* WITH_OPENSSL */
583         for (mode = 0; mode < MODE_MAX; mode++) {
584                 kex_free_newkeys(kex->newkeys[mode]);
585                 kex->newkeys[mode] = NULL;
586         }
587         sshbuf_free(kex->peer);
588         sshbuf_free(kex->my);
589         free(kex->session_id);
590         free(kex->client_version_string);
591         free(kex->server_version_string);
592         free(kex->failed_choice);
593         free(kex->hostkey_alg);
594         free(kex->name);
595         free(kex);
596 }
597
598 int
599 kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
600 {
601         int r;
602
603         if ((r = kex_new(ssh, proposal, &ssh->kex)) != 0)
604                 return r;
605         if ((r = kex_send_kexinit(ssh)) != 0) {         /* we start */
606                 kex_free(ssh->kex);
607                 ssh->kex = NULL;
608                 return r;
609         }
610         return 0;
611 }
612
613 /*
614  * Request key re-exchange, returns 0 on success or a ssherr.h error
615  * code otherwise. Must not be called if KEX is incomplete or in-progress.
616  */
617 int
618 kex_start_rekex(struct ssh *ssh)
619 {
620         if (ssh->kex == NULL) {
621                 error("%s: no kex", __func__);
622                 return SSH_ERR_INTERNAL_ERROR;
623         }
624         if (ssh->kex->done == 0) {
625                 error("%s: requested twice", __func__);
626                 return SSH_ERR_INTERNAL_ERROR;
627         }
628         ssh->kex->done = 0;
629         return kex_send_kexinit(ssh);
630 }
631
632 static int
633 choose_enc(struct sshenc *enc, char *client, char *server)
634 {
635         char *name = match_list(client, server, NULL);
636
637         if (name == NULL)
638                 return SSH_ERR_NO_CIPHER_ALG_MATCH;
639         if ((enc->cipher = cipher_by_name(name)) == NULL)
640                 return SSH_ERR_INTERNAL_ERROR;
641         enc->name = name;
642         enc->enabled = 0;
643         enc->iv = NULL;
644         enc->iv_len = cipher_ivlen(enc->cipher);
645         enc->key = NULL;
646         enc->key_len = cipher_keylen(enc->cipher);
647         enc->block_size = cipher_blocksize(enc->cipher);
648         return 0;
649 }
650
651 static int
652 choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
653 {
654         char *name = match_list(client, server, NULL);
655
656         if (name == NULL)
657                 return SSH_ERR_NO_MAC_ALG_MATCH;
658         if (mac_setup(mac, name) < 0)
659                 return SSH_ERR_INTERNAL_ERROR;
660         /* truncate the key */
661         if (ssh->compat & SSH_BUG_HMAC)
662                 mac->key_len = 16;
663         mac->name = name;
664         mac->key = NULL;
665         mac->enabled = 0;
666         return 0;
667 }
668
669 static int
670 choose_comp(struct sshcomp *comp, char *client, char *server)
671 {
672         char *name = match_list(client, server, NULL);
673
674         if (name == NULL)
675                 return SSH_ERR_NO_COMPRESS_ALG_MATCH;
676         if (strcmp(name, "zlib@openssh.com") == 0) {
677                 comp->type = COMP_DELAYED;
678         } else if (strcmp(name, "zlib") == 0) {
679                 comp->type = COMP_ZLIB;
680         } else if (strcmp(name, "none") == 0) {
681                 comp->type = COMP_NONE;
682         } else {
683                 return SSH_ERR_INTERNAL_ERROR;
684         }
685         comp->name = name;
686         return 0;
687 }
688
689 static int
690 choose_kex(struct kex *k, char *client, char *server)
691 {
692         const struct kexalg *kexalg;
693
694         k->name = match_list(client, server, NULL);
695
696         debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
697         if (k->name == NULL)
698                 return SSH_ERR_NO_KEX_ALG_MATCH;
699         if ((kexalg = kex_alg_by_name(k->name)) == NULL)
700                 return SSH_ERR_INTERNAL_ERROR;
701         k->kex_type = kexalg->type;
702         k->hash_alg = kexalg->hash_alg;
703         k->ec_nid = kexalg->ec_nid;
704         return 0;
705 }
706
707 static int
708 choose_hostkeyalg(struct kex *k, char *client, char *server)
709 {
710         k->hostkey_alg = match_list(client, server, NULL);
711
712         debug("kex: host key algorithm: %s",
713             k->hostkey_alg ? k->hostkey_alg : "(no match)");
714         if (k->hostkey_alg == NULL)
715                 return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
716         k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
717         if (k->hostkey_type == KEY_UNSPEC)
718                 return SSH_ERR_INTERNAL_ERROR;
719         k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
720         return 0;
721 }
722
723 static int
724 proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
725 {
726         static int check[] = {
727                 PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
728         };
729         int *idx;
730         char *p;
731
732         for (idx = &check[0]; *idx != -1; idx++) {
733                 if ((p = strchr(my[*idx], ',')) != NULL)
734                         *p = '\0';
735                 if ((p = strchr(peer[*idx], ',')) != NULL)
736                         *p = '\0';
737                 if (strcmp(my[*idx], peer[*idx]) != 0) {
738                         debug2("proposal mismatch: my %s peer %s",
739                             my[*idx], peer[*idx]);
740                         return (0);
741                 }
742         }
743         debug2("proposals match");
744         return (1);
745 }
746
747 static int
748 kex_choose_conf(struct ssh *ssh)
749 {
750         struct kex *kex = ssh->kex;
751         struct newkeys *newkeys;
752         char **my = NULL, **peer = NULL;
753         char **cprop, **sprop;
754         int nenc, nmac, ncomp;
755         u_int mode, ctos, need, dh_need, authlen;
756         int r, first_kex_follows;
757
758         debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
759         if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
760                 goto out;
761         debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
762         if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
763                 goto out;
764
765         if (kex->server) {
766                 cprop=peer;
767                 sprop=my;
768         } else {
769                 cprop=my;
770                 sprop=peer;
771         }
772
773         /* Check whether client supports ext_info_c */
774         if (kex->server) {
775                 char *ext;
776
777                 ext = match_list("ext-info-c", peer[PROPOSAL_KEX_ALGS], NULL);
778                 if (ext) {
779                         kex->ext_info_c = 1;
780                         free(ext);
781                 }
782         }
783
784         /* Algorithm Negotiation */
785         if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
786             sprop[PROPOSAL_KEX_ALGS])) != 0) {
787                 kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
788                 peer[PROPOSAL_KEX_ALGS] = NULL;
789                 goto out;
790         }
791         if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
792             sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
793                 kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
794                 peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
795                 goto out;
796         }
797         for (mode = 0; mode < MODE_MAX; mode++) {
798                 if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
799                         r = SSH_ERR_ALLOC_FAIL;
800                         goto out;
801                 }
802                 kex->newkeys[mode] = newkeys;
803                 ctos = (!kex->server && mode == MODE_OUT) ||
804                     (kex->server && mode == MODE_IN);
805                 nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
806                 nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
807                 ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
808                 if ((r = choose_enc(&newkeys->enc, cprop[nenc],
809                     sprop[nenc])) != 0) {
810                         kex->failed_choice = peer[nenc];
811                         peer[nenc] = NULL;
812                         goto out;
813                 }
814                 authlen = cipher_authlen(newkeys->enc.cipher);
815                 /* ignore mac for authenticated encryption */
816                 if (authlen == 0 &&
817                     (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
818                     sprop[nmac])) != 0) {
819                         kex->failed_choice = peer[nmac];
820                         peer[nmac] = NULL;
821                         goto out;
822                 }
823                 if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
824                     sprop[ncomp])) != 0) {
825                         kex->failed_choice = peer[ncomp];
826                         peer[ncomp] = NULL;
827                         goto out;
828                 }
829                 debug("kex: %s cipher: %s MAC: %s compression: %s",
830                     ctos ? "client->server" : "server->client",
831                     newkeys->enc.name,
832                     authlen == 0 ? newkeys->mac.name : "<implicit>",
833                     newkeys->comp.name);
834         }
835         need = dh_need = 0;
836         for (mode = 0; mode < MODE_MAX; mode++) {
837                 newkeys = kex->newkeys[mode];
838                 need = MAX(need, newkeys->enc.key_len);
839                 need = MAX(need, newkeys->enc.block_size);
840                 need = MAX(need, newkeys->enc.iv_len);
841                 need = MAX(need, newkeys->mac.key_len);
842                 dh_need = MAX(dh_need, cipher_seclen(newkeys->enc.cipher));
843                 dh_need = MAX(dh_need, newkeys->enc.block_size);
844                 dh_need = MAX(dh_need, newkeys->enc.iv_len);
845                 dh_need = MAX(dh_need, newkeys->mac.key_len);
846         }
847         /* XXX need runden? */
848         kex->we_need = need;
849         kex->dh_need = dh_need;
850
851         /* ignore the next message if the proposals do not match */
852         if (first_kex_follows && !proposals_match(my, peer) &&
853             !(ssh->compat & SSH_BUG_FIRSTKEX))
854                 ssh->dispatch_skip_packets = 1;
855         r = 0;
856  out:
857         kex_prop_free(my);
858         kex_prop_free(peer);
859         return r;
860 }
861
862 static int
863 derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
864     const struct sshbuf *shared_secret, u_char **keyp)
865 {
866         struct kex *kex = ssh->kex;
867         struct ssh_digest_ctx *hashctx = NULL;
868         char c = id;
869         u_int have;
870         size_t mdsz;
871         u_char *digest;
872         int r;
873
874         if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
875                 return SSH_ERR_INVALID_ARGUMENT;
876         if ((digest = calloc(1, roundup(need, mdsz))) == NULL) {
877                 r = SSH_ERR_ALLOC_FAIL;
878                 goto out;
879         }
880
881         /* K1 = HASH(K || H || "A" || session_id) */
882         if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
883             ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
884             ssh_digest_update(hashctx, hash, hashlen) != 0 ||
885             ssh_digest_update(hashctx, &c, 1) != 0 ||
886             ssh_digest_update(hashctx, kex->session_id,
887             kex->session_id_len) != 0 ||
888             ssh_digest_final(hashctx, digest, mdsz) != 0) {
889                 r = SSH_ERR_LIBCRYPTO_ERROR;
890                 goto out;
891         }
892         ssh_digest_free(hashctx);
893         hashctx = NULL;
894
895         /*
896          * expand key:
897          * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
898          * Key = K1 || K2 || ... || Kn
899          */
900         for (have = mdsz; need > have; have += mdsz) {
901                 if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
902                     ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
903                     ssh_digest_update(hashctx, hash, hashlen) != 0 ||
904                     ssh_digest_update(hashctx, digest, have) != 0 ||
905                     ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
906                         r = SSH_ERR_LIBCRYPTO_ERROR;
907                         goto out;
908                 }
909                 ssh_digest_free(hashctx);
910                 hashctx = NULL;
911         }
912 #ifdef DEBUG_KEX
913         fprintf(stderr, "key '%c'== ", c);
914         dump_digest("key", digest, need);
915 #endif
916         *keyp = digest;
917         digest = NULL;
918         r = 0;
919  out:
920         free(digest);
921         ssh_digest_free(hashctx);
922         return r;
923 }
924
925 #define NKEYS   6
926 int
927 kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
928     const struct sshbuf *shared_secret)
929 {
930         struct kex *kex = ssh->kex;
931         u_char *keys[NKEYS];
932         u_int i, j, mode, ctos;
933         int r;
934
935         for (i = 0; i < NKEYS; i++) {
936                 if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
937                     shared_secret, &keys[i])) != 0) {
938                         for (j = 0; j < i; j++)
939                                 free(keys[j]);
940                         return r;
941                 }
942         }
943         for (mode = 0; mode < MODE_MAX; mode++) {
944                 ctos = (!kex->server && mode == MODE_OUT) ||
945                     (kex->server && mode == MODE_IN);
946                 kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
947                 kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
948                 kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
949         }
950         return 0;
951 }
952
953 #ifdef WITH_OPENSSL
954 int
955 kex_derive_keys_bn(struct ssh *ssh, u_char *hash, u_int hashlen,
956     const BIGNUM *secret)
957 {
958         struct sshbuf *shared_secret;
959         int r;
960
961         if ((shared_secret = sshbuf_new()) == NULL)
962                 return SSH_ERR_ALLOC_FAIL;
963         if ((r = sshbuf_put_bignum2(shared_secret, secret)) == 0)
964                 r = kex_derive_keys(ssh, hash, hashlen, shared_secret);
965         sshbuf_free(shared_secret);
966         return r;
967 }
968 #endif
969
970 #ifdef WITH_SSH1
971 int
972 derive_ssh1_session_id(BIGNUM *host_modulus, BIGNUM *server_modulus,
973     u_int8_t cookie[8], u_int8_t id[16])
974 {
975         u_int8_t hbuf[2048], sbuf[2048], obuf[SSH_DIGEST_MAX_LENGTH];
976         struct ssh_digest_ctx *hashctx = NULL;
977         size_t hlen, slen;
978         int r;
979
980         hlen = BN_num_bytes(host_modulus);
981         slen = BN_num_bytes(server_modulus);
982         if (hlen < (512 / 8) || (u_int)hlen > sizeof(hbuf) ||
983             slen < (512 / 8) || (u_int)slen > sizeof(sbuf))
984                 return SSH_ERR_KEY_BITS_MISMATCH;
985         if (BN_bn2bin(host_modulus, hbuf) <= 0 ||
986             BN_bn2bin(server_modulus, sbuf) <= 0) {
987                 r = SSH_ERR_LIBCRYPTO_ERROR;
988                 goto out;
989         }
990         if ((hashctx = ssh_digest_start(SSH_DIGEST_MD5)) == NULL) {
991                 r = SSH_ERR_ALLOC_FAIL;
992                 goto out;
993         }
994         if (ssh_digest_update(hashctx, hbuf, hlen) != 0 ||
995             ssh_digest_update(hashctx, sbuf, slen) != 0 ||
996             ssh_digest_update(hashctx, cookie, 8) != 0 ||
997             ssh_digest_final(hashctx, obuf, sizeof(obuf)) != 0) {
998                 r = SSH_ERR_LIBCRYPTO_ERROR;
999                 goto out;
1000         }
1001         memcpy(id, obuf, ssh_digest_bytes(SSH_DIGEST_MD5));
1002         r = 0;
1003  out:
1004         ssh_digest_free(hashctx);
1005         explicit_bzero(hbuf, sizeof(hbuf));
1006         explicit_bzero(sbuf, sizeof(sbuf));
1007         explicit_bzero(obuf, sizeof(obuf));
1008         return r;
1009 }
1010 #endif
1011
1012 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
1013 void
1014 dump_digest(char *msg, u_char *digest, int len)
1015 {
1016         fprintf(stderr, "%s\n", msg);
1017         sshbuf_dump_data(digest, len, stderr);
1018 }
1019 #endif