tcplay - bring in
[dragonfly.git] / lib / libtcplay / crypto-dev.c
1 /*
2  * Copyright (c) 2011 Alex Hornung <alex@alexhornung.com>.
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright
10  *    notice, this list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright
12  *    notice, this list of conditions and the following disclaimer in
13  *    the documentation and/or other materials provided with the
14  *    distribution.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
19  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE
20  * COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
21  * INCIDENTAL, SPECIAL, EXEMPLARY OR CONSEQUENTIAL DAMAGES (INCLUDING,
22  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
23  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
24  * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
25  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
26  * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27  * SUCH DAMAGE.
28  */
29 #include <sys/types.h>
30 #include <sys/param.h>
31 #include <sys/ioctl.h>
32 #include <sys/sysctl.h>
33 #include <crypto/cryptodev.h>
34
35 #include <fcntl.h>
36 #include <unistd.h>
37 #include <errno.h>
38 #include <string.h>
39 #include <openssl/evp.h>
40
41 #include "crc32.h"
42 #include "tcplay.h"
43
44 static
45 int
46 getallowsoft(void)
47 {
48         int old;
49         size_t olen;
50
51         olen = sizeof(old);
52
53         if (sysctlbyname("kern.cryptodevallowsoft", &old, &olen, NULL, 0) < 0) {
54                 perror("accessing sysctl kern.cryptodevallowsoft failed");
55         }
56
57         return old;
58 }
59
60 static
61 void
62 setallowsoft(int new)
63 {
64         int old;
65         size_t olen, nlen;
66
67         olen = nlen = sizeof(new);
68
69         if (sysctlbyname("kern.cryptodevallowsoft", &old, &olen, &new, nlen) < 0) {
70                 perror("accessing sysctl kern.cryptodevallowsoft failed");
71         }
72 }
73
74 static
75 int
76 syscrypt(int cipher, unsigned char *key, size_t klen, unsigned char *iv,
77     unsigned char *in, unsigned char *out, size_t len, int do_encrypt)
78 {
79         struct session_op session;
80         struct crypt_op cryp;
81         int cryptodev_fd = -1, fd = -1;
82
83         if ((cryptodev_fd = open("/dev/crypto", O_RDWR, 0)) < 0) {
84                 perror("Could not open /dev/crypto");
85                 goto err;
86         }
87         if (ioctl(cryptodev_fd, CRIOGET, &fd) == -1) {
88                 perror("CRIOGET failed");
89                 goto err;
90         }
91         memset(&session, 0, sizeof(session));
92         session.cipher = cipher;
93         session.key = (caddr_t) key;
94         session.keylen = klen;
95         if (ioctl(fd, CIOCGSESSION, &session) == -1) {
96                 perror("CIOCGSESSION failed");
97                 goto err;
98         }
99         memset(&cryp, 0, sizeof(cryp));
100         cryp.ses = session.ses;
101         cryp.op = do_encrypt ? COP_ENCRYPT : COP_DECRYPT;
102         cryp.flags = 0;
103         cryp.len = len;
104         cryp.src = (caddr_t) in;
105         cryp.dst = (caddr_t) out;
106         cryp.iv = (caddr_t) iv;
107         cryp.mac = 0;
108         if (ioctl(fd, CIOCCRYPT, &cryp) == -1) {
109                 perror("CIOCCRYPT failed");
110                 goto err;
111         }
112         if (ioctl(fd, CIOCFSESSION, &session.ses) == -1) {
113                 perror("CIOCFSESSION failed");
114                 goto err;
115         }
116         close(fd);
117         close(cryptodev_fd);
118         return (0);
119
120 err:
121         if (fd != -1)
122                 close(fd);
123         if (cryptodev_fd != -1)
124                 close(cryptodev_fd);
125         return (-1);
126 }
127
128 static
129 int
130 get_cryptodev_cipher_id(struct tc_crypto_algo *cipher)
131 {
132         if      (strcmp(cipher->name, "AES-128-XTS") == 0)
133                 return CRYPTO_AES_XTS;
134         else if (strcmp(cipher->name, "AES-256-XTS") == 0)
135                 return CRYPTO_AES_XTS;
136         else if (strcmp(cipher->name, "TWOFISH-128-XTS") == 0)
137                 return CRYPTO_TWOFISH_XTS;
138         else if (strcmp(cipher->name, "TWOFISH-256-XTS") == 0)
139                 return CRYPTO_TWOFISH_XTS;
140         else if (strcmp(cipher->name, "SERPENT-128-XTS") == 0)
141                 return CRYPTO_SERPENT_XTS;
142         else if (strcmp(cipher->name, "SERPENT-256-XTS") == 0)
143                 return CRYPTO_SERPENT_XTS;
144         else
145                 return -1;
146 }
147
148 int
149 tc_crypto_init(void)
150 {
151         int allowed;
152
153         OpenSSL_add_all_algorithms();
154
155         allowed = getallowsoft();
156         if (allowed == 0)
157                 setallowsoft(1);
158
159         return 0;
160 }
161
162 int
163 tc_cipher_chain_populate_keys(struct tc_cipher_chain *cipher_chain,
164     unsigned char *key)
165 {
166         int total_key_bytes, used_key_bytes;
167         struct tc_cipher_chain *dummy_chain;
168
169         /*
170          * We need to determine the total key bytes as the key locations
171          * depend on it.
172          */
173         total_key_bytes = 0;
174         for (dummy_chain = cipher_chain;
175             dummy_chain != NULL;
176             dummy_chain = dummy_chain->next) {
177                 total_key_bytes += dummy_chain->cipher->klen;
178         }
179
180         /*
181          * Now we need to get prepare the keys, as the keys are in
182          * forward order with respect to the cipher cascade, but
183          * the actual decryption is in reverse cipher cascade order.
184          */
185         used_key_bytes = 0;
186         for (dummy_chain = cipher_chain;
187             dummy_chain != NULL;
188             dummy_chain = dummy_chain->next) {
189                 dummy_chain->key = alloc_safe_mem(dummy_chain->cipher->klen);
190                 if (dummy_chain->key == NULL) {
191                         tc_log(1, "tc_decrypt: Could not allocate key "
192                             "memory\n");
193                         return ENOMEM;
194                 }
195
196                 /* XXX: here we assume XTS operation! */
197                 memcpy(dummy_chain->key,
198                     key + used_key_bytes/2,
199                     dummy_chain->cipher->klen/2);
200                 memcpy(dummy_chain->key + dummy_chain->cipher->klen/2,
201                     key + (total_key_bytes/2) + used_key_bytes/2,
202                     dummy_chain->cipher->klen/2);
203
204                 /* Remember how many key bytes we've seen */
205                 used_key_bytes += dummy_chain->cipher->klen;
206         }
207
208         return 0;
209 }
210
211 int
212 tc_encrypt(struct tc_cipher_chain *cipher_chain, unsigned char *key,
213     unsigned char *iv,
214     unsigned char *in, int in_len, unsigned char *out)
215 {
216         int cipher_id;
217         int err;
218
219         if ((err = tc_cipher_chain_populate_keys(cipher_chain, key)))
220                 return err;
221
222 #ifdef DEBUG
223         printf("tc_encrypt: starting chain\n");
224 #endif
225
226         /*
227          * Now process the actual decryption, in forward cascade order.
228          */
229         for (;
230             cipher_chain != NULL;
231             cipher_chain = cipher_chain->next) {
232                 cipher_id = get_cryptodev_cipher_id(cipher_chain->cipher);
233                 if (cipher_id < 0) {
234                         tc_log(1, "Cipher %s not found\n",
235                             cipher_chain->cipher->name);
236                         return ENOENT;
237                 }
238
239 #ifdef DEBUG
240                 printf("tc_encrypt: Currently using cipher %s\n",
241                     cipher_chain->cipher->name);
242 #endif
243
244                 err = syscrypt(cipher_id, cipher_chain->key,
245                     cipher_chain->cipher->klen, iv, in, out, in_len, 1);
246
247                 /* Deallocate this key, since we won't need it anymore */
248                 free_safe_mem(cipher_chain->key);
249
250                 if (err != 0)
251                         return err;
252
253                 /* Set next input buffer as current output buffer */
254                 in = out;
255         }
256
257         return 0;
258 }
259
260 int
261 tc_decrypt(struct tc_cipher_chain *cipher_chain, unsigned char *key,
262     unsigned char *iv,
263     unsigned char *in, int in_len, unsigned char *out)
264 {
265         int cipher_id;
266         int err;
267
268         if ((err = tc_cipher_chain_populate_keys(cipher_chain, key)))
269                 return err;
270
271 #ifdef DEBUG
272         printf("tc_decrypt: starting chain!\n");
273 #endif
274
275         /*
276          * Now process the actual decryption, in reverse cascade order; so
277          * first find the last element in the chain.
278          */
279         for (; cipher_chain->next != NULL; cipher_chain = cipher_chain->next)
280                 ;
281         for (;
282             cipher_chain != NULL;
283             cipher_chain = cipher_chain->prev) {
284                 cipher_id = get_cryptodev_cipher_id(cipher_chain->cipher);
285                 if (cipher_id < 0) {
286                         tc_log(1, "Cipher %s not found\n",
287                             cipher_chain->cipher->name);
288                         return ENOENT;
289                 }
290
291 #ifdef DEBUG
292                 printf("tc_decrypt: Currently using cipher %s\n",
293                     cipher_chain->cipher->name);
294 #endif
295
296                 err = syscrypt(cipher_id, cipher_chain->key,
297                     cipher_chain->cipher->klen, iv, in, out, in_len, 0);
298
299                 /* Deallocate this key, since we won't need it anymore */
300                 free_safe_mem(cipher_chain->key);
301
302                 if (err != 0)
303                         return err;
304
305                 /* Set next input buffer as current output buffer */
306                 in = out;
307         }
308
309         return 0;
310 }
311
312 int
313 pbkdf2(const char *pass, int passlen, const unsigned char *salt, int saltlen,
314     int iter, const char *hash_name, int keylen, unsigned char *out)
315 {
316         const EVP_MD *md;
317         int r;
318
319         md = EVP_get_digestbyname(hash_name);
320         if (md == NULL) {
321                 printf("Hash %s not found\n", hash_name);
322                 return ENOENT;
323         }
324         r = PKCS5_PBKDF2_HMAC(pass, passlen, salt, saltlen, iter, md,
325             keylen, out);
326
327         if (r == 0) {
328                 printf("Error in PBKDF2\n");
329                 return EINVAL;
330         }
331
332         return 0;
333 }
334
335 int
336 apply_keyfiles(unsigned char *pass, size_t pass_memsz, const char *keyfiles[],
337     int nkeyfiles)
338 {
339         int pl, k;
340         unsigned char *kpool;
341         unsigned char *kdata;
342         int kpool_idx;
343         size_t i, kdata_sz;
344         uint32_t crc;
345
346         if (pass_memsz < MAX_PASSSZ) {
347                 tc_log(1, "Not enough memory for password manipluation\n");
348                 return ENOMEM;
349         }
350
351         pl = strlen(pass);
352         memset(pass+pl, 0, MAX_PASSSZ-pl);
353
354         if ((kpool = alloc_safe_mem(KPOOL_SZ)) == NULL) {
355                 tc_log(1, "Error allocating memory for keyfile pool\n");
356                 return ENOMEM;
357         }
358
359         memset(kpool, 0, KPOOL_SZ);
360
361         for (k = 0; k < nkeyfiles; k++) {
362 #ifdef DEBUG
363                 printf("Loading keyfile %s into kpool\n", keyfiles[k]);
364 #endif
365                 kpool_idx = 0;
366                 crc = ~0U;
367                 kdata_sz = MAX_KFILE_SZ;
368
369                 if ((kdata = read_to_safe_mem(keyfiles[k], 0, &kdata_sz)) == NULL) {
370                         tc_log(1, "Error reading keyfile %s content\n",
371                             keyfiles[k]);
372                         free_safe_mem(kpool);
373                         return EIO;
374                 }
375
376                 for (i = 0; i < kdata_sz; i++) {
377                         crc = crc32_intermediate(crc, kdata[i]);
378
379                         kpool[kpool_idx++] += (unsigned char)(crc >> 24);
380                         kpool[kpool_idx++] += (unsigned char)(crc >> 16);
381                         kpool[kpool_idx++] += (unsigned char)(crc >> 8);
382                         kpool[kpool_idx++] += (unsigned char)(crc);
383
384                         /* Wrap around */
385                         if (kpool_idx == KPOOL_SZ)
386                                 kpool_idx = 0;
387                 }
388
389                 free_safe_mem(kdata);
390         }
391
392 #ifdef DEBUG
393         printf("Applying kpool to passphrase\n");
394 #endif
395         /* Apply keyfile pool to passphrase */
396         for (i = 0; i < KPOOL_SZ; i++)
397                 pass[i] += kpool[i];
398
399         free_safe_mem(kpool);
400
401         return 0;
402 }