05a2b613a7bf31c94558aa521437b1a9028c528a
[dragonfly.git] / sys / net / ipfw3_basic / ip_fw3_basic.c
1 /*
2  * Copyright (c) 2014 The DragonFly Project.  All rights reserved.
3  *
4  * This code is derived from software contributed to The DragonFly Project
5  * by Bill Yuan <bycn82@gmail.com>
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  *
11  * 1. Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in
15  *    the documentation and/or other materials provided with the
16  *    distribution.
17  * 3. Neither the name of The DragonFly Project nor the names of its
18  *    contributors may be used to endorse or promote products derived
19  *    from this software without specific, prior written permission.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
24  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE
25  * COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
26  * INCIDENTAL, SPECIAL, EXEMPLARY OR CONSEQUENTIAL DAMAGES (INCLUDING,
27  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
28  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
29  * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
30  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
31  * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
32  * SUCH DAMAGE.
33  */
34
35 #include <sys/param.h>
36 #include <sys/kernel.h>
37 #include <sys/malloc.h>
38 #include <sys/mbuf.h>
39 #include <sys/socketvar.h>
40 #include <sys/sysctl.h>
41 #include <sys/syslog.h>
42 #include <sys/systimer.h>
43 #include <sys/thread2.h>
44 #include <sys/in_cksum.h>
45
46 #include <net/if.h>
47 #include <net/ethernet.h>
48 #include <net/netmsg2.h>
49 #include <net/netisr2.h>
50 #include <net/route.h>
51
52 #include <netinet/ip.h>
53 #include <netinet/in.h>
54 #include <netinet/in_systm.h>
55 #include <netinet/in_var.h>
56 #include <netinet/in_pcb.h>
57 #include <netinet/ip_var.h>
58 #include <netinet/ip_icmp.h>
59 #include <netinet/tcp.h>
60 #include <netinet/tcp_timer.h>
61 #include <netinet/tcp_var.h>
62 #include <netinet/tcpip.h>
63 #include <netinet/udp.h>
64 #include <netinet/udp_var.h>
65 #include <netinet/ip_divert.h>
66 #include <netinet/if_ether.h>
67
68 #include <net/ipfw3/ip_fw.h>
69
70 #include "ip_fw3_basic.h"
71
72 #define TIME_LEQ(a, b)  ((int)((a) - (b)) <= 0)
73
74 extern struct ipfw_context      *ipfw_ctx[MAXCPU];
75 extern int fw_verbose;
76 extern ipfw_basic_delete_state_t *ipfw_basic_flush_state_prt;
77 extern ipfw_basic_append_state_t *ipfw_basic_append_state_prt;
78
79 static int ip_fw_basic_loaded;
80 static struct netmsg_base ipfw_timeout_netmsg;  /* schedule ipfw timeout */
81 static struct callout ipfw_tick_callout;
82 static int state_lifetime = 20;
83 static int state_expiry_check_interval = 10;
84 static int state_count_max = 4096;
85 static int state_hash_size_old = 0;
86 static int state_hash_size = 4096;
87
88
89 static int ipfw_sysctl_adjust_hash_size(SYSCTL_HANDLER_ARGS);
90 void adjust_hash_size_dispatch(netmsg_t nmsg);
91
92 SYSCTL_NODE(_net_inet_ip, OID_AUTO, fw_basic,
93                 CTLFLAG_RW, 0, "Firewall Basic");
94 SYSCTL_PROC(_net_inet_ip_fw_basic, OID_AUTO, state_hash_size,
95                 CTLTYPE_INT | CTLFLAG_RW, &state_hash_size, 0,
96                 ipfw_sysctl_adjust_hash_size, "I", "Adjust hash size");
97
98 SYSCTL_INT(_net_inet_ip_fw_basic, OID_AUTO, state_lifetime, CTLFLAG_RW,
99                 &state_lifetime, 0, "default life time");
100 SYSCTL_INT(_net_inet_ip_fw_basic, OID_AUTO,
101                 state_expiry_check_interval, CTLFLAG_RW,
102                 &state_expiry_check_interval, 0,
103                 "default state expiry check interval");
104 SYSCTL_INT(_net_inet_ip_fw_basic, OID_AUTO, state_count_max, CTLFLAG_RW,
105                 &state_count_max, 0, "maximum of state");
106
107 static int
108 ipfw_sysctl_adjust_hash_size(SYSCTL_HANDLER_ARGS)
109 {
110         int error, value = 0;
111
112         state_hash_size_old = state_hash_size;
113         value = state_hash_size;
114         error = sysctl_handle_int(oidp, &value, 0, req);
115         if (error || !req->newptr) {
116                 goto back;
117         }
118         /*
119          * Make sure we have a power of 2 and
120          * do not allow more than 64k entries.
121          */
122         error = EINVAL;
123         if (value <= 1 || value > 65536) {
124                 goto back;
125         }
126         if ((value & (value - 1)) != 0) {
127                 goto back;
128         }
129
130         error = 0;
131         if (state_hash_size != value) {
132                 state_hash_size = value;
133
134                 struct netmsg_base *msg, the_msg;
135                 msg = &the_msg;
136                 bzero(msg,sizeof(struct netmsg_base));
137
138                 netmsg_init(msg, NULL, &curthread->td_msgport,
139                                 0, adjust_hash_size_dispatch);
140                 ifnet_domsg(&msg->lmsg, 0);
141         }
142 back:
143         return error;
144 }
145
146 void
147 adjust_hash_size_dispatch(netmsg_t nmsg)
148 {
149         struct ipfw_state_context *state_ctx;
150         struct ip_fw_state *the_state, *state;
151         struct ipfw_context *ctx = ipfw_ctx[mycpuid];
152         int i;
153
154         for (i = 0; i < state_hash_size_old; i++) {
155                 state_ctx = &ctx->state_ctx[i];
156                 if (state_ctx != NULL) {
157                         state = state_ctx->state;
158                         while (state != NULL) {
159                                 the_state = state;
160                                 state = state->next;
161                                 kfree(the_state, M_IPFW3_BASIC);
162                                 the_state = NULL;
163                         }
164                 }
165         }
166         kfree(ctx->state_ctx,M_IPFW3_BASIC);
167         ctx->state_ctx = kmalloc(state_hash_size *
168                                 sizeof(struct ipfw_state_context),
169                                 M_IPFW3_BASIC, M_WAITOK | M_ZERO);
170         ctx->state_hash_size = state_hash_size;
171         ifnet_forwardmsg(&nmsg->lmsg, mycpuid + 1);
172 }
173
174
175 /*      prototype of the checker functions      */
176 void check_count(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
177         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
178 void check_skipto(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
179         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
180 void check_forward(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
181         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
182 void check_check_state(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
183         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
184
185 void check_in(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
186         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
187 void check_out(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
188         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
189 void check_via(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
190         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
191 void check_proto(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
192         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
193 void check_prob(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
194         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
195 void check_from(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
196         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
197 void check_to(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
198         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
199 void check_keep_state(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
200         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
201 void check_tag(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
202         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
203 void check_untag(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
204         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
205 void check_tagged(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
206         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
207
208 /*      prototype of the utility functions      */
209 static struct ip_fw *lookup_next_rule(struct ip_fw *me);
210 static int iface_match(struct ifnet *ifp, ipfw_insn_if *cmd);
211 static __inline int hash_packet(struct ipfw_flow_id *id);
212
213 static __inline int
214 hash_packet(struct ipfw_flow_id *id)
215 {
216         uint32_t i;
217         i = (id->proto) ^ (id->dst_ip) ^ (id->src_ip) ^
218                 (id->dst_port) ^ (id->src_port);
219         i &= state_hash_size - 1;
220         return i;
221 }
222
223 static struct ip_fw *
224 lookup_next_rule(struct ip_fw *me)
225 {
226         struct ip_fw *rule = NULL;
227         ipfw_insn *cmd;
228
229         /* look for action, in case it is a skipto */
230         cmd = ACTION_PTR(me);
231         if ((int)cmd->module == MODULE_BASIC_ID &&
232                 (int)cmd->opcode == O_BASIC_SKIPTO) {
233                 for (rule = me->next; rule; rule = rule->next) {
234                         if (rule->rulenum >= cmd->arg1)
235                                 break;
236                 }
237         }
238         if (rule == NULL) /* failure or not a skipto */
239                 rule = me->next;
240
241         me->next_rule = rule;
242         return rule;
243 }
244
245 /*
246  * when all = 1, it will check all the state_ctx
247  */
248 static struct ip_fw_state *
249 lookup_state(struct ip_fw_args *args, ipfw_insn *cmd, int *limited, int all)
250 {
251         struct ip_fw_state *state = NULL;
252         struct ipfw_context *ctx = ipfw_ctx[mycpuid];
253         struct ipfw_state_context *state_ctx;
254         int start, end, i, count = 0;
255
256         if (all && cmd->arg1) {
257                 start = 0;
258                 end = state_hash_size - 1;
259         } else {
260                 start = hash_packet(&args->f_id);
261                 end = hash_packet(&args->f_id);
262         }
263         for (i = start; i <= end; i++) {
264                 state_ctx = &ctx->state_ctx[i];
265                 if (state_ctx != NULL) {
266                         state = state_ctx->state;
267                         struct ipfw_flow_id     *fid = &args->f_id;
268                         while (state != NULL) {
269                                 if (cmd->arg1) {
270                                         if ((cmd->arg3 == 1 &&
271                                                 fid->src_ip ==
272                                                 state->flow_id.src_ip) ||
273                                                 (cmd->arg3 == 2 &&
274                                                 fid->src_port ==
275                                                 state->flow_id.src_port) ||
276                                                 (cmd->arg3 == 3 &&
277                                                 fid->dst_ip ==
278                                                 state->flow_id.dst_ip) ||
279                                                 (cmd->arg3 == 4 &&
280                                                 fid->dst_port ==
281                                                 state->flow_id.dst_port)) {
282
283                                                 count++;
284                                                 if (count >= cmd->arg1) {
285                                                         *limited = 1;
286                                                         goto done;
287                                                 }
288                                         }
289                                 }
290
291                                 if (fid->proto == state->flow_id.proto) {
292                                         if (fid->src_ip ==
293                                         state->flow_id.src_ip &&
294                                         fid->dst_ip ==
295                                         state->flow_id.dst_ip &&
296                                         (fid->src_port ==
297                                         state->flow_id.src_port ||
298                                         state->flow_id.src_port == 0) &&
299                                         (fid->dst_port ==
300                                         state->flow_id.dst_port ||
301                                         state->flow_id.dst_port == 0)) {
302                                                 goto done;
303                                         }
304                                         if (fid->src_ip ==
305                                         state->flow_id.dst_ip &&
306                                         fid->dst_ip ==
307                                         state->flow_id.src_ip &&
308                                         (fid->src_port ==
309                                         state->flow_id.dst_port ||
310                                         state->flow_id.dst_port == 0) &&
311                                         (fid->dst_port ==
312                                         state->flow_id.src_port ||
313                                         state->flow_id.src_port == 0)) {
314                                                 goto done;
315                                         }
316                                 }
317                                 state = state->next;
318                         }
319                 }
320         }
321 done:
322         return state;
323 }
324
325 static struct ip_fw_state *
326 install_state(struct ip_fw *rule, ipfw_insn *cmd, struct ip_fw_args *args)
327 {
328         struct ip_fw_state *state;
329         struct ipfw_context *ctx = ipfw_ctx[mycpuid];
330         struct ipfw_state_context *state_ctx;
331         state_ctx = &ctx->state_ctx[hash_packet(&args->f_id)];
332         state = kmalloc(sizeof(struct ip_fw_state),
333                         M_IPFW3_BASIC, M_NOWAIT | M_ZERO);
334         if (state == NULL) {
335                 return NULL;
336         }
337         state->stub = rule;
338         state->lifetime = cmd->arg2 == 0 ? state_lifetime : cmd->arg2 ;
339         state->timestamp = time_second;
340         state->expiry = 0;
341         bcopy(&args->f_id,&state->flow_id,sizeof(struct ipfw_flow_id));
342         //append the state into the state chian
343         if (state_ctx->last != NULL)
344                 state_ctx->last->next = state;
345         else
346                 state_ctx->state = state;
347         state_ctx->last = state;
348         state_ctx->count++;
349         return state;
350 }
351
352
353 static int
354 iface_match(struct ifnet *ifp, ipfw_insn_if *cmd)
355 {
356         if (ifp == NULL)        /* no iface with this packet, match fails */
357                 return 0;
358
359         /* Check by name or by IP address */
360         if (cmd->name[0] != '\0') { /* match by name */
361                 /* Check name */
362                 if (cmd->p.glob) {
363                         if (kfnmatch(cmd->name, ifp->if_xname, 0) == 0)
364                                 return(1);
365                 } else {
366                         if (strncmp(ifp->if_xname, cmd->name, IFNAMSIZ) == 0)
367                                 return(1);
368                 }
369         } else {
370                 struct ifaddr_container *ifac;
371
372                 TAILQ_FOREACH(ifac, &ifp->if_addrheads[mycpuid], ifa_link) {
373                         struct ifaddr *ia = ifac->ifa;
374
375                         if (ia->ifa_addr == NULL)
376                                 continue;
377                         if (ia->ifa_addr->sa_family != AF_INET)
378                                 continue;
379                         if (cmd->p.ip.s_addr ==
380                                 ((struct sockaddr_in *)
381                                 (ia->ifa_addr))->sin_addr.s_addr)
382                                         return(1);      /* match */
383
384                 }
385         }
386         return 0;       /* no match, fail ... */
387 }
388
389 /* implimentation of the checker functions */
390 void
391 check_count(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
392         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
393 {
394         (*f)->pcnt++;
395         (*f)->bcnt += ip_len;
396         (*f)->timestamp = time_second;
397         *cmd_ctl = IP_FW_CTL_NEXT;
398 }
399
400 void
401 check_skipto(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
402         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
403 {
404         (*f)->pcnt++;
405         (*f)->bcnt += ip_len;
406         (*f)->timestamp = time_second;
407         if ((*f)->next_rule == NULL)
408                 lookup_next_rule(*f);
409
410         *cmd_ctl = IP_FW_CTL_AGAIN;
411 }
412
413 void
414 check_forward(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
415         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
416 {
417         struct sockaddr_in *sin, *sa;
418         struct m_tag *mtag;
419
420         if ((*args)->eh) {      /* not valid on layer2 pkts */
421                 *cmd_ctl=IP_FW_CTL_NEXT;
422                 return;
423         }
424
425         (*f)->pcnt++;
426         (*f)->bcnt += ip_len;
427         (*f)->timestamp = time_second;
428         if ((*f)->next_rule == NULL)
429                 lookup_next_rule(*f);
430
431         mtag = m_tag_get(PACKET_TAG_IPFORWARD,
432                         sizeof(*sin), M_NOWAIT);
433         if (mtag == NULL) {
434                 *cmd_val = IP_FW_DENY;
435                 *cmd_ctl = IP_FW_CTL_DONE;
436                 return;
437         }
438         sin = m_tag_data(mtag);
439         sa = &((ipfw_insn_sa *)cmd)->sa;
440         /* arg3: count of the dest, arg1: type of fwd */
441         int i = 0;
442         if(cmd->arg3 > 1) {
443                 if (cmd->arg1 == 0) {           /* type: random */
444                         i = krandom() % cmd->arg3;
445                 } else if (cmd->arg1 == 1) {    /* type: round-robin */
446                         i = cmd->arg2++ % cmd->arg3;
447                 } else if (cmd->arg1 == 2) {    /* type: sticky */
448                         struct ip *ip = mtod((*args)->m, struct ip *);
449                         i = ip->ip_src.s_addr & (cmd->arg3 - 1);
450                 }
451                 sa += i;
452         }
453         *sin = *sa;     /* apply the destination */
454         m_tag_prepend((*args)->m, mtag);
455         (*args)->m->m_pkthdr.fw_flags |= IPFORWARD_MBUF_TAGGED;
456         (*args)->m->m_pkthdr.fw_flags &= ~BRIDGE_MBUF_TAGGED;
457         *cmd_ctl = IP_FW_CTL_DONE;
458         *cmd_val = IP_FW_PASS;
459 }
460
461 void
462 check_check_state(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
463         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
464 {
465         struct ip_fw_state *state=NULL;
466         int limited = 0 ;
467         state = lookup_state(*args, cmd, &limited, 0);
468         if (state != NULL) {
469                 state->pcnt++;
470                 state->bcnt += ip_len;
471                 state->timestamp = time_second;
472                 (*f)->pcnt++;
473                 (*f)->bcnt += ip_len;
474                 (*f)->timestamp = time_second;
475                 *f = state->stub;
476                 *cmd_ctl = IP_FW_CTL_CHK_STATE;
477         } else {
478                 *cmd_ctl = IP_FW_CTL_NEXT;
479         }
480 }
481
482 void
483 check_in(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
484         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
485 {
486         *cmd_ctl = IP_FW_CTL_NO;
487         *cmd_val = ((*args)->oif == NULL);
488 }
489
490 void
491 check_out(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
492         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
493 {
494         *cmd_ctl = IP_FW_CTL_NO;
495         *cmd_val = ((*args)->oif != NULL);
496 }
497
498 void
499 check_via(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
500         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
501 {
502         *cmd_ctl = IP_FW_CTL_NO;
503         *cmd_val = iface_match((*args)->oif ?
504                         (*args)->oif : (*args)->m->m_pkthdr.rcvif,
505                         (ipfw_insn_if *)cmd);
506 }
507
508 void
509 check_proto(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
510         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
511 {
512         *cmd_ctl = IP_FW_CTL_NO;
513         *cmd_val = ((*args)->f_id.proto == cmd->arg1);
514 }
515
516 void
517 check_prob(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
518         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
519 {
520         *cmd_ctl = IP_FW_CTL_NO;
521         *cmd_val = (krandom() % 100) < cmd->arg1;
522 }
523
524 void
525 check_from(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
526         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
527 {
528         struct in_addr src_ip;
529         u_int hlen = 0;
530         struct mbuf *m = (*args)->m;
531         struct ip *ip = mtod(m, struct ip *);
532         src_ip = ip->ip_src;
533         if ((*args)->eh == NULL ||
534                 (m->m_pkthdr.len >= sizeof(struct ip) &&
535                 ntohs((*args)->eh->ether_type) == ETHERTYPE_IP)) {
536                 hlen = ip->ip_hl << 2;
537         }
538         *cmd_val = (hlen > 0 &&
539                         ((ipfw_insn_ip *)cmd)->addr.s_addr == src_ip.s_addr);
540         *cmd_ctl = IP_FW_CTL_NO;
541 }
542
543 void
544 check_to(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
545         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
546 {
547         struct in_addr dst_ip;
548         u_int hlen = 0;
549         struct mbuf *m = (*args)->m;
550         struct ip *ip = mtod(m, struct ip *);
551         dst_ip = ip->ip_dst;
552         if ((*args)->eh == NULL ||
553                 (m->m_pkthdr.len >= sizeof(struct ip) &&
554                  ntohs((*args)->eh->ether_type) == ETHERTYPE_IP)) {
555                 hlen = ip->ip_hl << 2;
556         }
557         *cmd_val = (hlen > 0 &&
558                         ((ipfw_insn_ip *)cmd)->addr.s_addr == dst_ip.s_addr);
559         *cmd_ctl = IP_FW_CTL_NO;
560 }
561
562 void
563 check_keep_state(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
564         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
565 {
566         struct ip_fw_state *state;
567         int limited = 0;
568
569         *cmd_ctl = IP_FW_CTL_NO;
570         state = lookup_state(*args, cmd, &limited, 1);
571         if (limited == 1) {
572                 *cmd_val = IP_FW_NOT_MATCH;
573         } else {
574                 if (state == NULL)
575                         state = install_state(*f, cmd, *args);
576
577                 if (state != NULL) {
578                         state->pcnt++;
579                         state->bcnt += ip_len;
580                         state->timestamp = time_second;
581                         *cmd_val = IP_FW_MATCH;
582                 } else {
583                         *cmd_val = IP_FW_NOT_MATCH;
584                 }
585         }
586 }
587
588 void
589 check_tag(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
590         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
591 {
592         struct m_tag *mtag = m_tag_locate((*args)->m,
593                         MTAG_IPFW, cmd->arg1, NULL);
594         if (mtag == NULL) {
595                 mtag = m_tag_alloc(MTAG_IPFW,cmd->arg1, 0, M_NOWAIT);
596                 if (mtag != NULL)
597                         m_tag_prepend((*args)->m, mtag);
598
599         }
600         (*f)->pcnt++;
601         (*f)->bcnt += ip_len;
602         (*f)->timestamp = time_second;
603         *cmd_ctl = IP_FW_CTL_NEXT;
604 }
605
606 void
607 check_untag(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
608         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
609 {
610         struct m_tag *mtag = m_tag_locate((*args)->m,
611                         MTAG_IPFW, cmd->arg1, NULL);
612         if (mtag != NULL)
613                 m_tag_delete((*args)->m, mtag);
614
615         (*f)->pcnt++;
616         (*f)->bcnt += ip_len;
617         (*f)->timestamp = time_second;
618         *cmd_ctl = IP_FW_CTL_NEXT;
619 }
620
621 void
622 check_tagged(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
623         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
624 {
625         *cmd_ctl = IP_FW_CTL_NO;
626         if (m_tag_locate( (*args)->m, MTAG_IPFW,cmd->arg1, NULL) != NULL )
627                 *cmd_val = IP_FW_MATCH;
628         else
629                 *cmd_val = IP_FW_NOT_MATCH;
630 }
631
632 static void
633 ipfw_basic_add_state(struct ipfw_ioc_state *ioc_state)
634 {
635         struct ip_fw_state *state;
636         struct ipfw_context *ctx = ipfw_ctx[mycpuid];
637         struct ipfw_state_context *state_ctx;
638         state_ctx = &ctx->state_ctx[hash_packet(&(ioc_state->flow_id))];
639         state = kmalloc(sizeof(struct ip_fw_state),
640                         M_IPFW3_BASIC, M_WAITOK | M_ZERO);
641         struct ip_fw *rule = ctx->ipfw_rule_chain;
642         while (rule != NULL) {
643                 if (rule->rulenum == ioc_state->rulenum) {
644                         break;
645                 }
646                 rule = rule->next;
647         }
648         if (rule == NULL)
649                 return;
650
651         state->stub = rule;
652
653         state->lifetime = ioc_state->lifetime == 0 ?
654                 state_lifetime : ioc_state->lifetime ;
655         state->timestamp = time_second;
656         state->expiry = ioc_state->expiry;
657         bcopy(&ioc_state->flow_id, &state->flow_id,
658                         sizeof(struct ipfw_flow_id));
659         //append the state into the state chian
660         if (state_ctx->last != NULL)
661                 state_ctx->last->next = state;
662         else
663                 state_ctx->state = state;
664
665         state_ctx->last = state;
666         state_ctx->count++;
667 }
668
669 /*
670  * if rule is NULL
671  *              flush all states
672  * else
673  *              flush states which stub is the rule
674  */
675 static void
676 ipfw_basic_flush_state(struct ip_fw *rule)
677 {
678         struct ipfw_state_context *state_ctx;
679         struct ip_fw_state *state,*the_state, *prev_state;
680         struct ipfw_context *ctx;
681         int i;
682
683         ctx = ipfw_ctx[mycpuid];
684         for (i = 0; i < state_hash_size; i++) {
685                 state_ctx = &ctx->state_ctx[i];
686                 if (state_ctx != NULL) {
687                         state = state_ctx->state;
688                         prev_state = NULL;
689                         while (state != NULL) {
690                                 if (rule != NULL && state->stub != rule) {
691                                         prev_state = state;
692                                         state = state->next;
693                                 } else {
694                                         if (prev_state == NULL)
695                                                 state_ctx->state = state->next;
696                                         else
697                                                 prev_state->next = state->next;
698
699                                         the_state = state;
700                                         state = state->next;
701                                         kfree(the_state, M_IPFW3_BASIC);
702                                         state_ctx->count--;
703                                         if (state == NULL)
704                                                 state_ctx->last = prev_state;
705
706                                 }
707                         }
708                 }
709         }
710 }
711
712 /*
713  * clean up expired state in every tick
714  */
715 static void
716 ipfw_cleanup_expired_state(netmsg_t nmsg)
717 {
718         struct ip_fw_state *state,*the_state,*prev_state;
719         struct ipfw_context *ctx = ipfw_ctx[mycpuid];
720         struct ipfw_state_context *state_ctx;
721         int i;
722
723         for (i = 0; i < state_hash_size; i++) {
724                 prev_state = NULL;
725                 state_ctx = &(ctx->state_ctx[i]);
726                 if (ctx->state_ctx != NULL) {
727                         state = state_ctx->state;
728                         while (state != NULL) {
729                                 if (IS_EXPIRED(state)) {
730                                         if (prev_state == NULL)
731                                                 state_ctx->state = state->next;
732                                         else
733                                                 prev_state->next = state->next;
734
735                                         the_state =state;
736                                         state = state->next;
737
738                                         if (the_state == state_ctx->last)
739                                                 state_ctx->last = NULL;
740
741
742                                         kfree(the_state, M_IPFW3_BASIC);
743                                         state_ctx->count--;
744                                 } else {
745                                         prev_state = state;
746                                         state = state->next;
747                                 }
748                         }
749                 }
750         }
751         ifnet_forwardmsg(&nmsg->lmsg, mycpuid + 1);
752 }
753
754 static void
755 ipfw_tick(void *dummy __unused)
756 {
757         struct lwkt_msg *lmsg = &ipfw_timeout_netmsg.lmsg;
758         KKASSERT(mycpuid == IPFW_CFGCPUID);
759
760         crit_enter();
761         KKASSERT(lmsg->ms_flags & MSGF_DONE);
762         if (IPFW_BASIC_LOADED) {
763                 lwkt_sendmsg_oncpu(IPFW_CFGPORT, lmsg);
764                 /* ipfw_timeout_netmsg's handler reset this callout */
765         }
766         crit_exit();
767
768         struct netmsg_base *msg;
769         struct netmsg_base the_msg;
770         msg = &the_msg;
771         bzero(msg,sizeof(struct netmsg_base));
772
773         netmsg_init(msg, NULL, &curthread->td_msgport, 0,
774                         ipfw_cleanup_expired_state);
775         ifnet_domsg(&msg->lmsg, 0);
776 }
777
778 static void
779 ipfw_tick_dispatch(netmsg_t nmsg)
780 {
781         IPFW_ASSERT_CFGPORT(&curthread->td_msgport);
782         KKASSERT(IPFW_BASIC_LOADED);
783
784         /* Reply ASAP */
785         crit_enter();
786         lwkt_replymsg(&nmsg->lmsg, 0);
787         crit_exit();
788
789         callout_reset(&ipfw_tick_callout,
790                         state_expiry_check_interval * hz, ipfw_tick, NULL);
791 }
792
793 static void
794 ipfw_basic_init_dispatch(netmsg_t nmsg)
795 {
796         IPFW_ASSERT_CFGPORT(&curthread->td_msgport);
797         KKASSERT(IPFW3_LOADED);
798
799         int error = 0;
800         callout_init_mp(&ipfw_tick_callout);
801         netmsg_init(&ipfw_timeout_netmsg, NULL, &netisr_adone_rport,
802                         MSGF_DROPABLE | MSGF_PRIORITY, ipfw_tick_dispatch);
803         callout_reset(&ipfw_tick_callout,
804                         state_expiry_check_interval * hz, ipfw_tick, NULL);
805         lwkt_replymsg(&nmsg->lmsg, error);
806         ip_fw_basic_loaded=1;
807 }
808
809 static int
810 ipfw_basic_init(void)
811 {
812         ipfw_basic_flush_state_prt = ipfw_basic_flush_state;
813         ipfw_basic_append_state_prt = ipfw_basic_add_state;
814
815         register_ipfw_module(MODULE_BASIC_ID, MODULE_BASIC_NAME);
816         register_ipfw_filter_funcs(MODULE_BASIC_ID, O_BASIC_COUNT,
817                         (filter_func)check_count);
818         register_ipfw_filter_funcs(MODULE_BASIC_ID, O_BASIC_SKIPTO,
819                         (filter_func)check_skipto);
820         register_ipfw_filter_funcs(MODULE_BASIC_ID, O_BASIC_FORWARD,
821                         (filter_func)check_forward);
822         register_ipfw_filter_funcs(MODULE_BASIC_ID, O_BASIC_KEEP_STATE,
823                         (filter_func)check_keep_state);
824         register_ipfw_filter_funcs(MODULE_BASIC_ID, O_BASIC_CHECK_STATE,
825                         (filter_func)check_check_state);
826
827         register_ipfw_filter_funcs(MODULE_BASIC_ID,
828                         O_BASIC_IN, (filter_func)check_in);
829         register_ipfw_filter_funcs(MODULE_BASIC_ID,
830                         O_BASIC_OUT, (filter_func)check_out);
831         register_ipfw_filter_funcs(MODULE_BASIC_ID,
832                         O_BASIC_VIA, (filter_func)check_via);
833         register_ipfw_filter_funcs(MODULE_BASIC_ID,
834                         O_BASIC_XMIT, (filter_func)check_via);
835         register_ipfw_filter_funcs(MODULE_BASIC_ID,
836                         O_BASIC_RECV, (filter_func)check_via);
837
838         register_ipfw_filter_funcs(MODULE_BASIC_ID,
839                         O_BASIC_PROTO, (filter_func)check_proto);
840         register_ipfw_filter_funcs(MODULE_BASIC_ID,
841                         O_BASIC_PROB, (filter_func)check_prob);
842         register_ipfw_filter_funcs(MODULE_BASIC_ID,
843                         O_BASIC_IP_SRC, (filter_func)check_from);
844         register_ipfw_filter_funcs(MODULE_BASIC_ID,
845                         O_BASIC_IP_DST, (filter_func)check_to);
846
847         register_ipfw_filter_funcs(MODULE_BASIC_ID,
848                         O_BASIC_TAG, (filter_func)check_tag);
849         register_ipfw_filter_funcs(MODULE_BASIC_ID,
850                         O_BASIC_UNTAG, (filter_func)check_untag);
851         register_ipfw_filter_funcs(MODULE_BASIC_ID,
852                         O_BASIC_TAGGED, (filter_func)check_tagged);
853
854         int cpu;
855         struct ipfw_context *ctx;
856
857         for (cpu = 0; cpu < ncpus; cpu++) {
858                 ctx = ipfw_ctx[cpu];
859                 if (ctx != NULL) {
860                         ctx->state_ctx = kmalloc(state_hash_size *
861                                         sizeof(struct ipfw_state_context),
862                                         M_IPFW3_BASIC, M_WAITOK | M_ZERO);
863                         ctx->state_hash_size = state_hash_size;
864                 }
865         }
866
867         struct netmsg_base smsg;
868         netmsg_init(&smsg, NULL, &curthread->td_msgport,
869                         0, ipfw_basic_init_dispatch);
870         lwkt_domsg(IPFW_CFGPORT, &smsg.lmsg, 0);
871         return 0;
872 }
873
874 static void
875 ipfw_basic_stop_dispatch(netmsg_t nmsg)
876 {
877         IPFW_ASSERT_CFGPORT(&curthread->td_msgport);
878         KKASSERT(IPFW3_LOADED);
879         int error = 0;
880         callout_stop(&ipfw_tick_callout);
881         netmsg_service_sync();
882         crit_enter();
883         lwkt_dropmsg(&ipfw_timeout_netmsg.lmsg);
884         crit_exit();
885         lwkt_replymsg(&nmsg->lmsg, error);
886         ip_fw_basic_loaded=0;
887 }
888
889 static int
890 ipfw_basic_stop(void)
891 {
892         int cpu,i;
893         struct ipfw_state_context *state_ctx;
894         struct ip_fw_state *state,*the_state;
895         struct ipfw_context *ctx;
896         if (unregister_ipfw_module(MODULE_BASIC_ID) ==0 ) {
897                 ipfw_basic_flush_state_prt = NULL;
898                 ipfw_basic_append_state_prt = NULL;
899
900                 for (cpu = 0; cpu < ncpus; cpu++) {
901                         ctx = ipfw_ctx[cpu];
902                         if (ctx != NULL) {
903                                 for (i = 0; i < state_hash_size; i++) {
904                                         state_ctx = &ctx->state_ctx[i];
905                                         if (state_ctx != NULL) {
906                                                 state = state_ctx->state;
907                                                 while (state != NULL) {
908                                                         the_state = state;
909                                                         state = state->next;
910                                                         if (the_state ==
911                                                                 state_ctx->last)
912                                                         state_ctx->last = NULL;
913
914                                                         kfree(the_state,
915                                                                 M_IPFW3_BASIC);
916                                                 }
917                                         }
918                                 }
919                                 ctx->state_hash_size = 0;
920                                 kfree(ctx->state_ctx, M_IPFW3_BASIC);
921                                 ctx->state_ctx = NULL;
922                         }
923                 }
924                 struct netmsg_base smsg;
925                 netmsg_init(&smsg, NULL, &curthread->td_msgport,
926                                 0, ipfw_basic_stop_dispatch);
927                 return lwkt_domsg(IPFW_CFGPORT, &smsg.lmsg, 0);
928         }
929         return 1;
930 }
931
932
933 static int
934 ipfw3_basic_modevent(module_t mod, int type, void *data)
935 {
936         int err;
937         switch (type) {
938                 case MOD_LOAD:
939                         err = ipfw_basic_init();
940                         break;
941                 case MOD_UNLOAD:
942                         err = ipfw_basic_stop();
943                         break;
944                 default:
945                         err = 1;
946         }
947         return err;
948 }
949
950 static moduledata_t ipfw3_basic_mod = {
951         "ipfw3_basic",
952         ipfw3_basic_modevent,
953         NULL
954 };
955 DECLARE_MODULE(ipfw3_basic, ipfw3_basic_mod, SI_SUB_PROTO_END, SI_ORDER_ANY);
956 MODULE_DEPEND(ipfw3_basic, ipfw3, 1, 1, 1);
957 MODULE_VERSION(ipfw3_basic, 1);