/torch/distributed/elastic/metrics/api.py

https://github.com/ROCmSoftwarePlatform/pytorch · Python · 206 lines · 125 code · 40 blank · 41 comment · 17 complexity · 5246f2fc1963fffd77bf0c10e9dd1f2f MD5 · raw file

  1. #!/usr/bin/env python3
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the BSD-style license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. import abc
  8. import time
  9. import warnings
  10. from collections import namedtuple
  11. from functools import wraps
  12. from typing import Dict, Optional
  13. MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value"])
  14. class MetricsConfig:
  15. __slots__ = ["params"]
  16. def __init__(self, params: Optional[Dict[str, str]] = None):
  17. self.params = params
  18. if self.params is None:
  19. self.params = {}
  20. class MetricHandler(abc.ABC):
  21. @abc.abstractmethod
  22. def emit(self, metric_data: MetricData):
  23. pass
  24. class ConsoleMetricHandler(MetricHandler):
  25. def emit(self, metric_data: MetricData):
  26. print(
  27. "[{}][{}]: {}={}".format(
  28. metric_data.timestamp,
  29. metric_data.group_name,
  30. metric_data.name,
  31. metric_data.value,
  32. )
  33. )
  34. class NullMetricHandler(MetricHandler):
  35. def emit(self, metric_data: MetricData):
  36. pass
  37. class MetricStream:
  38. def __init__(self, group_name: str, handler: MetricHandler):
  39. self.group_name = group_name
  40. self.handler = handler
  41. def add_value(self, metric_name: str, metric_value: int):
  42. self.handler.emit(
  43. MetricData(time.time(), self.group_name, metric_name, metric_value)
  44. )
  45. _metrics_map = {}
  46. _default_metrics_handler = NullMetricHandler() # type: MetricHandler
  47. # pyre-fixme[9]: group has type `str`; used as `None`.
  48. def configure(handler: MetricHandler, group: str = None):
  49. if group is None:
  50. global _default_metrics_handler
  51. # pyre-fixme[9]: _default_metrics_handler has type `NullMetricHandler`; used
  52. # as `MetricHandler`.
  53. _default_metrics_handler = handler
  54. else:
  55. _metrics_map[group] = handler
  56. def getStream(group: str):
  57. if group in _metrics_map:
  58. handler = _metrics_map[group]
  59. else:
  60. handler = _default_metrics_handler
  61. return MetricStream(group, handler)
  62. def _get_metric_name(fn):
  63. qualname = fn.__qualname__
  64. split = qualname.split(".")
  65. if len(split) == 1:
  66. module = fn.__module__
  67. if module:
  68. return module.split(".")[-1] + "." + split[0]
  69. else:
  70. return split[0]
  71. else:
  72. return qualname
  73. def prof(fn=None, group: str = "torchelastic"):
  74. r"""
  75. @profile decorator publishes duration.ms, count, success, failure
  76. metrics for the function that it decorates. The metric name defaults
  77. to the qualified name (``class_name.def_name``) of the function.
  78. If the function does not belong to a class, it uses the leaf module name
  79. instead.
  80. Usage
  81. ::
  82. @metrics.prof
  83. def x():
  84. pass
  85. @metrics.prof(group="agent")
  86. def y():
  87. pass
  88. """
  89. def wrap(f):
  90. @wraps(f)
  91. def wrapper(*args, **kwargs):
  92. key = _get_metric_name(f)
  93. try:
  94. start = time.time()
  95. result = f(*args, **kwargs)
  96. put_metric(f"{key}.success", 1, group)
  97. except Exception:
  98. put_metric(f"{key}.failure", 1, group)
  99. raise
  100. finally:
  101. put_metric(f"{key}.duration.ms", get_elapsed_time_ms(start), group)
  102. return result
  103. return wrapper
  104. if fn:
  105. return wrap(fn)
  106. else:
  107. return wrap
  108. def profile(group=None):
  109. """
  110. @profile decorator adds latency and success/failure metrics to any given function.
  111. Usage
  112. ::
  113. @metrics.profile("my_metric_group")
  114. def some_function(<arguments>):
  115. """
  116. warnings.warn("Deprecated, use @prof instead", DeprecationWarning)
  117. def wrap(func):
  118. @wraps(func)
  119. def wrapper(*args, **kwargs):
  120. try:
  121. start_time = time.time()
  122. result = func(*args, **kwargs)
  123. publish_metric(group, "{}.success".format(func.__name__), 1)
  124. except Exception:
  125. publish_metric(group, "{}.failure".format(func.__name__), 1)
  126. raise
  127. finally:
  128. publish_metric(
  129. group,
  130. "{}.duration.ms".format(func.__name__),
  131. get_elapsed_time_ms(start_time),
  132. )
  133. return result
  134. return wrapper
  135. return wrap
  136. def put_metric(metric_name: str, metric_value: int, metric_group: str = "torchelastic"):
  137. """
  138. Publishes a metric data point.
  139. Usage
  140. ::
  141. put_metric("metric_name", 1)
  142. put_metric("metric_name", 1, "metric_group_name")
  143. """
  144. getStream(metric_group).add_value(metric_name, metric_value)
  145. def publish_metric(metric_group: str, metric_name: str, metric_value: int):
  146. warnings.warn(
  147. "Deprecated, use put_metric(metric_group)(metric_name, metric_value) instead"
  148. )
  149. metric_stream = getStream(metric_group)
  150. metric_stream.add_value(metric_name, metric_value)
  151. def get_elapsed_time_ms(start_time_in_seconds: float):
  152. """
  153. Returns the elapsed time in millis from the given start time.
  154. """
  155. end_time = time.time()
  156. return int((end_time - start_time_in_seconds) * 1000)