/docqa/server/web_searcher.py

https://github.com/allenai/document-qa · Python · 93 lines · 62 code · 16 blank · 15 comment · 12 complexity · 9708dbf05dd26b6be3ed80a824999cd5 MD5 · raw file

  1. from typing import List, Dict, Optional
  2. import logging
  3. import ujson
  4. import asyncio
  5. from aiohttp import ClientSession
  6. from os.path import exists
  7. BING_API = "https://api.cognitive.microsoft.com/bing/"
  8. class AsyncWebSearcher(object):
  9. """ Runs search requests and returns the results """
  10. def __init__(self, bing_api, bing_version, loop=None):
  11. if bing_api is None or not isinstance(bing_api, str):
  12. raise ValueError("Need a string Bing API key")
  13. self.bing_api = bing_api
  14. self.url = BING_API + bing_version + "/search"
  15. self.cl_sess = ClientSession(headers={"Ocp-Apim-Subscription-Key": self.bing_api}, loop=loop)
  16. async def run_search(self, question: str, n_docs: int) -> List[Dict]:
  17. # avoid quoting the entire question, some triviaqa questions have this form
  18. # TODO is this the right place to do this?
  19. question = question.strip("\"\' ")
  20. async with self.cl_sess.get(url=self.url, params=dict(count=n_docs, q=question, mkt="en-US")) as resp:
  21. data = await resp.json()
  22. if resp.status != 200:
  23. raise ValueError("Web search error %s" % data)
  24. if "webPages" not in data:
  25. return []
  26. else:
  27. return data["webPages"]["value"]
  28. def close(self):
  29. self.cl_sess.close()
  30. class ExtractedWebDoc(object):
  31. def __init__(self, ur: str, text: str):
  32. self.url = ur
  33. self.text = text
  34. class AsyncBoilerpipeCliExtractor(object):
  35. """
  36. Downloads documents from URLs and returns the extracted text
  37. TriviaQA used boilerpipe (https://github.com/kohlschutter/boilerpipe) to extract the
  38. "main" pieces of text from web documents. There is, far as I can tell, no complete
  39. python re-implementation so far the moment we shell out to a jar file (boilerpipe.jar)
  40. which downloads files from the given URLs and runs them through boilerpipe's extraction code
  41. using multiple threads.
  42. """
  43. JAR = "docqa/server/boilerpipe.jar"
  44. def __init__(self, n_threads: int=10, timeout: int=None,
  45. process_additional_timeout: Optional[int]=5):
  46. """
  47. :param n_threads: Number of threads to use when downloading urls
  48. :param timeout: Time to wait while downloading urls, if the time limit is reached
  49. downloads that are still hanging will be returned as errors
  50. :param process_additional_timeout: How long to wait for the downloading sub-process to return,
  51. in addition to `timeout`. If this timeout is hit no results will
  52. be returned, so this is a last-resort to stop the server from freezing
  53. """
  54. self.log = logging.getLogger('downloader')
  55. if not exists(self.JAR):
  56. raise ValueError("Could not find boilerpipe jar")
  57. self.timeout = timeout
  58. self.n_threads = n_threads
  59. if self.timeout is None:
  60. self.proc_timeout = None
  61. else:
  62. self.proc_timeout = timeout + process_additional_timeout
  63. async def get_text(self, urls: List[str]) -> List[ExtractedWebDoc]:
  64. process = await asyncio.create_subprocess_exec(
  65. "java", "-jar", self.JAR, *urls, "-t", str(self.n_threads),
  66. "-l", str(self.timeout),
  67. stdout=asyncio.subprocess.PIPE)
  68. stdout, stderr = await asyncio.wait_for(process.communicate(),
  69. timeout=self.proc_timeout)
  70. text = stdout.decode("utf-8")
  71. data = ujson.loads(text)
  72. ex = data["extracted"]
  73. errors = data["error"]
  74. if len(errors) > 0:
  75. self.log.info("%d extraction errors: %s" % (len(errors), str(list(errors.items()))))
  76. return [ExtractedWebDoc(url, ex[url]) for url in urls if url in ex]