/fbnet/command_runner/thrift_client.py

https://github.com/facebookincubator/FCR · Python · 101 lines · 64 code · 25 blank · 12 comment · 5 complexity · 95b170d4a86aaecf4468be945dc9b696 MD5 · raw file

  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Copyright (c) Facebook, Inc. and its affiliates.
  4. #
  5. # This source code is licensed under the MIT license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. import asyncio
  8. from thrift.server.TAsyncioServer import ThriftClientProtocolFactory
  9. from .base_service import ServiceObj
  10. class AsyncioThriftClient(ServiceObj):
  11. """
  12. util class to get asyncio client for different services using asyncio
  13. get_hosts
  14. """
  15. _TIMEOUT = 60 # By default timeout after 60s
  16. def __init__(
  17. self, client_class, host, port, service=None, timeout=None, open_timeout=None
  18. ):
  19. super().__init__(service)
  20. self._client_class = client_class
  21. self._host = host
  22. self._port = port
  23. self._connected = False
  24. self._timeout = timeout
  25. self._open_timeout = open_timeout
  26. self._protocol = None
  27. self._transport = None
  28. self._client = None
  29. if self.service:
  30. self._register_counter("connected")
  31. self._register_counter("lookup.failed")
  32. def _format_counter(self, counter):
  33. return "thrift_client.{}.{}.{}".format(self._host, self._port, counter)
  34. def _inc_counter(self, counter):
  35. if self.service:
  36. c = self._format_counter(counter)
  37. self.inc_counter(c)
  38. def _register_counter(self, counter):
  39. c = self._format_counter(counter)
  40. self.service.stats_mgr.register_counter(c)
  41. async def _lookup_service(self):
  42. return self._host, self._port
  43. async def _get_timeouts(self):
  44. """Set the timeout for thrift calls"""
  45. return {"": self._timeout or self._TIMEOUT}
  46. async def open(self):
  47. host, port = await self._lookup_service()
  48. timeouts = await self._get_timeouts()
  49. conn_fut = self.loop.create_connection(
  50. ThriftClientProtocolFactory(self._client_class, timeouts=timeouts),
  51. host=host,
  52. port=port,
  53. )
  54. (transport, protocol) = await asyncio.wait_for(
  55. conn_fut, self._open_timeout, loop=self.loop
  56. )
  57. self._inc_counter("connected")
  58. self._protocol = protocol
  59. self._transport = transport
  60. self._client = protocol.client
  61. # hookup the close method to the client
  62. self._client.close = self.close
  63. self._connected = True
  64. return self._client
  65. def close(self):
  66. if self._protocol:
  67. self._protocol.close()
  68. if self._transport:
  69. self._transport.close()
  70. def __await__(self):
  71. return self.open().__await__()
  72. async def __aenter__(self):
  73. await self.open()
  74. return self._client
  75. async def __aexit__(self, exc_type, exc, tb):
  76. self.close()