/tle/util/events.py

https://github.com/cheran-senthil/TLE · Python · 164 lines · 102 code · 42 blank · 20 comment · 18 complexity · 566cd3bfe091c16ab56bb2b3274a8f65 MD5 · raw file

  1. import asyncio
  2. import logging
  3. from discord.ext import commands
  4. # Event types
  5. class Event:
  6. """Base class for events."""
  7. pass
  8. class ContestListRefresh(Event):
  9. def __init__(self, contests):
  10. self.contests = contests
  11. class RatingChangesUpdate(Event):
  12. def __init__(self, *, contest, rating_changes):
  13. self.contest = contest
  14. self.rating_changes = rating_changes
  15. # Event errors
  16. class EventError(commands.CommandError):
  17. pass
  18. class ListenerNotRegistered(EventError):
  19. def __init__(self, listener):
  20. super().__init__(f'Listener {listener.name} is not registered for event '
  21. f'{listener.event_cls.__name__}.')
  22. # Event system
  23. class EventSystem:
  24. """Rudimentary event system."""
  25. def __init__(self):
  26. self.listeners_by_event = {}
  27. self.futures_by_event = {}
  28. self.logger = logging.getLogger(self.__class__.__name__)
  29. def add_listener(self, listener):
  30. listeners = self.listeners_by_event.setdefault(listener.event_cls, set())
  31. listeners.add(listener)
  32. def remove_listener(self, listener):
  33. try:
  34. self.listeners_by_event[listener.event_cls].remove(listener)
  35. except KeyError:
  36. raise ListenerNotRegistered(listener)
  37. async def wait_for(self, event_cls, *, timeout=None):
  38. future = asyncio.get_running_loop().create_future()
  39. futures = self.futures_by_event.setdefault(event_cls, [])
  40. futures.append(future)
  41. return await asyncio.wait_for(future, timeout)
  42. def dispatch(self, event_cls, *args, **kwargs):
  43. self.logger.info(f'Dispatching event `{event_cls.__name__}`')
  44. event = event_cls(*args, **kwargs)
  45. for listener in self.listeners_by_event.get(event_cls, []):
  46. listener.trigger(event)
  47. futures = self.futures_by_event.pop(event_cls, [])
  48. for future in futures:
  49. if not future.done():
  50. future.set_result(event)
  51. # Listener
  52. def _ensure_coroutine_func(func):
  53. if not asyncio.iscoroutinefunction(func):
  54. raise TypeError('The listener function must be a coroutine function.')
  55. class Listener:
  56. """A listener for a particular event. A listener must have a name, the event it should listen
  57. to and a coroutine function `func` that is called when the event is dispatched.
  58. """
  59. def __init__(self, name, event_cls, func, *, with_lock=False):
  60. """`with_lock` controls whether execution of `func` should be guarded by an asyncio.Lock."""
  61. _ensure_coroutine_func(func)
  62. self.name = name
  63. self.event_cls = event_cls
  64. self.func = func
  65. self.lock = asyncio.Lock() if with_lock else None
  66. self.logger = logging.getLogger(self.__class__.__name__)
  67. def trigger(self, event):
  68. asyncio.create_task(self._trigger(event))
  69. async def _trigger(self, event):
  70. try:
  71. if self.lock:
  72. async with self.lock:
  73. await self.func(event)
  74. else:
  75. await self.func(event)
  76. except asyncio.CancelledError:
  77. raise
  78. except:
  79. self.logger.exception(f'Exception in listener `{self.name}`.')
  80. def __eq__(self, other):
  81. return (isinstance(other, Listener)
  82. and (self.event_cls, self.func) == (other.event_cls, other.func))
  83. def __hash__(self):
  84. return hash((self.event_cls, self.func))
  85. class ListenerSpec:
  86. """A descriptor intended to be an interface between an instance and its listeners. It creates
  87. the expected listener when `__get__` is called from an instance for the first time. No two
  88. listener specs in the same class should have the same name.
  89. """
  90. def __init__(self, name, event_cls, func, *, with_lock=False):
  91. """`with_lock` controls whether execution of `func` should be guarded by an asyncio.Lock."""
  92. _ensure_coroutine_func(func)
  93. self.name = name
  94. self.event_cls = event_cls
  95. self.func = func
  96. self.with_lock = with_lock
  97. def __get__(self, instance, owner):
  98. if instance is None:
  99. return self
  100. try:
  101. listeners = getattr(instance, '___listeners___')
  102. except AttributeError:
  103. listeners = instance.___listeners___ = {}
  104. if self.name not in listeners:
  105. # In Python <=3.7 iscoroutinefunction returns False for async functions wrapped by
  106. # functools.partial.
  107. # TODO: Use functools.partial when we move to Python 3.8.
  108. async def wrapper(event):
  109. return await self.func(instance, event)
  110. listeners[self.name] = Listener(self.name, self.event_cls, wrapper,
  111. with_lock=self.with_lock)
  112. return listeners[self.name]
  113. def listener(*, name, event_cls, with_lock=False):
  114. """Returns a decorator that creates a `Listener` with the given options."""
  115. def decorator(func):
  116. return Listener(name, event_cls, func, with_lock=with_lock)
  117. return decorator
  118. def listener_spec(*, name, event_cls, with_lock=False):
  119. """Returns a decorator that creates a `ListenerSpec` with the given options."""
  120. def decorator(func):
  121. return ListenerSpec(name, event_cls, func, with_lock=with_lock)
  122. return decorator