kernel: Remove numerous #include <sys/thread2.h>.
[dragonfly.git] / sys / net / ipfw3_nat / ip_fw3_nat.c
1 /*
2  * Copyright (c) 2014 - 2018 The DragonFly Project.  All rights reserved.
3  *
4  * This code is derived from software contributed to The DragonFly Project
5  * by Bill Yuan <bycn82@dragonflybsd.org>
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  *
11  * 1. Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in
15  *    the documentation and/or other materials provided with the
16  *    distribution.
17  * 3. Neither the name of The DragonFly Project nor the names of its
18  *    contributors may be used to endorse or promote products derived
19  *    from this software without specific, prior written permission.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
24  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE
25  * COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
26  * INCIDENTAL, SPECIAL, EXEMPLARY OR CONSEQUENTIAL DAMAGES (INCLUDING,
27  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
28  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
29  * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
30  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
31  * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
32  * SUCH DAMAGE.
33  */
34
35 #include "opt_ipfw.h"
36 #include "opt_inet.h"
37 #ifndef INET
38 #error IPFIREWALL3 requires INET.
39 #endif /* INET */
40
41 #include <sys/param.h>
42 #include <sys/kernel.h>
43 #include <sys/malloc.h>
44 #include <sys/mbuf.h>
45 #include <sys/socketvar.h>
46 #include <sys/sysctl.h>
47 #include <sys/systimer.h>
48 #include <sys/in_cksum.h>
49 #include <sys/systm.h>
50 #include <sys/proc.h>
51 #include <sys/socket.h>
52 #include <sys/syslog.h>
53 #include <sys/ucred.h>
54 #include <sys/lock.h>
55 #include <sys/mplock2.h>
56
57 #include <net/ethernet.h>
58 #include <net/netmsg2.h>
59 #include <net/netisr2.h>
60 #include <net/route.h>
61 #include <net/if.h>
62
63 #include <netinet/in.h>
64 #include <netinet/ip.h>
65 #include <netinet/ip_icmp.h>
66 #include <netinet/tcp.h>
67 #include <netinet/tcp_timer.h>
68 #include <netinet/tcp_var.h>
69 #include <netinet/tcpip.h>
70 #include <netinet/udp.h>
71 #include <netinet/udp_var.h>
72 #include <netinet/in_systm.h>
73 #include <netinet/in_var.h>
74 #include <netinet/in_pcb.h>
75 #include <netinet/ip_var.h>
76 #include <netinet/ip_divert.h>
77 #include <net/ipfw3/ip_fw.h>
78
79 #include "ip_fw3_nat.h"
80
81 MALLOC_DEFINE(M_IPFW3_NAT, "IP_FW3_NAT", "ipfw3_nat module");
82
83 /*
84  * Highspeed Lockless Kernel NAT
85  *
86  * Kernel NAT
87  * The network address translation (NAT) will replace the `src` of the packet
88  * with an `alias` (alias_addr & alias_port). Accordingt to the configuration,
89  * The alias will be randomly picked from the configured range.
90  *
91  * Highspeed
92  * The first outgoing packet should trigger the creation of the `net_state`,
93  * and the `net_state` will keep in a RB-Tree for the subsequent outgoing
94  * packets.
95  * The first returning packet will trigger the creation of the `net_state2`,
96  * which will be stored in a multidimensional array of points ( of net_state2 ).
97  *
98  * Lockless
99  * The `net_state` for outgoing packet will be stored in the nat_context of
100  * current CPU. But due to the nature of the NAT, the returning packet may be
101  * handled by another CPU. Hence, The `net_state2` for the returning packet
102  * will be prepared and stored into the nat_context of the right CPU.
103  */
104
105 struct ip_fw3_nat_context       *ip_fw3_nat_ctx[MAXCPU];
106 static struct callout           ip_fw3_nat_cleanup_callout;
107 extern struct ipfw3_context     *fw3_ctx[MAXCPU];
108 extern ip_fw_ctl_t              *ip_fw3_ctl_nat_ptr;
109
110 static int                      sysctl_var_cleanup_interval = 1;
111 static int                      sysctl_var_icmp_timeout = 10;
112 static int                      sysctl_var_tcp_timeout = 60;
113 static int                      sysctl_var_udp_timeout = 30;
114
115 SYSCTL_NODE(_net_inet_ip, OID_AUTO, fw3_nat, CTLFLAG_RW, 0, "ipfw3 NAT");
116 SYSCTL_INT(_net_inet_ip_fw3_nat, OID_AUTO, cleanup_interval, CTLFLAG_RW,
117                 &sysctl_var_cleanup_interval, 0, "default life time");
118 SYSCTL_INT(_net_inet_ip_fw3_nat, OID_AUTO, icmp_timeout, CTLFLAG_RW,
119                 &sysctl_var_icmp_timeout, 0, "default icmp state life time");
120 SYSCTL_INT(_net_inet_ip_fw3_nat, OID_AUTO, tcp_timeout, CTLFLAG_RW,
121                 &sysctl_var_tcp_timeout, 0, "default tcp state life time");
122 SYSCTL_INT(_net_inet_ip_fw3_nat, OID_AUTO, udp_timeout, CTLFLAG_RW,
123                 &sysctl_var_udp_timeout, 0, "default udp state life time");
124
125 RB_PROTOTYPE(state_tree, nat_state, entries, ip_fw3_nat_state_cmp);
126 RB_GENERATE(state_tree, nat_state, entries, ip_fw3_nat_state_cmp);
127
128 static __inline uint16_t
129 fix_cksum(uint16_t cksum, uint16_t old_info, uint16_t new_info, uint8_t is_udp)
130 {
131         uint32_t tmp;
132
133         if (is_udp && !cksum)
134                 return (0x0000);
135         tmp = cksum + old_info - new_info;
136         tmp = (tmp >> 16) + (tmp & 65535);
137         tmp = tmp & 65535;
138         if (is_udp && !tmp)
139                 return (0xFFFF);
140         return tmp;
141 }
142
143 void
144 check_nat(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
145                 struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
146 {
147         if ((*args)->eh != NULL) {
148                 *cmd_ctl = IP_FW_CTL_NO;
149                 *cmd_val = IP_FW_NOT_MATCH;
150                 return;
151         }
152
153         struct ip_fw3_nat_context *nat_ctx;
154         struct cfg_nat *nat;
155         int nat_id;
156
157         nat_ctx = ip_fw3_nat_ctx[mycpuid];
158         (*args)->rule = *f;
159         nat = ((ipfw_insn_nat *)cmd)->nat;
160         if (nat == NULL) {
161                 nat_id = cmd->arg1;
162                 nat = nat_ctx->nats[nat_id - 1];
163                 if (nat == NULL) {
164                         *cmd_val = IP_FW_DENY;
165                         *cmd_ctl = IP_FW_CTL_DONE;
166                         return;
167                 }
168                 ((ipfw_insn_nat *)cmd)->nat = nat;
169         }
170         *cmd_val = ip_fw3_nat(*args, nat, (*args)->m);
171         *cmd_ctl = IP_FW_CTL_NAT;
172 }
173
174 int
175 ip_fw3_nat(struct ip_fw_args *args, struct cfg_nat *nat, struct mbuf *m)
176 {
177         struct state_tree *tree_out = NULL;
178         struct nat_state *s = NULL, *dup, *k, key;
179         struct nat_state2 *s2 = NULL;
180         struct ip *ip = mtod(m, struct ip *);
181         struct in_addr *old_addr = NULL, new_addr;
182         uint16_t *old_port = NULL, new_port;
183         uint16_t *csum = NULL, dlen = 0;
184         uint8_t udp = 0;
185         boolean_t pseudo = FALSE, need_return_state = FALSE;
186         struct cfg_alias *alias;
187         int i = 0, rand_n = 0;
188
189         k = &key;
190         memset(k, 0, LEN_NAT_STATE);
191         if (args->oif == NULL) {
192                 old_addr = &ip->ip_dst;
193                 k->dst_addr = ntohl(args->f_id.dst_ip);
194                 LIST_FOREACH(alias, &nat->alias, next) {
195                         if (alias->ip.s_addr == ntohl(args->f_id.dst_ip)) {
196                                 break;
197                         }
198                 }
199                 if (alias == NULL) {
200                         goto oops;
201                 }
202                 switch (ip->ip_p) {
203                 case IPPROTO_TCP:
204                         old_port = &L3HDR(struct tcphdr, ip)->th_dport;
205                         s2 = alias->tcp_in[*old_port - ALIAS_BEGIN];
206                         csum = &L3HDR(struct tcphdr, ip)->th_sum;
207                         break;
208                 case IPPROTO_UDP:
209                         old_port = &L3HDR(struct udphdr, ip)->uh_dport;
210                         s2 = alias->udp_in[*old_port - ALIAS_BEGIN];
211                         csum = &L3HDR(struct udphdr, ip)->uh_sum;
212                         udp = 1;
213                         break;
214                 case IPPROTO_ICMP:
215                         old_port = &L3HDR(struct icmp, ip)->icmp_id;
216                         s2 = alias->icmp_in[*old_port];
217                         csum = &L3HDR(struct icmp, ip)->icmp_cksum;
218                         break;
219                 default:
220                         panic("ipfw3: unsupported proto %u", ip->ip_p);
221                 }
222                 if (s2 == NULL) {
223                         goto oops;
224                 }
225         } else {
226                 old_addr = &ip->ip_src;
227                 k->src_addr = args->f_id.src_ip;
228                 k->dst_addr = args->f_id.dst_ip;
229                 switch (ip->ip_p) {
230                 case IPPROTO_TCP:
231                         k->src_port = args->f_id.src_port;
232                         k->dst_port = args->f_id.dst_port;
233                         m->m_pkthdr.csum_flags = CSUM_TCP;
234                         tree_out = &nat->rb_tcp_out;
235                         old_port = &L3HDR(struct tcphdr, ip)->th_sport;
236                         csum = &L3HDR(struct tcphdr, ip)->th_sum;
237                         break;
238                 case IPPROTO_UDP:
239                         k->src_port = args->f_id.src_port;
240                         k->dst_port = args->f_id.dst_port;
241                         m->m_pkthdr.csum_flags = CSUM_UDP;
242                         tree_out = &nat->rb_udp_out;
243                         old_port = &L3HDR(struct udphdr, ip)->uh_sport;
244                         csum = &L3HDR(struct udphdr, ip)->uh_sum;
245                         udp = 1;
246                         break;
247                 case IPPROTO_ICMP:
248                         k->src_port = L3HDR(struct icmp, ip)->icmp_id;
249                         k->dst_port = k->src_port;
250                         tree_out = &nat->rb_icmp_out;
251                         old_port = &L3HDR(struct icmp, ip)->icmp_id;
252                         csum = &L3HDR(struct icmp, ip)->icmp_cksum;
253                         break;
254                 default:
255                         panic("ipfw3: unsupported proto %u", ip->ip_p);
256                 }
257                 s = RB_FIND(state_tree, tree_out, k);
258                 if (s == NULL) {
259                         /* pick an alias ip randomly when there are multiple */
260                         if (nat->count > 1) {
261                                 rand_n = krandom() % nat->count;
262                         }
263                         LIST_FOREACH(alias, &nat->alias, next) {
264                                 if (i++ == rand_n) {
265                                         break;
266                                 }
267                         }
268                         switch  (ip->ip_p) {
269                         case IPPROTO_TCP:
270                                 m->m_pkthdr.csum_flags = CSUM_TCP;
271                                 s = kmalloc(LEN_NAT_STATE, M_IPFW3_NAT,
272                                                 M_INTWAIT | M_NULLOK | M_ZERO);
273
274                                 s->src_addr = args->f_id.src_ip;
275                                 s->src_port = args->f_id.src_port;
276
277                                 s->dst_addr = args->f_id.dst_ip;
278                                 s->dst_port = args->f_id.dst_port;
279
280                                 s->alias_addr = alias->ip.s_addr;
281                                 pick_alias_port(s, tree_out);
282                                 dup = RB_INSERT(state_tree, tree_out, s);
283                                 need_return_state = TRUE;
284                                 break;
285                         case IPPROTO_UDP:
286                                 m->m_pkthdr.csum_flags = CSUM_UDP;
287                                 s = kmalloc(LEN_NAT_STATE, M_IPFW3_NAT,
288                                                 M_INTWAIT | M_NULLOK | M_ZERO);
289
290                                 s->src_addr = args->f_id.src_ip;
291                                 s->src_port = args->f_id.src_port;
292
293                                 s->dst_addr = args->f_id.dst_ip;
294                                 s->dst_port = args->f_id.dst_port;
295
296                                 s->alias_addr = alias->ip.s_addr;
297                                 pick_alias_port(s, tree_out);
298                                 dup = RB_INSERT(state_tree, tree_out, s);
299                                 need_return_state = TRUE;
300                                 break;
301                         case IPPROTO_ICMP:
302                                 s = kmalloc(LEN_NAT_STATE, M_IPFW3_NAT,
303                                                 M_INTWAIT | M_NULLOK | M_ZERO);
304                                 s->src_addr = args->f_id.src_ip;
305                                 s->dst_addr = args->f_id.dst_ip;
306
307                                 s->src_port = *old_port;
308                                 s->dst_port = *old_port;
309
310                                 s->alias_addr = alias->ip.s_addr;
311                                 s->alias_port = htons(s->src_addr % ALIAS_RANGE);
312                                 dup = RB_INSERT(state_tree, tree_out, s);
313
314                                 s2 = kmalloc(LEN_NAT_STATE2, M_IPFW3_NAT,
315                                                 M_INTWAIT | M_NULLOK | M_ZERO);
316
317                                 s2->src_addr = args->f_id.dst_ip;
318                                 s2->dst_addr = alias->ip.s_addr;
319
320                                 s2->src_port = s->alias_port;
321                                 s2->dst_port = s->alias_port;
322
323                                 s2->alias_addr = htonl(args->f_id.src_ip);
324                                 s2->alias_port = *old_port;
325
326                                 alias->icmp_in[s->alias_port] = s2;
327                                 break;
328                         default :
329                                 goto oops;
330                         }
331                 }
332         }
333         if (args->oif == NULL) {
334                 new_addr.s_addr = s2->src_addr;
335                 new_port = s2->src_port;
336                 s2->timestamp = time_uptime;
337         } else {
338                 new_addr.s_addr = s->alias_addr;
339                 new_port = s->alias_port;
340                 s->timestamp = time_uptime;
341         }
342
343         /* replace src/dst and fix the checksum */
344         if (m->m_pkthdr.csum_flags & (CSUM_UDP | CSUM_TCP | CSUM_TSO)) {
345                 if ((m->m_pkthdr.csum_flags & CSUM_TSO) == 0) {
346                         dlen = ip->ip_len - (ip->ip_hl << 2);
347                 }
348                 pseudo = TRUE;
349         }
350         if (!pseudo) {
351                 const uint16_t *oaddr, *naddr;
352                 oaddr = (const uint16_t *)&old_addr->s_addr;
353                 naddr = (const uint16_t *)&new_addr.s_addr;
354                 ip->ip_sum = fix_cksum(ip->ip_sum, oaddr[0], naddr[0], 0);
355                 ip->ip_sum = fix_cksum(ip->ip_sum, oaddr[1], naddr[1], 0);
356                 if (ip->ip_p != IPPROTO_ICMP) {
357                         *csum = fix_cksum(*csum, oaddr[0], naddr[0], udp);
358                         *csum = fix_cksum(*csum, oaddr[1], naddr[1], udp);
359                 }
360         }
361         old_addr->s_addr = new_addr.s_addr;
362         if (!pseudo) {
363                 *csum = fix_cksum(*csum, *old_port, new_port, udp);
364         }
365         *old_port = new_port;
366
367         if (pseudo) {
368                 *csum = in_pseudo(ip->ip_src.s_addr,
369                                 ip->ip_dst.s_addr, htons(dlen + ip->ip_p));
370         }
371
372         /* prepare the state for return traffic */
373         if (need_return_state) {
374                 ip->ip_len = htons(ip->ip_len);
375                 ip->ip_off = htons(ip->ip_off);
376
377                 m->m_flags &= ~M_HASH;
378                 ip_hashfn(&m, 0);
379
380                 ip->ip_len = ntohs(ip->ip_len);
381                 ip->ip_off = ntohs(ip->ip_off);
382
383                 int nextcpu = netisr_hashcpu(m->m_pkthdr.hash);
384                 if (nextcpu != mycpuid) {
385                         struct netmsg_nat_state_add *msg;
386                         msg = kmalloc(LEN_NMSG_NAT_STATE_ADD,
387                                         M_LWKTMSG, M_NOWAIT | M_ZERO);
388                         netmsg_init(&msg->base, NULL, &curthread->td_msgport,
389                                         0, nat_state_add_dispatch);
390                         s2 = kmalloc(LEN_NAT_STATE2, M_IPFW3_NAT,
391                                         M_INTWAIT | M_NULLOK | M_ZERO);
392
393                         s2->src_addr = args->f_id.dst_ip;
394                         s2->src_port = args->f_id.dst_port;
395
396                         s2->dst_addr = alias->ip.s_addr;
397                         s2->dst_port = s->alias_port;
398
399                         s2->src_addr = htonl(args->f_id.src_ip);
400                         s2->src_port = htons(args->f_id.src_port);
401
402                         s2->timestamp = s->timestamp;
403                         msg->alias_addr.s_addr = alias->ip.s_addr;
404                         msg->alias_port = s->alias_port;
405                         msg->state = s2;
406                         msg->nat_id = nat->id;
407                         msg->proto = ip->ip_p;
408                         netisr_sendmsg(&msg->base, nextcpu);
409                 } else {
410                         s2 = kmalloc(LEN_NAT_STATE2, M_IPFW3_NAT,
411                                         M_INTWAIT | M_NULLOK | M_ZERO);
412
413                         s2->src_addr = args->f_id.dst_ip;
414                         s2->dst_addr = alias->ip.s_addr;
415
416                         s2->src_port = s->alias_port;
417                         s2->dst_port = s->alias_port;
418
419                         s2->src_addr = htonl(args->f_id.src_ip);
420                         s2->src_port = htons(args->f_id.src_port);
421
422                         s2->timestamp = s->timestamp;
423                         if (ip->ip_p == IPPROTO_TCP) {
424                                 alias->tcp_in[s->alias_port - ALIAS_BEGIN] = s2;
425                         } else {
426                                 alias->udp_in[s->alias_port - ALIAS_BEGIN] = s2;
427                         }
428                 }
429         }
430         return IP_FW_NAT;
431 oops:
432         IPFW3_DEBUG1("oops\n");
433         return IP_FW_DENY;
434 }
435
436 void
437 pick_alias_port(struct nat_state *s, struct state_tree *tree)
438 {
439         do {
440                 s->alias_port = htons(krandom() % ALIAS_RANGE + ALIAS_BEGIN);
441         } while (RB_FIND(state_tree, tree, s) != NULL);
442 }
443
444 int
445 ip_fw3_nat_state_cmp(struct nat_state *s1, struct nat_state *s2)
446 {
447         if (s1->src_addr > s2->src_addr)
448                 return 1;
449         if (s1->src_addr < s2->src_addr)
450                 return -1;
451
452         if (s1->dst_addr > s2->dst_addr)
453                 return 1;
454         if (s1->dst_addr < s2->dst_addr)
455                 return -1;
456
457         if (s1->src_port > s2->src_port)
458                 return 1;
459         if (s1->src_port < s2->src_port)
460                 return -1;
461
462         if (s1->dst_port > s2->dst_port)
463                 return 1;
464         if (s1->dst_port < s2->dst_port)
465                 return -1;
466
467         return 0;
468 }
469
470 int
471 ip_fw3_ctl_nat_get_cfg(struct sockopt *sopt)
472 {
473         struct ip_fw3_nat_context *nat_ctx;
474         struct ioc_nat *ioc;
475         struct cfg_nat *nat;
476         struct cfg_alias *alias;
477         struct in_addr *ip;
478         size_t valsize;
479         int i, len;
480
481         len = 0;
482         nat_ctx = ip_fw3_nat_ctx[mycpuid];
483         valsize = sopt->sopt_valsize;
484         ioc = (struct ioc_nat *)sopt->sopt_val;
485
486         for (i = 0; i < NAT_ID_MAX; i++) {
487                 nat = nat_ctx->nats[i];
488                 if (nat != NULL) {
489                         len += LEN_IOC_NAT;
490                         if (len >= valsize) {
491                                 goto nospace;
492                         }
493                         ioc->id = nat->id;
494                         ioc->count = nat->count;
495                         ip = &ioc->ip;
496                         LIST_FOREACH(alias, &nat->alias, next) {
497                                 len += LEN_IN_ADDR;
498                                 if (len > valsize) {
499                                         goto nospace;
500                                 }
501                                 bcopy(&alias->ip, ip, LEN_IN_ADDR);
502                                 ip++;
503                         }
504                 }
505         }
506         sopt->sopt_valsize = len;
507         return 0;
508 nospace:
509         bzero(sopt->sopt_val, sopt->sopt_valsize);
510         sopt->sopt_valsize = 0;
511         return 0;
512 }
513
514 int
515 ip_fw3_ctl_nat_get_record(struct sockopt *sopt)
516 {
517         struct ip_fw3_nat_context *nat_ctx;
518         struct cfg_nat *the;
519         size_t sopt_size, total_len = 0;
520         struct ioc_nat_state *ioc;
521         int ioc_nat_id, i, n, cpu;
522         struct nat_state        *s;
523         struct nat_state2       *s2;
524         struct cfg_alias        *a1;
525
526         ioc_nat_id = *((int *)(sopt->sopt_val));
527         sopt_size = sopt->sopt_valsize;
528         ioc = (struct ioc_nat_state *)sopt->sopt_val;
529         /* icmp states only in CPU 0 */
530         cpu = 0;
531         nat_ctx = ip_fw3_nat_ctx[cpu];
532         for (n = 0; n < NAT_ID_MAX; n++) {
533                 if (ioc_nat_id == 0 || ioc_nat_id == n + 1) {
534                         if (nat_ctx->nats[n] == NULL)
535                                 break;
536                         the = nat_ctx->nats[n];
537                         RB_FOREACH(s, state_tree, &the->rb_icmp_out) {
538                                 total_len += LEN_IOC_NAT_STATE;
539                                 if (total_len > sopt_size)
540                                         goto nospace;
541                                 ioc->src_addr.s_addr = ntohl(s->src_addr);
542                                 ioc->dst_addr.s_addr = s->dst_addr;
543                                 ioc->alias_addr.s_addr = s->alias_addr;
544                                 ioc->src_port = s->src_port;
545                                 ioc->dst_port = s->dst_port;
546                                 ioc->alias_port = s->alias_port;
547                                 ioc->nat_id = n + 1;
548                                 ioc->cpu_id = cpu;
549                                 ioc->proto = IPPROTO_ICMP;
550                                 ioc->direction = 1;
551                                 ioc->life = s->timestamp +
552                                         sysctl_var_icmp_timeout - time_uptime;
553                                 ioc++;
554                         }
555
556                         LIST_FOREACH(a1, &the->alias, next) {
557                         for (i = 0; i < ALIAS_RANGE; i++) {
558                                 s2 = a1->icmp_in[i];
559                                 if (s2 == NULL) {
560                                         continue;
561                                 }
562
563                                 total_len += LEN_IOC_NAT_STATE;
564                                 if (total_len > sopt_size)
565                                         goto nospace;
566
567                                 ioc->src_addr.s_addr = ntohl(s2->src_addr);
568                                 ioc->dst_addr.s_addr = s2->dst_addr;
569                                 ioc->alias_addr.s_addr = s2->alias_addr;
570                                 ioc->src_port = s2->src_port;
571                                 ioc->dst_port = s2->dst_port;
572                                 ioc->alias_port = s2->alias_port;
573                                 ioc->nat_id = n + 1;
574                                 ioc->cpu_id = cpu;
575                                 ioc->proto = IPPROTO_ICMP;
576                                 ioc->direction = 0;
577                                 ioc->life = s2->timestamp +
578                                         sysctl_var_icmp_timeout - time_uptime;
579                                 ioc++;
580                         }
581                         }
582                 }
583         }
584
585         /* tcp states */
586         for (cpu = 0; cpu < ncpus; cpu++) {
587                 nat_ctx = ip_fw3_nat_ctx[cpu];
588                 for (n = 0; n < NAT_ID_MAX; n++) {
589                         if (ioc_nat_id == 0 || ioc_nat_id == n + 1) {
590                                 if (nat_ctx->nats[n] == NULL)
591                                         break;
592                                 the = nat_ctx->nats[n];
593                                 RB_FOREACH(s, state_tree, &the->rb_tcp_out) {
594                                         total_len += LEN_IOC_NAT_STATE;
595                                         if (total_len > sopt_size)
596                                                 goto nospace;
597                                         ioc->src_addr.s_addr = ntohl(s->src_addr);
598                                         ioc->dst_addr.s_addr = ntohl(s->dst_addr);
599                                         ioc->alias_addr.s_addr = s->alias_addr;
600                                         ioc->src_port = ntohs(s->src_port);
601                                         ioc->dst_port = ntohs(s->dst_port);
602                                         ioc->alias_port = s->alias_port;
603                                         ioc->nat_id = n + 1;
604                                         ioc->cpu_id = cpu;
605                                         ioc->proto = IPPROTO_TCP;
606                                         ioc->direction = 1;
607                                         ioc->life = s->timestamp +
608                                                 sysctl_var_tcp_timeout - time_uptime;
609                                         ioc++;
610                                 }
611                                 LIST_FOREACH(a1, &the->alias, next) {
612                                         for (i = 0; i < ALIAS_RANGE; i++) {
613                                                 s2 = a1->tcp_in[i];
614                                                 if (s2 == NULL) {
615                                                         continue;
616                                                 }
617
618                                                 total_len += LEN_IOC_NAT_STATE;
619                                                 if (total_len > sopt_size)
620                                                         goto nospace;
621
622                                                 ioc->src_addr.s_addr = ntohl(s2->src_addr);
623                                                 ioc->dst_addr.s_addr = s2->dst_addr;
624                                                 ioc->alias_addr.s_addr = s2->alias_addr;
625                                                 ioc->src_port = s2->src_port;
626                                                 ioc->dst_port = s2->dst_port;
627                                                 ioc->alias_port = s2->alias_port;
628                                                 ioc->nat_id = n + 1;
629                                                 ioc->cpu_id = cpu;
630                                                 ioc->proto = IPPROTO_TCP;
631                                                 ioc->direction = 0;
632                                                 ioc->life = s2->timestamp +
633                                                         sysctl_var_icmp_timeout - time_uptime;
634                                                 ioc++;
635                                         }
636                                 }
637                         }
638                 }
639         }
640
641         /* udp states */
642         for (cpu = 0; cpu < ncpus; cpu++) {
643                 nat_ctx = ip_fw3_nat_ctx[cpu];
644                 for (n = 0; n < NAT_ID_MAX; n++) {
645                         if (ioc_nat_id == 0 || ioc_nat_id == n + 1) {
646                                 if (nat_ctx->nats[n] == NULL)
647                                         break;
648                                 the = nat_ctx->nats[n];
649                                 RB_FOREACH(s, state_tree, &the->rb_udp_out) {
650                                         total_len += LEN_IOC_NAT_STATE;
651                                         if (total_len > sopt_size)
652                                                 goto nospace;
653                                         ioc->src_addr.s_addr = ntohl(s->src_addr);
654                                         ioc->dst_addr.s_addr = s->dst_addr;
655                                         ioc->alias_addr.s_addr = s->alias_addr;
656                                         ioc->src_port = s->src_port;
657                                         ioc->dst_port = s->dst_port;
658                                         ioc->alias_port = s->alias_port;
659                                         ioc->nat_id = n + 1;
660                                         ioc->cpu_id = cpu;
661                                         ioc->proto = IPPROTO_UDP;
662                                         ioc->direction = 1;
663                                         ioc->life = s->timestamp +
664                                                 sysctl_var_udp_timeout - time_uptime;
665                                         ioc++;
666                                 }
667                                 LIST_FOREACH(a1, &the->alias, next) {
668                                         for (i = 0; i < ALIAS_RANGE; i++) {
669                                                 s2 = a1->udp_in[i];
670                                                 if (s2 == NULL) {
671                                                         continue;
672                                                 }
673
674                                                 total_len += LEN_IOC_NAT_STATE;
675                                                 if (total_len > sopt_size)
676                                                         goto nospace;
677
678                                                 ioc->src_addr.s_addr = ntohl(s2->src_addr);
679                                                 ioc->dst_addr.s_addr = s2->dst_addr;
680                                                 ioc->alias_addr.s_addr = s2->alias_addr;
681                                                 ioc->src_port = s2->src_port;
682                                                 ioc->dst_port = s2->dst_port;
683                                                 ioc->alias_port = s2->alias_port;
684                                                 ioc->nat_id = n + 1;
685                                                 ioc->cpu_id = cpu;
686                                                 ioc->proto = IPPROTO_UDP;
687                                                 ioc->direction = 0;
688                                                 ioc->life = s2->timestamp +
689                                                         sysctl_var_icmp_timeout - time_uptime;
690                                                 ioc++;
691                                         }
692                                 }
693                         }
694                 }
695         }
696         sopt->sopt_valsize = total_len;
697         return 0;
698 nospace:
699         return 0;
700 }
701
702 void
703 nat_state_add_dispatch(netmsg_t add_msg)
704 {
705         struct ip_fw3_nat_context *nat_ctx;
706         struct netmsg_nat_state_add *msg;
707         struct cfg_nat *nat;
708         struct nat_state2 *s2;
709         struct cfg_alias *alias;
710
711         nat_ctx = ip_fw3_nat_ctx[mycpuid];
712         msg = (struct netmsg_nat_state_add *)add_msg;
713         nat = nat_ctx->nats[msg->nat_id - 1];
714
715         LIST_FOREACH(alias, &nat->alias, next) {
716                 if (alias->ip.s_addr == msg->alias_addr.s_addr) {
717                         break;
718                 }
719         }
720         s2 = msg->state;
721         if (msg->proto == IPPROTO_TCP) {
722                 alias->tcp_in[msg->alias_port - ALIAS_BEGIN] = s2;
723         } else {
724                 alias->udp_in[msg->alias_port - ALIAS_BEGIN] = s2;
725         }
726 }
727
728 /*
729  * Init the RB trees only when the NAT is configured.
730  */
731 void
732 nat_add_dispatch(netmsg_t nat_add_msg)
733 {
734         struct ip_fw3_nat_context *nat_ctx;
735         struct netmsg_nat_add *msg;
736         struct ioc_nat *ioc;
737         struct cfg_nat *nat;
738         struct cfg_alias *alias;
739         struct in_addr *ip;
740         int n;
741
742         msg = (struct netmsg_nat_add *)nat_add_msg;
743         ioc = &msg->ioc_nat;
744         nat_ctx = ip_fw3_nat_ctx[mycpuid];
745
746         if (nat_ctx->nats[ioc->id - 1] == NULL) {
747                 /* op = set, and nat not exists */
748                 nat = kmalloc(LEN_CFG_NAT, M_IPFW3_NAT, M_WAITOK | M_ZERO);
749                 LIST_INIT(&nat->alias);
750                 RB_INIT(&nat->rb_tcp_out);
751                 RB_INIT(&nat->rb_udp_out);
752                 if (mycpuid == 0) {
753                         RB_INIT(&nat->rb_icmp_out);
754                 }
755                 nat->id = ioc->id;
756                 nat->count = ioc->count;
757                 ip = &ioc->ip;
758                 for (n = 0; n < ioc->count; n++) {
759                         alias = kmalloc(LEN_CFG_ALIAS,
760                                         M_IPFW3_NAT, M_WAITOK | M_ZERO);
761                         memcpy(&alias->ip, ip, LEN_IN_ADDR);
762                         LIST_INSERT_HEAD((&nat->alias), alias, next);
763                         ip++;
764                 }
765                 nat_ctx->nats[ioc->id - 1] = nat;
766         }
767         netisr_forwardmsg_all(&msg->base, mycpuid + 1);
768 }
769
770 int
771 ip_fw3_ctl_nat_add(struct sockopt *sopt)
772 {
773         struct netmsg_nat_add nat_add_msg, *msg;
774         struct ioc_nat *ioc;
775         msg = &nat_add_msg;
776
777         ioc = (struct ioc_nat *)(sopt->sopt_val);
778         sooptcopyin(sopt, &msg->ioc_nat, sopt->sopt_valsize,
779                         sizeof(struct ioc_nat));
780         netmsg_init(&msg->base, NULL, &curthread->td_msgport, 0,
781                         nat_add_dispatch);
782         netisr_domsg(&msg->base, 0);
783         return 0;
784 }
785
786 void
787 nat_del_dispatch(netmsg_t nat_del_msg)
788 {
789         struct ip_fw3_nat_context *nat_ctx;
790         struct netmsg_nat_del *msg;
791         struct cfg_nat *nat;
792         struct nat_state *s, *tmp;
793         struct cfg_alias *alias, *tmp3;
794
795         msg = (struct netmsg_nat_del *)nat_del_msg;
796
797         nat_ctx = ip_fw3_nat_ctx[mycpuid];
798         nat = nat_ctx->nats[msg->id - 1];
799         if (nat != NULL) {
800                 /* the icmp states will only stored in cpu 0 */
801                 RB_FOREACH_SAFE(s, state_tree, &nat->rb_icmp_out, tmp) {
802                         RB_REMOVE(state_tree, &nat->rb_icmp_out, s);
803                         if (s != NULL) {
804                                 kfree(s, M_IPFW3_NAT);
805                         }
806                 }
807                 /*
808                 LIST_FOREACH_MUTABLE(s2, &nat->alias->icmp_in, next, tmp2) {
809                         LIST_REMOVE(s2, next);
810                         if (s != NULL) {
811                                 kfree(s, M_IPFW3_NAT);
812                         }
813                 }
814                 */
815
816                 RB_FOREACH_SAFE(s, state_tree, &nat->rb_tcp_out, tmp) {
817                         RB_REMOVE(state_tree, &nat->rb_tcp_out, s);
818                         if (s != NULL) {
819                                 kfree(s, M_IPFW3_NAT);
820                         }
821                 }
822                 /*
823                 LIST_FOREACH_MUTABLE(s2, &nat->alias->tcp_in, next, tmp2) {
824                         LIST_REMOVE(s2, next);
825                         if (s != NULL) {
826                                 kfree(s, M_IPFW3_NAT);
827                         }
828                 }
829                 */
830                 RB_FOREACH_SAFE(s, state_tree, &nat->rb_udp_out, tmp) {
831                         RB_REMOVE(state_tree, &nat->rb_udp_out, s);
832                         if (s != NULL) {
833                                 kfree(s, M_IPFW3_NAT);
834                         }
835                 }
836                 /*
837                 LIST_FOREACH_MUTABLE(s2, &nat->alias->udp_in, next, tmp2) {
838                         LIST_REMOVE(s2, next);
839                         if (s != NULL) {
840                                 kfree(s, M_IPFW3_NAT);
841                         }
842                 }
843                 */
844                 LIST_FOREACH_MUTABLE(alias, &nat->alias, next, tmp3) {
845                         kfree(alias, M_IPFW3_NAT);
846                 }
847                 kfree(nat, M_IPFW3_NAT);
848                 nat_ctx->nats[msg->id - 1] = NULL;
849         }
850         netisr_forwardmsg_all(&nat_del_msg->base, mycpuid + 1);
851 }
852 int
853 ip_fw3_ctl_nat_del(struct sockopt *sopt)
854 {
855         struct netmsg_nat_del nat_del_msg, *msg;
856
857         msg = &nat_del_msg;
858         msg->id = *((int *)sopt->sopt_val);
859         netmsg_init(&msg->base, NULL, &curthread->td_msgport,
860                         0, nat_del_dispatch);
861
862         netisr_domsg(&msg->base, 0);
863         return 0;
864 }
865 int
866 ip_fw3_ctl_nat_flush(struct sockopt *sopt)
867 {
868         struct netmsg_nat_del nat_del_msg, *msg;
869         int i;
870         msg = &nat_del_msg;
871         for (i = 0; i < NAT_ID_MAX; i++) {
872                 msg->id = i + 1;
873                 netmsg_init(&msg->base, NULL, &curthread->td_msgport,
874                                 0, nat_del_dispatch);
875
876                 netisr_domsg(&msg->base, 0);
877         }
878         return 0;
879 }
880
881 int
882 ip_fw3_ctl_nat_sockopt(struct sockopt *sopt)
883 {
884         int error = 0;
885         switch (sopt->sopt_name) {
886         case IP_FW_NAT_ADD:
887                 error = ip_fw3_ctl_nat_add(sopt);
888                 break;
889         case IP_FW_NAT_DEL:
890                 error = ip_fw3_ctl_nat_del(sopt);
891                 break;
892         case IP_FW_NAT_FLUSH:
893                 error = ip_fw3_ctl_nat_flush(sopt);
894                 break;
895         case IP_FW_NAT_GET:
896                 error = ip_fw3_ctl_nat_get_cfg(sopt);
897                 break;
898         case IP_FW_NAT_GET_RECORD:
899                 error = ip_fw3_ctl_nat_get_record(sopt);
900                 break;
901         default:
902                 kprintf("ipfw3 nat invalid socket option %d\n",
903                                 sopt->sopt_name);
904         }
905         return error;
906 }
907
908 void
909 nat_init_ctx_dispatch(netmsg_t msg)
910 {
911         struct ip_fw3_nat_context *tmp;
912         tmp = kmalloc(sizeof(struct ip_fw3_nat_context),
913                                 M_IPFW3_NAT, M_WAITOK | M_ZERO);
914
915         ip_fw3_nat_ctx[mycpuid] = tmp;
916         netisr_forwardmsg_all(&msg->base, mycpuid + 1);
917 }
918
919 void
920 nat_fnit_ctx_dispatch(netmsg_t msg)
921 {
922         kfree(ip_fw3_nat_ctx[mycpuid], M_IPFW3_NAT);
923         netisr_forwardmsg_all(&msg->base, mycpuid + 1);
924 }
925
926 static void
927 nat_cleanup_func_dispatch(netmsg_t nmsg)
928 {
929         struct nat_state *s, *tmp;
930         struct ip_fw3_nat_context *nat_ctx;
931         struct cfg_nat *nat;
932         struct cfg_alias *a1, *tmp2;
933         struct nat_state2 *s2;
934         int i, j;
935
936         nat_ctx = ip_fw3_nat_ctx[mycpuid];
937         for (j = 0; j < NAT_ID_MAX; j++) {
938                 nat = nat_ctx->nats[j];
939                 if (nat == NULL)
940                         continue;
941                 /* check the nat_states, remove the expired state */
942                 /* the icmp states will only stored in cpu 0 */
943                 RB_FOREACH_SAFE(s, state_tree, &nat->rb_icmp_out, tmp) {
944                         if (time_uptime - s->timestamp > sysctl_var_icmp_timeout) {
945                                 RB_REMOVE(state_tree, &nat->rb_icmp_out, s);
946                                 kfree(s, M_IPFW3_NAT);
947                         }
948                 }
949                 LIST_FOREACH_MUTABLE(a1, &nat->alias, next, tmp2) {
950                         for (i = 0; i < ALIAS_RANGE; i++) {
951                                 s2 = a1->icmp_in[i];
952                                 if (s2 != NULL) {
953                                         if (time_uptime - s2->timestamp > sysctl_var_icmp_timeout) {
954                                                 a1->icmp_in[i] = NULL;
955                                                 kfree(s2, M_IPFW3_NAT);
956                                         }
957                                 }
958
959                         }
960                 }
961
962                 RB_FOREACH_SAFE(s, state_tree, &nat->rb_tcp_out, tmp) {
963                         if (time_uptime - s->timestamp > sysctl_var_tcp_timeout) {
964                                 RB_REMOVE(state_tree, &nat->rb_tcp_out, s);
965                                 kfree(s, M_IPFW3_NAT);
966                         }
967                 }
968                 LIST_FOREACH_MUTABLE(a1, &nat->alias, next, tmp2) {
969                         for (i = 0; i < ALIAS_RANGE; i++) {
970                                 s2 = a1->tcp_in[i];
971                                 if (s2 != NULL) {
972                                         if (time_uptime - s2->timestamp > sysctl_var_icmp_timeout) {
973                                                 a1->tcp_in[i] = NULL;
974                                                 kfree(s2, M_IPFW3_NAT);
975                                         }
976                                 }
977
978                         }
979                 }
980                 RB_FOREACH_SAFE(s, state_tree, &nat->rb_udp_out, tmp) {
981                         if (time_uptime - s->timestamp > sysctl_var_udp_timeout) {
982                                 RB_REMOVE(state_tree, &nat->rb_udp_out, s);
983                                 kfree(s, M_IPFW3_NAT);
984                         }
985                 }
986                 LIST_FOREACH_MUTABLE(a1, &nat->alias, next, tmp2) {
987                         for (i = 0; i < ALIAS_RANGE; i++) {
988                                 s2 = a1->udp_in[i];
989                                 if (s2 != NULL) {
990                                         if (time_uptime - s2->timestamp > sysctl_var_icmp_timeout) {
991                                                 a1->udp_in[i] = NULL;
992                                                 kfree(s2, M_IPFW3_NAT);
993                                         }
994                                 }
995
996                         }
997                 }
998         }
999         netisr_forwardmsg_all(&nmsg->base, mycpuid + 1);
1000 }
1001
1002 static void
1003 ip_fw3_nat_cleanup_func(void *dummy __unused)
1004 {
1005         struct netmsg_base msg;
1006         netmsg_init(&msg, NULL, &curthread->td_msgport, 0,
1007                         nat_cleanup_func_dispatch);
1008         netisr_domsg(&msg, 0);
1009
1010         callout_reset(&ip_fw3_nat_cleanup_callout,
1011                         sysctl_var_cleanup_interval * hz,
1012                         ip_fw3_nat_cleanup_func, NULL);
1013 }
1014
1015 static
1016 int ip_fw3_nat_init(void)
1017 {
1018         struct netmsg_base msg;
1019         ip_fw3_register_module(MODULE_NAT_ID, MODULE_NAT_NAME);
1020         ip_fw3_register_filter_funcs(MODULE_NAT_ID, O_NAT_NAT,
1021                         (filter_func)check_nat);
1022         ip_fw3_ctl_nat_ptr = ip_fw3_ctl_nat_sockopt;
1023         netmsg_init(&msg, NULL, &curthread->td_msgport,
1024                         0, nat_init_ctx_dispatch);
1025         netisr_domsg(&msg, 0);
1026
1027         callout_init_mp(&ip_fw3_nat_cleanup_callout);
1028         callout_reset(&ip_fw3_nat_cleanup_callout,
1029                         sysctl_var_cleanup_interval * hz,
1030                         ip_fw3_nat_cleanup_func,
1031                         NULL);
1032         return 0;
1033 }
1034
1035 static int
1036 ip_fw3_nat_fini(void)
1037 {
1038         struct netmsg_base msg;
1039         struct netmsg_nat_del nat_del_msg, *msg1;
1040         int i;
1041
1042         callout_stop(&ip_fw3_nat_cleanup_callout);
1043
1044         msg1 = &nat_del_msg;
1045         for (i = 0; i < NAT_ID_MAX; i++) {
1046                 msg1->id = i + 1;
1047                 netmsg_init(&msg1->base, NULL, &curthread->td_msgport,
1048                                 0, nat_del_dispatch);
1049
1050                 netisr_domsg(&msg1->base, 0);
1051         }
1052
1053         netmsg_init(&msg, NULL, &curthread->td_msgport,
1054                         0, nat_fnit_ctx_dispatch);
1055         netisr_domsg(&msg, 0);
1056
1057         return ip_fw3_unregister_module(MODULE_NAT_ID);
1058 }
1059
1060 static int
1061 ip_fw3_nat_modevent(module_t mod, int type, void *data)
1062 {
1063         switch (type) {
1064         case MOD_LOAD:
1065                 return ip_fw3_nat_init();
1066         case MOD_UNLOAD:
1067                 return ip_fw3_nat_fini();
1068         default:
1069                 break;
1070         }
1071         return 0;
1072 }
1073
1074 moduledata_t ip_fw3_nat_mod = {
1075         "ipfw3_nat",
1076         ip_fw3_nat_modevent,
1077         NULL
1078 };
1079
1080 DECLARE_MODULE(ipfw3_nat, ip_fw3_nat_mod,
1081                 SI_SUB_PROTO_IFATTACHDOMAIN, SI_ORDER_ANY);
1082 MODULE_DEPEND(ipfw3_nat, ipfw3_basic, 1, 1, 1);
1083 MODULE_VERSION(ipfw3_nat, 1);