/dohproxy/server_protocol.py

https://github.com/facebookexperimental/doh-proxy · Python · 190 lines · 139 code · 42 blank · 9 comment · 15 complexity · fa9d3b191dcbdb8ae5fcf8d53acb9f51 MD5 · raw file

  1. #!/usr/bin/env python3
  2. #
  3. # Copyright (c) 2018-present, Facebook, Inc.
  4. # All rights reserved.
  5. #
  6. # This source code is licensed under the BSD-style license found in the
  7. # LICENSE file in the root directory of this source tree.
  8. #
  9. import asyncio
  10. import dns.edns
  11. import dns.entropy
  12. import dns.message
  13. import struct
  14. import time
  15. from dohproxy import utils
  16. class DOHException(Exception):
  17. def body(self):
  18. return self.args[0]
  19. class DOHParamsException(DOHException):
  20. pass
  21. class DOHDNSException(DOHException):
  22. pass
  23. class DNSClient():
  24. DEFAULT_TIMEOUT = 10
  25. def __init__(self, upstream_resolver, upstream_port, logger=None):
  26. self.loop = asyncio.get_event_loop()
  27. self.upstream_resolver = upstream_resolver
  28. self.upstream_port = upstream_port
  29. if logger is None:
  30. logger = utils.configure_logger('DNSClient', 'DEBUG')
  31. self.logger = logger
  32. self.transport = None
  33. async def query(self, dnsq, clientip, timeout=DEFAULT_TIMEOUT,
  34. ecs=False):
  35. # (Potentially) modified copy of dnsq
  36. dnsq_mod = dns.message.from_wire(dnsq.to_wire())
  37. we_set_ecs = False
  38. if ecs:
  39. we_set_ecs = utils.set_dns_ecs(dnsq_mod, clientip)
  40. dnsr = await self.query_udp(dnsq_mod, clientip, timeout=timeout)
  41. if dnsr is None or (dnsr.flags & dns.flags.TC):
  42. dnsr = await self.query_tcp(dnsq_mod, clientip, timeout=timeout)
  43. if dnsr is not None and we_set_ecs:
  44. for option in dnsr.options:
  45. if isinstance(option, dns.edns.ECSOption):
  46. dnsr.options.remove(option)
  47. dnsr.edns = dnsq.edns
  48. return dnsr
  49. async def query_udp(self, dnsq, clientip, timeout=DEFAULT_TIMEOUT):
  50. qid = dnsq.id
  51. fut = asyncio.Future()
  52. transport, _ = await self.loop.create_datagram_endpoint(
  53. lambda: DNSClientProtocolUDP(
  54. dnsq, fut, clientip, logger=self.logger),
  55. remote_addr=(self.upstream_resolver, self.upstream_port))
  56. return await self._try_query(fut, qid, timeout, transport)
  57. async def query_tcp(self, dnsq, clientip, timeout=DEFAULT_TIMEOUT):
  58. qid = dnsq.id
  59. fut = asyncio.Future()
  60. transport, _ = await self.loop.create_connection(
  61. lambda: DNSClientProtocolTCP(
  62. dnsq, fut, clientip, logger=self.logger),
  63. self.upstream_resolver, self.upstream_port)
  64. return await self._try_query(fut, qid, timeout, transport)
  65. async def _try_query(self, fut, qid, timeout, transport):
  66. try:
  67. await asyncio.wait_for(fut, timeout)
  68. dnsr = fut.result()
  69. dnsr.id = qid
  70. except asyncio.TimeoutError:
  71. self.logger.debug('Request timed out')
  72. if transport:
  73. transport.close()
  74. dnsr = None
  75. return dnsr
  76. class DNSClientProtocol(asyncio.Protocol):
  77. def __init__(self, dnsq, fut, clientip, logger=None):
  78. self.transport = None
  79. self.dnsq = dnsq
  80. self.fut = fut
  81. self.clientip = clientip
  82. if logger is None:
  83. logger = utils.configure_logger('DNSClientProtocol', 'DEBUG')
  84. self.logger = logger
  85. def connection_lost(self, exc):
  86. pass
  87. def connection_made(self, transport):
  88. raise NotImplementedError()
  89. def data_received(self, data):
  90. raise NotImplementedError()
  91. def datagram_received(self, data, addr):
  92. raise NotImplementedError()
  93. def error_received(self, exc):
  94. raise NotImplementedError()
  95. def eof_received(self):
  96. raise NotImplementedError()
  97. def send_helper(self, transport):
  98. self.transport = transport
  99. self.dnsq.id = dns.entropy.random_16()
  100. self.logger.info(
  101. '[DNS] {} {}'.format(
  102. self.clientip,
  103. utils.dnsquery2log(self.dnsq)
  104. )
  105. )
  106. self.time_stamp = time.time()
  107. def receive_helper(self, dnsr):
  108. interval = int((time.time() - self.time_stamp) * 1000)
  109. log_message = (
  110. '[DNS] {} {} {}ms'.format(
  111. self.clientip,
  112. utils.dnsans2log(dnsr),
  113. interval
  114. )
  115. )
  116. if not self.fut.cancelled():
  117. self.logger.info(log_message)
  118. self.fut.set_result(dnsr)
  119. else:
  120. self.logger.info(log_message + '(CANCELLED)')
  121. class DNSClientProtocolUDP(DNSClientProtocol):
  122. def connection_made(self, transport):
  123. self.send_helper(transport)
  124. self.transport.sendto(self.dnsq.to_wire())
  125. def datagram_received(self, data, addr):
  126. dnsr = dns.message.from_wire(data)
  127. self.receive_helper(dnsr)
  128. self.transport.close()
  129. def error_received(self, exc):
  130. self.transport.close()
  131. self.logger.exception('Error received: ' + str(exc))
  132. class DNSClientProtocolTCP(DNSClientProtocol):
  133. def __init__(self, dnsq, fut, clientip, logger=None):
  134. super().__init__(dnsq, fut, clientip, logger=logger)
  135. self.buffer = bytes()
  136. def connection_made(self, transport):
  137. self.send_helper(transport)
  138. msg = self.dnsq.to_wire()
  139. tcpmsg = struct.pack('!H', len(msg)) + msg
  140. self.transport.write(tcpmsg)
  141. def data_received(self, data):
  142. self.buffer = utils.handle_dns_tcp_data(
  143. self.buffer + data, self.receive_helper
  144. )
  145. def eof_received(self):
  146. if len(self.buffer) > 0:
  147. self.logger.debug('Discard incomplete message')
  148. self.transport.close()