/robot/NLU.py

https://github.com/wzpan/wukong-robot · Python · 218 lines · 82 code · 29 blank · 107 comment · 12 complexity · 49ba16b3f5b0ab98890e911241537059 MD5 · raw file

  1. # -*- coding: utf-8-*-
  2. from .sdk import unit
  3. from robot import logging
  4. from abc import ABCMeta, abstractmethod
  5. logger = logging.getLogger(__name__)
  6. class AbstractNLU(object):
  7. """
  8. Generic parent class for all NLU engines
  9. """
  10. __metaclass__ = ABCMeta
  11. @classmethod
  12. def get_config(cls):
  13. return {}
  14. @classmethod
  15. def get_instance(cls):
  16. profile = cls.get_config()
  17. instance = cls(**profile)
  18. return instance
  19. @abstractmethod
  20. def parse(self, query, **args):
  21. """
  22. 进行 NLU 解析
  23. :param query: 用户的指令字符串
  24. :param **args: 可选的参数
  25. """
  26. return None
  27. @abstractmethod
  28. def getIntent(self, parsed):
  29. """
  30. 提取意图
  31. :param parsed: 解析结果
  32. :returns: 意图数组
  33. """
  34. return None
  35. @abstractmethod
  36. def hasIntent(self, parsed, intent):
  37. """
  38. 判断是否包含某个意图
  39. :param parsed: 解析结果
  40. :param intent: 意图的名称
  41. :returns: True: 包含; False: 不包含
  42. """
  43. return False
  44. @abstractmethod
  45. def getSlots(self, parsed, intent):
  46. """
  47. 提取某个意图的所有词槽
  48. :param parsed: 解析结果
  49. :param intent: 意图的名称
  50. :returns: 词槽列表你可以通过 name 属性筛选词槽
  51. 再通过 normalized_word 属性取出相应的值
  52. """
  53. return None
  54. @abstractmethod
  55. def getSlotWords(self, parsed, intent, name):
  56. """
  57. 找出命中某个词槽的内容
  58. :param parsed: 解析结果
  59. :param intent: 意图的名称
  60. :param name: 词槽名
  61. :returns: 命中该词槽的值的列表
  62. """
  63. return None
  64. @abstractmethod
  65. def getSay(self, parsed, intent):
  66. """
  67. 提取回复文本
  68. :param parsed: 解析结果
  69. :param intent: 意图的名称
  70. :returns: 回复文本
  71. """
  72. return ""
  73. class UnitNLU(AbstractNLU):
  74. """
  75. 百度UNIT的NLU API.
  76. """
  77. SLUG = "unit"
  78. def __init__(self):
  79. super(self.__class__, self).__init__()
  80. @classmethod
  81. def get_config(cls):
  82. """
  83. 百度UNIT的配置
  84. 无需配置所以返回 {}
  85. """
  86. return {}
  87. def parse(self, query, **args):
  88. """
  89. 使用百度 UNIT 进行 NLU 解析
  90. :param query: 用户的指令字符串
  91. :param **args: UNIT 的相关参数
  92. - service_id: UNIT service_id
  93. - api_key: UNIT apk_key
  94. - secret_key: UNIT secret_key
  95. :returns: UNIT 解析结果如果解析失败返回 None
  96. """
  97. if 'service_id' not in args or \
  98. 'api_key' not in args or \
  99. 'secret_key' not in args:
  100. logger.critical('{} NLU 失败:参数错误!'.format(self.SLUG))
  101. return None
  102. return unit.getUnit(query,
  103. args['service_id'],
  104. args['api_key'],
  105. args['secret_key'])
  106. def getIntent(self, parsed):
  107. """
  108. 提取意图
  109. :param parsed: 解析结果
  110. :returns: 意图数组
  111. """
  112. return unit.getIntent(parsed)
  113. def hasIntent(self, parsed, intent):
  114. """
  115. 判断是否包含某个意图
  116. :param parsed: UNIT 解析结果
  117. :param intent: 意图的名称
  118. :returns: True: 包含; False: 不包含
  119. """
  120. return unit.hasIntent(parsed, intent)
  121. def getSlots(self, parsed, intent):
  122. """
  123. 提取某个意图的所有词槽
  124. :param parsed: UNIT 解析结果
  125. :param intent: 意图的名称
  126. :returns: 词槽列表你可以通过 name 属性筛选词槽
  127. 再通过 normalized_word 属性取出相应的值
  128. """
  129. return unit.getSlots(parsed, intent)
  130. def getSlotWords(self, parsed, intent, name):
  131. """
  132. 找出命中某个词槽的内容
  133. :param parsed: UNIT 解析结果
  134. :param intent: 意图的名称
  135. :param name: 词槽名
  136. :returns: 命中该词槽的值的列表
  137. """
  138. return unit.getSlotWords(parsed, intent, name)
  139. def getSay(self, parsed, intent):
  140. """
  141. 提取 UNIT 的回复文本
  142. :param parsed: UNIT 解析结果
  143. :param intent: 意图的名称
  144. :returns: UNIT 的回复文本
  145. """
  146. return unit.getSay(parsed, intent)
  147. def get_engine_by_slug(slug=None):
  148. """
  149. Returns:
  150. An NLU Engine implementation available on the current platform
  151. Raises:
  152. ValueError if no speaker implementation is supported on this platform
  153. """
  154. if not slug or type(slug) is not str:
  155. raise TypeError("无效的 NLU slug '%s'", slug)
  156. selected_engines = list(filter(lambda engine: hasattr(engine, "SLUG") and
  157. engine.SLUG == slug, get_engines()))
  158. if len(selected_engines) == 0:
  159. raise ValueError("错误:找不到名为 {} 的 NLU 引擎".format(slug))
  160. else:
  161. if len(selected_engines) > 1:
  162. logger.warning("注意: 有多个 NLU 名称与指定的引擎名 {} 匹配").format(slug)
  163. engine = selected_engines[0]
  164. logger.info("使用 {} NLU 引擎".format(engine.SLUG))
  165. return engine.get_instance()
  166. def get_engines():
  167. def get_subclasses(cls):
  168. subclasses = set()
  169. for subclass in cls.__subclasses__():
  170. subclasses.add(subclass)
  171. subclasses.update(get_subclasses(subclass))
  172. return subclasses
  173. return [engine for engine in
  174. list(get_subclasses(AbstractNLU))
  175. if hasattr(engine, 'SLUG') and engine.SLUG]