/paasta_tools/async_utils.py

https://github.com/Yelp/paasta · Python · 105 lines · 74 code · 19 blank · 12 comment · 16 complexity · 8dc30f60b62b0ae990d843179634ff23 MD5 · raw file

  1. import asyncio
  2. import functools
  3. import time
  4. import weakref
  5. from collections import defaultdict
  6. from typing import AsyncIterable
  7. from typing import Awaitable
  8. from typing import Callable
  9. from typing import Dict
  10. from typing import List
  11. from typing import Optional
  12. from typing import TypeVar
  13. T = TypeVar("T")
  14. # NOTE: this method is not thread-safe due to lack of locking while checking
  15. # and updating the cache
  16. def async_ttl_cache(
  17. ttl: Optional[float] = 300,
  18. cleanup_self: bool = False,
  19. *,
  20. cache: Optional[Dict] = None,
  21. ) -> Callable[
  22. [Callable[..., Awaitable[T]]], Callable[..., Awaitable[T]] # wrapped # inner
  23. ]:
  24. async def call_or_get_from_cache(cache, async_func, args_for_key, args, kwargs):
  25. # Please note that anything which is put into `key` will be in the
  26. # cache forever, potentially causing memory leaks. The most common
  27. # case is the `self` arg pointing to a huge object. To mitigate that
  28. # we're using `args_for_key`, which is supposed not contain any huge
  29. # objects.
  30. key = functools._make_key(args_for_key, kwargs, typed=False)
  31. try:
  32. future, last_update = cache[key]
  33. if ttl is not None and time.time() - last_update > ttl:
  34. raise KeyError
  35. except KeyError:
  36. future = asyncio.ensure_future(async_func(*args, **kwargs))
  37. # set the timestamp to +infinity so that we always wait on the in-flight request.
  38. cache[key] = (future, float("Inf"))
  39. try:
  40. value = await future
  41. except Exception:
  42. # Only update the cache if it's the same future we awaited and
  43. # it hasn't already been updated by another coroutine
  44. # Note also that we use get() in case the key was deleted from the
  45. # cache by another coroutine
  46. if cache.get(key) == (future, float("Inf")):
  47. del cache[key]
  48. raise
  49. else:
  50. if cache.get(key) == (future, float("Inf")):
  51. cache[key] = (future, time.time())
  52. return value
  53. if cleanup_self:
  54. instance_caches: Dict = cache if cache is not None else defaultdict(dict)
  55. def on_delete(w):
  56. del instance_caches[w]
  57. def outer(wrapped):
  58. @functools.wraps(wrapped)
  59. async def inner(self, *args, **kwargs):
  60. w = weakref.ref(self, on_delete)
  61. self_cache = instance_caches[w]
  62. return await call_or_get_from_cache(
  63. self_cache, wrapped, args, (self,) + args, kwargs
  64. )
  65. return inner
  66. else:
  67. cache2: Dict = cache if cache is not None else {} # Should be Dict[Any, T] but that doesn't work.
  68. def outer(wrapped):
  69. @functools.wraps(wrapped)
  70. async def inner(*args, **kwargs):
  71. return await call_or_get_from_cache(cache2, wrapped, args, args, kwargs)
  72. return inner
  73. return outer
  74. async def aiter_to_list(aiter: AsyncIterable[T],) -> List[T]:
  75. return [x async for x in aiter]
  76. def async_timeout(
  77. seconds: int = 10,
  78. ) -> Callable[
  79. [Callable[..., Awaitable[T]]], Callable[..., Awaitable[T]] # wrapped # inner
  80. ]:
  81. def outer(wrapped):
  82. @functools.wraps(wrapped)
  83. async def inner(*args, **kwargs):
  84. return await asyncio.wait_for(wrapped(*args, **kwargs), timeout=seconds)
  85. return inner
  86. return outer