/fstmerge/examples/eXe/rev3426-3513/right-branch-3513/twisted/test/test_protocols.py
Python | 315 lines | 312 code | 0 blank | 3 comment | 1 complexity | 4241c0872ee69d45523dbcc8458a9920 MD5 | raw file
- """
- Test cases for twisted.protocols package.
- """
- from twisted.trial import unittest
- from twisted.protocols import basic, wire, portforward
- from twisted.internet import reactor, protocol, defer
- import string, struct
- import StringIO
- class StringIOWithoutClosing(StringIO.StringIO):
- def close(self):
- pass
- class LineTester(basic.LineReceiver):
- delimiter = '\n'
- MAX_LENGTH = 64
- def connectionMade(self):
- self.received = []
- def lineReceived(self, line):
- self.received.append(line)
- if line == '':
- self.setRawMode()
- elif line == 'pause':
- self.pauseProducing()
- reactor.callLater(0, self.resumeProducing)
- elif line == 'rawpause':
- self.pauseProducing()
- self.setRawMode()
- self.received.append('')
- reactor.callLater(0, self.resumeProducing)
- elif line == 'stop':
- self.stopProducing()
- elif line[:4] == 'len ':
- self.length = int(line[4:])
- elif line.startswith('produce'):
- self.transport.registerProducer(self, False)
- elif line.startswith('unproduce'):
- self.transport.unregisterProducer()
- def rawDataReceived(self, data):
- data, rest = data[:self.length], data[self.length:]
- self.length = self.length - len(data)
- self.received[-1] = self.received[-1] + data
- if self.length == 0:
- self.setLineMode(rest)
- def lineLengthExceeded(self, line):
- if len(line) > self.MAX_LENGTH+1:
- self.setLineMode(line[self.MAX_LENGTH+1:])
- class LineOnlyTester(basic.LineOnlyReceiver):
- delimiter = '\n'
- MAX_LENGTH = 64
- def connectionMade(self):
- self.received = []
- def lineReceived(self, line):
- self.received.append(line)
- class WireTestCase(unittest.TestCase):
- def testEcho(self):
- t = StringIOWithoutClosing()
- a = wire.Echo()
- a.makeConnection(protocol.FileWrapper(t))
- a.dataReceived("hello")
- a.dataReceived("world")
- a.dataReceived("how")
- a.dataReceived("are")
- a.dataReceived("you")
- self.failUnlessEqual(t.getvalue(), "helloworldhowareyou")
- def testWho(self):
- t = StringIOWithoutClosing()
- a = wire.Who()
- a.makeConnection(protocol.FileWrapper(t))
- self.failUnlessEqual(t.getvalue(), "root\r\n")
- def testQOTD(self):
- t = StringIOWithoutClosing()
- a = wire.QOTD()
- a.makeConnection(protocol.FileWrapper(t))
- self.failUnlessEqual(t.getvalue(),
- "An apple a day keeps the doctor away.\r\n")
- def testDiscard(self):
- t = StringIOWithoutClosing()
- a = wire.Discard()
- a.makeConnection(protocol.FileWrapper(t))
- a.dataReceived("hello")
- a.dataReceived("world")
- a.dataReceived("how")
- a.dataReceived("are")
- a.dataReceived("you")
- self.failUnlessEqual(t.getvalue(), "")
- class LineReceiverTestCase(unittest.TestCase):
- buffer = '''\
- len 10
- 0123456789len 5
- 1234
- len 20
- foo 123
- 0123456789
- 012345678len 0
- foo 5
- 1234567890123456789012345678901234567890123456789012345678901234567890
- len 1
- a'''
- output = ['len 10', '0123456789', 'len 5', '1234\n',
- 'len 20', 'foo 123', '0123456789\n012345678',
- 'len 0', 'foo 5', '', '67890', 'len 1', 'a']
- def testBuffer(self):
- for packet_size in range(1, 10):
- t = StringIOWithoutClosing()
- a = LineTester()
- a.makeConnection(protocol.FileWrapper(t))
- for i in range(len(self.buffer)/packet_size + 1):
- s = self.buffer[i*packet_size:(i+1)*packet_size]
- a.dataReceived(s)
- self.failUnlessEqual(self.output, a.received)
- pause_buf = 'twiddle1\ntwiddle2\npause\ntwiddle3\n'
- pause_output1 = ['twiddle1', 'twiddle2', 'pause']
- pause_output2 = pause_output1+['twiddle3']
- def testPausing(self):
- for packet_size in range(1, 10):
- t = StringIOWithoutClosing()
- a = LineTester()
- a.makeConnection(protocol.FileWrapper(t))
- for i in range(len(self.pause_buf)/packet_size + 1):
- s = self.pause_buf[i*packet_size:(i+1)*packet_size]
- a.dataReceived(s)
- self.failUnlessEqual(self.pause_output1, a.received)
- reactor.iterate(0)
- self.failUnlessEqual(self.pause_output2, a.received)
- rawpause_buf = 'twiddle1\ntwiddle2\nlen 5\nrawpause\n12345twiddle3\n'
- rawpause_output1 = ['twiddle1', 'twiddle2', 'len 5', 'rawpause', '']
- rawpause_output2 = ['twiddle1', 'twiddle2', 'len 5', 'rawpause', '12345', 'twiddle3']
- def testRawPausing(self):
- for packet_size in range(1, 10):
- t = StringIOWithoutClosing()
- a = LineTester()
- a.makeConnection(protocol.FileWrapper(t))
- for i in range(len(self.rawpause_buf)/packet_size + 1):
- s = self.rawpause_buf[i*packet_size:(i+1)*packet_size]
- a.dataReceived(s)
- self.failUnlessEqual(self.rawpause_output1, a.received)
- reactor.iterate(0)
- self.failUnlessEqual(self.rawpause_output2, a.received)
- stop_buf = 'twiddle1\ntwiddle2\nstop\nmore\nstuff\n'
- stop_output = ['twiddle1', 'twiddle2', 'stop']
- def testStopProducing(self):
- for packet_size in range(1, 10):
- t = StringIOWithoutClosing()
- a = LineTester()
- a.makeConnection(protocol.FileWrapper(t))
- for i in range(len(self.stop_buf)/packet_size + 1):
- s = self.stop_buf[i*packet_size:(i+1)*packet_size]
- a.dataReceived(s)
- self.failUnlessEqual(self.stop_output, a.received)
- def testLineReceiverAsProducer(self):
- a = LineTester()
- t = StringIOWithoutClosing()
- a.makeConnection(protocol.FileWrapper(t))
- a.dataReceived('produce\nhello world\nunproduce\ngoodbye\n')
- self.assertEquals(a.received, ['produce', 'hello world', 'unproduce', 'goodbye'])
- class LineOnlyReceiverTestCase(unittest.TestCase):
- buffer = """foo
- bleakness
- desolation
- plastic forks
- """
- def testBuffer(self):
- t = StringIOWithoutClosing()
- a = LineOnlyTester()
- a.makeConnection(protocol.FileWrapper(t))
- for c in self.buffer:
- a.dataReceived(c)
- self.failUnlessEqual(a.received, self.buffer.split('\n')[:-1])
- def testLineTooLong(self):
- t = StringIOWithoutClosing()
- a = LineOnlyTester()
- a.makeConnection(protocol.FileWrapper(t))
- res = a.dataReceived('x'*200)
- self.failIfEqual(res, None)
- class TestMixin:
- def connectionMade(self):
- self.received = []
- def stringReceived(self, s):
- self.received.append(s)
- MAX_LENGTH = 50
- closed = 0
- def connectionLost(self, reason):
- self.closed = 1
- class TestNetstring(TestMixin, basic.NetstringReceiver):
- pass
- class LPTestCaseMixin:
- illegal_strings = []
- protocol = None
- def getProtocol(self):
- t = StringIOWithoutClosing()
- a = self.protocol()
- a.makeConnection(protocol.FileWrapper(t))
- return a
- def testIllegal(self):
- for s in self.illegal_strings:
- r = self.getProtocol()
- for c in s:
- r.dataReceived(c)
- self.assertEquals(r.transport.closed, 1)
- class NetstringReceiverTestCase(unittest.TestCase, LPTestCaseMixin):
- strings = ['hello', 'world', 'how', 'are', 'you123', ':today', "a"*515]
- illegal_strings = ['9999999999999999999999', 'abc', '4:abcde',
- '51:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab,',]
- protocol = TestNetstring
- def testBuffer(self):
- for packet_size in range(1, 10):
- t = StringIOWithoutClosing()
- a = TestNetstring()
- a.MAX_LENGTH = 699
- a.makeConnection(protocol.FileWrapper(t))
- for s in self.strings:
- a.sendString(s)
- out = t.getvalue()
- for i in range(len(out)/packet_size + 1):
- s = out[i*packet_size:(i+1)*packet_size]
- if s:
- a.dataReceived(s)
- self.assertEquals(a.received, self.strings)
- class TestInt32(TestMixin, basic.Int32StringReceiver):
- MAX_LENGTH = 50
- class Int32TestCase(unittest.TestCase, LPTestCaseMixin):
- protocol = TestInt32
- strings = ["a", "b" * 16]
- illegal_strings = ["\x10\x00\x00\x00aaaaaa"]
- partial_strings = ["\x00\x00\x00", "hello there", ""]
- def testPartial(self):
- for s in self.partial_strings:
- r = self.getProtocol()
- r.MAX_LENGTH = 99999999
- for c in s:
- r.dataReceived(c)
- self.assertEquals(r.received, [])
- def testReceive(self):
- r = self.getProtocol()
- for s in self.strings:
- for c in struct.pack("!i",len(s))+s:
- r.dataReceived(c)
- self.assertEquals(r.received, self.strings)
- class OnlyProducerTransport(object):
- paused = False
- disconnecting = False
- def __init__(self):
- self.data = []
- def pauseProducing(self):
- self.paused = True
- def resumeProducing(self):
- self.paused = False
- def write(self, bytes):
- self.data.append(bytes)
- class ConsumingProtocol(basic.LineReceiver):
- def lineReceived(self, line):
- self.transport.write(line)
- self.pauseProducing()
- class ProducerTestCase(unittest.TestCase):
- def testPauseResume(self):
- p = ConsumingProtocol()
- t = OnlyProducerTransport()
- p.makeConnection(t)
- p.dataReceived('hello, ')
- self.failIf(t.data)
- self.failIf(t.paused)
- self.failIf(p.paused)
- p.dataReceived('world\r\n')
- self.assertEquals(t.data, ['hello, world'])
- self.failUnless(t.paused)
- self.failUnless(p.paused)
- p.resumeProducing()
- self.failIf(t.paused)
- self.failIf(p.paused)
- p.dataReceived('hello\r\nworld\r\n')
- self.assertEquals(t.data, ['hello, world', 'hello'])
- self.failUnless(t.paused)
- self.failUnless(p.paused)
- p.resumeProducing()
- p.dataReceived('goodbye\r\n')
- self.assertEquals(t.data, ['hello, world', 'hello', 'world'])
- self.failUnless(t.paused)
- self.failUnless(p.paused)
- p.resumeProducing()
- self.assertEquals(t.data, ['hello, world', 'hello', 'world', 'goodbye'])
- self.failUnless(t.paused)
- self.failUnless(p.paused)
- p.resumeProducing()
- self.assertEquals(t.data, ['hello, world', 'hello', 'world', 'goodbye'])
- self.failIf(t.paused)
- self.failIf(p.paused)
- class Portforwarding(unittest.TestCase):
- def testPortforward(self):
- serverProtocol = wire.Echo()
- realServerFactory = protocol.ServerFactory()
- realServerFactory.protocol = lambda: serverProtocol
- realServerPort = reactor.listenTCP(0, realServerFactory,
- interface='127.0.0.1')
- proxyServerFactory = portforward.ProxyFactory('127.0.0.1',
- realServerPort.getHost().port)
- proxyServerPort = reactor.listenTCP(0, proxyServerFactory,
- interface='127.0.0.1')
- nBytes = 1000
- received = []
- clientProtocol = protocol.Protocol()
- clientProtocol.dataReceived = received.extend
- clientProtocol.connectionMade = lambda: clientProtocol.transport.write('x' * nBytes)
- clientFactory = protocol.ClientFactory()
- clientFactory.protocol = lambda: clientProtocol
- reactor.connectTCP('127.0.0.1', proxyServerPort.getHost().port,
- clientFactory)
- c = 0
- while len(received) < nBytes and c < 100:
- reactor.iterate(0.01)
- c += 1
- self.assertEquals(''.join(received), 'x' * nBytes)
- clientProtocol.transport.loseConnection()
- serverProtocol.transport.loseConnection()
- return defer.gatherResults([
- defer.maybeDeferred(realServerPort.stopListening),
- defer.maybeDeferred(proxyServerPort.stopListening)])