diff options
-rw-r--r-- | net/netlink/af_netlink.c | 69 |
1 files changed, 51 insertions, 18 deletions
diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c index 444ed223ee4..58d4ca42ac3 100644 --- a/net/netlink/af_netlink.c +++ b/net/netlink/af_netlink.c @@ -60,21 +60,24 @@ #include <net/scm.h> #define Nprintk(a...) +#define NLGRPSZ(x) (ALIGN(x, sizeof(unsigned long) * 8) / 8) struct netlink_sock { /* struct sock has to be the first member of netlink_sock */ struct sock sk; u32 pid; - unsigned int groups; u32 dst_pid; u32 dst_group; + u32 flags; + u32 subscriptions; + u32 ngroups; + unsigned long *groups; unsigned long state; wait_queue_head_t wait; struct netlink_callback *cb; spinlock_t cb_lock; void (*data_ready)(struct sock *sk, int bytes); struct module *module; - u32 flags; }; #define NETLINK_KERNEL_SOCKET 0x1 @@ -101,6 +104,7 @@ struct netlink_table { struct nl_pid_hash hash; struct hlist_head mc_list; unsigned int nl_nonroot; + unsigned int groups; struct module *module; int registered; }; @@ -138,6 +142,7 @@ static void netlink_sock_destruct(struct sock *sk) BUG_TRAP(!atomic_read(&sk->sk_rmem_alloc)); BUG_TRAP(!atomic_read(&sk->sk_wmem_alloc)); BUG_TRAP(!nlk_sk(sk)->cb); + BUG_TRAP(!nlk_sk(sk)->groups); } /* This lock without WQ_FLAG_EXCLUSIVE is good on UP and it is _very_ bad on SMP. @@ -333,7 +338,7 @@ static void netlink_remove(struct sock *sk) netlink_table_grab(); if (sk_del_node_init(sk)) nl_table[sk->sk_protocol].hash.entries--; - if (nlk_sk(sk)->groups) + if (nlk_sk(sk)->subscriptions) __sk_del_bind_node(sk); netlink_table_ungrab(); } @@ -369,6 +374,8 @@ static int __netlink_create(struct socket *sock, int protocol) static int netlink_create(struct socket *sock, int protocol) { struct module *module = NULL; + struct netlink_sock *nlk; + unsigned int groups; int err = 0; sock->state = SS_UNCONNECTED; @@ -392,15 +399,23 @@ static int netlink_create(struct socket *sock, int protocol) module = nl_table[protocol].module; else err = -EPROTONOSUPPORT; + groups = nl_table[protocol].groups; netlink_unlock_table(); - if (err) - goto out; + if (err || (err = __netlink_create(sock, protocol) < 0)) + goto out_module; + + nlk = nlk_sk(sock->sk); - if ((err = __netlink_create(sock, protocol) < 0)) + nlk->groups = kmalloc(NLGRPSZ(groups), GFP_KERNEL); + if (nlk->groups == NULL) { + err = -ENOMEM; goto out_module; + } + memset(nlk->groups, 0, NLGRPSZ(groups)); + nlk->ngroups = groups; - nlk_sk(sock->sk)->module = module; + nlk->module = module; out: return err; @@ -437,7 +452,7 @@ static int netlink_release(struct socket *sock) skb_queue_purge(&sk->sk_write_queue); - if (nlk->pid && !nlk->groups) { + if (nlk->pid && !nlk->subscriptions) { struct netlink_notify n = { .protocol = sk->sk_protocol, .pid = nlk->pid, @@ -455,6 +470,9 @@ static int netlink_release(struct socket *sock) netlink_table_ungrab(); } + kfree(nlk->groups); + nlk->groups = NULL; + sock_put(sk); return 0; } @@ -503,6 +521,18 @@ static inline int netlink_capable(struct socket *sock, unsigned int flag) capable(CAP_NET_ADMIN); } +static void +netlink_update_subscriptions(struct sock *sk, unsigned int subscriptions) +{ + struct netlink_sock *nlk = nlk_sk(sk); + + if (nlk->subscriptions && !subscriptions) + __sk_del_bind_node(sk); + else if (!nlk->subscriptions && subscriptions) + sk_add_bind_node(sk, &nl_table[sk->sk_protocol].mc_list); + nlk->subscriptions = subscriptions; +} + static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len) { struct sock *sk = sock->sk; @@ -528,15 +558,14 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len return err; } - if (!nladdr->nl_groups && !nlk->groups) + if (!nladdr->nl_groups && !(u32)nlk->groups[0]) return 0; netlink_table_grab(); - if (nlk->groups && !nladdr->nl_groups) - __sk_del_bind_node(sk); - else if (!nlk->groups && nladdr->nl_groups) - sk_add_bind_node(sk, &nl_table[sk->sk_protocol].mc_list); - nlk->groups = nladdr->nl_groups; + netlink_update_subscriptions(sk, nlk->subscriptions + + hweight32(nladdr->nl_groups) - + hweight32(nlk->groups[0])); + nlk->groups[0] = (nlk->groups[0] & ~0xffffffffUL) | nladdr->nl_groups; netlink_table_ungrab(); return 0; @@ -590,7 +619,7 @@ static int netlink_getname(struct socket *sock, struct sockaddr *addr, int *addr nladdr->nl_groups = netlink_group_mask(nlk->dst_group); } else { nladdr->nl_pid = nlk->pid; - nladdr->nl_groups = nlk->groups; + nladdr->nl_groups = nlk->groups[0]; } return 0; } @@ -791,7 +820,8 @@ static inline int do_one_broadcast(struct sock *sk, if (p->exclude_sk == sk) goto out; - if (nlk->pid == p->pid || !(nlk->groups & netlink_group_mask(p->group))) + if (nlk->pid == p->pid || p->group - 1 >= nlk->ngroups || + !test_bit(p->group - 1, nlk->groups)) goto out; if (p->failure) { @@ -887,7 +917,8 @@ static inline int do_one_set_err(struct sock *sk, if (sk == p->exclude_sk) goto out; - if (nlk->pid == p->pid || !(nlk->groups & netlink_group_mask(p->group))) + if (nlk->pid == p->pid || p->group - 1 >= nlk->ngroups || + !test_bit(p->group - 1, nlk->groups)) goto out; sk->sk_err = p->code; @@ -1112,6 +1143,7 @@ netlink_kernel_create(int unit, void (*input)(struct sock *sk, int len), struct nlk->flags |= NETLINK_KERNEL_SOCKET; netlink_table_grab(); + nl_table[unit].groups = 32; nl_table[unit].module = module; nl_table[unit].registered = 1; netlink_table_ungrab(); @@ -1358,7 +1390,8 @@ static int netlink_seq_show(struct seq_file *seq, void *v) s, s->sk_protocol, nlk->pid, - nlk->groups, + nlk->flags & NETLINK_KERNEL_SOCKET ? + 0 : (unsigned int)nlk->groups[0], atomic_read(&s->sk_rmem_alloc), atomic_read(&s->sk_wmem_alloc), nlk->cb, |