PageRenderTime 1997ms CodeModel.GetById 33ms RepoModel.GetById 0ms app.codeStats 0ms

/autobahn/asyncio/component.py

https://github.com/tavendo/AutobahnPython
Python | 420 lines | 356 code | 17 blank | 47 comment | 14 complexity | 3d2cf0217a7ad955076bbb93b62c8167 MD5 | raw file
  1. ###############################################################################
  2. #
  3. # The MIT License (MIT)
  4. #
  5. # Copyright (c) Crossbar.io Technologies GmbH
  6. #
  7. # Permission is hereby granted, free of charge, to any person obtaining a copy
  8. # of this software and associated documentation files (the "Software"), to deal
  9. # in the Software without restriction, including without limitation the rights
  10. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  11. # copies of the Software, and to permit persons to whom the Software is
  12. # furnished to do so, subject to the following conditions:
  13. #
  14. # The above copyright notice and this permission notice shall be included in
  15. # all copies or substantial portions of the Software.
  16. #
  17. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  18. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  19. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  20. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  21. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  22. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  23. # THE SOFTWARE.
  24. #
  25. ###############################################################################
  26. import asyncio
  27. import ssl
  28. import signal
  29. from functools import wraps
  30. import txaio
  31. txaio.use_asyncio() # noqa
  32. from autobahn.asyncio.websocket import WampWebSocketClientFactory
  33. from autobahn.asyncio.rawsocket import WampRawSocketClientFactory
  34. from autobahn.wamp import component
  35. from autobahn.wamp.exception import TransportLost
  36. from autobahn.asyncio.wamp import Session
  37. from autobahn.wamp.serializer import create_transport_serializers, create_transport_serializer
  38. __all__ = ('Component', 'run')
  39. def _unique_list(seq):
  40. """
  41. Return a list with unique elements from sequence, preserving order.
  42. """
  43. seen = set()
  44. return [x for x in seq if x not in seen and not seen.add(x)]
  45. def _camel_case_from_snake_case(snake):
  46. parts = snake.split('_')
  47. return parts[0] + ''.join(s.capitalize() for s in parts[1:])
  48. def _create_transport_factory(loop, transport, session_factory):
  49. """
  50. Create a WAMP-over-XXX transport factory.
  51. """
  52. if transport.type == 'websocket':
  53. serializers = create_transport_serializers(transport)
  54. factory = WampWebSocketClientFactory(
  55. session_factory,
  56. url=transport.url,
  57. serializers=serializers,
  58. proxy=transport.proxy, # either None or a dict with host, port
  59. )
  60. elif transport.type == 'rawsocket':
  61. serializer = create_transport_serializer(transport.serializers[0])
  62. factory = WampRawSocketClientFactory(session_factory, serializer=serializer)
  63. else:
  64. assert(False), 'should not arrive here'
  65. # set the options one at a time so we can give user better feedback
  66. for k, v in transport.options.items():
  67. try:
  68. factory.setProtocolOptions(**{k: v})
  69. except (TypeError, KeyError):
  70. # this allows us to document options as snake_case
  71. # until everything internally is upgraded from
  72. # camelCase
  73. try:
  74. factory.setProtocolOptions(
  75. **{_camel_case_from_snake_case(k): v}
  76. )
  77. except (TypeError, KeyError):
  78. raise ValueError(
  79. "Unknown {} transport option: {}={}".format(transport.type, k, v)
  80. )
  81. return factory
  82. class Component(component.Component):
  83. """
  84. A component establishes a transport and attached a session
  85. to a realm using the transport for communication.
  86. The transports a component tries to use can be configured,
  87. as well as the auto-reconnect strategy.
  88. """
  89. log = txaio.make_logger()
  90. session_factory = Session
  91. """
  92. The factory of the session we will instantiate.
  93. """
  94. def _is_ssl_error(self, e):
  95. """
  96. Internal helper.
  97. """
  98. return isinstance(e, ssl.SSLError)
  99. def _check_native_endpoint(self, endpoint):
  100. if isinstance(endpoint, dict):
  101. if 'tls' in endpoint:
  102. tls = endpoint['tls']
  103. if isinstance(tls, (dict, bool)):
  104. pass
  105. elif isinstance(tls, ssl.SSLContext):
  106. pass
  107. else:
  108. raise ValueError(
  109. "'tls' configuration must be a dict, bool or "
  110. "SSLContext instance"
  111. )
  112. else:
  113. raise ValueError(
  114. "'endpoint' configuration must be a dict or IStreamClientEndpoint"
  115. " provider"
  116. )
  117. # async function
  118. def _connect_transport(self, loop, transport, session_factory, done):
  119. """
  120. Create and connect a WAMP-over-XXX transport.
  121. """
  122. factory = _create_transport_factory(loop, transport, session_factory)
  123. # XXX the rest of this should probably be factored into its
  124. # own method (or three!)...
  125. if transport.proxy:
  126. timeout = transport.endpoint.get('timeout', 10) # in seconds
  127. if type(timeout) != int:
  128. raise ValueError('invalid type {} for timeout in client endpoint configuration'.format(type(timeout)))
  129. # do we support HTTPS proxies?
  130. f = loop.create_connection(
  131. protocol_factory=factory,
  132. host=transport.proxy['host'],
  133. port=transport.proxy['port'],
  134. )
  135. time_f = asyncio.ensure_future(asyncio.wait_for(f, timeout=timeout))
  136. return self._wrap_connection_future(transport, done, time_f)
  137. elif transport.endpoint['type'] == 'tcp':
  138. version = transport.endpoint.get('version', 4)
  139. if version not in [4, 6]:
  140. raise ValueError('invalid IP version {} in client endpoint configuration'.format(version))
  141. host = transport.endpoint['host']
  142. if type(host) != str:
  143. raise ValueError('invalid type {} for host in client endpoint configuration'.format(type(host)))
  144. port = transport.endpoint['port']
  145. if type(port) != int:
  146. raise ValueError('invalid type {} for port in client endpoint configuration'.format(type(port)))
  147. timeout = transport.endpoint.get('timeout', 10) # in seconds
  148. if type(timeout) != int:
  149. raise ValueError('invalid type {} for timeout in client endpoint configuration'.format(type(timeout)))
  150. tls = transport.endpoint.get('tls', None)
  151. tls_hostname = None
  152. # create a TLS enabled connecting TCP socket
  153. if tls:
  154. if isinstance(tls, dict):
  155. for k in tls.keys():
  156. if k not in ["hostname", "trust_root"]:
  157. raise ValueError("Invalid key '{}' in 'tls' config".format(k))
  158. hostname = tls.get('hostname', host)
  159. if type(hostname) != str:
  160. raise ValueError('invalid type {} for hostname in TLS client endpoint configuration'.format(hostname))
  161. cert_fname = tls.get('trust_root', None)
  162. tls_hostname = hostname
  163. tls = True
  164. if cert_fname is not None:
  165. tls = ssl.create_default_context(
  166. purpose=ssl.Purpose.SERVER_AUTH,
  167. cafile=cert_fname,
  168. )
  169. elif isinstance(tls, ssl.SSLContext):
  170. # tls=<an SSLContext> is valid
  171. tls_hostname = host
  172. elif tls in [False, True]:
  173. if tls:
  174. tls_hostname = host
  175. else:
  176. raise RuntimeError('unknown type {} for "tls" configuration in transport'.format(type(tls)))
  177. f = loop.create_connection(
  178. protocol_factory=factory,
  179. host=host,
  180. port=port,
  181. ssl=tls,
  182. server_hostname=tls_hostname,
  183. )
  184. time_f = asyncio.ensure_future(asyncio.wait_for(f, timeout=timeout))
  185. return self._wrap_connection_future(transport, done, time_f)
  186. elif transport.endpoint['type'] == 'unix':
  187. path = transport.endpoint['path']
  188. timeout = int(transport.endpoint.get('timeout', 10)) # in seconds
  189. f = loop.create_unix_connection(
  190. protocol_factory=factory,
  191. path=path,
  192. )
  193. time_f = asyncio.ensure_future(asyncio.wait_for(f, timeout=timeout))
  194. return self._wrap_connection_future(transport, done, time_f)
  195. else:
  196. assert(False), 'should not arrive here'
  197. def _wrap_connection_future(self, transport, done, conn_f):
  198. def on_connect_success(result):
  199. # async connect call returns a 2-tuple
  200. transport, proto = result
  201. # in the case where we .abort() the transport / connection
  202. # during setup, we still get on_connect_success but our
  203. # transport is already closed (this will happen if
  204. # e.g. there's an "open handshake timeout") -- I don't
  205. # know if there's a "better" way to detect this? #python
  206. # doesn't know of one, anyway
  207. if transport.is_closing():
  208. if not txaio.is_called(done):
  209. reason = getattr(proto, "_onclose_reason", "Connection already closed")
  210. txaio.reject(done, TransportLost(reason))
  211. return
  212. # if e.g. an SSL handshake fails, we will have
  213. # successfully connected (i.e. get here) but need to
  214. # 'listen' for the "connection_lost" from the underlying
  215. # protocol in case of handshake failure .. so we wrap
  216. # it. Also, we don't increment transport.success_count
  217. # here on purpose (because we might not succeed).
  218. # XXX double-check that asyncio behavior on TLS handshake
  219. # failures is in fact as described above
  220. orig = proto.connection_lost
  221. @wraps(orig)
  222. def lost(fail):
  223. rtn = orig(fail)
  224. if not txaio.is_called(done):
  225. # asyncio will call connection_lost(None) in case of
  226. # a transport failure, in which case we create an
  227. # appropriate exception
  228. if fail is None:
  229. fail = TransportLost("failed to complete connection")
  230. txaio.reject(done, fail)
  231. return rtn
  232. proto.connection_lost = lost
  233. def on_connect_failure(err):
  234. transport.connect_failures += 1
  235. # failed to establish a connection in the first place
  236. txaio.reject(done, err)
  237. txaio.add_callbacks(conn_f, on_connect_success, None)
  238. # the errback is added as a second step so it gets called if
  239. # there as an error in on_connect_success itself.
  240. txaio.add_callbacks(conn_f, None, on_connect_failure)
  241. return conn_f
  242. # async function
  243. def start(self, loop=None):
  244. """
  245. This starts the Component, which means it will start connecting
  246. (and re-connecting) to its configured transports. A Component
  247. runs until it is "done", which means one of:
  248. - There was a "main" function defined, and it completed successfully;
  249. - Something called ``.leave()`` on our session, and we left successfully;
  250. - ``.stop()`` was called, and completed successfully;
  251. - none of our transports were able to connect successfully (failure);
  252. :returns: a Future which will resolve (to ``None``) when we are
  253. "done" or with an error if something went wrong.
  254. """
  255. if loop is None:
  256. self.log.warn("Using default loop")
  257. loop = asyncio.get_event_loop()
  258. return self._start(loop=loop)
  259. def run(components, start_loop=True, log_level='info'):
  260. """
  261. High-level API to run a series of components.
  262. This will only return once all the components have stopped
  263. (including, possibly, after all re-connections have failed if you
  264. have re-connections enabled). Under the hood, this calls
  265. XXX fixme for asyncio
  266. -- if you wish to manage the loop yourself, use the
  267. :meth:`autobahn.asyncio.component.Component.start` method to start
  268. each component yourself.
  269. :param components: the Component(s) you wish to run
  270. :type components: instance or list of :class:`autobahn.asyncio.component.Component`
  271. :param start_loop: When ``True`` (the default) this method
  272. start a new asyncio loop.
  273. :type start_loop: bool
  274. :param log_level: a valid log-level (or None to avoid calling start_logging)
  275. :type log_level: string
  276. """
  277. # actually, should we even let people "not start" the logging? I'm
  278. # not sure that's wise... (double-check: if they already called
  279. # txaio.start_logging() what happens if we call it again?)
  280. if log_level is not None:
  281. txaio.start_logging(level=log_level)
  282. loop = asyncio.get_event_loop()
  283. if loop.is_closed():
  284. asyncio.set_event_loop(asyncio.new_event_loop())
  285. loop = asyncio.get_event_loop()
  286. txaio.config.loop = loop
  287. log = txaio.make_logger()
  288. # see https://github.com/python/asyncio/issues/341 asyncio has
  289. # "odd" handling of KeyboardInterrupt when using Tasks (as
  290. # run_until_complete does). Another option is to just resture
  291. # default SIGINT handling, which is to exit:
  292. # import signal
  293. # signal.signal(signal.SIGINT, signal.SIG_DFL)
  294. @asyncio.coroutine
  295. def nicely_exit(signal):
  296. log.info("Shutting down due to {signal}", signal=signal)
  297. try:
  298. tasks = asyncio.Task.all_tasks()
  299. except AttributeError:
  300. # this changed with python >= 3.7
  301. tasks = asyncio.all_tasks()
  302. for task in tasks:
  303. # Do not cancel the current task.
  304. try:
  305. current_task = asyncio.Task.current_task()
  306. except AttributeError:
  307. current_task = asyncio.current_task()
  308. if task is not current_task:
  309. task.cancel()
  310. def cancel_all_callback(fut):
  311. try:
  312. fut.result()
  313. except asyncio.CancelledError:
  314. log.debug("All task cancelled")
  315. except Exception as e:
  316. log.error("Error while shutting down: {exception}", exception=e)
  317. finally:
  318. loop.stop()
  319. fut = asyncio.gather(*tasks)
  320. fut.add_done_callback(cancel_all_callback)
  321. try:
  322. loop.add_signal_handler(signal.SIGINT, lambda: asyncio.ensure_future(nicely_exit("SIGINT")))
  323. loop.add_signal_handler(signal.SIGTERM, lambda: asyncio.ensure_future(nicely_exit("SIGTERM")))
  324. except NotImplementedError:
  325. # signals are not available on Windows
  326. pass
  327. def done_callback(loop, arg):
  328. loop.stop()
  329. # returns a future; could run_until_complete() but see below
  330. component._run(loop, components, done_callback)
  331. if start_loop:
  332. try:
  333. loop.run_forever()
  334. # this is probably more-correct, but then you always get
  335. # "Event loop stopped before Future completed":
  336. # loop.run_until_complete(f)
  337. except asyncio.CancelledError:
  338. pass
  339. # finally:
  340. # signal.signal(signal.SIGINT, signal.SIG_DFL)
  341. # signal.signal(signal.SIGTERM, signal.SIG_DFL)
  342. # Close the event loop at the end, otherwise an exception is
  343. # thrown. https://bugs.python.org/issue23548
  344. loop.close()