PageRenderTime 57ms CodeModel.GetById 2ms app.highlight 47ms RepoModel.GetById 1ms app.codeStats 0ms

/lib-python/modified-2.5.2/test/test_socket.py

https://bitbucket.org/probreasoning/pypy
Python | 1012 lines | 860 code | 77 blank | 75 comment | 30 complexity | c6a5792138e299d95c2665d21606b460 MD5 | raw file
   1#!/usr/bin/env python
   2
   3import unittest
   4from test import test_support
   5
   6import socket
   7import select
   8import time
   9import thread, threading
  10import Queue
  11import sys, gc
  12import array
  13from weakref import proxy
  14import signal
  15
  16PORT = 50007
  17HOST = 'localhost'
  18MSG = 'Michael Gilfix was here\n'
  19
  20class SocketTCPTest(unittest.TestCase):
  21
  22    def setUp(self):
  23        self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  24        self.serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  25        global PORT
  26        PORT = test_support.bind_port(self.serv, HOST, PORT)
  27        self.serv.listen(1)
  28
  29    def tearDown(self):
  30        self.serv.close()
  31        self.serv = None
  32
  33class SocketUDPTest(unittest.TestCase):
  34
  35    def setUp(self):
  36        self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  37        self.serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  38        global PORT
  39        PORT = test_support.bind_port(self.serv, HOST, PORT)
  40
  41    def tearDown(self):
  42        self.serv.close()
  43        self.serv = None
  44
  45class ThreadableTest:
  46    """Threadable Test class
  47
  48    The ThreadableTest class makes it easy to create a threaded
  49    client/server pair from an existing unit test. To create a
  50    new threaded class from an existing unit test, use multiple
  51    inheritance:
  52
  53        class NewClass (OldClass, ThreadableTest):
  54            pass
  55
  56    This class defines two new fixture functions with obvious
  57    purposes for overriding:
  58
  59        clientSetUp ()
  60        clientTearDown ()
  61
  62    Any new test functions within the class must then define
  63    tests in pairs, where the test name is preceeded with a
  64    '_' to indicate the client portion of the test. Ex:
  65
  66        def testFoo(self):
  67            # Server portion
  68
  69        def _testFoo(self):
  70            # Client portion
  71
  72    Any exceptions raised by the clients during their tests
  73    are caught and transferred to the main thread to alert
  74    the testing framework.
  75
  76    Note, the server setup function cannot call any blocking
  77    functions that rely on the client thread during setup,
  78    unless serverExplicityReady() is called just before
  79    the blocking call (such as in setting up a client/server
  80    connection and performing the accept() in setUp().
  81    """
  82
  83    def __init__(self):
  84        # Swap the true setup function
  85        self.__setUp = self.setUp
  86        self.__tearDown = self.tearDown
  87        self.setUp = self._setUp
  88        self.tearDown = self._tearDown
  89
  90    def serverExplicitReady(self):
  91        """This method allows the server to explicitly indicate that
  92        it wants the client thread to proceed. This is useful if the
  93        server is about to execute a blocking routine that is
  94        dependent upon the client thread during its setup routine."""
  95        self.server_ready.set()
  96
  97    def _setUp(self):
  98        self.server_ready = threading.Event()
  99        self.client_ready = threading.Event()
 100        self.done = threading.Event()
 101        self.queue = Queue.Queue(1)
 102
 103        # Do some munging to start the client test.
 104        methodname = self.id()
 105        i = methodname.rfind('.')
 106        methodname = methodname[i+1:]
 107        test_method = getattr(self, '_' + methodname)
 108        self.client_thread = thread.start_new_thread(
 109            self.clientRun, (test_method,))
 110
 111        self.__setUp()
 112        if not self.server_ready.isSet():
 113            self.server_ready.set()
 114        self.client_ready.wait()
 115
 116    def _tearDown(self):
 117        self.__tearDown()
 118        self.done.wait()
 119
 120        if not self.queue.empty():
 121            msg = self.queue.get()
 122            self.fail(msg)
 123
 124    def clientRun(self, test_func):
 125        self.server_ready.wait()
 126        self.client_ready.set()
 127        self.clientSetUp()
 128        if not callable(test_func):
 129            raise TypeError, "test_func must be a callable function"
 130        try:
 131            test_func()
 132        except Exception, strerror:
 133            self.queue.put(strerror)
 134        self.clientTearDown()
 135
 136    def clientSetUp(self):
 137        raise NotImplementedError, "clientSetUp must be implemented."
 138
 139    def clientTearDown(self):
 140        self.done.set()
 141        thread.exit()
 142
 143class ThreadedTCPSocketTest(SocketTCPTest, ThreadableTest):
 144
 145    def __init__(self, methodName='runTest'):
 146        SocketTCPTest.__init__(self, methodName=methodName)
 147        ThreadableTest.__init__(self)
 148
 149    def clientSetUp(self):
 150        self.cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 151
 152    def clientTearDown(self):
 153        self.cli.close()
 154        self.cli = None
 155        ThreadableTest.clientTearDown(self)
 156
 157class ThreadedUDPSocketTest(SocketUDPTest, ThreadableTest):
 158
 159    def __init__(self, methodName='runTest'):
 160        SocketUDPTest.__init__(self, methodName=methodName)
 161        ThreadableTest.__init__(self)
 162
 163    def clientSetUp(self):
 164        self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
 165
 166class SocketConnectedTest(ThreadedTCPSocketTest):
 167
 168    def __init__(self, methodName='runTest'):
 169        ThreadedTCPSocketTest.__init__(self, methodName=methodName)
 170
 171    def setUp(self):
 172        ThreadedTCPSocketTest.setUp(self)
 173        # Indicate explicitly we're ready for the client thread to
 174        # proceed and then perform the blocking call to accept
 175        self.serverExplicitReady()
 176        conn, addr = self.serv.accept()
 177        self.cli_conn = conn
 178
 179    def tearDown(self):
 180        self.cli_conn.close()
 181        self.cli_conn = None
 182        ThreadedTCPSocketTest.tearDown(self)
 183
 184    def clientSetUp(self):
 185        ThreadedTCPSocketTest.clientSetUp(self)
 186        self.cli.connect((HOST, PORT))
 187        self.serv_conn = self.cli
 188
 189    def clientTearDown(self):
 190        self.serv_conn.close()
 191        self.serv_conn = None
 192        ThreadedTCPSocketTest.clientTearDown(self)
 193
 194class SocketPairTest(unittest.TestCase, ThreadableTest):
 195
 196    def __init__(self, methodName='runTest'):
 197        unittest.TestCase.__init__(self, methodName=methodName)
 198        ThreadableTest.__init__(self)
 199
 200    def setUp(self):
 201        self.serv, self.cli = socket.socketpair()
 202
 203    def tearDown(self):
 204        self.serv.close()
 205        self.serv = None
 206
 207    def clientSetUp(self):
 208        pass
 209
 210    def clientTearDown(self):
 211        self.cli.close()
 212        self.cli = None
 213        ThreadableTest.clientTearDown(self)
 214
 215
 216#######################################################################
 217## Begin Tests
 218
 219class GeneralModuleTests(unittest.TestCase):
 220
 221    def test_weakref(self):
 222        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 223        p = proxy(s)
 224        self.assertEqual(p.fileno(), s.fileno())
 225        s.close()
 226        s = None
 227        gc.collect()
 228        gc.collect()
 229        try:
 230            p.fileno()
 231        except ReferenceError:
 232            pass
 233        else:
 234            self.fail('Socket proxy still exists')
 235
 236    def testSocketError(self):
 237        # Testing socket module exceptions
 238        def raise_error(*args, **kwargs):
 239            raise socket.error
 240        def raise_herror(*args, **kwargs):
 241            raise socket.herror
 242        def raise_gaierror(*args, **kwargs):
 243            raise socket.gaierror
 244        self.failUnlessRaises(socket.error, raise_error,
 245                              "Error raising socket exception.")
 246        self.failUnlessRaises(socket.error, raise_herror,
 247                              "Error raising socket exception.")
 248        self.failUnlessRaises(socket.error, raise_gaierror,
 249                              "Error raising socket exception.")
 250
 251    def testCrucialConstants(self):
 252        # Testing for mission critical constants
 253        socket.AF_INET
 254        socket.SOCK_STREAM
 255        socket.SOCK_DGRAM
 256        socket.SOCK_RAW
 257        socket.SOCK_RDM
 258        socket.SOCK_SEQPACKET
 259        socket.SOL_SOCKET
 260        socket.SO_REUSEADDR
 261
 262    def testHostnameRes(self):
 263        # Testing hostname resolution mechanisms
 264        hostname = socket.gethostname()
 265        try:
 266            ip = socket.gethostbyname(hostname)
 267        except socket.error:
 268            # Probably name lookup wasn't set up right; skip this test
 269            return
 270        self.assert_(ip.find('.') >= 0, "Error resolving host to ip.")
 271        try:
 272            hname, aliases, ipaddrs = socket.gethostbyaddr(ip)
 273        except socket.error:
 274            # Probably a similar problem as above; skip this test
 275            return
 276        all_host_names = [hostname, hname] + aliases
 277        fqhn = socket.getfqdn(ip)
 278        if not fqhn in all_host_names:
 279            self.fail("Error testing host resolution mechanisms. (fqdn: %s, all: %s)" % (fqhn, repr(all_host_names)))
 280
 281    def testRefCountGetNameInfo(self):
 282        # Testing reference count for getnameinfo
 283        import sys
 284        if hasattr(sys, "getrefcount"):
 285            try:
 286                # On some versions, this loses a reference
 287                orig = sys.getrefcount(__name__)
 288                socket.getnameinfo(__name__,0)
 289            except SystemError:
 290                if sys.getrefcount(__name__) <> orig:
 291                    self.fail("socket.getnameinfo loses a reference")
 292
 293    def testInterpreterCrash(self):
 294        # Making sure getnameinfo doesn't crash the interpreter
 295        try:
 296            # On some versions, this crashes the interpreter.
 297            socket.getnameinfo(('x', 0, 0, 0), 0)
 298        except socket.error:
 299            pass
 300
 301    def testNtoH(self):
 302        # This just checks that htons etc. are their own inverse,
 303        # when looking at the lower 16 or 32 bits.
 304        sizes = {socket.htonl: 32, socket.ntohl: 32,
 305                 socket.htons: 16, socket.ntohs: 16}
 306        for func, size in sizes.items():
 307            mask = (1L<<size) - 1
 308            for i in (0, 1, 0xffff, ~0xffff, 2, 0x01234567, 0x76543210):
 309                self.assertEqual(i & mask, func(func(i&mask)) & mask)
 310                # should also accept ints, not only longs:
 311                self.assertEqual(i & mask, func(func(int(i&mask))) & mask)
 312
 313            swapped = func(mask)
 314            self.assertEqual(swapped & mask, mask)
 315            self.assertRaises(OverflowError, func, 1L<<34)
 316            if test_support.check_impl_detail(cpython=False):
 317                # XXX on 64-bit machines, it should also raise if it's an
 318                # out-of-bound int, not just an out-of-bound long, but
 319                # CPython 2.5/2.6 does not do it:
 320                self.assertRaises(OverflowError, func, 1<<34)
 321                # XXX the following check seems reasonable enough, but
 322                # it does not work on CPython either; disabled for now:
 323                #   self.assertRaises(OverflowError, func, 1<<size)
 324                #   self.assertRaises(OverflowError, func, 1L<<size)
 325                # the following was fixed on CPython >= 2.6:
 326                self.assertRaises((OverflowError, ValueError), func, -1)
 327
 328    def testGetServBy(self):
 329        eq = self.assertEqual
 330        # Find one service that exists, then check all the related interfaces.
 331        # I've ordered this by protocols that have both a tcp and udp
 332        # protocol, at least for modern Linuxes.
 333        if sys.platform in ('linux2', 'freebsd4', 'freebsd5', 'freebsd6',
 334                            'freebsd7', 'darwin'):
 335            # avoid the 'echo' service on this platform, as there is an
 336            # assumption breaking non-standard port/protocol entry
 337            services = ('daytime', 'qotd', 'domain')
 338        else:
 339            services = ('echo', 'daytime', 'domain')
 340        for service in services:
 341            try:
 342                port = socket.getservbyname(service, 'tcp')
 343                break
 344            except socket.error:
 345                pass
 346        else:
 347            raise socket.error
 348        # Try same call with optional protocol omitted
 349        port2 = socket.getservbyname(service)
 350        eq(port, port2)
 351        # Try udp, but don't barf it it doesn't exist
 352        try:
 353            udpport = socket.getservbyname(service, 'udp')
 354        except socket.error:
 355            udpport = None
 356        else:
 357            eq(udpport, port)
 358        # Now make sure the lookup by port returns the same service name
 359        eq(socket.getservbyport(port2), service)
 360        eq(socket.getservbyport(port, 'tcp'), service)
 361        if udpport is not None:
 362            eq(socket.getservbyport(udpport, 'udp'), service)
 363
 364    def testDefaultTimeout(self):
 365        # Testing default timeout
 366        # The default timeout should initially be None
 367        self.assertEqual(socket.getdefaulttimeout(), None)
 368        s = socket.socket()
 369        self.assertEqual(s.gettimeout(), None)
 370        s.close()
 371
 372        # Set the default timeout to 10, and see if it propagates
 373        socket.setdefaulttimeout(10)
 374        self.assertEqual(socket.getdefaulttimeout(), 10)
 375        s = socket.socket()
 376        self.assertEqual(s.gettimeout(), 10)
 377        s.close()
 378
 379        # Reset the default timeout to None, and see if it propagates
 380        socket.setdefaulttimeout(None)
 381        self.assertEqual(socket.getdefaulttimeout(), None)
 382        s = socket.socket()
 383        self.assertEqual(s.gettimeout(), None)
 384        s.close()
 385
 386        # Check that setting it to an invalid value raises ValueError
 387        self.assertRaises(ValueError, socket.setdefaulttimeout, -1)
 388
 389        # Check that setting it to an invalid type raises TypeError
 390        self.assertRaises(TypeError, socket.setdefaulttimeout, "spam")
 391
 392    def testIPv4toString(self):
 393        if not hasattr(socket, 'inet_pton'):
 394            return # No inet_pton() on this platform
 395        from socket import inet_aton as f, inet_pton, AF_INET
 396        g = lambda a: inet_pton(AF_INET, a)
 397
 398        self.assertEquals('\x00\x00\x00\x00', f('0.0.0.0'))
 399        self.assertEquals('\xff\x00\xff\x00', f('255.0.255.0'))
 400        self.assertEquals('\xaa\xaa\xaa\xaa', f('170.170.170.170'))
 401        self.assertEquals('\x01\x02\x03\x04', f('1.2.3.4'))
 402        self.assertEquals('\xff\xff\xff\xff', f('255.255.255.255'))
 403
 404        self.assertEquals('\x00\x00\x00\x00', g('0.0.0.0'))
 405        self.assertEquals('\xff\x00\xff\x00', g('255.0.255.0'))
 406        self.assertEquals('\xaa\xaa\xaa\xaa', g('170.170.170.170'))
 407        self.assertEquals('\xff\xff\xff\xff', g('255.255.255.255'))
 408
 409    def testIPv6toString(self):
 410        if not hasattr(socket, 'inet_pton'):
 411            return # No inet_pton() on this platform
 412        try:
 413            from socket import inet_pton, AF_INET6, has_ipv6
 414            if not has_ipv6:
 415                return
 416        except ImportError:
 417            return
 418        f = lambda a: inet_pton(AF_INET6, a)
 419
 420        self.assertEquals('\x00' * 16, f('::'))
 421        self.assertEquals('\x00' * 16, f('0::0'))
 422        self.assertEquals('\x00\x01' + '\x00' * 14, f('1::'))
 423        self.assertEquals(
 424            '\x45\xef\x76\xcb\x00\x1a\x56\xef\xaf\xeb\x0b\xac\x19\x24\xae\xae',
 425            f('45ef:76cb:1a:56ef:afeb:bac:1924:aeae')
 426        )
 427
 428    def testStringToIPv4(self):
 429        if not hasattr(socket, 'inet_ntop'):
 430            return # No inet_ntop() on this platform
 431        from socket import inet_ntoa as f, inet_ntop, AF_INET
 432        g = lambda a: inet_ntop(AF_INET, a)
 433
 434        self.assertEquals('1.0.1.0', f('\x01\x00\x01\x00'))
 435        self.assertEquals('170.85.170.85', f('\xaa\x55\xaa\x55'))
 436        self.assertEquals('255.255.255.255', f('\xff\xff\xff\xff'))
 437        self.assertEquals('1.2.3.4', f('\x01\x02\x03\x04'))
 438
 439        self.assertEquals('1.0.1.0', g('\x01\x00\x01\x00'))
 440        self.assertEquals('170.85.170.85', g('\xaa\x55\xaa\x55'))
 441        self.assertEquals('255.255.255.255', g('\xff\xff\xff\xff'))
 442
 443    def testStringToIPv6(self):
 444        if not hasattr(socket, 'inet_ntop'):
 445            return # No inet_ntop() on this platform
 446        try:
 447            from socket import inet_ntop, AF_INET6, has_ipv6
 448            if not has_ipv6:
 449                return
 450        except ImportError:
 451            return
 452        f = lambda a: inet_ntop(AF_INET6, a)
 453
 454        self.assertEquals('::', f('\x00' * 16))
 455        self.assertEquals('::1', f('\x00' * 15 + '\x01'))
 456        self.assertEquals(
 457            'aef:b01:506:1001:ffff:9997:55:170',
 458            f('\x0a\xef\x0b\x01\x05\x06\x10\x01\xff\xff\x99\x97\x00\x55\x01\x70')
 459        )
 460
 461    # XXX The following don't test module-level functionality...
 462
 463    def testSockName(self):
 464        # Testing getsockname()
 465        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 466        sock.bind(("0.0.0.0", PORT+1))
 467        name = sock.getsockname()
 468        # XXX(nnorwitz): http://tinyurl.com/os5jz seems to indicate
 469        # it reasonable to get the host's addr in addition to 0.0.0.0.
 470        # At least for eCos.  This is required for the S/390 to pass.
 471        my_ip_addr = socket.gethostbyname(socket.gethostname())
 472        self.assert_(name[0] in ("0.0.0.0", my_ip_addr), '%s invalid' % name[0])
 473        self.assertEqual(name[1], PORT+1)
 474
 475    def testGetSockOpt(self):
 476        # Testing getsockopt()
 477        # We know a socket should start without reuse==0
 478        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 479        reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
 480        self.failIf(reuse != 0, "initial mode is reuse")
 481
 482    def testSetSockOpt(self):
 483        # Testing setsockopt()
 484        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 485        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
 486        reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
 487        self.failIf(reuse == 0, "failed to set reuse mode")
 488
 489    def testSendAfterClose(self):
 490        # testing send() after close() with timeout
 491        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 492        sock.settimeout(1)
 493        sock.close()
 494        self.assertRaises(socket.error, sock.send, "spam")
 495
 496    def testNewAttributes(self):
 497        # testing .family, .type and .protocol
 498        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 499        self.assertEqual(sock.family, socket.AF_INET)
 500        self.assertEqual(sock.type, socket.SOCK_STREAM)
 501        self.assertEqual(sock.proto, 0)
 502        sock.close()
 503
 504class BasicTCPTest(SocketConnectedTest):
 505
 506    def __init__(self, methodName='runTest'):
 507        SocketConnectedTest.__init__(self, methodName=methodName)
 508
 509    def testRecv(self):
 510        # Testing large receive over TCP
 511        msg = self.cli_conn.recv(1024)
 512        self.assertEqual(msg, MSG)
 513
 514    def _testRecv(self):
 515        self.serv_conn.send(MSG)
 516
 517    def testOverFlowRecv(self):
 518        # Testing receive in chunks over TCP
 519        seg1 = self.cli_conn.recv(len(MSG) - 3)
 520        seg2 = self.cli_conn.recv(1024)
 521        msg = seg1 + seg2
 522        self.assertEqual(msg, MSG)
 523
 524    def _testOverFlowRecv(self):
 525        self.serv_conn.send(MSG)
 526
 527    def testRecvFrom(self):
 528        # Testing large recvfrom() over TCP
 529        msg, addr = self.cli_conn.recvfrom(1024)
 530        self.assertEqual(msg, MSG)
 531
 532    def _testRecvFrom(self):
 533        self.serv_conn.send(MSG)
 534
 535    def testOverFlowRecvFrom(self):
 536        # Testing recvfrom() in chunks over TCP
 537        seg1, addr = self.cli_conn.recvfrom(len(MSG)-3)
 538        seg2, addr = self.cli_conn.recvfrom(1024)
 539        msg = seg1 + seg2
 540        self.assertEqual(msg, MSG)
 541
 542    def _testOverFlowRecvFrom(self):
 543        self.serv_conn.send(MSG)
 544
 545    def testSendAll(self):
 546        # Testing sendall() with a 2048 byte string over TCP
 547        msg = ''
 548        while 1:
 549            read = self.cli_conn.recv(1024)
 550            if not read:
 551                break
 552            msg += read
 553        self.assertEqual(msg, 'f' * 2048)
 554
 555    def _testSendAll(self):
 556        big_chunk = 'f' * 2048
 557        self.serv_conn.sendall(big_chunk)
 558
 559    def testFromFd(self):
 560        # Testing fromfd()
 561        if not hasattr(socket, "fromfd"):
 562            return # On Windows, this doesn't exist
 563        fd = self.cli_conn.fileno()
 564        sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
 565        msg = sock.recv(1024)
 566        self.assertEqual(msg, MSG)
 567
 568    def _testFromFd(self):
 569        self.serv_conn.send(MSG)
 570
 571    def testShutdown(self):
 572        # Testing shutdown()
 573        msg = self.cli_conn.recv(1024)
 574        self.assertEqual(msg, MSG)
 575
 576    def _testShutdown(self):
 577        self.serv_conn.send(MSG)
 578        self.serv_conn.shutdown(2)
 579
 580class BasicUDPTest(ThreadedUDPSocketTest):
 581
 582    def __init__(self, methodName='runTest'):
 583        ThreadedUDPSocketTest.__init__(self, methodName=methodName)
 584
 585    def testSendtoAndRecv(self):
 586        # Testing sendto() and Recv() over UDP
 587        msg = self.serv.recv(len(MSG))
 588        self.assertEqual(msg, MSG)
 589
 590    def _testSendtoAndRecv(self):
 591        self.cli.sendto(MSG, 0, (HOST, PORT))
 592
 593    def testRecvFrom(self):
 594        # Testing recvfrom() over UDP
 595        msg, addr = self.serv.recvfrom(len(MSG))
 596        self.assertEqual(msg, MSG)
 597
 598    def _testRecvFrom(self):
 599        self.cli.sendto(MSG, 0, (HOST, PORT))
 600
 601    def testRecvFromNegative(self):
 602        # Negative lengths passed to recvfrom should give ValueError.
 603        self.assertRaises(ValueError, self.serv.recvfrom, -1)
 604
 605    def _testRecvFromNegative(self):
 606        self.cli.sendto(MSG, 0, (HOST, PORT))
 607
 608class TCPCloserTest(ThreadedTCPSocketTest):
 609
 610    def testClose(self):
 611        conn, addr = self.serv.accept()
 612        conn.close()
 613
 614        sd = self.cli
 615        read, write, err = select.select([sd], [], [], 1.0)
 616        self.assertEqual(read, [sd])
 617        self.assertEqual(sd.recv(1), '')
 618
 619    def _testClose(self):
 620        self.cli.connect((HOST, PORT))
 621        time.sleep(1.0)
 622
 623class BasicSocketPairTest(SocketPairTest):
 624
 625    def __init__(self, methodName='runTest'):
 626        SocketPairTest.__init__(self, methodName=methodName)
 627
 628    def testRecv(self):
 629        msg = self.serv.recv(1024)
 630        self.assertEqual(msg, MSG)
 631
 632    def _testRecv(self):
 633        self.cli.send(MSG)
 634
 635    def testSend(self):
 636        self.serv.send(MSG)
 637
 638    def _testSend(self):
 639        msg = self.cli.recv(1024)
 640        self.assertEqual(msg, MSG)
 641
 642class NonBlockingTCPTests(ThreadedTCPSocketTest):
 643
 644    def __init__(self, methodName='runTest'):
 645        ThreadedTCPSocketTest.__init__(self, methodName=methodName)
 646
 647    def testSetBlocking(self):
 648        # Testing whether set blocking works
 649        self.serv.setblocking(0)
 650        start = time.time()
 651        try:
 652            self.serv.accept()
 653        except socket.error:
 654            pass
 655        end = time.time()
 656        self.assert_((end - start) < 1.0, "Error setting non-blocking mode.")
 657
 658    def _testSetBlocking(self):
 659        pass
 660
 661    def testAccept(self):
 662        # Testing non-blocking accept
 663        self.serv.setblocking(0)
 664        try:
 665            conn, addr = self.serv.accept()
 666        except socket.error:
 667            pass
 668        else:
 669            self.fail("Error trying to do non-blocking accept.")
 670        read, write, err = select.select([self.serv], [], [])
 671        if self.serv in read:
 672            conn, addr = self.serv.accept()
 673        else:
 674            self.fail("Error trying to do accept after select.")
 675
 676    def _testAccept(self):
 677        time.sleep(0.1)
 678        self.cli.connect((HOST, PORT))
 679
 680    def testConnect(self):
 681        # Testing non-blocking connect
 682        conn, addr = self.serv.accept()
 683
 684    def _testConnect(self):
 685        self.cli.settimeout(10)
 686        self.cli.connect((HOST, PORT))
 687
 688    def testRecv(self):
 689        # Testing non-blocking recv
 690        conn, addr = self.serv.accept()
 691        conn.setblocking(0)
 692        try:
 693            msg = conn.recv(len(MSG))
 694        except socket.error:
 695            pass
 696        else:
 697            self.fail("Error trying to do non-blocking recv.")
 698        read, write, err = select.select([conn], [], [])
 699        if conn in read:
 700            msg = conn.recv(len(MSG))
 701            self.assertEqual(msg, MSG)
 702        else:
 703            self.fail("Error during select call to non-blocking socket.")
 704
 705    def _testRecv(self):
 706        self.cli.connect((HOST, PORT))
 707        time.sleep(0.1)
 708        self.cli.send(MSG)
 709
 710class FileObjectClassTestCase(SocketConnectedTest):
 711
 712    bufsize = -1 # Use default buffer size
 713
 714    def __init__(self, methodName='runTest'):
 715        SocketConnectedTest.__init__(self, methodName=methodName)
 716
 717    def setUp(self):
 718        SocketConnectedTest.setUp(self)
 719        self.serv_file = self.cli_conn.makefile('rb', self.bufsize)
 720
 721    def tearDown(self):
 722        self.serv_file.close()
 723        self.assert_(self.serv_file.closed)
 724        self.serv_file = None
 725        SocketConnectedTest.tearDown(self)
 726
 727    def clientSetUp(self):
 728        SocketConnectedTest.clientSetUp(self)
 729        self.cli_file = self.serv_conn.makefile('wb')
 730
 731    def clientTearDown(self):
 732        self.cli_file.close()
 733        self.assert_(self.cli_file.closed)
 734        self.cli_file = None
 735        SocketConnectedTest.clientTearDown(self)
 736
 737    def testSmallRead(self):
 738        # Performing small file read test
 739        first_seg = self.serv_file.read(len(MSG)-3)
 740        second_seg = self.serv_file.read(3)
 741        msg = first_seg + second_seg
 742        self.assertEqual(msg, MSG)
 743
 744    def _testSmallRead(self):
 745        self.cli_file.write(MSG)
 746        self.cli_file.flush()
 747
 748    def testFullRead(self):
 749        # read until EOF
 750        msg = self.serv_file.read()
 751        self.assertEqual(msg, MSG)
 752
 753    def _testFullRead(self):
 754        self.cli_file.write(MSG)
 755        self.cli_file.close()
 756
 757    def testUnbufferedRead(self):
 758        # Performing unbuffered file read test
 759        buf = ''
 760        while 1:
 761            char = self.serv_file.read(1)
 762            if not char:
 763                break
 764            buf += char
 765        self.assertEqual(buf, MSG)
 766
 767    def _testUnbufferedRead(self):
 768        self.cli_file.write(MSG)
 769        self.cli_file.flush()
 770
 771    def testReadline(self):
 772        # Performing file readline test
 773        line = self.serv_file.readline()
 774        self.assertEqual(line, MSG)
 775
 776    def _testReadline(self):
 777        self.cli_file.write(MSG)
 778        self.cli_file.flush()
 779
 780    def testClosedAttr(self):
 781        self.assert_(not self.serv_file.closed)
 782
 783    def _testClosedAttr(self):
 784        self.assert_(not self.cli_file.closed)
 785
 786class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase):
 787
 788    """Repeat the tests from FileObjectClassTestCase with bufsize==0.
 789
 790    In this case (and in this case only), it should be possible to
 791    create a file object, read a line from it, create another file
 792    object, read another line from it, without loss of data in the
 793    first file object's buffer.  Note that httplib relies on this
 794    when reading multiple requests from the same socket."""
 795
 796    bufsize = 0 # Use unbuffered mode
 797
 798    def testUnbufferedReadline(self):
 799        # Read a line, create a new file object, read another line with it
 800        line = self.serv_file.readline() # first line
 801        self.assertEqual(line, "A. " + MSG) # first line
 802        self.serv_file = self.cli_conn.makefile('rb', 0)
 803        line = self.serv_file.readline() # second line
 804        self.assertEqual(line, "B. " + MSG) # second line
 805
 806    def _testUnbufferedReadline(self):
 807        self.cli_file.write("A. " + MSG)
 808        self.cli_file.write("B. " + MSG)
 809        self.cli_file.flush()
 810
 811class LineBufferedFileObjectClassTestCase(FileObjectClassTestCase):
 812
 813    bufsize = 1 # Default-buffered for reading; line-buffered for writing
 814
 815
 816class SmallBufferedFileObjectClassTestCase(FileObjectClassTestCase):
 817
 818    bufsize = 2 # Exercise the buffering code
 819
 820
 821class Urllib2FileobjectTest(unittest.TestCase):
 822
 823    # urllib2.HTTPHandler has "borrowed" socket._fileobject, and requires that
 824    # it close the socket if the close c'tor argument is true
 825
 826    def testClose(self):
 827        class MockSocket:
 828            closed = False
 829            def flush(self): pass
 830            def close(self): self.closed = True
 831            def _drop(self): pass
 832            def _reuse(self): pass
 833
 834        # must not close unless we request it: the original use of _fileobject
 835        # by module socket requires that the underlying socket not be closed until
 836        # the _socketobject that created the _fileobject is closed
 837        s = MockSocket()
 838        f = socket._fileobject(s)
 839        f.close()
 840        self.assert_(not s.closed)
 841
 842        s = MockSocket()
 843        f = socket._fileobject(s, close=True)
 844        f.close()
 845        self.assert_(s.closed)
 846
 847class TCPTimeoutTest(SocketTCPTest):
 848
 849    def testTCPTimeout(self):
 850        def raise_timeout(*args, **kwargs):
 851            self.serv.settimeout(1.0)
 852            self.serv.accept()
 853        self.failUnlessRaises(socket.timeout, raise_timeout,
 854                              "Error generating a timeout exception (TCP)")
 855
 856    def testTimeoutZero(self):
 857        ok = False
 858        try:
 859            self.serv.settimeout(0.0)
 860            foo = self.serv.accept()
 861        except socket.timeout:
 862            self.fail("caught timeout instead of error (TCP)")
 863        except socket.error:
 864            ok = True
 865        except:
 866            self.fail("caught unexpected exception (TCP)")
 867        if not ok:
 868            self.fail("accept() returned success when we did not expect it")
 869
 870    def testInterruptedTimeout(self):
 871        # XXX I don't know how to do this test on MSWindows or any other
 872        # plaform that doesn't support signal.alarm() or os.kill(), though
 873        # the bug should have existed on all platforms.
 874        if not hasattr(signal, "alarm"):
 875            return                  # can only test on *nix
 876        self.serv.settimeout(5.0)   # must be longer than alarm
 877        class Alarm(Exception):
 878            pass
 879        def alarm_handler(signal, frame):
 880            raise Alarm
 881        old_alarm = signal.signal(signal.SIGALRM, alarm_handler)
 882        try:
 883            signal.alarm(2)    # POSIX allows alarm to be up to 1 second early
 884            try:
 885                foo = self.serv.accept()
 886            except socket.timeout:
 887                self.fail("caught timeout instead of Alarm")
 888            except Alarm:
 889                pass
 890            except:
 891                self.fail("caught other exception instead of Alarm")
 892            else:
 893                self.fail("nothing caught")
 894            signal.alarm(0)         # shut off alarm
 895        except Alarm:
 896            self.fail("got Alarm in wrong place")
 897        finally:
 898            # no alarm can be pending.  Safe to restore old handler.
 899            signal.signal(signal.SIGALRM, old_alarm)
 900
 901class UDPTimeoutTest(SocketTCPTest):
 902
 903    def testUDPTimeout(self):
 904        def raise_timeout(*args, **kwargs):
 905            self.serv.settimeout(1.0)
 906            self.serv.recv(1024)
 907        self.failUnlessRaises(socket.timeout, raise_timeout,
 908                              "Error generating a timeout exception (UDP)")
 909
 910    def testTimeoutZero(self):
 911        ok = False
 912        try:
 913            self.serv.settimeout(0.0)
 914            foo = self.serv.recv(1024)
 915        except socket.timeout:
 916            self.fail("caught timeout instead of error (UDP)")
 917        except socket.error:
 918            ok = True
 919        except:
 920            self.fail("caught unexpected exception (UDP)")
 921        if not ok:
 922            self.fail("recv() returned success when we did not expect it")
 923
 924class TestExceptions(unittest.TestCase):
 925
 926    def testExceptionTree(self):
 927        self.assert_(issubclass(socket.error, Exception))
 928        self.assert_(issubclass(socket.herror, socket.error))
 929        self.assert_(issubclass(socket.gaierror, socket.error))
 930        self.assert_(issubclass(socket.timeout, socket.error))
 931
 932class TestLinuxAbstractNamespace(unittest.TestCase):
 933
 934    UNIX_PATH_MAX = 108
 935
 936    def testLinuxAbstractNamespace(self):
 937        address = "\x00python-test-hello\x00\xff"
 938        s1 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
 939        s1.bind(address)
 940        s1.listen(1)
 941        s2 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
 942        s2.connect(s1.getsockname())
 943        s1.accept()
 944        self.assertEqual(s1.getsockname(), address)
 945        self.assertEqual(s2.getpeername(), address)
 946
 947    def testMaxName(self):
 948        address = "\x00" + "h" * (self.UNIX_PATH_MAX - 1)
 949        s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
 950        s.bind(address)
 951        self.assertEqual(s.getsockname(), address)
 952
 953    def testNameOverflow(self):
 954        address = "\x00" + "h" * self.UNIX_PATH_MAX
 955        s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
 956        self.assertRaises(socket.error, s.bind, address)
 957
 958
 959class BufferIOTest(SocketConnectedTest):
 960    """
 961    Test the buffer versions of socket.recv() and socket.send().
 962    """
 963    def __init__(self, methodName='runTest'):
 964        SocketConnectedTest.__init__(self, methodName=methodName)
 965
 966    def testRecvInto(self):
 967        buf = array.array('c', ' '*1024)
 968        nbytes = self.cli_conn.recv_into(buf)
 969        self.assertEqual(nbytes, len(MSG))
 970        msg = buf.tostring()[:len(MSG)]
 971        self.assertEqual(msg, MSG)
 972
 973    def _testRecvInto(self):
 974        buf = buffer(MSG)
 975        self.serv_conn.send(buf)
 976
 977    def testRecvFromInto(self):
 978        buf = array.array('c', ' '*1024)
 979        nbytes, addr = self.cli_conn.recvfrom_into(buf)
 980        self.assertEqual(nbytes, len(MSG))
 981        msg = buf.tostring()[:len(MSG)]
 982        self.assertEqual(msg, MSG)
 983
 984    def _testRecvFromInto(self):
 985        buf = buffer(MSG)
 986        self.serv_conn.send(buf)
 987
 988def test_main():
 989    tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
 990             TestExceptions, BufferIOTest]
 991    if sys.platform != 'mac':
 992        tests.extend([ BasicUDPTest, UDPTimeoutTest ])
 993
 994    tests.extend([
 995        NonBlockingTCPTests,
 996        FileObjectClassTestCase,
 997        UnbufferedFileObjectClassTestCase,
 998        LineBufferedFileObjectClassTestCase,
 999        SmallBufferedFileObjectClassTestCase,
1000        Urllib2FileobjectTest,
1001    ])
1002    if hasattr(socket, "socketpair"):
1003        tests.append(BasicSocketPairTest)
1004    if sys.platform == 'linux2':
1005        tests.append(TestLinuxAbstractNamespace)
1006
1007    thread_info = test_support.threading_setup()
1008    test_support.run_unittest(*tests)
1009    test_support.threading_cleanup(*thread_info)
1010
1011if __name__ == "__main__":
1012    test_main()