PageRenderTime 64ms CodeModel.GetById 8ms app.highlight 50ms RepoModel.GetById 2ms app.codeStats 0ms

/lib/concurrence/io/socket.py

https://bitbucket.org/incubaid/pylabs-core-6.0
Python | 354 lines | 339 code | 11 blank | 4 comment | 10 complexity | 85b906bb3bdbc40734ca77ae3b794c37 MD5 | raw file
  1# Copyright (C) 2009, Hyves (Startphone Ltd.)
  2#
  3# This module is part of the Concurrence Framework and is released under
  4# the New BSD License: http://www.opensource.org/licenses/bsd-license.php
  5
  6import logging
  7import _socket
  8import types
  9import os
 10
 11from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, ECONNRESET, ENOTCONN, ESHUTDOWN, EINTR, EISCONN, ENOENT, EAGAIN
 12
 13import _io
 14
 15from concurrence import Tasklet, FileDescriptorEvent, TIMEOUT_CURRENT
 16from concurrence.io import IOStream
 17
 18DEFAULT_BACKLOG = 512
 19XMOD = 8
 20
 21_interceptor = None
 22
 23class Socket(IOStream):
 24    log = logging.getLogger('Socket')
 25
 26    __slots__ = ['socket', 'fd', '_readable', '_writable', 'state']
 27
 28    STATE_INIT = 0
 29    STATE_LISTENING = 1
 30    STATE_CONNECTING = 2
 31    STATE_CONNECTED = 3
 32    STATE_CLOSING = 4
 33    STATE_CLOSED = 5
 34
 35    _x = 0
 36
 37    def __init__(self, socket, state = STATE_INIT):
 38        """don't call directly pls use one of the provided classmethod to create a socket"""
 39        self.socket = socket
 40
 41        if _socket.AF_INET == socket.family:
 42            #always set the nodelay option on tcp sockets. This turns off the Nagle algorithm
 43            #we don't need this because in concurrence we are always buffering ourselves
 44            #before sending out data, so no need to let the tcp stack do it again and possibly delay
 45            #sending
 46            try:
 47                self.socket.setsockopt(_socket.IPPROTO_TCP, _socket.TCP_NODELAY, 1)
 48            except:
 49                self.log.warn("could not set TCP_NODELAY")
 50
 51        #concurrence sockets are always non-blocking, this is the whole idea :-) :
 52        self.socket.setblocking(0)
 53        self.fd = self.socket.fileno()
 54        self._readable = None #will be created lazily
 55        self._writable = None #will be created lazily
 56        self.state = state
 57
 58    @classmethod
 59    def set_interceptor(cls, interceptor):
 60        global _interceptor
 61        _interceptor = interceptor
 62
 63    @classmethod
 64    def from_address(cls, addr):
 65        """Creates a new socket from the given address. If the addr is a tuple (host, port)
 66        a normal tcp socket is assumed. if addr is a string, a UNIX Domain socket is assumed"""
 67        if _interceptor is not None:
 68            return _interceptor(addr)
 69        elif type(addr) == types.StringType:
 70            return cls(_socket.socket(_socket.AF_UNIX, _socket.SOCK_STREAM))
 71        else:
 72            return cls(_socket.socket(_socket.AF_INET, _socket.SOCK_STREAM))
 73
 74    @classmethod
 75    def new(cls):
 76        return cls(_socket.socket(_socket.AF_INET, _socket.SOCK_STREAM))
 77
 78    @classmethod
 79    def server(cls, addr, backlog = DEFAULT_BACKLOG, reuse_address = True):
 80        s = cls.from_address(addr)
 81        s.set_reuse_address(reuse_address)
 82        s.bind(addr)
 83        s.listen(backlog)
 84        return s
 85
 86    @classmethod
 87    def connect(cls, addr, timeout = TIMEOUT_CURRENT):
 88        """creates a new socket and connects it to the given address.
 89        returns the connected socket"""
 90        socket = cls.from_address(addr)
 91        socket._connect(addr, timeout)
 92        return socket
 93
 94    @classmethod
 95    def from_file_descriptor(cls, fd, socket_family = _socket.AF_UNIX, socket_type = _socket.SOCK_STREAM, socket_state = STATE_INIT):
 96        return cls(_socket.fromfd(fd, socket_family, socket_type), socket_state)
 97
 98    def _get_readable(self):
 99        if self._readable is None:
100            self._readable = FileDescriptorEvent(self.fd, 'r')
101        return self._readable
102
103    def _set_readable(self, readable):
104        self._readable = readable
105
106    readable = property(_get_readable, _set_readable)
107
108    def _get_writable(self):
109        if self._writable is None:
110            self._writable = FileDescriptorEvent(self.fd, 'w')
111        return self._writable
112
113    def _set_writable(self, writable):
114        self._writable = writable
115
116    writable = property(_get_writable, _set_writable)
117
118    def fileno(self):
119        return self.fd
120
121    def set_reuse_address(self, reuse_address):
122        self.socket.setsockopt(_socket.SOL_SOCKET, _socket.SO_REUSEADDR, int(reuse_address))
123
124    def set_send_buffer_size(self, n):
125        self.socket.setsockopt(_socket.SOL_SOCKET, _socket.SO_SNDBUF, n)
126
127    def set_recv_buffer_size(self, n):
128        self.socket.setsockopt(_socket.SOL_SOCKET, _socket.SO_RCVBUF, n)
129
130    def bind(self, addr):
131        self.socket.bind(addr)
132
133    def listen(self, backlog = DEFAULT_BACKLOG):
134        self.socket.listen(backlog)
135        self.state = self.STATE_LISTENING
136
137    def accept(self):
138        """waits on a listening socket, returns a new socket_class instance
139        for the incoming connection"""
140        assert self.state == self.STATE_LISTENING, "make sure socket is listening before calling accept"
141        while True:
142            #we need a loop because sometimes we become readable and still not a valid
143            #connection was accepted, in which case we return here and wait some more.
144            self.readable.wait()
145            try:
146                s, _ = self.socket.accept()
147            except _socket.error, (errno, _):
148                if errno in [EAGAIN, EWOULDBLOCK]:
149                    #this can happen when more than one process received readability on the same socket (forked/cloned/dupped)
150                    #in that case 1 process will do the accept, the others receive this error, and should continue waiting for
151                    #readability
152                    continue
153                else:
154                    raise
155
156            return self.__class__(s, self.STATE_CONNECTED)
157
158    def accept_iter(self):
159        while True:
160            try:
161                yield self.accept()
162            except Exception:
163                self.log.exception("in accept_iter")
164                Tasklet.sleep(1.0) #prevent hogging
165
166    def _connect(self, addr, timeout = TIMEOUT_CURRENT):
167        assert self.state == self.STATE_INIT, "make sure socket is not already connected or closed"
168        try:
169            err = self.socket.connect_ex(addr)
170            serr = self.socket.getsockopt(_socket.SOL_SOCKET, _socket.SO_ERROR)
171        except:
172            self.log.exception("unexpected exception thrown by connect_ex")
173            raise
174        if err == 0 and serr == 0:
175            self.state = self.STATE_CONNECTED
176        elif err == EINPROGRESS and serr != 0:
177            raise IOError(serr, os.strerror(serr))
178        elif err == EINPROGRESS and serr == 0:
179            self.state = self.STATE_CONNECTING
180            try:
181                self.writable.wait(timeout = timeout)
182                self.state = self.STATE_CONNECTED
183            except:
184                self.state = self.STATE_INIT
185                raise
186        else:
187            #some other error,
188            #unix domain socket that does not exist, Cannot assign requested address etc etc
189            raise _io.error_from_errno(IOError)
190
191    def write(self, buffer, timeout = TIMEOUT_CURRENT, assume_writable = True):
192        """Writes as many bytes as possible from the given buffer to this socket.
193        The buffer position is updated according to the number of bytes succesfully written to the socket.
194        This method returns the total number of bytes written. This method could possible write 0 bytes"""
195        assert self.state == self.STATE_CONNECTED, "socket must be connected in order to write to it"
196
197        Socket._x += 1
198        if Socket._x % XMOD == 0:
199            assume_writable = False
200
201        #by default assume that we can write to the socket without blocking
202        if assume_writable:
203            bytes_written, _ = buffer.send(self.fd) #write to fd from buffer
204            if bytes_written < 0 and _io.get_errno() == EAGAIN:
205                #nope, need to wait before sending our data
206                assume_writable = False
207            #else if error != EAGAIN, assume_writable will stay True, and we fall trough and raise error below
208
209        #if we cannot assume write-ability we will wait until data can be written again
210        if not assume_writable:
211            self.writable.wait(timeout = timeout)
212            bytes_written, _ = buffer.send(self.fd) #write to fd from buffer
213
214        #print 'bw', bytes_written, buffer.capacity
215        #
216        if bytes_written < 0:
217            raise _io.error_from_errno(IOError)
218        else:
219            return bytes_written
220
221    def read(self, buffer, timeout = TIMEOUT_CURRENT, assume_readable = True):
222        """Reads as many bytes as possible the socket into the given buffer.
223        The buffer position is updated according to the number of bytes read from the socket.
224        This method could possible read 0 bytes. The method returns the total number of bytes read"""
225        assert self.state == self.STATE_CONNECTED, "socket must be connected in order to read from it"
226
227        Socket._x += 1
228        if Socket._x % XMOD == 0:
229            assume_readable = False
230
231        #by default assume that we can read from the socket without blocking
232        if assume_readable:
233            bytes_read, _ = buffer.recv(self.fd) #read from fd to
234            if bytes_read < 0 and _io.get_errno() == EAGAIN:
235                #nope, need to wait before reading our data
236                assume_readable = False
237            #else if error != EAGAIN, assume_readable will stay True, and we fall trough and raise error below
238
239        #if we cannot assume readability we will wait until data can be read again
240        if not assume_readable:
241            self.readable.wait(timeout = timeout)
242            bytes_read, _ = buffer.recv(self.fd) #read from fd to
243
244        #print 'br', bytes_read, buffer.capacity
245        #
246        if bytes_read < 0:
247            raise _io.error_from_errno(IOError)
248        else:
249            return bytes_read
250
251    def write_socket(self, socket, timeout = TIMEOUT_CURRENT):
252        """writes a socket trough this socket"""
253        self.writable.wait(timeout = timeout)
254        _io.msgsendfd(self.fd, socket.fd)
255
256    def read_socket(self, socket_class = None, socket_family =  _socket.AF_INET, socket_type = _socket.SOCK_STREAM, socket_state = STATE_INIT, timeout = TIMEOUT_CURRENT):
257        """reads a socket from this socket"""
258        self.readable.wait(timeout = timeout)
259        fd = _io.msgrecvfd(self.fd)
260        return (socket_class or self.__class__).from_file_descriptor(fd, socket_family, socket_type, socket_state)
261
262    def is_closed(self):
263        return self.state == self.STATE_CLOSED
264
265    def close(self):
266        assert self.state in [self.STATE_CONNECTED, self.STATE_LISTENING]
267        self.state = self.STATE_CLOSING
268        if self._readable is not None:
269            self._readable.close()
270        if self._writable is not None:
271            self._writable.close()
272        self.socket.close()
273        del self.socket
274        del self._readable
275        del self._writable
276        self.state = self.STATE_CLOSED
277
278class SocketServer(object):
279    log = logging.getLogger('SocketServer')
280
281    def __init__(self, endpoint, handler = None):
282        self._addr = None
283        self._socket = None
284        if isinstance(endpoint, Socket):
285            self._socket = endpoint
286        else:
287            self._addr = endpoint
288        self._handler = handler
289        self._reuseaddress = True
290        self._handler_task_name = 'socket_handler'
291        self._accept_task = None
292        self._accept_task_name = 'socket_acceptor'
293
294    @property
295    def socket(self):
296        return self._socket
297
298    def _handle_accept(self, accepted_socket):
299        result = None
300        try:
301            result = self._handler(accepted_socket)
302        except TaskletExit:
303            raise
304        except:
305            self.log.exception("unhandled exception in socket handler")
306        finally:
307            if result is None and not accepted_socket.is_closed():
308                try:
309                    accepted_socket.close()
310                except TaskletExit:
311                    raise
312                except:
313                    self.log.exception("unhandled exception while forcefully closing client")
314
315    def _create_socket(self):
316        if self._socket is None:
317            if self._addr is None:
318                assert False, "address must be set or accepting socket must be explicitly set"
319            self._socket = Socket.from_address(self._addr)
320            self._socket.set_reuse_address(self._reuseaddress)
321        return self._socket
322
323    def _accept_task_loop(self):
324        accepted_socket = self._socket.accept()
325        Tasklet.new(self._handle_accept, self._handler_task_name)(accepted_socket)
326
327    def bind(self):
328        """creates socket if needed, and binds it"""
329        socket = self._create_socket()
330        socket.bind(self._addr)
331
332    def listen(self, backlog = DEFAULT_BACKLOG):
333        """creates socket if needed, and listens it"""
334        socket = self._create_socket()
335        socket.listen(backlog)
336
337    def serve(self):
338        """listens and starts a new tasks accepting incoming connections on the configured address"""
339        if self._socket is None:
340            self.bind()
341            self.listen()
342
343        if not callable(self._handler):
344            assert False, "handler not set or not callable"
345
346        self._accept_task = Tasklet.loop(self._accept_task_loop, name = self._accept_task_name, daemon = True)()
347
348    def close(self):
349        self._accept_task.kill()
350        self._socket.close()
351
352
353
354