PageRenderTime 69ms CodeModel.GetById 40ms RepoModel.GetById 1ms app.codeStats 0ms

/src/network/connection.c

https://github.com/q3k/uhub
C | 328 lines | 271 code | 37 blank | 20 comment | 49 complexity | 736dc1fdc68ecb3db7516af8569d2db6 MD5 | raw file
  1. /*
  2. * uhub - A tiny ADC p2p connection hub
  3. * Copyright (C) 2007-2010, Jan Vidar Krey
  4. *
  5. * This program is free software; you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation; either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. #include "uhub.h"
  20. #include "network/common.h"
  21. #ifdef SSL_SUPPORT
  22. enum uhub_tls_state
  23. {
  24. tls_st_none,
  25. tls_st_error,
  26. tls_st_accepting,
  27. tls_st_connecting,
  28. tls_st_connected,
  29. tls_st_disconnecting,
  30. };
  31. static int handle_openssl_error(struct net_connection* con, int ret)
  32. {
  33. uhub_assert(con);
  34. int error = SSL_get_error(con->ssl, ret);
  35. switch (error)
  36. {
  37. case SSL_ERROR_ZERO_RETURN:
  38. LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_ZERO_RETURN", ret, error);
  39. con->ssl_state = tls_st_error;
  40. return -1;
  41. case SSL_ERROR_WANT_READ:
  42. LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_WANT_READ", ret, error);
  43. con->flags |= NET_WANT_SSL_READ;
  44. net_con_update(con, NET_EVENT_READ);
  45. return 0;
  46. case SSL_ERROR_WANT_WRITE:
  47. LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_WANT_WRITE", ret, error);
  48. con->flags |= NET_WANT_SSL_WRITE;
  49. net_con_update(con, NET_EVENT_READ | NET_EVENT_WRITE);
  50. return 0;
  51. case SSL_ERROR_SYSCALL:
  52. LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_SYSCALL", ret, error);
  53. /* if ret == 0, connection closed, if ret == -1, check with errno */
  54. if (ret == 0)
  55. return -1;
  56. else
  57. return -net_error();
  58. case SSL_ERROR_SSL:
  59. LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_SSL", ret, error);
  60. /* internal openssl error */
  61. con->ssl_state = tls_st_error;
  62. return -1;
  63. }
  64. return -1;
  65. }
  66. ssize_t net_con_ssl_accept(struct net_connection* con)
  67. {
  68. uhub_assert(con);
  69. con->ssl_state = tls_st_accepting;
  70. ssize_t ret = SSL_accept(con->ssl);
  71. #ifdef NETWORK_DUMP_DEBUG
  72. LOG_PROTO("SSL_accept() ret=%d", ret);
  73. #endif
  74. if (ret > 0)
  75. {
  76. net_con_update(con, NET_EVENT_READ);
  77. con->ssl_state = tls_st_connected;
  78. }
  79. else
  80. {
  81. return handle_openssl_error(con, ret);
  82. }
  83. return ret;
  84. }
  85. ssize_t net_con_ssl_connect(struct net_connection* con)
  86. {
  87. uhub_assert(con);
  88. con->ssl_state = tls_st_connecting;
  89. ssize_t ret = SSL_connect(con->ssl);
  90. #ifdef NETWORK_DUMP_DEBUG
  91. LOG_PROTO("SSL_connect() ret=%d", ret);
  92. #endif
  93. if (ret > 0)
  94. {
  95. con->ssl_state = tls_st_connected;
  96. net_con_update(con, NET_EVENT_READ);
  97. }
  98. else
  99. {
  100. return handle_openssl_error(con, ret);
  101. }
  102. return ret;
  103. }
  104. ssize_t net_con_ssl_handshake(struct net_connection* con, enum net_con_ssl_mode ssl_mode, SSL_CTX* ssl_ctx)
  105. {
  106. uhub_assert(con);
  107. SSL* ssl = 0;
  108. if (ssl_mode == net_con_ssl_mode_server)
  109. {
  110. ssl = SSL_new(ssl_ctx);
  111. SSL_set_fd(ssl, con->sd);
  112. net_con_set_ssl(con, ssl);
  113. return net_con_ssl_accept(con);
  114. }
  115. else
  116. {
  117. ssl = SSL_new(SSL_CTX_new(TLSv1_method()));
  118. SSL_set_fd(ssl, con->sd);
  119. net_con_set_ssl(con, ssl);
  120. return net_con_ssl_connect(con);
  121. }
  122. }
  123. #endif /* SSL_SUPPORT */
  124. ssize_t net_con_send(struct net_connection* con, const void* buf, size_t len)
  125. {
  126. int ret;
  127. #ifdef SSL_SUPPORT
  128. if (!con->ssl)
  129. {
  130. #endif
  131. ret = net_send(con->sd, buf, len, UHUB_SEND_SIGNAL);
  132. if (ret == -1)
  133. {
  134. if (net_error() == EWOULDBLOCK || net_error() == EINTR)
  135. return 0;
  136. return -1;
  137. }
  138. #ifdef SSL_SUPPORT
  139. }
  140. else
  141. {
  142. con->write_len = len;
  143. ret = SSL_write(con->ssl, buf, len);
  144. LOG_PROTO("SSL_write(con=%p, buf=%p, len=" PRINTF_SIZE_T ") => %d", con, buf, len, ret);
  145. if (ret <= 0)
  146. {
  147. return -handle_openssl_error(con, ret);
  148. }
  149. }
  150. #endif
  151. return ret;
  152. }
  153. ssize_t net_con_recv(struct net_connection* con, void* buf, size_t len)
  154. {
  155. int ret;
  156. #ifdef SSL_SUPPORT
  157. if (!net_con_is_ssl(con))
  158. {
  159. #endif
  160. ret = net_recv(con->sd, buf, len, 0);
  161. if (ret == -1)
  162. {
  163. if (net_error() == EWOULDBLOCK || net_error() == EINTR)
  164. return 0;
  165. return -net_error();
  166. }
  167. else if (ret == 0)
  168. {
  169. return -1;
  170. }
  171. #ifdef SSL_SUPPORT
  172. }
  173. else
  174. {
  175. if (con->ssl_state == tls_st_error)
  176. return -1;
  177. ret = SSL_read(con->ssl, buf, len);
  178. LOG_PROTO("SSL_read(con=%p, buf=%p, len=" PRINTF_SIZE_T ") => %d", con, buf, len, ret);
  179. if (ret > 0)
  180. {
  181. net_con_update(con, NET_EVENT_READ);
  182. }
  183. else
  184. {
  185. return -handle_openssl_error(con, ret);
  186. }
  187. }
  188. #endif
  189. return ret;
  190. }
  191. ssize_t net_con_peek(struct net_connection* con, void* buf, size_t len)
  192. {
  193. int ret = net_recv(con->sd, buf, len, MSG_PEEK);
  194. if (ret == -1)
  195. {
  196. if (net_error() == EWOULDBLOCK || net_error() == EINTR)
  197. return 0;
  198. return -net_error();
  199. }
  200. else if (ret == 0)
  201. return -1;
  202. return ret;
  203. }
  204. #ifdef SSL_SUPPORT
  205. int net_con_is_ssl(struct net_connection* con)
  206. {
  207. return con->ssl != 0;
  208. }
  209. SSL* net_con_get_ssl(struct net_connection* con)
  210. {
  211. return con->ssl;
  212. }
  213. void net_con_set_ssl(struct net_connection* con, SSL* ssl)
  214. {
  215. con->ssl = ssl;
  216. }
  217. #endif /* SSL_SUPPORT */
  218. int net_con_get_sd(struct net_connection* con)
  219. {
  220. return con->sd;
  221. }
  222. void* net_con_get_ptr(struct net_connection* con)
  223. {
  224. return con->ptr;
  225. }
  226. void net_con_destroy(struct net_connection* con)
  227. {
  228. hub_free(con);
  229. }
  230. void net_con_callback(struct net_connection* con, int events)
  231. {
  232. if (con->flags & NET_CLEANUP)
  233. return;
  234. if (events == NET_EVENT_TIMEOUT)
  235. {
  236. LOG_TRACE("net_con_callback(%p, TIMEOUT", con);
  237. con->callback(con, events, con->ptr);
  238. return;
  239. }
  240. #ifdef SSL_SUPPORT
  241. if (!con->ssl)
  242. {
  243. #endif
  244. con->callback(con, events, con->ptr);
  245. #ifdef SSL_SUPPORT
  246. }
  247. else
  248. {
  249. #ifdef NETWORK_DUMP_DEBUG
  250. LOG_PROTO("net_con_event: events=%d, con=%p, state=%d", events, con, con->ssl_state);
  251. #endif
  252. switch (con->ssl_state)
  253. {
  254. case tls_st_none:
  255. con->callback(con, events, con->ptr);
  256. break;
  257. case tls_st_error:
  258. con->callback(con, NET_EVENT_READ, con->ptr);
  259. break;
  260. case tls_st_accepting:
  261. if (net_con_ssl_accept(con) < 0)
  262. {
  263. con->callback(con, NET_EVENT_READ, con->ptr);
  264. }
  265. break;
  266. case tls_st_connecting:
  267. if (net_con_ssl_connect(con) < 0)
  268. {
  269. con->callback(con, NET_EVENT_READ, con->ptr);
  270. }
  271. break;
  272. case tls_st_connected:
  273. LOG_PROTO("tls_st_connected, events=%s%s, ssl_flags=%s%s", (events & NET_EVENT_READ ? "R" : ""), (events & NET_EVENT_WRITE ? "W" : ""), con->flags & NET_WANT_SSL_READ ? "R" : "", con->flags & NET_WANT_SSL_WRITE ? "W" : "");
  274. if (events & NET_EVENT_WRITE && con->flags & NET_WANT_SSL_READ)
  275. {
  276. con->callback(con, events & NET_EVENT_READ, con->ptr);
  277. return;
  278. }
  279. if (events & NET_EVENT_READ && con->flags & NET_WANT_SSL_WRITE)
  280. {
  281. con->callback(con, events & NET_EVENT_READ, con->ptr);
  282. return;
  283. }
  284. con->callback(con, events, con->ptr);
  285. break;
  286. case tls_st_disconnecting:
  287. return;
  288. }
  289. }
  290. #endif
  291. }