PageRenderTime 28ms CodeModel.GetById 1ms app.highlight 21ms RepoModel.GetById 1ms app.codeStats 0ms

/tests/web/websocket.py

https://bitbucket.org/prologic/circuits/
Python | 538 lines | 465 code | 30 blank | 43 comment | 18 complexity | b9a576367a2cd129dd80d32cbd852123 MD5 | raw file
  1"""
  2websocket - WebSocket client library for Python
  3
  4Copyright (C) 2010 Hiroki Ohtani(liris)
  5
  6    This library is free software; you can redistribute it and/or
  7    modify it under the terms of the GNU Lesser General Public
  8    License as published by the Free Software Foundation; either
  9    version 2.1 of the License, or (at your option) any later version.
 10
 11    This library is distributed in the hope that it will be useful,
 12    but WITHOUT ANY WARRANTY; without even the implied warranty of
 13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 14    Lesser General Public License for more details.
 15
 16    You should have received a copy of the GNU Lesser General Public
 17    License along with this library; if not, write to the Free Software
 18    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 19
 20"""
 21
 22
 23import socket
 24import random
 25import struct
 26from hashlib import md5
 27import logging
 28
 29from .helpers import urlparse
 30
 31
 32logger = logging.getLogger()
 33
 34
 35class WebSocketException(Exception):
 36    pass
 37
 38
 39class ConnectionClosedException(WebSocketException):
 40    pass
 41
 42default_timeout = None
 43traceEnabled = False
 44
 45
 46def enableTrace(tracable):
 47    """
 48    turn on/off the tracability.
 49    """
 50    global traceEnabled
 51    traceEnabled = tracable
 52    if tracable:
 53        if not logger.handlers:
 54            logger.addHandler(logging.StreamHandler())
 55        logger.setLevel(logging.DEBUG)
 56
 57
 58def setdefaulttimeout(timeout):
 59    """
 60    Set the global timeout setting to connect.
 61    """
 62    global default_timeout
 63    default_timeout = timeout
 64
 65
 66def getdefaulttimeout():
 67    """
 68    Return the global timeout setting to connect.
 69    """
 70    return default_timeout
 71
 72
 73def _parse_url(url):
 74    """
 75    parse url and the result is tuple of
 76    (hostname, port, resource path and the flag of secure mode)
 77    """
 78    parsed = urlparse(url)
 79    if parsed.hostname:
 80        hostname = parsed.hostname
 81    else:
 82        raise ValueError("hostname is invalid")
 83    port = 0
 84    if parsed.port:
 85        port = parsed.port
 86
 87    is_secure = False
 88    if parsed.scheme == "ws":
 89        if not port:
 90            port = 80
 91    elif parsed.scheme == "wss":
 92        is_secure = True
 93        if not port:
 94            port = 443
 95    else:
 96        raise ValueError("scheme %s is invalid" % parsed.scheme)
 97
 98    if parsed.path:
 99        resource = parsed.path
100    else:
101        resource = "/"
102
103    return (hostname, port, resource, is_secure)
104
105
106def create_connection(url, timeout=None, **options):
107    """
108    connect to url and return websocket object.
109
110    Connect to url and return the WebSocket object.
111    Passing optional timeout parameter will set the timeout on the socket.
112    If no timeout is supplied, the global default timeout setting returned
113    by getdefauttimeout() is used.
114    """
115    websock = WebSocket()
116    websock.settimeout(timeout is not None and timeout or default_timeout)
117    websock.connect(url, **options)
118    return websock
119
120_MAX_INTEGER = (1 << 32) - 1
121_AVAILABLE_KEY_CHARS = list(range(0x21, 0x2f + 1)).extend(
122    list(range(0x3a, 0x7e + 1))
123)
124_MAX_CHAR_BYTE = (1 << 8) - 1
125_MAX_ASCII_BYTE = (1 << 7) - 1
126
127# ref. Websocket gets an update, and it breaks stuff.
128# http://axod.blogspot.com/2010/06/websocket-gets-update-and-it-breaks.html
129
130
131def _create_sec_websocket_key():
132    spaces_n = random.randint(1, 12)
133    max_n = _MAX_INTEGER / spaces_n
134    number_n = random.randint(0, int(max_n))
135    product_n = number_n * spaces_n
136    key_n = str(product_n)
137    for i in range(random.randint(1, 12)):
138        c = random.choice(_AVAILABLE_KEY_CHARS)
139        pos = random.randint(0, len(key_n))
140        key_n = key_n[0:pos] + chr(c) + key_n[pos:]
141    for i in range(spaces_n):
142        pos = random.randint(1, len(key_n)-1)
143        key_n = key_n[0:pos] + " " + key_n[pos:]
144
145    return number_n, key_n
146
147
148def _create_key3():
149    return "".join([chr(random.randint(0, _MAX_ASCII_BYTE)) for i in range(8)])
150
151HEADERS_TO_CHECK = {
152    "upgrade": "websocket",
153    "connection": "upgrade",
154}
155
156HEADERS_TO_EXIST_FOR_HYBI00 = [
157    "sec-websocket-origin",
158    "sec-websocket-location",
159]
160
161HEADERS_TO_EXIST_FOR_HIXIE75 = [
162    "websocket-origin",
163    "websocket-location",
164]
165
166
167class _SSLSocketWrapper(object):
168
169    def __init__(self, sock):
170        self.ssl = socket.ssl(sock)
171
172    def recv(self, bufsize):
173        return self.ssl.read(bufsize)
174
175    def send(self, payload):
176        return self.ssl.write(payload)
177
178
179class WebSocket(object):
180    """
181    Low level WebSocket interface.
182    This class is based on
183      The WebSocket protocol draft-hixie-thewebsocketprotocol-76
184      http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76
185
186    We can connect to the websocket server and send/recieve data.
187    The following example is a echo client.
188
189    >>> import websocket
190    >>> ws = websocket.WebSocket()
191    >>> ws.Connect("ws://localhost:8080/echo")
192    >>> ws.send("Hello, Server")
193    >>> ws.recv()
194    'Hello, Server'
195    >>> ws.close()
196    """
197    def __init__(self):
198        """
199        Initalize WebSocket object.
200        """
201        self.connected = False
202        self.io_sock = self.sock = socket.socket()
203
204    def settimeout(self, timeout):
205        """
206        Set the timeout to the websocket.
207        """
208        self.sock.settimeout(timeout)
209
210    def gettimeout(self):
211        """
212        Get the websocket timeout.
213        """
214        return self.sock.gettimeout()
215
216    def connect(self, url, **options):
217        """
218        Connect to url. url is websocket url scheme.
219        ie. ws://host:port/resource
220        """
221        hostname, port, resource, is_secure = _parse_url(url)
222        # TODO: we need to support proxy
223        self.sock.connect((hostname, port))
224        if is_secure:
225            self.io_sock = _SSLSocketWrapper(self.sock)
226        self._handshake(hostname, port, resource, **options)
227
228    def _handshake(self, host, port, resource, **options):
229        sock = self.io_sock
230        headers = []
231        headers.append("GET %s HTTP/1.1" % resource)
232        headers.append("Upgrade: WebSocket")
233        headers.append("Connection: Upgrade")
234        if port == 80:
235            hostport = host
236        else:
237            hostport = "%s:%d" % (host, port)
238        headers.append("Host: %s" % hostport)
239        headers.append("Origin: %s" % hostport)
240
241        number_1, key_1 = _create_sec_websocket_key()
242        headers.append("Sec-WebSocket-Key1: %s" % key_1)
243        number_2, key_2 = _create_sec_websocket_key()
244        headers.append("Sec-WebSocket-Key2: %s" % key_2)
245        if "header" in options:
246            headers.extend(options["header"])
247
248        headers.append("")
249        key3 = _create_key3()
250        headers.append(key3)
251
252        header_str = "\r\n".join(headers)
253        sock.send(header_str.encode('utf-8'))
254        if traceEnabled:
255            logger.debug("--- request header ---")
256            logger.debug(header_str)
257            logger.debug("-----------------------")
258
259        status, resp_headers = self._read_headers()
260
261        if status != 101:
262            self.close()
263            raise WebSocketException("Handshake Status %d" % status)
264        success, secure = self._validate_header(resp_headers)
265        if not success:
266            self.close()
267            raise WebSocketException("Invalid WebSocket Header")
268
269        if secure:
270            resp = self._get_resp()
271
272            if not self._validate_resp(number_1, number_2, key3, resp):
273                self.close()
274                raise WebSocketException("challenge-response error")
275
276        self.connected = True
277
278    def _validate_resp(self, number_1, number_2, key3, resp):
279        challenge = struct.pack("!I", number_1)
280        challenge += struct.pack("!I", number_2)
281        challenge += key3.encode('utf-8')
282        digest = md5(challenge).digest()
283
284        return resp == digest
285
286    def _get_resp(self):
287        result = self._recv(16)
288        if traceEnabled:
289            logger.debug("--- challenge response result ---")
290            logger.debug(repr(result))
291            logger.debug("---------------------------------")
292
293        return result
294
295    def _validate_header(self, headers):
296        #TODO: check other headers
297        for key, value in HEADERS_TO_CHECK.items():
298            v = headers.get(key, None)
299            if value != v:
300                return False, False
301
302        success = 0
303        for key in HEADERS_TO_EXIST_FOR_HYBI00:
304            if key in headers:
305                success += 1
306        if success == len(HEADERS_TO_EXIST_FOR_HYBI00):
307            return True, True
308        elif success != 0:
309            return False, True
310
311        success = 0
312        for key in HEADERS_TO_EXIST_FOR_HIXIE75:
313            if key in headers:
314                success += 1
315        if success == len(HEADERS_TO_EXIST_FOR_HIXIE75):
316            return True, False
317
318        return False, False
319
320    def _read_headers(self):
321        status = None
322        headers = {}
323        if traceEnabled:
324            logger.debug("--- response header ---")
325
326        while True:
327            line = self._recv_line()
328            if line == b"\r\n":
329                break
330            line = line.strip()
331            if traceEnabled:
332                logger.debug(line)
333            if not status:
334                status_info = line.split(b" ", 2)
335                status = int(status_info[1])
336            else:
337                kv = line.split(b":", 1)
338                if len(kv) == 2:
339                    key, value = kv
340                    headers[key.lower().decode('utf-8')] \
341                        = value.strip().lower().decode('utf-8')
342                else:
343                    raise WebSocketException("Invalid header")
344
345        if traceEnabled:
346            logger.debug("-----------------------")
347
348        return status, headers
349
350    def send(self, payload):
351        """
352        Send the data as string. payload must be utf-8 string or unicoce.
353        """
354        if isinstance(payload, str):
355            payload = payload.encode("utf-8")
356        data = b"".join([b"\x00", payload, b"\xff"])
357        self.io_sock.send(data)
358        if traceEnabled:
359            logger.debug("send: " + repr(data))
360
361    def recv(self):
362        """
363        Reeive utf-8 string data from the server.
364        """
365        b = self._recv(1)
366
367        if enableTrace:
368            logger.debug("recv frame: " + repr(b))
369        frame_type = ord(b)
370
371        if frame_type == 0x00:
372            bytes = []
373            while True:
374                b = self._recv(1)
375                if b == b"\xff":
376                    break
377                else:
378                    bytes.append(b)
379            return b"".join(bytes)
380        elif 0x80 < frame_type < 0xff:
381            # which frame type is valid?
382            length = self._read_length()
383            bytes = self._recv_strict(length)
384            return bytes
385        elif frame_type == 0xff:
386            self._recv(1)
387            self._closeInternal()
388            return None
389        else:
390            raise WebSocketException("Invalid frame type")
391
392    def _read_length(self):
393        length = 0
394        while True:
395            b = ord(self._recv(1))
396            length = length * (1 << 7) + (b & 0x7f)
397            if b < 0x80:
398                break
399
400        return length
401
402    def close(self):
403        """
404        Close Websocket object
405        """
406        if self.connected:
407            try:
408                self.io_sock.send("\xff\x00")
409                timeout = self.sock.gettimeout()
410                self.sock.settimeout(1)
411                try:
412                    result = self._recv(2)
413                    if result != "\xff\x00":
414                        logger.error("bad closing Handshake")
415                except:
416                    pass
417                self.sock.settimeout(timeout)
418                self.sock.shutdown(socket.SHUT_RDWR)
419            except:
420                pass
421        self._closeInternal()
422
423    def _closeInternal(self):
424        self.connected = False
425        self.sock.close()
426        self.io_sock = self.sock
427
428    def _recv(self, bufsize):
429        bytes = self.io_sock.recv(bufsize)
430
431        if not bytes:
432            raise ConnectionClosedException()
433        return bytes
434
435    def _recv_strict(self, bufsize):
436        remaining = bufsize
437        bytes = ""
438        while remaining:
439            bytes += self._recv(remaining)
440            remaining = bufsize - len(bytes)
441
442        return bytes
443
444    def _recv_line(self):
445        line = []
446        while True:
447            c = self._recv(1)
448            line.append(c)
449            if c == b"\n":
450                break
451        return b"".join(line)
452
453
454class WebSocketApp(object):
455    """
456    Higher level of APIs are provided.
457    The interface is like JavaScript WebSocket object.
458    """
459    def __init__(self, url,
460                 on_open=None, on_message=None, on_error=None,
461                 on_close=None):
462        """
463        url: websocket url.
464        on_open: callable object which is called at opening websocket.
465          this function has one argument. The arugment is this class object.
466        on_message: callbale object which is called when recieved data.
467         on_message has 2 arguments.
468         The 1st arugment is this class object.
469         The passing 2nd arugment is utf-8 string which we get from the server.
470       on_error: callable object which is called when we get error.
471         on_error has 2 arguments.
472         The 1st arugment is this class object.
473         The passing 2nd arugment is exception object.
474       on_close: callable object which is called when closed the connection.
475         this function has one argument. The arugment is this class object.
476        """
477        self.url = url
478        self.on_open = on_open
479        self.on_message = on_message
480        self.on_error = on_error
481        self.on_close = on_close
482        self.sock = None
483
484    def send(self, data):
485        """
486        send message. data must be utf-8 string or unicode.
487        """
488        self.sock.send(data)
489
490    def close(self):
491        """
492        close websocket connection.
493        """
494        self.sock.close()
495
496    def run_forever(self):
497        """
498        run event loop for WebSocket framework.
499        This loop is infinite loop and is alive during websocket is available.
500        """
501        if self.sock:
502            raise WebSocketException("socket is already opened")
503        try:
504            self.sock = WebSocket()
505            self.sock.connect(self.url)
506            self._run_with_no_err(self.on_open)
507            while True:
508                data = self.sock.recv()
509                if data is None:
510                    break
511                self._run_with_no_err(self.on_message, data)
512        except Exception as e:
513            self._run_with_no_err(self.on_error, e)
514        finally:
515            self.sock.close()
516            self._run_with_no_err(self.on_close)
517            self.sock = None
518
519    def _run_with_no_err(self, callback, *args):
520        if callback:
521            try:
522                callback(self, *args)
523            except Exception as e:
524                if logger.isEnabledFor(logging.DEBUG):
525                    logger.error(e)
526
527
528if __name__ == "__main__":
529    enableTrace(True)
530    #ws = create_connection("ws://localhost:8080/echo")
531    ws = create_connection("ws://localhost:5000/chat")
532    print("Sending 'Hello, World'...")
533    ws.send("Hello, World")
534    print("Sent")
535    print("Receiving...")
536    result = ws.recv()
537    print("Received '%s'" % result)
538    ws.close()