PageRenderTime 81ms CodeModel.GetById 2ms app.highlight 72ms RepoModel.GetById 1ms app.codeStats 0ms

/lib-python/2.5.2/test/test_socket.py

https://bitbucket.org/amauryfa/pypy-sepcomp
Python | 995 lines | 852 code | 77 blank | 66 comment | 29 complexity | fd41aa2dac297dffc082fcdab149c9e4 MD5 | raw file
  1#!/usr/bin/env python
  2
  3import unittest
  4from test import test_support
  5
  6import socket
  7import select
  8import time
  9import thread, threading
 10import Queue
 11import sys
 12import array
 13from weakref import proxy
 14import signal
 15
 16PORT = 50007
 17HOST = 'localhost'
 18MSG = 'Michael Gilfix was here\n'
 19
 20class SocketTCPTest(unittest.TestCase):
 21
 22    def setUp(self):
 23        self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 24        self.serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
 25        global PORT
 26        PORT = test_support.bind_port(self.serv, HOST, PORT)
 27        self.serv.listen(1)
 28
 29    def tearDown(self):
 30        self.serv.close()
 31        self.serv = None
 32
 33class SocketUDPTest(unittest.TestCase):
 34
 35    def setUp(self):
 36        self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
 37        self.serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
 38        global PORT
 39        PORT = test_support.bind_port(self.serv, HOST, PORT)
 40
 41    def tearDown(self):
 42        self.serv.close()
 43        self.serv = None
 44
 45class ThreadableTest:
 46    """Threadable Test class
 47
 48    The ThreadableTest class makes it easy to create a threaded
 49    client/server pair from an existing unit test. To create a
 50    new threaded class from an existing unit test, use multiple
 51    inheritance:
 52
 53        class NewClass (OldClass, ThreadableTest):
 54            pass
 55
 56    This class defines two new fixture functions with obvious
 57    purposes for overriding:
 58
 59        clientSetUp ()
 60        clientTearDown ()
 61
 62    Any new test functions within the class must then define
 63    tests in pairs, where the test name is preceeded with a
 64    '_' to indicate the client portion of the test. Ex:
 65
 66        def testFoo(self):
 67            # Server portion
 68
 69        def _testFoo(self):
 70            # Client portion
 71
 72    Any exceptions raised by the clients during their tests
 73    are caught and transferred to the main thread to alert
 74    the testing framework.
 75
 76    Note, the server setup function cannot call any blocking
 77    functions that rely on the client thread during setup,
 78    unless serverExplicityReady() is called just before
 79    the blocking call (such as in setting up a client/server
 80    connection and performing the accept() in setUp().
 81    """
 82
 83    def __init__(self):
 84        # Swap the true setup function
 85        self.__setUp = self.setUp
 86        self.__tearDown = self.tearDown
 87        self.setUp = self._setUp
 88        self.tearDown = self._tearDown
 89
 90    def serverExplicitReady(self):
 91        """This method allows the server to explicitly indicate that
 92        it wants the client thread to proceed. This is useful if the
 93        server is about to execute a blocking routine that is
 94        dependent upon the client thread during its setup routine."""
 95        self.server_ready.set()
 96
 97    def _setUp(self):
 98        self.server_ready = threading.Event()
 99        self.client_ready = threading.Event()
100        self.done = threading.Event()
101        self.queue = Queue.Queue(1)
102
103        # Do some munging to start the client test.
104        methodname = self.id()
105        i = methodname.rfind('.')
106        methodname = methodname[i+1:]
107        test_method = getattr(self, '_' + methodname)
108        self.client_thread = thread.start_new_thread(
109            self.clientRun, (test_method,))
110
111        self.__setUp()
112        if not self.server_ready.isSet():
113            self.server_ready.set()
114        self.client_ready.wait()
115
116    def _tearDown(self):
117        self.__tearDown()
118        self.done.wait()
119
120        if not self.queue.empty():
121            msg = self.queue.get()
122            self.fail(msg)
123
124    def clientRun(self, test_func):
125        self.server_ready.wait()
126        self.client_ready.set()
127        self.clientSetUp()
128        if not callable(test_func):
129            raise TypeError, "test_func must be a callable function"
130        try:
131            test_func()
132        except Exception, strerror:
133            self.queue.put(strerror)
134        self.clientTearDown()
135
136    def clientSetUp(self):
137        raise NotImplementedError, "clientSetUp must be implemented."
138
139    def clientTearDown(self):
140        self.done.set()
141        thread.exit()
142
143class ThreadedTCPSocketTest(SocketTCPTest, ThreadableTest):
144
145    def __init__(self, methodName='runTest'):
146        SocketTCPTest.__init__(self, methodName=methodName)
147        ThreadableTest.__init__(self)
148
149    def clientSetUp(self):
150        self.cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
151
152    def clientTearDown(self):
153        self.cli.close()
154        self.cli = None
155        ThreadableTest.clientTearDown(self)
156
157class ThreadedUDPSocketTest(SocketUDPTest, ThreadableTest):
158
159    def __init__(self, methodName='runTest'):
160        SocketUDPTest.__init__(self, methodName=methodName)
161        ThreadableTest.__init__(self)
162
163    def clientSetUp(self):
164        self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
165
166class SocketConnectedTest(ThreadedTCPSocketTest):
167
168    def __init__(self, methodName='runTest'):
169        ThreadedTCPSocketTest.__init__(self, methodName=methodName)
170
171    def setUp(self):
172        ThreadedTCPSocketTest.setUp(self)
173        # Indicate explicitly we're ready for the client thread to
174        # proceed and then perform the blocking call to accept
175        self.serverExplicitReady()
176        conn, addr = self.serv.accept()
177        self.cli_conn = conn
178
179    def tearDown(self):
180        self.cli_conn.close()
181        self.cli_conn = None
182        ThreadedTCPSocketTest.tearDown(self)
183
184    def clientSetUp(self):
185        ThreadedTCPSocketTest.clientSetUp(self)
186        self.cli.connect((HOST, PORT))
187        self.serv_conn = self.cli
188
189    def clientTearDown(self):
190        self.serv_conn.close()
191        self.serv_conn = None
192        ThreadedTCPSocketTest.clientTearDown(self)
193
194class SocketPairTest(unittest.TestCase, ThreadableTest):
195
196    def __init__(self, methodName='runTest'):
197        unittest.TestCase.__init__(self, methodName=methodName)
198        ThreadableTest.__init__(self)
199
200    def setUp(self):
201        self.serv, self.cli = socket.socketpair()
202
203    def tearDown(self):
204        self.serv.close()
205        self.serv = None
206
207    def clientSetUp(self):
208        pass
209
210    def clientTearDown(self):
211        self.cli.close()
212        self.cli = None
213        ThreadableTest.clientTearDown(self)
214
215
216#######################################################################
217## Begin Tests
218
219class GeneralModuleTests(unittest.TestCase):
220
221    def test_weakref(self):
222        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
223        p = proxy(s)
224        self.assertEqual(p.fileno(), s.fileno())
225        s.close()
226        s = None
227        try:
228            p.fileno()
229        except ReferenceError:
230            pass
231        else:
232            self.fail('Socket proxy still exists')
233
234    def testSocketError(self):
235        # Testing socket module exceptions
236        def raise_error(*args, **kwargs):
237            raise socket.error
238        def raise_herror(*args, **kwargs):
239            raise socket.herror
240        def raise_gaierror(*args, **kwargs):
241            raise socket.gaierror
242        self.failUnlessRaises(socket.error, raise_error,
243                              "Error raising socket exception.")
244        self.failUnlessRaises(socket.error, raise_herror,
245                              "Error raising socket exception.")
246        self.failUnlessRaises(socket.error, raise_gaierror,
247                              "Error raising socket exception.")
248
249    def testCrucialConstants(self):
250        # Testing for mission critical constants
251        socket.AF_INET
252        socket.SOCK_STREAM
253        socket.SOCK_DGRAM
254        socket.SOCK_RAW
255        socket.SOCK_RDM
256        socket.SOCK_SEQPACKET
257        socket.SOL_SOCKET
258        socket.SO_REUSEADDR
259
260    def testHostnameRes(self):
261        # Testing hostname resolution mechanisms
262        hostname = socket.gethostname()
263        try:
264            ip = socket.gethostbyname(hostname)
265        except socket.error:
266            # Probably name lookup wasn't set up right; skip this test
267            return
268        self.assert_(ip.find('.') >= 0, "Error resolving host to ip.")
269        try:
270            hname, aliases, ipaddrs = socket.gethostbyaddr(ip)
271        except socket.error:
272            # Probably a similar problem as above; skip this test
273            return
274        all_host_names = [hostname, hname] + aliases
275        fqhn = socket.getfqdn(ip)
276        if not fqhn in all_host_names:
277            self.fail("Error testing host resolution mechanisms. (fqdn: %s, all: %s)" % (fqhn, repr(all_host_names)))
278
279    def testRefCountGetNameInfo(self):
280        # Testing reference count for getnameinfo
281        import sys
282        if hasattr(sys, "getrefcount"):
283            try:
284                # On some versions, this loses a reference
285                orig = sys.getrefcount(__name__)
286                socket.getnameinfo(__name__,0)
287            except SystemError:
288                if sys.getrefcount(__name__) <> orig:
289                    self.fail("socket.getnameinfo loses a reference")
290
291    def testInterpreterCrash(self):
292        # Making sure getnameinfo doesn't crash the interpreter
293        try:
294            # On some versions, this crashes the interpreter.
295            socket.getnameinfo(('x', 0, 0, 0), 0)
296        except socket.error:
297            pass
298
299    def testNtoH(self):
300        # This just checks that htons etc. are their own inverse,
301        # when looking at the lower 16 or 32 bits.
302        sizes = {socket.htonl: 32, socket.ntohl: 32,
303                 socket.htons: 16, socket.ntohs: 16}
304        for func, size in sizes.items():
305            mask = (1L<<size) - 1
306            for i in (0, 1, 0xffff, ~0xffff, 2, 0x01234567, 0x76543210):
307                self.assertEqual(i & mask, func(func(i&mask)) & mask)
308
309            swapped = func(mask)
310            self.assertEqual(swapped & mask, mask)
311            self.assertRaises(OverflowError, func, 1L<<34)
312
313    def testGetServBy(self):
314        eq = self.assertEqual
315        # Find one service that exists, then check all the related interfaces.
316        # I've ordered this by protocols that have both a tcp and udp
317        # protocol, at least for modern Linuxes.
318        if sys.platform in ('linux2', 'freebsd4', 'freebsd5', 'freebsd6',
319                            'freebsd7', 'darwin'):
320            # avoid the 'echo' service on this platform, as there is an
321            # assumption breaking non-standard port/protocol entry
322            services = ('daytime', 'qotd', 'domain')
323        else:
324            services = ('echo', 'daytime', 'domain')
325        for service in services:
326            try:
327                port = socket.getservbyname(service, 'tcp')
328                break
329            except socket.error:
330                pass
331        else:
332            raise socket.error
333        # Try same call with optional protocol omitted
334        port2 = socket.getservbyname(service)
335        eq(port, port2)
336        # Try udp, but don't barf it it doesn't exist
337        try:
338            udpport = socket.getservbyname(service, 'udp')
339        except socket.error:
340            udpport = None
341        else:
342            eq(udpport, port)
343        # Now make sure the lookup by port returns the same service name
344        eq(socket.getservbyport(port2), service)
345        eq(socket.getservbyport(port, 'tcp'), service)
346        if udpport is not None:
347            eq(socket.getservbyport(udpport, 'udp'), service)
348
349    def testDefaultTimeout(self):
350        # Testing default timeout
351        # The default timeout should initially be None
352        self.assertEqual(socket.getdefaulttimeout(), None)
353        s = socket.socket()
354        self.assertEqual(s.gettimeout(), None)
355        s.close()
356
357        # Set the default timeout to 10, and see if it propagates
358        socket.setdefaulttimeout(10)
359        self.assertEqual(socket.getdefaulttimeout(), 10)
360        s = socket.socket()
361        self.assertEqual(s.gettimeout(), 10)
362        s.close()
363
364        # Reset the default timeout to None, and see if it propagates
365        socket.setdefaulttimeout(None)
366        self.assertEqual(socket.getdefaulttimeout(), None)
367        s = socket.socket()
368        self.assertEqual(s.gettimeout(), None)
369        s.close()
370
371        # Check that setting it to an invalid value raises ValueError
372        self.assertRaises(ValueError, socket.setdefaulttimeout, -1)
373
374        # Check that setting it to an invalid type raises TypeError
375        self.assertRaises(TypeError, socket.setdefaulttimeout, "spam")
376
377    def testIPv4toString(self):
378        if not hasattr(socket, 'inet_pton'):
379            return # No inet_pton() on this platform
380        from socket import inet_aton as f, inet_pton, AF_INET
381        g = lambda a: inet_pton(AF_INET, a)
382
383        self.assertEquals('\x00\x00\x00\x00', f('0.0.0.0'))
384        self.assertEquals('\xff\x00\xff\x00', f('255.0.255.0'))
385        self.assertEquals('\xaa\xaa\xaa\xaa', f('170.170.170.170'))
386        self.assertEquals('\x01\x02\x03\x04', f('1.2.3.4'))
387        self.assertEquals('\xff\xff\xff\xff', f('255.255.255.255'))
388
389        self.assertEquals('\x00\x00\x00\x00', g('0.0.0.0'))
390        self.assertEquals('\xff\x00\xff\x00', g('255.0.255.0'))
391        self.assertEquals('\xaa\xaa\xaa\xaa', g('170.170.170.170'))
392        self.assertEquals('\xff\xff\xff\xff', g('255.255.255.255'))
393
394    def testIPv6toString(self):
395        if not hasattr(socket, 'inet_pton'):
396            return # No inet_pton() on this platform
397        try:
398            from socket import inet_pton, AF_INET6, has_ipv6
399            if not has_ipv6:
400                return
401        except ImportError:
402            return
403        f = lambda a: inet_pton(AF_INET6, a)
404
405        self.assertEquals('\x00' * 16, f('::'))
406        self.assertEquals('\x00' * 16, f('0::0'))
407        self.assertEquals('\x00\x01' + '\x00' * 14, f('1::'))
408        self.assertEquals(
409            '\x45\xef\x76\xcb\x00\x1a\x56\xef\xaf\xeb\x0b\xac\x19\x24\xae\xae',
410            f('45ef:76cb:1a:56ef:afeb:bac:1924:aeae')
411        )
412
413    def testStringToIPv4(self):
414        if not hasattr(socket, 'inet_ntop'):
415            return # No inet_ntop() on this platform
416        from socket import inet_ntoa as f, inet_ntop, AF_INET
417        g = lambda a: inet_ntop(AF_INET, a)
418
419        self.assertEquals('1.0.1.0', f('\x01\x00\x01\x00'))
420        self.assertEquals('170.85.170.85', f('\xaa\x55\xaa\x55'))
421        self.assertEquals('255.255.255.255', f('\xff\xff\xff\xff'))
422        self.assertEquals('1.2.3.4', f('\x01\x02\x03\x04'))
423
424        self.assertEquals('1.0.1.0', g('\x01\x00\x01\x00'))
425        self.assertEquals('170.85.170.85', g('\xaa\x55\xaa\x55'))
426        self.assertEquals('255.255.255.255', g('\xff\xff\xff\xff'))
427
428    def testStringToIPv6(self):
429        if not hasattr(socket, 'inet_ntop'):
430            return # No inet_ntop() on this platform
431        try:
432            from socket import inet_ntop, AF_INET6, has_ipv6
433            if not has_ipv6:
434                return
435        except ImportError:
436            return
437        f = lambda a: inet_ntop(AF_INET6, a)
438
439        self.assertEquals('::', f('\x00' * 16))
440        self.assertEquals('::1', f('\x00' * 15 + '\x01'))
441        self.assertEquals(
442            'aef:b01:506:1001:ffff:9997:55:170',
443            f('\x0a\xef\x0b\x01\x05\x06\x10\x01\xff\xff\x99\x97\x00\x55\x01\x70')
444        )
445
446    # XXX The following don't test module-level functionality...
447
448    def testSockName(self):
449        # Testing getsockname()
450        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
451        sock.bind(("0.0.0.0", PORT+1))
452        name = sock.getsockname()
453        # XXX(nnorwitz): http://tinyurl.com/os5jz seems to indicate
454        # it reasonable to get the host's addr in addition to 0.0.0.0.
455        # At least for eCos.  This is required for the S/390 to pass.
456        my_ip_addr = socket.gethostbyname(socket.gethostname())
457        self.assert_(name[0] in ("0.0.0.0", my_ip_addr), '%s invalid' % name[0])
458        self.assertEqual(name[1], PORT+1)
459
460    def testGetSockOpt(self):
461        # Testing getsockopt()
462        # We know a socket should start without reuse==0
463        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
464        reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
465        self.failIf(reuse != 0, "initial mode is reuse")
466
467    def testSetSockOpt(self):
468        # Testing setsockopt()
469        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
470        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
471        reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
472        self.failIf(reuse == 0, "failed to set reuse mode")
473
474    def testSendAfterClose(self):
475        # testing send() after close() with timeout
476        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
477        sock.settimeout(1)
478        sock.close()
479        self.assertRaises(socket.error, sock.send, "spam")
480
481    def testNewAttributes(self):
482        # testing .family, .type and .protocol
483        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
484        self.assertEqual(sock.family, socket.AF_INET)
485        self.assertEqual(sock.type, socket.SOCK_STREAM)
486        self.assertEqual(sock.proto, 0)
487        sock.close()
488
489class BasicTCPTest(SocketConnectedTest):
490
491    def __init__(self, methodName='runTest'):
492        SocketConnectedTest.__init__(self, methodName=methodName)
493
494    def testRecv(self):
495        # Testing large receive over TCP
496        msg = self.cli_conn.recv(1024)
497        self.assertEqual(msg, MSG)
498
499    def _testRecv(self):
500        self.serv_conn.send(MSG)
501
502    def testOverFlowRecv(self):
503        # Testing receive in chunks over TCP
504        seg1 = self.cli_conn.recv(len(MSG) - 3)
505        seg2 = self.cli_conn.recv(1024)
506        msg = seg1 + seg2
507        self.assertEqual(msg, MSG)
508
509    def _testOverFlowRecv(self):
510        self.serv_conn.send(MSG)
511
512    def testRecvFrom(self):
513        # Testing large recvfrom() over TCP
514        msg, addr = self.cli_conn.recvfrom(1024)
515        self.assertEqual(msg, MSG)
516
517    def _testRecvFrom(self):
518        self.serv_conn.send(MSG)
519
520    def testOverFlowRecvFrom(self):
521        # Testing recvfrom() in chunks over TCP
522        seg1, addr = self.cli_conn.recvfrom(len(MSG)-3)
523        seg2, addr = self.cli_conn.recvfrom(1024)
524        msg = seg1 + seg2
525        self.assertEqual(msg, MSG)
526
527    def _testOverFlowRecvFrom(self):
528        self.serv_conn.send(MSG)
529
530    def testSendAll(self):
531        # Testing sendall() with a 2048 byte string over TCP
532        msg = ''
533        while 1:
534            read = self.cli_conn.recv(1024)
535            if not read:
536                break
537            msg += read
538        self.assertEqual(msg, 'f' * 2048)
539
540    def _testSendAll(self):
541        big_chunk = 'f' * 2048
542        self.serv_conn.sendall(big_chunk)
543
544    def testFromFd(self):
545        # Testing fromfd()
546        if not hasattr(socket, "fromfd"):
547            return # On Windows, this doesn't exist
548        fd = self.cli_conn.fileno()
549        sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
550        msg = sock.recv(1024)
551        self.assertEqual(msg, MSG)
552
553    def _testFromFd(self):
554        self.serv_conn.send(MSG)
555
556    def testShutdown(self):
557        # Testing shutdown()
558        msg = self.cli_conn.recv(1024)
559        self.assertEqual(msg, MSG)
560
561    def _testShutdown(self):
562        self.serv_conn.send(MSG)
563        self.serv_conn.shutdown(2)
564
565class BasicUDPTest(ThreadedUDPSocketTest):
566
567    def __init__(self, methodName='runTest'):
568        ThreadedUDPSocketTest.__init__(self, methodName=methodName)
569
570    def testSendtoAndRecv(self):
571        # Testing sendto() and Recv() over UDP
572        msg = self.serv.recv(len(MSG))
573        self.assertEqual(msg, MSG)
574
575    def _testSendtoAndRecv(self):
576        self.cli.sendto(MSG, 0, (HOST, PORT))
577
578    def testRecvFrom(self):
579        # Testing recvfrom() over UDP
580        msg, addr = self.serv.recvfrom(len(MSG))
581        self.assertEqual(msg, MSG)
582
583    def _testRecvFrom(self):
584        self.cli.sendto(MSG, 0, (HOST, PORT))
585
586    def testRecvFromNegative(self):
587        # Negative lengths passed to recvfrom should give ValueError.
588        self.assertRaises(ValueError, self.serv.recvfrom, -1)
589
590    def _testRecvFromNegative(self):
591        self.cli.sendto(MSG, 0, (HOST, PORT))
592
593class TCPCloserTest(ThreadedTCPSocketTest):
594
595    def testClose(self):
596        conn, addr = self.serv.accept()
597        conn.close()
598
599        sd = self.cli
600        read, write, err = select.select([sd], [], [], 1.0)
601        self.assertEqual(read, [sd])
602        self.assertEqual(sd.recv(1), '')
603
604    def _testClose(self):
605        self.cli.connect((HOST, PORT))
606        time.sleep(1.0)
607
608class BasicSocketPairTest(SocketPairTest):
609
610    def __init__(self, methodName='runTest'):
611        SocketPairTest.__init__(self, methodName=methodName)
612
613    def testRecv(self):
614        msg = self.serv.recv(1024)
615        self.assertEqual(msg, MSG)
616
617    def _testRecv(self):
618        self.cli.send(MSG)
619
620    def testSend(self):
621        self.serv.send(MSG)
622
623    def _testSend(self):
624        msg = self.cli.recv(1024)
625        self.assertEqual(msg, MSG)
626
627class NonBlockingTCPTests(ThreadedTCPSocketTest):
628
629    def __init__(self, methodName='runTest'):
630        ThreadedTCPSocketTest.__init__(self, methodName=methodName)
631
632    def testSetBlocking(self):
633        # Testing whether set blocking works
634        self.serv.setblocking(0)
635        start = time.time()
636        try:
637            self.serv.accept()
638        except socket.error:
639            pass
640        end = time.time()
641        self.assert_((end - start) < 1.0, "Error setting non-blocking mode.")
642
643    def _testSetBlocking(self):
644        pass
645
646    def testAccept(self):
647        # Testing non-blocking accept
648        self.serv.setblocking(0)
649        try:
650            conn, addr = self.serv.accept()
651        except socket.error:
652            pass
653        else:
654            self.fail("Error trying to do non-blocking accept.")
655        read, write, err = select.select([self.serv], [], [])
656        if self.serv in read:
657            conn, addr = self.serv.accept()
658        else:
659            self.fail("Error trying to do accept after select.")
660
661    def _testAccept(self):
662        time.sleep(0.1)
663        self.cli.connect((HOST, PORT))
664
665    def testConnect(self):
666        # Testing non-blocking connect
667        conn, addr = self.serv.accept()
668
669    def _testConnect(self):
670        self.cli.settimeout(10)
671        self.cli.connect((HOST, PORT))
672
673    def testRecv(self):
674        # Testing non-blocking recv
675        conn, addr = self.serv.accept()
676        conn.setblocking(0)
677        try:
678            msg = conn.recv(len(MSG))
679        except socket.error:
680            pass
681        else:
682            self.fail("Error trying to do non-blocking recv.")
683        read, write, err = select.select([conn], [], [])
684        if conn in read:
685            msg = conn.recv(len(MSG))
686            self.assertEqual(msg, MSG)
687        else:
688            self.fail("Error during select call to non-blocking socket.")
689
690    def _testRecv(self):
691        self.cli.connect((HOST, PORT))
692        time.sleep(0.1)
693        self.cli.send(MSG)
694
695class FileObjectClassTestCase(SocketConnectedTest):
696
697    bufsize = -1 # Use default buffer size
698
699    def __init__(self, methodName='runTest'):
700        SocketConnectedTest.__init__(self, methodName=methodName)
701
702    def setUp(self):
703        SocketConnectedTest.setUp(self)
704        self.serv_file = self.cli_conn.makefile('rb', self.bufsize)
705
706    def tearDown(self):
707        self.serv_file.close()
708        self.assert_(self.serv_file.closed)
709        self.serv_file = None
710        SocketConnectedTest.tearDown(self)
711
712    def clientSetUp(self):
713        SocketConnectedTest.clientSetUp(self)
714        self.cli_file = self.serv_conn.makefile('wb')
715
716    def clientTearDown(self):
717        self.cli_file.close()
718        self.assert_(self.cli_file.closed)
719        self.cli_file = None
720        SocketConnectedTest.clientTearDown(self)
721
722    def testSmallRead(self):
723        # Performing small file read test
724        first_seg = self.serv_file.read(len(MSG)-3)
725        second_seg = self.serv_file.read(3)
726        msg = first_seg + second_seg
727        self.assertEqual(msg, MSG)
728
729    def _testSmallRead(self):
730        self.cli_file.write(MSG)
731        self.cli_file.flush()
732
733    def testFullRead(self):
734        # read until EOF
735        msg = self.serv_file.read()
736        self.assertEqual(msg, MSG)
737
738    def _testFullRead(self):
739        self.cli_file.write(MSG)
740        self.cli_file.close()
741
742    def testUnbufferedRead(self):
743        # Performing unbuffered file read test
744        buf = ''
745        while 1:
746            char = self.serv_file.read(1)
747            if not char:
748                break
749            buf += char
750        self.assertEqual(buf, MSG)
751
752    def _testUnbufferedRead(self):
753        self.cli_file.write(MSG)
754        self.cli_file.flush()
755
756    def testReadline(self):
757        # Performing file readline test
758        line = self.serv_file.readline()
759        self.assertEqual(line, MSG)
760
761    def _testReadline(self):
762        self.cli_file.write(MSG)
763        self.cli_file.flush()
764
765    def testClosedAttr(self):
766        self.assert_(not self.serv_file.closed)
767
768    def _testClosedAttr(self):
769        self.assert_(not self.cli_file.closed)
770
771class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase):
772
773    """Repeat the tests from FileObjectClassTestCase with bufsize==0.
774
775    In this case (and in this case only), it should be possible to
776    create a file object, read a line from it, create another file
777    object, read another line from it, without loss of data in the
778    first file object's buffer.  Note that httplib relies on this
779    when reading multiple requests from the same socket."""
780
781    bufsize = 0 # Use unbuffered mode
782
783    def testUnbufferedReadline(self):
784        # Read a line, create a new file object, read another line with it
785        line = self.serv_file.readline() # first line
786        self.assertEqual(line, "A. " + MSG) # first line
787        self.serv_file = self.cli_conn.makefile('rb', 0)
788        line = self.serv_file.readline() # second line
789        self.assertEqual(line, "B. " + MSG) # second line
790
791    def _testUnbufferedReadline(self):
792        self.cli_file.write("A. " + MSG)
793        self.cli_file.write("B. " + MSG)
794        self.cli_file.flush()
795
796class LineBufferedFileObjectClassTestCase(FileObjectClassTestCase):
797
798    bufsize = 1 # Default-buffered for reading; line-buffered for writing
799
800
801class SmallBufferedFileObjectClassTestCase(FileObjectClassTestCase):
802
803    bufsize = 2 # Exercise the buffering code
804
805
806class Urllib2FileobjectTest(unittest.TestCase):
807
808    # urllib2.HTTPHandler has "borrowed" socket._fileobject, and requires that
809    # it close the socket if the close c'tor argument is true
810
811    def testClose(self):
812        class MockSocket:
813            closed = False
814            def flush(self): pass
815            def close(self): self.closed = True
816
817        # must not close unless we request it: the original use of _fileobject
818        # by module socket requires that the underlying socket not be closed until
819        # the _socketobject that created the _fileobject is closed
820        s = MockSocket()
821        f = socket._fileobject(s)
822        f.close()
823        self.assert_(not s.closed)
824
825        s = MockSocket()
826        f = socket._fileobject(s, close=True)
827        f.close()
828        self.assert_(s.closed)
829
830class TCPTimeoutTest(SocketTCPTest):
831
832    def testTCPTimeout(self):
833        def raise_timeout(*args, **kwargs):
834            self.serv.settimeout(1.0)
835            self.serv.accept()
836        self.failUnlessRaises(socket.timeout, raise_timeout,
837                              "Error generating a timeout exception (TCP)")
838
839    def testTimeoutZero(self):
840        ok = False
841        try:
842            self.serv.settimeout(0.0)
843            foo = self.serv.accept()
844        except socket.timeout:
845            self.fail("caught timeout instead of error (TCP)")
846        except socket.error:
847            ok = True
848        except:
849            self.fail("caught unexpected exception (TCP)")
850        if not ok:
851            self.fail("accept() returned success when we did not expect it")
852
853    def testInterruptedTimeout(self):
854        # XXX I don't know how to do this test on MSWindows or any other
855        # plaform that doesn't support signal.alarm() or os.kill(), though
856        # the bug should have existed on all platforms.
857        if not hasattr(signal, "alarm"):
858            return                  # can only test on *nix
859        self.serv.settimeout(5.0)   # must be longer than alarm
860        class Alarm(Exception):
861            pass
862        def alarm_handler(signal, frame):
863            raise Alarm
864        old_alarm = signal.signal(signal.SIGALRM, alarm_handler)
865        try:
866            signal.alarm(2)    # POSIX allows alarm to be up to 1 second early
867            try:
868                foo = self.serv.accept()
869            except socket.timeout:
870                self.fail("caught timeout instead of Alarm")
871            except Alarm:
872                pass
873            except:
874                self.fail("caught other exception instead of Alarm")
875            else:
876                self.fail("nothing caught")
877            signal.alarm(0)         # shut off alarm
878        except Alarm:
879            self.fail("got Alarm in wrong place")
880        finally:
881            # no alarm can be pending.  Safe to restore old handler.
882            signal.signal(signal.SIGALRM, old_alarm)
883
884class UDPTimeoutTest(SocketTCPTest):
885
886    def testUDPTimeout(self):
887        def raise_timeout(*args, **kwargs):
888            self.serv.settimeout(1.0)
889            self.serv.recv(1024)
890        self.failUnlessRaises(socket.timeout, raise_timeout,
891                              "Error generating a timeout exception (UDP)")
892
893    def testTimeoutZero(self):
894        ok = False
895        try:
896            self.serv.settimeout(0.0)
897            foo = self.serv.recv(1024)
898        except socket.timeout:
899            self.fail("caught timeout instead of error (UDP)")
900        except socket.error:
901            ok = True
902        except:
903            self.fail("caught unexpected exception (UDP)")
904        if not ok:
905            self.fail("recv() returned success when we did not expect it")
906
907class TestExceptions(unittest.TestCase):
908
909    def testExceptionTree(self):
910        self.assert_(issubclass(socket.error, Exception))
911        self.assert_(issubclass(socket.herror, socket.error))
912        self.assert_(issubclass(socket.gaierror, socket.error))
913        self.assert_(issubclass(socket.timeout, socket.error))
914
915class TestLinuxAbstractNamespace(unittest.TestCase):
916
917    UNIX_PATH_MAX = 108
918
919    def testLinuxAbstractNamespace(self):
920        address = "\x00python-test-hello\x00\xff"
921        s1 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
922        s1.bind(address)
923        s1.listen(1)
924        s2 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
925        s2.connect(s1.getsockname())
926        s1.accept()
927        self.assertEqual(s1.getsockname(), address)
928        self.assertEqual(s2.getpeername(), address)
929
930    def testMaxName(self):
931        address = "\x00" + "h" * (self.UNIX_PATH_MAX - 1)
932        s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
933        s.bind(address)
934        self.assertEqual(s.getsockname(), address)
935
936    def testNameOverflow(self):
937        address = "\x00" + "h" * self.UNIX_PATH_MAX
938        s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
939        self.assertRaises(socket.error, s.bind, address)
940
941
942class BufferIOTest(SocketConnectedTest):
943    """
944    Test the buffer versions of socket.recv() and socket.send().
945    """
946    def __init__(self, methodName='runTest'):
947        SocketConnectedTest.__init__(self, methodName=methodName)
948
949    def testRecvInto(self):
950        buf = array.array('c', ' '*1024)
951        nbytes = self.cli_conn.recv_into(buf)
952        self.assertEqual(nbytes, len(MSG))
953        msg = buf.tostring()[:len(MSG)]
954        self.assertEqual(msg, MSG)
955
956    def _testRecvInto(self):
957        buf = buffer(MSG)
958        self.serv_conn.send(buf)
959
960    def testRecvFromInto(self):
961        buf = array.array('c', ' '*1024)
962        nbytes, addr = self.cli_conn.recvfrom_into(buf)
963        self.assertEqual(nbytes, len(MSG))
964        msg = buf.tostring()[:len(MSG)]
965        self.assertEqual(msg, MSG)
966
967    def _testRecvFromInto(self):
968        buf = buffer(MSG)
969        self.serv_conn.send(buf)
970
971def test_main():
972    tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
973             TestExceptions, BufferIOTest]
974    if sys.platform != 'mac':
975        tests.extend([ BasicUDPTest, UDPTimeoutTest ])
976
977    tests.extend([
978        NonBlockingTCPTests,
979        FileObjectClassTestCase,
980        UnbufferedFileObjectClassTestCase,
981        LineBufferedFileObjectClassTestCase,
982        SmallBufferedFileObjectClassTestCase,
983        Urllib2FileobjectTest,
984    ])
985    if hasattr(socket, "socketpair"):
986        tests.append(BasicSocketPairTest)
987    if sys.platform == 'linux2':
988        tests.append(TestLinuxAbstractNamespace)
989
990    thread_info = test_support.threading_setup()
991    test_support.run_unittest(*tests)
992    test_support.threading_cleanup(*thread_info)
993
994if __name__ == "__main__":
995    test_main()