PageRenderTime 140ms CodeModel.GetById 11ms app.highlight 114ms RepoModel.GetById 1ms app.codeStats 1ms

/Lib/test/test_socket.py

http://unladen-swallow.googlecode.com/
Python | 1245 lines | 1094 code | 81 blank | 70 comment | 38 complexity | d251cb5f4a40c370d1f7223e6a1f21e9 MD5 | raw file
   1#!/usr/bin/env python
   2
   3import unittest
   4from test import test_support
   5
   6import errno
   7import socket
   8import select
   9import thread, threading
  10import time
  11import traceback
  12import Queue
  13import sys
  14import os
  15import array
  16from weakref import proxy
  17import signal
  18
  19HOST = test_support.HOST
  20MSG = 'Michael Gilfix was here\n'
  21
  22class SocketTCPTest(unittest.TestCase):
  23
  24    def setUp(self):
  25        self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  26        self.port = test_support.bind_port(self.serv)
  27        self.serv.listen(1)
  28
  29    def tearDown(self):
  30        self.serv.close()
  31        self.serv = None
  32
  33class SocketUDPTest(unittest.TestCase):
  34
  35    def setUp(self):
  36        self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  37        self.port = test_support.bind_port(self.serv)
  38
  39    def tearDown(self):
  40        self.serv.close()
  41        self.serv = None
  42
  43class ThreadableTest:
  44    """Threadable Test class
  45
  46    The ThreadableTest class makes it easy to create a threaded
  47    client/server pair from an existing unit test. To create a
  48    new threaded class from an existing unit test, use multiple
  49    inheritance:
  50
  51        class NewClass (OldClass, ThreadableTest):
  52            pass
  53
  54    This class defines two new fixture functions with obvious
  55    purposes for overriding:
  56
  57        clientSetUp ()
  58        clientTearDown ()
  59
  60    Any new test functions within the class must then define
  61    tests in pairs, where the test name is preceeded with a
  62    '_' to indicate the client portion of the test. Ex:
  63
  64        def testFoo(self):
  65            # Server portion
  66
  67        def _testFoo(self):
  68            # Client portion
  69
  70    Any exceptions raised by the clients during their tests
  71    are caught and transferred to the main thread to alert
  72    the testing framework.
  73
  74    Note, the server setup function cannot call any blocking
  75    functions that rely on the client thread during setup,
  76    unless serverExplicitReady() is called just before
  77    the blocking call (such as in setting up a client/server
  78    connection and performing the accept() in setUp().
  79    """
  80
  81    def __init__(self):
  82        # Swap the true setup function
  83        self.__setUp = self.setUp
  84        self.__tearDown = self.tearDown
  85        self.setUp = self._setUp
  86        self.tearDown = self._tearDown
  87
  88    def serverExplicitReady(self):
  89        """This method allows the server to explicitly indicate that
  90        it wants the client thread to proceed. This is useful if the
  91        server is about to execute a blocking routine that is
  92        dependent upon the client thread during its setup routine."""
  93        self.server_ready.set()
  94
  95    def _setUp(self):
  96        self.server_ready = threading.Event()
  97        self.client_ready = threading.Event()
  98        self.done = threading.Event()
  99        self.queue = Queue.Queue(1)
 100
 101        # Do some munging to start the client test.
 102        methodname = self.id()
 103        i = methodname.rfind('.')
 104        methodname = methodname[i+1:]
 105        test_method = getattr(self, '_' + methodname)
 106        self.client_thread = thread.start_new_thread(
 107            self.clientRun, (test_method,))
 108
 109        self.__setUp()
 110        if not self.server_ready.is_set():
 111            self.server_ready.set()
 112        self.client_ready.wait()
 113
 114    def _tearDown(self):
 115        self.__tearDown()
 116        self.done.wait()
 117
 118        if not self.queue.empty():
 119            msg = self.queue.get()
 120            self.fail(msg)
 121
 122    def clientRun(self, test_func):
 123        self.server_ready.wait()
 124        self.client_ready.set()
 125        self.clientSetUp()
 126        if not callable(test_func):
 127            raise TypeError, "test_func must be a callable function"
 128        try:
 129            test_func()
 130        except Exception, strerror:
 131            self.queue.put(strerror)
 132        self.clientTearDown()
 133
 134    def clientSetUp(self):
 135        raise NotImplementedError, "clientSetUp must be implemented."
 136
 137    def clientTearDown(self):
 138        self.done.set()
 139        thread.exit()
 140
 141class ThreadedTCPSocketTest(SocketTCPTest, ThreadableTest):
 142
 143    def __init__(self, methodName='runTest'):
 144        SocketTCPTest.__init__(self, methodName=methodName)
 145        ThreadableTest.__init__(self)
 146
 147    def clientSetUp(self):
 148        self.cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 149
 150    def clientTearDown(self):
 151        self.cli.close()
 152        self.cli = None
 153        ThreadableTest.clientTearDown(self)
 154
 155class ThreadedUDPSocketTest(SocketUDPTest, ThreadableTest):
 156
 157    def __init__(self, methodName='runTest'):
 158        SocketUDPTest.__init__(self, methodName=methodName)
 159        ThreadableTest.__init__(self)
 160
 161    def clientSetUp(self):
 162        self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
 163
 164class SocketConnectedTest(ThreadedTCPSocketTest):
 165
 166    def __init__(self, methodName='runTest'):
 167        ThreadedTCPSocketTest.__init__(self, methodName=methodName)
 168
 169    def setUp(self):
 170        ThreadedTCPSocketTest.setUp(self)
 171        # Indicate explicitly we're ready for the client thread to
 172        # proceed and then perform the blocking call to accept
 173        self.serverExplicitReady()
 174        conn, addr = self.serv.accept()
 175        self.cli_conn = conn
 176
 177    def tearDown(self):
 178        self.cli_conn.close()
 179        self.cli_conn = None
 180        ThreadedTCPSocketTest.tearDown(self)
 181
 182    def clientSetUp(self):
 183        ThreadedTCPSocketTest.clientSetUp(self)
 184        self.cli.connect((HOST, self.port))
 185        self.serv_conn = self.cli
 186
 187    def clientTearDown(self):
 188        self.serv_conn.close()
 189        self.serv_conn = None
 190        ThreadedTCPSocketTest.clientTearDown(self)
 191
 192class SocketPairTest(unittest.TestCase, ThreadableTest):
 193
 194    def __init__(self, methodName='runTest'):
 195        unittest.TestCase.__init__(self, methodName=methodName)
 196        ThreadableTest.__init__(self)
 197
 198    def setUp(self):
 199        self.serv, self.cli = socket.socketpair()
 200
 201    def tearDown(self):
 202        self.serv.close()
 203        self.serv = None
 204
 205    def clientSetUp(self):
 206        pass
 207
 208    def clientTearDown(self):
 209        self.cli.close()
 210        self.cli = None
 211        ThreadableTest.clientTearDown(self)
 212
 213
 214#######################################################################
 215## Begin Tests
 216
 217class GeneralModuleTests(unittest.TestCase):
 218
 219    def test_weakref(self):
 220        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 221        p = proxy(s)
 222        self.assertEqual(p.fileno(), s.fileno())
 223        s.close()
 224        s = None
 225        try:
 226            p.fileno()
 227        except ReferenceError:
 228            pass
 229        else:
 230            self.fail('Socket proxy still exists')
 231
 232    def testSocketError(self):
 233        # Testing socket module exceptions
 234        def raise_error(*args, **kwargs):
 235            raise socket.error
 236        def raise_herror(*args, **kwargs):
 237            raise socket.herror
 238        def raise_gaierror(*args, **kwargs):
 239            raise socket.gaierror
 240        self.failUnlessRaises(socket.error, raise_error,
 241                              "Error raising socket exception.")
 242        self.failUnlessRaises(socket.error, raise_herror,
 243                              "Error raising socket exception.")
 244        self.failUnlessRaises(socket.error, raise_gaierror,
 245                              "Error raising socket exception.")
 246
 247    def testCrucialConstants(self):
 248        # Testing for mission critical constants
 249        socket.AF_INET
 250        socket.SOCK_STREAM
 251        socket.SOCK_DGRAM
 252        socket.SOCK_RAW
 253        socket.SOCK_RDM
 254        socket.SOCK_SEQPACKET
 255        socket.SOL_SOCKET
 256        socket.SO_REUSEADDR
 257
 258    def testHostnameRes(self):
 259        # Testing hostname resolution mechanisms
 260        hostname = socket.gethostname()
 261        try:
 262            ip = socket.gethostbyname(hostname)
 263        except socket.error:
 264            # Probably name lookup wasn't set up right; skip this test
 265            return
 266        self.assert_(ip.find('.') >= 0, "Error resolving host to ip.")
 267        try:
 268            hname, aliases, ipaddrs = socket.gethostbyaddr(ip)
 269        except socket.error:
 270            # Probably a similar problem as above; skip this test
 271            return
 272        all_host_names = [hostname, hname] + aliases
 273        fqhn = socket.getfqdn(ip)
 274        if not fqhn in all_host_names:
 275            self.fail("Error testing host resolution mechanisms. (fqdn: %s, all: %s)" % (fqhn, repr(all_host_names)))
 276
 277    def testRefCountGetNameInfo(self):
 278        # Testing reference count for getnameinfo
 279        if hasattr(sys, "getrefcount"):
 280            try:
 281                # On some versions, this loses a reference
 282                orig = sys.getrefcount(__name__)
 283                socket.getnameinfo(__name__,0)
 284            except TypeError:
 285                if sys.getrefcount(__name__) <> orig:
 286                    self.fail("socket.getnameinfo loses a reference")
 287
 288    def testInterpreterCrash(self):
 289        # Making sure getnameinfo doesn't crash the interpreter
 290        try:
 291            # On some versions, this crashes the interpreter.
 292            socket.getnameinfo(('x', 0, 0, 0), 0)
 293        except socket.error:
 294            pass
 295
 296    def testNtoH(self):
 297        # This just checks that htons etc. are their own inverse,
 298        # when looking at the lower 16 or 32 bits.
 299        sizes = {socket.htonl: 32, socket.ntohl: 32,
 300                 socket.htons: 16, socket.ntohs: 16}
 301        for func, size in sizes.items():
 302            mask = (1L<<size) - 1
 303            for i in (0, 1, 0xffff, ~0xffff, 2, 0x01234567, 0x76543210):
 304                self.assertEqual(i & mask, func(func(i&mask)) & mask)
 305
 306            swapped = func(mask)
 307            self.assertEqual(swapped & mask, mask)
 308            self.assertRaises(OverflowError, func, 1L<<34)
 309
 310    def testNtoHErrors(self):
 311        good_values = [ 1, 2, 3, 1L, 2L, 3L ]
 312        bad_values = [ -1, -2, -3, -1L, -2L, -3L ]
 313        for k in good_values:
 314            socket.ntohl(k)
 315            socket.ntohs(k)
 316            socket.htonl(k)
 317            socket.htons(k)
 318        for k in bad_values:
 319            self.assertRaises(OverflowError, socket.ntohl, k)
 320            self.assertRaises(OverflowError, socket.ntohs, k)
 321            self.assertRaises(OverflowError, socket.htonl, k)
 322            self.assertRaises(OverflowError, socket.htons, k)
 323
 324    def testGetServBy(self):
 325        eq = self.assertEqual
 326        # Find one service that exists, then check all the related interfaces.
 327        # I've ordered this by protocols that have both a tcp and udp
 328        # protocol, at least for modern Linuxes.
 329        if sys.platform in ('linux2', 'freebsd4', 'freebsd5', 'freebsd6',
 330                            'freebsd7', 'freebsd8', 'darwin'):
 331            # avoid the 'echo' service on this platform, as there is an
 332            # assumption breaking non-standard port/protocol entry
 333            services = ('daytime', 'qotd', 'domain')
 334        else:
 335            services = ('echo', 'daytime', 'domain')
 336        for service in services:
 337            try:
 338                port = socket.getservbyname(service, 'tcp')
 339                break
 340            except socket.error:
 341                pass
 342        else:
 343            raise socket.error
 344        # Try same call with optional protocol omitted
 345        port2 = socket.getservbyname(service)
 346        eq(port, port2)
 347        # Try udp, but don't barf it it doesn't exist
 348        try:
 349            udpport = socket.getservbyname(service, 'udp')
 350        except socket.error:
 351            udpport = None
 352        else:
 353            eq(udpport, port)
 354        # Now make sure the lookup by port returns the same service name
 355        eq(socket.getservbyport(port2), service)
 356        eq(socket.getservbyport(port, 'tcp'), service)
 357        if udpport is not None:
 358            eq(socket.getservbyport(udpport, 'udp'), service)
 359
 360    def testDefaultTimeout(self):
 361        # Testing default timeout
 362        # The default timeout should initially be None
 363        self.assertEqual(socket.getdefaulttimeout(), None)
 364        s = socket.socket()
 365        self.assertEqual(s.gettimeout(), None)
 366        s.close()
 367
 368        # Set the default timeout to 10, and see if it propagates
 369        socket.setdefaulttimeout(10)
 370        self.assertEqual(socket.getdefaulttimeout(), 10)
 371        s = socket.socket()
 372        self.assertEqual(s.gettimeout(), 10)
 373        s.close()
 374
 375        # Reset the default timeout to None, and see if it propagates
 376        socket.setdefaulttimeout(None)
 377        self.assertEqual(socket.getdefaulttimeout(), None)
 378        s = socket.socket()
 379        self.assertEqual(s.gettimeout(), None)
 380        s.close()
 381
 382        # Check that setting it to an invalid value raises ValueError
 383        self.assertRaises(ValueError, socket.setdefaulttimeout, -1)
 384
 385        # Check that setting it to an invalid type raises TypeError
 386        self.assertRaises(TypeError, socket.setdefaulttimeout, "spam")
 387
 388    def testIPv4toString(self):
 389        if not hasattr(socket, 'inet_pton'):
 390            return # No inet_pton() on this platform
 391        from socket import inet_aton as f, inet_pton, AF_INET
 392        g = lambda a: inet_pton(AF_INET, a)
 393
 394        self.assertEquals('\x00\x00\x00\x00', f('0.0.0.0'))
 395        self.assertEquals('\xff\x00\xff\x00', f('255.0.255.0'))
 396        self.assertEquals('\xaa\xaa\xaa\xaa', f('170.170.170.170'))
 397        self.assertEquals('\x01\x02\x03\x04', f('1.2.3.4'))
 398        self.assertEquals('\xff\xff\xff\xff', f('255.255.255.255'))
 399
 400        self.assertEquals('\x00\x00\x00\x00', g('0.0.0.0'))
 401        self.assertEquals('\xff\x00\xff\x00', g('255.0.255.0'))
 402        self.assertEquals('\xaa\xaa\xaa\xaa', g('170.170.170.170'))
 403        self.assertEquals('\xff\xff\xff\xff', g('255.255.255.255'))
 404
 405    def testIPv6toString(self):
 406        if not hasattr(socket, 'inet_pton'):
 407            return # No inet_pton() on this platform
 408        try:
 409            from socket import inet_pton, AF_INET6, has_ipv6
 410            if not has_ipv6:
 411                return
 412        except ImportError:
 413            return
 414        f = lambda a: inet_pton(AF_INET6, a)
 415
 416        self.assertEquals('\x00' * 16, f('::'))
 417        self.assertEquals('\x00' * 16, f('0::0'))
 418        self.assertEquals('\x00\x01' + '\x00' * 14, f('1::'))
 419        self.assertEquals(
 420            '\x45\xef\x76\xcb\x00\x1a\x56\xef\xaf\xeb\x0b\xac\x19\x24\xae\xae',
 421            f('45ef:76cb:1a:56ef:afeb:bac:1924:aeae')
 422        )
 423
 424    def testStringToIPv4(self):
 425        if not hasattr(socket, 'inet_ntop'):
 426            return # No inet_ntop() on this platform
 427        from socket import inet_ntoa as f, inet_ntop, AF_INET
 428        g = lambda a: inet_ntop(AF_INET, a)
 429
 430        self.assertEquals('1.0.1.0', f('\x01\x00\x01\x00'))
 431        self.assertEquals('170.85.170.85', f('\xaa\x55\xaa\x55'))
 432        self.assertEquals('255.255.255.255', f('\xff\xff\xff\xff'))
 433        self.assertEquals('1.2.3.4', f('\x01\x02\x03\x04'))
 434
 435        self.assertEquals('1.0.1.0', g('\x01\x00\x01\x00'))
 436        self.assertEquals('170.85.170.85', g('\xaa\x55\xaa\x55'))
 437        self.assertEquals('255.255.255.255', g('\xff\xff\xff\xff'))
 438
 439    def testStringToIPv6(self):
 440        if not hasattr(socket, 'inet_ntop'):
 441            return # No inet_ntop() on this platform
 442        try:
 443            from socket import inet_ntop, AF_INET6, has_ipv6
 444            if not has_ipv6:
 445                return
 446        except ImportError:
 447            return
 448        f = lambda a: inet_ntop(AF_INET6, a)
 449
 450        self.assertEquals('::', f('\x00' * 16))
 451        self.assertEquals('::1', f('\x00' * 15 + '\x01'))
 452        self.assertEquals(
 453            'aef:b01:506:1001:ffff:9997:55:170',
 454            f('\x0a\xef\x0b\x01\x05\x06\x10\x01\xff\xff\x99\x97\x00\x55\x01\x70')
 455        )
 456
 457    # XXX The following don't test module-level functionality...
 458
 459    def testSockName(self):
 460        # Testing getsockname().  Use a temporary socket to elicit an unused
 461        # ephemeral port that we can use later in the test.
 462        tempsock = socket.socket()
 463        tempsock.bind(("0.0.0.0", 0))
 464        (host, port) = tempsock.getsockname()
 465        tempsock.close()
 466        del tempsock
 467
 468        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 469        sock.bind(("0.0.0.0", port))
 470        name = sock.getsockname()
 471        # XXX(nnorwitz): http://tinyurl.com/os5jz seems to indicate
 472        # it reasonable to get the host's addr in addition to 0.0.0.0.
 473        # At least for eCos.  This is required for the S/390 to pass.
 474        my_ip_addr = socket.gethostbyname(socket.gethostname())
 475        self.assert_(name[0] in ("0.0.0.0", my_ip_addr), '%s invalid' % name[0])
 476        self.assertEqual(name[1], port)
 477
 478    def testGetSockOpt(self):
 479        # Testing getsockopt()
 480        # We know a socket should start without reuse==0
 481        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 482        reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
 483        self.failIf(reuse != 0, "initial mode is reuse")
 484
 485    def testSetSockOpt(self):
 486        # Testing setsockopt()
 487        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 488        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
 489        reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
 490        self.failIf(reuse == 0, "failed to set reuse mode")
 491
 492    def testSendAfterClose(self):
 493        # testing send() after close() with timeout
 494        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 495        sock.settimeout(1)
 496        sock.close()
 497        self.assertRaises(socket.error, sock.send, "spam")
 498
 499    def testNewAttributes(self):
 500        # testing .family, .type and .protocol
 501        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 502        self.assertEqual(sock.family, socket.AF_INET)
 503        self.assertEqual(sock.type, socket.SOCK_STREAM)
 504        self.assertEqual(sock.proto, 0)
 505        sock.close()
 506
 507    def test_sock_ioctl(self):
 508        if os.name != "nt":
 509            return
 510        self.assert_(hasattr(socket.socket, 'ioctl'))
 511        self.assert_(hasattr(socket, 'SIO_RCVALL'))
 512        self.assert_(hasattr(socket, 'RCVALL_ON'))
 513        self.assert_(hasattr(socket, 'RCVALL_OFF'))
 514
 515
 516class BasicTCPTest(SocketConnectedTest):
 517
 518    def __init__(self, methodName='runTest'):
 519        SocketConnectedTest.__init__(self, methodName=methodName)
 520
 521    def testRecv(self):
 522        # Testing large receive over TCP
 523        msg = self.cli_conn.recv(1024)
 524        self.assertEqual(msg, MSG)
 525
 526    def _testRecv(self):
 527        self.serv_conn.send(MSG)
 528
 529    def testOverFlowRecv(self):
 530        # Testing receive in chunks over TCP
 531        seg1 = self.cli_conn.recv(len(MSG) - 3)
 532        seg2 = self.cli_conn.recv(1024)
 533        msg = seg1 + seg2
 534        self.assertEqual(msg, MSG)
 535
 536    def _testOverFlowRecv(self):
 537        self.serv_conn.send(MSG)
 538
 539    def testRecvFrom(self):
 540        # Testing large recvfrom() over TCP
 541        msg, addr = self.cli_conn.recvfrom(1024)
 542        self.assertEqual(msg, MSG)
 543
 544    def _testRecvFrom(self):
 545        self.serv_conn.send(MSG)
 546
 547    def testOverFlowRecvFrom(self):
 548        # Testing recvfrom() in chunks over TCP
 549        seg1, addr = self.cli_conn.recvfrom(len(MSG)-3)
 550        seg2, addr = self.cli_conn.recvfrom(1024)
 551        msg = seg1 + seg2
 552        self.assertEqual(msg, MSG)
 553
 554    def _testOverFlowRecvFrom(self):
 555        self.serv_conn.send(MSG)
 556
 557    def testSendAll(self):
 558        # Testing sendall() with a 2048 byte string over TCP
 559        msg = ''
 560        while 1:
 561            read = self.cli_conn.recv(1024)
 562            if not read:
 563                break
 564            msg += read
 565        self.assertEqual(msg, 'f' * 2048)
 566
 567    def _testSendAll(self):
 568        big_chunk = 'f' * 2048
 569        self.serv_conn.sendall(big_chunk)
 570
 571    def testFromFd(self):
 572        # Testing fromfd()
 573        if not hasattr(socket, "fromfd"):
 574            return # On Windows, this doesn't exist
 575        fd = self.cli_conn.fileno()
 576        sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
 577        msg = sock.recv(1024)
 578        self.assertEqual(msg, MSG)
 579
 580    def _testFromFd(self):
 581        self.serv_conn.send(MSG)
 582
 583    def testShutdown(self):
 584        # Testing shutdown()
 585        msg = self.cli_conn.recv(1024)
 586        self.assertEqual(msg, MSG)
 587        # wait for _testShutdown to finish: on OS X, when the server
 588        # closes the connection the client also becomes disconnected,
 589        # and the client's shutdown call will fail. (Issue #4397.)
 590        self.done.wait()
 591
 592    def _testShutdown(self):
 593        self.serv_conn.send(MSG)
 594        self.serv_conn.shutdown(2)
 595
 596class BasicUDPTest(ThreadedUDPSocketTest):
 597
 598    def __init__(self, methodName='runTest'):
 599        ThreadedUDPSocketTest.__init__(self, methodName=methodName)
 600
 601    def testSendtoAndRecv(self):
 602        # Testing sendto() and Recv() over UDP
 603        msg = self.serv.recv(len(MSG))
 604        self.assertEqual(msg, MSG)
 605
 606    def _testSendtoAndRecv(self):
 607        self.cli.sendto(MSG, 0, (HOST, self.port))
 608
 609    def testRecvFrom(self):
 610        # Testing recvfrom() over UDP
 611        msg, addr = self.serv.recvfrom(len(MSG))
 612        self.assertEqual(msg, MSG)
 613
 614    def _testRecvFrom(self):
 615        self.cli.sendto(MSG, 0, (HOST, self.port))
 616
 617    def testRecvFromNegative(self):
 618        # Negative lengths passed to recvfrom should give ValueError.
 619        self.assertRaises(ValueError, self.serv.recvfrom, -1)
 620
 621    def _testRecvFromNegative(self):
 622        self.cli.sendto(MSG, 0, (HOST, self.port))
 623
 624class TCPCloserTest(ThreadedTCPSocketTest):
 625
 626    def testClose(self):
 627        conn, addr = self.serv.accept()
 628        conn.close()
 629
 630        sd = self.cli
 631        read, write, err = select.select([sd], [], [], 1.0)
 632        self.assertEqual(read, [sd])
 633        self.assertEqual(sd.recv(1), '')
 634
 635    def _testClose(self):
 636        self.cli.connect((HOST, self.port))
 637        time.sleep(1.0)
 638
 639class BasicSocketPairTest(SocketPairTest):
 640
 641    def __init__(self, methodName='runTest'):
 642        SocketPairTest.__init__(self, methodName=methodName)
 643
 644    def testRecv(self):
 645        msg = self.serv.recv(1024)
 646        self.assertEqual(msg, MSG)
 647
 648    def _testRecv(self):
 649        self.cli.send(MSG)
 650
 651    def testSend(self):
 652        self.serv.send(MSG)
 653
 654    def _testSend(self):
 655        msg = self.cli.recv(1024)
 656        self.assertEqual(msg, MSG)
 657
 658class NonBlockingTCPTests(ThreadedTCPSocketTest):
 659
 660    def __init__(self, methodName='runTest'):
 661        ThreadedTCPSocketTest.__init__(self, methodName=methodName)
 662
 663    def testSetBlocking(self):
 664        # Testing whether set blocking works
 665        self.serv.setblocking(0)
 666        start = time.time()
 667        try:
 668            self.serv.accept()
 669        except socket.error:
 670            pass
 671        end = time.time()
 672        self.assert_((end - start) < 1.0, "Error setting non-blocking mode.")
 673
 674    def _testSetBlocking(self):
 675        pass
 676
 677    def testAccept(self):
 678        # Testing non-blocking accept
 679        self.serv.setblocking(0)
 680        try:
 681            conn, addr = self.serv.accept()
 682        except socket.error:
 683            pass
 684        else:
 685            self.fail("Error trying to do non-blocking accept.")
 686        read, write, err = select.select([self.serv], [], [])
 687        if self.serv in read:
 688            conn, addr = self.serv.accept()
 689        else:
 690            self.fail("Error trying to do accept after select.")
 691
 692    def _testAccept(self):
 693        time.sleep(0.1)
 694        self.cli.connect((HOST, self.port))
 695
 696    def testConnect(self):
 697        # Testing non-blocking connect
 698        conn, addr = self.serv.accept()
 699
 700    def _testConnect(self):
 701        self.cli.settimeout(10)
 702        self.cli.connect((HOST, self.port))
 703
 704    def testRecv(self):
 705        # Testing non-blocking recv
 706        conn, addr = self.serv.accept()
 707        conn.setblocking(0)
 708        try:
 709            msg = conn.recv(len(MSG))
 710        except socket.error:
 711            pass
 712        else:
 713            self.fail("Error trying to do non-blocking recv.")
 714        read, write, err = select.select([conn], [], [])
 715        if conn in read:
 716            msg = conn.recv(len(MSG))
 717            self.assertEqual(msg, MSG)
 718        else:
 719            self.fail("Error during select call to non-blocking socket.")
 720
 721    def _testRecv(self):
 722        self.cli.connect((HOST, self.port))
 723        time.sleep(0.1)
 724        self.cli.send(MSG)
 725
 726class FileObjectClassTestCase(SocketConnectedTest):
 727
 728    bufsize = -1 # Use default buffer size
 729
 730    def __init__(self, methodName='runTest'):
 731        SocketConnectedTest.__init__(self, methodName=methodName)
 732
 733    def setUp(self):
 734        SocketConnectedTest.setUp(self)
 735        self.serv_file = self.cli_conn.makefile('rb', self.bufsize)
 736
 737    def tearDown(self):
 738        self.serv_file.close()
 739        self.assert_(self.serv_file.closed)
 740        self.serv_file = None
 741        SocketConnectedTest.tearDown(self)
 742
 743    def clientSetUp(self):
 744        SocketConnectedTest.clientSetUp(self)
 745        self.cli_file = self.serv_conn.makefile('wb')
 746
 747    def clientTearDown(self):
 748        self.cli_file.close()
 749        self.assert_(self.cli_file.closed)
 750        self.cli_file = None
 751        SocketConnectedTest.clientTearDown(self)
 752
 753    def testSmallRead(self):
 754        # Performing small file read test
 755        first_seg = self.serv_file.read(len(MSG)-3)
 756        second_seg = self.serv_file.read(3)
 757        msg = first_seg + second_seg
 758        self.assertEqual(msg, MSG)
 759
 760    def _testSmallRead(self):
 761        self.cli_file.write(MSG)
 762        self.cli_file.flush()
 763
 764    def testFullRead(self):
 765        # read until EOF
 766        msg = self.serv_file.read()
 767        self.assertEqual(msg, MSG)
 768
 769    def _testFullRead(self):
 770        self.cli_file.write(MSG)
 771        self.cli_file.close()
 772
 773    def testUnbufferedRead(self):
 774        # Performing unbuffered file read test
 775        buf = ''
 776        while 1:
 777            char = self.serv_file.read(1)
 778            if not char:
 779                break
 780            buf += char
 781        self.assertEqual(buf, MSG)
 782
 783    def _testUnbufferedRead(self):
 784        self.cli_file.write(MSG)
 785        self.cli_file.flush()
 786
 787    def testReadline(self):
 788        # Performing file readline test
 789        line = self.serv_file.readline()
 790        self.assertEqual(line, MSG)
 791
 792    def _testReadline(self):
 793        self.cli_file.write(MSG)
 794        self.cli_file.flush()
 795
 796    def testReadlineAfterRead(self):
 797        a_baloo_is = self.serv_file.read(len("A baloo is"))
 798        self.assertEqual("A baloo is", a_baloo_is)
 799        _a_bear = self.serv_file.read(len(" a bear"))
 800        self.assertEqual(" a bear", _a_bear)
 801        line = self.serv_file.readline()
 802        self.assertEqual("\n", line)
 803        line = self.serv_file.readline()
 804        self.assertEqual("A BALOO IS A BEAR.\n", line)
 805        line = self.serv_file.readline()
 806        self.assertEqual(MSG, line)
 807
 808    def _testReadlineAfterRead(self):
 809        self.cli_file.write("A baloo is a bear\n")
 810        self.cli_file.write("A BALOO IS A BEAR.\n")
 811        self.cli_file.write(MSG)
 812        self.cli_file.flush()
 813
 814    def testReadlineAfterReadNoNewline(self):
 815        end_of_ = self.serv_file.read(len("End Of "))
 816        self.assertEqual("End Of ", end_of_)
 817        line = self.serv_file.readline()
 818        self.assertEqual("Line", line)
 819
 820    def _testReadlineAfterReadNoNewline(self):
 821        self.cli_file.write("End Of Line")
 822
 823    def testClosedAttr(self):
 824        self.assert_(not self.serv_file.closed)
 825
 826    def _testClosedAttr(self):
 827        self.assert_(not self.cli_file.closed)
 828
 829class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase):
 830
 831    """Repeat the tests from FileObjectClassTestCase with bufsize==0.
 832
 833    In this case (and in this case only), it should be possible to
 834    create a file object, read a line from it, create another file
 835    object, read another line from it, without loss of data in the
 836    first file object's buffer.  Note that httplib relies on this
 837    when reading multiple requests from the same socket."""
 838
 839    bufsize = 0 # Use unbuffered mode
 840
 841    def testUnbufferedReadline(self):
 842        # Read a line, create a new file object, read another line with it
 843        line = self.serv_file.readline() # first line
 844        self.assertEqual(line, "A. " + MSG) # first line
 845        self.serv_file = self.cli_conn.makefile('rb', 0)
 846        line = self.serv_file.readline() # second line
 847        self.assertEqual(line, "B. " + MSG) # second line
 848
 849    def _testUnbufferedReadline(self):
 850        self.cli_file.write("A. " + MSG)
 851        self.cli_file.write("B. " + MSG)
 852        self.cli_file.flush()
 853
 854class LineBufferedFileObjectClassTestCase(FileObjectClassTestCase):
 855
 856    bufsize = 1 # Default-buffered for reading; line-buffered for writing
 857
 858
 859class SmallBufferedFileObjectClassTestCase(FileObjectClassTestCase):
 860
 861    bufsize = 2 # Exercise the buffering code
 862
 863
 864class NetworkConnectionTest(object):
 865    """Prove network connection."""
 866    def clientSetUp(self):
 867        # We're inherited below by BasicTCPTest2, which also inherits
 868        # BasicTCPTest, which defines self.port referenced below.
 869        self.cli = socket.create_connection((HOST, self.port))
 870        self.serv_conn = self.cli
 871
 872class BasicTCPTest2(NetworkConnectionTest, BasicTCPTest):
 873    """Tests that NetworkConnection does not break existing TCP functionality.
 874    """
 875
 876class NetworkConnectionNoServer(unittest.TestCase):
 877    def testWithoutServer(self):
 878        port = test_support.find_unused_port()
 879        self.failUnlessRaises(
 880            socket.error,
 881            lambda: socket.create_connection((HOST, port))
 882        )
 883
 884class NetworkConnectionAttributesTest(SocketTCPTest, ThreadableTest):
 885
 886    def __init__(self, methodName='runTest'):
 887        SocketTCPTest.__init__(self, methodName=methodName)
 888        ThreadableTest.__init__(self)
 889
 890    def clientSetUp(self):
 891        pass
 892
 893    def clientTearDown(self):
 894        self.cli.close()
 895        self.cli = None
 896        ThreadableTest.clientTearDown(self)
 897
 898    def _justAccept(self):
 899        conn, addr = self.serv.accept()
 900
 901    testFamily = _justAccept
 902    def _testFamily(self):
 903        self.cli = socket.create_connection((HOST, self.port), timeout=30)
 904        self.assertEqual(self.cli.family, 2)
 905
 906    testTimeoutDefault = _justAccept
 907    def _testTimeoutDefault(self):
 908        # passing no explicit timeout uses socket's global default
 909        self.assert_(socket.getdefaulttimeout() is None)
 910        socket.setdefaulttimeout(42)
 911        try:
 912            self.cli = socket.create_connection((HOST, self.port))
 913        finally:
 914            socket.setdefaulttimeout(None)
 915        self.assertEquals(self.cli.gettimeout(), 42)
 916
 917    testTimeoutNone = _justAccept
 918    def _testTimeoutNone(self):
 919        # None timeout means the same as sock.settimeout(None)
 920        self.assert_(socket.getdefaulttimeout() is None)
 921        socket.setdefaulttimeout(30)
 922        try:
 923            self.cli = socket.create_connection((HOST, self.port), timeout=None)
 924        finally:
 925            socket.setdefaulttimeout(None)
 926        self.assertEqual(self.cli.gettimeout(), None)
 927
 928    testTimeoutValueNamed = _justAccept
 929    def _testTimeoutValueNamed(self):
 930        self.cli = socket.create_connection((HOST, self.port), timeout=30)
 931        self.assertEqual(self.cli.gettimeout(), 30)
 932
 933    testTimeoutValueNonamed = _justAccept
 934    def _testTimeoutValueNonamed(self):
 935        self.cli = socket.create_connection((HOST, self.port), 30)
 936        self.assertEqual(self.cli.gettimeout(), 30)
 937
 938class NetworkConnectionBehaviourTest(SocketTCPTest, ThreadableTest):
 939
 940    def __init__(self, methodName='runTest'):
 941        SocketTCPTest.__init__(self, methodName=methodName)
 942        ThreadableTest.__init__(self)
 943
 944    def clientSetUp(self):
 945        pass
 946
 947    def clientTearDown(self):
 948        self.cli.close()
 949        self.cli = None
 950        ThreadableTest.clientTearDown(self)
 951
 952    def testInsideTimeout(self):
 953        conn, addr = self.serv.accept()
 954        time.sleep(3)
 955        conn.send("done!")
 956    testOutsideTimeout = testInsideTimeout
 957
 958    def _testInsideTimeout(self):
 959        self.cli = sock = socket.create_connection((HOST, self.port))
 960        data = sock.recv(5)
 961        self.assertEqual(data, "done!")
 962
 963    def _testOutsideTimeout(self):
 964        self.cli = sock = socket.create_connection((HOST, self.port), timeout=1)
 965        self.failUnlessRaises(socket.timeout, lambda: sock.recv(5))
 966
 967
 968class Urllib2FileobjectTest(unittest.TestCase):
 969
 970    # urllib2.HTTPHandler has "borrowed" socket._fileobject, and requires that
 971    # it close the socket if the close c'tor argument is true
 972
 973    def testClose(self):
 974        class MockSocket:
 975            closed = False
 976            def flush(self): pass
 977            def close(self): self.closed = True
 978
 979        # must not close unless we request it: the original use of _fileobject
 980        # by module socket requires that the underlying socket not be closed until
 981        # the _socketobject that created the _fileobject is closed
 982        s = MockSocket()
 983        f = socket._fileobject(s)
 984        f.close()
 985        self.assert_(not s.closed)
 986
 987        s = MockSocket()
 988        f = socket._fileobject(s, close=True)
 989        f.close()
 990        self.assert_(s.closed)
 991
 992class TCPTimeoutTest(SocketTCPTest):
 993
 994    def testTCPTimeout(self):
 995        def raise_timeout(*args, **kwargs):
 996            self.serv.settimeout(1.0)
 997            self.serv.accept()
 998        self.failUnlessRaises(socket.timeout, raise_timeout,
 999                              "Error generating a timeout exception (TCP)")
1000
1001    def testTimeoutZero(self):
1002        ok = False
1003        try:
1004            self.serv.settimeout(0.0)
1005            foo = self.serv.accept()
1006        except socket.timeout:
1007            self.fail("caught timeout instead of error (TCP)")
1008        except socket.error:
1009            ok = True
1010        except:
1011            self.fail("caught unexpected exception (TCP)")
1012        if not ok:
1013            self.fail("accept() returned success when we did not expect it")
1014
1015    def testInterruptedTimeout(self):
1016        # XXX I don't know how to do this test on MSWindows or any other
1017        # plaform that doesn't support signal.alarm() or os.kill(), though
1018        # the bug should have existed on all platforms.
1019        if not hasattr(signal, "alarm"):
1020            return                  # can only test on *nix
1021        self.serv.settimeout(5.0)   # must be longer than alarm
1022        class Alarm(Exception):
1023            pass
1024        def alarm_handler(signal, frame):
1025            raise Alarm
1026        old_alarm = signal.signal(signal.SIGALRM, alarm_handler)
1027        try:
1028            signal.alarm(2)    # POSIX allows alarm to be up to 1 second early
1029            try:
1030                foo = self.serv.accept()
1031            except socket.timeout:
1032                self.fail("caught timeout instead of Alarm")
1033            except Alarm:
1034                pass
1035            except:
1036                self.fail("caught other exception instead of Alarm:"
1037                          " %s(%s):\n%s" %
1038                          (sys.exc_info()[:2] + (traceback.format_exc(),)))
1039            else:
1040                self.fail("nothing caught")
1041            finally:
1042                signal.alarm(0)         # shut off alarm
1043        except Alarm:
1044            self.fail("got Alarm in wrong place")
1045        finally:
1046            # no alarm can be pending.  Safe to restore old handler.
1047            signal.signal(signal.SIGALRM, old_alarm)
1048
1049class UDPTimeoutTest(SocketTCPTest):
1050
1051    def testUDPTimeout(self):
1052        def raise_timeout(*args, **kwargs):
1053            self.serv.settimeout(1.0)
1054            self.serv.recv(1024)
1055        self.failUnlessRaises(socket.timeout, raise_timeout,
1056                              "Error generating a timeout exception (UDP)")
1057
1058    def testTimeoutZero(self):
1059        ok = False
1060        try:
1061            self.serv.settimeout(0.0)
1062            foo = self.serv.recv(1024)
1063        except socket.timeout:
1064            self.fail("caught timeout instead of error (UDP)")
1065        except socket.error:
1066            ok = True
1067        except:
1068            self.fail("caught unexpected exception (UDP)")
1069        if not ok:
1070            self.fail("recv() returned success when we did not expect it")
1071
1072class TestExceptions(unittest.TestCase):
1073
1074    def testExceptionTree(self):
1075        self.assert_(issubclass(socket.error, Exception))
1076        self.assert_(issubclass(socket.herror, socket.error))
1077        self.assert_(issubclass(socket.gaierror, socket.error))
1078        self.assert_(issubclass(socket.timeout, socket.error))
1079
1080class TestLinuxAbstractNamespace(unittest.TestCase):
1081
1082    UNIX_PATH_MAX = 108
1083
1084    def testLinuxAbstractNamespace(self):
1085        address = "\x00python-test-hello\x00\xff"
1086        s1 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1087        s1.bind(address)
1088        s1.listen(1)
1089        s2 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1090        s2.connect(s1.getsockname())
1091        s1.accept()
1092        self.assertEqual(s1.getsockname(), address)
1093        self.assertEqual(s2.getpeername(), address)
1094
1095    def testMaxName(self):
1096        address = "\x00" + "h" * (self.UNIX_PATH_MAX - 1)
1097        s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1098        s.bind(address)
1099        self.assertEqual(s.getsockname(), address)
1100
1101    def testNameOverflow(self):
1102        address = "\x00" + "h" * self.UNIX_PATH_MAX
1103        s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1104        self.assertRaises(socket.error, s.bind, address)
1105
1106
1107class BufferIOTest(SocketConnectedTest):
1108    """
1109    Test the buffer versions of socket.recv() and socket.send().
1110    """
1111    def __init__(self, methodName='runTest'):
1112        SocketConnectedTest.__init__(self, methodName=methodName)
1113
1114    def testRecvInto(self):
1115        buf = array.array('c', ' '*1024)
1116        nbytes = self.cli_conn.recv_into(buf)
1117        self.assertEqual(nbytes, len(MSG))
1118        msg = buf.tostring()[:len(MSG)]
1119        self.assertEqual(msg, MSG)
1120
1121    def _testRecvInto(self):
1122        buf = buffer(MSG)
1123        self.serv_conn.send(buf)
1124
1125    def testRecvFromInto(self):
1126        buf = array.array('c', ' '*1024)
1127        nbytes, addr = self.cli_conn.recvfrom_into(buf)
1128        self.assertEqual(nbytes, len(MSG))
1129        msg = buf.tostring()[:len(MSG)]
1130        self.assertEqual(msg, MSG)
1131
1132    def _testRecvFromInto(self):
1133        buf = buffer(MSG)
1134        self.serv_conn.send(buf)
1135
1136
1137TIPC_STYPE = 2000
1138TIPC_LOWER = 200
1139TIPC_UPPER = 210
1140
1141def isTipcAvailable():
1142    """Check if the TIPC module is loaded
1143
1144    The TIPC module is not loaded automatically on Ubuntu and probably
1145    other Linux distros.
1146    """
1147    if not hasattr(socket, "AF_TIPC"):
1148        return False
1149    if not os.path.isfile("/proc/modules"):
1150        return False
1151    with open("/proc/modules") as f:
1152        for line in f:
1153            if line.startswith("tipc "):
1154                return True
1155    if test_support.verbose:
1156        print "TIPC module is not loaded, please 'sudo modprobe tipc'"
1157    return False
1158
1159class TIPCTest (unittest.TestCase):
1160    def testRDM(self):
1161        srv = socket.socket(socket.AF_TIPC, socket.SOCK_RDM)
1162        cli = socket.socket(socket.AF_TIPC, socket.SOCK_RDM)
1163
1164        srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1165        srvaddr = (socket.TIPC_ADDR_NAMESEQ, TIPC_STYPE,
1166                TIPC_LOWER, TIPC_UPPER)
1167        srv.bind(srvaddr)
1168
1169        sendaddr = (socket.TIPC_ADDR_NAME, TIPC_STYPE,
1170                TIPC_LOWER + (TIPC_UPPER - TIPC_LOWER) / 2, 0)
1171        cli.sendto(MSG, sendaddr)
1172
1173        msg, recvaddr = srv.recvfrom(1024)
1174
1175        self.assertEqual(cli.getsockname(), recvaddr)
1176        self.assertEqual(msg, MSG)
1177
1178
1179class TIPCThreadableTest (unittest.TestCase, ThreadableTest):
1180    def __init__(self, methodName = 'runTest'):
1181        unittest.TestCase.__init__(self, methodName = methodName)
1182        ThreadableTest.__init__(self)
1183
1184    def setUp(self):
1185        self.srv = socket.socket(socket.AF_TIPC, socket.SOCK_STREAM)
1186        self.srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1187        srvaddr = (socket.TIPC_ADDR_NAMESEQ, TIPC_STYPE,
1188                TIPC_LOWER, TIPC_UPPER)
1189        self.srv.bind(srvaddr)
1190        self.srv.listen(5)
1191        self.serverExplicitReady()
1192        self.conn, self.connaddr = self.srv.accept()
1193
1194    def clientSetUp(self):
1195        # The is a hittable race between serverExplicitReady() and the
1196        # accept() call; sleep a little while to avoid it, otherwise
1197        # we could get an exception
1198        time.sleep(0.1)
1199        self.cli = socket.socket(socket.AF_TIPC, socket.SOCK_STREAM)
1200        addr = (socket.TIPC_ADDR_NAME, TIPC_STYPE,
1201                TIPC_LOWER + (TIPC_UPPER - TIPC_LOWER) / 2, 0)
1202        self.cli.connect(addr)
1203        self.cliaddr = self.cli.getsockname()
1204
1205    def testStream(self):
1206        msg = self.conn.recv(1024)
1207        self.assertEqual(msg, MSG)
1208        self.assertEqual(self.cliaddr, self.connaddr)
1209
1210    def _testStream(self):
1211        self.cli.send(MSG)
1212        self.cli.close()
1213
1214
1215def test_main():
1216    tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
1217             TestExceptions, BufferIOTest, BasicTCPTest2]
1218    if sys.platform != 'mac':
1219        tests.extend([ BasicUDPTest, UDPTimeoutTest ])
1220
1221    tests.extend([
1222        NonBlockingTCPTests,
1223        FileObjectClassTestCase,
1224        UnbufferedFileObjectClassTestCase,
1225        LineBufferedFileObjectClassTestCase,
1226        SmallBufferedFileObjectClassTestCase,
1227        Urllib2FileobjectTest,
1228        NetworkConnectionNoServer,
1229        NetworkConnectionAttributesTest,
1230        NetworkConnectionBehaviourTest,
1231    ])
1232    if hasattr(socket, "socketpair"):
1233        tests.append(BasicSocketPairTest)
1234    if sys.platform == 'linux2':
1235        tests.append(TestLinuxAbstractNamespace)
1236    if isTipcAvailable():
1237        tests.append(TIPCTest)
1238        tests.append(TIPCThreadableTest)
1239
1240    thread_info = test_support.threading_setup()
1241    test_support.run_unittest(*tests)
1242    test_support.threading_cleanup(*thread_info)
1243
1244if __name__ == "__main__":
1245    test_main()