e4867c4b5d6bbe43627f9b608e23595614da2985
[dragonfly.git] / sys / net / ipfw2_basic / ip_fw2_basic.c
1 /*
2  * Copyright (c) 2014 The DragonFly Project.  All rights reserved.
3  *
4  * This code is derived from software contributed to The DragonFly Project
5  * by Bill Yuan <bycn82@gmail.com>
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  *
11  * 1. Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in
15  *    the documentation and/or other materials provided with the
16  *    distribution.
17  * 3. Neither the name of The DragonFly Project nor the names of its
18  *    contributors may be used to endorse or promote products derived
19  *    from this software without specific, prior written permission.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
24  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE
25  * COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
26  * INCIDENTAL, SPECIAL, EXEMPLARY OR CONSEQUENTIAL DAMAGES (INCLUDING,
27  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
28  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
29  * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
30  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
31  * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
32  * SUCH DAMAGE.
33  */
34
35 #include <sys/param.h>
36 #include <sys/kernel.h>
37 #include <sys/malloc.h>
38 #include <sys/mbuf.h>
39 #include <sys/socketvar.h>
40 #include <sys/sysctl.h>
41 #include <sys/syslog.h>
42 #include <sys/systimer.h>
43 #include <sys/thread2.h>
44 #include <sys/in_cksum.h>
45
46 #include <net/if.h>
47 #include <net/ethernet.h>
48 #include <net/netmsg2.h>
49 #include <net/netisr2.h>
50 #include <net/route.h>
51
52 #include <netinet/ip.h>
53 #include <netinet/in.h>
54 #include <netinet/in_systm.h>
55 #include <netinet/in_var.h>
56 #include <netinet/in_pcb.h>
57 #include <netinet/ip_var.h>
58 #include <netinet/ip_icmp.h>
59 #include <netinet/tcp.h>
60 #include <netinet/tcp_timer.h>
61 #include <netinet/tcp_var.h>
62 #include <netinet/tcpip.h>
63 #include <netinet/udp.h>
64 #include <netinet/udp_var.h>
65 #include <netinet/ip_divert.h>
66 #include <netinet/if_ether.h>
67
68 #include <net/ipfw2/ip_fw.h>
69
70 #include "ip_fw2_basic.h"
71
72 #define TIME_LEQ(a, b)  ((int)((a) - (b)) <= 0)
73
74 extern struct ipfw_context      *ipfw_ctx[MAXCPU];
75 extern int fw_verbose;
76 extern ipfw_basic_delete_state_t *ipfw_basic_flush_state_prt;
77 extern ipfw_basic_append_state_t *ipfw_basic_append_state_prt;
78
79 static int ip_fw_basic_loaded;
80 static struct netmsg_base ipfw_timeout_netmsg;  /* schedule ipfw timeout */
81 static struct callout ipfw_tick_callout;
82 static int state_lifetime = 20;
83 static int state_expiry_check_interval = 10;
84 static int state_count_max = 4096;
85 static int state_hash_size_old = 0;
86 static int state_hash_size = 4096;
87
88
89 static int ipfw_sysctl_adjust_hash_size(SYSCTL_HANDLER_ARGS);
90 void adjust_hash_size_dispatch(netmsg_t nmsg);
91
92 SYSCTL_NODE(_net_inet_ip, OID_AUTO, fw_basic,
93                 CTLFLAG_RW, 0, "Firewall Basic");
94 SYSCTL_PROC(_net_inet_ip_fw_basic, OID_AUTO, state_hash_size,
95                 CTLTYPE_INT | CTLFLAG_RW, &state_hash_size, 0,
96                 ipfw_sysctl_adjust_hash_size, "I", "Adjust hash size");
97
98 SYSCTL_INT(_net_inet_ip_fw_basic, OID_AUTO, state_lifetime, CTLFLAG_RW,
99                 &state_lifetime, 0, "default life time");
100 SYSCTL_INT(_net_inet_ip_fw_basic, OID_AUTO,
101                 state_expiry_check_interval, CTLFLAG_RW,
102                 &state_expiry_check_interval, 0,
103                 "default state expiry check interval");
104 SYSCTL_INT(_net_inet_ip_fw_basic, OID_AUTO, state_count_max, CTLFLAG_RW,
105                 &state_count_max, 0, "maximum of state");
106
107 static int
108 ipfw_sysctl_adjust_hash_size(SYSCTL_HANDLER_ARGS)
109 {
110         int error, value = 0;
111
112         state_hash_size_old = state_hash_size;
113         value = state_hash_size;
114         error = sysctl_handle_int(oidp, &value, 0, req);
115         if (error || !req->newptr) {
116                 goto back;
117         }
118         /*
119          * Make sure we have a power of 2 and
120          * do not allow more than 64k entries.
121          */
122         error = EINVAL;
123         if (value <= 1 || value > 65536) {
124                 goto back;
125         }
126         if ((value & (value - 1)) != 0) {
127                 goto back;
128         }
129
130         error = 0;
131         if (state_hash_size != value) {
132                 state_hash_size = value;
133
134                 struct netmsg_base *msg, the_msg;
135                 msg = &the_msg;
136                 bzero(msg,sizeof(struct netmsg_base));
137
138                 netmsg_init(msg, NULL, &curthread->td_msgport,
139                                 0, adjust_hash_size_dispatch);
140                 ifnet_domsg(&msg->lmsg, 0);
141         }
142 back:
143         return error;
144 }
145
146 void
147 adjust_hash_size_dispatch(netmsg_t nmsg)
148 {
149         struct ipfw_state_context *state_ctx;
150         struct ip_fw_state *the_state, *state;
151         struct ipfw_context *ctx = ipfw_ctx[mycpuid];
152         int i;
153
154         for (i = 0; i < state_hash_size_old; i++) {
155                 state_ctx = &ctx->state_ctx[i];
156                 if (state_ctx != NULL) {
157                         state = state_ctx->state;
158                         while (state != NULL) {
159                                 the_state = state;
160                                 state = state->next;
161                                 kfree(the_state, M_IPFW2_BASIC);
162                                 the_state = NULL;
163                         }
164                 }
165         }
166         kfree(ctx->state_ctx,M_IPFW2_BASIC);
167         ctx->state_ctx = kmalloc(state_hash_size *
168                                 sizeof(struct ipfw_state_context),
169                                 M_IPFW2_BASIC, M_WAITOK | M_ZERO);
170         ctx->state_hash_size = state_hash_size;
171         ifnet_forwardmsg(&nmsg->lmsg, mycpuid + 1);
172 }
173
174
175 /*      prototype of the checker functions      */
176 void check_count(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
177         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
178 void check_skipto(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
179         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
180 void check_forward(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
181         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
182 void check_check_state(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
183         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
184
185 void check_in(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
186         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
187 void check_out(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
188         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
189 void check_via(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
190         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
191 void check_proto(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
192         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
193 void check_prob(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
194         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
195 void check_from(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
196         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
197 void check_to(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
198         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
199 void check_keep_state(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
200         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
201 void check_tag(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
202         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
203 void check_untag(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
204         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
205 void check_tagged(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
206         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
207
208 /*      prototype of the utility functions      */
209 static struct ip_fw *lookup_next_rule(struct ip_fw *me);
210 static int iface_match(struct ifnet *ifp, ipfw_insn_if *cmd);
211 static __inline int hash_packet(struct ipfw_flow_id *id);
212
213 static __inline int
214 hash_packet(struct ipfw_flow_id *id)
215 {
216         uint32_t i;
217         i = (id->proto) ^ (id->dst_ip) ^ (id->src_ip) ^
218                 (id->dst_port) ^ (id->src_port);
219         i &= state_hash_size - 1;
220         return i;
221 }
222
223 static struct ip_fw *
224 lookup_next_rule(struct ip_fw *me)
225 {
226         struct ip_fw *rule = NULL;
227         ipfw_insn *cmd;
228
229         /* look for action, in case it is a skipto */
230         cmd = ACTION_PTR(me);
231         if ((int)cmd->module == MODULE_BASIC_ID &&
232                 (int)cmd->opcode == O_BASIC_SKIPTO) {
233                 for (rule = me->next; rule; rule = rule->next) {
234                         if (rule->rulenum >= cmd->arg1)
235                                 break;
236                 }
237         }
238         if (rule == NULL) /* failure or not a skipto */
239                 rule = me->next;
240
241         me->next_rule = rule;
242         return rule;
243 }
244
245 /*
246  * when all = 1, it will check all the state_ctx
247  */
248 static struct ip_fw_state *
249 lookup_state(struct ip_fw_args *args, ipfw_insn *cmd, int *limited, int all)
250 {
251         struct ip_fw_state *state = NULL;
252         struct ipfw_context *ctx = ipfw_ctx[mycpuid];
253         struct ipfw_state_context *state_ctx;
254         int start, end, i, count = 0;
255
256         if (all && cmd->arg1) {
257                 start = 0;
258                 end = state_hash_size - 1;
259         } else {
260                 start = hash_packet(&args->f_id);
261                 end = hash_packet(&args->f_id);
262         }
263         for (i = start; i <= end; i++) {
264                 state_ctx = &ctx->state_ctx[i];
265                 if (state_ctx != NULL) {
266                         state = state_ctx->state;
267                         struct ipfw_flow_id     *fid = &args->f_id;
268                         while (state != NULL) {
269                                 if (cmd->arg1) {
270                                         if ((cmd->arg3 == 1 &&
271                                                 fid->src_ip ==
272                                                 state->flow_id.src_ip) ||
273                                                 (cmd->arg3 == 2 &&
274                                                 fid->src_port ==
275                                                 state->flow_id.src_port) ||
276                                                 (cmd->arg3 == 3 &&
277                                                 fid->dst_ip ==
278                                                 state->flow_id.dst_ip) ||
279                                                 (cmd->arg3 == 4 &&
280                                                 fid->dst_port ==
281                                                 state->flow_id.dst_port)) {
282
283                                                 count++;
284                                                 if (count >= cmd->arg1) {
285                                                         *limited = 1;
286                                                         goto done;
287                                                 }
288                                         }
289                                 }
290
291                                 if (fid->proto == state->flow_id.proto) {
292                                         if (fid->src_ip ==
293                                         state->flow_id.src_ip &&
294                                         fid->dst_ip ==
295                                         state->flow_id.dst_ip &&
296                                         (fid->src_port ==
297                                         state->flow_id.src_port ||
298                                         state->flow_id.src_port == 0) &&
299                                         (fid->dst_port ==
300                                         state->flow_id.dst_port ||
301                                         state->flow_id.dst_port == 0)) {
302                                                 goto done;
303                                         }
304                                         if (fid->src_ip ==
305                                         state->flow_id.dst_ip &&
306                                         fid->dst_ip ==
307                                         state->flow_id.src_ip &&
308                                         (fid->src_port ==
309                                         state->flow_id.dst_port ||
310                                         state->flow_id.dst_port == 0) &&
311                                         (fid->dst_port ==
312                                         state->flow_id.src_port ||
313                                         state->flow_id.src_port == 0)) {
314                                                 goto done;
315                                         }
316                                 }
317                                 state = state->next;
318                         }
319                 }
320         }
321 done:
322         return state;
323 }
324
325 static struct ip_fw_state *
326 install_state(struct ip_fw *rule, ipfw_insn *cmd, struct ip_fw_args *args)
327 {
328         struct ip_fw_state *state;
329         struct ipfw_context *ctx = ipfw_ctx[mycpuid];
330         struct ipfw_state_context *state_ctx;
331         state_ctx = &ctx->state_ctx[hash_packet(&args->f_id)];
332         state = kmalloc(sizeof(struct ip_fw_state),
333                         M_IPFW2_BASIC, M_NOWAIT | M_ZERO);
334         if (state == NULL) {
335                 return NULL;
336         }
337         state->stub = rule;
338         state->lifetime = cmd->arg2 == 0 ? state_lifetime : cmd->arg2 ;
339         state->timestamp = time_second;
340         state->expiry = 0;
341         bcopy(&args->f_id,&state->flow_id,sizeof(struct ipfw_flow_id));
342         //append the state into the state chian
343         if (state_ctx->last != NULL)
344                 state_ctx->last->next = state;
345         else
346                 state_ctx->state = state;
347         state_ctx->last = state;
348         state_ctx->count++;
349         return state;
350 }
351
352
353 static int
354 iface_match(struct ifnet *ifp, ipfw_insn_if *cmd)
355 {
356         if (ifp == NULL)        /* no iface with this packet, match fails */
357                 return 0;
358
359         /* Check by name or by IP address */
360         if (cmd->name[0] != '\0') { /* match by name */
361                 /* Check name */
362                 if (cmd->p.glob) {
363                         if (kfnmatch(cmd->name, ifp->if_xname, 0) == 0)
364                                 return(1);
365                 } else {
366                         if (strncmp(ifp->if_xname, cmd->name, IFNAMSIZ) == 0)
367                                 return(1);
368                 }
369         } else {
370                 struct ifaddr_container *ifac;
371
372                 TAILQ_FOREACH(ifac, &ifp->if_addrheads[mycpuid], ifa_link) {
373                         struct ifaddr *ia = ifac->ifa;
374
375                         if (ia->ifa_addr == NULL)
376                                 continue;
377                         if (ia->ifa_addr->sa_family != AF_INET)
378                                 continue;
379                         if (cmd->p.ip.s_addr ==
380                                 ((struct sockaddr_in *)
381                                 (ia->ifa_addr))->sin_addr.s_addr)
382                                         return(1);      /* match */
383
384                 }
385         }
386         return 0;       /* no match, fail ... */
387 }
388
389 /* implimentation of the checker functions */
390 void
391 check_count(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
392         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
393 {
394         (*f)->pcnt++;
395         (*f)->bcnt += ip_len;
396         (*f)->timestamp = time_second;
397         *cmd_ctl = IP_FW_CTL_NEXT;
398 }
399
400 void
401 check_skipto(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
402         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
403 {
404         (*f)->pcnt++;
405         (*f)->bcnt += ip_len;
406         (*f)->timestamp = time_second;
407         if ((*f)->next_rule == NULL)
408                 lookup_next_rule(*f);
409
410         *cmd_ctl = IP_FW_CTL_AGAIN;
411 }
412
413 void
414 check_forward(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
415         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
416 {
417         struct sockaddr_in *sin;
418         struct m_tag *mtag;
419
420         if ((*args)->eh) {      /* not valid on layer2 pkts */
421                 *cmd_ctl=IP_FW_CTL_NEXT;
422                 return;
423         }
424
425         (*f)->pcnt++;
426         (*f)->bcnt += ip_len;
427         (*f)->timestamp = time_second;
428         if ((*f)->next_rule == NULL)
429                 lookup_next_rule(*f);
430
431         mtag = m_tag_get(PACKET_TAG_IPFORWARD,
432                         sizeof(*sin), MB_DONTWAIT);
433         if (mtag == NULL) {
434                 *cmd_val = IP_FW_DENY;
435                 *cmd_ctl = IP_FW_CTL_DONE;
436                 return;
437         }
438         sin = m_tag_data(mtag);
439         *sin = ((ipfw_insn_sa *)cmd)->sa;
440         /* arg3: count of the dest, arg1: type of fwd */
441         int i;
442         if(cmd->arg3 == 1) {
443                 i = 0;
444         } else {
445                 if (cmd->arg1 == 0) {           /* type: random */
446                         i = krandom() % cmd->arg3;
447                 } else if (cmd->arg1 == 1) {    /* type: round-robin */
448                         i = cmd->arg2++ % cmd->arg3;
449                 } else if (cmd->arg1 == 2) {    /* type: sticky */
450                         struct ip *ip = mtod((*args)->m, struct ip *);
451                         i = ip->ip_src.s_addr & (cmd->arg3 - 1);
452                 }
453         }
454         sin += i;
455         m_tag_prepend((*args)->m, mtag);
456         (*args)->m->m_pkthdr.fw_flags |= IPFORWARD_MBUF_TAGGED;
457         (*args)->m->m_pkthdr.fw_flags &= ~BRIDGE_MBUF_TAGGED;
458         *cmd_ctl = IP_FW_CTL_DONE;
459         *cmd_val = IP_FW_PASS;
460 }
461
462 void
463 check_check_state(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
464         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
465 {
466         struct ip_fw_state *state=NULL;
467         int limited = 0 ;
468         state = lookup_state(*args, cmd, &limited, 0);
469         if (state != NULL) {
470                 state->pcnt++;
471                 state->bcnt += ip_len;
472                 state->timestamp = time_second;
473                 (*f)->pcnt++;
474                 (*f)->bcnt += ip_len;
475                 (*f)->timestamp = time_second;
476                 *f = state->stub;
477                 *cmd_ctl = IP_FW_CTL_CHK_STATE;
478         } else {
479                 *cmd_ctl = IP_FW_CTL_NEXT;
480         }
481 }
482
483 void
484 check_in(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
485         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
486 {
487         *cmd_ctl = IP_FW_CTL_NO;
488         *cmd_val = ((*args)->oif == NULL);
489 }
490
491 void
492 check_out(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
493         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
494 {
495         *cmd_ctl = IP_FW_CTL_NO;
496         *cmd_val = ((*args)->oif != NULL);
497 }
498
499 void
500 check_via(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
501         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
502 {
503         *cmd_ctl = IP_FW_CTL_NO;
504         *cmd_val = iface_match((*args)->oif ?
505                         (*args)->oif : (*args)->m->m_pkthdr.rcvif,
506                         (ipfw_insn_if *)cmd);
507 }
508
509 void
510 check_proto(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
511         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
512 {
513         *cmd_ctl = IP_FW_CTL_NO;
514         *cmd_val = ((*args)->f_id.proto == cmd->arg1);
515 }
516
517 void
518 check_prob(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
519         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
520 {
521         *cmd_ctl = IP_FW_CTL_NO;
522         *cmd_val = (krandom() % 100) < cmd->arg1;
523 }
524
525 void
526 check_from(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
527         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
528 {
529         struct in_addr src_ip;
530         u_int hlen = 0;
531         struct mbuf *m = (*args)->m;
532         struct ip *ip = mtod(m, struct ip *);
533         src_ip = ip->ip_src;
534         if ((*args)->eh == NULL ||
535                 (m->m_pkthdr.len >= sizeof(struct ip) &&
536                 ntohs((*args)->eh->ether_type) == ETHERTYPE_IP)) {
537                 hlen = ip->ip_hl << 2;
538         }
539         *cmd_val = (hlen > 0 &&
540                         ((ipfw_insn_ip *)cmd)->addr.s_addr == src_ip.s_addr);
541         *cmd_ctl = IP_FW_CTL_NO;
542 }
543
544 void
545 check_to(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
546         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
547 {
548         struct in_addr dst_ip;
549         u_int hlen = 0;
550         struct mbuf *m = (*args)->m;
551         struct ip *ip = mtod(m, struct ip *);
552         dst_ip = ip->ip_dst;
553         if ((*args)->eh == NULL ||
554                 (m->m_pkthdr.len >= sizeof(struct ip) &&
555                  ntohs((*args)->eh->ether_type) == ETHERTYPE_IP)) {
556                 hlen = ip->ip_hl << 2;
557         }
558         *cmd_val = (hlen > 0 &&
559                         ((ipfw_insn_ip *)cmd)->addr.s_addr == dst_ip.s_addr);
560         *cmd_ctl = IP_FW_CTL_NO;
561 }
562
563 void
564 check_keep_state(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
565         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
566 {
567         struct ip_fw_state *state;
568         int limited = 0;
569
570         *cmd_ctl = IP_FW_CTL_NO;
571         state = lookup_state(*args, cmd, &limited, 1);
572         if (limited == 1) {
573                 *cmd_val = IP_FW_NOT_MATCH;
574         } else {
575                 if (state == NULL)
576                         state = install_state(*f, cmd, *args);
577
578                 if (state != NULL) {
579                         state->pcnt++;
580                         state->bcnt += ip_len;
581                         state->timestamp = time_second;
582                         *cmd_val = IP_FW_MATCH;
583                 } else {
584                         *cmd_val = IP_FW_NOT_MATCH;
585                 }
586         }
587 }
588
589 void
590 check_tag(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
591         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
592 {
593         struct m_tag *mtag = m_tag_locate((*args)->m,
594                         MTAG_IPFW, cmd->arg1, NULL);
595         if (mtag == NULL) {
596                 mtag = m_tag_alloc(MTAG_IPFW,cmd->arg1, 0, M_NOWAIT);
597                 if (mtag != NULL)
598                         m_tag_prepend((*args)->m, mtag);
599
600         }
601         (*f)->pcnt++;
602         (*f)->bcnt += ip_len;
603         (*f)->timestamp = time_second;
604         *cmd_ctl = IP_FW_CTL_NEXT;
605 }
606
607 void
608 check_untag(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
609         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
610 {
611         struct m_tag *mtag = m_tag_locate((*args)->m,
612                         MTAG_IPFW, cmd->arg1, NULL);
613         if (mtag != NULL)
614                 m_tag_delete((*args)->m, mtag);
615
616         (*f)->pcnt++;
617         (*f)->bcnt += ip_len;
618         (*f)->timestamp = time_second;
619         *cmd_ctl = IP_FW_CTL_NEXT;
620 }
621
622 void
623 check_tagged(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
624         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
625 {
626         *cmd_ctl = IP_FW_CTL_NO;
627         if (m_tag_locate( (*args)->m, MTAG_IPFW,cmd->arg1, NULL) != NULL )
628                 *cmd_val = IP_FW_MATCH;
629         else
630                 *cmd_val = IP_FW_NOT_MATCH;
631 }
632
633 static void
634 ipfw_basic_add_state(struct ipfw_ioc_state *ioc_state)
635 {
636         struct ip_fw_state *state;
637         struct ipfw_context *ctx = ipfw_ctx[mycpuid];
638         struct ipfw_state_context *state_ctx;
639         state_ctx = &ctx->state_ctx[hash_packet(&(ioc_state->flow_id))];
640         state = kmalloc(sizeof(struct ip_fw_state),
641                         M_IPFW2_BASIC, M_WAITOK | M_ZERO);
642         struct ip_fw *rule = ctx->ipfw_rule_chain;
643         while (rule != NULL) {
644                 if (rule->rulenum == ioc_state->rulenum) {
645                         break;
646                 }
647                 rule = rule->next;
648         }
649         if (rule == NULL)
650                 return;
651
652         state->stub = rule;
653
654         state->lifetime = ioc_state->lifetime == 0 ?
655                 state_lifetime : ioc_state->lifetime ;
656         state->timestamp = time_second;
657         state->expiry = ioc_state->expiry;
658         bcopy(&ioc_state->flow_id, &state->flow_id,
659                         sizeof(struct ipfw_flow_id));
660         //append the state into the state chian
661         if (state_ctx->last != NULL)
662                 state_ctx->last->next = state;
663         else
664                 state_ctx->state = state;
665
666         state_ctx->last = state;
667         state_ctx->count++;
668 }
669
670 /*
671  * if rule is NULL
672  *              flush all states
673  * else
674  *              flush states which stub is the rule
675  */
676 static void
677 ipfw_basic_flush_state(struct ip_fw *rule)
678 {
679         struct ipfw_state_context *state_ctx;
680         struct ip_fw_state *state,*the_state, *prev_state;
681         struct ipfw_context *ctx;
682         int i;
683
684         ctx = ipfw_ctx[mycpuid];
685         for (i = 0; i < state_hash_size; i++) {
686                 state_ctx = &ctx->state_ctx[i];
687                 if (state_ctx != NULL) {
688                         state = state_ctx->state;
689                         prev_state = NULL;
690                         while (state != NULL) {
691                                 if (rule != NULL && state->stub != rule) {
692                                         prev_state = state;
693                                         state = state->next;
694                                 } else {
695                                         if (prev_state == NULL)
696                                                 state_ctx->state = state->next;
697                                         else
698                                                 prev_state->next = state->next;
699
700                                         the_state = state;
701                                         state = state->next;
702                                         kfree(the_state, M_IPFW2_BASIC);
703                                         state_ctx->count--;
704                                         if (state == NULL)
705                                                 state_ctx->last = prev_state;
706
707                                 }
708                         }
709                 }
710         }
711 }
712
713 /*
714  * clean up expired state in every tick
715  */
716 static void
717 ipfw_cleanup_expired_state(netmsg_t nmsg)
718 {
719         struct ip_fw_state *state,*the_state,*prev_state;
720         struct ipfw_context *ctx = ipfw_ctx[mycpuid];
721         struct ipfw_state_context *state_ctx;
722         int i;
723
724         for (i = 0; i < state_hash_size; i++) {
725                 prev_state = NULL;
726                 state_ctx = &(ctx->state_ctx[i]);
727                 if (ctx->state_ctx != NULL) {
728                         state = state_ctx->state;
729                         while (state != NULL) {
730                                 if (IS_EXPIRED(state)) {
731                                         if (prev_state == NULL)
732                                                 state_ctx->state = state->next;
733                                         else
734                                                 prev_state->next = state->next;
735
736                                         the_state =state;
737                                         state = state->next;
738
739                                         if (the_state == state_ctx->last)
740                                                 state_ctx->last = NULL;
741
742
743                                         kfree(the_state, M_IPFW2_BASIC);
744                                         state_ctx->count--;
745                                 } else {
746                                         prev_state = state;
747                                         state = state->next;
748                                 }
749                         }
750                 }
751         }
752         ifnet_forwardmsg(&nmsg->lmsg, mycpuid + 1);
753 }
754
755 static void
756 ipfw_tick(void *dummy __unused)
757 {
758         struct lwkt_msg *lmsg = &ipfw_timeout_netmsg.lmsg;
759         KKASSERT(mycpuid == IPFW_CFGCPUID);
760
761         crit_enter();
762         KKASSERT(lmsg->ms_flags & MSGF_DONE);
763         if (IPFW_BASIC_LOADED) {
764                 lwkt_sendmsg_oncpu(IPFW_CFGPORT, lmsg);
765                 /* ipfw_timeout_netmsg's handler reset this callout */
766         }
767         crit_exit();
768
769         struct netmsg_base *msg;
770         struct netmsg_base the_msg;
771         msg = &the_msg;
772         bzero(msg,sizeof(struct netmsg_base));
773
774         netmsg_init(msg, NULL, &curthread->td_msgport, 0,
775                         ipfw_cleanup_expired_state);
776         ifnet_domsg(&msg->lmsg, 0);
777 }
778
779 static void
780 ipfw_tick_dispatch(netmsg_t nmsg)
781 {
782         IPFW_ASSERT_CFGPORT(&curthread->td_msgport);
783         KKASSERT(IPFW_BASIC_LOADED);
784
785         /* Reply ASAP */
786         crit_enter();
787         lwkt_replymsg(&nmsg->lmsg, 0);
788         crit_exit();
789
790         callout_reset(&ipfw_tick_callout,
791                         state_expiry_check_interval * hz, ipfw_tick, NULL);
792 }
793
794 static void
795 ipfw_basic_init_dispatch(netmsg_t nmsg)
796 {
797         IPFW_ASSERT_CFGPORT(&curthread->td_msgport);
798         KKASSERT(IPFW_LOADED);
799
800         int error = 0;
801         callout_init_mp(&ipfw_tick_callout);
802         netmsg_init(&ipfw_timeout_netmsg, NULL, &netisr_adone_rport,
803                         MSGF_DROPABLE | MSGF_PRIORITY, ipfw_tick_dispatch);
804         callout_reset(&ipfw_tick_callout,
805                         state_expiry_check_interval * hz, ipfw_tick, NULL);
806         lwkt_replymsg(&nmsg->lmsg, error);
807         ip_fw_basic_loaded=1;
808 }
809
810 static int
811 ipfw_basic_init(void)
812 {
813         ipfw_basic_flush_state_prt = ipfw_basic_flush_state;
814         ipfw_basic_append_state_prt = ipfw_basic_add_state;
815
816         register_ipfw_module(MODULE_BASIC_ID, MODULE_BASIC_NAME);
817         register_ipfw_filter_funcs(MODULE_BASIC_ID, O_BASIC_COUNT,
818                         (filter_func)check_count);
819         register_ipfw_filter_funcs(MODULE_BASIC_ID, O_BASIC_SKIPTO,
820                         (filter_func)check_skipto);
821         register_ipfw_filter_funcs(MODULE_BASIC_ID, O_BASIC_FORWARD,
822                         (filter_func)check_forward);
823         register_ipfw_filter_funcs(MODULE_BASIC_ID, O_BASIC_KEEP_STATE,
824                         (filter_func)check_keep_state);
825         register_ipfw_filter_funcs(MODULE_BASIC_ID, O_BASIC_CHECK_STATE,
826                         (filter_func)check_check_state);
827
828         register_ipfw_filter_funcs(MODULE_BASIC_ID,
829                         O_BASIC_IN, (filter_func)check_in);
830         register_ipfw_filter_funcs(MODULE_BASIC_ID,
831                         O_BASIC_OUT, (filter_func)check_out);
832         register_ipfw_filter_funcs(MODULE_BASIC_ID,
833                         O_BASIC_VIA, (filter_func)check_via);
834         register_ipfw_filter_funcs(MODULE_BASIC_ID,
835                         O_BASIC_XMIT, (filter_func)check_via);
836         register_ipfw_filter_funcs(MODULE_BASIC_ID,
837                         O_BASIC_RECV, (filter_func)check_via);
838
839         register_ipfw_filter_funcs(MODULE_BASIC_ID,
840                         O_BASIC_PROTO, (filter_func)check_proto);
841         register_ipfw_filter_funcs(MODULE_BASIC_ID,
842                         O_BASIC_PROB, (filter_func)check_prob);
843         register_ipfw_filter_funcs(MODULE_BASIC_ID,
844                         O_BASIC_IP_SRC, (filter_func)check_from);
845         register_ipfw_filter_funcs(MODULE_BASIC_ID,
846                         O_BASIC_IP_DST, (filter_func)check_to);
847
848         register_ipfw_filter_funcs(MODULE_BASIC_ID,
849                         O_BASIC_TAG, (filter_func)check_tag);
850         register_ipfw_filter_funcs(MODULE_BASIC_ID,
851                         O_BASIC_UNTAG, (filter_func)check_untag);
852         register_ipfw_filter_funcs(MODULE_BASIC_ID,
853                         O_BASIC_TAGGED, (filter_func)check_tagged);
854
855         int cpu;
856         struct ipfw_context *ctx;
857
858         for (cpu = 0; cpu < ncpus; cpu++) {
859                 ctx = ipfw_ctx[cpu];
860                 if (ctx != NULL) {
861                         ctx->state_ctx = kmalloc(state_hash_size *
862                                         sizeof(struct ipfw_state_context),
863                                         M_IPFW2_BASIC, M_WAITOK | M_ZERO);
864                         ctx->state_hash_size = state_hash_size;
865                 }
866         }
867
868         struct netmsg_base smsg;
869         netmsg_init(&smsg, NULL, &curthread->td_msgport,
870                         0, ipfw_basic_init_dispatch);
871         lwkt_domsg(IPFW_CFGPORT, &smsg.lmsg, 0);
872         return 0;
873 }
874
875 static void
876 ipfw_basic_stop_dispatch(netmsg_t nmsg)
877 {
878         IPFW_ASSERT_CFGPORT(&curthread->td_msgport);
879         KKASSERT(IPFW_LOADED);
880         int error = 0;
881         callout_stop(&ipfw_tick_callout);
882         netmsg_service_sync();
883         crit_enter();
884         lwkt_dropmsg(&ipfw_timeout_netmsg.lmsg);
885         crit_exit();
886         lwkt_replymsg(&nmsg->lmsg, error);
887         ip_fw_basic_loaded=0;
888 }
889
890 static int
891 ipfw_basic_stop(void)
892 {
893         int cpu,i;
894         struct ipfw_state_context *state_ctx;
895         struct ip_fw_state *state,*the_state;
896         struct ipfw_context *ctx;
897         if (unregister_ipfw_module(MODULE_BASIC_ID) ==0 ) {
898                 ipfw_basic_flush_state_prt = NULL;
899                 ipfw_basic_append_state_prt = NULL;
900
901                 for (cpu = 0; cpu < ncpus; cpu++) {
902                         ctx = ipfw_ctx[cpu];
903                         if (ctx != NULL) {
904                                 for (i = 0; i < state_hash_size; i++) {
905                                         state_ctx = &ctx->state_ctx[i];
906                                         if (state_ctx != NULL) {
907                                                 state = state_ctx->state;
908                                                 while (state != NULL) {
909                                                         the_state = state;
910                                                         state = state->next;
911                                                         if (the_state ==
912                                                                 state_ctx->last)
913                                                         state_ctx->last = NULL;
914
915                                                         kfree(the_state,
916                                                                 M_IPFW2_BASIC);
917                                                 }
918                                         }
919                                 }
920                                 ctx->state_hash_size = 0;
921                                 kfree(ctx->state_ctx, M_IPFW2_BASIC);
922                                 ctx->state_ctx = NULL;
923                         }
924                 }
925                 struct netmsg_base smsg;
926                 netmsg_init(&smsg, NULL, &curthread->td_msgport,
927                                 0, ipfw_basic_stop_dispatch);
928                 return lwkt_domsg(IPFW_CFGPORT, &smsg.lmsg, 0);
929         }
930         return 1;
931 }
932
933
934 static int
935 ipfw2_basic_modevent(module_t mod, int type, void *data)
936 {
937         int err;
938         switch (type) {
939                 case MOD_LOAD:
940                         err = ipfw_basic_init();
941                         break;
942                 case MOD_UNLOAD:
943                         err = ipfw_basic_stop();
944                         break;
945                 default:
946                         err = 1;
947         }
948         return err;
949 }
950
951 static moduledata_t ipfw2_basic_mod = {
952         "ipfw2_basic",
953         ipfw2_basic_modevent,
954         NULL
955 };
956 DECLARE_MODULE(ipfw2_basic, ipfw2_basic_mod, SI_SUB_PROTO_END, SI_ORDER_ANY);
957 MODULE_DEPEND(ipfw2_basic, ipfw2, 1, 1, 1);
958 MODULE_VERSION(ipfw2_basic, 1);