PageRenderTime 57ms CodeModel.GetById 21ms RepoModel.GetById 0ms app.codeStats 0ms

/test/conftest.py

https://github.com/shazow/urllib3
Python | 343 lines | 249 code | 75 blank | 19 comment | 31 complexity | 085b9a172f8c3e97f9fbe0954b537602 MD5 | raw file
  1. import contextlib
  2. import platform
  3. import socket
  4. import ssl
  5. import sys
  6. import threading
  7. from pathlib import Path
  8. from typing import AbstractSet, Any, Dict, Generator, NamedTuple, Optional, Tuple
  9. import pytest
  10. import trustme
  11. from tornado import ioloop, web
  12. from dummyserver.handlers import TestingApp
  13. from dummyserver.proxy import ProxyHandler
  14. from dummyserver.server import HAS_IPV6, run_tornado_app
  15. from dummyserver.testcase import HTTPSDummyServerTestCase
  16. from urllib3.util import ssl_
  17. from .tz_stub import stub_timezone_ctx
  18. # The Python 3.8+ default loop on Windows breaks Tornado
  19. @pytest.fixture(scope="session", autouse=True)
  20. def configure_windows_event_loop() -> None:
  21. if sys.version_info >= (3, 8) and platform.system() == "Windows":
  22. import asyncio
  23. asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # type: ignore[attr-defined]
  24. class ServerConfig(NamedTuple):
  25. scheme: str
  26. host: str
  27. port: int
  28. ca_certs: str
  29. @property
  30. def base_url(self) -> str:
  31. host = self.host
  32. if ":" in host:
  33. host = f"[{host}]"
  34. return f"{self.scheme}://{host}:{self.port}"
  35. def _write_cert_to_dir(
  36. cert: trustme.LeafCert, tmpdir: Path, file_prefix: str = "server"
  37. ) -> Dict[str, str]:
  38. cert_path = str(tmpdir / ("%s.pem" % file_prefix))
  39. key_path = str(tmpdir / ("%s.key" % file_prefix))
  40. cert.private_key_pem.write_to_path(key_path)
  41. cert.cert_chain_pems[0].write_to_path(cert_path)
  42. certs = {"keyfile": key_path, "certfile": cert_path}
  43. return certs
  44. @contextlib.contextmanager
  45. def run_server_in_thread(
  46. scheme: str, host: str, tmpdir: Path, ca: trustme.CA, server_cert: trustme.LeafCert
  47. ) -> Generator[ServerConfig, None, None]:
  48. ca_cert_path = str(tmpdir / "ca.pem")
  49. server_cert_path = str(tmpdir / "server.pem")
  50. server_key_path = str(tmpdir / "server.key")
  51. ca.cert_pem.write_to_path(ca_cert_path)
  52. server_cert.private_key_pem.write_to_path(server_key_path)
  53. server_cert.cert_chain_pems[0].write_to_path(server_cert_path)
  54. server_certs = {"keyfile": server_key_path, "certfile": server_cert_path}
  55. io_loop = ioloop.IOLoop.current()
  56. app = web.Application([(r".*", TestingApp)])
  57. server, port = run_tornado_app(app, io_loop, server_certs, scheme, host)
  58. server_thread = threading.Thread(target=io_loop.start)
  59. server_thread.start()
  60. yield ServerConfig("https", host, port, ca_cert_path)
  61. io_loop.add_callback(server.stop)
  62. io_loop.add_callback(io_loop.stop)
  63. server_thread.join()
  64. @contextlib.contextmanager
  65. def run_server_and_proxy_in_thread(
  66. proxy_scheme: str,
  67. proxy_host: str,
  68. tmpdir: Path,
  69. ca: trustme.CA,
  70. proxy_cert: trustme.LeafCert,
  71. server_cert: trustme.LeafCert,
  72. ) -> Generator[Tuple[ServerConfig, ServerConfig], None, None]:
  73. ca_cert_path = str(tmpdir / "ca.pem")
  74. ca.cert_pem.write_to_path(ca_cert_path)
  75. server_certs = _write_cert_to_dir(server_cert, tmpdir)
  76. proxy_certs = _write_cert_to_dir(proxy_cert, tmpdir, "proxy")
  77. io_loop = ioloop.IOLoop.current()
  78. app = web.Application([(r".*", TestingApp)])
  79. server_app, port = run_tornado_app(app, io_loop, server_certs, "https", "localhost")
  80. server_config = ServerConfig("https", "localhost", port, ca_cert_path)
  81. proxy = web.Application([(r".*", ProxyHandler)])
  82. proxy_app, proxy_port = run_tornado_app(
  83. proxy, io_loop, proxy_certs, proxy_scheme, proxy_host
  84. )
  85. proxy_config = ServerConfig(proxy_scheme, proxy_host, proxy_port, ca_cert_path)
  86. server_thread = threading.Thread(target=io_loop.start)
  87. server_thread.start()
  88. yield (proxy_config, server_config)
  89. io_loop.add_callback(server_app.stop)
  90. io_loop.add_callback(proxy_app.stop)
  91. io_loop.add_callback(io_loop.stop)
  92. server_thread.join()
  93. @pytest.fixture(params=["localhost", "127.0.0.1", "::1"])
  94. def loopback_host(request: Any) -> Generator[str, None, None]:
  95. host = request.param
  96. if host == "::1" and not HAS_IPV6:
  97. pytest.skip("Test requires IPv6 on loopback")
  98. yield host
  99. @pytest.fixture()
  100. def san_server(
  101. loopback_host: str, tmp_path_factory: pytest.TempPathFactory
  102. ) -> Generator[ServerConfig, None, None]:
  103. tmpdir = tmp_path_factory.mktemp("certs")
  104. ca = trustme.CA()
  105. server_cert = ca.issue_cert(loopback_host)
  106. with run_server_in_thread("https", loopback_host, tmpdir, ca, server_cert) as cfg:
  107. yield cfg
  108. @pytest.fixture()
  109. def no_san_server(
  110. loopback_host: str, tmp_path_factory: pytest.TempPathFactory
  111. ) -> Generator[ServerConfig, None, None]:
  112. tmpdir = tmp_path_factory.mktemp("certs")
  113. ca = trustme.CA()
  114. server_cert = ca.issue_cert(common_name=loopback_host)
  115. with run_server_in_thread("https", loopback_host, tmpdir, ca, server_cert) as cfg:
  116. yield cfg
  117. @pytest.fixture()
  118. def no_san_server_with_different_commmon_name(
  119. tmp_path_factory: pytest.TempPathFactory,
  120. ) -> Generator[ServerConfig, None, None]:
  121. tmpdir = tmp_path_factory.mktemp("certs")
  122. ca = trustme.CA()
  123. server_cert = ca.issue_cert(common_name="example.com")
  124. with run_server_in_thread("https", "localhost", tmpdir, ca, server_cert) as cfg:
  125. yield cfg
  126. @pytest.fixture
  127. def no_san_proxy_with_server(
  128. tmp_path_factory: pytest.TempPathFactory,
  129. ) -> Generator[Tuple[ServerConfig, ServerConfig], None, None]:
  130. tmpdir = tmp_path_factory.mktemp("certs")
  131. ca = trustme.CA()
  132. # only common name, no subject alternative names
  133. proxy_cert = ca.issue_cert(common_name="localhost")
  134. server_cert = ca.issue_cert("localhost")
  135. with run_server_and_proxy_in_thread(
  136. "https", "localhost", tmpdir, ca, proxy_cert, server_cert
  137. ) as cfg:
  138. yield cfg
  139. @pytest.fixture
  140. def no_localhost_san_server(
  141. tmp_path_factory: pytest.TempPathFactory,
  142. ) -> Generator[ServerConfig, None, None]:
  143. tmpdir = tmp_path_factory.mktemp("certs")
  144. ca = trustme.CA()
  145. # non localhost common name
  146. server_cert = ca.issue_cert("example.com")
  147. with run_server_in_thread("https", "localhost", tmpdir, ca, server_cert) as cfg:
  148. yield cfg
  149. @pytest.fixture
  150. def ipv4_san_proxy_with_server(
  151. tmp_path_factory: pytest.TempPathFactory,
  152. ) -> Generator[Tuple[ServerConfig, ServerConfig], None, None]:
  153. tmpdir = tmp_path_factory.mktemp("certs")
  154. ca = trustme.CA()
  155. # IP address in Subject Alternative Name
  156. proxy_cert = ca.issue_cert("127.0.0.1")
  157. server_cert = ca.issue_cert("localhost")
  158. with run_server_and_proxy_in_thread(
  159. "https", "127.0.0.1", tmpdir, ca, proxy_cert, server_cert
  160. ) as cfg:
  161. yield cfg
  162. @pytest.fixture
  163. def ipv6_san_proxy_with_server(
  164. tmp_path_factory: pytest.TempPathFactory,
  165. ) -> Generator[Tuple[ServerConfig, ServerConfig], None, None]:
  166. tmpdir = tmp_path_factory.mktemp("certs")
  167. ca = trustme.CA()
  168. # IP addresses in Subject Alternative Name
  169. proxy_cert = ca.issue_cert("::1")
  170. server_cert = ca.issue_cert("localhost")
  171. with run_server_and_proxy_in_thread(
  172. "https", "::1", tmpdir, ca, proxy_cert, server_cert
  173. ) as cfg:
  174. yield cfg
  175. @pytest.fixture
  176. def ipv4_san_server(
  177. tmp_path_factory: pytest.TempPathFactory,
  178. ) -> Generator[ServerConfig, None, None]:
  179. tmpdir = tmp_path_factory.mktemp("certs")
  180. ca = trustme.CA()
  181. # IP address in Subject Alternative Name
  182. server_cert = ca.issue_cert("127.0.0.1")
  183. with run_server_in_thread("https", "127.0.0.1", tmpdir, ca, server_cert) as cfg:
  184. yield cfg
  185. @pytest.fixture
  186. def ipv6_san_server(
  187. tmp_path_factory: pytest.TempPathFactory,
  188. ) -> Generator[ServerConfig, None, None]:
  189. if not HAS_IPV6:
  190. pytest.skip("Only runs on IPv6 systems")
  191. tmpdir = tmp_path_factory.mktemp("certs")
  192. ca = trustme.CA()
  193. # IP address in Subject Alternative Name
  194. server_cert = ca.issue_cert("::1")
  195. with run_server_in_thread("https", "::1", tmpdir, ca, server_cert) as cfg:
  196. yield cfg
  197. @pytest.fixture
  198. def ipv6_no_san_server(
  199. tmp_path_factory: pytest.TempPathFactory,
  200. ) -> Generator[ServerConfig, None, None]:
  201. if not HAS_IPV6:
  202. pytest.skip("Only runs on IPv6 systems")
  203. tmpdir = tmp_path_factory.mktemp("certs")
  204. ca = trustme.CA()
  205. # IP address in Common Name
  206. server_cert = ca.issue_cert(common_name="::1")
  207. with run_server_in_thread("https", "::1", tmpdir, ca, server_cert) as cfg:
  208. yield cfg
  209. @pytest.fixture
  210. def stub_timezone(request: pytest.FixtureRequest) -> Generator[None, None, None]:
  211. """
  212. A pytest fixture that runs the test with a stub timezone.
  213. """
  214. with stub_timezone_ctx(request.param): # type: ignore[attr-defined]
  215. yield
  216. @pytest.fixture(scope="session")
  217. def supported_tls_versions() -> AbstractSet[Optional[str]]:
  218. # We have to create an actual TLS connection
  219. # to test if the TLS version is not disabled by
  220. # OpenSSL config. Ubuntu 20.04 specifically
  221. # disables TLSv1 and TLSv1.1.
  222. tls_versions = set()
  223. _server = HTTPSDummyServerTestCase()
  224. _server._start_server()
  225. for _ssl_version_name in (
  226. "PROTOCOL_TLSv1",
  227. "PROTOCOL_TLSv1_1",
  228. "PROTOCOL_TLSv1_2",
  229. "PROTOCOL_TLS",
  230. ):
  231. _ssl_version = getattr(ssl, _ssl_version_name, 0)
  232. if _ssl_version == 0:
  233. continue
  234. _sock = socket.create_connection((_server.host, _server.port))
  235. try:
  236. _sock = ssl_.ssl_wrap_socket(
  237. _sock, cert_reqs=ssl.CERT_NONE, ssl_version=_ssl_version
  238. )
  239. except ssl.SSLError:
  240. pass
  241. else:
  242. tls_versions.add(_sock.version())
  243. _sock.close()
  244. _server._stop_server()
  245. return tls_versions
  246. @pytest.fixture(scope="function")
  247. def requires_tlsv1(supported_tls_versions: AbstractSet[str]) -> None:
  248. """Test requires TLSv1 available"""
  249. if not hasattr(ssl, "PROTOCOL_TLSv1") or "TLSv1" not in supported_tls_versions:
  250. pytest.skip("Test requires TLSv1")
  251. @pytest.fixture(scope="function")
  252. def requires_tlsv1_1(supported_tls_versions: AbstractSet[str]) -> None:
  253. """Test requires TLSv1.1 available"""
  254. if not hasattr(ssl, "PROTOCOL_TLSv1_1") or "TLSv1.1" not in supported_tls_versions:
  255. pytest.skip("Test requires TLSv1.1")
  256. @pytest.fixture(scope="function")
  257. def requires_tlsv1_2(supported_tls_versions: AbstractSet[str]) -> None:
  258. """Test requires TLSv1.2 available"""
  259. if not hasattr(ssl, "PROTOCOL_TLSv1_2") or "TLSv1.2" not in supported_tls_versions:
  260. pytest.skip("Test requires TLSv1.2")
  261. @pytest.fixture(scope="function")
  262. def requires_tlsv1_3(supported_tls_versions: AbstractSet[str]) -> None:
  263. """Test requires TLSv1.3 available"""
  264. if (
  265. not getattr(ssl, "HAS_TLSv1_3", False)
  266. or "TLSv1.3" not in supported_tls_versions
  267. ):
  268. pytest.skip("Test requires TLSv1.3")