/src/System/IO/Socket_Windows.cpp

https://github.com/benlaurie/keyspace · C++ · 354 lines · 267 code · 74 blank · 13 comment · 57 complexity · 0e9b6e7219328d4e522eb42b18b50164 MD5 · raw file

  1. #ifdef PLATFORM_WINDOWS
  2. #include "Socket.h"
  3. #include "System/Log.h"
  4. #include "System/Common.h"
  5. #include "System/Platform.h"
  6. #include <winsock2.h>
  7. #ifndef INADDR_NONE
  8. #define INADDR_NONE 0xffffffff
  9. #endif
  10. unsigned long iftonl(const char* interface_);
  11. /*
  12. * The Socket implementation is tightly coupled with the
  13. * IOProcessor because Windows asynchronous mechanism and
  14. * the IO completion uses the same Windows specific OVERLAPPED
  15. * structures, therefore these functions are imported here.
  16. */
  17. bool IOProcessorRegisterSocket(FD& fd);
  18. bool IOProcessorUnregisterSocket(FD& fd);
  19. bool IOProcessorAccept(const FD& listeningFd, FD& fd);
  20. bool IOProcessorConnect(FD& fd, Endpoint& endpoint);
  21. extern unsigned SEND_BUFFER_SIZE;
  22. Socket::Socket()
  23. {
  24. fd = INVALID_FD;
  25. listening = false;
  26. }
  27. bool Socket::Create(Proto proto)
  28. {
  29. int ret, stype, ipproto;
  30. BOOL trueval = TRUE;
  31. if (fd.sock != INVALID_SOCKET)
  32. {
  33. Log_Trace("Called Create() on existing socket");
  34. return false;
  35. }
  36. type = proto;
  37. listening = false;
  38. if (proto == UDP)
  39. {
  40. stype = SOCK_DGRAM;
  41. ipproto = IPPROTO_UDP;
  42. }
  43. else
  44. {
  45. stype = SOCK_STREAM;
  46. ipproto = IPPROTO_TCP;
  47. }
  48. // create the socket with WSA_FLAG_OVERLAPPED to support async operations
  49. fd.sock = WSASocket(AF_INET, stype, ipproto, NULL, 0, WSA_FLAG_OVERLAPPED);
  50. if (fd.sock == INVALID_SOCKET)
  51. return false;
  52. if (setsockopt(fd.sock, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, (char *)&trueval, sizeof(BOOL)))
  53. {
  54. ret = WSAGetLastError();
  55. Log_Trace("error = %d", ret);
  56. Close();
  57. return false;
  58. }
  59. if (setsockopt(fd.sock, SOL_SOCKET, SO_SNDBUF, (char *) &SEND_BUFFER_SIZE, sizeof(SEND_BUFFER_SIZE)))
  60. {
  61. ret = WSAGetLastError();
  62. Log_Trace("error = %d", ret);
  63. Close();
  64. return false;
  65. }
  66. // TODO set FD index too!
  67. IOProcessorRegisterSocket(fd);
  68. return true;
  69. }
  70. bool Socket::Bind(int port)
  71. {
  72. int ret;
  73. struct sockaddr_in sa;
  74. memset(&sa, 0, sizeof(sa));
  75. sa.sin_family = AF_INET;
  76. sa.sin_port = htons((uint16_t)port);
  77. sa.sin_addr.s_addr = htonl(INADDR_ANY);
  78. ret = bind(fd.sock, (struct sockaddr *)&sa, sizeof(struct sockaddr_in));
  79. if (ret < 0)
  80. {
  81. Log_Errno();
  82. Close();
  83. return false;
  84. }
  85. return true;
  86. }
  87. bool Socket::SetNonblocking()
  88. {
  89. u_long nonblocking;
  90. if (fd.sock == INVALID_SOCKET)
  91. {
  92. Log_Trace("SetNonblocking on invalid file descriptor");
  93. return false;
  94. }
  95. nonblocking = 1;
  96. if (ioctlsocket(fd.sock, FIONBIO, &nonblocking) == SOCKET_ERROR)
  97. return false;
  98. return true;
  99. }
  100. bool Socket::SetNodelay()
  101. {
  102. BOOL nodelay;
  103. if (fd.sock == INVALID_SOCKET)
  104. {
  105. Log_Trace("SetNodelay on invalid file descriptor");
  106. return false;
  107. }
  108. // Nagle algorithm is disabled if TCP_NODELAY is enabled.
  109. nodelay = TRUE;
  110. if (setsockopt(fd.sock, IPPROTO_TCP, TCP_NODELAY, (char *) &nodelay, sizeof(nodelay)) == SOCKET_ERROR)
  111. {
  112. Log_Trace("setsockopt() failed");
  113. return false;
  114. }
  115. return true;
  116. }
  117. bool Socket::Listen(int port, int backlog)
  118. {
  119. int ret;
  120. if (!Bind(port))
  121. return false;
  122. ret = listen(fd.sock, backlog);
  123. if (ret < 0)
  124. {
  125. Log_Errno();
  126. Close();
  127. return false;
  128. }
  129. listening = true;
  130. return true;
  131. }
  132. bool Socket::Accept(Socket *newSocket)
  133. {
  134. if (!IOProcessorAccept(fd, newSocket->fd))
  135. {
  136. Log_Errno();
  137. Close();
  138. return false;
  139. }
  140. // register the newly created socket
  141. IOProcessorRegisterSocket(newSocket->fd);
  142. return true;
  143. }
  144. bool Socket::Connect(Endpoint &endpoint)
  145. {
  146. if (!IOProcessorConnect(fd, endpoint))
  147. {
  148. Log_Errno();
  149. return false;
  150. }
  151. return true;
  152. }
  153. bool Socket::GetEndpoint(Endpoint &endpoint)
  154. {
  155. int ret;
  156. int len = ENDPOINT_SOCKADDR_SIZE;
  157. struct sockaddr* sa = (struct sockaddr*) endpoint.GetSockAddr();
  158. ret = getpeername(fd.sock, sa, &len);
  159. if (ret == SOCKET_ERROR)
  160. {
  161. ret = WSAGetLastError();
  162. Log_Trace("error = %d", ret);
  163. Close();
  164. return false;
  165. }
  166. return true;
  167. }
  168. const char* Socket::ToString(char s[ENDPOINT_STRING_SIZE])
  169. {
  170. Endpoint endpoint;
  171. if (!GetEndpoint(endpoint))
  172. return "";
  173. return endpoint.ToString(s);
  174. }
  175. bool Socket::SendTo(void *data, int count, const Endpoint &endpoint)
  176. {
  177. int ret;
  178. const struct sockaddr* sa = (const struct sockaddr*) ((Endpoint &) endpoint).GetSockAddr();
  179. ret = sendto(fd.sock, (const char*) data, count, 0,
  180. sa,
  181. ENDPOINT_SOCKADDR_SIZE);
  182. if (ret < 0)
  183. {
  184. Log_Errno();
  185. return false;
  186. }
  187. return true;
  188. }
  189. int Socket::Send(const char* data, int count, int timeout)
  190. {
  191. size_t left;
  192. int nwritten;
  193. left = count;
  194. while (left > 0)
  195. {
  196. if ((nwritten = send((SOCKET) fd.sock, (char*) data, count, 0)) == SOCKET_ERROR)
  197. {
  198. // TODO error handling
  199. if (WSAGetLastError() == WSAEWOULDBLOCK)
  200. return 0;
  201. return -1;
  202. }
  203. left -= nwritten;
  204. data += nwritten;
  205. }
  206. return count;
  207. }
  208. int Socket::Read(char* data, int count, int timeout)
  209. {
  210. int ret;
  211. ret = recv((SOCKET)fd.sock, (char *)data, count, 0);
  212. // TODO better error handling
  213. if (ret == SOCKET_ERROR)
  214. {
  215. if (WSAGetLastError() == WSAEWOULDBLOCK)
  216. return 0;
  217. return -1;
  218. }
  219. else if (ret == 0)
  220. {
  221. // graceful disconnection
  222. return -1;
  223. }
  224. return ret;
  225. }
  226. void Socket::Close()
  227. {
  228. int ret;
  229. if (fd.sock != INVALID_SOCKET)
  230. {
  231. IOProcessorUnregisterSocket(fd);
  232. ret = closesocket(fd.sock);
  233. if (ret < 0)
  234. Log_Errno();
  235. fd.sock = INVALID_SOCKET;
  236. fd.index = 0;
  237. }
  238. }
  239. unsigned long iftonl(const char* interface_)
  240. {
  241. int pos;
  242. int len;
  243. unsigned long a;
  244. unsigned long b;
  245. unsigned long c;
  246. unsigned long d;
  247. unsigned long addr;
  248. unsigned nread;
  249. nread = 0;
  250. pos = 0;
  251. len = strlen(interface_);
  252. a = strntouint64(interface_ + pos, len - pos, &nread);
  253. if (nread < 0 || a > 255)
  254. return INADDR_NONE;
  255. pos += nread;
  256. if (interface_[pos++] != '.')
  257. return INADDR_NONE;
  258. b = strntouint64(interface_ + pos, len - pos, &nread);
  259. if (nread < 0 || b > 255)
  260. return INADDR_NONE;
  261. pos += nread;
  262. if (interface_[pos++] != '.')
  263. return INADDR_NONE;
  264. c = strntouint64(interface_ + pos, len - pos, &nread);
  265. if (nread < 0 || c > 255)
  266. return INADDR_NONE;
  267. pos += nread;
  268. if (interface_[pos++] != '.')
  269. return INADDR_NONE;
  270. d = strntouint64(interface_ + pos, len - pos, &nread);
  271. if (nread < 0 || d > 255)
  272. return INADDR_NONE;
  273. pos += nread;
  274. if (interface_[pos] != '\0' &&
  275. interface_[pos] != ':')
  276. return INADDR_NONE;
  277. addr = (d & 0xff) << 24 | (c & 0xff) << 16 | (b & 0xff) << 8 | (a & 0xff);
  278. return addr;
  279. }
  280. #endif