PageRenderTime 50ms CodeModel.GetById 20ms RepoModel.GetById 0ms app.codeStats 0ms

/net/vmw_vsock/virtio_transport_common.c

https://gitlab.com/CadeLaRen/linux
C | 992 lines | 785 code | 170 blank | 37 comment | 84 complexity | 899c9d40b31a6aaaf947d6170f6a9e0f MD5 | raw file
  1. /*
  2. * common code for virtio vsock
  3. *
  4. * Copyright (C) 2013-2015 Red Hat, Inc.
  5. * Author: Asias He <asias@redhat.com>
  6. * Stefan Hajnoczi <stefanha@redhat.com>
  7. *
  8. * This work is licensed under the terms of the GNU GPL, version 2.
  9. */
  10. #include <linux/spinlock.h>
  11. #include <linux/module.h>
  12. #include <linux/ctype.h>
  13. #include <linux/list.h>
  14. #include <linux/virtio.h>
  15. #include <linux/virtio_ids.h>
  16. #include <linux/virtio_config.h>
  17. #include <linux/virtio_vsock.h>
  18. #include <net/sock.h>
  19. #include <net/af_vsock.h>
  20. #define CREATE_TRACE_POINTS
  21. #include <trace/events/vsock_virtio_transport_common.h>
  22. /* How long to wait for graceful shutdown of a connection */
  23. #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
  24. static const struct virtio_transport *virtio_transport_get_ops(void)
  25. {
  26. const struct vsock_transport *t = vsock_core_get_transport();
  27. return container_of(t, struct virtio_transport, transport);
  28. }
  29. struct virtio_vsock_pkt *
  30. virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
  31. size_t len,
  32. u32 src_cid,
  33. u32 src_port,
  34. u32 dst_cid,
  35. u32 dst_port)
  36. {
  37. struct virtio_vsock_pkt *pkt;
  38. int err;
  39. pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
  40. if (!pkt)
  41. return NULL;
  42. pkt->hdr.type = cpu_to_le16(info->type);
  43. pkt->hdr.op = cpu_to_le16(info->op);
  44. pkt->hdr.src_cid = cpu_to_le64(src_cid);
  45. pkt->hdr.dst_cid = cpu_to_le64(dst_cid);
  46. pkt->hdr.src_port = cpu_to_le32(src_port);
  47. pkt->hdr.dst_port = cpu_to_le32(dst_port);
  48. pkt->hdr.flags = cpu_to_le32(info->flags);
  49. pkt->len = len;
  50. pkt->hdr.len = cpu_to_le32(len);
  51. pkt->reply = info->reply;
  52. if (info->msg && len > 0) {
  53. pkt->buf = kmalloc(len, GFP_KERNEL);
  54. if (!pkt->buf)
  55. goto out_pkt;
  56. err = memcpy_from_msg(pkt->buf, info->msg, len);
  57. if (err)
  58. goto out;
  59. }
  60. trace_virtio_transport_alloc_pkt(src_cid, src_port,
  61. dst_cid, dst_port,
  62. len,
  63. info->type,
  64. info->op,
  65. info->flags);
  66. return pkt;
  67. out:
  68. kfree(pkt->buf);
  69. out_pkt:
  70. kfree(pkt);
  71. return NULL;
  72. }
  73. EXPORT_SYMBOL_GPL(virtio_transport_alloc_pkt);
  74. static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
  75. struct virtio_vsock_pkt_info *info)
  76. {
  77. u32 src_cid, src_port, dst_cid, dst_port;
  78. struct virtio_vsock_sock *vvs;
  79. struct virtio_vsock_pkt *pkt;
  80. u32 pkt_len = info->pkt_len;
  81. src_cid = vm_sockets_get_local_cid();
  82. src_port = vsk->local_addr.svm_port;
  83. if (!info->remote_cid) {
  84. dst_cid = vsk->remote_addr.svm_cid;
  85. dst_port = vsk->remote_addr.svm_port;
  86. } else {
  87. dst_cid = info->remote_cid;
  88. dst_port = info->remote_port;
  89. }
  90. vvs = vsk->trans;
  91. /* we can send less than pkt_len bytes */
  92. if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
  93. pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
  94. /* virtio_transport_get_credit might return less than pkt_len credit */
  95. pkt_len = virtio_transport_get_credit(vvs, pkt_len);
  96. /* Do not send zero length OP_RW pkt */
  97. if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
  98. return pkt_len;
  99. pkt = virtio_transport_alloc_pkt(info, pkt_len,
  100. src_cid, src_port,
  101. dst_cid, dst_port);
  102. if (!pkt) {
  103. virtio_transport_put_credit(vvs, pkt_len);
  104. return -ENOMEM;
  105. }
  106. virtio_transport_inc_tx_pkt(vvs, pkt);
  107. return virtio_transport_get_ops()->send_pkt(pkt);
  108. }
  109. static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
  110. struct virtio_vsock_pkt *pkt)
  111. {
  112. vvs->rx_bytes += pkt->len;
  113. }
  114. static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
  115. struct virtio_vsock_pkt *pkt)
  116. {
  117. vvs->rx_bytes -= pkt->len;
  118. vvs->fwd_cnt += pkt->len;
  119. }
  120. void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
  121. {
  122. spin_lock_bh(&vvs->tx_lock);
  123. pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
  124. pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
  125. spin_unlock_bh(&vvs->tx_lock);
  126. }
  127. EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
  128. u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
  129. {
  130. u32 ret;
  131. spin_lock_bh(&vvs->tx_lock);
  132. ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
  133. if (ret > credit)
  134. ret = credit;
  135. vvs->tx_cnt += ret;
  136. spin_unlock_bh(&vvs->tx_lock);
  137. return ret;
  138. }
  139. EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
  140. void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
  141. {
  142. spin_lock_bh(&vvs->tx_lock);
  143. vvs->tx_cnt -= credit;
  144. spin_unlock_bh(&vvs->tx_lock);
  145. }
  146. EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
  147. static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
  148. int type,
  149. struct virtio_vsock_hdr *hdr)
  150. {
  151. struct virtio_vsock_pkt_info info = {
  152. .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
  153. .type = type,
  154. };
  155. return virtio_transport_send_pkt_info(vsk, &info);
  156. }
  157. static ssize_t
  158. virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
  159. struct msghdr *msg,
  160. size_t len)
  161. {
  162. struct virtio_vsock_sock *vvs = vsk->trans;
  163. struct virtio_vsock_pkt *pkt;
  164. size_t bytes, total = 0;
  165. int err = -EFAULT;
  166. spin_lock_bh(&vvs->rx_lock);
  167. while (total < len && !list_empty(&vvs->rx_queue)) {
  168. pkt = list_first_entry(&vvs->rx_queue,
  169. struct virtio_vsock_pkt, list);
  170. bytes = len - total;
  171. if (bytes > pkt->len - pkt->off)
  172. bytes = pkt->len - pkt->off;
  173. /* sk_lock is held by caller so no one else can dequeue.
  174. * Unlock rx_lock since memcpy_to_msg() may sleep.
  175. */
  176. spin_unlock_bh(&vvs->rx_lock);
  177. err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
  178. if (err)
  179. goto out;
  180. spin_lock_bh(&vvs->rx_lock);
  181. total += bytes;
  182. pkt->off += bytes;
  183. if (pkt->off == pkt->len) {
  184. virtio_transport_dec_rx_pkt(vvs, pkt);
  185. list_del(&pkt->list);
  186. virtio_transport_free_pkt(pkt);
  187. }
  188. }
  189. spin_unlock_bh(&vvs->rx_lock);
  190. /* Send a credit pkt to peer */
  191. virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
  192. NULL);
  193. return total;
  194. out:
  195. if (total)
  196. err = total;
  197. return err;
  198. }
  199. ssize_t
  200. virtio_transport_stream_dequeue(struct vsock_sock *vsk,
  201. struct msghdr *msg,
  202. size_t len, int flags)
  203. {
  204. if (flags & MSG_PEEK)
  205. return -EOPNOTSUPP;
  206. return virtio_transport_stream_do_dequeue(vsk, msg, len);
  207. }
  208. EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
  209. int
  210. virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
  211. struct msghdr *msg,
  212. size_t len, int flags)
  213. {
  214. return -EOPNOTSUPP;
  215. }
  216. EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
  217. s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
  218. {
  219. struct virtio_vsock_sock *vvs = vsk->trans;
  220. s64 bytes;
  221. spin_lock_bh(&vvs->rx_lock);
  222. bytes = vvs->rx_bytes;
  223. spin_unlock_bh(&vvs->rx_lock);
  224. return bytes;
  225. }
  226. EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
  227. static s64 virtio_transport_has_space(struct vsock_sock *vsk)
  228. {
  229. struct virtio_vsock_sock *vvs = vsk->trans;
  230. s64 bytes;
  231. bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
  232. if (bytes < 0)
  233. bytes = 0;
  234. return bytes;
  235. }
  236. s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
  237. {
  238. struct virtio_vsock_sock *vvs = vsk->trans;
  239. s64 bytes;
  240. spin_lock_bh(&vvs->tx_lock);
  241. bytes = virtio_transport_has_space(vsk);
  242. spin_unlock_bh(&vvs->tx_lock);
  243. return bytes;
  244. }
  245. EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
  246. int virtio_transport_do_socket_init(struct vsock_sock *vsk,
  247. struct vsock_sock *psk)
  248. {
  249. struct virtio_vsock_sock *vvs;
  250. vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
  251. if (!vvs)
  252. return -ENOMEM;
  253. vsk->trans = vvs;
  254. vvs->vsk = vsk;
  255. if (psk) {
  256. struct virtio_vsock_sock *ptrans = psk->trans;
  257. vvs->buf_size = ptrans->buf_size;
  258. vvs->buf_size_min = ptrans->buf_size_min;
  259. vvs->buf_size_max = ptrans->buf_size_max;
  260. vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
  261. } else {
  262. vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
  263. vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
  264. vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
  265. }
  266. vvs->buf_alloc = vvs->buf_size;
  267. spin_lock_init(&vvs->rx_lock);
  268. spin_lock_init(&vvs->tx_lock);
  269. INIT_LIST_HEAD(&vvs->rx_queue);
  270. return 0;
  271. }
  272. EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
  273. u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
  274. {
  275. struct virtio_vsock_sock *vvs = vsk->trans;
  276. return vvs->buf_size;
  277. }
  278. EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
  279. u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
  280. {
  281. struct virtio_vsock_sock *vvs = vsk->trans;
  282. return vvs->buf_size_min;
  283. }
  284. EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
  285. u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
  286. {
  287. struct virtio_vsock_sock *vvs = vsk->trans;
  288. return vvs->buf_size_max;
  289. }
  290. EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
  291. void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
  292. {
  293. struct virtio_vsock_sock *vvs = vsk->trans;
  294. if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
  295. val = VIRTIO_VSOCK_MAX_BUF_SIZE;
  296. if (val < vvs->buf_size_min)
  297. vvs->buf_size_min = val;
  298. if (val > vvs->buf_size_max)
  299. vvs->buf_size_max = val;
  300. vvs->buf_size = val;
  301. vvs->buf_alloc = val;
  302. }
  303. EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
  304. void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
  305. {
  306. struct virtio_vsock_sock *vvs = vsk->trans;
  307. if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
  308. val = VIRTIO_VSOCK_MAX_BUF_SIZE;
  309. if (val > vvs->buf_size)
  310. vvs->buf_size = val;
  311. vvs->buf_size_min = val;
  312. }
  313. EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
  314. void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
  315. {
  316. struct virtio_vsock_sock *vvs = vsk->trans;
  317. if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
  318. val = VIRTIO_VSOCK_MAX_BUF_SIZE;
  319. if (val < vvs->buf_size)
  320. vvs->buf_size = val;
  321. vvs->buf_size_max = val;
  322. }
  323. EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
  324. int
  325. virtio_transport_notify_poll_in(struct vsock_sock *vsk,
  326. size_t target,
  327. bool *data_ready_now)
  328. {
  329. if (vsock_stream_has_data(vsk))
  330. *data_ready_now = true;
  331. else
  332. *data_ready_now = false;
  333. return 0;
  334. }
  335. EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
  336. int
  337. virtio_transport_notify_poll_out(struct vsock_sock *vsk,
  338. size_t target,
  339. bool *space_avail_now)
  340. {
  341. s64 free_space;
  342. free_space = vsock_stream_has_space(vsk);
  343. if (free_space > 0)
  344. *space_avail_now = true;
  345. else if (free_space == 0)
  346. *space_avail_now = false;
  347. return 0;
  348. }
  349. EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
  350. int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
  351. size_t target, struct vsock_transport_recv_notify_data *data)
  352. {
  353. return 0;
  354. }
  355. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
  356. int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
  357. size_t target, struct vsock_transport_recv_notify_data *data)
  358. {
  359. return 0;
  360. }
  361. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
  362. int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
  363. size_t target, struct vsock_transport_recv_notify_data *data)
  364. {
  365. return 0;
  366. }
  367. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
  368. int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
  369. size_t target, ssize_t copied, bool data_read,
  370. struct vsock_transport_recv_notify_data *data)
  371. {
  372. return 0;
  373. }
  374. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
  375. int virtio_transport_notify_send_init(struct vsock_sock *vsk,
  376. struct vsock_transport_send_notify_data *data)
  377. {
  378. return 0;
  379. }
  380. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
  381. int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
  382. struct vsock_transport_send_notify_data *data)
  383. {
  384. return 0;
  385. }
  386. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
  387. int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
  388. struct vsock_transport_send_notify_data *data)
  389. {
  390. return 0;
  391. }
  392. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
  393. int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
  394. ssize_t written, struct vsock_transport_send_notify_data *data)
  395. {
  396. return 0;
  397. }
  398. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
  399. u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
  400. {
  401. struct virtio_vsock_sock *vvs = vsk->trans;
  402. return vvs->buf_size;
  403. }
  404. EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
  405. bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
  406. {
  407. return true;
  408. }
  409. EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
  410. bool virtio_transport_stream_allow(u32 cid, u32 port)
  411. {
  412. return true;
  413. }
  414. EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
  415. int virtio_transport_dgram_bind(struct vsock_sock *vsk,
  416. struct sockaddr_vm *addr)
  417. {
  418. return -EOPNOTSUPP;
  419. }
  420. EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
  421. bool virtio_transport_dgram_allow(u32 cid, u32 port)
  422. {
  423. return false;
  424. }
  425. EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
  426. int virtio_transport_connect(struct vsock_sock *vsk)
  427. {
  428. struct virtio_vsock_pkt_info info = {
  429. .op = VIRTIO_VSOCK_OP_REQUEST,
  430. .type = VIRTIO_VSOCK_TYPE_STREAM,
  431. };
  432. return virtio_transport_send_pkt_info(vsk, &info);
  433. }
  434. EXPORT_SYMBOL_GPL(virtio_transport_connect);
  435. int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
  436. {
  437. struct virtio_vsock_pkt_info info = {
  438. .op = VIRTIO_VSOCK_OP_SHUTDOWN,
  439. .type = VIRTIO_VSOCK_TYPE_STREAM,
  440. .flags = (mode & RCV_SHUTDOWN ?
  441. VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
  442. (mode & SEND_SHUTDOWN ?
  443. VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
  444. };
  445. return virtio_transport_send_pkt_info(vsk, &info);
  446. }
  447. EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
  448. int
  449. virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
  450. struct sockaddr_vm *remote_addr,
  451. struct msghdr *msg,
  452. size_t dgram_len)
  453. {
  454. return -EOPNOTSUPP;
  455. }
  456. EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
  457. ssize_t
  458. virtio_transport_stream_enqueue(struct vsock_sock *vsk,
  459. struct msghdr *msg,
  460. size_t len)
  461. {
  462. struct virtio_vsock_pkt_info info = {
  463. .op = VIRTIO_VSOCK_OP_RW,
  464. .type = VIRTIO_VSOCK_TYPE_STREAM,
  465. .msg = msg,
  466. .pkt_len = len,
  467. };
  468. return virtio_transport_send_pkt_info(vsk, &info);
  469. }
  470. EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
  471. void virtio_transport_destruct(struct vsock_sock *vsk)
  472. {
  473. struct virtio_vsock_sock *vvs = vsk->trans;
  474. kfree(vvs);
  475. }
  476. EXPORT_SYMBOL_GPL(virtio_transport_destruct);
  477. static int virtio_transport_reset(struct vsock_sock *vsk,
  478. struct virtio_vsock_pkt *pkt)
  479. {
  480. struct virtio_vsock_pkt_info info = {
  481. .op = VIRTIO_VSOCK_OP_RST,
  482. .type = VIRTIO_VSOCK_TYPE_STREAM,
  483. .reply = !!pkt,
  484. };
  485. /* Send RST only if the original pkt is not a RST pkt */
  486. if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
  487. return 0;
  488. return virtio_transport_send_pkt_info(vsk, &info);
  489. }
  490. /* Normally packets are associated with a socket. There may be no socket if an
  491. * attempt was made to connect to a socket that does not exist.
  492. */
  493. static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
  494. {
  495. struct virtio_vsock_pkt_info info = {
  496. .op = VIRTIO_VSOCK_OP_RST,
  497. .type = le16_to_cpu(pkt->hdr.type),
  498. .reply = true,
  499. };
  500. /* Send RST only if the original pkt is not a RST pkt */
  501. if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
  502. return 0;
  503. pkt = virtio_transport_alloc_pkt(&info, 0,
  504. le32_to_cpu(pkt->hdr.dst_cid),
  505. le32_to_cpu(pkt->hdr.dst_port),
  506. le32_to_cpu(pkt->hdr.src_cid),
  507. le32_to_cpu(pkt->hdr.src_port));
  508. if (!pkt)
  509. return -ENOMEM;
  510. return virtio_transport_get_ops()->send_pkt(pkt);
  511. }
  512. static void virtio_transport_wait_close(struct sock *sk, long timeout)
  513. {
  514. if (timeout) {
  515. DEFINE_WAIT(wait);
  516. do {
  517. prepare_to_wait(sk_sleep(sk), &wait,
  518. TASK_INTERRUPTIBLE);
  519. if (sk_wait_event(sk, &timeout,
  520. sock_flag(sk, SOCK_DONE)))
  521. break;
  522. } while (!signal_pending(current) && timeout);
  523. finish_wait(sk_sleep(sk), &wait);
  524. }
  525. }
  526. static void virtio_transport_do_close(struct vsock_sock *vsk,
  527. bool cancel_timeout)
  528. {
  529. struct sock *sk = sk_vsock(vsk);
  530. sock_set_flag(sk, SOCK_DONE);
  531. vsk->peer_shutdown = SHUTDOWN_MASK;
  532. if (vsock_stream_has_data(vsk) <= 0)
  533. sk->sk_state = SS_DISCONNECTING;
  534. sk->sk_state_change(sk);
  535. if (vsk->close_work_scheduled &&
  536. (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
  537. vsk->close_work_scheduled = false;
  538. vsock_remove_sock(vsk);
  539. /* Release refcnt obtained when we scheduled the timeout */
  540. sock_put(sk);
  541. }
  542. }
  543. static void virtio_transport_close_timeout(struct work_struct *work)
  544. {
  545. struct vsock_sock *vsk =
  546. container_of(work, struct vsock_sock, close_work.work);
  547. struct sock *sk = sk_vsock(vsk);
  548. sock_hold(sk);
  549. lock_sock(sk);
  550. if (!sock_flag(sk, SOCK_DONE)) {
  551. (void)virtio_transport_reset(vsk, NULL);
  552. virtio_transport_do_close(vsk, false);
  553. }
  554. vsk->close_work_scheduled = false;
  555. release_sock(sk);
  556. sock_put(sk);
  557. }
  558. /* User context, vsk->sk is locked */
  559. static bool virtio_transport_close(struct vsock_sock *vsk)
  560. {
  561. struct sock *sk = &vsk->sk;
  562. if (!(sk->sk_state == SS_CONNECTED ||
  563. sk->sk_state == SS_DISCONNECTING))
  564. return true;
  565. /* Already received SHUTDOWN from peer, reply with RST */
  566. if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
  567. (void)virtio_transport_reset(vsk, NULL);
  568. return true;
  569. }
  570. if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
  571. (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
  572. if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
  573. virtio_transport_wait_close(sk, sk->sk_lingertime);
  574. if (sock_flag(sk, SOCK_DONE)) {
  575. return true;
  576. }
  577. sock_hold(sk);
  578. INIT_DELAYED_WORK(&vsk->close_work,
  579. virtio_transport_close_timeout);
  580. vsk->close_work_scheduled = true;
  581. schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
  582. return false;
  583. }
  584. void virtio_transport_release(struct vsock_sock *vsk)
  585. {
  586. struct sock *sk = &vsk->sk;
  587. bool remove_sock = true;
  588. lock_sock(sk);
  589. if (sk->sk_type == SOCK_STREAM)
  590. remove_sock = virtio_transport_close(vsk);
  591. release_sock(sk);
  592. if (remove_sock)
  593. vsock_remove_sock(vsk);
  594. }
  595. EXPORT_SYMBOL_GPL(virtio_transport_release);
  596. static int
  597. virtio_transport_recv_connecting(struct sock *sk,
  598. struct virtio_vsock_pkt *pkt)
  599. {
  600. struct vsock_sock *vsk = vsock_sk(sk);
  601. int err;
  602. int skerr;
  603. switch (le16_to_cpu(pkt->hdr.op)) {
  604. case VIRTIO_VSOCK_OP_RESPONSE:
  605. sk->sk_state = SS_CONNECTED;
  606. sk->sk_socket->state = SS_CONNECTED;
  607. vsock_insert_connected(vsk);
  608. sk->sk_state_change(sk);
  609. break;
  610. case VIRTIO_VSOCK_OP_INVALID:
  611. break;
  612. case VIRTIO_VSOCK_OP_RST:
  613. skerr = ECONNRESET;
  614. err = 0;
  615. goto destroy;
  616. default:
  617. skerr = EPROTO;
  618. err = -EINVAL;
  619. goto destroy;
  620. }
  621. return 0;
  622. destroy:
  623. virtio_transport_reset(vsk, pkt);
  624. sk->sk_state = SS_UNCONNECTED;
  625. sk->sk_err = skerr;
  626. sk->sk_error_report(sk);
  627. return err;
  628. }
  629. static int
  630. virtio_transport_recv_connected(struct sock *sk,
  631. struct virtio_vsock_pkt *pkt)
  632. {
  633. struct vsock_sock *vsk = vsock_sk(sk);
  634. struct virtio_vsock_sock *vvs = vsk->trans;
  635. int err = 0;
  636. switch (le16_to_cpu(pkt->hdr.op)) {
  637. case VIRTIO_VSOCK_OP_RW:
  638. pkt->len = le32_to_cpu(pkt->hdr.len);
  639. pkt->off = 0;
  640. spin_lock_bh(&vvs->rx_lock);
  641. virtio_transport_inc_rx_pkt(vvs, pkt);
  642. list_add_tail(&pkt->list, &vvs->rx_queue);
  643. spin_unlock_bh(&vvs->rx_lock);
  644. sk->sk_data_ready(sk);
  645. return err;
  646. case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
  647. sk->sk_write_space(sk);
  648. break;
  649. case VIRTIO_VSOCK_OP_SHUTDOWN:
  650. if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
  651. vsk->peer_shutdown |= RCV_SHUTDOWN;
  652. if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
  653. vsk->peer_shutdown |= SEND_SHUTDOWN;
  654. if (vsk->peer_shutdown == SHUTDOWN_MASK &&
  655. vsock_stream_has_data(vsk) <= 0)
  656. sk->sk_state = SS_DISCONNECTING;
  657. if (le32_to_cpu(pkt->hdr.flags))
  658. sk->sk_state_change(sk);
  659. break;
  660. case VIRTIO_VSOCK_OP_RST:
  661. virtio_transport_do_close(vsk, true);
  662. break;
  663. default:
  664. err = -EINVAL;
  665. break;
  666. }
  667. virtio_transport_free_pkt(pkt);
  668. return err;
  669. }
  670. static void
  671. virtio_transport_recv_disconnecting(struct sock *sk,
  672. struct virtio_vsock_pkt *pkt)
  673. {
  674. struct vsock_sock *vsk = vsock_sk(sk);
  675. if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
  676. virtio_transport_do_close(vsk, true);
  677. }
  678. static int
  679. virtio_transport_send_response(struct vsock_sock *vsk,
  680. struct virtio_vsock_pkt *pkt)
  681. {
  682. struct virtio_vsock_pkt_info info = {
  683. .op = VIRTIO_VSOCK_OP_RESPONSE,
  684. .type = VIRTIO_VSOCK_TYPE_STREAM,
  685. .remote_cid = le32_to_cpu(pkt->hdr.src_cid),
  686. .remote_port = le32_to_cpu(pkt->hdr.src_port),
  687. .reply = true,
  688. };
  689. return virtio_transport_send_pkt_info(vsk, &info);
  690. }
  691. /* Handle server socket */
  692. static int
  693. virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
  694. {
  695. struct vsock_sock *vsk = vsock_sk(sk);
  696. struct vsock_sock *vchild;
  697. struct sock *child;
  698. if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
  699. virtio_transport_reset(vsk, pkt);
  700. return -EINVAL;
  701. }
  702. if (sk_acceptq_is_full(sk)) {
  703. virtio_transport_reset(vsk, pkt);
  704. return -ENOMEM;
  705. }
  706. child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
  707. sk->sk_type, 0);
  708. if (!child) {
  709. virtio_transport_reset(vsk, pkt);
  710. return -ENOMEM;
  711. }
  712. sk->sk_ack_backlog++;
  713. lock_sock_nested(child, SINGLE_DEPTH_NESTING);
  714. child->sk_state = SS_CONNECTED;
  715. vchild = vsock_sk(child);
  716. vsock_addr_init(&vchild->local_addr, le32_to_cpu(pkt->hdr.dst_cid),
  717. le32_to_cpu(pkt->hdr.dst_port));
  718. vsock_addr_init(&vchild->remote_addr, le32_to_cpu(pkt->hdr.src_cid),
  719. le32_to_cpu(pkt->hdr.src_port));
  720. vsock_insert_connected(vchild);
  721. vsock_enqueue_accept(sk, child);
  722. virtio_transport_send_response(vchild, pkt);
  723. release_sock(child);
  724. sk->sk_data_ready(sk);
  725. return 0;
  726. }
  727. static bool virtio_transport_space_update(struct sock *sk,
  728. struct virtio_vsock_pkt *pkt)
  729. {
  730. struct vsock_sock *vsk = vsock_sk(sk);
  731. struct virtio_vsock_sock *vvs = vsk->trans;
  732. bool space_available;
  733. /* buf_alloc and fwd_cnt is always included in the hdr */
  734. spin_lock_bh(&vvs->tx_lock);
  735. vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
  736. vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
  737. space_available = virtio_transport_has_space(vsk);
  738. spin_unlock_bh(&vvs->tx_lock);
  739. return space_available;
  740. }
  741. /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
  742. * lock.
  743. */
  744. void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
  745. {
  746. struct sockaddr_vm src, dst;
  747. struct vsock_sock *vsk;
  748. struct sock *sk;
  749. bool space_available;
  750. vsock_addr_init(&src, le32_to_cpu(pkt->hdr.src_cid),
  751. le32_to_cpu(pkt->hdr.src_port));
  752. vsock_addr_init(&dst, le32_to_cpu(pkt->hdr.dst_cid),
  753. le32_to_cpu(pkt->hdr.dst_port));
  754. trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
  755. dst.svm_cid, dst.svm_port,
  756. le32_to_cpu(pkt->hdr.len),
  757. le16_to_cpu(pkt->hdr.type),
  758. le16_to_cpu(pkt->hdr.op),
  759. le32_to_cpu(pkt->hdr.flags),
  760. le32_to_cpu(pkt->hdr.buf_alloc),
  761. le32_to_cpu(pkt->hdr.fwd_cnt));
  762. if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
  763. (void)virtio_transport_reset_no_sock(pkt);
  764. goto free_pkt;
  765. }
  766. /* The socket must be in connected or bound table
  767. * otherwise send reset back
  768. */
  769. sk = vsock_find_connected_socket(&src, &dst);
  770. if (!sk) {
  771. sk = vsock_find_bound_socket(&dst);
  772. if (!sk) {
  773. (void)virtio_transport_reset_no_sock(pkt);
  774. goto free_pkt;
  775. }
  776. }
  777. vsk = vsock_sk(sk);
  778. space_available = virtio_transport_space_update(sk, pkt);
  779. lock_sock(sk);
  780. /* Update CID in case it has changed after a transport reset event */
  781. vsk->local_addr.svm_cid = dst.svm_cid;
  782. if (space_available)
  783. sk->sk_write_space(sk);
  784. switch (sk->sk_state) {
  785. case VSOCK_SS_LISTEN:
  786. virtio_transport_recv_listen(sk, pkt);
  787. virtio_transport_free_pkt(pkt);
  788. break;
  789. case SS_CONNECTING:
  790. virtio_transport_recv_connecting(sk, pkt);
  791. virtio_transport_free_pkt(pkt);
  792. break;
  793. case SS_CONNECTED:
  794. virtio_transport_recv_connected(sk, pkt);
  795. break;
  796. case SS_DISCONNECTING:
  797. virtio_transport_recv_disconnecting(sk, pkt);
  798. virtio_transport_free_pkt(pkt);
  799. break;
  800. default:
  801. virtio_transport_free_pkt(pkt);
  802. break;
  803. }
  804. release_sock(sk);
  805. /* Release refcnt obtained when we fetched this socket out of the
  806. * bound or connected list.
  807. */
  808. sock_put(sk);
  809. return;
  810. free_pkt:
  811. virtio_transport_free_pkt(pkt);
  812. }
  813. EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
  814. void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
  815. {
  816. kfree(pkt->buf);
  817. kfree(pkt);
  818. }
  819. EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
  820. MODULE_LICENSE("GPL v2");
  821. MODULE_AUTHOR("Asias He");
  822. MODULE_DESCRIPTION("common code for virtio vsock");