PageRenderTime 60ms CodeModel.GetById 25ms RepoModel.GetById 0ms app.codeStats 1ms

/net/ipv4/tcp_bpf.c

https://github.com/tytso/ext4
C | 628 lines | 531 code | 77 blank | 20 comment | 118 complexity | 63d26c86e7fe4cd37289d128dbdc4a59 MD5 | raw file
Possible License(s): GPL-2.0
  1. // SPDX-License-Identifier: GPL-2.0
  2. /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
  3. #include <linux/skmsg.h>
  4. #include <linux/filter.h>
  5. #include <linux/bpf.h>
  6. #include <linux/init.h>
  7. #include <linux/wait.h>
  8. #include <net/inet_common.h>
  9. #include <net/tls.h>
  10. int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
  11. struct msghdr *msg, int len, int flags)
  12. {
  13. struct iov_iter *iter = &msg->msg_iter;
  14. int peek = flags & MSG_PEEK;
  15. int i, ret, copied = 0;
  16. struct sk_msg *msg_rx;
  17. msg_rx = list_first_entry_or_null(&psock->ingress_msg,
  18. struct sk_msg, list);
  19. while (copied != len) {
  20. struct scatterlist *sge;
  21. if (unlikely(!msg_rx))
  22. break;
  23. i = msg_rx->sg.start;
  24. do {
  25. struct page *page;
  26. int copy;
  27. sge = sk_msg_elem(msg_rx, i);
  28. copy = sge->length;
  29. page = sg_page(sge);
  30. if (copied + copy > len)
  31. copy = len - copied;
  32. ret = copy_page_to_iter(page, sge->offset, copy, iter);
  33. if (ret != copy) {
  34. msg_rx->sg.start = i;
  35. return -EFAULT;
  36. }
  37. copied += copy;
  38. if (likely(!peek)) {
  39. sge->offset += copy;
  40. sge->length -= copy;
  41. sk_mem_uncharge(sk, copy);
  42. msg_rx->sg.size -= copy;
  43. if (!sge->length) {
  44. sk_msg_iter_var_next(i);
  45. if (!msg_rx->skb)
  46. put_page(page);
  47. }
  48. } else {
  49. sk_msg_iter_var_next(i);
  50. }
  51. if (copied == len)
  52. break;
  53. } while (i != msg_rx->sg.end);
  54. if (unlikely(peek)) {
  55. if (msg_rx == list_last_entry(&psock->ingress_msg,
  56. struct sk_msg, list))
  57. break;
  58. msg_rx = list_next_entry(msg_rx, list);
  59. continue;
  60. }
  61. msg_rx->sg.start = i;
  62. if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
  63. list_del(&msg_rx->list);
  64. if (msg_rx->skb)
  65. consume_skb(msg_rx->skb);
  66. kfree(msg_rx);
  67. }
  68. msg_rx = list_first_entry_or_null(&psock->ingress_msg,
  69. struct sk_msg, list);
  70. }
  71. return copied;
  72. }
  73. EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg);
  74. static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock,
  75. struct sk_msg *msg, u32 apply_bytes, int flags)
  76. {
  77. bool apply = apply_bytes;
  78. struct scatterlist *sge;
  79. u32 size, copied = 0;
  80. struct sk_msg *tmp;
  81. int i, ret = 0;
  82. tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL);
  83. if (unlikely(!tmp))
  84. return -ENOMEM;
  85. lock_sock(sk);
  86. tmp->sg.start = msg->sg.start;
  87. i = msg->sg.start;
  88. do {
  89. sge = sk_msg_elem(msg, i);
  90. size = (apply && apply_bytes < sge->length) ?
  91. apply_bytes : sge->length;
  92. if (!sk_wmem_schedule(sk, size)) {
  93. if (!copied)
  94. ret = -ENOMEM;
  95. break;
  96. }
  97. sk_mem_charge(sk, size);
  98. sk_msg_xfer(tmp, msg, i, size);
  99. copied += size;
  100. if (sge->length)
  101. get_page(sk_msg_page(tmp, i));
  102. sk_msg_iter_var_next(i);
  103. tmp->sg.end = i;
  104. if (apply) {
  105. apply_bytes -= size;
  106. if (!apply_bytes)
  107. break;
  108. }
  109. } while (i != msg->sg.end);
  110. if (!ret) {
  111. msg->sg.start = i;
  112. sk_psock_queue_msg(psock, tmp);
  113. sk_psock_data_ready(sk, psock);
  114. } else {
  115. sk_msg_free(sk, tmp);
  116. kfree(tmp);
  117. }
  118. release_sock(sk);
  119. return ret;
  120. }
  121. static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes,
  122. int flags, bool uncharge)
  123. {
  124. bool apply = apply_bytes;
  125. struct scatterlist *sge;
  126. struct page *page;
  127. int size, ret = 0;
  128. u32 off;
  129. while (1) {
  130. bool has_tx_ulp;
  131. sge = sk_msg_elem(msg, msg->sg.start);
  132. size = (apply && apply_bytes < sge->length) ?
  133. apply_bytes : sge->length;
  134. off = sge->offset;
  135. page = sg_page(sge);
  136. tcp_rate_check_app_limited(sk);
  137. retry:
  138. has_tx_ulp = tls_sw_has_ctx_tx(sk);
  139. if (has_tx_ulp) {
  140. flags |= MSG_SENDPAGE_NOPOLICY;
  141. ret = kernel_sendpage_locked(sk,
  142. page, off, size, flags);
  143. } else {
  144. ret = do_tcp_sendpages(sk, page, off, size, flags);
  145. }
  146. if (ret <= 0)
  147. return ret;
  148. if (apply)
  149. apply_bytes -= ret;
  150. msg->sg.size -= ret;
  151. sge->offset += ret;
  152. sge->length -= ret;
  153. if (uncharge)
  154. sk_mem_uncharge(sk, ret);
  155. if (ret != size) {
  156. size -= ret;
  157. off += ret;
  158. goto retry;
  159. }
  160. if (!sge->length) {
  161. put_page(page);
  162. sk_msg_iter_next(msg, start);
  163. sg_init_table(sge, 1);
  164. if (msg->sg.start == msg->sg.end)
  165. break;
  166. }
  167. if (apply && !apply_bytes)
  168. break;
  169. }
  170. return 0;
  171. }
  172. static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg,
  173. u32 apply_bytes, int flags, bool uncharge)
  174. {
  175. int ret;
  176. lock_sock(sk);
  177. ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge);
  178. release_sock(sk);
  179. return ret;
  180. }
  181. int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg,
  182. u32 bytes, int flags)
  183. {
  184. bool ingress = sk_msg_to_ingress(msg);
  185. struct sk_psock *psock = sk_psock_get(sk);
  186. int ret;
  187. if (unlikely(!psock)) {
  188. sk_msg_free(sk, msg);
  189. return 0;
  190. }
  191. ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) :
  192. tcp_bpf_push_locked(sk, msg, bytes, flags, false);
  193. sk_psock_put(sk, psock);
  194. return ret;
  195. }
  196. EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir);
  197. #ifdef CONFIG_BPF_STREAM_PARSER
  198. static bool tcp_bpf_stream_read(const struct sock *sk)
  199. {
  200. struct sk_psock *psock;
  201. bool empty = true;
  202. rcu_read_lock();
  203. psock = sk_psock(sk);
  204. if (likely(psock))
  205. empty = list_empty(&psock->ingress_msg);
  206. rcu_read_unlock();
  207. return !empty;
  208. }
  209. static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock,
  210. int flags, long timeo, int *err)
  211. {
  212. DEFINE_WAIT_FUNC(wait, woken_wake_function);
  213. int ret = 0;
  214. if (sk->sk_shutdown & RCV_SHUTDOWN)
  215. return 1;
  216. if (!timeo)
  217. return ret;
  218. add_wait_queue(sk_sleep(sk), &wait);
  219. sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
  220. ret = sk_wait_event(sk, &timeo,
  221. !list_empty(&psock->ingress_msg) ||
  222. !skb_queue_empty(&sk->sk_receive_queue), &wait);
  223. sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
  224. remove_wait_queue(sk_sleep(sk), &wait);
  225. return ret;
  226. }
  227. static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
  228. int nonblock, int flags, int *addr_len)
  229. {
  230. struct sk_psock *psock;
  231. int copied, ret;
  232. if (unlikely(flags & MSG_ERRQUEUE))
  233. return inet_recv_error(sk, msg, len, addr_len);
  234. psock = sk_psock_get(sk);
  235. if (unlikely(!psock))
  236. return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
  237. if (!skb_queue_empty(&sk->sk_receive_queue) &&
  238. sk_psock_queue_empty(psock)) {
  239. sk_psock_put(sk, psock);
  240. return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
  241. }
  242. lock_sock(sk);
  243. msg_bytes_ready:
  244. copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags);
  245. if (!copied) {
  246. int data, err = 0;
  247. long timeo;
  248. timeo = sock_rcvtimeo(sk, nonblock);
  249. data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err);
  250. if (data) {
  251. if (!sk_psock_queue_empty(psock))
  252. goto msg_bytes_ready;
  253. release_sock(sk);
  254. sk_psock_put(sk, psock);
  255. return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
  256. }
  257. if (err) {
  258. ret = err;
  259. goto out;
  260. }
  261. copied = -EAGAIN;
  262. }
  263. ret = copied;
  264. out:
  265. release_sock(sk);
  266. sk_psock_put(sk, psock);
  267. return ret;
  268. }
  269. static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
  270. struct sk_msg *msg, int *copied, int flags)
  271. {
  272. bool cork = false, enospc = sk_msg_full(msg);
  273. struct sock *sk_redir;
  274. u32 tosend, delta = 0;
  275. int ret;
  276. more_data:
  277. if (psock->eval == __SK_NONE) {
  278. /* Track delta in msg size to add/subtract it on SK_DROP from
  279. * returned to user copied size. This ensures user doesn't
  280. * get a positive return code with msg_cut_data and SK_DROP
  281. * verdict.
  282. */
  283. delta = msg->sg.size;
  284. psock->eval = sk_psock_msg_verdict(sk, psock, msg);
  285. delta -= msg->sg.size;
  286. }
  287. if (msg->cork_bytes &&
  288. msg->cork_bytes > msg->sg.size && !enospc) {
  289. psock->cork_bytes = msg->cork_bytes - msg->sg.size;
  290. if (!psock->cork) {
  291. psock->cork = kzalloc(sizeof(*psock->cork),
  292. GFP_ATOMIC | __GFP_NOWARN);
  293. if (!psock->cork)
  294. return -ENOMEM;
  295. }
  296. memcpy(psock->cork, msg, sizeof(*msg));
  297. return 0;
  298. }
  299. tosend = msg->sg.size;
  300. if (psock->apply_bytes && psock->apply_bytes < tosend)
  301. tosend = psock->apply_bytes;
  302. switch (psock->eval) {
  303. case __SK_PASS:
  304. ret = tcp_bpf_push(sk, msg, tosend, flags, true);
  305. if (unlikely(ret)) {
  306. *copied -= sk_msg_free(sk, msg);
  307. break;
  308. }
  309. sk_msg_apply_bytes(psock, tosend);
  310. break;
  311. case __SK_REDIRECT:
  312. sk_redir = psock->sk_redir;
  313. sk_msg_apply_bytes(psock, tosend);
  314. if (psock->cork) {
  315. cork = true;
  316. psock->cork = NULL;
  317. }
  318. sk_msg_return(sk, msg, tosend);
  319. release_sock(sk);
  320. ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags);
  321. lock_sock(sk);
  322. if (unlikely(ret < 0)) {
  323. int free = sk_msg_free_nocharge(sk, msg);
  324. if (!cork)
  325. *copied -= free;
  326. }
  327. if (cork) {
  328. sk_msg_free(sk, msg);
  329. kfree(msg);
  330. msg = NULL;
  331. ret = 0;
  332. }
  333. break;
  334. case __SK_DROP:
  335. default:
  336. sk_msg_free_partial(sk, msg, tosend);
  337. sk_msg_apply_bytes(psock, tosend);
  338. *copied -= (tosend + delta);
  339. return -EACCES;
  340. }
  341. if (likely(!ret)) {
  342. if (!psock->apply_bytes) {
  343. psock->eval = __SK_NONE;
  344. if (psock->sk_redir) {
  345. sock_put(psock->sk_redir);
  346. psock->sk_redir = NULL;
  347. }
  348. }
  349. if (msg &&
  350. msg->sg.data[msg->sg.start].page_link &&
  351. msg->sg.data[msg->sg.start].length)
  352. goto more_data;
  353. }
  354. return ret;
  355. }
  356. static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
  357. {
  358. struct sk_msg tmp, *msg_tx = NULL;
  359. int copied = 0, err = 0;
  360. struct sk_psock *psock;
  361. long timeo;
  362. int flags;
  363. /* Don't let internal do_tcp_sendpages() flags through */
  364. flags = (msg->msg_flags & ~MSG_SENDPAGE_DECRYPTED);
  365. flags |= MSG_NO_SHARED_FRAGS;
  366. psock = sk_psock_get(sk);
  367. if (unlikely(!psock))
  368. return tcp_sendmsg(sk, msg, size);
  369. lock_sock(sk);
  370. timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
  371. while (msg_data_left(msg)) {
  372. bool enospc = false;
  373. u32 copy, osize;
  374. if (sk->sk_err) {
  375. err = -sk->sk_err;
  376. goto out_err;
  377. }
  378. copy = msg_data_left(msg);
  379. if (!sk_stream_memory_free(sk))
  380. goto wait_for_sndbuf;
  381. if (psock->cork) {
  382. msg_tx = psock->cork;
  383. } else {
  384. msg_tx = &tmp;
  385. sk_msg_init(msg_tx);
  386. }
  387. osize = msg_tx->sg.size;
  388. err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1);
  389. if (err) {
  390. if (err != -ENOSPC)
  391. goto wait_for_memory;
  392. enospc = true;
  393. copy = msg_tx->sg.size - osize;
  394. }
  395. err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx,
  396. copy);
  397. if (err < 0) {
  398. sk_msg_trim(sk, msg_tx, osize);
  399. goto out_err;
  400. }
  401. copied += copy;
  402. if (psock->cork_bytes) {
  403. if (size > psock->cork_bytes)
  404. psock->cork_bytes = 0;
  405. else
  406. psock->cork_bytes -= size;
  407. if (psock->cork_bytes && !enospc)
  408. goto out_err;
  409. /* All cork bytes are accounted, rerun the prog. */
  410. psock->eval = __SK_NONE;
  411. psock->cork_bytes = 0;
  412. }
  413. err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags);
  414. if (unlikely(err < 0))
  415. goto out_err;
  416. continue;
  417. wait_for_sndbuf:
  418. set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
  419. wait_for_memory:
  420. err = sk_stream_wait_memory(sk, &timeo);
  421. if (err) {
  422. if (msg_tx && msg_tx != psock->cork)
  423. sk_msg_free(sk, msg_tx);
  424. goto out_err;
  425. }
  426. }
  427. out_err:
  428. if (err < 0)
  429. err = sk_stream_error(sk, msg->msg_flags, err);
  430. release_sock(sk);
  431. sk_psock_put(sk, psock);
  432. return copied ? copied : err;
  433. }
  434. static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset,
  435. size_t size, int flags)
  436. {
  437. struct sk_msg tmp, *msg = NULL;
  438. int err = 0, copied = 0;
  439. struct sk_psock *psock;
  440. bool enospc = false;
  441. psock = sk_psock_get(sk);
  442. if (unlikely(!psock))
  443. return tcp_sendpage(sk, page, offset, size, flags);
  444. lock_sock(sk);
  445. if (psock->cork) {
  446. msg = psock->cork;
  447. } else {
  448. msg = &tmp;
  449. sk_msg_init(msg);
  450. }
  451. /* Catch case where ring is full and sendpage is stalled. */
  452. if (unlikely(sk_msg_full(msg)))
  453. goto out_err;
  454. sk_msg_page_add(msg, page, size, offset);
  455. sk_mem_charge(sk, size);
  456. copied = size;
  457. if (sk_msg_full(msg))
  458. enospc = true;
  459. if (psock->cork_bytes) {
  460. if (size > psock->cork_bytes)
  461. psock->cork_bytes = 0;
  462. else
  463. psock->cork_bytes -= size;
  464. if (psock->cork_bytes && !enospc)
  465. goto out_err;
  466. /* All cork bytes are accounted, rerun the prog. */
  467. psock->eval = __SK_NONE;
  468. psock->cork_bytes = 0;
  469. }
  470. err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags);
  471. out_err:
  472. release_sock(sk);
  473. sk_psock_put(sk, psock);
  474. return copied ? copied : err;
  475. }
  476. enum {
  477. TCP_BPF_IPV4,
  478. TCP_BPF_IPV6,
  479. TCP_BPF_NUM_PROTS,
  480. };
  481. enum {
  482. TCP_BPF_BASE,
  483. TCP_BPF_TX,
  484. TCP_BPF_NUM_CFGS,
  485. };
  486. static struct proto *tcpv6_prot_saved __read_mostly;
  487. static DEFINE_SPINLOCK(tcpv6_prot_lock);
  488. static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS];
  489. static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
  490. struct proto *base)
  491. {
  492. prot[TCP_BPF_BASE] = *base;
  493. prot[TCP_BPF_BASE].unhash = sock_map_unhash;
  494. prot[TCP_BPF_BASE].close = sock_map_close;
  495. prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg;
  496. prot[TCP_BPF_BASE].stream_memory_read = tcp_bpf_stream_read;
  497. prot[TCP_BPF_TX] = prot[TCP_BPF_BASE];
  498. prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg;
  499. prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage;
  500. }
  501. static void tcp_bpf_check_v6_needs_rebuild(struct proto *ops)
  502. {
  503. if (unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) {
  504. spin_lock_bh(&tcpv6_prot_lock);
  505. if (likely(ops != tcpv6_prot_saved)) {
  506. tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops);
  507. smp_store_release(&tcpv6_prot_saved, ops);
  508. }
  509. spin_unlock_bh(&tcpv6_prot_lock);
  510. }
  511. }
  512. static int __init tcp_bpf_v4_build_proto(void)
  513. {
  514. tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot);
  515. return 0;
  516. }
  517. core_initcall(tcp_bpf_v4_build_proto);
  518. static int tcp_bpf_assert_proto_ops(struct proto *ops)
  519. {
  520. /* In order to avoid retpoline, we make assumptions when we call
  521. * into ops if e.g. a psock is not present. Make sure they are
  522. * indeed valid assumptions.
  523. */
  524. return ops->recvmsg == tcp_recvmsg &&
  525. ops->sendmsg == tcp_sendmsg &&
  526. ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
  527. }
  528. struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
  529. {
  530. int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
  531. int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
  532. if (sk->sk_family == AF_INET6) {
  533. if (tcp_bpf_assert_proto_ops(psock->sk_proto))
  534. return ERR_PTR(-EINVAL);
  535. tcp_bpf_check_v6_needs_rebuild(psock->sk_proto);
  536. }
  537. return &tcp_bpf_prots[family][config];
  538. }
  539. /* If a child got cloned from a listening socket that had tcp_bpf
  540. * protocol callbacks installed, we need to restore the callbacks to
  541. * the default ones because the child does not inherit the psock state
  542. * that tcp_bpf callbacks expect.
  543. */
  544. void tcp_bpf_clone(const struct sock *sk, struct sock *newsk)
  545. {
  546. int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
  547. struct proto *prot = newsk->sk_prot;
  548. if (prot == &tcp_bpf_prots[family][TCP_BPF_BASE])
  549. newsk->sk_prot = sk->sk_prot_creator;
  550. }
  551. #endif /* CONFIG_BPF_STREAM_PARSER */