PageRenderTime 197ms CodeModel.GetById 26ms RepoModel.GetById 1ms app.codeStats 1ms

/r2/r2/lib/cloudsearch.py

https://github.com/stevewilber/reddit
Python | 1020 lines | 890 code | 87 blank | 43 comment | 71 complexity | 08b658ef3de6356485b0a4426dd5b09d MD5 | raw file
Possible License(s): MPL-2.0-no-copyleft-exception, Apache-2.0
  1. # The contents of this file are subject to the Common Public Attribution
  2. # License Version 1.0. (the "License"); you may not use this file except in
  3. # compliance with the License. You may obtain a copy of the License at
  4. # http://code.reddit.com/LICENSE. The License is based on the Mozilla Public
  5. # License Version 1.1, but Sections 14 and 15 have been added to cover use of
  6. # software over a computer network and provide for limited attribution for the
  7. # Original Developer. In addition, Exhibit A has been modified to be consistent
  8. # with Exhibit B.
  9. #
  10. # Software distributed under the License is distributed on an "AS IS" basis,
  11. # WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License for
  12. # the specific language governing rights and limitations under the License.
  13. #
  14. # The Original Code is reddit.
  15. #
  16. # The Original Developer is the Initial Developer. The Initial Developer of
  17. # the Original Code is reddit Inc.
  18. #
  19. # All portions of the code written by reddit are Copyright (c) 2006-2012 reddit
  20. # Inc. All Rights Reserved.
  21. ###############################################################################
  22. import collections
  23. import cPickle as pickle
  24. from datetime import datetime
  25. import functools
  26. import httplib
  27. import json
  28. from lxml import etree
  29. from pylons import g, c
  30. import re
  31. import time
  32. import urllib
  33. import l2cs
  34. from r2.lib import amqp, filters
  35. from r2.lib.db.operators import desc
  36. import r2.lib.utils as r2utils
  37. from r2.models import (Account, Link, Subreddit, Thing, All, DefaultSR,
  38. MultiReddit, DomainSR, Friends, ModContribSR,
  39. FakeSubreddit, NotFound)
  40. _CHUNK_SIZE = 4000000 # Approx. 4 MB, to stay under the 5MB limit
  41. _VERSION_OFFSET = 13257906857
  42. ILLEGAL_XML = re.compile(u'[\x00-\x08\x0b\x0c\x0e-\x1F\uD800-\uDFFF\uFFFE\uFFFF]')
  43. def _safe_xml_str(s, use_encoding="utf-8"):
  44. '''Replace invalid-in-XML unicode control characters with '\uFFFD'.
  45. Also, coerces result to unicode
  46. '''
  47. if not isinstance(s, unicode):
  48. if isinstance(s, str):
  49. s = unicode(s, use_encoding, errors="replace")
  50. else:
  51. # ints will raise TypeError if the "errors" kwarg
  52. # is passed, but since it's not a str no problem
  53. s = unicode(s)
  54. s = ILLEGAL_XML.sub(u"\uFFFD", s)
  55. return s
  56. def safe_get(get_fn, ids, return_dict=True, **kw):
  57. items = {}
  58. for i in ids:
  59. try:
  60. item = get_fn(i, **kw)
  61. except NotFound:
  62. g.log.info("%r failed for %r", get_fn, i)
  63. else:
  64. items[i] = item
  65. if return_dict:
  66. return items
  67. else:
  68. return items.values()
  69. class CloudSearchHTTPError(httplib.HTTPException): pass
  70. class InvalidQuery(Exception): pass
  71. Field = collections.namedtuple("Field", "name cloudsearch_type "
  72. "lucene_type function")
  73. SAME_AS_CLOUDSEARCH = object()
  74. FIELD_TYPES = (int, str, datetime, SAME_AS_CLOUDSEARCH, "yesno")
  75. def field(name=None, cloudsearch_type=str, lucene_type=SAME_AS_CLOUDSEARCH):
  76. if lucene_type is SAME_AS_CLOUDSEARCH:
  77. lucene_type = cloudsearch_type
  78. if cloudsearch_type not in FIELD_TYPES + (None,):
  79. raise ValueError("cloudsearch_type %r not in %r" %
  80. (cloudsearch_type, FIELD_TYPES))
  81. if lucene_type not in FIELD_TYPES + (None,):
  82. raise ValueError("lucene_type %r not in %r" %
  83. (lucene_type, FIELD_TYPES))
  84. if callable(name):
  85. # Simple case; decorated as '@field'; act as a decorator instead
  86. # of a decorator factory
  87. function = name
  88. name = None
  89. else:
  90. function = None
  91. def field_inner(fn):
  92. fn.field = Field(name or fn.func_name, cloudsearch_type,
  93. lucene_type, fn)
  94. return fn
  95. if function:
  96. return field_inner(function)
  97. else:
  98. return field_inner
  99. class FieldsMeta(type):
  100. def __init__(cls, name, bases, attrs):
  101. type.__init__(cls, name, bases, attrs)
  102. fields = []
  103. for attr in attrs.itervalues():
  104. if hasattr(attr, "field"):
  105. fields.append(attr.field)
  106. cls._fields = tuple(fields)
  107. class FieldsBase(object):
  108. __metaclass__ = FieldsMeta
  109. def fields(self):
  110. data = {}
  111. for field in self._fields:
  112. if field.cloudsearch_type is None:
  113. continue
  114. val = field.function(self)
  115. if val is not None:
  116. data[field.name] = val
  117. return data
  118. @classmethod
  119. def all_fields(cls):
  120. return cls._fields
  121. @classmethod
  122. def cloudsearch_fields(cls, type_=None, types=FIELD_TYPES):
  123. types = (type_,) if type_ else types
  124. return [f for f in cls._fields if f.cloudsearch_type in types]
  125. @classmethod
  126. def lucene_fields(cls, type_=None, types=FIELD_TYPES):
  127. types = (type_,) if type_ else types
  128. return [f for f in cls._fields if f.lucene_type in types]
  129. @classmethod
  130. def cloudsearch_fieldnames(cls, type_=None, types=FIELD_TYPES):
  131. return [f.name for f in cls.cloudsearch_fields(type_=type_,
  132. types=types)]
  133. @classmethod
  134. def lucene_fieldnames(cls, type_=None, types=FIELD_TYPES):
  135. return [f.name for f in cls.lucene_fields(type_=type_, types=types)]
  136. class LinkFields(FieldsBase):
  137. def __init__(self, link, author, sr):
  138. self.link = link
  139. self.author = author
  140. self.sr = sr
  141. @field(cloudsearch_type=int, lucene_type=None)
  142. def ups(self):
  143. return max(0, self.link._ups)
  144. @field(cloudsearch_type=int, lucene_type=None)
  145. def downs(self):
  146. return max(0, self.link._downs)
  147. @field(cloudsearch_type=int, lucene_type=None)
  148. def num_comments(self):
  149. return max(0, getattr(self.link, 'num_comments', 0))
  150. @field
  151. def fullname(self):
  152. return self.link._fullname
  153. @field
  154. def subreddit(self):
  155. return self.sr.name
  156. @field
  157. def reddit(self):
  158. return self.sr.name
  159. @field
  160. def title(self):
  161. return self.link.title
  162. @field(cloudsearch_type=int)
  163. def sr_id(self):
  164. return self.link.sr_id
  165. @field(cloudsearch_type=int, lucene_type=datetime)
  166. def timestamp(self):
  167. return int(time.mktime(self.link._date.utctimetuple()))
  168. @field(cloudsearch_type=int, lucene_type="yesno")
  169. def over18(self):
  170. nsfw = (self.sr.over_18 or self.link.over_18 or
  171. Link._nsfw.findall(self.link.title))
  172. return (1 if nsfw else 0)
  173. @field(cloudsearch_type=None, lucene_type="yesno")
  174. def nsfw(self):
  175. return NotImplemented
  176. @field(cloudsearch_type=int, lucene_type="yesno")
  177. def is_self(self):
  178. return (1 if self.link.is_self else 0)
  179. @field(name="self", cloudsearch_type=None, lucene_type="yesno")
  180. def self_(self):
  181. return NotImplemented
  182. @field
  183. def author_fullname(self):
  184. return self.author._fullname
  185. @field(name="author")
  186. def author_field(self):
  187. return '[deleted]' if self.author._deleted else self.author.name
  188. @field(cloudsearch_type=int)
  189. def type_id(self):
  190. return self.link._type_id
  191. @field
  192. def site(self):
  193. if self.link.is_self:
  194. return g.domain
  195. else:
  196. url = r2utils.UrlParser(self.link.url)
  197. try:
  198. return list(url.domain_permutations())
  199. except ValueError:
  200. return None
  201. @field
  202. def selftext(self):
  203. if self.link.is_self and self.link.selftext:
  204. return self.link.selftext
  205. else:
  206. return None
  207. @field
  208. def url(self):
  209. if not self.link.is_self:
  210. return self.link.url
  211. else:
  212. return None
  213. @field
  214. def flair_css_class(self):
  215. return self.link.flair_css_class
  216. @field
  217. def flair_text(self):
  218. return self.link.flair_text
  219. @field(cloudsearch_type=None, lucene_type=str)
  220. def flair(self):
  221. return NotImplemented
  222. class SubredditFields(FieldsBase):
  223. def __init__(self, sr):
  224. self.sr = sr
  225. @field
  226. def name(self):
  227. return self.sr.name
  228. @field
  229. def title(self):
  230. return self.sr.title
  231. @field(name="type")
  232. def type_(self):
  233. return self.sr.type
  234. @field
  235. def language(self):
  236. return self.sr.lang
  237. @field
  238. def header_title(self):
  239. return self.sr.header_title
  240. @field
  241. def description(self):
  242. return self.sr.public_description
  243. @field
  244. def sidebar(self):
  245. return self.sr.description
  246. @field
  247. def over18(self):
  248. return self.sr.over_18
  249. @field
  250. def link_type(self):
  251. return self.sr.link_type
  252. @field
  253. def activity(self):
  254. return self.sr._downs
  255. @field
  256. def subscribers(self):
  257. return self.sr._ups
  258. @field
  259. def type_id(self):
  260. return self.sr._type_id
  261. class CloudSearchUploader(object):
  262. use_safe_get = False
  263. types = ()
  264. def __init__(self, doc_api, things=None, version_offset=_VERSION_OFFSET):
  265. self.doc_api = doc_api
  266. self._version_offset = version_offset
  267. self.things = self.desired_things(things) if things else []
  268. @classmethod
  269. def desired_fullnames(cls, items):
  270. '''Pull fullnames that represent instances of 'types' out of items'''
  271. fullnames = set()
  272. type_ids = [type_._type_id for type_ in cls.types]
  273. for item in items:
  274. item_type = r2utils.decompose_fullname(item['fullname'])[1]
  275. if item_type in type_ids:
  276. fullnames.add(item['fullname'])
  277. return fullnames
  278. @classmethod
  279. def desired_things(cls, things):
  280. return [t for t in things if isinstance(t, cls.types)]
  281. def _version_tenths(self):
  282. '''Cloudsearch documents don't update unless the sent "version" field
  283. is higher than the one currently indexed. As our documents don't have
  284. "versions" and could in theory be updated multiple times in one second,
  285. for now, use "tenths of a second since 12:00:00.00 1/1/2012" as the
  286. "version" - this will last approximately 13 years until bumping up against
  287. the version max of 2^32 for cloudsearch docs'''
  288. return int(time.time() * 10) - self._version_offset
  289. def _version_seconds(self):
  290. return int(time.time()) - int(self._version_offset / 10)
  291. _version = _version_tenths
  292. def add_xml(self, thing, version):
  293. add = etree.Element("add", id=thing._fullname, version=str(version),
  294. lang="en")
  295. for field_name, value in self.fields(thing).iteritems():
  296. field = etree.SubElement(add, "field", name=field_name)
  297. field.text = _safe_xml_str(value)
  298. return add
  299. def delete_xml(self, thing, version=None):
  300. '''Return the cloudsearch XML representation of
  301. "delete this from the index"
  302. '''
  303. version = str(version or self._version())
  304. delete = etree.Element("delete", id=thing._fullname, version=version)
  305. return delete
  306. def delete_ids(self, ids):
  307. '''Delete documents from the index.
  308. 'ids' should be a list of fullnames
  309. '''
  310. version = self._version()
  311. deletes = [etree.Element("delete", id=id_, version=str(version))
  312. for id_ in ids]
  313. batch = etree.Element("batch")
  314. batch.extend(deletes)
  315. return self.send_documents(batch)
  316. def xml_from_things(self):
  317. '''Generate a <batch> XML tree to send to cloudsearch for
  318. adding/updating/deleting the given things
  319. '''
  320. batch = etree.Element("batch")
  321. self.batch_lookups()
  322. version = self._version()
  323. for thing in self.things:
  324. try:
  325. if thing._spam or thing._deleted:
  326. delete_node = self.delete_xml(thing, version)
  327. batch.append(delete_node)
  328. elif self.should_index(thing):
  329. add_node = self.add_xml(thing, version)
  330. batch.append(add_node)
  331. except (AttributeError, KeyError) as e:
  332. # Problem! Bail out, which means these items won't get
  333. # "consumed" from the queue. If the problem is from DB
  334. # lag or a transient issue, then the queue consumer
  335. # will succeed eventually. If it's something else,
  336. # then manually run a consumer with 'use_safe_get'
  337. # on to get past the bad Thing in the queue
  338. if not self.use_safe_get:
  339. raise
  340. else:
  341. g.log.warn("Ignoring problem on thing %r.\n\n%r",
  342. thing, e)
  343. return batch
  344. def should_index(self, thing):
  345. raise NotImplementedError
  346. def batch_lookups(self):
  347. pass
  348. def fields(self, thing):
  349. raise NotImplementedError
  350. def inject(self, quiet=False):
  351. '''Send things to cloudsearch. Return value is time elapsed, in seconds,
  352. of the communication with the cloudsearch endpoint
  353. '''
  354. xml_things = self.xml_from_things()
  355. cs_start = datetime.now(g.tz)
  356. if len(xml_things):
  357. sent = self.send_documents(xml_things)
  358. if not quiet:
  359. print sent
  360. return (datetime.now(g.tz) - cs_start).total_seconds()
  361. def send_documents(self, docs):
  362. '''Open a connection to the cloudsearch endpoint, and send the documents
  363. for indexing. Multiple requests are sent if a large number of documents
  364. are being sent (see chunk_xml())
  365. Raises CloudSearchHTTPError if the endpoint indicates a failure
  366. '''
  367. responses = []
  368. connection = httplib.HTTPConnection(self.doc_api, 80)
  369. chunker = chunk_xml(docs)
  370. try:
  371. for data in chunker:
  372. headers = {}
  373. headers['Content-Type'] = 'application/xml'
  374. # HTTPLib calculates Content-Length header automatically
  375. connection.request('POST', "/2011-02-01/documents/batch",
  376. data, headers)
  377. response = connection.getresponse()
  378. if 200 <= response.status < 300:
  379. responses.append(response.read())
  380. else:
  381. raise CloudSearchHTTPError(response.status,
  382. response.reason,
  383. response.read())
  384. finally:
  385. connection.close()
  386. return responses
  387. class LinkUploader(CloudSearchUploader):
  388. types = (Link,)
  389. def __init__(self, doc_api, things=None, version_offset=_VERSION_OFFSET):
  390. super(LinkUploader, self).__init__(doc_api, things, version_offset)
  391. self.accounts = {}
  392. self.srs = {}
  393. def fields(self, thing):
  394. '''Return fields relevant to a Link search index'''
  395. account = self.accounts[thing.author_id]
  396. sr = self.srs[thing.sr_id]
  397. return LinkFields(thing, account, sr).fields()
  398. def batch_lookups(self):
  399. author_ids = [thing.author_id for thing in self.things
  400. if hasattr(thing, 'author_id')]
  401. try:
  402. self.accounts = Account._byID(author_ids, data=True,
  403. return_dict=True)
  404. except NotFound:
  405. if self.use_safe_get:
  406. self.accounts = safe_get(Account._byID, author_ids, data=True,
  407. return_dict=True)
  408. else:
  409. raise
  410. sr_ids = [thing.sr_id for thing in self.things
  411. if hasattr(thing, 'sr_id')]
  412. try:
  413. self.srs = Subreddit._byID(sr_ids, data=True, return_dict=True)
  414. except NotFound:
  415. if self.use_safe_get:
  416. self.srs = safe_get(Subreddit._byID, sr_ids, data=True,
  417. return_dict=True)
  418. else:
  419. raise
  420. def should_index(self, thing):
  421. return (thing.promoted is None and getattr(thing, "sr_id", None) != -1)
  422. class SubredditUploader(CloudSearchUploader):
  423. types = (Subreddit,)
  424. _version = CloudSearchUploader._version_seconds
  425. def fields(self, thing):
  426. return SubredditFields(thing).fields()
  427. def should_index(self, thing):
  428. return getattr(thing, 'author_id', None) != -1
  429. def chunk_xml(xml, depth=0):
  430. '''Chunk POST data into pieces that are smaller than the 20 MB limit.
  431. Ideally, this never happens (if chunking is necessary, would be better
  432. to avoid xml'ifying before testing content_length)'''
  433. data = etree.tostring(xml)
  434. content_length = len(data)
  435. if content_length < _CHUNK_SIZE:
  436. yield data
  437. else:
  438. depth += 1
  439. print "WARNING: Chunking (depth=%s)" % depth
  440. half = len(xml) / 2
  441. left_half = xml # for ease of reading
  442. right_half = etree.Element("batch")
  443. # etree magic simultaneously removes the elements from one tree
  444. # when they are appended to a different tree
  445. right_half.append(xml[half:])
  446. for chunk in chunk_xml(left_half, depth=depth):
  447. yield chunk
  448. for chunk in chunk_xml(right_half, depth=depth):
  449. yield chunk
  450. def _run_changed(msgs, chan):
  451. '''Consume the cloudsearch_changes queue, and print reporting information
  452. on how long it took and how many remain
  453. '''
  454. start = datetime.now(g.tz)
  455. changed = [pickle.loads(msg.body) for msg in msgs]
  456. fullnames = set()
  457. fullnames.update(LinkUploader.desired_fullnames(changed))
  458. fullnames.update(SubredditUploader.desired_fullnames(changed))
  459. things = Thing._by_fullname(fullnames, data=True, return_dict=False)
  460. link_uploader = LinkUploader(g.CLOUDSEARCH_DOC_API, things=things)
  461. subreddit_uploader = SubredditUploader(g.CLOUDSEARCH_SUBREDDIT_DOC_API,
  462. things=things)
  463. link_time = link_uploader.inject()
  464. subreddit_time = subreddit_uploader.inject()
  465. cloudsearch_time = link_time + subreddit_time
  466. totaltime = (datetime.now(g.tz) - start).total_seconds()
  467. print ("%s: %d messages in %.2fs seconds (%.2fs secs waiting on "
  468. "cloudsearch); %d duplicates, %s remaining)" %
  469. (start, len(changed), totaltime, cloudsearch_time,
  470. len(changed) - len(things),
  471. msgs[-1].delivery_info.get('message_count', 'unknown')))
  472. def run_changed(drain=False, min_size=500, limit=1000, sleep_time=10,
  473. use_safe_get=False, verbose=False):
  474. '''Run by `cron` (through `paster run`) on a schedule to send Things to
  475. Amazon CloudSearch
  476. '''
  477. if use_safe_get:
  478. CloudSearchUploader.use_safe_get = True
  479. amqp.handle_items('cloudsearch_changes', _run_changed, min_size=min_size,
  480. limit=limit, drain=drain, sleep_time=sleep_time,
  481. verbose=verbose)
  482. def _progress_key(item):
  483. return "%s/%s" % (item._id, item._date)
  484. _REBUILD_INDEX_CACHE_KEY = "cloudsearch_cursor_%s"
  485. def rebuild_link_index(start_at=None, sleeptime=1, cls=Link,
  486. uploader=LinkUploader, doc_api='CLOUDSEARCH_DOC_API',
  487. estimate=50000000, chunk_size=1000):
  488. cache_key = _REBUILD_INDEX_CACHE_KEY % uploader.__name__.lower()
  489. doc_api = getattr(g, doc_api)
  490. uploader = uploader(doc_api)
  491. if start_at is _REBUILD_INDEX_CACHE_KEY:
  492. start_at = g.cache.get(cache_key)
  493. if not start_at:
  494. raise ValueError("Told me to use '%s' key, but it's not set" %
  495. cache_key)
  496. q = cls._query(cls.c._deleted == (True, False),
  497. sort=desc('_date'), data=True)
  498. if start_at:
  499. after = cls._by_fullname(start_at)
  500. assert isinstance(after, cls)
  501. q._after(after)
  502. q = r2utils.fetch_things2(q, chunk_size=chunk_size)
  503. q = r2utils.progress(q, verbosity=1000, estimate=estimate, persec=True,
  504. key=_progress_key)
  505. for chunk in r2utils.in_chunks(q, size=chunk_size):
  506. uploader.things = chunk
  507. for x in range(5):
  508. try:
  509. uploader.inject()
  510. except httplib.HTTPException as err:
  511. print "Got %s, sleeping %s secs" % (err, x)
  512. time.sleep(x)
  513. continue
  514. else:
  515. break
  516. else:
  517. raise err
  518. last_update = chunk[-1]
  519. g.cache.set(cache_key, last_update._fullname)
  520. time.sleep(sleeptime)
  521. rebuild_subreddit_index = functools.partial(rebuild_link_index,
  522. cls=Subreddit,
  523. uploader=SubredditUploader,
  524. doc_api='CLOUDSEARCH_SUBREDDIT_DOC_API',
  525. estimate=200000,
  526. chunk_size=1000)
  527. def test_run_link(start_link, count=1000):
  528. '''Inject `count` number of links, starting with `start_link`'''
  529. if isinstance(start_link, basestring):
  530. start_link = int(start_link, 36)
  531. links = Link._byID(range(start_link - count, start_link), data=True,
  532. return_dict=False)
  533. uploader = LinkUploader(g.CLOUDSEARCH_DOC_API, things=links)
  534. return uploader.inject()
  535. def test_run_srs(*sr_names):
  536. '''Inject Subreddits by name into the index'''
  537. srs = Subreddit._by_name(sr_names).values()
  538. uploader = SubredditUploader(g.CLOUDSEARCH_SUBREDDIT_DOC_API, things=srs)
  539. return uploader.inject()
  540. ### Query Code ###
  541. class Results(object):
  542. def __init__(self, docs, hits, facets):
  543. self.docs = docs
  544. self.hits = hits
  545. self._facets = facets
  546. self._subreddits = []
  547. def __repr__(self):
  548. return '%s(%r, %r, %r)' % (self.__class__.__name__,
  549. self.docs,
  550. self.hits,
  551. self._facets)
  552. @property
  553. def subreddit_facets(self):
  554. '''Filter out subreddits that the user isn't allowed to see'''
  555. if not self._subreddits and 'reddit' in self._facets:
  556. sr_facets = [(sr['value'], sr['count']) for sr in
  557. self._facets['reddit']]
  558. # look up subreddits
  559. srs_by_name = Subreddit._by_name([name for name, count
  560. in sr_facets])
  561. sr_facets = [(srs_by_name[name], count) for name, count
  562. in sr_facets if name in srs_by_name]
  563. # filter by can_view
  564. self._subreddits = [(sr, count) for sr, count in sr_facets
  565. if sr.can_view(c.user)]
  566. return self._subreddits
  567. _SEARCH = "/2011-02-01/search?"
  568. INVALID_QUERY_CODES = ('CS-UnknownFieldInMatchExpression',
  569. 'CS-IncorrectFieldTypeInMatchExpression',
  570. 'CS-InvalidMatchSetExpression',)
  571. DEFAULT_FACETS = {"reddit": {"count":20}}
  572. def basic_query(query=None, bq=None, faceting=None, size=1000,
  573. start=0, rank="-relevance", return_fields=None, record_stats=False,
  574. search_api=None):
  575. if search_api is None:
  576. search_api = g.CLOUDSEARCH_SEARCH_API
  577. if faceting is None:
  578. faceting = DEFAULT_FACETS
  579. path = _encode_query(query, bq, faceting, size, start, rank, return_fields)
  580. timer = None
  581. if record_stats:
  582. timer = g.stats.get_timer("cloudsearch_timer")
  583. timer.start()
  584. connection = httplib.HTTPConnection(search_api, 80)
  585. try:
  586. connection.request('GET', path)
  587. resp = connection.getresponse()
  588. response = resp.read()
  589. if record_stats:
  590. g.stats.action_count("event.search_query", resp.status)
  591. if resp.status >= 300:
  592. try:
  593. reasons = json.loads(response)
  594. except ValueError:
  595. pass
  596. else:
  597. messages = reasons.get("messages", [])
  598. for message in messages:
  599. if message['code'] in INVALID_QUERY_CODES:
  600. raise InvalidQuery(resp.status, resp.reason, message,
  601. path, reasons)
  602. raise CloudSearchHTTPError(resp.status, resp.reason, path,
  603. response)
  604. finally:
  605. connection.close()
  606. if timer is not None:
  607. timer.stop()
  608. return json.loads(response)
  609. basic_link = functools.partial(basic_query, size=10, start=0,
  610. rank="-relevance",
  611. return_fields=['title', 'reddit',
  612. 'author_fullname'],
  613. record_stats=False,
  614. search_api=g.CLOUDSEARCH_SEARCH_API)
  615. basic_subreddit = functools.partial(basic_query,
  616. faceting=None,
  617. size=10, start=0,
  618. rank="-activity",
  619. return_fields=['title', 'reddit',
  620. 'author_fullname'],
  621. record_stats=False,
  622. search_api=g.CLOUDSEARCH_SUBREDDIT_SEARCH_API)
  623. def _encode_query(query, bq, faceting, size, start, rank, return_fields):
  624. if not (query or bq):
  625. raise ValueError("Need query or bq")
  626. params = {}
  627. if bq:
  628. params["bq"] = bq
  629. else:
  630. params["q"] = query
  631. params["results-type"] = "json"
  632. params["size"] = size
  633. params["start"] = start
  634. params["rank"] = rank
  635. if faceting:
  636. params["facet"] = ",".join(faceting.iterkeys())
  637. for facet, options in faceting.iteritems():
  638. params["facet-%s-top-n" % facet] = options.get("count", 20)
  639. if "sort" in options:
  640. params["facet-%s-sort" % facet] = options["sort"]
  641. if return_fields:
  642. params["return-fields"] = ",".join(return_fields)
  643. encoded_query = urllib.urlencode(params)
  644. path = _SEARCH + encoded_query
  645. return path
  646. class CloudSearchQuery(object):
  647. '''Represents a search query sent to cloudsearch'''
  648. search_api = None
  649. sorts = {}
  650. sorts_menu_mapping = {}
  651. known_syntaxes = ("cloudsearch", "lucene", "plain")
  652. default_syntax = "plain"
  653. lucene_parser = None
  654. def __init__(self, query, sr=None, sort=None, syntax=None, raw_sort=None,
  655. faceting=None):
  656. if syntax is None:
  657. syntax = self.default_syntax
  658. elif syntax not in self.known_syntaxes:
  659. raise ValueError("Unknown search syntax: %s" % syntax)
  660. self.query = filters._force_unicode(query or u'')
  661. self.converted_data = None
  662. self.syntax = syntax
  663. self.sr = sr
  664. self._sort = sort
  665. if raw_sort:
  666. self.sort = raw_sort
  667. else:
  668. self.sort = self.sorts[sort]
  669. self.faceting = faceting
  670. self.bq = u''
  671. self.results = None
  672. def run(self, after=None, reverse=False, num=1000, _update=False):
  673. if not self.query:
  674. return Results([], 0, {})
  675. results = self._run(_update=_update)
  676. docs, hits, facets = results.docs, results.hits, results._facets
  677. after_docs = r2utils.get_after(docs, after, num, reverse=reverse)
  678. self.results = Results(after_docs, hits, facets)
  679. return self.results
  680. def _run(self, start=0, num=1000, _update=False):
  681. '''Run the search against self.query'''
  682. q = None
  683. if self.syntax == "cloudsearch":
  684. self.bq = self.customize_query(self.query)
  685. elif self.syntax == "lucene":
  686. bq = l2cs.convert(self.query, self.lucene_parser)
  687. self.converted_data = {"syntax": "cloudsearch",
  688. "converted": bq}
  689. self.bq = self.customize_query(bq)
  690. elif self.syntax == "plain":
  691. q = self.query.encode('utf-8')
  692. if g.sqlprinting:
  693. g.log.info("%s", self)
  694. return self._run_cached(q, self.bq.encode('utf-8'), self.sort,
  695. self.faceting, start=start, num=num,
  696. _update=_update)
  697. def customize_query(self, bq):
  698. return bq
  699. def __repr__(self):
  700. '''Return a string representation of this query'''
  701. result = ["<", self.__class__.__name__, "> query:",
  702. repr(self.query), " "]
  703. if self.bq:
  704. result.append(" bq:")
  705. result.append(repr(self.bq))
  706. result.append(" ")
  707. result.append("sort:")
  708. result.append(self.sort)
  709. return ''.join(result)
  710. @classmethod
  711. def _run_cached(cls, query, bq, sort="relevance", faceting=None, start=0,
  712. num=1000, _update=False):
  713. '''Query the cloudsearch API. _update parameter allows for supposed
  714. easy memoization at later date.
  715. Example result set:
  716. {u'facets': {u'reddit': {u'constraints':
  717. [{u'count': 114, u'value': u'politics'},
  718. {u'count': 42, u'value': u'atheism'},
  719. {u'count': 27, u'value': u'wtf'},
  720. {u'count': 19, u'value': u'gaming'},
  721. {u'count': 12, u'value': u'bestof'},
  722. {u'count': 12, u'value': u'tf2'},
  723. {u'count': 11, u'value': u'AdviceAnimals'},
  724. {u'count': 9, u'value': u'todayilearned'},
  725. {u'count': 9, u'value': u'pics'},
  726. {u'count': 9, u'value': u'funny'}]}},
  727. u'hits': {u'found': 399,
  728. u'hit': [{u'id': u't3_11111'},
  729. {u'id': u't3_22222'},
  730. {u'id': u't3_33333'},
  731. {u'id': u't3_44444'},
  732. ...
  733. ],
  734. u'start': 0},
  735. u'info': {u'cpu-time-ms': 10,
  736. u'messages': [{u'code': u'CS-InvalidFieldOrRankAliasInRankParameter',
  737. u'message': u"Unable to create score object for rank '-hot'",
  738. u'severity': u'warning'}],
  739. u'rid': u'<hash>',
  740. u'time-ms': 9},
  741. u'match-expr': u"(label 'my query')",
  742. u'rank': u'-text_relevance'}
  743. '''
  744. response = basic_query(query=query, bq=bq, size=num, start=start,
  745. rank=sort, search_api=cls.search_api,
  746. faceting=faceting, record_stats=True)
  747. warnings = response['info'].get('messages', [])
  748. for warning in warnings:
  749. g.log.warn("%(code)s (%(severity)s): %(message)s" % warning)
  750. hits = response['hits']['found']
  751. docs = [doc['id'] for doc in response['hits']['hit']]
  752. facets = response.get('facets', {})
  753. for facet in facets.keys():
  754. values = facets[facet]['constraints']
  755. facets[facet] = values
  756. results = Results(docs, hits, facets)
  757. return results
  758. class LinkSearchQuery(CloudSearchQuery):
  759. search_api = g.CLOUDSEARCH_SEARCH_API
  760. sorts = {'relevance': '-relevance',
  761. 'hot': '-hot2',
  762. 'top': '-top',
  763. 'new': '-timestamp',
  764. 'comments': '-num_comments',
  765. }
  766. sorts_menu_mapping = {'relevance': 1,
  767. 'hot': 2,
  768. 'new': 3,
  769. 'top': 4,
  770. 'comments': 5,
  771. }
  772. schema = l2cs.make_schema(LinkFields.lucene_fieldnames())
  773. lucene_parser = l2cs.make_parser(
  774. int_fields=LinkFields.lucene_fieldnames(type_=int),
  775. yesno_fields=LinkFields.lucene_fieldnames(type_="yesno"),
  776. schema=schema)
  777. known_syntaxes = ("cloudsearch", "lucene", "plain")
  778. default_syntax = "lucene"
  779. def customize_query(self, bq):
  780. subreddit_query = self._get_sr_restriction(self.sr)
  781. return self.create_boolean_query(bq, subreddit_query)
  782. @classmethod
  783. def create_boolean_query(cls, query, subreddit_query):
  784. '''Join a (user-entered) text query with the generated subreddit query
  785. Input:
  786. base_query: user input from the search textbox
  787. subreddit_query: output from _get_sr_restriction(sr)
  788. Test cases:
  789. base_query: simple, simple with quotes, boolean, boolean w/ parens
  790. subreddit_query: None, in parens '(or sr_id:1 sr_id:2 ...)',
  791. without parens "author:'foo'"
  792. '''
  793. if subreddit_query:
  794. bq = "(and %s %s)" % (query, subreddit_query)
  795. else:
  796. bq = query
  797. return bq
  798. @staticmethod
  799. def _get_sr_restriction(sr):
  800. '''Return a cloudsearch appropriate query string that restricts
  801. results to only contain results from self.sr
  802. '''
  803. bq = []
  804. if (not sr) or sr == All or isinstance(sr, DefaultSR):
  805. return None
  806. elif isinstance(sr, MultiReddit):
  807. bq = ["(or"]
  808. for sr_id in sr.sr_ids:
  809. bq.append("sr_id:%s" % sr_id)
  810. bq.append(")")
  811. elif isinstance(sr, DomainSR):
  812. bq = ["site:'%s'" % sr.domain]
  813. elif sr == Friends:
  814. if not c.user_is_loggedin or not c.user.friends:
  815. return None
  816. bq = ["(or"]
  817. # The query limit is roughly 8k bytes. Limit to 200 friends to
  818. # avoid getting too close to that limit
  819. friend_ids = c.user.friends[:200]
  820. friends = ["author_fullname:'%s'" %
  821. Account._fullname_from_id36(r2utils.to36(id_))
  822. for id_ in friend_ids]
  823. bq.extend(friends)
  824. bq.append(")")
  825. elif isinstance(sr, ModContribSR):
  826. bq = ["(or"]
  827. for sr_id in sr.sr_ids:
  828. bq.append("sr_id:%s" % sr_id)
  829. bq.append(")")
  830. elif not isinstance(sr, FakeSubreddit):
  831. bq = ["sr_id:%s" % sr._id]
  832. return ' '.join(bq)
  833. class SubredditSearchQuery(CloudSearchQuery):
  834. search_api = g.CLOUDSEARCH_SUBREDDIT_SEARCH_API
  835. sorts = {'relevance': '-activity',
  836. None: '-activity',
  837. }
  838. sorts_menu_mapping = {'relevance': 1,
  839. }
  840. known_syntaxes = ("plain",)
  841. default_syntax = "plain"