net/ipv6/udp.c
changeset 2 d1f6d8b6f81c
parent 0 aa628870c1d3
--- a/net/ipv6/udp.c	Thu Apr 09 12:06:38 2009 +0200
+++ b/net/ipv6/udp.c	Thu Apr 09 12:07:21 2009 +0200
@@ -54,62 +54,91 @@
 	return udp_lib_get_port(sk, snum, ipv6_rcv_saddr_equal);
 }
 
+static inline int compute_score(struct sock *sk, struct net *net,
+				unsigned short hnum,
+				struct in6_addr *saddr, __be16 sport,
+				struct in6_addr *daddr, __be16 dport,
+				int dif)
+{
+	int score = -1;
+
+	if (net_eq(sock_net(sk), net) && sk->sk_hash == hnum &&
+			sk->sk_family == PF_INET6) {
+		struct ipv6_pinfo *np = inet6_sk(sk);
+		struct inet_sock *inet = inet_sk(sk);
+
+		score = 0;
+		if (inet->dport) {
+			if (inet->dport != sport)
+				return -1;
+			score++;
+		}
+		if (!ipv6_addr_any(&np->rcv_saddr)) {
+			if (!ipv6_addr_equal(&np->rcv_saddr, daddr))
+				return -1;
+			score++;
+		}
+		if (!ipv6_addr_any(&np->daddr)) {
+			if (!ipv6_addr_equal(&np->daddr, saddr))
+				return -1;
+			score++;
+		}
+		if (sk->sk_bound_dev_if) {
+			if (sk->sk_bound_dev_if != dif)
+				return -1;
+			score++;
+		}
+	}
+	return score;
+}
+
 static struct sock *__udp6_lib_lookup(struct net *net,
 				      struct in6_addr *saddr, __be16 sport,
 				      struct in6_addr *daddr, __be16 dport,
-				      int dif, struct hlist_head udptable[])
+				      int dif, struct udp_table *udptable)
 {
-	struct sock *sk, *result = NULL;
-	struct hlist_node *node;
+	struct sock *sk, *result;
+	struct hlist_nulls_node *node;
 	unsigned short hnum = ntohs(dport);
-	int badness = -1;
-
-	read_lock(&udp_hash_lock);
-	sk_for_each(sk, node, &udptable[udp_hashfn(net, hnum)]) {
-		struct inet_sock *inet = inet_sk(sk);
+	unsigned int hash = udp_hashfn(net, hnum);
+	struct udp_hslot *hslot = &udptable->hash[hash];
+	int score, badness;
 
-		if (net_eq(sock_net(sk), net) && sk->sk_hash == hnum &&
-				sk->sk_family == PF_INET6) {
-			struct ipv6_pinfo *np = inet6_sk(sk);
-			int score = 0;
-			if (inet->dport) {
-				if (inet->dport != sport)
-					continue;
-				score++;
-			}
-			if (!ipv6_addr_any(&np->rcv_saddr)) {
-				if (!ipv6_addr_equal(&np->rcv_saddr, daddr))
-					continue;
-				score++;
-			}
-			if (!ipv6_addr_any(&np->daddr)) {
-				if (!ipv6_addr_equal(&np->daddr, saddr))
-					continue;
-				score++;
-			}
-			if (sk->sk_bound_dev_if) {
-				if (sk->sk_bound_dev_if != dif)
-					continue;
-				score++;
-			}
-			if (score == 4) {
-				result = sk;
-				break;
-			} else if (score > badness) {
-				result = sk;
-				badness = score;
-			}
+	rcu_read_lock();
+begin:
+	result = NULL;
+	badness = -1;
+	sk_nulls_for_each_rcu(sk, node, &hslot->head) {
+		score = compute_score(sk, net, hnum, saddr, sport, daddr, dport, dif);
+		if (score > badness) {
+			result = sk;
+			badness = score;
 		}
 	}
-	if (result)
-		sock_hold(result);
-	read_unlock(&udp_hash_lock);
+	/*
+	 * if the nulls value we got at the end of this lookup is
+	 * not the expected one, we must restart lookup.
+	 * We probably met an item that was moved to another chain.
+	 */
+	if (get_nulls_value(node) != hash)
+		goto begin;
+
+	if (result) {
+		if (unlikely(!atomic_inc_not_zero(&result->sk_refcnt)))
+			result = NULL;
+		else if (unlikely(compute_score(result, net, hnum, saddr, sport,
+					daddr, dport, dif) < badness)) {
+			sock_put(result);
+			goto begin;
+		}
+	}
+	rcu_read_unlock();
 	return result;
 }
 
 static struct sock *__udp6_lib_lookup_skb(struct sk_buff *skb,
 					  __be16 sport, __be16 dport,
-					  struct hlist_head udptable[])
+					  struct udp_table *udptable)
 {
 	struct sock *sk;
 	struct ipv6hdr *iph = ipv6_hdr(skb);
@@ -253,7 +282,7 @@
 
 void __udp6_lib_err(struct sk_buff *skb, struct inet6_skb_parm *opt,
 		    int type, int code, int offset, __be32 info,
-		    struct hlist_head udptable[]                    )
+		    struct udp_table *udptable)
 {
 	struct ipv6_pinfo *np;
 	struct ipv6hdr *hdr = (struct ipv6hdr*)skb->data;
@@ -289,7 +318,7 @@
 				 struct inet6_skb_parm *opt, int type,
 				 int code, int offset, __be32 info     )
 {
-	__udp6_lib_err(skb, opt, type, code, offset, info, udp_hash);
+	__udp6_lib_err(skb, opt, type, code, offset, info, &udp_table);
 }
 
 int udpv6_queue_rcv_skb(struct sock * sk, struct sk_buff *skb)
@@ -347,11 +376,11 @@
 				      __be16 rmt_port, struct in6_addr *rmt_addr,
 				      int dif)
 {
-	struct hlist_node *node;
+	struct hlist_nulls_node *node;
 	struct sock *s = sk;
 	unsigned short num = ntohs(loc_port);
 
-	sk_for_each_from(s, node) {
+	sk_nulls_for_each_from(s, node) {
 		struct inet_sock *inet = inet_sk(s);
 
 		if (!net_eq(sock_net(s), net))
@@ -388,14 +417,15 @@
  */
 static int __udp6_lib_mcast_deliver(struct net *net, struct sk_buff *skb,
 		struct in6_addr *saddr, struct in6_addr *daddr,
-		struct hlist_head udptable[])
+		struct udp_table *udptable)
 {
 	struct sock *sk, *sk2;
 	const struct udphdr *uh = udp_hdr(skb);
+	struct udp_hslot *hslot = &udptable->hash[udp_hashfn(net, ntohs(uh->dest))];
 	int dif;
 
-	read_lock(&udp_hash_lock);
-	sk = sk_head(&udptable[udp_hashfn(net, ntohs(uh->dest))]);
+	spin_lock(&hslot->lock);
+	sk = sk_nulls_head(&hslot->head);
 	dif = inet6_iif(skb);
 	sk = udp_v6_mcast_next(net, sk, uh->dest, daddr, uh->source, saddr, dif);
 	if (!sk) {
@@ -404,7 +434,7 @@
 	}
 
 	sk2 = sk;
-	while ((sk2 = udp_v6_mcast_next(net, sk_next(sk2), uh->dest, daddr,
+	while ((sk2 = udp_v6_mcast_next(net, sk_nulls_next(sk2), uh->dest, daddr,
 					uh->source, saddr, dif))) {
 		struct sk_buff *buff = skb_clone(skb, GFP_ATOMIC);
 		if (buff) {
@@ -423,7 +453,7 @@
 		sk_add_backlog(sk, skb);
 	bh_unlock_sock(sk);
 out:
-	read_unlock(&udp_hash_lock);
+	spin_unlock(&hslot->lock);
 	return 0;
 }
 
@@ -461,7 +491,7 @@
 	return 0;
 }
 
-int __udp6_lib_rcv(struct sk_buff *skb, struct hlist_head udptable[],
+int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
 		   int proto)
 {
 	struct sock *sk;
@@ -558,7 +588,7 @@
 
 static __inline__ int udpv6_rcv(struct sk_buff *skb)
 {
-	return __udp6_lib_rcv(skb, udp_hash, IPPROTO_UDP);
+	return __udp6_lib_rcv(skb, &udp_table, IPPROTO_UDP);
 }
 
 /*
@@ -763,6 +793,9 @@
 	if (!fl.oif)
 		fl.oif = sk->sk_bound_dev_if;
 
+	if (!fl.oif)
+		fl.oif = np->sticky_pktinfo.ipi6_ifindex;
+
 	if (msg->msg_controllen) {
 		opt = &opt_space;
 		memset(opt, 0, sizeof(struct ipv6_txoptions));
@@ -819,7 +852,8 @@
 	if (final_p)
 		ipv6_addr_copy(&fl.fl6_dst, final_p);
 
-	if ((err = __xfrm_lookup(&dst, &fl, sk, XFRM_LOOKUP_WAIT)) < 0) {
+	err = __xfrm_lookup(sock_net(sk), &dst, &fl, sk, XFRM_LOOKUP_WAIT);
+	if (err < 0) {
 		if (err == -EREMOTE)
 			err = ip6_dst_blackhole(sk, &dst, &fl);
 		if (err < 0)
@@ -1022,7 +1056,7 @@
 static struct udp_seq_afinfo udp6_seq_afinfo = {
 	.name		= "udp6",
 	.family		= AF_INET6,
-	.hashtable	= udp_hash,
+	.udp_table	= &udp_table,
 	.seq_fops	= {
 		.owner	=	THIS_MODULE,
 	},
@@ -1064,7 +1098,8 @@
 	.sysctl_wmem	   = &sysctl_udp_wmem_min,
 	.sysctl_rmem	   = &sysctl_udp_rmem_min,
 	.obj_size	   = sizeof(struct udp6_sock),
-	.h.udp_hash	   = udp_hash,
+	.slab_flags	   = SLAB_DESTROY_BY_RCU,
+	.h.udp_table	   = &udp_table,
 #ifdef CONFIG_COMPAT
 	.compat_setsockopt = compat_udpv6_setsockopt,
 	.compat_getsockopt = compat_udpv6_getsockopt,