/tests/web/websocket.py

https://bitbucket.org/prologic/circuits/ · Python · 538 lines · 344 code · 85 blank · 109 comment · 76 complexity · b9a576367a2cd129dd80d32cbd852123 MD5 · raw file

  1. """
  2. websocket - WebSocket client library for Python
  3. Copyright (C) 2010 Hiroki Ohtani(liris)
  4. This library is free software; you can redistribute it and/or
  5. modify it under the terms of the GNU Lesser General Public
  6. License as published by the Free Software Foundation; either
  7. version 2.1 of the License, or (at your option) any later version.
  8. This library is distributed in the hope that it will be useful,
  9. but WITHOUT ANY WARRANTY; without even the implied warranty of
  10. MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  11. Lesser General Public License for more details.
  12. You should have received a copy of the GNU Lesser General Public
  13. License along with this library; if not, write to the Free Software
  14. Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
  15. """
  16. import socket
  17. import random
  18. import struct
  19. from hashlib import md5
  20. import logging
  21. from .helpers import urlparse
  22. logger = logging.getLogger()
  23. class WebSocketException(Exception):
  24. pass
  25. class ConnectionClosedException(WebSocketException):
  26. pass
  27. default_timeout = None
  28. traceEnabled = False
  29. def enableTrace(tracable):
  30. """
  31. turn on/off the tracability.
  32. """
  33. global traceEnabled
  34. traceEnabled = tracable
  35. if tracable:
  36. if not logger.handlers:
  37. logger.addHandler(logging.StreamHandler())
  38. logger.setLevel(logging.DEBUG)
  39. def setdefaulttimeout(timeout):
  40. """
  41. Set the global timeout setting to connect.
  42. """
  43. global default_timeout
  44. default_timeout = timeout
  45. def getdefaulttimeout():
  46. """
  47. Return the global timeout setting to connect.
  48. """
  49. return default_timeout
  50. def _parse_url(url):
  51. """
  52. parse url and the result is tuple of
  53. (hostname, port, resource path and the flag of secure mode)
  54. """
  55. parsed = urlparse(url)
  56. if parsed.hostname:
  57. hostname = parsed.hostname
  58. else:
  59. raise ValueError("hostname is invalid")
  60. port = 0
  61. if parsed.port:
  62. port = parsed.port
  63. is_secure = False
  64. if parsed.scheme == "ws":
  65. if not port:
  66. port = 80
  67. elif parsed.scheme == "wss":
  68. is_secure = True
  69. if not port:
  70. port = 443
  71. else:
  72. raise ValueError("scheme %s is invalid" % parsed.scheme)
  73. if parsed.path:
  74. resource = parsed.path
  75. else:
  76. resource = "/"
  77. return (hostname, port, resource, is_secure)
  78. def create_connection(url, timeout=None, **options):
  79. """
  80. connect to url and return websocket object.
  81. Connect to url and return the WebSocket object.
  82. Passing optional timeout parameter will set the timeout on the socket.
  83. If no timeout is supplied, the global default timeout setting returned
  84. by getdefauttimeout() is used.
  85. """
  86. websock = WebSocket()
  87. websock.settimeout(timeout is not None and timeout or default_timeout)
  88. websock.connect(url, **options)
  89. return websock
  90. _MAX_INTEGER = (1 << 32) - 1
  91. _AVAILABLE_KEY_CHARS = list(range(0x21, 0x2f + 1)).extend(
  92. list(range(0x3a, 0x7e + 1))
  93. )
  94. _MAX_CHAR_BYTE = (1 << 8) - 1
  95. _MAX_ASCII_BYTE = (1 << 7) - 1
  96. # ref. Websocket gets an update, and it breaks stuff.
  97. # http://axod.blogspot.com/2010/06/websocket-gets-update-and-it-breaks.html
  98. def _create_sec_websocket_key():
  99. spaces_n = random.randint(1, 12)
  100. max_n = _MAX_INTEGER / spaces_n
  101. number_n = random.randint(0, int(max_n))
  102. product_n = number_n * spaces_n
  103. key_n = str(product_n)
  104. for i in range(random.randint(1, 12)):
  105. c = random.choice(_AVAILABLE_KEY_CHARS)
  106. pos = random.randint(0, len(key_n))
  107. key_n = key_n[0:pos] + chr(c) + key_n[pos:]
  108. for i in range(spaces_n):
  109. pos = random.randint(1, len(key_n)-1)
  110. key_n = key_n[0:pos] + " " + key_n[pos:]
  111. return number_n, key_n
  112. def _create_key3():
  113. return "".join([chr(random.randint(0, _MAX_ASCII_BYTE)) for i in range(8)])
  114. HEADERS_TO_CHECK = {
  115. "upgrade": "websocket",
  116. "connection": "upgrade",
  117. }
  118. HEADERS_TO_EXIST_FOR_HYBI00 = [
  119. "sec-websocket-origin",
  120. "sec-websocket-location",
  121. ]
  122. HEADERS_TO_EXIST_FOR_HIXIE75 = [
  123. "websocket-origin",
  124. "websocket-location",
  125. ]
  126. class _SSLSocketWrapper(object):
  127. def __init__(self, sock):
  128. self.ssl = socket.ssl(sock)
  129. def recv(self, bufsize):
  130. return self.ssl.read(bufsize)
  131. def send(self, payload):
  132. return self.ssl.write(payload)
  133. class WebSocket(object):
  134. """
  135. Low level WebSocket interface.
  136. This class is based on
  137. The WebSocket protocol draft-hixie-thewebsocketprotocol-76
  138. http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76
  139. We can connect to the websocket server and send/recieve data.
  140. The following example is a echo client.
  141. >>> import websocket
  142. >>> ws = websocket.WebSocket()
  143. >>> ws.Connect("ws://localhost:8080/echo")
  144. >>> ws.send("Hello, Server")
  145. >>> ws.recv()
  146. 'Hello, Server'
  147. >>> ws.close()
  148. """
  149. def __init__(self):
  150. """
  151. Initalize WebSocket object.
  152. """
  153. self.connected = False
  154. self.io_sock = self.sock = socket.socket()
  155. def settimeout(self, timeout):
  156. """
  157. Set the timeout to the websocket.
  158. """
  159. self.sock.settimeout(timeout)
  160. def gettimeout(self):
  161. """
  162. Get the websocket timeout.
  163. """
  164. return self.sock.gettimeout()
  165. def connect(self, url, **options):
  166. """
  167. Connect to url. url is websocket url scheme.
  168. ie. ws://host:port/resource
  169. """
  170. hostname, port, resource, is_secure = _parse_url(url)
  171. # TODO: we need to support proxy
  172. self.sock.connect((hostname, port))
  173. if is_secure:
  174. self.io_sock = _SSLSocketWrapper(self.sock)
  175. self._handshake(hostname, port, resource, **options)
  176. def _handshake(self, host, port, resource, **options):
  177. sock = self.io_sock
  178. headers = []
  179. headers.append("GET %s HTTP/1.1" % resource)
  180. headers.append("Upgrade: WebSocket")
  181. headers.append("Connection: Upgrade")
  182. if port == 80:
  183. hostport = host
  184. else:
  185. hostport = "%s:%d" % (host, port)
  186. headers.append("Host: %s" % hostport)
  187. headers.append("Origin: %s" % hostport)
  188. number_1, key_1 = _create_sec_websocket_key()
  189. headers.append("Sec-WebSocket-Key1: %s" % key_1)
  190. number_2, key_2 = _create_sec_websocket_key()
  191. headers.append("Sec-WebSocket-Key2: %s" % key_2)
  192. if "header" in options:
  193. headers.extend(options["header"])
  194. headers.append("")
  195. key3 = _create_key3()
  196. headers.append(key3)
  197. header_str = "\r\n".join(headers)
  198. sock.send(header_str.encode('utf-8'))
  199. if traceEnabled:
  200. logger.debug("--- request header ---")
  201. logger.debug(header_str)
  202. logger.debug("-----------------------")
  203. status, resp_headers = self._read_headers()
  204. if status != 101:
  205. self.close()
  206. raise WebSocketException("Handshake Status %d" % status)
  207. success, secure = self._validate_header(resp_headers)
  208. if not success:
  209. self.close()
  210. raise WebSocketException("Invalid WebSocket Header")
  211. if secure:
  212. resp = self._get_resp()
  213. if not self._validate_resp(number_1, number_2, key3, resp):
  214. self.close()
  215. raise WebSocketException("challenge-response error")
  216. self.connected = True
  217. def _validate_resp(self, number_1, number_2, key3, resp):
  218. challenge = struct.pack("!I", number_1)
  219. challenge += struct.pack("!I", number_2)
  220. challenge += key3.encode('utf-8')
  221. digest = md5(challenge).digest()
  222. return resp == digest
  223. def _get_resp(self):
  224. result = self._recv(16)
  225. if traceEnabled:
  226. logger.debug("--- challenge response result ---")
  227. logger.debug(repr(result))
  228. logger.debug("---------------------------------")
  229. return result
  230. def _validate_header(self, headers):
  231. #TODO: check other headers
  232. for key, value in HEADERS_TO_CHECK.items():
  233. v = headers.get(key, None)
  234. if value != v:
  235. return False, False
  236. success = 0
  237. for key in HEADERS_TO_EXIST_FOR_HYBI00:
  238. if key in headers:
  239. success += 1
  240. if success == len(HEADERS_TO_EXIST_FOR_HYBI00):
  241. return True, True
  242. elif success != 0:
  243. return False, True
  244. success = 0
  245. for key in HEADERS_TO_EXIST_FOR_HIXIE75:
  246. if key in headers:
  247. success += 1
  248. if success == len(HEADERS_TO_EXIST_FOR_HIXIE75):
  249. return True, False
  250. return False, False
  251. def _read_headers(self):
  252. status = None
  253. headers = {}
  254. if traceEnabled:
  255. logger.debug("--- response header ---")
  256. while True:
  257. line = self._recv_line()
  258. if line == b"\r\n":
  259. break
  260. line = line.strip()
  261. if traceEnabled:
  262. logger.debug(line)
  263. if not status:
  264. status_info = line.split(b" ", 2)
  265. status = int(status_info[1])
  266. else:
  267. kv = line.split(b":", 1)
  268. if len(kv) == 2:
  269. key, value = kv
  270. headers[key.lower().decode('utf-8')] \
  271. = value.strip().lower().decode('utf-8')
  272. else:
  273. raise WebSocketException("Invalid header")
  274. if traceEnabled:
  275. logger.debug("-----------------------")
  276. return status, headers
  277. def send(self, payload):
  278. """
  279. Send the data as string. payload must be utf-8 string or unicoce.
  280. """
  281. if isinstance(payload, str):
  282. payload = payload.encode("utf-8")
  283. data = b"".join([b"\x00", payload, b"\xff"])
  284. self.io_sock.send(data)
  285. if traceEnabled:
  286. logger.debug("send: " + repr(data))
  287. def recv(self):
  288. """
  289. Reeive utf-8 string data from the server.
  290. """
  291. b = self._recv(1)
  292. if enableTrace:
  293. logger.debug("recv frame: " + repr(b))
  294. frame_type = ord(b)
  295. if frame_type == 0x00:
  296. bytes = []
  297. while True:
  298. b = self._recv(1)
  299. if b == b"\xff":
  300. break
  301. else:
  302. bytes.append(b)
  303. return b"".join(bytes)
  304. elif 0x80 < frame_type < 0xff:
  305. # which frame type is valid?
  306. length = self._read_length()
  307. bytes = self._recv_strict(length)
  308. return bytes
  309. elif frame_type == 0xff:
  310. self._recv(1)
  311. self._closeInternal()
  312. return None
  313. else:
  314. raise WebSocketException("Invalid frame type")
  315. def _read_length(self):
  316. length = 0
  317. while True:
  318. b = ord(self._recv(1))
  319. length = length * (1 << 7) + (b & 0x7f)
  320. if b < 0x80:
  321. break
  322. return length
  323. def close(self):
  324. """
  325. Close Websocket object
  326. """
  327. if self.connected:
  328. try:
  329. self.io_sock.send("\xff\x00")
  330. timeout = self.sock.gettimeout()
  331. self.sock.settimeout(1)
  332. try:
  333. result = self._recv(2)
  334. if result != "\xff\x00":
  335. logger.error("bad closing Handshake")
  336. except:
  337. pass
  338. self.sock.settimeout(timeout)
  339. self.sock.shutdown(socket.SHUT_RDWR)
  340. except:
  341. pass
  342. self._closeInternal()
  343. def _closeInternal(self):
  344. self.connected = False
  345. self.sock.close()
  346. self.io_sock = self.sock
  347. def _recv(self, bufsize):
  348. bytes = self.io_sock.recv(bufsize)
  349. if not bytes:
  350. raise ConnectionClosedException()
  351. return bytes
  352. def _recv_strict(self, bufsize):
  353. remaining = bufsize
  354. bytes = ""
  355. while remaining:
  356. bytes += self._recv(remaining)
  357. remaining = bufsize - len(bytes)
  358. return bytes
  359. def _recv_line(self):
  360. line = []
  361. while True:
  362. c = self._recv(1)
  363. line.append(c)
  364. if c == b"\n":
  365. break
  366. return b"".join(line)
  367. class WebSocketApp(object):
  368. """
  369. Higher level of APIs are provided.
  370. The interface is like JavaScript WebSocket object.
  371. """
  372. def __init__(self, url,
  373. on_open=None, on_message=None, on_error=None,
  374. on_close=None):
  375. """
  376. url: websocket url.
  377. on_open: callable object which is called at opening websocket.
  378. this function has one argument. The arugment is this class object.
  379. on_message: callbale object which is called when recieved data.
  380. on_message has 2 arguments.
  381. The 1st arugment is this class object.
  382. The passing 2nd arugment is utf-8 string which we get from the server.
  383. on_error: callable object which is called when we get error.
  384. on_error has 2 arguments.
  385. The 1st arugment is this class object.
  386. The passing 2nd arugment is exception object.
  387. on_close: callable object which is called when closed the connection.
  388. this function has one argument. The arugment is this class object.
  389. """
  390. self.url = url
  391. self.on_open = on_open
  392. self.on_message = on_message
  393. self.on_error = on_error
  394. self.on_close = on_close
  395. self.sock = None
  396. def send(self, data):
  397. """
  398. send message. data must be utf-8 string or unicode.
  399. """
  400. self.sock.send(data)
  401. def close(self):
  402. """
  403. close websocket connection.
  404. """
  405. self.sock.close()
  406. def run_forever(self):
  407. """
  408. run event loop for WebSocket framework.
  409. This loop is infinite loop and is alive during websocket is available.
  410. """
  411. if self.sock:
  412. raise WebSocketException("socket is already opened")
  413. try:
  414. self.sock = WebSocket()
  415. self.sock.connect(self.url)
  416. self._run_with_no_err(self.on_open)
  417. while True:
  418. data = self.sock.recv()
  419. if data is None:
  420. break
  421. self._run_with_no_err(self.on_message, data)
  422. except Exception as e:
  423. self._run_with_no_err(self.on_error, e)
  424. finally:
  425. self.sock.close()
  426. self._run_with_no_err(self.on_close)
  427. self.sock = None
  428. def _run_with_no_err(self, callback, *args):
  429. if callback:
  430. try:
  431. callback(self, *args)
  432. except Exception as e:
  433. if logger.isEnabledFor(logging.DEBUG):
  434. logger.error(e)
  435. if __name__ == "__main__":
  436. enableTrace(True)
  437. #ws = create_connection("ws://localhost:8080/echo")
  438. ws = create_connection("ws://localhost:5000/chat")
  439. print("Sending 'Hello, World'...")
  440. ws.send("Hello, World")
  441. print("Sent")
  442. print("Receiving...")
  443. result = ws.recv()
  444. print("Received '%s'" % result)
  445. ws.close()