/edb/testbase/protocol/protocol.pyx

https://github.com/edgedb/edgedb · Cython · 145 lines · 99 code · 28 blank · 18 comment · 18 complexity · faca2c0e7430e724b883490938db0e33 MD5 · raw file

  1. #
  2. # This source file is part of the EdgeDB open source project.
  3. #
  4. # Copyright 2020-present MagicStack Inc. and the EdgeDB authors.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. import asyncio
  19. import re
  20. import time
  21. from edgedb import con_utils
  22. from edgedb.protocol.asyncio_proto cimport AsyncIOProtocol
  23. from edgedb.protocol.protocol cimport ReadBuffer, WriteBuffer
  24. from . import messages
  25. class Protocol(AsyncIOProtocol):
  26. pass
  27. cdef class Connection:
  28. def __init__(self, pr, tr):
  29. self._protocol = pr
  30. self._transport = tr
  31. self.inbox = []
  32. async def connect(self):
  33. await self._protocol.connect()
  34. async def sync(self):
  35. await self.send(messages.Sync())
  36. reply = await self.recv()
  37. if not isinstance(reply, messages.ReadyForCommand):
  38. raise AssertionError(
  39. f'invalid response for Sync request: {reply!r}')
  40. return reply.transaction_state
  41. async def recv(self):
  42. while True:
  43. await self._protocol.wait_for_message()
  44. mtype = self._protocol.buffer.get_message_type()
  45. data = self._protocol.buffer.consume_message()
  46. msg = messages.ServerMessage.parse(mtype, data)
  47. if isinstance(msg, messages.LogMessage):
  48. self.inbox.append(msg)
  49. continue
  50. return msg
  51. async def recv_match(self, msgcls, **fields):
  52. message = await self.recv()
  53. if not isinstance(message, msgcls):
  54. raise AssertionError(
  55. f'expected for {msgcls.__name__} message, received '
  56. f'{type(message).__name__}')
  57. for fieldname, expected in fields.items():
  58. val = getattr(message, fieldname)
  59. if isinstance(expected, str):
  60. if not re.match(expected, val):
  61. raise AssertionError(
  62. f'{msgcls.__name__}.{fieldname} value {val!r} '
  63. f'does not match expected regexp {expected!r}')
  64. else:
  65. if expected != val:
  66. raise AssertionError(
  67. f'{msgcls.__name__}.{fieldname} value {val!r} '
  68. f'does not equal to expected {expected!r}')
  69. async def send(self, *msgs: messages.ClientMessage):
  70. cdef WriteBuffer buf
  71. for msg in msgs:
  72. out = msg.dump()
  73. buf = WriteBuffer.new()
  74. buf.write_bytes(out)
  75. self._protocol.write(buf)
  76. async def aclose(self):
  77. # TODO: Fix when edgedb-python implements proper cancellation
  78. self._protocol.abort()
  79. async def new_connection(
  80. dsn: str = None,
  81. *,
  82. host: str = None,
  83. port: int = None,
  84. user: str = None,
  85. password: str = None,
  86. admin: str = None,
  87. database: str = None,
  88. timeout: float = 60,
  89. ):
  90. addrs, params, config = con_utils.parse_connect_arguments(
  91. dsn=dsn, host=host, port=port, user=user, password=password,
  92. admin=admin, database=database,
  93. timeout=timeout, command_timeout=None, server_settings=None)
  94. loop = asyncio.get_running_loop()
  95. last_error = None
  96. addr = None
  97. for addr in addrs:
  98. before = time.monotonic()
  99. try:
  100. if timeout <= 0:
  101. raise asyncio.TimeoutError
  102. protocol_factory = lambda: Protocol(params, loop)
  103. if isinstance(addr, str):
  104. connector = loop.create_unix_connection(
  105. protocol_factory, addr)
  106. else:
  107. connector = loop.create_connection(
  108. protocol_factory, *addr)
  109. before = time.monotonic()
  110. try:
  111. tr, pr = await asyncio.wait_for(connector, timeout=timeout)
  112. finally:
  113. timeout -= time.monotonic() - before
  114. return Connection(pr, tr)
  115. except (OSError, asyncio.TimeoutError, ConnectionError) as ex:
  116. last_error = ex
  117. raise last_error