PageRenderTime 52ms CodeModel.GetById 13ms app.highlight 33ms RepoModel.GetById 1ms app.codeStats 1ms

/tags/release-0.0.0-rc0/hive/external/service/lib/py/thrift/server/TNonblockingServer.py

#
Python | 309 lines | 246 code | 15 blank | 48 comment | 26 complexity | a3fe5c3abdd08e59501932eee8549f11 MD5 | raw file
  1#
  2# Licensed to the Apache Software Foundation (ASF) under one
  3# or more contributor license agreements. See the NOTICE file
  4# distributed with this work for additional information
  5# regarding copyright ownership. The ASF licenses this file
  6# to you under the Apache License, Version 2.0 (the
  7# "License"); you may not use this file except in compliance
  8# with the License. You may obtain a copy of the License at
  9#
 10#   http://www.apache.org/licenses/LICENSE-2.0
 11#
 12# Unless required by applicable law or agreed to in writing,
 13# software distributed under the License is distributed on an
 14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 15# KIND, either express or implied. See the License for the
 16# specific language governing permissions and limitations
 17# under the License.
 18#
 19"""Implementation of non-blocking server.
 20
 21The main idea of the server is reciving and sending requests
 22only from main thread.
 23
 24It also makes thread pool server in tasks terms, not connections.
 25"""
 26import threading
 27import socket
 28import Queue
 29import select
 30import struct
 31import logging
 32
 33from thrift.transport import TTransport
 34from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory
 35
 36__all__ = ['TNonblockingServer']
 37
 38class Worker(threading.Thread):
 39    """Worker is a small helper to process incoming connection."""
 40    def __init__(self, queue):
 41        threading.Thread.__init__(self)
 42        self.queue = queue
 43
 44    def run(self):
 45        """Process queries from task queue, stop if processor is None."""
 46        while True:
 47            try:
 48                processor, iprot, oprot, otrans, callback = self.queue.get()
 49                if processor is None:
 50                    break
 51                processor.process(iprot, oprot)
 52                callback(True, otrans.getvalue())
 53            except Exception:
 54                logging.exception("Exception while processing request")
 55                callback(False, '')
 56
 57WAIT_LEN = 0
 58WAIT_MESSAGE = 1
 59WAIT_PROCESS = 2
 60SEND_ANSWER = 3
 61CLOSED = 4
 62
 63def locked(func):
 64    "Decorator which locks self.lock."
 65    def nested(self, *args, **kwargs):
 66        self.lock.acquire()
 67        try:
 68            return func(self, *args, **kwargs)
 69        finally:
 70            self.lock.release()
 71    return nested
 72
 73def socket_exception(func):
 74    "Decorator close object on socket.error."
 75    def read(self, *args, **kwargs):
 76        try:
 77            return func(self, *args, **kwargs)
 78        except socket.error:
 79            self.close()
 80    return read
 81
 82class Connection:
 83    """Basic class is represented connection.
 84    
 85    It can be in state:
 86        WAIT_LEN --- connection is reading request len.
 87        WAIT_MESSAGE --- connection is reading request.
 88        WAIT_PROCESS --- connection has just read whole request and 
 89            waits for call ready routine.
 90        SEND_ANSWER --- connection is sending answer string (including length
 91            of answer).
 92        CLOSED --- socket was closed and connection should be deleted.
 93    """
 94    def __init__(self, new_socket, wake_up):
 95        self.socket = new_socket
 96        self.socket.setblocking(False)
 97        self.status = WAIT_LEN
 98        self.len = 0
 99        self.message = ''
100        self.lock = threading.Lock()
101        self.wake_up = wake_up
102
103    def _read_len(self):
104        """Reads length of request.
105        
106        It's really paranoic routine and it may be replaced by 
107        self.socket.recv(4)."""
108        read = self.socket.recv(4 - len(self.message))
109        if len(read) == 0:
110            # if we read 0 bytes and self.message is empty, it means client close 
111            # connection
112            if len(self.message) != 0:
113                logging.error("can't read frame size from socket")
114            self.close()
115            return
116        self.message += read
117        if len(self.message) == 4:
118            self.len, = struct.unpack('!i', self.message)
119            if self.len < 0:
120                logging.error("negative frame size, it seems client"\
121                    " doesn't use FramedTransport")
122                self.close()
123            elif self.len == 0:
124                logging.error("empty frame, it's really strange")
125                self.close()
126            else:
127                self.message = ''
128                self.status = WAIT_MESSAGE
129
130    @socket_exception
131    def read(self):
132        """Reads data from stream and switch state."""
133        assert self.status in (WAIT_LEN, WAIT_MESSAGE)
134        if self.status == WAIT_LEN:
135            self._read_len()
136            # go back to the main loop here for simplicity instead of
137            # falling through, even though there is a good chance that
138            # the message is already available
139        elif self.status == WAIT_MESSAGE:
140            read = self.socket.recv(self.len - len(self.message))
141            if len(read) == 0:
142                logging.error("can't read frame from socket (get %d of %d bytes)" %
143                    (len(self.message), self.len))
144                self.close()
145                return
146            self.message += read
147            if len(self.message) == self.len:
148                self.status = WAIT_PROCESS
149
150    @socket_exception
151    def write(self):
152        """Writes data from socket and switch state."""
153        assert self.status == SEND_ANSWER
154        sent = self.socket.send(self.message)
155        if sent == len(self.message):
156            self.status = WAIT_LEN
157            self.message = ''
158            self.len = 0
159        else:
160            self.message = self.message[sent:]
161
162    @locked
163    def ready(self, all_ok, message):
164        """Callback function for switching state and waking up main thread.
165        
166        This function is the only function witch can be called asynchronous.
167        
168        The ready can switch Connection to three states:
169            WAIT_LEN if request was oneway.
170            SEND_ANSWER if request was processed in normal way.
171            CLOSED if request throws unexpected exception.
172        
173        The one wakes up main thread.
174        """
175        assert self.status == WAIT_PROCESS
176        if not all_ok:
177            self.close()
178            self.wake_up()
179            return
180        self.len = ''
181        self.message = struct.pack('!i', len(message)) + message
182        if len(message) == 0:
183            # it was a oneway request, do not write answer
184            self.status = WAIT_LEN
185        else:
186            self.status = SEND_ANSWER
187        self.wake_up()
188
189    @locked
190    def is_writeable(self):
191        "Returns True if connection should be added to write list of select."
192        return self.status == SEND_ANSWER
193
194    # it's not necessary, but...
195    @locked
196    def is_readable(self):
197        "Returns True if connection should be added to read list of select."
198        return self.status in (WAIT_LEN, WAIT_MESSAGE)
199
200    @locked
201    def is_closed(self):
202        "Returns True if connection is closed."
203        return self.status == CLOSED
204
205    def fileno(self):
206        "Returns the file descriptor of the associated socket."
207        return self.socket.fileno()
208
209    def close(self):
210        "Closes connection"
211        self.status = CLOSED
212        self.socket.close()
213
214class TNonblockingServer:
215    """Non-blocking server."""
216    def __init__(self, processor, lsocket, inputProtocolFactory=None, 
217            outputProtocolFactory=None, threads=10):
218        self.processor = processor
219        self.socket = lsocket
220        self.in_protocol = inputProtocolFactory or TBinaryProtocolFactory()
221        self.out_protocol = outputProtocolFactory or self.in_protocol
222        self.threads = int(threads)
223        self.clients = {}
224        self.tasks = Queue.Queue()
225        self._read, self._write = socket.socketpair()
226        self.prepared = False
227
228    def setNumThreads(self, num):
229        """Set the number of worker threads that should be created."""
230        # implement ThreadPool interface
231        assert not self.prepared, "You can't change number of threads for working server"
232        self.threads = num
233
234    def prepare(self):
235        """Prepares server for serve requests."""
236        self.socket.listen()
237        for _ in xrange(self.threads):
238            thread = Worker(self.tasks)
239            thread.setDaemon(True)
240            thread.start()
241        self.prepared = True
242
243    def wake_up(self):
244        """Wake up main thread.
245        
246        The server usualy waits in select call in we should terminate one.
247        The simplest way is using socketpair.
248        
249        Select always wait to read from the first socket of socketpair.
250        
251        In this case, we can just write anything to the second socket from
252        socketpair."""
253        self._write.send('1')
254
255    def _select(self):
256        """Does select on open connections."""
257        readable = [self.socket.handle.fileno(), self._read.fileno()]
258        writable = []
259        for i, connection in self.clients.items():
260            if connection.is_readable():
261                readable.append(connection.fileno())
262            if connection.is_writeable():
263                writable.append(connection.fileno())
264            if connection.is_closed():
265                del self.clients[i]
266        return select.select(readable, writable, readable)
267        
268    def handle(self):
269        """Handle requests.
270       
271        WARNING! You must call prepare BEFORE calling handle.
272        """
273        assert self.prepared, "You have to call prepare before handle"
274        rset, wset, xset = self._select()
275        for readable in rset:
276            if readable == self._read.fileno():
277                # don't care i just need to clean readable flag
278                self._read.recv(1024) 
279            elif readable == self.socket.handle.fileno():
280                client = self.socket.accept().handle
281                self.clients[client.fileno()] = Connection(client, self.wake_up)
282            else:
283                connection = self.clients[readable]
284                connection.read()
285                if connection.status == WAIT_PROCESS:
286                    itransport = TTransport.TMemoryBuffer(connection.message)
287                    otransport = TTransport.TMemoryBuffer()
288                    iprot = self.in_protocol.getProtocol(itransport)
289                    oprot = self.out_protocol.getProtocol(otransport)
290                    self.tasks.put([self.processor, iprot, oprot, 
291                                    otransport, connection.ready])
292        for writeable in wset:
293            self.clients[writeable].write()
294        for oob in xset:
295            self.clients[oob].close()
296            del self.clients[oob]
297
298    def close(self):
299        """Closes the server."""
300        for _ in xrange(self.threads):
301            self.tasks.put([None, None, None, None, None])
302        self.socket.close()
303        self.prepared = False
304        
305    def serve(self):
306        """Serve forever."""
307        self.prepare()
308        while True:
309            self.handle()