PageRenderTime 44ms CodeModel.GetById 9ms RepoModel.GetById 0ms app.codeStats 0ms

/scrapy/tests/test_engine.py

http://github.com/scrapy/scrapy
Python | 207 lines | 153 code | 39 blank | 15 comment | 22 complexity | 24485d4f5ea4c4d5a524bd6e3f2b2973 MD5 | raw file
Possible License(s): BSD-3-Clause
  1. """
  2. Scrapy engine tests
  3. This starts a testing web server (using twisted.server.Site) and then crawls it
  4. with the Scrapy crawler.
  5. To view the testing web server in a browser you can start it by running this
  6. module with the ``runserver`` argument::
  7. python test_engine.py runserver
  8. """
  9. import sys, os, re, urlparse
  10. from twisted.internet import reactor, defer
  11. from twisted.web import server, static, util
  12. from twisted.trial import unittest
  13. from scrapy import signals
  14. from scrapy.utils.test import get_crawler
  15. from scrapy.xlib.pydispatch import dispatcher
  16. from scrapy.tests import tests_datadir
  17. from scrapy.spider import BaseSpider
  18. from scrapy.item import Item, Field
  19. from scrapy.contrib.linkextractors.sgml import SgmlLinkExtractor
  20. from scrapy.http import Request
  21. from scrapy.utils.signal import disconnect_all
  22. class TestItem(Item):
  23. name = Field()
  24. url = Field()
  25. price = Field()
  26. class TestSpider(BaseSpider):
  27. name = "scrapytest.org"
  28. allowed_domains = ["scrapytest.org", "localhost"]
  29. itemurl_re = re.compile("item\d+.html")
  30. name_re = re.compile("<h1>(.*?)</h1>", re.M)
  31. price_re = re.compile(">Price: \$(.*?)<", re.M)
  32. def parse(self, response):
  33. xlink = SgmlLinkExtractor()
  34. itemre = re.compile(self.itemurl_re)
  35. for link in xlink.extract_links(response):
  36. if itemre.search(link.url):
  37. yield Request(url=link.url, callback=self.parse_item)
  38. def parse_item(self, response):
  39. item = TestItem()
  40. m = self.name_re.search(response.body)
  41. if m:
  42. item['name'] = m.group(1)
  43. item['url'] = response.url
  44. m = self.price_re.search(response.body)
  45. if m:
  46. item['price'] = m.group(1)
  47. return item
  48. def start_test_site(debug=False):
  49. root_dir = os.path.join(tests_datadir, "test_site")
  50. r = static.File(root_dir)
  51. r.putChild("redirect", util.Redirect("/redirected"))
  52. r.putChild("redirected", static.Data("Redirected here", "text/plain"))
  53. port = reactor.listenTCP(0, server.Site(r), interface="127.0.0.1")
  54. if debug:
  55. print "Test server running at http://localhost:%d/ - hit Ctrl-C to finish." \
  56. % port.getHost().port
  57. return port
  58. class CrawlerRun(object):
  59. """A class to run the crawler and keep track of events occurred"""
  60. def __init__(self):
  61. self.spider = None
  62. self.respplug = []
  63. self.reqplug = []
  64. self.itemresp = []
  65. self.signals_catched = {}
  66. def run(self):
  67. self.port = start_test_site()
  68. self.portno = self.port.getHost().port
  69. start_urls = [self.geturl("/"), self.geturl("/redirect")]
  70. self.spider = TestSpider(start_urls=start_urls)
  71. for name, signal in vars(signals).items():
  72. if not name.startswith('_'):
  73. dispatcher.connect(self.record_signal, signal)
  74. dispatcher.connect(self.item_scraped, signals.item_scraped)
  75. dispatcher.connect(self.request_received, signals.request_received)
  76. dispatcher.connect(self.response_downloaded, signals.response_downloaded)
  77. self.crawler = get_crawler()
  78. self.crawler.install()
  79. self.crawler.configure()
  80. self.crawler.crawl(self.spider)
  81. self.crawler.start()
  82. self.deferred = defer.Deferred()
  83. dispatcher.connect(self.stop, signals.engine_stopped)
  84. return self.deferred
  85. def stop(self):
  86. self.port.stopListening()
  87. for name, signal in vars(signals).items():
  88. if not name.startswith('_'):
  89. disconnect_all(signal)
  90. self.crawler.uninstall()
  91. self.deferred.callback(None)
  92. def geturl(self, path):
  93. return "http://localhost:%s%s" % (self.portno, path)
  94. def getpath(self, url):
  95. u = urlparse.urlparse(url)
  96. return u.path
  97. def item_scraped(self, item, spider, response):
  98. self.itemresp.append((item, response))
  99. def request_received(self, request, spider):
  100. self.reqplug.append((request, spider))
  101. def response_downloaded(self, response, spider):
  102. self.respplug.append((response, spider))
  103. def record_signal(self, *args, **kwargs):
  104. """Record a signal and its parameters"""
  105. signalargs = kwargs.copy()
  106. sig = signalargs.pop('signal')
  107. signalargs.pop('sender', None)
  108. self.signals_catched[sig] = signalargs
  109. class EngineTest(unittest.TestCase):
  110. @defer.inlineCallbacks
  111. def test_crawler(self):
  112. self.run = CrawlerRun()
  113. yield self.run.run()
  114. self._assert_visited_urls()
  115. self._assert_received_requests()
  116. self._assert_downloaded_responses()
  117. self._assert_scraped_items()
  118. self._assert_signals_catched()
  119. def _assert_visited_urls(self):
  120. must_be_visited = ["/", "/redirect", "/redirected",
  121. "/item1.html", "/item2.html", "/item999.html"]
  122. urls_visited = set([rp[0].url for rp in self.run.respplug])
  123. urls_expected = set([self.run.geturl(p) for p in must_be_visited])
  124. assert urls_expected <= urls_visited, "URLs not visited: %s" % list(urls_expected - urls_visited)
  125. def _assert_received_requests(self):
  126. # 3 requests should be received from the spider. start_urls and redirects don't count
  127. self.assertEqual(3, len(self.run.reqplug))
  128. paths_expected = ['/item999.html', '/item2.html', '/item1.html']
  129. urls_requested = set([rq[0].url for rq in self.run.reqplug])
  130. urls_expected = set([self.run.geturl(p) for p in paths_expected])
  131. assert urls_expected <= urls_requested
  132. def _assert_downloaded_responses(self):
  133. # response tests
  134. self.assertEqual(6, len(self.run.respplug))
  135. for response, _ in self.run.respplug:
  136. if self.run.getpath(response.url) == '/item999.html':
  137. self.assertEqual(404, response.status)
  138. if self.run.getpath(response.url) == '/redirect':
  139. self.assertEqual(302, response.status)
  140. def _assert_scraped_items(self):
  141. self.assertEqual(2, len(self.run.itemresp))
  142. for item, response in self.run.itemresp:
  143. self.assertEqual(item['url'], response.url)
  144. if 'item1.html' in item['url']:
  145. self.assertEqual('Item 1 name', item['name'])
  146. self.assertEqual('100', item['price'])
  147. if 'item2.html' in item['url']:
  148. self.assertEqual('Item 2 name', item['name'])
  149. self.assertEqual('200', item['price'])
  150. def _assert_signals_catched(self):
  151. assert signals.engine_started in self.run.signals_catched
  152. assert signals.engine_stopped in self.run.signals_catched
  153. assert signals.spider_opened in self.run.signals_catched
  154. assert signals.spider_idle in self.run.signals_catched
  155. assert signals.spider_closed in self.run.signals_catched
  156. self.assertEqual({'spider': self.run.spider},
  157. self.run.signals_catched[signals.spider_opened])
  158. self.assertEqual({'spider': self.run.spider},
  159. self.run.signals_catched[signals.spider_idle])
  160. self.assertEqual({'spider': self.run.spider, 'reason': 'finished'},
  161. self.run.signals_catched[signals.spider_closed])
  162. if __name__ == "__main__":
  163. if len(sys.argv) > 1 and sys.argv[1] == 'runserver':
  164. start_test_site(debug=True)
  165. reactor.run()