PageRenderTime 69ms CodeModel.GetById 36ms RepoModel.GetById 0ms app.codeStats 0ms

/tornado/test/httpserver_test.py

https://github.com/yinhm/tornado
Python | 277 lines | 245 code | 20 blank | 12 comment | 3 complexity | 5617e5304e065df64ac766de82d83d28 MD5 | raw file
  1. #!/usr/bin/env python
  2. from tornado import httpclient, simple_httpclient, netutil
  3. from tornado.escape import json_decode, utf8, _unicode, recursive_unicode, native_str
  4. from tornado.httpserver import HTTPServer
  5. from tornado.httputil import HTTPHeaders
  6. from tornado.iostream import IOStream
  7. from tornado.simple_httpclient import SimpleAsyncHTTPClient
  8. from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, AsyncTestCase
  9. from tornado.util import b, bytes_type
  10. from tornado.web import Application, RequestHandler
  11. import os
  12. import shutil
  13. import socket
  14. import sys
  15. import tempfile
  16. try:
  17. import ssl
  18. except ImportError:
  19. ssl = None
  20. class HelloWorldRequestHandler(RequestHandler):
  21. def initialize(self, protocol="http"):
  22. self.expected_protocol = protocol
  23. def get(self):
  24. assert self.request.protocol == self.expected_protocol
  25. self.finish("Hello world")
  26. def post(self):
  27. self.finish("Got %d bytes in POST" % len(self.request.body))
  28. class SSLTest(AsyncHTTPTestCase, LogTrapTestCase):
  29. def setUp(self):
  30. super(SSLTest, self).setUp()
  31. # Replace the client defined in the parent class.
  32. # Some versions of libcurl have deadlock bugs with ssl,
  33. # so always run these tests with SimpleAsyncHTTPClient.
  34. self.http_client = SimpleAsyncHTTPClient(io_loop=self.io_loop,
  35. force_instance=True)
  36. def get_app(self):
  37. return Application([('/', HelloWorldRequestHandler,
  38. dict(protocol="https"))])
  39. def get_httpserver_options(self):
  40. # Testing keys were generated with:
  41. # openssl req -new -keyout tornado/test/test.key -out tornado/test/test.crt -nodes -days 3650 -x509
  42. test_dir = os.path.dirname(__file__)
  43. return dict(ssl_options=dict(
  44. certfile=os.path.join(test_dir, 'test.crt'),
  45. keyfile=os.path.join(test_dir, 'test.key')))
  46. def fetch(self, path, **kwargs):
  47. self.http_client.fetch(self.get_url(path).replace('http', 'https'),
  48. self.stop,
  49. validate_cert=False,
  50. **kwargs)
  51. return self.wait()
  52. def test_ssl(self):
  53. response = self.fetch('/')
  54. self.assertEqual(response.body, b("Hello world"))
  55. def test_large_post(self):
  56. response = self.fetch('/',
  57. method='POST',
  58. body='A'*5000)
  59. self.assertEqual(response.body, b("Got 5000 bytes in POST"))
  60. def test_non_ssl_request(self):
  61. # Make sure the server closes the connection when it gets a non-ssl
  62. # connection, rather than waiting for a timeout or otherwise
  63. # misbehaving.
  64. self.http_client.fetch(self.get_url("/"), self.stop,
  65. request_timeout=3600,
  66. connect_timeout=3600)
  67. response = self.wait()
  68. self.assertEqual(response.code, 599)
  69. if ssl is None:
  70. del SSLTest
  71. class MultipartTestHandler(RequestHandler):
  72. def post(self):
  73. self.finish({"header": self.request.headers["X-Header-Encoding-Test"],
  74. "argument": self.get_argument("argument"),
  75. "filename": self.request.files["files"][0].filename,
  76. "filebody": _unicode(self.request.files["files"][0]["body"]),
  77. })
  78. class RawRequestHTTPConnection(simple_httpclient._HTTPConnection):
  79. def set_request(self, request):
  80. self.__next_request = request
  81. def _on_connect(self, parsed):
  82. self.stream.write(self.__next_request)
  83. self.__next_request = None
  84. self.stream.read_until(b("\r\n\r\n"), self._on_headers)
  85. # This test is also called from wsgi_test
  86. class HTTPConnectionTest(AsyncHTTPTestCase, LogTrapTestCase):
  87. def get_handlers(self):
  88. return [("/multipart", MultipartTestHandler),
  89. ("/hello", HelloWorldRequestHandler)]
  90. def get_app(self):
  91. return Application(self.get_handlers())
  92. def raw_fetch(self, headers, body):
  93. conn = RawRequestHTTPConnection(self.io_loop, self.http_client,
  94. httpclient.HTTPRequest(self.get_url("/")),
  95. None, self.stop,
  96. 1024*1024)
  97. conn.set_request(
  98. b("\r\n").join(headers +
  99. [utf8("Content-Length: %d\r\n" % len(body))]) +
  100. b("\r\n") + body)
  101. response = self.wait()
  102. response.rethrow()
  103. return response
  104. def test_multipart_form(self):
  105. # Encodings here are tricky: Headers are latin1, bodies can be
  106. # anything (we use utf8 by default).
  107. response = self.raw_fetch([
  108. b("POST /multipart HTTP/1.0"),
  109. b("Content-Type: multipart/form-data; boundary=1234567890"),
  110. b("X-Header-encoding-test: \xe9"),
  111. ],
  112. b("\r\n").join([
  113. b("Content-Disposition: form-data; name=argument"),
  114. b(""),
  115. u"\u00e1".encode("utf-8"),
  116. b("--1234567890"),
  117. u'Content-Disposition: form-data; name="files"; filename="\u00f3"'.encode("utf8"),
  118. b(""),
  119. u"\u00fa".encode("utf-8"),
  120. b("--1234567890--"),
  121. b(""),
  122. ]))
  123. data = json_decode(response.body)
  124. self.assertEqual(u"\u00e9", data["header"])
  125. self.assertEqual(u"\u00e1", data["argument"])
  126. self.assertEqual(u"\u00f3", data["filename"])
  127. self.assertEqual(u"\u00fa", data["filebody"])
  128. def test_100_continue(self):
  129. # Run through a 100-continue interaction by hand:
  130. # When given Expect: 100-continue, we get a 100 response after the
  131. # headers, and then the real response after the body.
  132. stream = IOStream(socket.socket(), io_loop=self.io_loop)
  133. stream.connect(("localhost", self.get_http_port()), callback=self.stop)
  134. self.wait()
  135. stream.write(b("\r\n").join([b("POST /hello HTTP/1.1"),
  136. b("Content-Length: 1024"),
  137. b("Expect: 100-continue"),
  138. b("\r\n")]), callback=self.stop)
  139. self.wait()
  140. stream.read_until(b("\r\n\r\n"), self.stop)
  141. data = self.wait()
  142. self.assertTrue(data.startswith(b("HTTP/1.1 100 ")), data)
  143. stream.write(b("a") * 1024)
  144. stream.read_until(b("\r\n"), self.stop)
  145. first_line = self.wait()
  146. self.assertTrue(first_line.startswith(b("HTTP/1.1 200")), first_line)
  147. stream.read_until(b("\r\n\r\n"), self.stop)
  148. header_data = self.wait()
  149. headers = HTTPHeaders.parse(native_str(header_data.decode('latin1')))
  150. stream.read_bytes(int(headers["Content-Length"]), self.stop)
  151. body = self.wait()
  152. self.assertEqual(body, b("Got 1024 bytes in POST"))
  153. class EchoHandler(RequestHandler):
  154. def get(self):
  155. self.write(recursive_unicode(self.request.arguments))
  156. class TypeCheckHandler(RequestHandler):
  157. def prepare(self):
  158. self.errors = {}
  159. fields = [
  160. ('method', str),
  161. ('uri', str),
  162. ('version', str),
  163. ('remote_ip', str),
  164. ('protocol', str),
  165. ('host', str),
  166. ('path', str),
  167. ('query', str),
  168. ]
  169. for field, expected_type in fields:
  170. self.check_type(field, getattr(self.request, field), expected_type)
  171. self.check_type('header_key', self.request.headers.keys()[0], str)
  172. self.check_type('header_value', self.request.headers.values()[0], str)
  173. self.check_type('cookie_key', self.request.cookies.keys()[0], str)
  174. self.check_type('cookie_value', self.request.cookies.values()[0].value, str)
  175. # secure cookies
  176. self.check_type('arg_key', self.request.arguments.keys()[0], str)
  177. self.check_type('arg_value', self.request.arguments.values()[0][0], bytes_type)
  178. def post(self):
  179. self.check_type('body', self.request.body, bytes_type)
  180. self.write(self.errors)
  181. def get(self):
  182. self.write(self.errors)
  183. def check_type(self, name, obj, expected_type):
  184. actual_type = type(obj)
  185. if expected_type != actual_type:
  186. self.errors[name] = "expected %s, got %s" % (expected_type,
  187. actual_type)
  188. class HTTPServerTest(AsyncHTTPTestCase, LogTrapTestCase):
  189. def get_app(self):
  190. return Application([("/echo", EchoHandler),
  191. ("/typecheck", TypeCheckHandler),
  192. ])
  193. def test_query_string_encoding(self):
  194. response = self.fetch("/echo?foo=%C3%A9")
  195. data = json_decode(response.body)
  196. self.assertEqual(data, {u"foo": [u"\u00e9"]})
  197. def test_types(self):
  198. headers = {"Cookie": "foo=bar"}
  199. response = self.fetch("/typecheck?foo=bar", headers=headers)
  200. data = json_decode(response.body)
  201. self.assertEqual(data, {})
  202. response = self.fetch("/typecheck", method="POST", body="foo=bar", headers=headers)
  203. data = json_decode(response.body)
  204. self.assertEqual(data, {})
  205. class UnixSocketTest(AsyncTestCase, LogTrapTestCase):
  206. """HTTPServers can listen on Unix sockets too.
  207. Why would you want to do this? Nginx can proxy to backends listening
  208. on unix sockets, for one thing (and managing a namespace for unix
  209. sockets can be easier than managing a bunch of TCP port numbers).
  210. Unfortunately, there's no way to specify a unix socket in a url for
  211. an HTTP client, so we have to test this by hand.
  212. """
  213. def setUp(self):
  214. super(UnixSocketTest, self).setUp()
  215. self.tmpdir = tempfile.mkdtemp()
  216. def tearDown(self):
  217. shutil.rmtree(self.tmpdir)
  218. super(UnixSocketTest, self).tearDown()
  219. def test_unix_socket(self):
  220. sockfile = os.path.join(self.tmpdir, "test.sock")
  221. sock = netutil.bind_unix_socket(sockfile)
  222. app = Application([("/hello", HelloWorldRequestHandler)])
  223. server = HTTPServer(app, io_loop=self.io_loop)
  224. server.add_socket(sock)
  225. stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop)
  226. stream.connect(sockfile, self.stop)
  227. self.wait()
  228. stream.write(b("GET /hello HTTP/1.0\r\n\r\n"))
  229. stream.read_until(b("\r\n"), self.stop)
  230. response = self.wait()
  231. self.assertEqual(response, b("HTTP/1.0 200 OK\r\n"))
  232. stream.read_until(b("\r\n\r\n"), self.stop)
  233. headers = HTTPHeaders.parse(self.wait().decode('latin1'))
  234. stream.read_bytes(int(headers["Content-Length"]), self.stop)
  235. body = self.wait()
  236. self.assertEqual(body, b("Hello world"))
  237. if not hasattr(socket, 'AF_UNIX') or sys.platform == 'cygwin':
  238. del UnixSocketTest