/edgedb/protocol/asyncio_proto.pyx

https://github.com/edgedb/edgedb-python · Cython · 141 lines · 89 code · 27 blank · 25 comment · 29 complexity · ba6e004c9b672562636243d8f2ecebb8 MD5 · raw file

  1. #
  2. # This source file is part of the EdgeDB open source project.
  3. #
  4. # Copyright 2016-present MagicStack Inc. and the EdgeDB authors.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. import asyncio
  19. from edgedb import errors
  20. from edgedb.pgproto.pgproto cimport (
  21. WriteBuffer,
  22. ReadBuffer,
  23. )
  24. from . cimport protocol
  25. cdef class AsyncIOProtocol(protocol.SansIOProtocol):
  26. def __init__(self, con_params, loop, tls_compat=False):
  27. protocol.SansIOProtocol.__init__(self, con_params, tls_compat)
  28. self.loop = loop
  29. self.transport = None
  30. self.connected_fut = loop.create_future()
  31. self.disconnected_fut = None
  32. self.msg_waiter = None
  33. cpdef abort(self):
  34. self.connected = False
  35. self.terminate()
  36. if self.transport is not None:
  37. self.transport.close()
  38. self.transport = None
  39. cdef write(self, WriteBuffer buf):
  40. if self.transport is None:
  41. raise errors.ClientConnectionFailedTemporarilyError()
  42. self.transport.write(memoryview(buf))
  43. async def wait_for_message(self):
  44. if self.buffer.take_message():
  45. return
  46. try:
  47. self.msg_waiter = self.loop.create_future()
  48. await self.msg_waiter
  49. return
  50. except asyncio.CancelledError:
  51. # TODO: A proper cancellation requires server/protocol
  52. # support, which isn't yet available. Therefore,
  53. # we're disabling asyncio cancellation completely
  54. # until we can implement it properly.
  55. try:
  56. self.cancelled = True
  57. self.abort()
  58. finally:
  59. raise
  60. async def try_recv_eagerly(self):
  61. pass
  62. async def wait_for_connect(self):
  63. if self.connected_fut is not None:
  64. await self.connected_fut
  65. async def wait_for_disconnect(self):
  66. if not self.connected:
  67. return
  68. else:
  69. self.disconnected_fut = self.loop.create_future()
  70. try:
  71. await self.disconnected_fut
  72. except ConnectionError:
  73. pass
  74. finally:
  75. self.disconnected_fut = None
  76. def connection_made(self, transport):
  77. if self.transport is not None:
  78. raise RuntimeError('connection_made: invalid connection status')
  79. self.transport = transport
  80. self.connected_fut.set_result(True)
  81. self.connected_fut = None
  82. def connection_lost(self, exc):
  83. self.connected = False
  84. if self.connected_fut is not None and not self.connected_fut.done():
  85. self.connected_fut.set_exception(ConnectionAbortedError())
  86. if (
  87. self.disconnected_fut is not None
  88. and not self.disconnected_fut.done()
  89. ):
  90. self.disconnected_fut.set_exception(ConnectionResetError())
  91. if self.msg_waiter is not None and not self.msg_waiter.done():
  92. self.msg_waiter.set_exception(errors.ClientConnectionClosedError())
  93. self.msg_waiter = None
  94. if self.transport is not None:
  95. # With asyncio sslproto on CPython 3.10 or lower, a normal exit
  96. # (connection closed by peer) cannot set the transport._closed
  97. # properly, leading to false ResourceWarning. Let's fix that by
  98. # closing the transport again.
  99. if not self.transport.is_closing():
  100. self.transport.close()
  101. self.transport = None
  102. def pause_writing(self):
  103. pass
  104. def resume_writing(self):
  105. pass
  106. def data_received(self, data):
  107. self.buffer.feed_data(data)
  108. if (self.msg_waiter is not None and
  109. self.buffer.take_message() and
  110. not self.msg_waiter.done()):
  111. self.msg_waiter.set_result(True)
  112. self.msg_waiter = None
  113. def eof_received(self):
  114. pass