/scripts/perf_benchmark.py

https://github.com/ethereum/lahja · Python · 152 lines · 129 code · 16 blank · 7 comment · 17 complexity · 337e11e3a039ef9d1a5ef131c7fcc985 MD5 · raw file

  1. import argparse
  2. import multiprocessing
  3. import os
  4. import tempfile
  5. from lahja import ConnectionConfig
  6. from lahja.tools.benchmark.backends import AsyncioBackend, BaseBackend, TrioBackend
  7. from lahja.tools.benchmark.constants import (
  8. DRIVER_ENDPOINT,
  9. REPORTER_ENDPOINT,
  10. ROOT_ENDPOINT,
  11. )
  12. from lahja.tools.benchmark.logging import setup_stderr_lahja_logging
  13. from lahja.tools.benchmark.process import (
  14. BroadcastConsumer,
  15. BroadcastDriver,
  16. ConsumerConfig,
  17. DriverProcessConfig,
  18. ReportingProcess,
  19. ReportingProcessConfig,
  20. RequestConsumer,
  21. RequestDriver,
  22. )
  23. from lahja.tools.benchmark.typing import ShutdownEvent
  24. from lahja.tools.benchmark.utils.config import (
  25. create_consumer_endpoint_configs,
  26. create_consumer_endpoint_name,
  27. )
  28. parser = argparse.ArgumentParser()
  29. parser.add_argument(
  30. "--num-processes",
  31. type=int,
  32. default=10,
  33. help="The number of processes listening for events",
  34. )
  35. parser.add_argument(
  36. "--num-events", type=int, default=100, help="The number of events propagated"
  37. )
  38. parser.add_argument(
  39. "--throttle",
  40. type=float,
  41. default=0.0,
  42. help="The time to wait between propagating events",
  43. )
  44. parser.add_argument(
  45. "--payload-bytes", type=int, default=1, help="The payload of each event in bytes"
  46. )
  47. parser.add_argument("--backend", action="append", help="The endpoint backend to use")
  48. parser.add_argument(
  49. "--mode",
  50. default="broadcast",
  51. choices=("broadcast", "request"),
  52. help="benchmarks request/response round trip",
  53. )
  54. parser.add_argument(
  55. "--enable-debug-logs",
  56. action="store_true",
  57. help="Turns on DEBUG level logging to stderr to the `lahja` namespaced loggers.",
  58. )
  59. async def run(args: argparse.Namespace, backend: BaseBackend):
  60. if args.mode == "broadcast":
  61. DriverClass = BroadcastDriver
  62. ConsumerClass = BroadcastConsumer
  63. elif args.mode == "request":
  64. DriverClass = RequestDriver
  65. ConsumerClass = RequestConsumer
  66. else:
  67. raise Exception(f"Unknown mode: '{args.mode}'")
  68. consumer_endpoint_configs = create_consumer_endpoint_configs(args.num_processes)
  69. (
  70. config.path.unlink()
  71. for config in consumer_endpoint_configs
  72. + tuple(
  73. ConnectionConfig.from_name(name)
  74. for name in (ROOT_ENDPOINT, REPORTER_ENDPOINT, DRIVER_ENDPOINT)
  75. )
  76. )
  77. root_config = ConnectionConfig.from_name(ROOT_ENDPOINT)
  78. async with backend.Endpoint.serve(root_config) as root:
  79. # The reporter process is collecting statistical events from all consumer processes
  80. # For some reason, doing this work in the main process didn't end so well which is
  81. # why it was moved into a dedicated process. Notice that this will slightly skew results
  82. # as the reporter process will also receive events which we don't account for
  83. reporting_config = ReportingProcessConfig(
  84. num_events=args.num_events,
  85. num_processes=args.num_processes,
  86. throttle=args.throttle,
  87. payload_bytes=args.payload_bytes,
  88. backend=backend,
  89. debug_logging=args.enable_debug_logs,
  90. )
  91. reporter = ReportingProcess(reporting_config)
  92. reporter.start()
  93. consumer_config = ConsumerConfig(
  94. num_events=args.num_events,
  95. backend=backend,
  96. debug_logging=args.enable_debug_logs,
  97. )
  98. for n in range(args.num_processes):
  99. consumer_process = ConsumerClass(
  100. create_consumer_endpoint_name(n), consumer_config
  101. )
  102. consumer_process.start()
  103. # In this benchmark, this is the only process that is flooding events
  104. driver_config = DriverProcessConfig(
  105. connected_endpoints=consumer_endpoint_configs,
  106. num_events=args.num_events,
  107. throttle=args.throttle,
  108. payload_bytes=args.payload_bytes,
  109. backend=backend,
  110. debug_logging=args.enable_debug_logs,
  111. )
  112. driver = DriverClass(driver_config)
  113. driver.start()
  114. await root.wait_for(ShutdownEvent)
  115. driver.stop()
  116. if __name__ == "__main__":
  117. args = parser.parse_args()
  118. # WARNING: The `fork` method does not work well with asyncio yet.
  119. # This might change with Python 3.8 (See https://bugs.python.org/issue22087#msg318140)
  120. multiprocessing.set_start_method("spawn")
  121. if args.enable_debug_logs:
  122. setup_stderr_lahja_logging()
  123. for backend_str in args.backend or ["asyncio"]:
  124. if backend_str == "asyncio":
  125. backend = AsyncioBackend()
  126. elif backend_str == "trio":
  127. backend = TrioBackend()
  128. else:
  129. raise Exception(f"Unrecognized backend: {args.backend}")
  130. original_dir = os.getcwd()
  131. with tempfile.TemporaryDirectory() as base_dir:
  132. os.chdir(base_dir)
  133. try:
  134. backend.run(run, args, backend)
  135. finally:
  136. os.chdir(original_dir)