b065a5287361de09056c2f09e03226dffbeaac83
[dragonfly.git] / sys / net / ipfw3_nat / ip_fw3_nat.c
1 /*
2  * Copyright (c) 2014 - 2018 The DragonFly Project.  All rights reserved.
3  *
4  * This code is derived from software contributed to The DragonFly Project
5  * by Bill Yuan <bycn82@dragonflybsd.org>
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  *
11  * 1. Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in
15  *    the documentation and/or other materials provided with the
16  *    distribution.
17  * 3. Neither the name of The DragonFly Project nor the names of its
18  *    contributors may be used to endorse or promote products derived
19  *    from this software without specific, prior written permission.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
24  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE
25  * COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
26  * INCIDENTAL, SPECIAL, EXEMPLARY OR CONSEQUENTIAL DAMAGES (INCLUDING,
27  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
28  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
29  * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
30  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
31  * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
32  * SUCH DAMAGE.
33  */
34
35 #include "opt_ipfw.h"
36 #include "opt_inet.h"
37 #ifndef INET
38 #error IPFIREWALL3 requires INET.
39 #endif /* INET */
40
41 #include <sys/param.h>
42 #include <sys/kernel.h>
43 #include <sys/malloc.h>
44 #include <sys/mbuf.h>
45 #include <sys/socketvar.h>
46 #include <sys/sysctl.h>
47 #include <sys/systimer.h>
48 #include <sys/thread2.h>
49 #include <sys/in_cksum.h>
50 #include <sys/systm.h>
51 #include <sys/proc.h>
52 #include <sys/socket.h>
53 #include <sys/syslog.h>
54 #include <sys/ucred.h>
55 #include <sys/lock.h>
56 #include <sys/mplock2.h>
57
58 #include <net/ethernet.h>
59 #include <net/netmsg2.h>
60 #include <net/netisr2.h>
61 #include <net/route.h>
62 #include <net/if.h>
63
64 #include <netinet/in.h>
65 #include <netinet/ip.h>
66 #include <netinet/ip_icmp.h>
67 #include <netinet/tcp.h>
68 #include <netinet/tcp_timer.h>
69 #include <netinet/tcp_var.h>
70 #include <netinet/tcpip.h>
71 #include <netinet/udp.h>
72 #include <netinet/udp_var.h>
73 #include <netinet/in_systm.h>
74 #include <netinet/in_var.h>
75 #include <netinet/in_pcb.h>
76 #include <netinet/ip_var.h>
77 #include <netinet/ip_divert.h>
78 #include <net/ipfw3/ip_fw.h>
79
80 #include "ip_fw3_nat.h"
81
82 MALLOC_DEFINE(M_IPFW3_NAT, "IP_FW3_NAT", "ipfw3_nat module");
83
84 /*
85  * Highspeed Lockless Kernel NAT
86  *
87  * Kernel NAT
88  * The network address translation (NAT) will replace the `src` of the packet
89  * with an `alias` (alias_addr & alias_port). Accordingt to the configuration,
90  * The alias will be randomly picked from the configured range.
91  *
92  * Highspeed
93  * The first outgoing packet should trigger the creation of the `net_state`,
94  * and the `net_state` will keep in a RB-Tree for the subsequent outgoing
95  * packets.
96  * The first returning packet will trigger the creation of the `net_state2`,
97  * which will be stored in a multidimensional array of points ( of net_state2 ).
98  *
99  * Lockless
100  * The `net_state` for outgoing packet will be stored in the nat_context of
101  * current CPU. But due to the nature of the NAT, the returning packet may be
102  * handled by another CPU. Hence, The `net_state2` for the returning packet
103  * will be prepared and stored into the nat_context of the right CPU.
104  */
105
106 struct ip_fw3_nat_context       *ip_fw3_nat_ctx[MAXCPU];
107 static struct callout           ip_fw3_nat_cleanup_callout;
108 extern struct ipfw3_context     *fw3_ctx[MAXCPU];
109 extern ip_fw_ctl_t              *ip_fw3_ctl_nat_ptr;
110
111 static int                      sysctl_var_cleanup_interval = 1;
112 static int                      sysctl_var_icmp_timeout = 10;
113 static int                      sysctl_var_tcp_timeout = 60;
114 static int                      sysctl_var_udp_timeout = 30;
115
116 SYSCTL_NODE(_net_inet_ip, OID_AUTO, fw3_nat, CTLFLAG_RW, 0, "ipfw3 NAT");
117 SYSCTL_INT(_net_inet_ip_fw3_nat, OID_AUTO, cleanup_interval, CTLFLAG_RW,
118                 &sysctl_var_cleanup_interval, 0, "default life time");
119 SYSCTL_INT(_net_inet_ip_fw3_nat, OID_AUTO, icmp_timeout, CTLFLAG_RW,
120                 &sysctl_var_icmp_timeout, 0, "default icmp state life time");
121 SYSCTL_INT(_net_inet_ip_fw3_nat, OID_AUTO, tcp_timeout, CTLFLAG_RW,
122                 &sysctl_var_tcp_timeout, 0, "default tcp state life time");
123 SYSCTL_INT(_net_inet_ip_fw3_nat, OID_AUTO, udp_timeout, CTLFLAG_RW,
124                 &sysctl_var_udp_timeout, 0, "default udp state life time");
125
126 RB_PROTOTYPE(state_tree, nat_state, entries, ip_fw3_nat_state_cmp);
127 RB_GENERATE(state_tree, nat_state, entries, ip_fw3_nat_state_cmp);
128
129 static __inline uint16_t
130 fix_cksum(uint16_t cksum, uint16_t old_info, uint16_t new_info, uint8_t is_udp)
131 {
132         uint32_t tmp;
133
134         if (is_udp && !cksum)
135                 return (0x0000);
136         tmp = cksum + old_info - new_info;
137         tmp = (tmp >> 16) + (tmp & 65535);
138         tmp = tmp & 65535;
139         if (is_udp && !tmp)
140                 return (0xFFFF);
141         return tmp;
142 }
143
144 void
145 check_nat(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
146                 struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
147 {
148         if ((*args)->eh != NULL) {
149                 *cmd_ctl = IP_FW_CTL_NO;
150                 *cmd_val = IP_FW_NOT_MATCH;
151                 return;
152         }
153
154         struct ip_fw3_nat_context *nat_ctx;
155         struct cfg_nat *nat;
156         int nat_id;
157
158         nat_ctx = ip_fw3_nat_ctx[mycpuid];
159         (*args)->rule = *f;
160         nat = ((ipfw_insn_nat *)cmd)->nat;
161         if (nat == NULL) {
162                 nat_id = cmd->arg1;
163                 nat = nat_ctx->nats[nat_id - 1];
164                 if (nat == NULL) {
165                         *cmd_val = IP_FW_DENY;
166                         *cmd_ctl = IP_FW_CTL_DONE;
167                         return;
168                 }
169                 ((ipfw_insn_nat *)cmd)->nat = nat;
170         }
171         *cmd_val = ip_fw3_nat(*args, nat, (*args)->m);
172         *cmd_ctl = IP_FW_CTL_NAT;
173 }
174
175 int
176 ip_fw3_nat(struct ip_fw_args *args, struct cfg_nat *nat, struct mbuf *m)
177 {
178         struct state_tree *tree_out = NULL;
179         struct nat_state *s = NULL, *dup, *k, key;
180         struct nat_state2 *s2 = NULL;
181         struct ip *ip = mtod(m, struct ip *);
182         struct in_addr *old_addr = NULL, new_addr;
183         uint16_t *old_port = NULL, new_port;
184         uint16_t *csum = NULL, dlen = 0;
185         uint8_t udp = 0;
186         boolean_t pseudo = FALSE, need_return_state = FALSE;
187         struct cfg_alias *alias;
188         int i = 0, rand_n = 0;
189
190         k = &key;
191         memset(k, 0, LEN_NAT_STATE);
192         if (args->oif == NULL) {
193                 old_addr = &ip->ip_dst;
194                 k->dst_addr = ntohl(args->f_id.dst_ip);
195                 LIST_FOREACH(alias, &nat->alias, next) {
196                         if (alias->ip.s_addr == ntohl(args->f_id.dst_ip)) {
197                                 break;
198                         }
199                 }
200                 if (alias == NULL) {
201                         goto oops;
202                 }
203                 switch (ip->ip_p) {
204                 case IPPROTO_TCP:
205                         old_port = &L3HDR(struct tcphdr, ip)->th_dport;
206                         s2 = alias->tcp_in[*old_port - ALIAS_BEGIN];
207                         csum = &L3HDR(struct tcphdr, ip)->th_sum;
208                         break;
209                 case IPPROTO_UDP:
210                         old_port = &L3HDR(struct udphdr, ip)->uh_dport;
211                         s2 = alias->udp_in[*old_port - ALIAS_BEGIN];
212                         csum = &L3HDR(struct udphdr, ip)->uh_sum;
213                         udp = 1;
214                         break;
215                 case IPPROTO_ICMP:
216                         old_port = &L3HDR(struct icmp, ip)->icmp_id;
217                         s2 = alias->icmp_in[*old_port];
218                         csum = &L3HDR(struct icmp, ip)->icmp_cksum;
219                         break;
220                 default:
221                         panic("ipfw3: unsupported proto %u", ip->ip_p);
222                 }
223                 if (s2 == NULL) {
224                         goto oops;
225                 }
226         } else {
227                 old_addr = &ip->ip_src;
228                 k->src_addr = args->f_id.src_ip;
229                 k->dst_addr = args->f_id.dst_ip;
230                 switch (ip->ip_p) {
231                 case IPPROTO_TCP:
232                         k->src_port = args->f_id.src_port;
233                         k->dst_port = args->f_id.dst_port;
234                         m->m_pkthdr.csum_flags = CSUM_TCP;
235                         tree_out = &nat->rb_tcp_out;
236                         old_port = &L3HDR(struct tcphdr, ip)->th_sport;
237                         csum = &L3HDR(struct tcphdr, ip)->th_sum;
238                         break;
239                 case IPPROTO_UDP:
240                         k->src_port = args->f_id.src_port;
241                         k->dst_port = args->f_id.dst_port;
242                         m->m_pkthdr.csum_flags = CSUM_UDP;
243                         tree_out = &nat->rb_udp_out;
244                         old_port = &L3HDR(struct udphdr, ip)->uh_sport;
245                         csum = &L3HDR(struct udphdr, ip)->uh_sum;
246                         udp = 1;
247                         break;
248                 case IPPROTO_ICMP:
249                         k->src_port = L3HDR(struct icmp, ip)->icmp_id;
250                         k->dst_port = k->src_port;
251                         tree_out = &nat->rb_icmp_out;
252                         old_port = &L3HDR(struct icmp, ip)->icmp_id;
253                         csum = &L3HDR(struct icmp, ip)->icmp_cksum;
254                         break;
255                 default:
256                         panic("ipfw3: unsupported proto %u", ip->ip_p);
257                 }
258                 s = RB_FIND(state_tree, tree_out, k);
259                 if (s == NULL) {
260                         /* pick an alias ip randomly when there are multiple */
261                         if (nat->count > 1) {
262                                 rand_n = krandom() % nat->count;
263                         }
264                         LIST_FOREACH(alias, &nat->alias, next) {
265                                 if (i++ == rand_n) {
266                                         break;
267                                 }
268                         }
269                         switch  (ip->ip_p) {
270                         case IPPROTO_TCP:
271                                 m->m_pkthdr.csum_flags = CSUM_TCP;
272                                 s = kmalloc(LEN_NAT_STATE, M_IPFW3_NAT,
273                                                 M_INTWAIT | M_NULLOK | M_ZERO);
274
275                                 s->src_addr = args->f_id.src_ip;
276                                 s->src_port = args->f_id.src_port;
277
278                                 s->dst_addr = args->f_id.dst_ip;
279                                 s->dst_port = args->f_id.dst_port;
280
281                                 s->alias_addr = alias->ip.s_addr;
282                                 pick_alias_port(s, tree_out);
283                                 dup = RB_INSERT(state_tree, tree_out, s);
284                                 need_return_state = TRUE;
285                                 break;
286                         case IPPROTO_UDP:
287                                 m->m_pkthdr.csum_flags = CSUM_UDP;
288                                 s = kmalloc(LEN_NAT_STATE, M_IPFW3_NAT,
289                                                 M_INTWAIT | M_NULLOK | M_ZERO);
290
291                                 s->src_addr = args->f_id.src_ip;
292                                 s->src_port = args->f_id.src_port;
293
294                                 s->dst_addr = args->f_id.dst_ip;
295                                 s->dst_port = args->f_id.dst_port;
296
297                                 s->alias_addr = alias->ip.s_addr;
298                                 pick_alias_port(s, tree_out);
299                                 dup = RB_INSERT(state_tree, tree_out, s);
300                                 need_return_state = TRUE;
301                                 break;
302                         case IPPROTO_ICMP:
303                                 s = kmalloc(LEN_NAT_STATE, M_IPFW3_NAT,
304                                                 M_INTWAIT | M_NULLOK | M_ZERO);
305                                 s->src_addr = args->f_id.src_ip;
306                                 s->dst_addr = args->f_id.dst_ip;
307
308                                 s->src_port = *old_port;
309                                 s->dst_port = *old_port;
310
311                                 s->alias_addr = alias->ip.s_addr;
312                                 s->alias_port = htons(s->src_addr % ALIAS_RANGE);
313                                 dup = RB_INSERT(state_tree, tree_out, s);
314
315                                 s2 = kmalloc(LEN_NAT_STATE2, M_IPFW3_NAT,
316                                                 M_INTWAIT | M_NULLOK | M_ZERO);
317
318                                 s2->src_addr = args->f_id.dst_ip;
319                                 s2->dst_addr = alias->ip.s_addr;
320
321                                 s2->src_port = s->alias_port;
322                                 s2->dst_port = s->alias_port;
323
324                                 s2->alias_addr = htonl(args->f_id.src_ip);
325                                 s2->alias_port = *old_port;
326
327                                 alias->icmp_in[s->alias_port] = s2;
328                                 break;
329                         default :
330                                 goto oops;
331                         }
332                 }
333         }
334         if (args->oif == NULL) {
335                 new_addr.s_addr = s2->src_addr;
336                 new_port = s2->src_port;
337                 s2->timestamp = time_uptime;
338         } else {
339                 new_addr.s_addr = s->alias_addr;
340                 new_port = s->alias_port;
341                 s->timestamp = time_uptime;
342         }
343
344         /* replace src/dst and fix the checksum */
345         if (m->m_pkthdr.csum_flags & (CSUM_UDP | CSUM_TCP | CSUM_TSO)) {
346                 if ((m->m_pkthdr.csum_flags & CSUM_TSO) == 0) {
347                         dlen = ip->ip_len - (ip->ip_hl << 2);
348                 }
349                 pseudo = TRUE;
350         }
351         if (!pseudo) {
352                 const uint16_t *oaddr, *naddr;
353                 oaddr = (const uint16_t *)&old_addr->s_addr;
354                 naddr = (const uint16_t *)&new_addr.s_addr;
355                 ip->ip_sum = fix_cksum(ip->ip_sum, oaddr[0], naddr[0], 0);
356                 ip->ip_sum = fix_cksum(ip->ip_sum, oaddr[1], naddr[1], 0);
357                 if (ip->ip_p != IPPROTO_ICMP) {
358                         *csum = fix_cksum(*csum, oaddr[0], naddr[0], udp);
359                         *csum = fix_cksum(*csum, oaddr[1], naddr[1], udp);
360                 }
361         }
362         old_addr->s_addr = new_addr.s_addr;
363         if (!pseudo) {
364                 *csum = fix_cksum(*csum, *old_port, new_port, udp);
365         }
366         *old_port = new_port;
367
368         if (pseudo) {
369                 *csum = in_pseudo(ip->ip_src.s_addr,
370                                 ip->ip_dst.s_addr, htons(dlen + ip->ip_p));
371         }
372
373         /* prepare the state for return traffic */
374         if (need_return_state) {
375                 ip->ip_len = htons(ip->ip_len);
376                 ip->ip_off = htons(ip->ip_off);
377
378                 m->m_flags &= ~M_HASH;
379                 ip_hashfn(&m, 0);
380
381                 ip->ip_len = ntohs(ip->ip_len);
382                 ip->ip_off = ntohs(ip->ip_off);
383
384                 int nextcpu = netisr_hashcpu(m->m_pkthdr.hash);
385                 if (nextcpu != mycpuid) {
386                         struct netmsg_nat_state_add *msg;
387                         msg = kmalloc(LEN_NMSG_NAT_STATE_ADD,
388                                         M_LWKTMSG, M_NOWAIT | M_ZERO);
389                         netmsg_init(&msg->base, NULL, &curthread->td_msgport,
390                                         0, nat_state_add_dispatch);
391                         s2 = kmalloc(LEN_NAT_STATE2, M_IPFW3_NAT,
392                                         M_INTWAIT | M_NULLOK | M_ZERO);
393
394                         s2->src_addr = args->f_id.dst_ip;
395                         s2->src_port = args->f_id.dst_port;
396
397                         s2->dst_addr = alias->ip.s_addr;
398                         s2->dst_port = s->alias_port;
399
400                         s2->src_addr = htonl(args->f_id.src_ip);
401                         s2->src_port = htons(args->f_id.src_port);
402
403                         s2->timestamp = s->timestamp;
404                         msg->alias_addr.s_addr = alias->ip.s_addr;
405                         msg->alias_port = s->alias_port;
406                         msg->state = s2;
407                         msg->nat_id = nat->id;
408                         msg->proto = ip->ip_p;
409                         netisr_sendmsg(&msg->base, nextcpu);
410                 } else {
411                         s2 = kmalloc(LEN_NAT_STATE2, M_IPFW3_NAT,
412                                         M_INTWAIT | M_NULLOK | M_ZERO);
413
414                         s2->src_addr = args->f_id.dst_ip;
415                         s2->dst_addr = alias->ip.s_addr;
416
417                         s2->src_port = s->alias_port;
418                         s2->dst_port = s->alias_port;
419
420                         s2->src_addr = htonl(args->f_id.src_ip);
421                         s2->src_port = htons(args->f_id.src_port);
422
423                         s2->timestamp = s->timestamp;
424                         if (ip->ip_p == IPPROTO_TCP) {
425                                 alias->tcp_in[s->alias_port - ALIAS_BEGIN] = s2;
426                         } else {
427                                 alias->udp_in[s->alias_port - ALIAS_BEGIN] = s2;
428                         }
429                 }
430         }
431         return IP_FW_NAT;
432 oops:
433         IPFW3_DEBUG1("oops\n");
434         return IP_FW_DENY;
435 }
436
437 void
438 pick_alias_port(struct nat_state *s, struct state_tree *tree)
439 {
440         do {
441                 s->alias_port = htons(krandom() % ALIAS_RANGE + ALIAS_BEGIN);
442         } while (RB_FIND(state_tree, tree, s) != NULL);
443 }
444
445 int
446 ip_fw3_nat_state_cmp(struct nat_state *s1, struct nat_state *s2)
447 {
448         if (s1->src_addr > s2->src_addr)
449                 return 1;
450         if (s1->src_addr < s2->src_addr)
451                 return -1;
452
453         if (s1->dst_addr > s2->dst_addr)
454                 return 1;
455         if (s1->dst_addr < s2->dst_addr)
456                 return -1;
457
458         if (s1->src_port > s2->src_port)
459                 return 1;
460         if (s1->src_port < s2->src_port)
461                 return -1;
462
463         if (s1->dst_port > s2->dst_port)
464                 return 1;
465         if (s1->dst_port < s2->dst_port)
466                 return -1;
467
468         return 0;
469 }
470
471 int
472 ip_fw3_ctl_nat_get_cfg(struct sockopt *sopt)
473 {
474         struct ip_fw3_nat_context *nat_ctx;
475         struct ioc_nat *ioc;
476         struct cfg_nat *nat;
477         struct cfg_alias *alias;
478         struct in_addr *ip;
479         size_t valsize;
480         int i, len;
481
482         len = 0;
483         nat_ctx = ip_fw3_nat_ctx[mycpuid];
484         valsize = sopt->sopt_valsize;
485         ioc = (struct ioc_nat *)sopt->sopt_val;
486
487         for (i = 0; i < NAT_ID_MAX; i++) {
488                 nat = nat_ctx->nats[i];
489                 if (nat != NULL) {
490                         len += LEN_IOC_NAT;
491                         if (len >= valsize) {
492                                 goto nospace;
493                         }
494                         ioc->id = nat->id;
495                         ioc->count = nat->count;
496                         ip = &ioc->ip;
497                         LIST_FOREACH(alias, &nat->alias, next) {
498                                 len += LEN_IN_ADDR;
499                                 if (len > valsize) {
500                                         goto nospace;
501                                 }
502                                 bcopy(&alias->ip, ip, LEN_IN_ADDR);
503                                 ip++;
504                         }
505                 }
506         }
507         sopt->sopt_valsize = len;
508         return 0;
509 nospace:
510         bzero(sopt->sopt_val, sopt->sopt_valsize);
511         sopt->sopt_valsize = 0;
512         return 0;
513 }
514
515 int
516 ip_fw3_ctl_nat_get_record(struct sockopt *sopt)
517 {
518         struct ip_fw3_nat_context *nat_ctx;
519         struct cfg_nat *the;
520         size_t sopt_size, total_len = 0;
521         struct ioc_nat_state *ioc;
522         int ioc_nat_id, i, n, cpu;
523         struct nat_state        *s;
524         struct nat_state2       *s2;
525         struct cfg_alias        *a1;
526
527         ioc_nat_id = *((int *)(sopt->sopt_val));
528         sopt_size = sopt->sopt_valsize;
529         ioc = (struct ioc_nat_state *)sopt->sopt_val;
530         /* icmp states only in CPU 0 */
531         cpu = 0;
532         nat_ctx = ip_fw3_nat_ctx[cpu];
533         for (n = 0; n < NAT_ID_MAX; n++) {
534                 if (ioc_nat_id == 0 || ioc_nat_id == n + 1) {
535                         if (nat_ctx->nats[n] == NULL)
536                                 break;
537                         the = nat_ctx->nats[n];
538                         RB_FOREACH(s, state_tree, &the->rb_icmp_out) {
539                                 total_len += LEN_IOC_NAT_STATE;
540                                 if (total_len > sopt_size)
541                                         goto nospace;
542                                 ioc->src_addr.s_addr = ntohl(s->src_addr);
543                                 ioc->dst_addr.s_addr = s->dst_addr;
544                                 ioc->alias_addr.s_addr = s->alias_addr;
545                                 ioc->src_port = s->src_port;
546                                 ioc->dst_port = s->dst_port;
547                                 ioc->alias_port = s->alias_port;
548                                 ioc->nat_id = n + 1;
549                                 ioc->cpu_id = cpu;
550                                 ioc->proto = IPPROTO_ICMP;
551                                 ioc->direction = 1;
552                                 ioc->life = s->timestamp +
553                                         sysctl_var_icmp_timeout - time_uptime;
554                                 ioc++;
555                         }
556
557                         LIST_FOREACH(a1, &the->alias, next) {
558                         for (i = 0; i < ALIAS_RANGE; i++) {
559                                 s2 = a1->icmp_in[i];
560                                 if (s2 == NULL) {
561                                         continue;
562                                 }
563
564                                 total_len += LEN_IOC_NAT_STATE;
565                                 if (total_len > sopt_size)
566                                         goto nospace;
567
568                                 ioc->src_addr.s_addr = ntohl(s2->src_addr);
569                                 ioc->dst_addr.s_addr = s2->dst_addr;
570                                 ioc->alias_addr.s_addr = s2->alias_addr;
571                                 ioc->src_port = s2->src_port;
572                                 ioc->dst_port = s2->dst_port;
573                                 ioc->alias_port = s2->alias_port;
574                                 ioc->nat_id = n + 1;
575                                 ioc->cpu_id = cpu;
576                                 ioc->proto = IPPROTO_ICMP;
577                                 ioc->direction = 0;
578                                 ioc->life = s2->timestamp +
579                                         sysctl_var_icmp_timeout - time_uptime;
580                                 ioc++;
581                         }
582                         }
583                 }
584         }
585
586         /* tcp states */
587         for (cpu = 0; cpu < ncpus; cpu++) {
588                 nat_ctx = ip_fw3_nat_ctx[cpu];
589                 for (n = 0; n < NAT_ID_MAX; n++) {
590                         if (ioc_nat_id == 0 || ioc_nat_id == n + 1) {
591                                 if (nat_ctx->nats[n] == NULL)
592                                         break;
593                                 the = nat_ctx->nats[n];
594                                 RB_FOREACH(s, state_tree, &the->rb_tcp_out) {
595                                         total_len += LEN_IOC_NAT_STATE;
596                                         if (total_len > sopt_size)
597                                                 goto nospace;
598                                         ioc->src_addr.s_addr = ntohl(s->src_addr);
599                                         ioc->dst_addr.s_addr = ntohl(s->dst_addr);
600                                         ioc->alias_addr.s_addr = s->alias_addr;
601                                         ioc->src_port = ntohs(s->src_port);
602                                         ioc->dst_port = ntohs(s->dst_port);
603                                         ioc->alias_port = s->alias_port;
604                                         ioc->nat_id = n + 1;
605                                         ioc->cpu_id = cpu;
606                                         ioc->proto = IPPROTO_TCP;
607                                         ioc->direction = 1;
608                                         ioc->life = s->timestamp +
609                                                 sysctl_var_tcp_timeout - time_uptime;
610                                         ioc++;
611                                 }
612                                 LIST_FOREACH(a1, &the->alias, next) {
613                                         for (i = 0; i < ALIAS_RANGE; i++) {
614                                                 s2 = a1->tcp_in[i];
615                                                 if (s2 == NULL) {
616                                                         continue;
617                                                 }
618
619                                                 total_len += LEN_IOC_NAT_STATE;
620                                                 if (total_len > sopt_size)
621                                                         goto nospace;
622
623                                                 ioc->src_addr.s_addr = ntohl(s2->src_addr);
624                                                 ioc->dst_addr.s_addr = s2->dst_addr;
625                                                 ioc->alias_addr.s_addr = s2->alias_addr;
626                                                 ioc->src_port = s2->src_port;
627                                                 ioc->dst_port = s2->dst_port;
628                                                 ioc->alias_port = s2->alias_port;
629                                                 ioc->nat_id = n + 1;
630                                                 ioc->cpu_id = cpu;
631                                                 ioc->proto = IPPROTO_TCP;
632                                                 ioc->direction = 0;
633                                                 ioc->life = s2->timestamp +
634                                                         sysctl_var_icmp_timeout - time_uptime;
635                                                 ioc++;
636                                         }
637                                 }
638                         }
639                 }
640         }
641
642         /* udp states */
643         for (cpu = 0; cpu < ncpus; cpu++) {
644                 nat_ctx = ip_fw3_nat_ctx[cpu];
645                 for (n = 0; n < NAT_ID_MAX; n++) {
646                         if (ioc_nat_id == 0 || ioc_nat_id == n + 1) {
647                                 if (nat_ctx->nats[n] == NULL)
648                                         break;
649                                 the = nat_ctx->nats[n];
650                                 RB_FOREACH(s, state_tree, &the->rb_udp_out) {
651                                         total_len += LEN_IOC_NAT_STATE;
652                                         if (total_len > sopt_size)
653                                                 goto nospace;
654                                         ioc->src_addr.s_addr = ntohl(s->src_addr);
655                                         ioc->dst_addr.s_addr = s->dst_addr;
656                                         ioc->alias_addr.s_addr = s->alias_addr;
657                                         ioc->src_port = s->src_port;
658                                         ioc->dst_port = s->dst_port;
659                                         ioc->alias_port = s->alias_port;
660                                         ioc->nat_id = n + 1;
661                                         ioc->cpu_id = cpu;
662                                         ioc->proto = IPPROTO_UDP;
663                                         ioc->direction = 1;
664                                         ioc->life = s->timestamp +
665                                                 sysctl_var_udp_timeout - time_uptime;
666                                         ioc++;
667                                 }
668                                 LIST_FOREACH(a1, &the->alias, next) {
669                                         for (i = 0; i < ALIAS_RANGE; i++) {
670                                                 s2 = a1->udp_in[i];
671                                                 if (s2 == NULL) {
672                                                         continue;
673                                                 }
674
675                                                 total_len += LEN_IOC_NAT_STATE;
676                                                 if (total_len > sopt_size)
677                                                         goto nospace;
678
679                                                 ioc->src_addr.s_addr = ntohl(s2->src_addr);
680                                                 ioc->dst_addr.s_addr = s2->dst_addr;
681                                                 ioc->alias_addr.s_addr = s2->alias_addr;
682                                                 ioc->src_port = s2->src_port;
683                                                 ioc->dst_port = s2->dst_port;
684                                                 ioc->alias_port = s2->alias_port;
685                                                 ioc->nat_id = n + 1;
686                                                 ioc->cpu_id = cpu;
687                                                 ioc->proto = IPPROTO_UDP;
688                                                 ioc->direction = 0;
689                                                 ioc->life = s2->timestamp +
690                                                         sysctl_var_icmp_timeout - time_uptime;
691                                                 ioc++;
692                                         }
693                                 }
694                         }
695                 }
696         }
697         sopt->sopt_valsize = total_len;
698         return 0;
699 nospace:
700         return 0;
701 }
702
703 void
704 nat_state_add_dispatch(netmsg_t add_msg)
705 {
706         struct ip_fw3_nat_context *nat_ctx;
707         struct netmsg_nat_state_add *msg;
708         struct cfg_nat *nat;
709         struct nat_state2 *s2;
710         struct cfg_alias *alias;
711
712         nat_ctx = ip_fw3_nat_ctx[mycpuid];
713         msg = (struct netmsg_nat_state_add *)add_msg;
714         nat = nat_ctx->nats[msg->nat_id - 1];
715
716         LIST_FOREACH(alias, &nat->alias, next) {
717                 if (alias->ip.s_addr == msg->alias_addr.s_addr) {
718                         break;
719                 }
720         }
721         s2 = msg->state;
722         if (msg->proto == IPPROTO_TCP) {
723                 alias->tcp_in[msg->alias_port - ALIAS_BEGIN] = s2;
724         } else {
725                 alias->udp_in[msg->alias_port - ALIAS_BEGIN] = s2;
726         }
727 }
728
729 /*
730  * Init the RB trees only when the NAT is configured.
731  */
732 void
733 nat_add_dispatch(netmsg_t nat_add_msg)
734 {
735         struct ip_fw3_nat_context *nat_ctx;
736         struct netmsg_nat_add *msg;
737         struct ioc_nat *ioc;
738         struct cfg_nat *nat;
739         struct cfg_alias *alias;
740         struct in_addr *ip;
741         int n;
742
743         msg = (struct netmsg_nat_add *)nat_add_msg;
744         ioc = &msg->ioc_nat;
745         nat_ctx = ip_fw3_nat_ctx[mycpuid];
746
747         if (nat_ctx->nats[ioc->id - 1] == NULL) {
748                 /* op = set, and nat not exists */
749                 nat = kmalloc(LEN_CFG_NAT, M_IPFW3_NAT, M_WAITOK | M_ZERO);
750                 LIST_INIT(&nat->alias);
751                 RB_INIT(&nat->rb_tcp_out);
752                 RB_INIT(&nat->rb_udp_out);
753                 if (mycpuid == 0) {
754                         RB_INIT(&nat->rb_icmp_out);
755                 }
756                 nat->id = ioc->id;
757                 nat->count = ioc->count;
758                 ip = &ioc->ip;
759                 for (n = 0; n < ioc->count; n++) {
760                         alias = kmalloc(LEN_CFG_ALIAS,
761                                         M_IPFW3_NAT, M_WAITOK | M_ZERO);
762                         memcpy(&alias->ip, ip, LEN_IN_ADDR);
763                         LIST_INSERT_HEAD((&nat->alias), alias, next);
764                         ip++;
765                 }
766                 nat_ctx->nats[ioc->id - 1] = nat;
767         }
768         netisr_forwardmsg_all(&msg->base, mycpuid + 1);
769 }
770
771 int
772 ip_fw3_ctl_nat_add(struct sockopt *sopt)
773 {
774         struct netmsg_nat_add nat_add_msg, *msg;
775         struct ioc_nat *ioc;
776         msg = &nat_add_msg;
777
778         ioc = (struct ioc_nat *)(sopt->sopt_val);
779         sooptcopyin(sopt, &msg->ioc_nat, sopt->sopt_valsize,
780                         sizeof(struct ioc_nat));
781         netmsg_init(&msg->base, NULL, &curthread->td_msgport, 0,
782                         nat_add_dispatch);
783         netisr_domsg(&msg->base, 0);
784         return 0;
785 }
786
787 void
788 nat_del_dispatch(netmsg_t nat_del_msg)
789 {
790         struct ip_fw3_nat_context *nat_ctx;
791         struct netmsg_nat_del *msg;
792         struct cfg_nat *nat;
793         struct nat_state *s, *tmp;
794         struct cfg_alias *alias, *tmp3;
795
796         msg = (struct netmsg_nat_del *)nat_del_msg;
797
798         nat_ctx = ip_fw3_nat_ctx[mycpuid];
799         nat = nat_ctx->nats[msg->id - 1];
800         if (nat != NULL) {
801                 /* the icmp states will only stored in cpu 0 */
802                 RB_FOREACH_SAFE(s, state_tree, &nat->rb_icmp_out, tmp) {
803                         RB_REMOVE(state_tree, &nat->rb_icmp_out, s);
804                         if (s != NULL) {
805                                 kfree(s, M_IPFW3_NAT);
806                         }
807                 }
808                 /*
809                 LIST_FOREACH_MUTABLE(s2, &nat->alias->icmp_in, next, tmp2) {
810                         LIST_REMOVE(s2, next);
811                         if (s != NULL) {
812                                 kfree(s, M_IPFW3_NAT);
813                         }
814                 }
815                 */
816
817                 RB_FOREACH_SAFE(s, state_tree, &nat->rb_tcp_out, tmp) {
818                         RB_REMOVE(state_tree, &nat->rb_tcp_out, s);
819                         if (s != NULL) {
820                                 kfree(s, M_IPFW3_NAT);
821                         }
822                 }
823                 /*
824                 LIST_FOREACH_MUTABLE(s2, &nat->alias->tcp_in, next, tmp2) {
825                         LIST_REMOVE(s2, next);
826                         if (s != NULL) {
827                                 kfree(s, M_IPFW3_NAT);
828                         }
829                 }
830                 */
831                 RB_FOREACH_SAFE(s, state_tree, &nat->rb_udp_out, tmp) {
832                         RB_REMOVE(state_tree, &nat->rb_udp_out, s);
833                         if (s != NULL) {
834                                 kfree(s, M_IPFW3_NAT);
835                         }
836                 }
837                 /*
838                 LIST_FOREACH_MUTABLE(s2, &nat->alias->udp_in, next, tmp2) {
839                         LIST_REMOVE(s2, next);
840                         if (s != NULL) {
841                                 kfree(s, M_IPFW3_NAT);
842                         }
843                 }
844                 */
845                 LIST_FOREACH_MUTABLE(alias, &nat->alias, next, tmp3) {
846                         kfree(alias, M_IPFW3_NAT);
847                 }
848                 kfree(nat, M_IPFW3_NAT);
849                 nat_ctx->nats[msg->id - 1] = NULL;
850         }
851         netisr_forwardmsg_all(&nat_del_msg->base, mycpuid + 1);
852 }
853 int
854 ip_fw3_ctl_nat_del(struct sockopt *sopt)
855 {
856         struct netmsg_nat_del nat_del_msg, *msg;
857
858         msg = &nat_del_msg;
859         msg->id = *((int *)sopt->sopt_val);
860         netmsg_init(&msg->base, NULL, &curthread->td_msgport,
861                         0, nat_del_dispatch);
862
863         netisr_domsg(&msg->base, 0);
864         return 0;
865 }
866 int
867 ip_fw3_ctl_nat_flush(struct sockopt *sopt)
868 {
869         struct netmsg_nat_del nat_del_msg, *msg;
870         int i;
871         msg = &nat_del_msg;
872         for (i = 0; i < NAT_ID_MAX; i++) {
873                 msg->id = i + 1;
874                 netmsg_init(&msg->base, NULL, &curthread->td_msgport,
875                                 0, nat_del_dispatch);
876
877                 netisr_domsg(&msg->base, 0);
878         }
879         return 0;
880 }
881
882 int
883 ip_fw3_ctl_nat_sockopt(struct sockopt *sopt)
884 {
885         int error = 0;
886         switch (sopt->sopt_name) {
887         case IP_FW_NAT_ADD:
888                 error = ip_fw3_ctl_nat_add(sopt);
889                 break;
890         case IP_FW_NAT_DEL:
891                 error = ip_fw3_ctl_nat_del(sopt);
892                 break;
893         case IP_FW_NAT_FLUSH:
894                 error = ip_fw3_ctl_nat_flush(sopt);
895                 break;
896         case IP_FW_NAT_GET:
897                 error = ip_fw3_ctl_nat_get_cfg(sopt);
898                 break;
899         case IP_FW_NAT_GET_RECORD:
900                 error = ip_fw3_ctl_nat_get_record(sopt);
901                 break;
902         default:
903                 kprintf("ipfw3 nat invalid socket option %d\n",
904                                 sopt->sopt_name);
905         }
906         return error;
907 }
908
909 void
910 nat_init_ctx_dispatch(netmsg_t msg)
911 {
912         struct ip_fw3_nat_context *tmp;
913         tmp = kmalloc(sizeof(struct ip_fw3_nat_context),
914                                 M_IPFW3_NAT, M_WAITOK | M_ZERO);
915
916         ip_fw3_nat_ctx[mycpuid] = tmp;
917         netisr_forwardmsg_all(&msg->base, mycpuid + 1);
918 }
919
920 void
921 nat_fnit_ctx_dispatch(netmsg_t msg)
922 {
923         kfree(ip_fw3_nat_ctx[mycpuid], M_IPFW3_NAT);
924         netisr_forwardmsg_all(&msg->base, mycpuid + 1);
925 }
926
927 static void
928 nat_cleanup_func_dispatch(netmsg_t nmsg)
929 {
930         struct nat_state *s, *tmp;
931         struct ip_fw3_nat_context *nat_ctx;
932         struct cfg_nat *nat;
933         struct cfg_alias *a1, *tmp2;
934         struct nat_state2 *s2;
935         int i, j;
936
937         nat_ctx = ip_fw3_nat_ctx[mycpuid];
938         for (j = 0; j < NAT_ID_MAX; j++) {
939                 nat = nat_ctx->nats[j];
940                 if (nat == NULL)
941                         continue;
942                 /* check the nat_states, remove the expired state */
943                 /* the icmp states will only stored in cpu 0 */
944                 RB_FOREACH_SAFE(s, state_tree, &nat->rb_icmp_out, tmp) {
945                         if (time_uptime - s->timestamp > sysctl_var_icmp_timeout) {
946                                 RB_REMOVE(state_tree, &nat->rb_icmp_out, s);
947                                 kfree(s, M_IPFW3_NAT);
948                         }
949                 }
950                 LIST_FOREACH_MUTABLE(a1, &nat->alias, next, tmp2) {
951                         for (i = 0; i < ALIAS_RANGE; i++) {
952                                 s2 = a1->icmp_in[i];
953                                 if (s2 != NULL) {
954                                         if (time_uptime - s2->timestamp > sysctl_var_icmp_timeout) {
955                                                 a1->icmp_in[i] = NULL;
956                                                 kfree(s2, M_IPFW3_NAT);
957                                         }
958                                 }
959
960                         }
961                 }
962
963                 RB_FOREACH_SAFE(s, state_tree, &nat->rb_tcp_out, tmp) {
964                         if (time_uptime - s->timestamp > sysctl_var_tcp_timeout) {
965                                 RB_REMOVE(state_tree, &nat->rb_tcp_out, s);
966                                 kfree(s, M_IPFW3_NAT);
967                         }
968                 }
969                 LIST_FOREACH_MUTABLE(a1, &nat->alias, next, tmp2) {
970                         for (i = 0; i < ALIAS_RANGE; i++) {
971                                 s2 = a1->tcp_in[i];
972                                 if (s2 != NULL) {
973                                         if (time_uptime - s2->timestamp > sysctl_var_icmp_timeout) {
974                                                 a1->tcp_in[i] = NULL;
975                                                 kfree(s2, M_IPFW3_NAT);
976                                         }
977                                 }
978
979                         }
980                 }
981                 RB_FOREACH_SAFE(s, state_tree, &nat->rb_udp_out, tmp) {
982                         if (time_uptime - s->timestamp > sysctl_var_udp_timeout) {
983                                 RB_REMOVE(state_tree, &nat->rb_udp_out, s);
984                                 kfree(s, M_IPFW3_NAT);
985                         }
986                 }
987                 LIST_FOREACH_MUTABLE(a1, &nat->alias, next, tmp2) {
988                         for (i = 0; i < ALIAS_RANGE; i++) {
989                                 s2 = a1->udp_in[i];
990                                 if (s2 != NULL) {
991                                         if (time_uptime - s2->timestamp > sysctl_var_icmp_timeout) {
992                                                 a1->udp_in[i] = NULL;
993                                                 kfree(s2, M_IPFW3_NAT);
994                                         }
995                                 }
996
997                         }
998                 }
999         }
1000         netisr_forwardmsg_all(&nmsg->base, mycpuid + 1);
1001 }
1002
1003 static void
1004 ip_fw3_nat_cleanup_func(void *dummy __unused)
1005 {
1006         struct netmsg_base msg;
1007         netmsg_init(&msg, NULL, &curthread->td_msgport, 0,
1008                         nat_cleanup_func_dispatch);
1009         netisr_domsg(&msg, 0);
1010
1011         callout_reset(&ip_fw3_nat_cleanup_callout,
1012                         sysctl_var_cleanup_interval * hz,
1013                         ip_fw3_nat_cleanup_func, NULL);
1014 }
1015
1016 static
1017 int ip_fw3_nat_init(void)
1018 {
1019         struct netmsg_base msg;
1020         ip_fw3_register_module(MODULE_NAT_ID, MODULE_NAT_NAME);
1021         ip_fw3_register_filter_funcs(MODULE_NAT_ID, O_NAT_NAT,
1022                         (filter_func)check_nat);
1023         ip_fw3_ctl_nat_ptr = ip_fw3_ctl_nat_sockopt;
1024         netmsg_init(&msg, NULL, &curthread->td_msgport,
1025                         0, nat_init_ctx_dispatch);
1026         netisr_domsg(&msg, 0);
1027
1028         callout_init_mp(&ip_fw3_nat_cleanup_callout);
1029         callout_reset(&ip_fw3_nat_cleanup_callout,
1030                         sysctl_var_cleanup_interval * hz,
1031                         ip_fw3_nat_cleanup_func,
1032                         NULL);
1033         return 0;
1034 }
1035
1036 static int
1037 ip_fw3_nat_fini(void)
1038 {
1039         struct netmsg_base msg;
1040         struct netmsg_nat_del nat_del_msg, *msg1;
1041         int i;
1042
1043         callout_stop(&ip_fw3_nat_cleanup_callout);
1044
1045         msg1 = &nat_del_msg;
1046         for (i = 0; i < NAT_ID_MAX; i++) {
1047                 msg1->id = i + 1;
1048                 netmsg_init(&msg1->base, NULL, &curthread->td_msgport,
1049                                 0, nat_del_dispatch);
1050
1051                 netisr_domsg(&msg1->base, 0);
1052         }
1053
1054         netmsg_init(&msg, NULL, &curthread->td_msgport,
1055                         0, nat_fnit_ctx_dispatch);
1056         netisr_domsg(&msg, 0);
1057
1058         return ip_fw3_unregister_module(MODULE_NAT_ID);
1059 }
1060
1061 static int
1062 ip_fw3_nat_modevent(module_t mod, int type, void *data)
1063 {
1064         switch (type) {
1065         case MOD_LOAD:
1066                 return ip_fw3_nat_init();
1067         case MOD_UNLOAD:
1068                 return ip_fw3_nat_fini();
1069         default:
1070                 break;
1071         }
1072         return 0;
1073 }
1074
1075 moduledata_t ip_fw3_nat_mod = {
1076         "ipfw3_nat",
1077         ip_fw3_nat_modevent,
1078         NULL
1079 };
1080
1081 DECLARE_MODULE(ipfw3_nat, ip_fw3_nat_mod,
1082                 SI_SUB_PROTO_IFATTACHDOMAIN, SI_ORDER_ANY);
1083 MODULE_DEPEND(ipfw3_nat, ipfw3_basic, 1, 1, 1);
1084 MODULE_VERSION(ipfw3_nat, 1);