ipfw3: filter from/to supports me/any/subnet
[dragonfly.git] / sys / net / ipfw3_basic / ip_fw3_basic.c
1 /*
2  * Copyright (c) 2014 The DragonFly Project.  All rights reserved.
3  *
4  * This code is derived from software contributed to The DragonFly Project
5  * by Bill Yuan <bycn82@gmail.com>
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  *
11  * 1. Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  * 2. Redistributions in binary form must reproduce the above copyright
14  *    notice, this list of conditions and the following disclaimer in
15  *    the documentation and/or other materials provided with the
16  *    distribution.
17  * 3. Neither the name of The DragonFly Project nor the names of its
18  *    contributors may be used to endorse or promote products derived
19  *    from this software without specific, prior written permission.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
24  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE
25  * COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
26  * INCIDENTAL, SPECIAL, EXEMPLARY OR CONSEQUENTIAL DAMAGES (INCLUDING,
27  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
28  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
29  * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
30  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
31  * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
32  * SUCH DAMAGE.
33  */
34
35 #include <sys/param.h>
36 #include <sys/kernel.h>
37 #include <sys/malloc.h>
38 #include <sys/mbuf.h>
39 #include <sys/socketvar.h>
40 #include <sys/sysctl.h>
41 #include <sys/syslog.h>
42 #include <sys/systimer.h>
43 #include <sys/thread2.h>
44 #include <sys/in_cksum.h>
45
46 #include <net/if.h>
47 #include <net/ethernet.h>
48 #include <net/netmsg2.h>
49 #include <net/netisr2.h>
50 #include <net/route.h>
51
52 #include <netinet/ip.h>
53 #include <netinet/in.h>
54 #include <netinet/in_systm.h>
55 #include <netinet/in_var.h>
56 #include <netinet/in_pcb.h>
57 #include <netinet/ip_var.h>
58 #include <netinet/ip_icmp.h>
59 #include <netinet/tcp.h>
60 #include <netinet/tcp_timer.h>
61 #include <netinet/tcp_var.h>
62 #include <netinet/tcpip.h>
63 #include <netinet/udp.h>
64 #include <netinet/udp_var.h>
65 #include <netinet/ip_divert.h>
66 #include <netinet/if_ether.h>
67
68 #include <net/ipfw3/ip_fw.h>
69
70 #include "ip_fw3_basic.h"
71
72 #define TIME_LEQ(a, b)  ((int)((a) - (b)) <= 0)
73
74 extern struct ipfw_context      *ipfw_ctx[MAXCPU];
75 extern int fw_verbose;
76 extern ipfw_basic_delete_state_t *ipfw_basic_flush_state_prt;
77 extern ipfw_basic_append_state_t *ipfw_basic_append_state_prt;
78
79 static int ip_fw_basic_loaded;
80 static struct netmsg_base ipfw_timeout_netmsg;  /* schedule ipfw timeout */
81 static struct callout ipfw_tick_callout;
82 static int state_lifetime = 20;
83 static int state_expiry_check_interval = 10;
84 static int state_count_max = 4096;
85 static int state_hash_size_old = 0;
86 static int state_hash_size = 4096;
87
88
89 static int ipfw_sysctl_adjust_hash_size(SYSCTL_HANDLER_ARGS);
90 void adjust_hash_size_dispatch(netmsg_t nmsg);
91
92 SYSCTL_NODE(_net_inet_ip, OID_AUTO, fw_basic,
93                 CTLFLAG_RW, 0, "Firewall Basic");
94 SYSCTL_PROC(_net_inet_ip_fw_basic, OID_AUTO, state_hash_size,
95                 CTLTYPE_INT | CTLFLAG_RW, &state_hash_size, 0,
96                 ipfw_sysctl_adjust_hash_size, "I", "Adjust hash size");
97
98 SYSCTL_INT(_net_inet_ip_fw_basic, OID_AUTO, state_lifetime, CTLFLAG_RW,
99                 &state_lifetime, 0, "default life time");
100 SYSCTL_INT(_net_inet_ip_fw_basic, OID_AUTO,
101                 state_expiry_check_interval, CTLFLAG_RW,
102                 &state_expiry_check_interval, 0,
103                 "default state expiry check interval");
104 SYSCTL_INT(_net_inet_ip_fw_basic, OID_AUTO, state_count_max, CTLFLAG_RW,
105                 &state_count_max, 0, "maximum of state");
106
107 static int
108 ipfw_sysctl_adjust_hash_size(SYSCTL_HANDLER_ARGS)
109 {
110         int error, value = 0;
111
112         state_hash_size_old = state_hash_size;
113         value = state_hash_size;
114         error = sysctl_handle_int(oidp, &value, 0, req);
115         if (error || !req->newptr) {
116                 goto back;
117         }
118         /*
119          * Make sure we have a power of 2 and
120          * do not allow more than 64k entries.
121          */
122         error = EINVAL;
123         if (value <= 1 || value > 65536) {
124                 goto back;
125         }
126         if ((value & (value - 1)) != 0) {
127                 goto back;
128         }
129
130         error = 0;
131         if (state_hash_size != value) {
132                 state_hash_size = value;
133
134                 struct netmsg_base *msg, the_msg;
135                 msg = &the_msg;
136                 bzero(msg,sizeof(struct netmsg_base));
137
138                 netmsg_init(msg, NULL, &curthread->td_msgport,
139                                 0, adjust_hash_size_dispatch);
140                 ifnet_domsg(&msg->lmsg, 0);
141         }
142 back:
143         return error;
144 }
145
146 void
147 adjust_hash_size_dispatch(netmsg_t nmsg)
148 {
149         struct ipfw_state_context *state_ctx;
150         struct ip_fw_state *the_state, *state;
151         struct ipfw_context *ctx = ipfw_ctx[mycpuid];
152         int i;
153
154         for (i = 0; i < state_hash_size_old; i++) {
155                 state_ctx = &ctx->state_ctx[i];
156                 if (state_ctx != NULL) {
157                         state = state_ctx->state;
158                         while (state != NULL) {
159                                 the_state = state;
160                                 state = state->next;
161                                 kfree(the_state, M_IPFW3_BASIC);
162                                 the_state = NULL;
163                         }
164                 }
165         }
166         kfree(ctx->state_ctx,M_IPFW3_BASIC);
167         ctx->state_ctx = kmalloc(state_hash_size *
168                                 sizeof(struct ipfw_state_context),
169                                 M_IPFW3_BASIC, M_WAITOK | M_ZERO);
170         ctx->state_hash_size = state_hash_size;
171         ifnet_forwardmsg(&nmsg->lmsg, mycpuid + 1);
172 }
173
174
175 /*      prototype of the checker functions      */
176 void check_count(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
177         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
178 void check_skipto(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
179         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
180 void check_forward(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
181         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
182 void check_check_state(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
183         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
184
185 void check_in(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
186         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
187 void check_out(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
188         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
189 void check_via(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
190         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
191 void check_proto(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
192         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
193 void check_prob(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
194         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
195 void check_from(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
196         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
197 void check_from_me(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
198         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
199 void check_from_mask(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
200         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
201 void check_to(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
202         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
203 void check_to_me(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
204         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
205 void check_to_mask(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
206         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
207 void check_keep_state(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
208         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
209 void check_tag(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
210         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
211 void check_untag(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
212         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
213 void check_tagged(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
214         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len);
215
216 /*      prototype of the utility functions      */
217 static struct ip_fw *lookup_next_rule(struct ip_fw *me);
218 static int iface_match(struct ifnet *ifp, ipfw_insn_if *cmd);
219 static __inline int hash_packet(struct ipfw_flow_id *id);
220
221 static __inline int
222 hash_packet(struct ipfw_flow_id *id)
223 {
224         uint32_t i;
225         i = (id->proto) ^ (id->dst_ip) ^ (id->src_ip) ^
226                 (id->dst_port) ^ (id->src_port);
227         i &= state_hash_size - 1;
228         return i;
229 }
230
231 static struct ip_fw *
232 lookup_next_rule(struct ip_fw *me)
233 {
234         struct ip_fw *rule = NULL;
235         ipfw_insn *cmd;
236
237         /* look for action, in case it is a skipto */
238         cmd = ACTION_PTR(me);
239         if ((int)cmd->module == MODULE_BASIC_ID &&
240                 (int)cmd->opcode == O_BASIC_SKIPTO) {
241                 for (rule = me->next; rule; rule = rule->next) {
242                         if (rule->rulenum >= cmd->arg1)
243                                 break;
244                 }
245         }
246         if (rule == NULL) /* failure or not a skipto */
247                 rule = me->next;
248
249         me->next_rule = rule;
250         return rule;
251 }
252
253 /*
254  * when all = 1, it will check all the state_ctx
255  */
256 static struct ip_fw_state *
257 lookup_state(struct ip_fw_args *args, ipfw_insn *cmd, int *limited, int all)
258 {
259         struct ip_fw_state *state = NULL;
260         struct ipfw_context *ctx = ipfw_ctx[mycpuid];
261         struct ipfw_state_context *state_ctx;
262         int start, end, i, count = 0;
263
264         if (all && cmd->arg1) {
265                 start = 0;
266                 end = state_hash_size - 1;
267         } else {
268                 start = hash_packet(&args->f_id);
269                 end = hash_packet(&args->f_id);
270         }
271         for (i = start; i <= end; i++) {
272                 state_ctx = &ctx->state_ctx[i];
273                 if (state_ctx != NULL) {
274                         state = state_ctx->state;
275                         struct ipfw_flow_id     *fid = &args->f_id;
276                         while (state != NULL) {
277                                 if (cmd->arg1) {
278                                         if ((cmd->arg3 == 1 &&
279                                                 fid->src_ip ==
280                                                 state->flow_id.src_ip) ||
281                                                 (cmd->arg3 == 2 &&
282                                                 fid->src_port ==
283                                                 state->flow_id.src_port) ||
284                                                 (cmd->arg3 == 3 &&
285                                                 fid->dst_ip ==
286                                                 state->flow_id.dst_ip) ||
287                                                 (cmd->arg3 == 4 &&
288                                                 fid->dst_port ==
289                                                 state->flow_id.dst_port)) {
290
291                                                 count++;
292                                                 if (count >= cmd->arg1) {
293                                                         *limited = 1;
294                                                         goto done;
295                                                 }
296                                         }
297                                 }
298
299                                 if (fid->proto == state->flow_id.proto) {
300                                         if (fid->src_ip ==
301                                         state->flow_id.src_ip &&
302                                         fid->dst_ip ==
303                                         state->flow_id.dst_ip &&
304                                         (fid->src_port ==
305                                         state->flow_id.src_port ||
306                                         state->flow_id.src_port == 0) &&
307                                         (fid->dst_port ==
308                                         state->flow_id.dst_port ||
309                                         state->flow_id.dst_port == 0)) {
310                                                 goto done;
311                                         }
312                                         if (fid->src_ip ==
313                                         state->flow_id.dst_ip &&
314                                         fid->dst_ip ==
315                                         state->flow_id.src_ip &&
316                                         (fid->src_port ==
317                                         state->flow_id.dst_port ||
318                                         state->flow_id.dst_port == 0) &&
319                                         (fid->dst_port ==
320                                         state->flow_id.src_port ||
321                                         state->flow_id.src_port == 0)) {
322                                                 goto done;
323                                         }
324                                 }
325                                 state = state->next;
326                         }
327                 }
328         }
329 done:
330         return state;
331 }
332
333 static struct ip_fw_state *
334 install_state(struct ip_fw *rule, ipfw_insn *cmd, struct ip_fw_args *args)
335 {
336         struct ip_fw_state *state;
337         struct ipfw_context *ctx = ipfw_ctx[mycpuid];
338         struct ipfw_state_context *state_ctx;
339         state_ctx = &ctx->state_ctx[hash_packet(&args->f_id)];
340         state = kmalloc(sizeof(struct ip_fw_state),
341                         M_IPFW3_BASIC, M_NOWAIT | M_ZERO);
342         if (state == NULL) {
343                 return NULL;
344         }
345         state->stub = rule;
346         state->lifetime = cmd->arg2 == 0 ? state_lifetime : cmd->arg2 ;
347         state->timestamp = time_second;
348         state->expiry = 0;
349         bcopy(&args->f_id,&state->flow_id,sizeof(struct ipfw_flow_id));
350         //append the state into the state chian
351         if (state_ctx->last != NULL)
352                 state_ctx->last->next = state;
353         else
354                 state_ctx->state = state;
355         state_ctx->last = state;
356         state_ctx->count++;
357         return state;
358 }
359
360
361 static int
362 iface_match(struct ifnet *ifp, ipfw_insn_if *cmd)
363 {
364         if (ifp == NULL)        /* no iface with this packet, match fails */
365                 return 0;
366
367         /* Check by name or by IP address */
368         if (cmd->name[0] != '\0') { /* match by name */
369                 /* Check name */
370                 if (cmd->p.glob) {
371                         if (kfnmatch(cmd->name, ifp->if_xname, 0) == 0)
372                                 return(1);
373                 } else {
374                         if (strncmp(ifp->if_xname, cmd->name, IFNAMSIZ) == 0)
375                                 return(1);
376                 }
377         } else {
378                 struct ifaddr_container *ifac;
379
380                 TAILQ_FOREACH(ifac, &ifp->if_addrheads[mycpuid], ifa_link) {
381                         struct ifaddr *ia = ifac->ifa;
382
383                         if (ia->ifa_addr == NULL)
384                                 continue;
385                         if (ia->ifa_addr->sa_family != AF_INET)
386                                 continue;
387                         if (cmd->p.ip.s_addr ==
388                                 ((struct sockaddr_in *)
389                                 (ia->ifa_addr))->sin_addr.s_addr)
390                                         return(1);      /* match */
391
392                 }
393         }
394         return 0;       /* no match, fail ... */
395 }
396
397 /* implimentation of the checker functions */
398 void
399 check_count(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
400         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
401 {
402         (*f)->pcnt++;
403         (*f)->bcnt += ip_len;
404         (*f)->timestamp = time_second;
405         *cmd_ctl = IP_FW_CTL_NEXT;
406 }
407
408 void
409 check_skipto(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
410         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
411 {
412         (*f)->pcnt++;
413         (*f)->bcnt += ip_len;
414         (*f)->timestamp = time_second;
415         if ((*f)->next_rule == NULL)
416                 lookup_next_rule(*f);
417
418         *cmd_ctl = IP_FW_CTL_AGAIN;
419 }
420
421 void
422 check_forward(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
423         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
424 {
425         struct sockaddr_in *sin, *sa;
426         struct m_tag *mtag;
427
428         if ((*args)->eh) {      /* not valid on layer2 pkts */
429                 *cmd_ctl=IP_FW_CTL_NEXT;
430                 return;
431         }
432
433         (*f)->pcnt++;
434         (*f)->bcnt += ip_len;
435         (*f)->timestamp = time_second;
436         if ((*f)->next_rule == NULL)
437                 lookup_next_rule(*f);
438
439         mtag = m_tag_get(PACKET_TAG_IPFORWARD,
440                         sizeof(*sin), M_NOWAIT);
441         if (mtag == NULL) {
442                 *cmd_val = IP_FW_DENY;
443                 *cmd_ctl = IP_FW_CTL_DONE;
444                 return;
445         }
446         sin = m_tag_data(mtag);
447         sa = &((ipfw_insn_sa *)cmd)->sa;
448         /* arg3: count of the dest, arg1: type of fwd */
449         int i = 0;
450         if(cmd->arg3 > 1) {
451                 if (cmd->arg1 == 0) {           /* type: random */
452                         i = krandom() % cmd->arg3;
453                 } else if (cmd->arg1 == 1) {    /* type: round-robin */
454                         i = cmd->arg2++ % cmd->arg3;
455                 } else if (cmd->arg1 == 2) {    /* type: sticky */
456                         struct ip *ip = mtod((*args)->m, struct ip *);
457                         i = ip->ip_src.s_addr & (cmd->arg3 - 1);
458                 }
459                 sa += i;
460         }
461         *sin = *sa;     /* apply the destination */
462         m_tag_prepend((*args)->m, mtag);
463         (*args)->m->m_pkthdr.fw_flags |= IPFORWARD_MBUF_TAGGED;
464         (*args)->m->m_pkthdr.fw_flags &= ~BRIDGE_MBUF_TAGGED;
465         *cmd_ctl = IP_FW_CTL_DONE;
466         *cmd_val = IP_FW_PASS;
467 }
468
469 void
470 check_check_state(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
471         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
472 {
473         struct ip_fw_state *state=NULL;
474         int limited = 0 ;
475         state = lookup_state(*args, cmd, &limited, 0);
476         if (state != NULL) {
477                 state->pcnt++;
478                 state->bcnt += ip_len;
479                 state->timestamp = time_second;
480                 (*f)->pcnt++;
481                 (*f)->bcnt += ip_len;
482                 (*f)->timestamp = time_second;
483                 *f = state->stub;
484                 *cmd_ctl = IP_FW_CTL_CHK_STATE;
485         } else {
486                 *cmd_ctl = IP_FW_CTL_NEXT;
487         }
488 }
489
490 void
491 check_in(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
492         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
493 {
494         *cmd_ctl = IP_FW_CTL_NO;
495         *cmd_val = ((*args)->oif == NULL);
496 }
497
498 void
499 check_out(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
500         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
501 {
502         *cmd_ctl = IP_FW_CTL_NO;
503         *cmd_val = ((*args)->oif != NULL);
504 }
505
506 void
507 check_via(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
508         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
509 {
510         *cmd_ctl = IP_FW_CTL_NO;
511         *cmd_val = iface_match((*args)->oif ?
512                         (*args)->oif : (*args)->m->m_pkthdr.rcvif,
513                         (ipfw_insn_if *)cmd);
514 }
515
516 void
517 check_proto(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
518         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
519 {
520         *cmd_ctl = IP_FW_CTL_NO;
521         *cmd_val = ((*args)->f_id.proto == cmd->arg1);
522 }
523
524 void
525 check_prob(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
526         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
527 {
528         *cmd_ctl = IP_FW_CTL_NO;
529         *cmd_val = (krandom() % 100) < cmd->arg1;
530 }
531
532 void
533 check_from(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
534         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
535 {
536         struct in_addr src_ip;
537         u_int hlen = 0;
538         struct mbuf *m = (*args)->m;
539         struct ip *ip = mtod(m, struct ip *);
540         src_ip = ip->ip_src;
541         if ((*args)->eh == NULL ||
542                 (m->m_pkthdr.len >= sizeof(struct ip) &&
543                 ntohs((*args)->eh->ether_type) == ETHERTYPE_IP)) {
544                 hlen = ip->ip_hl << 2;
545         }
546         *cmd_val = (hlen > 0 &&
547                         ((ipfw_insn_ip *)cmd)->addr.s_addr == src_ip.s_addr);
548         *cmd_ctl = IP_FW_CTL_NO;
549 }
550
551 void
552 check_from_me(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
553         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
554 {
555         struct in_addr src_ip;
556         u_int hlen = 0;
557         struct mbuf *m = (*args)->m;
558         struct ip *ip = mtod(m, struct ip *);
559         src_ip = ip->ip_src;
560         if ((*args)->eh == NULL ||
561                 (m->m_pkthdr.len >= sizeof(struct ip) &&
562                 ntohs((*args)->eh->ether_type) == ETHERTYPE_IP)) {
563                 hlen = ip->ip_hl << 2;
564         }
565         *cmd_ctl = IP_FW_CTL_NO;
566         if (hlen > 0) {
567                 struct ifnet *tif;
568                 tif = INADDR_TO_IFP(&src_ip);
569                 *cmd_val = (tif != NULL);
570         } else {
571                 *cmd_val = IP_FW_NOT_MATCH;
572         }
573 }
574
575 void
576 check_from_mask(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
577         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
578 {
579         struct in_addr src_ip;
580         u_int hlen = 0;
581         struct mbuf *m = (*args)->m;
582         struct ip *ip = mtod(m, struct ip *);
583         src_ip = ip->ip_src;
584         if ((*args)->eh == NULL ||
585                 (m->m_pkthdr.len >= sizeof(struct ip) &&
586                 ntohs((*args)->eh->ether_type) == ETHERTYPE_IP)) {
587                 hlen = ip->ip_hl << 2;
588         }
589
590         *cmd_ctl = IP_FW_CTL_NO;
591         *cmd_val = (hlen > 0 &&
592                         ((ipfw_insn_ip *)cmd)->addr.s_addr ==
593                         (src_ip.s_addr &
594                         ((ipfw_insn_ip *)cmd)->mask.s_addr));
595 }
596
597 void
598 check_to(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
599         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
600 {
601         struct in_addr dst_ip;
602         u_int hlen = 0;
603         struct mbuf *m = (*args)->m;
604         struct ip *ip = mtod(m, struct ip *);
605         dst_ip = ip->ip_dst;
606         if ((*args)->eh == NULL ||
607                 (m->m_pkthdr.len >= sizeof(struct ip) &&
608                  ntohs((*args)->eh->ether_type) == ETHERTYPE_IP)) {
609                 hlen = ip->ip_hl << 2;
610         }
611         *cmd_val = (hlen > 0 &&
612                         ((ipfw_insn_ip *)cmd)->addr.s_addr == dst_ip.s_addr);
613         *cmd_ctl = IP_FW_CTL_NO;
614 }
615
616 void
617 check_to_me(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
618         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
619 {
620         struct in_addr dst_ip;
621         u_int hlen = 0;
622         struct mbuf *m = (*args)->m;
623         struct ip *ip = mtod(m, struct ip *);
624         dst_ip = ip->ip_src;
625         if ((*args)->eh == NULL ||
626                 (m->m_pkthdr.len >= sizeof(struct ip) &&
627                 ntohs((*args)->eh->ether_type) == ETHERTYPE_IP)) {
628                 hlen = ip->ip_hl << 2;
629         }
630         *cmd_ctl = IP_FW_CTL_NO;
631         if (hlen > 0) {
632                 struct ifnet *tif;
633                 tif = INADDR_TO_IFP(&dst_ip);
634                 *cmd_val = (tif != NULL);
635         } else {
636                 *cmd_val = IP_FW_NOT_MATCH;
637         }
638 }
639
640 void
641 check_to_mask(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
642         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
643 {
644         struct in_addr dst_ip;
645         u_int hlen = 0;
646         struct mbuf *m = (*args)->m;
647         struct ip *ip = mtod(m, struct ip *);
648         dst_ip = ip->ip_src;
649         if ((*args)->eh == NULL ||
650                 (m->m_pkthdr.len >= sizeof(struct ip) &&
651                 ntohs((*args)->eh->ether_type) == ETHERTYPE_IP)) {
652                 hlen = ip->ip_hl << 2;
653         }
654
655         *cmd_ctl = IP_FW_CTL_NO;
656         *cmd_val = (hlen > 0 &&
657                         ((ipfw_insn_ip *)cmd)->addr.s_addr ==
658                         (dst_ip.s_addr &
659                         ((ipfw_insn_ip *)cmd)->mask.s_addr));
660 }
661
662 void
663 check_keep_state(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
664         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
665 {
666         struct ip_fw_state *state;
667         int limited = 0;
668
669         *cmd_ctl = IP_FW_CTL_NO;
670         state = lookup_state(*args, cmd, &limited, 1);
671         if (limited == 1) {
672                 *cmd_val = IP_FW_NOT_MATCH;
673         } else {
674                 if (state == NULL)
675                         state = install_state(*f, cmd, *args);
676
677                 if (state != NULL) {
678                         state->pcnt++;
679                         state->bcnt += ip_len;
680                         state->timestamp = time_second;
681                         *cmd_val = IP_FW_MATCH;
682                 } else {
683                         *cmd_val = IP_FW_NOT_MATCH;
684                 }
685         }
686 }
687
688 void
689 check_tag(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
690         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
691 {
692         struct m_tag *mtag = m_tag_locate((*args)->m,
693                         MTAG_IPFW, cmd->arg1, NULL);
694         if (mtag == NULL) {
695                 mtag = m_tag_alloc(MTAG_IPFW,cmd->arg1, 0, M_NOWAIT);
696                 if (mtag != NULL)
697                         m_tag_prepend((*args)->m, mtag);
698
699         }
700         (*f)->pcnt++;
701         (*f)->bcnt += ip_len;
702         (*f)->timestamp = time_second;
703         *cmd_ctl = IP_FW_CTL_NEXT;
704 }
705
706 void
707 check_untag(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
708         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
709 {
710         struct m_tag *mtag = m_tag_locate((*args)->m,
711                         MTAG_IPFW, cmd->arg1, NULL);
712         if (mtag != NULL)
713                 m_tag_delete((*args)->m, mtag);
714
715         (*f)->pcnt++;
716         (*f)->bcnt += ip_len;
717         (*f)->timestamp = time_second;
718         *cmd_ctl = IP_FW_CTL_NEXT;
719 }
720
721 void
722 check_tagged(int *cmd_ctl, int *cmd_val, struct ip_fw_args **args,
723         struct ip_fw **f, ipfw_insn *cmd, uint16_t ip_len)
724 {
725         *cmd_ctl = IP_FW_CTL_NO;
726         if (m_tag_locate( (*args)->m, MTAG_IPFW,cmd->arg1, NULL) != NULL )
727                 *cmd_val = IP_FW_MATCH;
728         else
729                 *cmd_val = IP_FW_NOT_MATCH;
730 }
731
732 static void
733 ipfw_basic_add_state(struct ipfw_ioc_state *ioc_state)
734 {
735         struct ip_fw_state *state;
736         struct ipfw_context *ctx = ipfw_ctx[mycpuid];
737         struct ipfw_state_context *state_ctx;
738         state_ctx = &ctx->state_ctx[hash_packet(&(ioc_state->flow_id))];
739         state = kmalloc(sizeof(struct ip_fw_state),
740                         M_IPFW3_BASIC, M_WAITOK | M_ZERO);
741         struct ip_fw *rule = ctx->ipfw_rule_chain;
742         while (rule != NULL) {
743                 if (rule->rulenum == ioc_state->rulenum) {
744                         break;
745                 }
746                 rule = rule->next;
747         }
748         if (rule == NULL)
749                 return;
750
751         state->stub = rule;
752
753         state->lifetime = ioc_state->lifetime == 0 ?
754                 state_lifetime : ioc_state->lifetime ;
755         state->timestamp = time_second;
756         state->expiry = ioc_state->expiry;
757         bcopy(&ioc_state->flow_id, &state->flow_id,
758                         sizeof(struct ipfw_flow_id));
759         //append the state into the state chian
760         if (state_ctx->last != NULL)
761                 state_ctx->last->next = state;
762         else
763                 state_ctx->state = state;
764
765         state_ctx->last = state;
766         state_ctx->count++;
767 }
768
769 /*
770  * if rule is NULL
771  *              flush all states
772  * else
773  *              flush states which stub is the rule
774  */
775 static void
776 ipfw_basic_flush_state(struct ip_fw *rule)
777 {
778         struct ipfw_state_context *state_ctx;
779         struct ip_fw_state *state,*the_state, *prev_state;
780         struct ipfw_context *ctx;
781         int i;
782
783         ctx = ipfw_ctx[mycpuid];
784         for (i = 0; i < state_hash_size; i++) {
785                 state_ctx = &ctx->state_ctx[i];
786                 if (state_ctx != NULL) {
787                         state = state_ctx->state;
788                         prev_state = NULL;
789                         while (state != NULL) {
790                                 if (rule != NULL && state->stub != rule) {
791                                         prev_state = state;
792                                         state = state->next;
793                                 } else {
794                                         if (prev_state == NULL)
795                                                 state_ctx->state = state->next;
796                                         else
797                                                 prev_state->next = state->next;
798
799                                         the_state = state;
800                                         state = state->next;
801                                         kfree(the_state, M_IPFW3_BASIC);
802                                         state_ctx->count--;
803                                         if (state == NULL)
804                                                 state_ctx->last = prev_state;
805
806                                 }
807                         }
808                 }
809         }
810 }
811
812 /*
813  * clean up expired state in every tick
814  */
815 static void
816 ipfw_cleanup_expired_state(netmsg_t nmsg)
817 {
818         struct ip_fw_state *state,*the_state,*prev_state;
819         struct ipfw_context *ctx = ipfw_ctx[mycpuid];
820         struct ipfw_state_context *state_ctx;
821         int i;
822
823         for (i = 0; i < state_hash_size; i++) {
824                 prev_state = NULL;
825                 state_ctx = &(ctx->state_ctx[i]);
826                 if (ctx->state_ctx != NULL) {
827                         state = state_ctx->state;
828                         while (state != NULL) {
829                                 if (IS_EXPIRED(state)) {
830                                         if (prev_state == NULL)
831                                                 state_ctx->state = state->next;
832                                         else
833                                                 prev_state->next = state->next;
834
835                                         the_state =state;
836                                         state = state->next;
837
838                                         if (the_state == state_ctx->last)
839                                                 state_ctx->last = NULL;
840
841
842                                         kfree(the_state, M_IPFW3_BASIC);
843                                         state_ctx->count--;
844                                 } else {
845                                         prev_state = state;
846                                         state = state->next;
847                                 }
848                         }
849                 }
850         }
851         ifnet_forwardmsg(&nmsg->lmsg, mycpuid + 1);
852 }
853
854 static void
855 ipfw_tick(void *dummy __unused)
856 {
857         struct lwkt_msg *lmsg = &ipfw_timeout_netmsg.lmsg;
858         KKASSERT(mycpuid == IPFW_CFGCPUID);
859
860         crit_enter();
861         KKASSERT(lmsg->ms_flags & MSGF_DONE);
862         if (IPFW_BASIC_LOADED) {
863                 lwkt_sendmsg_oncpu(IPFW_CFGPORT, lmsg);
864                 /* ipfw_timeout_netmsg's handler reset this callout */
865         }
866         crit_exit();
867
868         struct netmsg_base *msg;
869         struct netmsg_base the_msg;
870         msg = &the_msg;
871         bzero(msg,sizeof(struct netmsg_base));
872
873         netmsg_init(msg, NULL, &curthread->td_msgport, 0,
874                         ipfw_cleanup_expired_state);
875         ifnet_domsg(&msg->lmsg, 0);
876 }
877
878 static void
879 ipfw_tick_dispatch(netmsg_t nmsg)
880 {
881         IPFW_ASSERT_CFGPORT(&curthread->td_msgport);
882         KKASSERT(IPFW_BASIC_LOADED);
883
884         /* Reply ASAP */
885         crit_enter();
886         lwkt_replymsg(&nmsg->lmsg, 0);
887         crit_exit();
888
889         callout_reset(&ipfw_tick_callout,
890                         state_expiry_check_interval * hz, ipfw_tick, NULL);
891 }
892
893 static void
894 ipfw_basic_init_dispatch(netmsg_t nmsg)
895 {
896         IPFW_ASSERT_CFGPORT(&curthread->td_msgport);
897         KKASSERT(IPFW3_LOADED);
898
899         int error = 0;
900         callout_init_mp(&ipfw_tick_callout);
901         netmsg_init(&ipfw_timeout_netmsg, NULL, &netisr_adone_rport,
902                         MSGF_DROPABLE | MSGF_PRIORITY, ipfw_tick_dispatch);
903         callout_reset(&ipfw_tick_callout,
904                         state_expiry_check_interval * hz, ipfw_tick, NULL);
905         lwkt_replymsg(&nmsg->lmsg, error);
906         ip_fw_basic_loaded=1;
907 }
908
909 static int
910 ipfw_basic_init(void)
911 {
912         ipfw_basic_flush_state_prt = ipfw_basic_flush_state;
913         ipfw_basic_append_state_prt = ipfw_basic_add_state;
914
915         register_ipfw_module(MODULE_BASIC_ID, MODULE_BASIC_NAME);
916         register_ipfw_filter_funcs(MODULE_BASIC_ID, O_BASIC_COUNT,
917                         (filter_func)check_count);
918         register_ipfw_filter_funcs(MODULE_BASIC_ID, O_BASIC_SKIPTO,
919                         (filter_func)check_skipto);
920         register_ipfw_filter_funcs(MODULE_BASIC_ID, O_BASIC_FORWARD,
921                         (filter_func)check_forward);
922         register_ipfw_filter_funcs(MODULE_BASIC_ID, O_BASIC_KEEP_STATE,
923                         (filter_func)check_keep_state);
924         register_ipfw_filter_funcs(MODULE_BASIC_ID, O_BASIC_CHECK_STATE,
925                         (filter_func)check_check_state);
926
927         register_ipfw_filter_funcs(MODULE_BASIC_ID,
928                         O_BASIC_IN, (filter_func)check_in);
929         register_ipfw_filter_funcs(MODULE_BASIC_ID,
930                         O_BASIC_OUT, (filter_func)check_out);
931         register_ipfw_filter_funcs(MODULE_BASIC_ID,
932                         O_BASIC_VIA, (filter_func)check_via);
933         register_ipfw_filter_funcs(MODULE_BASIC_ID,
934                         O_BASIC_XMIT, (filter_func)check_via);
935         register_ipfw_filter_funcs(MODULE_BASIC_ID,
936                         O_BASIC_RECV, (filter_func)check_via);
937
938         register_ipfw_filter_funcs(MODULE_BASIC_ID,
939                         O_BASIC_PROTO, (filter_func)check_proto);
940         register_ipfw_filter_funcs(MODULE_BASIC_ID,
941                         O_BASIC_PROB, (filter_func)check_prob);
942         register_ipfw_filter_funcs(MODULE_BASIC_ID,
943                         O_BASIC_IP_SRC, (filter_func)check_from);
944         register_ipfw_filter_funcs(MODULE_BASIC_ID,
945                         O_BASIC_IP_SRC_ME, (filter_func)check_from_me);
946         register_ipfw_filter_funcs(MODULE_BASIC_ID,
947                         O_BASIC_IP_SRC_MASK, (filter_func)check_from_mask);
948         register_ipfw_filter_funcs(MODULE_BASIC_ID,
949                         O_BASIC_IP_DST, (filter_func)check_to);
950         register_ipfw_filter_funcs(MODULE_BASIC_ID,
951                         O_BASIC_IP_DST_ME, (filter_func)check_to_me);
952         register_ipfw_filter_funcs(MODULE_BASIC_ID,
953                         O_BASIC_IP_DST_MASK, (filter_func)check_to_mask);
954         register_ipfw_filter_funcs(MODULE_BASIC_ID,
955                         O_BASIC_TAG, (filter_func)check_tag);
956         register_ipfw_filter_funcs(MODULE_BASIC_ID,
957                         O_BASIC_UNTAG, (filter_func)check_untag);
958         register_ipfw_filter_funcs(MODULE_BASIC_ID,
959                         O_BASIC_TAGGED, (filter_func)check_tagged);
960
961         int cpu;
962         struct ipfw_context *ctx;
963
964         for (cpu = 0; cpu < ncpus; cpu++) {
965                 ctx = ipfw_ctx[cpu];
966                 if (ctx != NULL) {
967                         ctx->state_ctx = kmalloc(state_hash_size *
968                                         sizeof(struct ipfw_state_context),
969                                         M_IPFW3_BASIC, M_WAITOK | M_ZERO);
970                         ctx->state_hash_size = state_hash_size;
971                 }
972         }
973
974         struct netmsg_base smsg;
975         netmsg_init(&smsg, NULL, &curthread->td_msgport,
976                         0, ipfw_basic_init_dispatch);
977         lwkt_domsg(IPFW_CFGPORT, &smsg.lmsg, 0);
978         return 0;
979 }
980
981 static void
982 ipfw_basic_stop_dispatch(netmsg_t nmsg)
983 {
984         IPFW_ASSERT_CFGPORT(&curthread->td_msgport);
985         KKASSERT(IPFW3_LOADED);
986         int error = 0;
987         callout_stop(&ipfw_tick_callout);
988         netmsg_service_sync();
989         crit_enter();
990         lwkt_dropmsg(&ipfw_timeout_netmsg.lmsg);
991         crit_exit();
992         lwkt_replymsg(&nmsg->lmsg, error);
993         ip_fw_basic_loaded=0;
994 }
995
996 static int
997 ipfw_basic_stop(void)
998 {
999         int cpu,i;
1000         struct ipfw_state_context *state_ctx;
1001         struct ip_fw_state *state,*the_state;
1002         struct ipfw_context *ctx;
1003         if (unregister_ipfw_module(MODULE_BASIC_ID) ==0 ) {
1004                 ipfw_basic_flush_state_prt = NULL;
1005                 ipfw_basic_append_state_prt = NULL;
1006
1007                 for (cpu = 0; cpu < ncpus; cpu++) {
1008                         ctx = ipfw_ctx[cpu];
1009                         if (ctx != NULL) {
1010                                 for (i = 0; i < state_hash_size; i++) {
1011                                         state_ctx = &ctx->state_ctx[i];
1012                                         if (state_ctx != NULL) {
1013                                                 state = state_ctx->state;
1014                                                 while (state != NULL) {
1015                                                         the_state = state;
1016                                                         state = state->next;
1017                                                         if (the_state ==
1018                                                                 state_ctx->last)
1019                                                         state_ctx->last = NULL;
1020
1021                                                         kfree(the_state,
1022                                                                 M_IPFW3_BASIC);
1023                                                 }
1024                                         }
1025                                 }
1026                                 ctx->state_hash_size = 0;
1027                                 kfree(ctx->state_ctx, M_IPFW3_BASIC);
1028                                 ctx->state_ctx = NULL;
1029                         }
1030                 }
1031                 struct netmsg_base smsg;
1032                 netmsg_init(&smsg, NULL, &curthread->td_msgport,
1033                                 0, ipfw_basic_stop_dispatch);
1034                 return lwkt_domsg(IPFW_CFGPORT, &smsg.lmsg, 0);
1035         }
1036         return 1;
1037 }
1038
1039
1040 static int
1041 ipfw3_basic_modevent(module_t mod, int type, void *data)
1042 {
1043         int err;
1044         switch (type) {
1045                 case MOD_LOAD:
1046                         err = ipfw_basic_init();
1047                         break;
1048                 case MOD_UNLOAD:
1049                         err = ipfw_basic_stop();
1050                         break;
1051                 default:
1052                         err = 1;
1053         }
1054         return err;
1055 }
1056
1057 static moduledata_t ipfw3_basic_mod = {
1058         "ipfw3_basic",
1059         ipfw3_basic_modevent,
1060         NULL
1061 };
1062 DECLARE_MODULE(ipfw3_basic, ipfw3_basic_mod, SI_SUB_PROTO_END, SI_ORDER_ANY);
1063 MODULE_DEPEND(ipfw3_basic, ipfw3, 1, 1, 1);
1064 MODULE_VERSION(ipfw3_basic, 1);