/neo/Network/protocol.py

https://github.com/CityOfZion/neo-python · Python · 101 lines · 70 code · 16 blank · 15 comment · 16 complexity · 4cbe6e3d9f6554c0b9870e9fb6486ca1 MD5 · raw file

  1. import asyncio
  2. import struct
  3. from typing import Optional
  4. from neo.Network.node import NeoNode
  5. from neo.Network.message import Message
  6. from asyncio.streams import StreamReader, StreamReaderProtocol, StreamWriter
  7. from asyncio import events
  8. from neo.logging import log_manager
  9. logger = log_manager.getLogger('network')
  10. class NeoProtocol(StreamReaderProtocol):
  11. def __init__(self, *args, quality_check=False, **kwargs):
  12. """
  13. Args:
  14. *args:
  15. quality_check (bool): there are times when we only establish a connection to check the quality of the node/address
  16. **kwargs:
  17. """
  18. self._stream_reader = StreamReader()
  19. self._stream_writer = None
  20. nodemanager = kwargs.pop('nodemanager')
  21. self.client = NeoNode(self, nodemanager, quality_check)
  22. self._loop = events.get_event_loop()
  23. super().__init__(self._stream_reader)
  24. def connection_made(self, transport: asyncio.transports.BaseTransport) -> None:
  25. super().connection_made(transport)
  26. self._stream_writer = StreamWriter(transport, self, self._stream_reader, self._loop)
  27. if self.client:
  28. asyncio.create_task(self.client.connection_made(transport))
  29. def connection_lost(self, exc: Optional[Exception] = None) -> None:
  30. if self.client:
  31. task = asyncio.create_task(self.client.connection_lost(exc))
  32. task.add_done_callback(lambda args: super(NeoProtocol, self).connection_lost(exc))
  33. else:
  34. super().connection_lost(exc)
  35. def eof_received(self) -> bool:
  36. self._stream_reader.feed_eof()
  37. self.connection_lost()
  38. return True
  39. # False == Do not keep connection open, this makes sure that `connection_lost` gets called.
  40. # return False
  41. async def send_message(self, message: Message) -> None:
  42. try:
  43. self._stream_writer.write(message.to_array())
  44. await self._stream_writer.drain()
  45. except ConnectionResetError:
  46. # print("connection reset")
  47. self.connection_lost(ConnectionResetError)
  48. except ConnectionError:
  49. # print("connection error")
  50. self.connection_lost(ConnectionError)
  51. except asyncio.CancelledError:
  52. # print("task cancelled, closing connection")
  53. self.connection_lost(asyncio.CancelledError)
  54. except Exception as e:
  55. # print(f"***** woah what happened here?! {traceback.format_exc()}")
  56. self.connection_lost()
  57. async def read_message(self, timeout: int = 30) -> Message:
  58. if timeout == 0:
  59. # avoid memleak. See: https://bugs.python.org/issue37042
  60. timeout = None
  61. async def _read():
  62. try:
  63. message_header = await self._stream_reader.readexactly(24)
  64. magic, command, payload_length, checksum = struct.unpack('I 12s I I',
  65. message_header) # uint32, 12byte-string, uint32, uint32
  66. payload_data = await self._stream_reader.readexactly(payload_length)
  67. payload, = struct.unpack('{}s'.format(payload_length), payload_data)
  68. except Exception:
  69. # ensures we break out of the main run() loop of Node, which triggers a disconnect callback to clean up
  70. self.client.disconnecting = True
  71. return None
  72. m = Message(magic, command.rstrip(b'\x00').decode('utf-8'), payload)
  73. if checksum != m.get_checksum(payload):
  74. logger.debug("Message checksum incorrect")
  75. return None
  76. else:
  77. return m
  78. try:
  79. return await asyncio.wait_for(_read(), timeout)
  80. except Exception:
  81. return None
  82. def disconnect(self) -> None:
  83. if self._stream_writer:
  84. self._stream_writer.close()