PageRenderTime 29ms CodeModel.GetById 21ms RepoModel.GetById 0ms app.codeStats 0ms

/dns/_asyncio_backend.py

http://github.com/rthalley/dnspython
Python | 149 lines | 111 code | 34 blank | 4 comment | 18 complexity | c2481ca08a5b53af27b1f8d3d68ec4c1 MD5 | raw file
Possible License(s): 0BSD
  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. """asyncio library query support"""
  3. import socket
  4. import asyncio
  5. import sys
  6. import dns._asyncbackend
  7. import dns.exception
  8. _is_win32 = sys.platform == 'win32'
  9. def _get_running_loop():
  10. try:
  11. return asyncio.get_running_loop()
  12. except AttributeError: # pragma: no cover
  13. return asyncio.get_event_loop()
  14. class _DatagramProtocol:
  15. def __init__(self):
  16. self.transport = None
  17. self.recvfrom = None
  18. def connection_made(self, transport):
  19. self.transport = transport
  20. def datagram_received(self, data, addr):
  21. if self.recvfrom:
  22. self.recvfrom.set_result((data, addr))
  23. self.recvfrom = None
  24. def error_received(self, exc): # pragma: no cover
  25. if self.recvfrom and not self.recvfrom.done():
  26. self.recvfrom.set_exception(exc)
  27. def connection_lost(self, exc):
  28. if self.recvfrom and not self.recvfrom.done():
  29. self.recvfrom.set_exception(exc)
  30. def close(self):
  31. self.transport.close()
  32. async def _maybe_wait_for(awaitable, timeout):
  33. if timeout:
  34. try:
  35. return await asyncio.wait_for(awaitable, timeout)
  36. except asyncio.TimeoutError:
  37. raise dns.exception.Timeout(timeout=timeout)
  38. else:
  39. return await awaitable
  40. class DatagramSocket(dns._asyncbackend.DatagramSocket):
  41. def __init__(self, family, transport, protocol):
  42. self.family = family
  43. self.transport = transport
  44. self.protocol = protocol
  45. async def sendto(self, what, destination, timeout): # pragma: no cover
  46. # no timeout for asyncio sendto
  47. self.transport.sendto(what, destination)
  48. async def recvfrom(self, size, timeout):
  49. # ignore size as there's no way I know to tell protocol about it
  50. done = _get_running_loop().create_future()
  51. assert self.protocol.recvfrom is None
  52. self.protocol.recvfrom = done
  53. await _maybe_wait_for(done, timeout)
  54. return done.result()
  55. async def close(self):
  56. self.protocol.close()
  57. async def getpeername(self):
  58. return self.transport.get_extra_info('peername')
  59. async def getsockname(self):
  60. return self.transport.get_extra_info('sockname')
  61. class StreamSocket(dns._asyncbackend.StreamSocket):
  62. def __init__(self, af, reader, writer):
  63. self.family = af
  64. self.reader = reader
  65. self.writer = writer
  66. async def sendall(self, what, timeout):
  67. self.writer.write(what)
  68. return await _maybe_wait_for(self.writer.drain(), timeout)
  69. async def recv(self, size, timeout):
  70. return await _maybe_wait_for(self.reader.read(size),
  71. timeout)
  72. async def close(self):
  73. self.writer.close()
  74. try:
  75. await self.writer.wait_closed()
  76. except AttributeError: # pragma: no cover
  77. pass
  78. async def getpeername(self):
  79. return self.writer.get_extra_info('peername')
  80. async def getsockname(self):
  81. return self.writer.get_extra_info('sockname')
  82. class Backend(dns._asyncbackend.Backend):
  83. def name(self):
  84. return 'asyncio'
  85. async def make_socket(self, af, socktype, proto=0,
  86. source=None, destination=None, timeout=None,
  87. ssl_context=None, server_hostname=None):
  88. if destination is None and socktype == socket.SOCK_DGRAM and \
  89. _is_win32:
  90. raise NotImplementedError('destinationless datagram sockets '
  91. 'are not supported by asyncio '
  92. 'on Windows')
  93. loop = _get_running_loop()
  94. if socktype == socket.SOCK_DGRAM:
  95. transport, protocol = await loop.create_datagram_endpoint(
  96. _DatagramProtocol, source, family=af,
  97. proto=proto, remote_addr=destination)
  98. return DatagramSocket(af, transport, protocol)
  99. elif socktype == socket.SOCK_STREAM:
  100. (r, w) = await _maybe_wait_for(
  101. asyncio.open_connection(destination[0],
  102. destination[1],
  103. ssl=ssl_context,
  104. family=af,
  105. proto=proto,
  106. local_addr=source,
  107. server_hostname=server_hostname),
  108. timeout)
  109. return StreamSocket(af, r, w)
  110. raise NotImplementedError('unsupported socket ' +
  111. f'type {socktype}') # pragma: no cover
  112. async def sleep(self, interval):
  113. await asyncio.sleep(interval)
  114. def datagram_connection_required(self):
  115. return _is_win32