PageRenderTime 62ms CodeModel.GetById 2ms app.highlight 52ms RepoModel.GetById 1ms app.codeStats 1ms

/greentest/test_socket.py

https://bitbucket.org/lericson/gevent
Python | 1037 lines | 873 code | 96 blank | 68 comment | 34 complexity | d2c04bdda50751928f74dbe9508812eb MD5 | raw file
   1#!/usr/bin/env python
   2
   3from gevent import monkey
   4monkey.patch_all()
   5
   6import greentest
   7import test_support
   8
   9import socket
  10import select
  11import time
  12import thread, threading
  13import Queue
  14import sys
  15import array
  16from weakref import proxy
  17import signal
  18
  19PORT = 50007
  20HOST = 'localhost'
  21MSG = 'Michael Gilfix was here\n'
  22
  23class SocketTCPTest(greentest.TestCase):
  24
  25    def setUp(self):
  26        self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  27        self.serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  28        global PORT
  29        PORT = test_support.bind_port(self.serv, HOST, PORT)
  30        self.serv.listen(1)
  31
  32    def tearDown(self):
  33        self.serv.close()
  34        self.serv = None
  35
  36class SocketUDPTest(greentest.TestCase):
  37
  38    def setUp(self):
  39        self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  40        self.serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  41        global PORT
  42        PORT = test_support.bind_port(self.serv, HOST, PORT)
  43
  44    def tearDown(self):
  45        self.serv.close()
  46        self.serv = None
  47
  48class ThreadableTest:
  49    """Threadable Test class
  50
  51    The ThreadableTest class makes it easy to create a threaded
  52    client/server pair from an existing unit test. To create a
  53    new threaded class from an existing unit test, use multiple
  54    inheritance:
  55
  56        class NewClass (OldClass, ThreadableTest):
  57            pass
  58
  59    This class defines two new fixture functions with obvious
  60    purposes for overriding:
  61
  62        clientSetUp ()
  63        clientTearDown ()
  64
  65    Any new test functions within the class must then define
  66    tests in pairs, where the test name is preceeded with a
  67    '_' to indicate the client portion of the test. Ex:
  68
  69        def testFoo(self):
  70            # Server portion
  71
  72        def _testFoo(self):
  73            # Client portion
  74
  75    Any exceptions raised by the clients during their tests
  76    are caught and transferred to the main thread to alert
  77    the testing framework.
  78
  79    Note, the server setup function cannot call any blocking
  80    functions that rely on the client thread during setup,
  81    unless serverExplicityReady() is called just before
  82    the blocking call (such as in setting up a client/server
  83    connection and performing the accept() in setUp().
  84    """
  85
  86    def __init__(self):
  87        # Swap the true setup function
  88        self.__setUp = self.setUp
  89        self.__tearDown = self.tearDown
  90        self.setUp = self._setUp
  91        self.tearDown = self._tearDown
  92
  93    def serverExplicitReady(self):
  94        """This method allows the server to explicitly indicate that
  95        it wants the client thread to proceed. This is useful if the
  96        server is about to execute a blocking routine that is
  97        dependent upon the client thread during its setup routine."""
  98        self.server_ready.set()
  99
 100    def _setUp(self):
 101        #self.server_ready = threading.Event()
 102        #self.client_ready = threading.Event()
 103        #self.done = threading.Event()
 104        #self.queue = Queue.Queue(1)
 105        # monkey-patched threading.Event and threading.Queue work fine here too
 106        # but let's test gevent classes
 107        from gevent.event import Event
 108        from gevent.queue import Queue
 109        self.server_ready = Event()
 110        self.client_ready = Event()
 111        self.done = Event()
 112        self.queue = Queue(1)
 113
 114        # Do some munging to start the client test.
 115        methodname = self.id()
 116        i = methodname.rfind('.')
 117        methodname = methodname[i+1:]
 118        test_method = getattr(self, '_' + methodname)
 119        self.client_thread = thread.start_new_thread(
 120            self.clientRun, (test_method,))
 121
 122        self.__setUp()
 123        if not self.server_ready.isSet():
 124            self.server_ready.set()
 125        self.client_ready.wait()
 126
 127    def _tearDown(self):
 128        self.__tearDown()
 129        self.done.wait()
 130
 131        if not self.queue.empty():
 132            msg = self.queue.get()
 133            self.fail(msg)
 134
 135    def clientRun(self, test_func):
 136        self.server_ready.wait()
 137        self.client_ready.set()
 138        self.clientSetUp()
 139        if not callable(test_func):
 140            raise TypeError, "test_func must be a callable function"
 141        try:
 142            test_func()
 143        except Exception, strerror:
 144            self.queue.put(strerror)
 145        self.clientTearDown()
 146
 147    def clientSetUp(self):
 148        raise NotImplementedError, "clientSetUp must be implemented."
 149
 150    def clientTearDown(self):
 151        self.done.set()
 152        thread.exit()
 153
 154class ThreadedTCPSocketTest(SocketTCPTest, ThreadableTest):
 155
 156    def __init__(self, methodName='runTest'):
 157        SocketTCPTest.__init__(self, methodName=methodName)
 158        ThreadableTest.__init__(self)
 159
 160    def clientSetUp(self):
 161        self.cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 162
 163    def clientTearDown(self):
 164        self.cli.close()
 165        self.cli = None
 166        ThreadableTest.clientTearDown(self)
 167
 168class ThreadedUDPSocketTest(SocketUDPTest, ThreadableTest):
 169
 170    def __init__(self, methodName='runTest'):
 171        SocketUDPTest.__init__(self, methodName=methodName)
 172        ThreadableTest.__init__(self)
 173
 174    def clientSetUp(self):
 175        self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
 176
 177class SocketConnectedTest(ThreadedTCPSocketTest):
 178
 179    def __init__(self, methodName='runTest'):
 180        ThreadedTCPSocketTest.__init__(self, methodName=methodName)
 181
 182    def setUp(self):
 183        ThreadedTCPSocketTest.setUp(self)
 184        # Indicate explicitly we're ready for the client thread to
 185        # proceed and then perform the blocking call to accept
 186        self.serverExplicitReady()
 187        conn, addr = self.serv.accept()
 188        self.cli_conn = conn
 189
 190    def tearDown(self):
 191        self.cli_conn.close()
 192        self.cli_conn = None
 193        ThreadedTCPSocketTest.tearDown(self)
 194
 195    def clientSetUp(self):
 196        ThreadedTCPSocketTest.clientSetUp(self)
 197        self.cli.connect((HOST, PORT))
 198        self.serv_conn = self.cli
 199
 200    def clientTearDown(self):
 201        self.serv_conn.close()
 202        self.serv_conn = None
 203        ThreadedTCPSocketTest.clientTearDown(self)
 204
 205class SocketPairTest(greentest.TestCase, ThreadableTest):
 206
 207    def __init__(self, methodName='runTest'):
 208        greentest.TestCase.__init__(self, methodName=methodName)
 209        ThreadableTest.__init__(self)
 210
 211    def setUp(self):
 212        self.serv, self.cli = socket.socketpair()
 213
 214    def tearDown(self):
 215        self.serv.close()
 216        self.serv = None
 217
 218    def clientSetUp(self):
 219        pass
 220
 221    def clientTearDown(self):
 222        self.cli.close()
 223        self.cli = None
 224        ThreadableTest.clientTearDown(self)
 225
 226
 227#######################################################################
 228## Begin Tests
 229
 230class GeneralModuleTests(greentest.TestCase):
 231
 232    def test_weakref(self):
 233        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 234        p = proxy(s)
 235        self.assertEqual(p.fileno(), s.fileno())
 236        s.close()
 237        s = None
 238        try:
 239            p.fileno()
 240        except ReferenceError:
 241            pass
 242        else:
 243            self.fail('Socket proxy still exists')
 244
 245    def testSocketError(self):
 246        # Testing socket module exceptions
 247        def raise_error(*args, **kwargs):
 248            raise socket.error
 249        def raise_herror(*args, **kwargs):
 250            raise socket.herror
 251        def raise_gaierror(*args, **kwargs):
 252            raise socket.gaierror
 253        self.failUnlessRaises(socket.error, raise_error,
 254                              "Error raising socket exception.")
 255        self.failUnlessRaises(socket.error, raise_herror,
 256                              "Error raising socket exception.")
 257        self.failUnlessRaises(socket.error, raise_gaierror,
 258                              "Error raising socket exception.")
 259
 260    def testCrucialConstants(self):
 261        # Testing for mission critical constants
 262        socket.AF_INET
 263        socket.SOCK_STREAM
 264        socket.SOCK_DGRAM
 265        socket.SOCK_RAW
 266        socket.SOCK_RDM
 267        socket.SOCK_SEQPACKET
 268        socket.SOL_SOCKET
 269        socket.SO_REUSEADDR
 270
 271    def testHostnameRes(self):
 272        # Testing hostname resolution mechanisms
 273        hostname = socket.gethostname()
 274        try:
 275            ip = socket.gethostbyname(hostname)
 276        except socket.error:
 277            # Probably name lookup wasn't set up right; skip this test
 278            return
 279        self.assert_(ip.find('.') >= 0, "Error resolving host to ip.")
 280        try:
 281            hname, aliases, ipaddrs = socket.gethostbyaddr(ip)
 282        except socket.error:
 283            # Probably a similar problem as above; skip this test
 284            return
 285        all_host_names = [hostname, hname] + aliases
 286        fqhn = socket.getfqdn(ip)
 287        if not fqhn in all_host_names:
 288            self.fail("Error testing host resolution mechanisms. (fqdn: %s, all: %s)" % (fqhn, repr(all_host_names)))
 289
 290    def testRefCountGetNameInfo(self):
 291        # Testing reference count for getnameinfo
 292        import sys
 293        if hasattr(sys, "getrefcount"):
 294            try:
 295                # On some versions, this loses a reference
 296                orig = sys.getrefcount(__name__)
 297                socket.getnameinfo(__name__,0)
 298            except SystemError:
 299                if sys.getrefcount(__name__) <> orig:
 300                    self.fail("socket.getnameinfo loses a reference")
 301
 302    def testInterpreterCrash(self):
 303        # Making sure getnameinfo doesn't crash the interpreter
 304        try:
 305            # On some versions, this crashes the interpreter.
 306            socket.getnameinfo(('x', 0, 0, 0), 0)
 307        except socket.error:
 308            pass
 309
 310    def testNtoH(self):
 311        # This just checks that htons etc. are their own inverse,
 312        # when looking at the lower 16 or 32 bits.
 313        sizes = {socket.htonl: 32, socket.ntohl: 32,
 314                 socket.htons: 16, socket.ntohs: 16}
 315        for func, size in sizes.items():
 316            mask = (1L<<size) - 1
 317            for i in (0, 1, 0xffff, ~0xffff, 2, 0x01234567, 0x76543210):
 318                self.assertEqual(i & mask, func(func(i&mask)) & mask)
 319
 320            swapped = func(mask)
 321            self.assertEqual(swapped & mask, mask)
 322            self.assertRaises(OverflowError, func, 1L<<34)
 323
 324    def testGetServBy(self):
 325        eq = self.assertEqual
 326        # Find one service that exists, then check all the related interfaces.
 327        # I've ordered this by protocols that have both a tcp and udp
 328        # protocol, at least for modern Linuxes.
 329        if sys.platform in ('linux2', 'freebsd4', 'freebsd5', 'freebsd6',
 330                            'freebsd7', 'darwin'):
 331            # avoid the 'echo' service on this platform, as there is an
 332            # assumption breaking non-standard port/protocol entry
 333            services = ('daytime', 'qotd', 'domain')
 334        else:
 335            services = ('echo', 'daytime', 'domain')
 336        for service in services:
 337            try:
 338                port = socket.getservbyname(service, 'tcp')
 339                break
 340            except socket.error:
 341                pass
 342        else:
 343            raise socket.error
 344        # Try same call with optional protocol omitted
 345        port2 = socket.getservbyname(service)
 346        eq(port, port2)
 347        # Try udp, but don't barf it it doesn't exist
 348        try:
 349            udpport = socket.getservbyname(service, 'udp')
 350        except socket.error:
 351            udpport = None
 352        else:
 353            eq(udpport, port)
 354        # Now make sure the lookup by port returns the same service name
 355        eq(socket.getservbyport(port2), service)
 356        eq(socket.getservbyport(port, 'tcp'), service)
 357        if udpport is not None:
 358            eq(socket.getservbyport(udpport, 'udp'), service)
 359
 360    def testDefaultTimeout(self):
 361        # Testing default timeout
 362        # The default timeout should initially be None
 363        self.assertEqual(socket.getdefaulttimeout(), None)
 364        s = socket.socket()
 365        self.assertEqual(s.gettimeout(), None)
 366        s.close()
 367
 368        # Set the default timeout to 10, and see if it propagates
 369        socket.setdefaulttimeout(10)
 370        self.assertEqual(socket.getdefaulttimeout(), 10)
 371        s = socket.socket()
 372        self.assertEqual(s.gettimeout(), 10)
 373        s.close()
 374
 375        # Reset the default timeout to None, and see if it propagates
 376        socket.setdefaulttimeout(None)
 377        self.assertEqual(socket.getdefaulttimeout(), None)
 378        s = socket.socket()
 379        self.assertEqual(s.gettimeout(), None)
 380        s.close()
 381
 382        # Check that setting it to an invalid value raises ValueError
 383        self.assertRaises(ValueError, socket.setdefaulttimeout, -1)
 384
 385        # Check that setting it to an invalid type raises TypeError
 386        self.assertRaises(TypeError, socket.setdefaulttimeout, "spam")
 387
 388    def testIPv4toString(self):
 389        if not hasattr(socket, 'inet_pton'):
 390            return # No inet_pton() on this platform
 391        from socket import inet_aton as f, inet_pton, AF_INET
 392        g = lambda a: inet_pton(AF_INET, a)
 393
 394        self.assertEquals('\x00\x00\x00\x00', f('0.0.0.0'))
 395        self.assertEquals('\xff\x00\xff\x00', f('255.0.255.0'))
 396        self.assertEquals('\xaa\xaa\xaa\xaa', f('170.170.170.170'))
 397        self.assertEquals('\x01\x02\x03\x04', f('1.2.3.4'))
 398        self.assertEquals('\xff\xff\xff\xff', f('255.255.255.255'))
 399
 400        self.assertEquals('\x00\x00\x00\x00', g('0.0.0.0'))
 401        self.assertEquals('\xff\x00\xff\x00', g('255.0.255.0'))
 402        self.assertEquals('\xaa\xaa\xaa\xaa', g('170.170.170.170'))
 403        self.assertEquals('\xff\xff\xff\xff', g('255.255.255.255'))
 404
 405    def testIPv6toString(self):
 406        if not hasattr(socket, 'inet_pton'):
 407            return # No inet_pton() on this platform
 408        try:
 409            from socket import inet_pton, AF_INET6, has_ipv6
 410            if not has_ipv6:
 411                return
 412        except ImportError:
 413            return
 414        f = lambda a: inet_pton(AF_INET6, a)
 415
 416        self.assertEquals('\x00' * 16, f('::'))
 417        self.assertEquals('\x00' * 16, f('0::0'))
 418        self.assertEquals('\x00\x01' + '\x00' * 14, f('1::'))
 419        self.assertEquals(
 420            '\x45\xef\x76\xcb\x00\x1a\x56\xef\xaf\xeb\x0b\xac\x19\x24\xae\xae',
 421            f('45ef:76cb:1a:56ef:afeb:bac:1924:aeae')
 422        )
 423
 424    def testStringToIPv4(self):
 425        if not hasattr(socket, 'inet_ntop'):
 426            return # No inet_ntop() on this platform
 427        from socket import inet_ntoa as f, inet_ntop, AF_INET
 428        g = lambda a: inet_ntop(AF_INET, a)
 429
 430        self.assertEquals('1.0.1.0', f('\x01\x00\x01\x00'))
 431        self.assertEquals('170.85.170.85', f('\xaa\x55\xaa\x55'))
 432        self.assertEquals('255.255.255.255', f('\xff\xff\xff\xff'))
 433        self.assertEquals('1.2.3.4', f('\x01\x02\x03\x04'))
 434
 435        self.assertEquals('1.0.1.0', g('\x01\x00\x01\x00'))
 436        self.assertEquals('170.85.170.85', g('\xaa\x55\xaa\x55'))
 437        self.assertEquals('255.255.255.255', g('\xff\xff\xff\xff'))
 438
 439    def testStringToIPv6(self):
 440        if not hasattr(socket, 'inet_ntop'):
 441            return # No inet_ntop() on this platform
 442        try:
 443            from socket import inet_ntop, AF_INET6, has_ipv6
 444            if not has_ipv6:
 445                return
 446        except ImportError:
 447            return
 448        f = lambda a: inet_ntop(AF_INET6, a)
 449
 450        self.assertEquals('::', f('\x00' * 16))
 451        self.assertEquals('::1', f('\x00' * 15 + '\x01'))
 452        self.assertEquals(
 453            'aef:b01:506:1001:ffff:9997:55:170',
 454            f('\x0a\xef\x0b\x01\x05\x06\x10\x01\xff\xff\x99\x97\x00\x55\x01\x70')
 455        )
 456
 457    # XXX The following don't test module-level functionality...
 458
 459    def testSockName(self):
 460        # Testing getsockname()
 461        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 462        sock.bind(("0.0.0.0", PORT+1))
 463        name = sock.getsockname()
 464        # XXX(nnorwitz): http://tinyurl.com/os5jz seems to indicate
 465        # it reasonable to get the host's addr in addition to 0.0.0.0.
 466        # At least for eCos.  This is required for the S/390 to pass.
 467        my_ip_addr = socket.gethostbyname(socket.gethostname())
 468        self.assert_(name[0] in ("0.0.0.0", my_ip_addr), '%s invalid' % name[0])
 469        self.assertEqual(name[1], PORT+1)
 470
 471    def testGetSockOpt(self):
 472        # Testing getsockopt()
 473        # We know a socket should start without reuse==0
 474        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 475        reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
 476        self.failIf(reuse != 0, "initial mode is reuse")
 477
 478    def testSetSockOpt(self):
 479        # Testing setsockopt()
 480        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 481        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
 482        reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
 483        self.failIf(reuse == 0, "failed to set reuse mode")
 484
 485    __timeout__ = 2
 486
 487    def testSendAfterClose(self):
 488        # testing send() after close() with timeout
 489        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 490        sock.settimeout(1)
 491        sock.close()
 492        self.assertRaises(socket.error, sock.send, "spam")
 493
 494    def testNewAttributes(self):
 495        # testing .family, .type and .protocol
 496        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 497        self.assertEqual(sock.family, socket.AF_INET)
 498        self.assertEqual(sock.type, socket.SOCK_STREAM)
 499        self.assertEqual(sock.proto, 0)
 500        sock.close()
 501
 502class BasicTCPTest(SocketConnectedTest):
 503
 504    def __init__(self, methodName='runTest'):
 505        SocketConnectedTest.__init__(self, methodName=methodName)
 506
 507    def testRecv(self):
 508        # Testing large receive over TCP
 509        msg = self.cli_conn.recv(1024)
 510        self.assertEqual(msg, MSG)
 511
 512    def _testRecv(self):
 513        self.serv_conn.send(MSG)
 514
 515    def testOverFlowRecv(self):
 516        # Testing receive in chunks over TCP
 517        seg1 = self.cli_conn.recv(len(MSG) - 3)
 518        seg2 = self.cli_conn.recv(1024)
 519        msg = seg1 + seg2
 520        self.assertEqual(msg, MSG)
 521
 522    def _testOverFlowRecv(self):
 523        self.serv_conn.send(MSG)
 524
 525    def testRecvFrom(self):
 526        # Testing large recvfrom() over TCP
 527        msg, addr = self.cli_conn.recvfrom(1024)
 528        self.assertEqual(msg, MSG)
 529
 530    def _testRecvFrom(self):
 531        self.serv_conn.send(MSG)
 532
 533    def testOverFlowRecvFrom(self):
 534        # Testing recvfrom() in chunks over TCP
 535        seg1, addr = self.cli_conn.recvfrom(len(MSG)-3)
 536        seg2, addr = self.cli_conn.recvfrom(1024)
 537        msg = seg1 + seg2
 538        self.assertEqual(msg, MSG)
 539
 540    def _testOverFlowRecvFrom(self):
 541        self.serv_conn.send(MSG)
 542
 543    def testSendAll(self):
 544        # Testing sendall() with a 2048 byte string over TCP
 545        msg = ''
 546        while 1:
 547            read = self.cli_conn.recv(1024)
 548            if not read:
 549                break
 550            msg += read
 551        self.assertEqual(msg, 'f' * 2048)
 552
 553    def _testSendAll(self):
 554        big_chunk = 'f' * 2048
 555        self.serv_conn.sendall(big_chunk)
 556
 557    def testFromFd(self):
 558        # Testing fromfd()
 559        if not hasattr(socket, "fromfd"):
 560            return # On Windows, this doesn't exist
 561        fd = self.cli_conn.fileno()
 562        sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
 563        msg = sock.recv(1024)
 564        self.assertEqual(msg, MSG)
 565
 566    def _testFromFd(self):
 567        self.serv_conn.send(MSG)
 568
 569    def testShutdown(self):
 570        # Testing shutdown()
 571        msg = self.cli_conn.recv(1024)
 572        self.assertEqual(msg, MSG)
 573
 574    def _testShutdown(self):
 575        self.serv_conn.send(MSG)
 576        self.serv_conn.shutdown(2)
 577
 578class BasicUDPTest(ThreadedUDPSocketTest):
 579
 580    def __init__(self, methodName='runTest'):
 581        ThreadedUDPSocketTest.__init__(self, methodName=methodName)
 582
 583    def testSendtoAndRecv(self):
 584        # Testing sendto() and Recv() over UDP
 585        msg = self.serv.recv(len(MSG))
 586        self.assertEqual(msg, MSG)
 587
 588    def _testSendtoAndRecv(self):
 589        self.cli.sendto(MSG, 0, (HOST, PORT))
 590
 591    def testRecvFrom(self):
 592        # Testing recvfrom() over UDP
 593        msg, addr = self.serv.recvfrom(len(MSG))
 594        self.assertEqual(msg, MSG)
 595
 596    def _testRecvFrom(self):
 597        self.cli.sendto(MSG, 0, (HOST, PORT))
 598
 599    def testRecvFromNegative(self):
 600        # Negative lengths passed to recvfrom should give ValueError.
 601        self.assertRaises(ValueError, self.serv.recvfrom, -1)
 602
 603    def _testRecvFromNegative(self):
 604        self.cli.sendto(MSG, 0, (HOST, PORT))
 605
 606class TCPCloserTest(ThreadedTCPSocketTest):
 607
 608    def testClose(self):
 609        conn, addr = self.serv.accept()
 610        conn.close()
 611
 612        sd = self.cli
 613        read, write, err = select.select([sd], [], [], 1.0)
 614        self.assertEqual(read, [sd])
 615        self.assertEqual(sd.recv(1), '')
 616
 617    def _testClose(self):
 618        self.cli.connect((HOST, PORT))
 619        time.sleep(1.0)
 620
 621class BasicSocketPairTest(SocketPairTest):
 622
 623    def __init__(self, methodName='runTest'):
 624        SocketPairTest.__init__(self, methodName=methodName)
 625
 626    def testRecv(self):
 627        msg = self.serv.recv(1024)
 628        self.assertEqual(msg, MSG)
 629
 630    def _testRecv(self):
 631        self.cli.send(MSG)
 632
 633    def testSend(self):
 634        self.serv.send(MSG)
 635
 636    def _testSend(self):
 637        msg = self.cli.recv(1024)
 638        self.assertEqual(msg, MSG)
 639
 640class NonBlockingTCPTests(ThreadedTCPSocketTest):
 641
 642    def __init__(self, methodName='runTest'):
 643        ThreadedTCPSocketTest.__init__(self, methodName=methodName)
 644
 645    def testSetBlocking(self):
 646        # Testing whether set blocking works
 647        self.serv.setblocking(0)
 648        start = time.time()
 649        try:
 650            self.serv.accept()
 651        except socket.error:
 652            pass
 653        end = time.time()
 654        self.assert_((end - start) < 1.0, "Error setting non-blocking mode.")
 655
 656    def _testSetBlocking(self):
 657        pass
 658
 659    def testAccept(self):
 660        # Testing non-blocking accept
 661        self.serv.setblocking(0)
 662        try:
 663            conn, addr = self.serv.accept()
 664        except socket.error:
 665            pass
 666        else:
 667            self.fail("Error trying to do non-blocking accept.")
 668        read, write, err = select.select([self.serv], [], [])
 669        if self.serv in read:
 670            conn, addr = self.serv.accept()
 671        else:
 672            self.fail("Error trying to do accept after select.")
 673
 674    def _testAccept(self):
 675        time.sleep(0.1)
 676        self.cli.connect((HOST, PORT))
 677
 678    def testConnect(self):
 679        # Testing non-blocking connect
 680        conn, addr = self.serv.accept()
 681
 682    def _testConnect(self):
 683        self.cli.settimeout(10)
 684        self.cli.connect((HOST, PORT))
 685
 686    def testRecv(self):
 687        # Testing non-blocking recv
 688        conn, addr = self.serv.accept()
 689        conn.setblocking(0)
 690        try:
 691            msg = conn.recv(len(MSG))
 692        except socket.error:
 693            pass
 694        else:
 695            self.fail("Error trying to do non-blocking recv.")
 696        read, write, err = select.select([conn], [], [])
 697        if conn in read:
 698            msg = conn.recv(len(MSG))
 699            self.assertEqual(msg, MSG)
 700        else:
 701            self.fail("Error during select call to non-blocking socket.")
 702
 703    def _testRecv(self):
 704        self.cli.connect((HOST, PORT))
 705        time.sleep(0.1)
 706        self.cli.send(MSG)
 707
 708class FileObjectClassTestCase(SocketConnectedTest):
 709
 710    bufsize = -1 # Use default buffer size
 711
 712    def __init__(self, methodName='runTest'):
 713        SocketConnectedTest.__init__(self, methodName=methodName)
 714
 715    def setUp(self):
 716        SocketConnectedTest.setUp(self)
 717        self.serv_file = self.cli_conn.makefile('rb', self.bufsize)
 718
 719    def tearDown(self):
 720        self.serv_file.close()
 721        self.assert_(self.serv_file.closed)
 722        self.serv_file = None
 723        SocketConnectedTest.tearDown(self)
 724
 725    def clientSetUp(self):
 726        SocketConnectedTest.clientSetUp(self)
 727        self.cli_file = self.serv_conn.makefile('wb')
 728
 729    def clientTearDown(self):
 730        self.cli_file.close()
 731        self.assert_(self.cli_file.closed)
 732        self.cli_file = None
 733        SocketConnectedTest.clientTearDown(self)
 734
 735    def testSmallRead(self):
 736        # Performing small file read test
 737        first_seg = self.serv_file.read(len(MSG)-3)
 738        second_seg = self.serv_file.read(3)
 739        msg = first_seg + second_seg
 740        self.assertEqual(msg, MSG)
 741
 742    def _testSmallRead(self):
 743        self.cli_file.write(MSG)
 744        self.cli_file.flush()
 745
 746    def testFullRead(self):
 747        # read until EOF
 748        msg = self.serv_file.read()
 749        self.assertEqual(msg, MSG)
 750
 751    def _testFullRead(self):
 752        self.cli_file.write(MSG)
 753        self.cli_file.close()
 754
 755    def testUnbufferedRead(self):
 756        # Performing unbuffered file read test
 757        buf = ''
 758        while 1:
 759            char = self.serv_file.read(1)
 760            if not char:
 761                break
 762            buf += char
 763        self.assertEqual(buf, MSG)
 764
 765    def _testUnbufferedRead(self):
 766        self.cli_file.write(MSG)
 767        self.cli_file.flush()
 768
 769    def testReadline(self):
 770        # Performing file readline test
 771        line = self.serv_file.readline()
 772        self.assertEqual(line, MSG)
 773
 774    def _testReadline(self):
 775        self.cli_file.write(MSG)
 776        self.cli_file.flush()
 777
 778    def testReadlineAfterRead(self):
 779        a_baloo_is = self.serv_file.read(len("A baloo is"))
 780        self.assertEqual("A baloo is", a_baloo_is)
 781        _a_bear = self.serv_file.read(len(" a bear"))
 782        self.assertEqual(" a bear", _a_bear)
 783        line = self.serv_file.readline()
 784        self.assertEqual("\n", line)
 785        line = self.serv_file.readline()
 786        self.assertEqual("A BALOO IS A BEAR.\n", line)
 787        line = self.serv_file.readline()
 788        self.assertEqual(MSG, line)
 789
 790    def _testReadlineAfterRead(self):
 791        self.cli_file.write("A baloo is a bear\n")
 792        self.cli_file.write("A BALOO IS A BEAR.\n")
 793        self.cli_file.write(MSG)
 794        self.cli_file.flush()
 795
 796    def testReadlineAfterReadNoNewline(self):
 797        end_of_ = self.serv_file.read(len("End Of "))
 798        self.assertEqual("End Of ", end_of_)
 799        line = self.serv_file.readline()
 800        self.assertEqual("Line", line)
 801
 802    def _testReadlineAfterReadNoNewline(self):
 803        self.cli_file.write("End Of Line")
 804
 805    def testClosedAttr(self):
 806        self.assert_(not self.serv_file.closed)
 807
 808    def _testClosedAttr(self):
 809        self.assert_(not self.cli_file.closed)
 810
 811class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase):
 812
 813    """Repeat the tests from FileObjectClassTestCase with bufsize==0.
 814
 815    In this case (and in this case only), it should be possible to
 816    create a file object, read a line from it, create another file
 817    object, read another line from it, without loss of data in the
 818    first file object's buffer.  Note that httplib relies on this
 819    when reading multiple requests from the same socket."""
 820
 821    bufsize = 0 # Use unbuffered mode
 822
 823    def testUnbufferedReadline(self):
 824        # Read a line, create a new file object, read another line with it
 825        line = self.serv_file.readline() # first line
 826        self.assertEqual(line, "A. " + MSG) # first line
 827        self.serv_file = self.cli_conn.makefile('rb', 0)
 828        line = self.serv_file.readline() # second line
 829        self.assertEqual(line, "B. " + MSG) # second line
 830
 831    def _testUnbufferedReadline(self):
 832        self.cli_file.write("A. " + MSG)
 833        self.cli_file.write("B. " + MSG)
 834        self.cli_file.flush()
 835
 836class LineBufferedFileObjectClassTestCase(FileObjectClassTestCase):
 837
 838    bufsize = 1 # Default-buffered for reading; line-buffered for writing
 839
 840
 841class SmallBufferedFileObjectClassTestCase(FileObjectClassTestCase):
 842
 843    bufsize = 2 # Exercise the buffering code
 844
 845
 846class Urllib2FileobjectTest(greentest.TestCase):
 847
 848    # urllib2.HTTPHandler has "borrowed" socket._fileobject, and requires that
 849    # it close the socket if the close c'tor argument is true
 850
 851    def testClose(self):
 852        class MockSocket:
 853            closed = False
 854            def flush(self): pass
 855            def close(self): self.closed = True
 856
 857        # must not close unless we request it: the original use of _fileobject
 858        # by module socket requires that the underlying socket not be closed until
 859        # the _socketobject that created the _fileobject is closed
 860        s = MockSocket()
 861        f = socket._fileobject(s)
 862        f.close()
 863        self.assert_(not s.closed)
 864
 865        s = MockSocket()
 866        f = socket._fileobject(s, close=True)
 867        f.close()
 868        self.assert_(s.closed)
 869
 870class TCPTimeoutTest(SocketTCPTest):
 871
 872    def testTCPTimeout(self):
 873        def raise_timeout(*args, **kwargs):
 874            self.serv.settimeout(1.0)
 875            self.serv.accept()
 876        self.failUnlessRaises(socket.timeout, raise_timeout,
 877                              "Error generating a timeout exception (TCP)")
 878
 879    def testTimeoutZero(self):
 880        ok = False
 881        try:
 882            self.serv.settimeout(0.0)
 883            foo = self.serv.accept()
 884        except socket.timeout:
 885            self.fail("caught timeout instead of error (TCP)")
 886        except socket.error:
 887            ok = True
 888        except:
 889            self.fail("caught unexpected exception (TCP)")
 890        if not ok:
 891            self.fail("accept() returned success when we did not expect it")
 892
 893    def testInterruptedTimeout(self):
 894        # XXX I don't know how to do this test on MSWindows or any other
 895        # plaform that doesn't support signal.alarm() or os.kill(), though
 896        # the bug should have existed on all platforms.
 897        if not hasattr(signal, "alarm"):
 898            return                  # can only test on *nix
 899        self.serv.settimeout(5.0)   # must be longer than alarm
 900        class Alarm(Exception):
 901            pass
 902        def alarm_handler(signal, frame):
 903            raise Alarm
 904        import gevent
 905        myalarm = gevent.signal(signal.SIGALRM, alarm_handler, signal.SIGALRM, None)
 906        try:
 907            try:
 908                signal.alarm(2)    # POSIX allows alarm to be up to 1 second early
 909                try:
 910                    foo = self.serv.accept()
 911                except socket.timeout:
 912                    self.fail("caught timeout instead of Alarm")
 913                except Alarm:
 914                    pass
 915                except:
 916                    self.fail("caught other exception instead of Alarm")
 917                else:
 918                    self.fail("nothing caught")
 919                signal.alarm(0)         # shut off alarm
 920            except Alarm:
 921                self.fail("got Alarm in wrong place")
 922        finally:
 923            # no alarm can be pending.  Safe to restore old handler.
 924            myalarm.cancel()
 925
 926class UDPTimeoutTest(SocketUDPTest):
 927
 928    def testUDPTimeout(self):
 929        def raise_timeout(*args, **kwargs):
 930            self.serv.settimeout(1.0)
 931            self.serv.recv(1024)
 932        self.failUnlessRaises(socket.timeout, raise_timeout,
 933                              "Error generating a timeout exception (UDP)")
 934
 935    def testTimeoutZero(self):
 936        ok = False
 937        try:
 938            self.serv.settimeout(0.0)
 939            foo = self.serv.recv(1024)
 940        except socket.timeout:
 941            self.fail("caught timeout instead of error (UDP)")
 942        except socket.error:
 943            ok = True
 944        except:
 945            self.fail("caught unexpected exception (UDP)")
 946        if not ok:
 947            self.fail("recv() returned success when we did not expect it")
 948
 949class TestExceptions(greentest.TestCase):
 950
 951    def testExceptionTree(self):
 952        self.assert_(issubclass(socket.error, Exception))
 953        self.assert_(issubclass(socket.herror, socket.error))
 954        self.assert_(issubclass(socket.gaierror, socket.error))
 955        self.assert_(issubclass(socket.timeout, socket.error))
 956
 957class TestLinuxAbstractNamespace(greentest.TestCase):
 958
 959    UNIX_PATH_MAX = 108
 960
 961    def testLinuxAbstractNamespace(self):
 962        address = "\x00python-test-hello\x00\xff"
 963        s1 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
 964        s1.bind(address)
 965        s1.listen(1)
 966        s2 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
 967        s2.connect(s1.getsockname())
 968        s1.accept()
 969        self.assertEqual(s1.getsockname(), address)
 970        self.assertEqual(s2.getpeername(), address)
 971
 972    def testMaxName(self):
 973        address = "\x00" + "h" * (self.UNIX_PATH_MAX - 1)
 974        s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
 975        s.bind(address)
 976        self.assertEqual(s.getsockname(), address)
 977
 978    def testNameOverflow(self):
 979        address = "\x00" + "h" * self.UNIX_PATH_MAX
 980        s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
 981        self.assertRaises(socket.error, s.bind, address)
 982
 983
 984class BufferIOTest(SocketConnectedTest):
 985    """
 986    Test the buffer versions of socket.recv() and socket.send().
 987    """
 988    def __init__(self, methodName='runTest'):
 989        SocketConnectedTest.__init__(self, methodName=methodName)
 990
 991    def testRecvInto(self):
 992        buf = array.array('c', ' '*1024)
 993        nbytes = self.cli_conn.recv_into(buf)
 994        self.assertEqual(nbytes, len(MSG))
 995        msg = buf.tostring()[:len(MSG)]
 996        self.assertEqual(msg, MSG)
 997
 998    def _testRecvInto(self):
 999        buf = buffer(MSG)
1000        self.serv_conn.send(buf)
1001
1002    def testRecvFromInto(self):
1003        buf = array.array('c', ' '*1024)
1004        nbytes, addr = self.cli_conn.recvfrom_into(buf)
1005        self.assertEqual(nbytes, len(MSG))
1006        msg = buf.tostring()[:len(MSG)]
1007        self.assertEqual(msg, MSG)
1008
1009    def _testRecvFromInto(self):
1010        buf = buffer(MSG)
1011        self.serv_conn.send(buf)
1012
1013def test_main():
1014    tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
1015             TestExceptions, BufferIOTest]
1016    if sys.platform != 'mac':
1017        tests.extend([ BasicUDPTest, UDPTimeoutTest ])
1018
1019    tests.extend([
1020        NonBlockingTCPTests,
1021        FileObjectClassTestCase,
1022        UnbufferedFileObjectClassTestCase,
1023        LineBufferedFileObjectClassTestCase,
1024        SmallBufferedFileObjectClassTestCase,
1025        Urllib2FileobjectTest,
1026    ])
1027    if hasattr(socket, "socketpair"):
1028        tests.append(BasicSocketPairTest)
1029    if sys.platform == 'linux2':
1030        tests.append(TestLinuxAbstractNamespace)
1031
1032    thread_info = test_support.threading_setup()
1033    test_support.run_unittest(*tests)
1034    test_support.threading_cleanup(*thread_info)
1035
1036if __name__ == "__main__":
1037    test_main()