/aiormq/tools.py

https://github.com/mosquito/aiormq · Python · 100 lines · 70 code · 29 blank · 1 comment · 11 complexity · 0db54a77ada562d308782c4db0f85fc4 MD5 · raw file

  1. import asyncio
  2. from functools import wraps
  3. from types import TracebackType
  4. from typing import (
  5. Any, AsyncContextManager, Awaitable, Callable, Coroutine, Optional, Type,
  6. TypeVar, Union,
  7. )
  8. from yarl import URL
  9. from aiormq.abc import TimeoutType
  10. T = TypeVar("T")
  11. def censor_url(url: URL) -> URL:
  12. if url.password is not None:
  13. return url.with_password("******")
  14. return url
  15. def shield(func: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]:
  16. @wraps(func)
  17. def wrap(*args: Any, **kwargs: Any) -> Awaitable[T]:
  18. return asyncio.shield(func(*args, **kwargs))
  19. return wrap
  20. def awaitable(
  21. func: Callable[..., Union[T, Awaitable[T]]],
  22. ) -> Callable[..., Coroutine[Any, Any, T]]:
  23. # Avoid python 3.8+ warning
  24. if asyncio.iscoroutinefunction(func):
  25. return func # type: ignore
  26. @wraps(func)
  27. async def wrap(*args: Any, **kwargs: Any) -> T:
  28. result = func(*args, **kwargs)
  29. if hasattr(result, "__await__"):
  30. return await result # type: ignore
  31. if asyncio.iscoroutine(result) or asyncio.isfuture(result):
  32. return await result # type: ignore
  33. return result # type: ignore
  34. return wrap
  35. class Countdown:
  36. __slots__ = "loop", "deadline"
  37. def __init__(self, timeout: TimeoutType = None):
  38. self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
  39. self.deadline: TimeoutType = None
  40. if timeout is not None:
  41. self.deadline = self.loop.time() + timeout
  42. def get_timeout(self) -> TimeoutType:
  43. if self.deadline is None:
  44. return None
  45. current = self.loop.time()
  46. if current >= self.deadline:
  47. raise asyncio.TimeoutError
  48. return self.deadline - current
  49. def __call__(self, coro: Awaitable[T]) -> Awaitable[T]:
  50. if self.deadline is None:
  51. return coro
  52. return asyncio.wait_for(coro, timeout=self.get_timeout())
  53. def enter_context(
  54. self, ctx: AsyncContextManager[T],
  55. ) -> AsyncContextManager[T]:
  56. return CountdownContext(self, ctx)
  57. class CountdownContext(AsyncContextManager):
  58. def __init__(self, countdown: Countdown, ctx: AsyncContextManager):
  59. self.countdown = countdown
  60. self.ctx = ctx
  61. def __aenter__(self) -> Awaitable[T]:
  62. if self.countdown.deadline is None:
  63. return self.ctx.__aenter__()
  64. return self.countdown(self.ctx.__aenter__())
  65. def __aexit__(
  66. self, exc_type: Optional[Type[BaseException]],
  67. exc_val: Optional[BaseException], exc_tb: Optional[TracebackType],
  68. ) -> Awaitable[Any]:
  69. if self.countdown.deadline is None:
  70. return self.ctx.__aexit__(exc_type, exc_val, exc_tb)
  71. return self.countdown(self.ctx.__aexit__(exc_type, exc_val, exc_tb))