/openrasp_iast/core/components/result_receiver.py

https://github.com/baidu-security/openrasp-iast · Python · 123 lines · 74 code · 10 blank · 39 comment · 18 complexity · e6df7a22749cd68af382791949db6334 MD5 · raw file

  1. #!/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. """
  4. Copyright 2017-2020 Baidu Inc.
  5. Licensed under the Apache License, Version 2.0 (the "License");
  6. you may not use this file except in compliance with the License.
  7. You may obtain a copy of the License at
  8. http://www.apache.org/licenses/LICENSE-2.0
  9. Unless required by applicable law or agreed to in writing, software
  10. distributed under the License is distributed on an "AS IS" BASIS,
  11. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. See the License for the specific language governing permissions and
  13. limitations under the License.
  14. """
  15. import time
  16. import asyncio
  17. import collections
  18. from core.components import exceptions
  19. from core.components.logger import Logger
  20. from core.components.config import Config
  21. from core.components.communicator import Communicator
  22. class RaspResultReceiver(object):
  23. """
  24. 缓存扫描请求的RaspResult, 并通知对应扫描进程获取结果
  25. """
  26. def __new__(cls):
  27. """
  28. 单例模式初始化
  29. """
  30. if not hasattr(cls, 'instance'):
  31. cls.instance = super(RaspResultReceiver, cls).__new__(cls)
  32. # request_id 为key ,每个item为一个list结构为: [获取到result的event, 过期时间, 获取到的结果(未获取前为None)]
  33. # 例如 {scan_request_id_1: [event_1, expire_time1, result_dict_1] , scan_request_id_2:[event_2, expire_time2, None] ...}
  34. cls.instance.rasp_result_collection = collections.OrderedDict()
  35. cls.instance.timeout = Config().get_config("scanner.request_timeout") * \
  36. (Config().get_config("scanner.retry_times") + 1)
  37. return cls.instance
  38. def register_result(self, req_id):
  39. """
  40. 注册待接收的扫描请求的结果id注册后调用wait_result等待返回结果
  41. Parameters:
  42. req_id - 结果的scan_request_id
  43. """
  44. expire_time = time.time() + (self.timeout * 2)
  45. self.rasp_result_collection[req_id] = [
  46. asyncio.Event(), expire_time, None]
  47. def add_result(self, rasp_result):
  48. """
  49. 添加一个RaspResult实例到缓存队列并触发对应的数据到达事件, 同时清空缓存中过期的实例
  50. 若RaspResult实例的id未通过register_result方法注册则直接丢弃
  51. Parameters:
  52. rasp_result - 待添加的RaspResult实例
  53. """
  54. scan_request_id = rasp_result.get_scan_request_id()
  55. try:
  56. self.rasp_result_collection[scan_request_id][2] = rasp_result
  57. self.rasp_result_collection[scan_request_id][0].set()
  58. except KeyError:
  59. Communicator().increase_value("dropped_rasp_result")
  60. Logger().warning("Drop no registered rasp result data: {}".format(str(rasp_result)))
  61. while True:
  62. try:
  63. key = next(iter(self.rasp_result_collection))
  64. except StopIteration:
  65. break
  66. if self.rasp_result_collection[key][1] < time.time():
  67. if type(self.rasp_result_collection[key][0]) is not dict:
  68. Logger().debug("Rasp result with id: {} timeout, dropped".format(key))
  69. self.rasp_result_collection.popitem(False)
  70. else:
  71. break
  72. async def wait_result(self, req_id):
  73. """
  74. 异步等待一个扫描请求的RaspResult结果
  75. Parameters:
  76. req_id - str, 等待请求的scan_request_id
  77. Returns:
  78. 获取到的扫描请求结果的RaspResult实例
  79. Rasise:
  80. exceptions.GetRaspResultFailed - 等待超时或请求id未使用register_result方法注册时引发此异常
  81. """
  82. try:
  83. expire_time = self.rasp_result_collection[req_id][1]
  84. event = self.rasp_result_collection[req_id][0]
  85. except KeyError:
  86. Logger().warning("Try to wait not exist result with request id " + req_id)
  87. raise exceptions.GetRaspResultFailed
  88. else:
  89. if type(event) == dict:
  90. return event
  91. timeout = expire_time - time.time()
  92. timeout = timeout if timeout > 0 else 0.01
  93. try:
  94. Logger().debug("Start waiting rasp result, id: " + req_id)
  95. await asyncio.wait_for(event.wait(), timeout=timeout)
  96. except asyncio.TimeoutError:
  97. Logger().warning("Timeout when wait rasp result, id: " + req_id)
  98. Communicator().increase_value("rasp_result_timeout")
  99. raise exceptions.GetRaspResultFailed
  100. else:
  101. result = self.rasp_result_collection.get(
  102. req_id, (None, None, None))[2]
  103. Logger().debug("Got rasp result, scan-request-id: {}".format(req_id, str(result)))
  104. return result