ipfw3: filtering with lookup table
[dragonfly.git] / sys / net / ipfw3_basic / ip_fw3_basic.c
1 /*
2  * Copyright (c) 2014 The DragonFly Project.  All rights reserved.
3  *
4  * This code is derived from software contributed to The DragonFly Project
5  * by Bill Yuan <bycn82@gmail.com>
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  *
11  * 1. Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in
15  *    the documentation and/or other materials provided with the
16  *    distribution.
17  * 3. Neither the name of The DragonFly Project nor the names of its
18  *    contributors may be used to endorse or promote products derived
19  *    from this software without specific, prior written permission.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
24  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE
25  * COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
26  * INCIDENTAL, SPECIAL, EXEMPLARY OR CONSEQUENTIAL DAMAGES (INCLUDING,
27  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
28  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
29  * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
30  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
31  * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
32  * SUCH DAMAGE.
33  */
34
35 #include <sys/param.h>
36 #include <sys/kernel.h>
37 #include <sys/malloc.h>
38 #include <sys/mbuf.h>
39 #include <sys/socketvar.h>
40 #include <sys/sysctl.h>
41 #include <sys/syslog.h>
42 #include <sys/systimer.h>
43 #include <sys/thread2.h>
44 #include <sys/in_cksum.h>
45
46 #include <net/if.h>
47 #include <net/ethernet.h>
48 #include <net/netmsg2.h>
49 #include <net/netisr2.h>
50 #include <net/route.h>
51
52 #include <netinet/ip.h>
53 #include <netinet/in.h>
54 #include <netinet/in_systm.h>
55 #include <netinet/in_var.h>
56 #include <netinet/in_pcb.h>
57 #include <netinet/ip_var.h>
58 #include <netinet/ip_icmp.h>
59 #include <netinet/tcp.h>
60 #include <netinet/tcp_timer.h>
61 #include <netinet/tcp_var.h>
62 #include <netinet/tcpip.h>
63 #include <netinet/udp.h>
64 #include <netinet/udp_var.h>
65 #include <netinet/ip_divert.h>
66 #include <netinet/if_ether.h>
67
68 #include <net/ipfw3/ip_fw.h>
69 #include <net/ipfw3/ip_fw3_table.h>
70
71 #include "ip_fw3_basic.h"
72
73 #define TIME_LEQ(a, b)  ((int)((a) - (b)) <= 0)
74
75 extern struct ipfw_context      *ipfw_ctx[MAXCPU];
76 extern int fw_verbose;
77 extern ipfw_basic_delete_state_t *ipfw_basic_flush_state_prt;
78 extern ipfw_basic_append_state_t *ipfw_basic_append_state_prt;
79
80 static int ip_fw_basic_loaded;
81 static struct netmsg_base ipfw_timeout_netmsg;  /* schedule ipfw timeout */
82 static struct callout ipfw_tick_callout;
83 static int state_lifetime = 20;
84 static int state_expiry_check_interval = 10;
85 static int state_count_max = 4096;
86 static int state_hash_size_old = 0;
87 static int state_hash_size = 4096;
88
89
90 static int ipfw_sysctl_adjust_hash_size(SYSCTL_HANDLER_ARGS);
91 void adjust_hash_size_dispatch(netmsg_t nmsg);
92
93 SYSCTL_NODE(_net_inet_ip, OID_AUTO, fw_basic,
94                 CTLFLAG_RW, 0, "Firewall Basic");
95 SYSCTL_PROC(_net_inet_ip_fw_basic, OID_AUTO, state_hash_size,
96                 CTLTYPE_INT | CTLFLAG_RW, &state_hash_size, 0,
97                 ipfw_sysctl_adjust_hash_size, "I", "Adjust hash size");
98
99 SYSCTL_INT(_net_inet_ip_fw_basic, OID_AUTO, state_lifetime, CTLFLAG_RW,
100                 &state_lifetime, 0, "default life time");
101 SYSCTL_INT(_net_inet_ip_fw_basic, OID_AUTO,
102                 state_expiry_check_interval, CTLFLAG_RW,
103                 &state_expiry_check_interval, 0,
104                 "default state expiry check interval");
105 SYSCTL_INT(_net_inet_ip_fw_basic, OID_AUTO, state_count_max, CTLFLAG_RW,
106                 &state_count_max, 0, "maximum of state");
107
108 static int
109 ipfw_sysctl_adjust_hash_size(SYSCTL_HANDLER_ARGS)
110 {
111         int error, value = 0;
112
113         state_hash_size_old = state_hash_size;
114         value = state_hash_size;
115         error = sysctl_handle_int(oidp, &value, 0, req);
116         if (error || !req->newptr) {
117                 goto back;
118         }
119         /*
120          * Make sure we have a power of 2 and
121          * do not allow more than 64k entries.
122          */
123         error = EINVAL;
124         if (value <= 1 || value > 65536) {
125                 goto back;
126         }
127         if ((value & (value - 1)) != 0) {
128                 goto back;
129         }
130
131         error = 0;
132         if (state_hash_size != value) {
133                 state_hash_size = value;
134
135                 struct netmsg_base *msg, the_msg;
136                 msg = &the_msg;
137                 bzero(msg,sizeof(struct netmsg_base));
138
139                 netmsg_init(msg, NULL, &curthread->td_msgport,
140                                 0, adjust_hash_size_dispatch);
141                 ifnet_domsg(&msg->lmsg, 0);
142         }
143 back:
144         return error;
145 }
146
147 void
148 adjust_hash_size_dispatch(netmsg_t nmsg)
149 {
150         struct ipfw_state_context *state_ctx;
151         struct ip_fw_state *the_state, *state;
152         struct ipfw_context *ctx = ipfw_ctx[mycpuid];
153         int i;
154
155         for (i = 0; i < state_hash_size_old; i++) {
156                 state_ctx = &ctx->state_ctx[i];
157                 if (state_ctx != NULL) {
158                         state = state_ctx->state;
159                         while (state != NULL) {
160                                 the_state = state;
161                                 state = state->next;
162                                 kfree(the_state, M_IPFW3_BASIC);
163                                 the_state = NULL;
164                         }
165                 }
166         }
167         kfree(ctx->state_ctx,M_IPFW3_BASIC);
168         ctx->state_ctx = kmalloc(state_hash_size *
169                                 sizeof(struct ipfw_state_context),
170                                 M_IPFW3_BASIC, M_WAITOK | M_ZERO);
171         ctx->state_hash_size = state_hash_size;
172         ifnet_forwardmsg(&nmsg->lmsg, mycpuid + 1);
173 }
174
175
176 /*      prototype of the checker functions      */
177 void check_count(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
178         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
179 void check_skipto(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
180         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
181 void check_forward(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
182         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
183 void check_check_state(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
184         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
185
186 void check_in(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
187         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
188 void check_out(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
189         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
190 void check_via(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
191         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
192 void check_proto(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
193         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
194 void check_prob(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
195         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
196 void check_from(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
197         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
198 void check_from_lookup(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
199         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
200 void check_from_me(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
201         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
202 void check_from_mask(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
203         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
204 void check_to(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
205         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
206 void check_to_lookup(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
207         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
208 void check_to_me(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
209         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
210 void check_to_mask(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
211         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
212 void check_keep_state(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
213         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
214 void check_tag(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
215         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
216 void check_untag(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
217         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
218 void check_tagged(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
219         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
220
221 /*      prototype of the utility functions      */
222 static struct ip_fw *lookup_next_rule(struct ip_fw *me);
223 static int iface_match(struct ifnet *ifp, ipfw_insn_if *cmd);
224 static __inline int hash_packet(struct ipfw_flow_id *id);
225
226 static __inline int
227 hash_packet(struct ipfw_flow_id *id)
228 {
229         uint32_t i;
230         i = (id->proto) ^ (id->dst_ip) ^ (id->src_ip) ^
231                 (id->dst_port) ^ (id->src_port);
232         i &= state_hash_size - 1;
233         return i;
234 }
235
236 static struct ip_fw *
237 lookup_next_rule(struct ip_fw *me)
238 {
239         struct ip_fw *rule = NULL;
240         ipfw_insn *cmd;
241
242         /* look for action, in case it is a skipto */
243         cmd = ACTION_PTR(me);
244         if ((int)cmd->module == MODULE_BASIC_ID &&
245                 (int)cmd->opcode == O_BASIC_SKIPTO) {
246                 for (rule = me->next; rule; rule = rule->next) {
247                         if (rule->rulenum >= cmd->arg1)
248                                 break;
249                 }
250         }
251         if (rule == NULL) /* failure or not a skipto */
252                 rule = me->next;
253
254         me->next_rule = rule;
255         return rule;
256 }
257
258 /*
259  * when all = 1, it will check all the state_ctx
260  */
261 static struct ip_fw_state *
262 lookup_state(struct ip_fw_args *args, ipfw_insn *cmd, int *limited, int all)
263 {
264         struct ip_fw_state *state = NULL;
265         struct ipfw_context *ctx = ipfw_ctx[mycpuid];
266         struct ipfw_state_context *state_ctx;
267         int start, end, i, count = 0;
268
269         if (all && cmd->arg1) {
270                 start = 0;
271                 end = state_hash_size - 1;
272         } else {
273                 start = hash_packet(&args->f_id);
274                 end = hash_packet(&args->f_id);
275         }
276         for (i = start; i <= end; i++) {
277                 state_ctx = &ctx->state_ctx[i];
278                 if (state_ctx != NULL) {
279                         state = state_ctx->state;
280                         struct ipfw_flow_id     *fid = &args->f_id;
281                         while (state != NULL) {
282                                 if (cmd->arg1) {
283                                         if ((cmd->arg3 == 1 &&
284                                                 fid->src_ip ==
285                                                 state->flow_id.src_ip) ||
286                                                 (cmd->arg3 == 2 &&
287                                                 fid->src_port ==
288                                                 state->flow_id.src_port) ||
289                                                 (cmd->arg3 == 3 &&
290                                                 fid->dst_ip ==
291                                                 state->flow_id.dst_ip) ||
292                                                 (cmd->arg3 == 4 &&
293                                                 fid->dst_port ==
294                                                 state->flow_id.dst_port)) {
295
296                                                 count++;
297                                                 if (count >= cmd->arg1) {
298                                                         *limited = 1;
299                                                         goto done;
300                                                 }
301                                         }
302                                 }
303
304                                 if (fid->proto == state->flow_id.proto) {
305                                         if (fid->src_ip ==
306                                         state->flow_id.src_ip &&
307                                         fid->dst_ip ==
308                                         state->flow_id.dst_ip &&
309                                         (fid->src_port ==
310                                         state->flow_id.src_port ||
311                                         state->flow_id.src_port == 0) &&
312                                         (fid->dst_port ==
313                                         state->flow_id.dst_port ||
314                                         state->flow_id.dst_port == 0)) {
315                                                 goto done;
316                                         }
317                                         if (fid->src_ip ==
318                                         state->flow_id.dst_ip &&
319                                         fid->dst_ip ==
320                                         state->flow_id.src_ip &&
321                                         (fid->src_port ==
322                                         state->flow_id.dst_port ||
323                                         state->flow_id.dst_port == 0) &&
324                                         (fid->dst_port ==
325                                         state->flow_id.src_port ||
326                                         state->flow_id.src_port == 0)) {
327                                                 goto done;
328                                         }
329                                 }
330                                 state = state->next;
331                         }
332                 }
333         }
334 done:
335         return state;
336 }
337
338 static struct ip_fw_state *
339 install_state(struct ip_fw *rule, ipfw_insn *cmd, struct ip_fw_args *args)
340 {
341         struct ip_fw_state *state;
342         struct ipfw_context *ctx = ipfw_ctx[mycpuid];
343         struct ipfw_state_context *state_ctx;
344         state_ctx = &ctx->state_ctx[hash_packet(&args->f_id)];
345         state = kmalloc(sizeof(struct ip_fw_state),
346                         M_IPFW3_BASIC, M_NOWAIT | M_ZERO);
347         if (state == NULL) {
348                 return NULL;
349         }
350         state->stub = rule;
351         state->lifetime = cmd->arg2 == 0 ? state_lifetime : cmd->arg2 ;
352         state->timestamp = time_second;
353         state->expiry = 0;
354         bcopy(&args->f_id,&state->flow_id,sizeof(struct ipfw_flow_id));
355         //append the state into the state chian
356         if (state_ctx->last != NULL)
357                 state_ctx->last->next = state;
358         else
359                 state_ctx->state = state;
360         state_ctx->last = state;
361         state_ctx->count++;
362         return state;
363 }
364
365
366 static int
367 iface_match(struct ifnet *ifp, ipfw_insn_if *cmd)
368 {
369         if (ifp == NULL)        /* no iface with this packet, match fails */
370                 return 0;
371
372         /* Check by name or by IP address */
373         if (cmd->name[0] != '\0') { /* match by name */
374                 /* Check name */
375                 if (cmd->p.glob) {
376                         if (kfnmatch(cmd->name, ifp->if_xname, 0) == 0)
377                                 return(1);
378                 } else {
379                         if (strncmp(ifp->if_xname, cmd->name, IFNAMSIZ) == 0)
380                                 return(1);
381                 }
382         } else {
383                 struct ifaddr_container *ifac;
384
385                 TAILQ_FOREACH(ifac, &ifp->if_addrheads[mycpuid], ifa_link) {
386                         struct ifaddr *ia = ifac->ifa;
387
388                         if (ia->ifa_addr == NULL)
389                                 continue;
390                         if (ia->ifa_addr->sa_family != AF_INET)
391                                 continue;
392                         if (cmd->p.ip.s_addr ==
393                                 ((struct sockaddr_in *)
394                                 (ia->ifa_addr))->sin_addr.s_addr)
395                                         return(1);      /* match */
396
397                 }
398         }
399         return 0;       /* no match, fail ... */
400 }
401
402 /* implimentation of the checker functions */
403 void
404 check_count(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
405         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
406 {
407         (*f)->pcnt++;
408         (*f)->bcnt += ip_len;
409         (*f)->timestamp = time_second;
410         *cmd_ctl = IP_FW_CTL_NEXT;
411 }
412
413 void
414 check_skipto(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
415         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
416 {
417         (*f)->pcnt++;
418         (*f)->bcnt += ip_len;
419         (*f)->timestamp = time_second;
420         if ((*f)->next_rule == NULL)
421                 lookup_next_rule(*f);
422
423         *cmd_ctl = IP_FW_CTL_AGAIN;
424 }
425
426 void
427 check_forward(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
428         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
429 {
430         struct sockaddr_in *sin, *sa;
431         struct m_tag *mtag;
432
433         if ((*args)->eh) {      /* not valid on layer2 pkts */
434                 *cmd_ctl=IP_FW_CTL_NEXT;
435                 return;
436         }
437
438         (*f)->pcnt++;
439         (*f)->bcnt += ip_len;
440         (*f)->timestamp = time_second;
441         if ((*f)->next_rule == NULL)
442                 lookup_next_rule(*f);
443
444         mtag = m_tag_get(PACKET_TAG_IPFORWARD,
445                         sizeof(*sin), M_NOWAIT);
446         if (mtag == NULL) {
447                 *cmd_val = IP_FW_DENY;
448                 *cmd_ctl = IP_FW_CTL_DONE;
449                 return;
450         }
451         sin = m_tag_data(mtag);
452         sa = &((ipfw_insn_sa *)cmd)->sa;
453         /* arg3: count of the dest, arg1: type of fwd */
454         int i = 0;
455         if(cmd->arg3 > 1) {
456                 if (cmd->arg1 == 0) {           /* type: random */
457                         i = krandom() % cmd->arg3;
458                 } else if (cmd->arg1 == 1) {    /* type: round-robin */
459                         i = cmd->arg2++ % cmd->arg3;
460                 } else if (cmd->arg1 == 2) {    /* type: sticky */
461                         struct ip *ip = mtod((*args)->m, struct ip *);
462                         i = ip->ip_src.s_addr & (cmd->arg3 - 1);
463                 }
464                 sa += i;
465         }
466         *sin = *sa;     /* apply the destination */
467         m_tag_prepend((*args)->m, mtag);
468         (*args)->m->m_pkthdr.fw_flags |= IPFORWARD_MBUF_TAGGED;
469         (*args)->m->m_pkthdr.fw_flags &= ~BRIDGE_MBUF_TAGGED;
470         *cmd_ctl = IP_FW_CTL_DONE;
471         *cmd_val = IP_FW_PASS;
472 }
473
474 void
475 check_check_state(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
476         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
477 {
478         struct ip_fw_state *state=NULL;
479         int limited = 0 ;
480         state = lookup_state(*args, cmd, &limited, 0);
481         if (state != NULL) {
482                 state->pcnt++;
483                 state->bcnt += ip_len;
484                 state->timestamp = time_second;
485                 (*f)->pcnt++;
486                 (*f)->bcnt += ip_len;
487                 (*f)->timestamp = time_second;
488                 *f = state->stub;
489                 *cmd_ctl = IP_FW_CTL_CHK_STATE;
490         } else {
491                 *cmd_ctl = IP_FW_CTL_NEXT;
492         }
493 }
494
495 void
496 check_in(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
497         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
498 {
499         *cmd_ctl = IP_FW_CTL_NO;
500         *cmd_val = ((*args)->oif == NULL);
501 }
502
503 void
504 check_out(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
505         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
506 {
507         *cmd_ctl = IP_FW_CTL_NO;
508         *cmd_val = ((*args)->oif != NULL);
509 }
510
511 void
512 check_via(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
513         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
514 {
515         *cmd_ctl = IP_FW_CTL_NO;
516         *cmd_val = iface_match((*args)->oif ?
517                         (*args)->oif : (*args)->m->m_pkthdr.rcvif,
518                         (ipfw_insn_if *)cmd);
519 }
520
521 void
522 check_proto(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
523         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
524 {
525         *cmd_ctl = IP_FW_CTL_NO;
526         *cmd_val = ((*args)->f_id.proto == cmd->arg1);
527 }
528
529 void
530 check_prob(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
531         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
532 {
533         *cmd_ctl = IP_FW_CTL_NO;
534         *cmd_val = (krandom() % 100) < cmd->arg1;
535 }
536
537 void
538 check_from(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
539         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
540 {
541         struct in_addr src_ip;
542         u_int hlen = 0;
543         struct mbuf *m = (*args)->m;
544         struct ip *ip = mtod(m, struct ip *);
545         src_ip = ip->ip_src;
546         if ((*args)->eh == NULL ||
547                 (m->m_pkthdr.len >= sizeof(struct ip) &&
548                 ntohs((*args)->eh->ether_type) == ETHERTYPE_IP)) {
549                 hlen = ip->ip_hl << 2;
550         }
551         *cmd_val = (hlen > 0 &&
552                         ((ipfw_insn_ip *)cmd)->addr.s_addr == src_ip.s_addr);
553         *cmd_ctl = IP_FW_CTL_NO;
554 }
555
556 void
557 check_from_lookup(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
558         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
559 {
560         struct ipfw_context *ctx = ipfw_ctx[mycpuid];
561         struct ipfw_table_context *table_ctx;
562         struct radix_node_head *rnh;
563         struct sockaddr_in sa;
564
565         struct mbuf *m = (*args)->m;
566         struct ip *ip = mtod(m, struct ip *);
567         struct in_addr src_ip = ip->ip_src;
568
569         *cmd_val = IP_FW_NOT_MATCH;
570
571         table_ctx = ctx->table_ctx;
572         table_ctx += cmd->arg1;
573
574         if (table_ctx->type != 0) {
575                 rnh = table_ctx->node;
576                 sa.sin_len = 8;
577                 sa.sin_addr.s_addr = src_ip.s_addr;
578                 if(rnh->rnh_lookup((char *)&sa, NULL, rnh) != NULL)
579                         *cmd_val = IP_FW_MATCH;
580         }
581         *cmd_ctl = IP_FW_CTL_NO;
582 }
583
584 void
585 check_from_me(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
586         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
587 {
588         struct in_addr src_ip;
589         u_int hlen = 0;
590         struct mbuf *m = (*args)->m;
591         struct ip *ip = mtod(m, struct ip *);
592         src_ip = ip->ip_src;
593         if ((*args)->eh == NULL ||
594                 (m->m_pkthdr.len >= sizeof(struct ip) &&
595                 ntohs((*args)->eh->ether_type) == ETHERTYPE_IP)) {
596                 hlen = ip->ip_hl << 2;
597         }
598         *cmd_ctl = IP_FW_CTL_NO;
599         if (hlen > 0) {
600                 struct ifnet *tif;
601                 tif = INADDR_TO_IFP(&src_ip);
602                 *cmd_val = (tif != NULL);
603         } else {
604                 *cmd_val = IP_FW_NOT_MATCH;
605         }
606 }
607
608 void
609 check_from_mask(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
610         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
611 {
612         struct in_addr src_ip;
613         u_int hlen = 0;
614         struct mbuf *m = (*args)->m;
615         struct ip *ip = mtod(m, struct ip *);
616         src_ip = ip->ip_src;
617         if ((*args)->eh == NULL ||
618                 (m->m_pkthdr.len >= sizeof(struct ip) &&
619                 ntohs((*args)->eh->ether_type) == ETHERTYPE_IP)) {
620                 hlen = ip->ip_hl << 2;
621         }
622
623         *cmd_ctl = IP_FW_CTL_NO;
624         *cmd_val = (hlen > 0 &&
625                         ((ipfw_insn_ip *)cmd)->addr.s_addr ==
626                         (src_ip.s_addr &
627                         ((ipfw_insn_ip *)cmd)->mask.s_addr));
628 }
629
630 void
631 check_to(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
632         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
633 {
634         struct in_addr dst_ip;
635         u_int hlen = 0;
636         struct mbuf *m = (*args)->m;
637         struct ip *ip = mtod(m, struct ip *);
638         dst_ip = ip->ip_dst;
639         if ((*args)->eh == NULL ||
640                 (m->m_pkthdr.len >= sizeof(struct ip) &&
641                  ntohs((*args)->eh->ether_type) == ETHERTYPE_IP)) {
642                 hlen = ip->ip_hl << 2;
643         }
644         *cmd_val = (hlen > 0 &&
645                         ((ipfw_insn_ip *)cmd)->addr.s_addr == dst_ip.s_addr);
646         *cmd_ctl = IP_FW_CTL_NO;
647 }
648
649 void
650 check_to_lookup(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
651         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
652 {
653         struct ipfw_context *ctx = ipfw_ctx[mycpuid];
654         struct ipfw_table_context *table_ctx;
655         struct radix_node_head *rnh;
656         struct sockaddr_in sa;
657
658         struct mbuf *m = (*args)->m;
659         struct ip *ip = mtod(m, struct ip *);
660         struct in_addr dst_ip = ip->ip_dst;
661
662         *cmd_val = IP_FW_NOT_MATCH;
663
664         table_ctx = ctx->table_ctx;
665         table_ctx += cmd->arg1;
666
667         if (table_ctx->type != 0) {
668                 rnh = table_ctx->node;
669                 sa.sin_len = 8;
670                 sa.sin_addr.s_addr = dst_ip.s_addr;
671                 if(rnh->rnh_lookup((char *)&sa, NULL, rnh) != NULL)
672                         *cmd_val = IP_FW_MATCH;
673         }
674         *cmd_ctl = IP_FW_CTL_NO;
675 }
676
677 void
678 check_to_me(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
679         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
680 {
681         struct in_addr dst_ip;
682         u_int hlen = 0;
683         struct mbuf *m = (*args)->m;
684         struct ip *ip = mtod(m, struct ip *);
685         dst_ip = ip->ip_src;
686         if ((*args)->eh == NULL ||
687                 (m->m_pkthdr.len >= sizeof(struct ip) &&
688                 ntohs((*args)->eh->ether_type) == ETHERTYPE_IP)) {
689                 hlen = ip->ip_hl << 2;
690         }
691         *cmd_ctl = IP_FW_CTL_NO;
692         if (hlen > 0) {
693                 struct ifnet *tif;
694                 tif = INADDR_TO_IFP(&dst_ip);
695                 *cmd_val = (tif != NULL);
696         } else {
697                 *cmd_val = IP_FW_NOT_MATCH;
698         }
699 }
700
701 void
702 check_to_mask(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
703         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
704 {
705         struct in_addr dst_ip;
706         u_int hlen = 0;
707         struct mbuf *m = (*args)->m;
708         struct ip *ip = mtod(m, struct ip *);
709         dst_ip = ip->ip_src;
710         if ((*args)->eh == NULL ||
711                 (m->m_pkthdr.len >= sizeof(struct ip) &&
712                 ntohs((*args)->eh->ether_type) == ETHERTYPE_IP)) {
713                 hlen = ip->ip_hl << 2;
714         }
715
716         *cmd_ctl = IP_FW_CTL_NO;
717         *cmd_val = (hlen > 0 &&
718                         ((ipfw_insn_ip *)cmd)->addr.s_addr ==
719                         (dst_ip.s_addr &
720                         ((ipfw_insn_ip *)cmd)->mask.s_addr));
721 }
722
723 void
724 check_keep_state(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
725         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
726 {
727         struct ip_fw_state *state;
728         int limited = 0;
729
730         *cmd_ctl = IP_FW_CTL_NO;
731         state = lookup_state(*args, cmd, &limited, 1);
732         if (limited == 1) {
733                 *cmd_val = IP_FW_NOT_MATCH;
734         } else {
735                 if (state == NULL)
736                         state = install_state(*f, cmd, *args);
737
738                 if (state != NULL) {
739                         state->pcnt++;
740                         state->bcnt += ip_len;
741                         state->timestamp = time_second;
742                         *cmd_val = IP_FW_MATCH;
743                 } else {
744                         *cmd_val = IP_FW_NOT_MATCH;
745                 }
746         }
747 }
748
749 void
750 check_tag(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
751         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
752 {
753         struct m_tag *mtag = m_tag_locate((*args)->m,
754                         MTAG_IPFW, cmd->arg1, NULL);
755         if (mtag == NULL) {
756                 mtag = m_tag_alloc(MTAG_IPFW,cmd->arg1, 0, M_NOWAIT);
757                 if (mtag != NULL)
758                         m_tag_prepend((*args)->m, mtag);
759
760         }
761         (*f)->pcnt++;
762         (*f)->bcnt += ip_len;
763         (*f)->timestamp = time_second;
764         *cmd_ctl = IP_FW_CTL_NEXT;
765 }
766
767 void
768 check_untag(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
769         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
770 {
771         struct m_tag *mtag = m_tag_locate((*args)->m,
772                         MTAG_IPFW, cmd->arg1, NULL);
773         if (mtag != NULL)
774                 m_tag_delete((*args)->m, mtag);
775
776         (*f)->pcnt++;
777         (*f)->bcnt += ip_len;
778         (*f)->timestamp = time_second;
779         *cmd_ctl = IP_FW_CTL_NEXT;
780 }
781
782 void
783 check_tagged(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
784         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
785 {
786         *cmd_ctl = IP_FW_CTL_NO;
787         if (m_tag_locate( (*args)->m, MTAG_IPFW,cmd->arg1, NULL) != NULL )
788                 *cmd_val = IP_FW_MATCH;
789         else
790                 *cmd_val = IP_FW_NOT_MATCH;
791 }
792
793 static void
794 ipfw_basic_add_state(struct ipfw_ioc_state *ioc_state)
795 {
796         struct ip_fw_state *state;
797         struct ipfw_context *ctx = ipfw_ctx[mycpuid];
798         struct ipfw_state_context *state_ctx;
799         state_ctx = &ctx->state_ctx[hash_packet(&(ioc_state->flow_id))];
800         state = kmalloc(sizeof(struct ip_fw_state),
801                         M_IPFW3_BASIC, M_WAITOK | M_ZERO);
802         struct ip_fw *rule = ctx->ipfw_rule_chain;
803         while (rule != NULL) {
804                 if (rule->rulenum == ioc_state->rulenum) {
805                         break;
806                 }
807                 rule = rule->next;
808         }
809         if (rule == NULL)
810                 return;
811
812         state->stub = rule;
813
814         state->lifetime = ioc_state->lifetime == 0 ?
815                 state_lifetime : ioc_state->lifetime ;
816         state->timestamp = time_second;
817         state->expiry = ioc_state->expiry;
818         bcopy(&ioc_state->flow_id, &state->flow_id,
819                         sizeof(struct ipfw_flow_id));
820         //append the state into the state chian
821         if (state_ctx->last != NULL)
822                 state_ctx->last->next = state;
823         else
824                 state_ctx->state = state;
825
826         state_ctx->last = state;
827         state_ctx->count++;
828 }
829
830 /*
831  * if rule is NULL
832  *              flush all states
833  * else
834  *              flush states which stub is the rule
835  */
836 static void
837 ipfw_basic_flush_state(struct ip_fw *rule)
838 {
839         struct ipfw_state_context *state_ctx;
840         struct ip_fw_state *state,*the_state, *prev_state;
841         struct ipfw_context *ctx;
842         int i;
843
844         ctx = ipfw_ctx[mycpuid];
845         for (i = 0; i < state_hash_size; i++) {
846                 state_ctx = &ctx->state_ctx[i];
847                 if (state_ctx != NULL) {
848                         state = state_ctx->state;
849                         prev_state = NULL;
850                         while (state != NULL) {
851                                 if (rule != NULL && state->stub != rule) {
852                                         prev_state = state;
853                                         state = state->next;
854                                 } else {
855                                         if (prev_state == NULL)
856                                                 state_ctx->state = state->next;
857                                         else
858                                                 prev_state->next = state->next;
859
860                                         the_state = state;
861                                         state = state->next;
862                                         kfree(the_state, M_IPFW3_BASIC);
863                                         state_ctx->count--;
864                                         if (state == NULL)
865                                                 state_ctx->last = prev_state;
866
867                                 }
868                         }
869                 }
870         }
871 }
872
873 /*
874  * clean up expired state in every tick
875  */
876 static void
877 ipfw_cleanup_expired_state(netmsg_t nmsg)
878 {
879         struct ip_fw_state *state,*the_state,*prev_state;
880         struct ipfw_context *ctx = ipfw_ctx[mycpuid];
881         struct ipfw_state_context *state_ctx;
882         int i;
883
884         for (i = 0; i < state_hash_size; i++) {
885                 prev_state = NULL;
886                 state_ctx = &(ctx->state_ctx[i]);
887                 if (ctx->state_ctx != NULL) {
888                         state = state_ctx->state;
889                         while (state != NULL) {
890                                 if (IS_EXPIRED(state)) {
891                                         if (prev_state == NULL)
892                                                 state_ctx->state = state->next;
893                                         else
894                                                 prev_state->next = state->next;
895
896                                         the_state =state;
897                                         state = state->next;
898
899                                         if (the_state == state_ctx->last)
900                                                 state_ctx->last = NULL;
901
902
903                                         kfree(the_state, M_IPFW3_BASIC);
904                                         state_ctx->count--;
905                                 } else {
906                                         prev_state = state;
907                                         state = state->next;
908                                 }
909                         }
910                 }
911         }
912         ifnet_forwardmsg(&nmsg->lmsg, mycpuid + 1);
913 }
914
915 static void
916 ipfw_tick(void *dummy __unused)
917 {
918         struct lwkt_msg *lmsg = &ipfw_timeout_netmsg.lmsg;
919         KKASSERT(mycpuid == IPFW_CFGCPUID);
920
921         crit_enter();
922         KKASSERT(lmsg->ms_flags & MSGF_DONE);
923         if (IPFW_BASIC_LOADED) {
924                 lwkt_sendmsg_oncpu(IPFW_CFGPORT, lmsg);
925                 /* ipfw_timeout_netmsg's handler reset this callout */
926         }
927         crit_exit();
928
929         struct netmsg_base *msg;
930         struct netmsg_base the_msg;
931         msg = &the_msg;
932         bzero(msg,sizeof(struct netmsg_base));
933
934         netmsg_init(msg, NULL, &curthread->td_msgport, 0,
935                         ipfw_cleanup_expired_state);
936         ifnet_domsg(&msg->lmsg, 0);
937 }
938
939 static void
940 ipfw_tick_dispatch(netmsg_t nmsg)
941 {
942         IPFW_ASSERT_CFGPORT(&curthread->td_msgport);
943         KKASSERT(IPFW_BASIC_LOADED);
944
945         /* Reply ASAP */
946         crit_enter();
947         lwkt_replymsg(&nmsg->lmsg, 0);
948         crit_exit();
949
950         callout_reset(&ipfw_tick_callout,
951                         state_expiry_check_interval * hz, ipfw_tick, NULL);
952 }
953
954 static void
955 ipfw_basic_init_dispatch(netmsg_t nmsg)
956 {
957         IPFW_ASSERT_CFGPORT(&curthread->td_msgport);
958         KKASSERT(IPFW3_LOADED);
959
960         int error = 0;
961         callout_init_mp(&ipfw_tick_callout);
962         netmsg_init(&ipfw_timeout_netmsg, NULL, &netisr_adone_rport,
963                         MSGF_DROPABLE | MSGF_PRIORITY, ipfw_tick_dispatch);
964         callout_reset(&ipfw_tick_callout,
965                         state_expiry_check_interval * hz, ipfw_tick, NULL);
966         lwkt_replymsg(&nmsg->lmsg, error);
967         ip_fw_basic_loaded=1;
968 }
969
970 static int
971 ipfw_basic_init(void)
972 {
973         ipfw_basic_flush_state_prt = ipfw_basic_flush_state;
974         ipfw_basic_append_state_prt = ipfw_basic_add_state;
975
976         register_ipfw_module(MODULE_BASIC_ID, MODULE_BASIC_NAME);
977         register_ipfw_filter_funcs(MODULE_BASIC_ID, O_BASIC_COUNT,
978                         (filter_func)check_count);
979         register_ipfw_filter_funcs(MODULE_BASIC_ID, O_BASIC_SKIPTO,
980                         (filter_func)check_skipto);
981         register_ipfw_filter_funcs(MODULE_BASIC_ID, O_BASIC_FORWARD,
982                         (filter_func)check_forward);
983         register_ipfw_filter_funcs(MODULE_BASIC_ID, O_BASIC_KEEP_STATE,
984                         (filter_func)check_keep_state);
985         register_ipfw_filter_funcs(MODULE_BASIC_ID, O_BASIC_CHECK_STATE,
986                         (filter_func)check_check_state);
987
988         register_ipfw_filter_funcs(MODULE_BASIC_ID,
989                         O_BASIC_IN, (filter_func)check_in);
990         register_ipfw_filter_funcs(MODULE_BASIC_ID,
991                         O_BASIC_OUT, (filter_func)check_out);
992         register_ipfw_filter_funcs(MODULE_BASIC_ID,
993                         O_BASIC_VIA, (filter_func)check_via);
994         register_ipfw_filter_funcs(MODULE_BASIC_ID,
995                         O_BASIC_XMIT, (filter_func)check_via);
996         register_ipfw_filter_funcs(MODULE_BASIC_ID,
997                         O_BASIC_RECV, (filter_func)check_via);
998
999         register_ipfw_filter_funcs(MODULE_BASIC_ID,
1000                         O_BASIC_PROTO, (filter_func)check_proto);
1001         register_ipfw_filter_funcs(MODULE_BASIC_ID,
1002                         O_BASIC_PROB, (filter_func)check_prob);
1003         register_ipfw_filter_funcs(MODULE_BASIC_ID,
1004                         O_BASIC_IP_SRC, (filter_func)check_from);
1005         register_ipfw_filter_funcs(MODULE_BASIC_ID,
1006                         O_BASIC_IP_SRC_LOOKUP, (filter_func)check_from_lookup);
1007         register_ipfw_filter_funcs(MODULE_BASIC_ID,
1008                         O_BASIC_IP_SRC_ME, (filter_func)check_from_me);
1009         register_ipfw_filter_funcs(MODULE_BASIC_ID,
1010                         O_BASIC_IP_SRC_MASK, (filter_func)check_from_mask);
1011         register_ipfw_filter_funcs(MODULE_BASIC_ID,
1012                         O_BASIC_IP_DST, (filter_func)check_to);
1013         register_ipfw_filter_funcs(MODULE_BASIC_ID,
1014                         O_BASIC_IP_DST_LOOKUP, (filter_func)check_to_lookup);
1015         register_ipfw_filter_funcs(MODULE_BASIC_ID,
1016                         O_BASIC_IP_DST_ME, (filter_func)check_to_me);
1017         register_ipfw_filter_funcs(MODULE_BASIC_ID,
1018                         O_BASIC_IP_DST_MASK, (filter_func)check_to_mask);
1019         register_ipfw_filter_funcs(MODULE_BASIC_ID,
1020                         O_BASIC_TAG, (filter_func)check_tag);
1021         register_ipfw_filter_funcs(MODULE_BASIC_ID,
1022                         O_BASIC_UNTAG, (filter_func)check_untag);
1023         register_ipfw_filter_funcs(MODULE_BASIC_ID,
1024                         O_BASIC_TAGGED, (filter_func)check_tagged);
1025
1026         int cpu;
1027         struct ipfw_context *ctx;
1028
1029         for (cpu = 0; cpu < ncpus; cpu++) {
1030                 ctx = ipfw_ctx[cpu];
1031                 if (ctx != NULL) {
1032                         ctx->state_ctx = kmalloc(state_hash_size *
1033                                         sizeof(struct ipfw_state_context),
1034                                         M_IPFW3_BASIC, M_WAITOK | M_ZERO);
1035                         ctx->state_hash_size = state_hash_size;
1036                 }
1037         }
1038
1039         struct netmsg_base smsg;
1040         netmsg_init(&smsg, NULL, &curthread->td_msgport,
1041                         0, ipfw_basic_init_dispatch);
1042         lwkt_domsg(IPFW_CFGPORT, &smsg.lmsg, 0);
1043         return 0;
1044 }
1045
1046 static void
1047 ipfw_basic_stop_dispatch(netmsg_t nmsg)
1048 {
1049         IPFW_ASSERT_CFGPORT(&curthread->td_msgport);
1050         KKASSERT(IPFW3_LOADED);
1051         int error = 0;
1052         callout_stop(&ipfw_tick_callout);
1053         netmsg_service_sync();
1054         crit_enter();
1055         lwkt_dropmsg(&ipfw_timeout_netmsg.lmsg);
1056         crit_exit();
1057         lwkt_replymsg(&nmsg->lmsg, error);
1058         ip_fw_basic_loaded=0;
1059 }
1060
1061 static int
1062 ipfw_basic_stop(void)
1063 {
1064         int cpu,i;
1065         struct ipfw_state_context *state_ctx;
1066         struct ip_fw_state *state,*the_state;
1067         struct ipfw_context *ctx;
1068         if (unregister_ipfw_module(MODULE_BASIC_ID) ==0 ) {
1069                 ipfw_basic_flush_state_prt = NULL;
1070                 ipfw_basic_append_state_prt = NULL;
1071
1072                 for (cpu = 0; cpu < ncpus; cpu++) {
1073                         ctx = ipfw_ctx[cpu];
1074                         if (ctx != NULL) {
1075                                 for (i = 0; i < state_hash_size; i++) {
1076                                         state_ctx = &ctx->state_ctx[i];
1077                                         if (state_ctx != NULL) {
1078                                                 state = state_ctx->state;
1079                                                 while (state != NULL) {
1080                                                         the_state = state;
1081                                                         state = state->next;
1082                                                         if (the_state ==
1083                                                                 state_ctx->last)
1084                                                         state_ctx->last = NULL;
1085
1086                                                         kfree(the_state,
1087                                                                 M_IPFW3_BASIC);
1088                                                 }
1089                                         }
1090                                 }
1091                                 ctx->state_hash_size = 0;
1092                                 kfree(ctx->state_ctx, M_IPFW3_BASIC);
1093                                 ctx->state_ctx = NULL;
1094                         }
1095                 }
1096                 struct netmsg_base smsg;
1097                 netmsg_init(&smsg, NULL, &curthread->td_msgport,
1098                                 0, ipfw_basic_stop_dispatch);
1099                 return lwkt_domsg(IPFW_CFGPORT, &smsg.lmsg, 0);
1100         }
1101         return 1;
1102 }
1103
1104
1105 static int
1106 ipfw3_basic_modevent(module_t mod, int type, void *data)
1107 {
1108         int err;
1109         switch (type) {
1110                 case MOD_LOAD:
1111                         err = ipfw_basic_init();
1112                         break;
1113                 case MOD_UNLOAD:
1114                         err = ipfw_basic_stop();
1115                         break;
1116                 default:
1117                         err = 1;
1118         }
1119         return err;
1120 }
1121
1122 static moduledata_t ipfw3_basic_mod = {
1123         "ipfw3_basic",
1124         ipfw3_basic_modevent,
1125         NULL
1126 };
1127 DECLARE_MODULE(ipfw3_basic, ipfw3_basic_mod, SI_SUB_PROTO_END, SI_ORDER_ANY);
1128 MODULE_DEPEND(ipfw3_basic, ipfw3, 1, 1, 1);
1129 MODULE_VERSION(ipfw3_basic, 1);