PageRenderTime 74ms CodeModel.GetById 31ms RepoModel.GetById 1ms app.codeStats 0ms

/r2/r2/lib/utils/utils.py

https://github.com/stevewilber/reddit
Python | 1463 lines | 1339 code | 58 blank | 66 comment | 69 complexity | 4075b9a6415af69eb4394080b04504a9 MD5 | raw file
Possible License(s): MPL-2.0-no-copyleft-exception, Apache-2.0
  1. # The contents of this file are subject to the Common Public Attribution
  2. # License Version 1.0. (the "License"); you may not use this file except in
  3. # compliance with the License. You may obtain a copy of the License at
  4. # http://code.reddit.com/LICENSE. The License is based on the Mozilla Public
  5. # License Version 1.1, but Sections 14 and 15 have been added to cover use of
  6. # software over a computer network and provide for limited attribution for the
  7. # Original Developer. In addition, Exhibit A has been modified to be consistent
  8. # with Exhibit B.
  9. #
  10. # Software distributed under the License is distributed on an "AS IS" basis,
  11. # WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License for
  12. # the specific language governing rights and limitations under the License.
  13. #
  14. # The Original Code is reddit.
  15. #
  16. # The Original Developer is the Initial Developer. The Initial Developer of
  17. # the Original Code is reddit Inc.
  18. #
  19. # All portions of the code written by reddit are Copyright (c) 2006-2012 reddit
  20. # Inc. All Rights Reserved.
  21. ###############################################################################
  22. import os
  23. import base64
  24. import traceback
  25. import ConfigParser
  26. from urllib import unquote_plus
  27. from urllib2 import urlopen
  28. from urlparse import urlparse, urlunparse
  29. import signal
  30. from copy import deepcopy
  31. import cPickle as pickle
  32. import re, math, random
  33. import boto
  34. from decimal import Decimal
  35. from BeautifulSoup import BeautifulSoup, SoupStrainer
  36. from time import sleep
  37. from datetime import datetime, timedelta
  38. from functools import wraps, partial, WRAPPER_ASSIGNMENTS
  39. from pylons import g
  40. from pylons.i18n import ungettext, _
  41. from r2.lib.filters import _force_unicode, _force_utf8
  42. from mako.filters import url_escape
  43. from r2.lib.contrib import ipaddress
  44. from r2.lib.require import require, require_split
  45. import snudown
  46. from r2.lib.utils._utils import *
  47. iters = (list, tuple, set)
  48. def randstr(len, reallyrandom = False):
  49. """If reallyrandom = False, generates a random alphanumeric string
  50. (base-36 compatible) of length len. If reallyrandom, add
  51. uppercase and punctuation (which we'll call 'base-93' for the sake
  52. of argument) and suitable for use as salt."""
  53. alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789'
  54. if reallyrandom:
  55. alphabet += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ!#$%&\()*+,-./:;<=>?@[\\]^_{|}~'
  56. return ''.join(random.choice(alphabet)
  57. for i in range(len))
  58. class Storage(dict):
  59. """
  60. A Storage object is like a dictionary except `obj.foo` can be used
  61. in addition to `obj['foo']`.
  62. >>> o = storage(a=1)
  63. >>> o.a
  64. 1
  65. >>> o['a']
  66. 1
  67. >>> o.a = 2
  68. >>> o['a']
  69. 2
  70. >>> del o.a
  71. >>> o.a
  72. Traceback (most recent call last):
  73. ...
  74. AttributeError: 'a'
  75. """
  76. def __getattr__(self, key):
  77. try:
  78. return self[key]
  79. except KeyError, k:
  80. raise AttributeError, k
  81. def __setattr__(self, key, value):
  82. self[key] = value
  83. def __delattr__(self, key):
  84. try:
  85. del self[key]
  86. except KeyError, k:
  87. raise AttributeError, k
  88. def __repr__(self):
  89. return '<Storage ' + dict.__repr__(self) + '>'
  90. storage = Storage
  91. def storify(mapping, *requireds, **defaults):
  92. """
  93. Creates a `storage` object from dictionary `mapping`, raising `KeyError` if
  94. d doesn't have all of the keys in `requireds` and using the default
  95. values for keys found in `defaults`.
  96. For example, `storify({'a':1, 'c':3}, b=2, c=0)` will return the equivalent of
  97. `storage({'a':1, 'b':2, 'c':3})`.
  98. If a `storify` value is a list (e.g. multiple values in a form submission),
  99. `storify` returns the last element of the list, unless the key appears in
  100. `defaults` as a list. Thus:
  101. >>> storify({'a':[1, 2]}).a
  102. 2
  103. >>> storify({'a':[1, 2]}, a=[]).a
  104. [1, 2]
  105. >>> storify({'a':1}, a=[]).a
  106. [1]
  107. >>> storify({}, a=[]).a
  108. []
  109. Similarly, if the value has a `value` attribute, `storify will return _its_
  110. value, unless the key appears in `defaults` as a dictionary.
  111. >>> storify({'a':storage(value=1)}).a
  112. 1
  113. >>> storify({'a':storage(value=1)}, a={}).a
  114. <Storage {'value': 1}>
  115. >>> storify({}, a={}).a
  116. {}
  117. """
  118. def getvalue(x):
  119. if hasattr(x, 'value'):
  120. return x.value
  121. else:
  122. return x
  123. stor = Storage()
  124. for key in requireds + tuple(mapping.keys()):
  125. value = mapping[key]
  126. if isinstance(value, list):
  127. if isinstance(defaults.get(key), list):
  128. value = [getvalue(x) for x in value]
  129. else:
  130. value = value[-1]
  131. if not isinstance(defaults.get(key), dict):
  132. value = getvalue(value)
  133. if isinstance(defaults.get(key), list) and not isinstance(value, list):
  134. value = [value]
  135. setattr(stor, key, value)
  136. for (key, value) in defaults.iteritems():
  137. result = value
  138. if hasattr(stor, key):
  139. result = stor[key]
  140. if value == () and not isinstance(result, tuple):
  141. result = (result,)
  142. setattr(stor, key, result)
  143. return stor
  144. class Enum(Storage):
  145. def __init__(self, *a):
  146. self.name = tuple(a)
  147. Storage.__init__(self, ((e, i) for i, e in enumerate(a)))
  148. def __contains__(self, item):
  149. if isinstance(item, int):
  150. return item in self.values()
  151. else:
  152. return Storage.__contains__(self, item)
  153. class Results():
  154. def __init__(self, sa_ResultProxy, build_fn, do_batch=False):
  155. self.rp = sa_ResultProxy
  156. self.fn = build_fn
  157. self.do_batch = do_batch
  158. @property
  159. def rowcount(self):
  160. return self.rp.rowcount
  161. def _fetch(self, res):
  162. if self.do_batch:
  163. return self.fn(res)
  164. else:
  165. return [self.fn(row) for row in res]
  166. def fetchall(self):
  167. return self._fetch(self.rp.fetchall())
  168. def fetchmany(self, n):
  169. rows = self._fetch(self.rp.fetchmany(n))
  170. if rows:
  171. return rows
  172. else:
  173. raise StopIteration
  174. def fetchone(self):
  175. row = self.rp.fetchone()
  176. if row:
  177. if self.do_batch:
  178. row = tup(row)
  179. return self.fn(row)[0]
  180. else:
  181. return self.fn(row)
  182. else:
  183. raise StopIteration
  184. def strip_www(domain):
  185. if domain.count('.') >= 2 and domain.startswith("www."):
  186. return domain[4:]
  187. else:
  188. return domain
  189. def is_subdomain(subdomain, base):
  190. """Check if a domain is equal to or a subdomain of a base domain."""
  191. return subdomain == base or (subdomain is not None and subdomain.endswith('.' + base))
  192. r_base_url = re.compile("(?i)(?:.+?://)?(?:www[\d]*\.)?([^#]*[^#/])/?")
  193. def base_url(url):
  194. res = r_base_url.findall(url)
  195. return (res and res[0]) or url
  196. r_domain = re.compile("(?i)(?:.+?://)?(?:www[\d]*\.)?([^/:#?]*)")
  197. def domain(s):
  198. """
  199. Takes a URL and returns the domain part, minus www., if
  200. present
  201. """
  202. res = r_domain.findall(s)
  203. domain = (res and res[0]) or s
  204. return domain.lower()
  205. r_path_component = re.compile(".*?/(.*)")
  206. def path_component(s):
  207. """
  208. takes a url http://www.foo.com/i/like/cheese and returns
  209. i/like/cheese
  210. """
  211. res = r_path_component.findall(base_url(s))
  212. return (res and res[0]) or s
  213. def get_title(url):
  214. """Fetches the contents of url and extracts (and utf-8 encodes)
  215. the contents of <title>"""
  216. if not url or not (url.startswith('http://') or url.startswith('https://')):
  217. return None
  218. try:
  219. opener = urlopen(url, timeout=15)
  220. # Attempt to find the title in the first 1kb
  221. data = opener.read(1024)
  222. title = extract_title(data)
  223. # Title not found in the first kb, try searching an additional 2kb
  224. if not title:
  225. data += opener.read(2048)
  226. title = extract_title(data)
  227. opener.close()
  228. return title
  229. except:
  230. return None
  231. def extract_title(data):
  232. """Tries to extract the value of the title element from a string of HTML"""
  233. bs = BeautifulSoup(data, convertEntities=BeautifulSoup.HTML_ENTITIES)
  234. if not bs:
  235. return
  236. title_bs = bs.html.head.title
  237. if not title_bs or not title_bs.string:
  238. return
  239. return title_bs.string.encode('utf-8').strip()
  240. valid_schemes = ('http', 'https', 'ftp', 'mailto')
  241. valid_dns = re.compile('\A[-a-zA-Z0-9]+\Z')
  242. def sanitize_url(url, require_scheme = False):
  243. """Validates that the url is of the form
  244. scheme://domain/path/to/content#anchor?cruft
  245. using the python built-in urlparse. If the url fails to validate,
  246. returns None. If no scheme is provided and 'require_scheme =
  247. False' is set, the url is returned with scheme 'http', provided it
  248. otherwise validates"""
  249. if not url:
  250. return
  251. url = url.strip()
  252. if url.lower() == 'self':
  253. return url
  254. try:
  255. u = urlparse(url)
  256. # first pass: make sure a scheme has been specified
  257. if not require_scheme and not u.scheme:
  258. url = 'http://' + url
  259. u = urlparse(url)
  260. except ValueError:
  261. return
  262. if u.scheme and u.scheme in valid_schemes:
  263. # if there is a scheme and no hostname, it is a bad url.
  264. if not u.hostname:
  265. return
  266. if u.username is not None or u.password is not None:
  267. return
  268. labels = u.hostname.split('.')
  269. for label in labels:
  270. try:
  271. #if this succeeds, this portion of the dns is almost
  272. #valid and converted to ascii
  273. label = label.encode('idna')
  274. except TypeError:
  275. print "label sucks: [%r]" % label
  276. raise
  277. except UnicodeError:
  278. return
  279. else:
  280. #then if this success, this portion of the dns is really valid
  281. if not re.match(valid_dns, label):
  282. return
  283. return url
  284. def trunc_string(text, length):
  285. return text[0:length]+'...' if len(text)>length else text
  286. # Truncate a time to a certain number of minutes
  287. # e.g, trunc_time(5:52, 30) == 5:30
  288. def trunc_time(time, mins, hours=None):
  289. if hours is not None:
  290. if hours < 1 or hours > 60:
  291. raise ValueError("Hours %d is weird" % mins)
  292. time = time.replace(hour = hours * (time.hour / hours))
  293. if mins < 1 or mins > 60:
  294. raise ValueError("Mins %d is weird" % mins)
  295. return time.replace(minute = mins * (time.minute / mins),
  296. second = 0,
  297. microsecond = 0)
  298. def long_datetime(datetime):
  299. return datetime.astimezone(g.tz).ctime() + " " + str(g.tz)
  300. def median(l):
  301. if l:
  302. s = sorted(l)
  303. i = len(s) / 2
  304. return s[i]
  305. def query_string(dict):
  306. pairs = []
  307. for k,v in dict.iteritems():
  308. if v is not None:
  309. try:
  310. k = url_escape(_force_unicode(k))
  311. v = url_escape(_force_unicode(v))
  312. pairs.append(k + '=' + v)
  313. except UnicodeDecodeError:
  314. continue
  315. if pairs:
  316. return '?' + '&'.join(pairs)
  317. else:
  318. return ''
  319. class UrlParser(object):
  320. """
  321. Wrapper for urlparse and urlunparse for making changes to urls.
  322. All attributes present on the tuple-like object returned by
  323. urlparse are present on this class, and are setable, with the
  324. exception of netloc, which is instead treated via a getter method
  325. as a concatenation of hostname and port.
  326. Unlike urlparse, this class allows the query parameters to be
  327. converted to a dictionary via the query_dict method (and
  328. correspondingly updated vi update_query). The extension of the
  329. path can also be set and queried.
  330. The class also contains reddit-specific functions for setting,
  331. checking, and getting a path's subreddit. It also can convert
  332. paths between in-frame and out of frame cname'd forms.
  333. """
  334. __slots__ = ['scheme', 'path', 'params', 'query',
  335. 'fragment', 'username', 'password', 'hostname',
  336. 'port', '_url_updates', '_orig_url', '_query_dict']
  337. valid_schemes = ('http', 'https', 'ftp', 'mailto')
  338. cname_get = "cnameframe"
  339. def __init__(self, url):
  340. u = urlparse(url)
  341. for s in self.__slots__:
  342. if hasattr(u, s):
  343. setattr(self, s, getattr(u, s))
  344. self._url_updates = {}
  345. self._orig_url = url
  346. self._query_dict = None
  347. def update_query(self, **updates):
  348. """
  349. Can be used instead of self.query_dict.update() to add/change
  350. query params in situations where the original contents are not
  351. required.
  352. """
  353. self._url_updates.update(updates)
  354. @property
  355. def query_dict(self):
  356. """
  357. Parses the `params' attribute of the original urlparse and
  358. generates a dictionary where both the keys and values have
  359. been url_unescape'd. Any updates or changes to the resulting
  360. dict will be reflected in the updated query params
  361. """
  362. if self._query_dict is None:
  363. def _split(param):
  364. p = param.split('=')
  365. return (unquote_plus(p[0]),
  366. unquote_plus('='.join(p[1:])))
  367. self._query_dict = dict(_split(p) for p in self.query.split('&')
  368. if p)
  369. return self._query_dict
  370. def path_extension(self):
  371. """
  372. Fetches the current extension of the path.
  373. """
  374. return self.path.split('/')[-1].split('.')[-1]
  375. def set_extension(self, extension):
  376. """
  377. Changes the extension of the path to the provided value (the
  378. "." should not be included in the extension as a "." is
  379. provided)
  380. """
  381. pieces = self.path.split('/')
  382. dirs = pieces[:-1]
  383. base = pieces[-1].split('.')
  384. base = '.'.join(base[:-1] if len(base) > 1 else base)
  385. if extension:
  386. base += '.' + extension
  387. dirs.append(base)
  388. self.path = '/'.join(dirs)
  389. return self
  390. def unparse(self):
  391. """
  392. Converts the url back to a string, applying all updates made
  393. to the feilds thereof.
  394. Note: if a host name has been added and none was present
  395. before, will enforce scheme -> "http" unless otherwise
  396. specified. Double-slashes are removed from the resultant
  397. path, and the query string is reconstructed only if the
  398. query_dict has been modified/updated.
  399. """
  400. # only parse the query params if there is an update dict
  401. q = self.query
  402. if self._url_updates or self._query_dict is not None:
  403. q = self._query_dict or self.query_dict
  404. q.update(self._url_updates)
  405. q = query_string(q).lstrip('?')
  406. # make sure the port is not doubly specified
  407. if self.port and ":" in self.hostname:
  408. self.hostname = self.hostname.split(':')[0]
  409. # if there is a netloc, there had better be a scheme
  410. if self.netloc and not self.scheme:
  411. self.scheme = "http"
  412. return urlunparse((self.scheme, self.netloc,
  413. self.path.replace('//', '/'),
  414. self.params, q, self.fragment))
  415. def path_has_subreddit(self):
  416. """
  417. utility method for checking if the path starts with a
  418. subreddit specifier (namely /r/ or /reddits/).
  419. """
  420. return (self.path.startswith('/r/') or
  421. self.path.startswith('/reddits/'))
  422. def get_subreddit(self):
  423. """checks if the current url refers to a subreddit and returns
  424. that subreddit object. The cases here are:
  425. * the hostname is unset or is g.domain, in which case it
  426. looks for /r/XXXX or /reddits. The default in this case
  427. is Default.
  428. * the hostname is a cname to a known subreddit.
  429. On failure to find a subreddit, returns None.
  430. """
  431. from pylons import g
  432. from r2.models import Subreddit, Sub, NotFound, DefaultSR
  433. try:
  434. if not self.hostname or self.hostname.startswith(g.domain):
  435. if self.path.startswith('/r/'):
  436. return Subreddit._by_name(self.path.split('/')[2])
  437. elif self.path.startswith('/reddits/'):
  438. return Sub
  439. else:
  440. return DefaultSR()
  441. elif self.hostname:
  442. return Subreddit._by_domain(self.hostname)
  443. except NotFound:
  444. pass
  445. return None
  446. def is_reddit_url(self, subreddit = None):
  447. """utility method for seeing if the url is associated with
  448. reddit as we don't necessarily want to mangle non-reddit
  449. domains
  450. returns true only if hostname is nonexistant, a subdomain of
  451. g.domain, or a subdomain of the provided subreddit's cname.
  452. """
  453. from pylons import g
  454. return (not self.hostname or
  455. is_subdomain(self.hostname, g.domain) or
  456. (subreddit and subreddit.domain and
  457. is_subdomain(self.hostname, subreddit.domain)))
  458. def path_add_subreddit(self, subreddit):
  459. """
  460. Adds the subreddit's path to the path if another subreddit's
  461. prefix is not already present.
  462. """
  463. if not self.path_has_subreddit():
  464. self.path = (subreddit.path + self.path)
  465. return self
  466. @property
  467. def netloc(self):
  468. """
  469. Getter method which returns the hostname:port, or empty string
  470. if no hostname is present.
  471. """
  472. if not self.hostname:
  473. return ""
  474. elif getattr(self, "port", None):
  475. return self.hostname + ":" + str(self.port)
  476. return self.hostname
  477. def mk_cname(self, require_frame = True, subreddit = None, port = None):
  478. """
  479. Converts a ?cnameframe url into the corresponding cnamed
  480. domain if applicable. Useful for frame-busting on redirect.
  481. """
  482. # make sure the url is indeed in a frame
  483. if require_frame and not self.query_dict.has_key(self.cname_get):
  484. return self
  485. # fetch the subreddit and make sure it
  486. subreddit = subreddit or self.get_subreddit()
  487. if subreddit and subreddit.domain:
  488. # no guarantee there was a scheme
  489. self.scheme = self.scheme or "http"
  490. # update the domain (preserving the port)
  491. self.hostname = subreddit.domain
  492. self.port = self.port or port
  493. # and remove any cnameframe GET parameters
  494. if self.query_dict.has_key(self.cname_get):
  495. del self._query_dict[self.cname_get]
  496. # remove the subreddit reference
  497. self.path = lstrips(self.path, subreddit.path)
  498. if not self.path.startswith('/'):
  499. self.path = '/' + self.path
  500. return self
  501. def is_in_frame(self):
  502. """
  503. Checks if the url is in a frame by determining if
  504. cls.cname_get is present.
  505. """
  506. return self.query_dict.has_key(self.cname_get)
  507. def put_in_frame(self):
  508. """
  509. Adds the cls.cname_get get parameter to the query string.
  510. """
  511. self.update_query(**{self.cname_get:random.random()})
  512. def __repr__(self):
  513. return "<URL %s>" % repr(self.unparse())
  514. def domain_permutations(self, fragments=False, subdomains=True):
  515. """
  516. Takes a domain like `www.reddit.com`, and returns a list of ways
  517. that a user might search for it, like:
  518. * www
  519. * reddit
  520. * com
  521. * www.reddit.com
  522. * reddit.com
  523. * com
  524. """
  525. ret = set()
  526. if self.hostname:
  527. r = self.hostname.split('.')
  528. if subdomains:
  529. for x in xrange(len(r)-1):
  530. ret.add('.'.join(r[x:len(r)]))
  531. if fragments:
  532. for x in r:
  533. ret.add(x)
  534. return ret
  535. @classmethod
  536. def base_url(cls, url):
  537. u = cls(url)
  538. # strip off any www and lowercase the hostname:
  539. netloc = strip_www(u.netloc.lower())
  540. # http://code.google.com/web/ajaxcrawling/docs/specification.html
  541. fragment = u.fragment if u.fragment.startswith("!") else ""
  542. return urlunparse((u.scheme.lower(), netloc,
  543. u.path, u.params, u.query, fragment))
  544. def to_js(content, callback="document.write", escape=True):
  545. before = after = ''
  546. if callback:
  547. before = callback + "("
  548. after = ");"
  549. if escape:
  550. content = string2js(content)
  551. return before + content + after
  552. def pload(fname, default = None):
  553. "Load a pickled object from a file"
  554. try:
  555. f = file(fname, 'r')
  556. d = pickle.load(f)
  557. except IOError:
  558. d = default
  559. else:
  560. f.close()
  561. return d
  562. def psave(fname, d):
  563. "Save a pickled object into a file"
  564. f = file(fname, 'w')
  565. pickle.dump(d, f)
  566. f.close()
  567. def unicode_safe(res):
  568. try:
  569. return str(res)
  570. except UnicodeEncodeError:
  571. try:
  572. return unicode(res).encode('utf-8')
  573. except UnicodeEncodeError:
  574. return res.decode('utf-8').encode('utf-8')
  575. def decompose_fullname(fullname):
  576. """
  577. decompose_fullname("t3_e4fa") ->
  578. (Thing, 3, 658918)
  579. """
  580. from r2.lib.db.thing import Thing,Relation
  581. if fullname[0] == 't':
  582. type_class = Thing
  583. elif fullname[0] == 'r':
  584. type_class = Relation
  585. type_id36, thing_id36 = fullname[1:].split('_')
  586. type_id = int(type_id36,36)
  587. id = int(thing_id36,36)
  588. return (type_class, type_id, id)
  589. def cols(lst, ncols):
  590. """divides a list into columns, and returns the
  591. rows. e.g. cols('abcdef', 2) returns (('a', 'd'), ('b', 'e'), ('c',
  592. 'f'))"""
  593. nrows = int(math.ceil(1.*len(lst) / ncols))
  594. lst = lst + [None for i in range(len(lst), nrows*ncols)]
  595. cols = [lst[i:i+nrows] for i in range(0, nrows*ncols, nrows)]
  596. rows = zip(*cols)
  597. rows = [filter(lambda x: x is not None, r) for r in rows]
  598. return rows
  599. def fetch_things(t_class,since,until,batch_fn=None,
  600. *query_params, **extra_query_dict):
  601. """
  602. Simple utility function to fetch all Things of class t_class
  603. (spam or not, but not deleted) that were created from 'since'
  604. to 'until'
  605. """
  606. from r2.lib.db.operators import asc
  607. if not batch_fn:
  608. batch_fn = lambda x: x
  609. query_params = ([t_class.c._date >= since,
  610. t_class.c._date < until,
  611. t_class.c._spam == (True,False)]
  612. + list(query_params))
  613. query_dict = {'sort': asc('_date'),
  614. 'limit': 100,
  615. 'data': True}
  616. query_dict.update(extra_query_dict)
  617. q = t_class._query(*query_params,
  618. **query_dict)
  619. orig_rules = deepcopy(q._rules)
  620. things = list(q)
  621. while things:
  622. things = batch_fn(things)
  623. for t in things:
  624. yield t
  625. q._rules = deepcopy(orig_rules)
  626. q._after(t)
  627. things = list(q)
  628. def fetch_things2(query, chunk_size = 100, batch_fn = None, chunks = False):
  629. """Incrementally run query with a limit of chunk_size until there are
  630. no results left. batch_fn transforms the results for each chunk
  631. before returning."""
  632. orig_rules = deepcopy(query._rules)
  633. query._limit = chunk_size
  634. items = list(query)
  635. done = False
  636. while items and not done:
  637. #don't need to query again at the bottom if we didn't get enough
  638. if len(items) < chunk_size:
  639. done = True
  640. after = items[-1]
  641. if batch_fn:
  642. items = batch_fn(items)
  643. if chunks:
  644. yield items
  645. else:
  646. for i in items:
  647. yield i
  648. if not done:
  649. query._rules = deepcopy(orig_rules)
  650. query._after(after)
  651. items = list(query)
  652. def fix_if_broken(thing, delete = True, fudge_links = False):
  653. from r2.models import Link, Comment, Subreddit, Message
  654. # the minimum set of attributes that are required
  655. attrs = dict((cls, cls._essentials)
  656. for cls
  657. in (Link, Comment, Subreddit, Message))
  658. if thing.__class__ not in attrs:
  659. raise TypeError
  660. tried_loading = False
  661. for attr in attrs[thing.__class__]:
  662. try:
  663. # try to retrieve the attribute
  664. getattr(thing, attr)
  665. except AttributeError:
  666. # that failed; let's explicitly load it and try again
  667. if not tried_loading:
  668. tried_loading = True
  669. thing._load()
  670. try:
  671. getattr(thing, attr)
  672. except AttributeError:
  673. if not delete:
  674. raise
  675. if isinstance(thing, Link) and fudge_links:
  676. if attr == "sr_id":
  677. thing.sr_id = 6
  678. print "Fudging %s.sr_id to %d" % (thing._fullname,
  679. thing.sr_id)
  680. elif attr == "author_id":
  681. thing.author_id = 8244672
  682. print "Fudging %s.author_id to %d" % (thing._fullname,
  683. thing.author_id)
  684. else:
  685. print "Got weird attr %s; can't fudge" % attr
  686. if not thing._deleted:
  687. print "%s is missing %r, deleting" % (thing._fullname, attr)
  688. thing._deleted = True
  689. thing._commit()
  690. if not fudge_links:
  691. break
  692. def find_recent_broken_things(from_time = None, to_time = None,
  693. delete = False):
  694. """
  695. Occasionally (usually during app-server crashes), Things will
  696. be partially written out to the database. Things missing data
  697. attributes break the contract for these things, which often
  698. breaks various pages. This function hunts for and destroys
  699. them as appropriate.
  700. """
  701. from r2.models import Link, Comment
  702. from r2.lib.db.operators import desc
  703. from pylons import g
  704. from_time = from_time or timeago('1 hour')
  705. to_time = to_time or datetime.now(g.tz)
  706. for cls in (Link, Comment):
  707. q = cls._query(cls.c._date > from_time,
  708. cls.c._date < to_time,
  709. data=True,
  710. sort=desc('_date'))
  711. for thing in fetch_things2(q):
  712. fix_if_broken(thing, delete = delete)
  713. def timeit(func):
  714. "Run some function, and return (RunTimeInSeconds,Result)"
  715. before=time.time()
  716. res=func()
  717. return (time.time()-before,res)
  718. def lineno():
  719. "Returns the current line number in our program."
  720. import inspect
  721. print "%s\t%s" % (datetime.now(),inspect.currentframe().f_back.f_lineno)
  722. def IteratorFilter(iterator, fn):
  723. for x in iterator:
  724. if fn(x):
  725. yield x
  726. def UniqueIterator(iterator, key = lambda x: x):
  727. """
  728. Takes an iterator and returns an iterator that returns only the
  729. first occurence of each entry
  730. """
  731. so_far = set()
  732. def no_dups(x):
  733. k = key(x)
  734. if k in so_far:
  735. return False
  736. else:
  737. so_far.add(k)
  738. return True
  739. return IteratorFilter(iterator, no_dups)
  740. def modhash(user, rand = None, test = False):
  741. return user.name
  742. def valid_hash(user, hash):
  743. return True
  744. def check_cheating(loc):
  745. pass
  746. def vote_hash(user, thing, note='valid'):
  747. return user.name
  748. def valid_vote_hash(hash, user, thing):
  749. return True
  750. def safe_eval_str(unsafe_str):
  751. return unsafe_str.replace('\\x3d', '=').replace('\\x26', '&')
  752. rx_whitespace = re.compile('\s+', re.UNICODE)
  753. rx_notsafe = re.compile('\W+', re.UNICODE)
  754. rx_underscore = re.compile('_+', re.UNICODE)
  755. def title_to_url(title, max_length = 50):
  756. """Takes a string and makes it suitable for use in URLs"""
  757. title = _force_unicode(title) #make sure the title is unicode
  758. title = rx_whitespace.sub('_', title) #remove whitespace
  759. title = rx_notsafe.sub('', title) #remove non-printables
  760. title = rx_underscore.sub('_', title) #remove double underscores
  761. title = title.strip('_') #remove trailing underscores
  762. title = title.lower() #lowercase the title
  763. if len(title) > max_length:
  764. #truncate to nearest word
  765. title = title[:max_length]
  766. last_word = title.rfind('_')
  767. if (last_word > 0):
  768. title = title[:last_word]
  769. return title or "_"
  770. def dbg(s):
  771. import sys
  772. sys.stderr.write('%s\n' % (s,))
  773. def trace(fn):
  774. def new_fn(*a,**kw):
  775. ret = fn(*a,**kw)
  776. dbg("Fn: %s; a=%s; kw=%s\nRet: %s"
  777. % (fn,a,kw,ret))
  778. return ret
  779. return new_fn
  780. def common_subdomain(domain1, domain2):
  781. if not domain1 or not domain2:
  782. return ""
  783. domain1 = domain1.split(":")[0]
  784. domain2 = domain2.split(":")[0]
  785. if len(domain1) > len(domain2):
  786. domain1, domain2 = domain2, domain1
  787. if domain1 == domain2:
  788. return domain1
  789. else:
  790. dom = domain1.split(".")
  791. for i in range(len(dom), 1, -1):
  792. d = '.'.join(dom[-i:])
  793. if domain2.endswith(d):
  794. return d
  795. return ""
  796. def interleave_lists(*args):
  797. max_len = max(len(x) for x in args)
  798. for i in xrange(max_len):
  799. for a in args:
  800. if i < len(a):
  801. yield a[i]
  802. def link_from_url(path, filter_spam = False, multiple = True):
  803. from pylons import c
  804. from r2.models import IDBuilder, Link, Subreddit, NotFound
  805. if not path:
  806. return
  807. try:
  808. links = Link._by_url(path, c.site)
  809. except NotFound:
  810. return [] if multiple else None
  811. return filter_links(tup(links), filter_spam = filter_spam,
  812. multiple = multiple)
  813. def filter_links(links, filter_spam = False, multiple = True):
  814. # run the list through a builder to remove any that the user
  815. # isn't allowed to see
  816. from pylons import c
  817. from r2.models import IDBuilder, Link, Subreddit, NotFound
  818. links = IDBuilder([link._fullname for link in links],
  819. skip = False).get_items()[0]
  820. if not links:
  821. return
  822. if filter_spam:
  823. # first, try to remove any spam
  824. links_nonspam = [ link for link in links
  825. if not link._spam ]
  826. if links_nonspam:
  827. links = links_nonspam
  828. # if it occurs in one or more of their subscriptions, show them
  829. # that one first
  830. subs = set(Subreddit.user_subreddits(c.user, limit = None))
  831. def cmp_links(a, b):
  832. if a.sr_id in subs and b.sr_id not in subs:
  833. return -1
  834. elif a.sr_id not in subs and b.sr_id in subs:
  835. return 1
  836. else:
  837. return cmp(b._hot, a._hot)
  838. links = sorted(links, cmp = cmp_links)
  839. # among those, show them the hottest one
  840. return links if multiple else links[0]
  841. def link_duplicates(article):
  842. # don't bother looking it up if the link doesn't have a URL anyway
  843. if getattr(article, 'is_self', False):
  844. return []
  845. return url_links(article.url, exclude = article._fullname)
  846. def url_links(url, exclude=None):
  847. from r2.models import Link, NotFound
  848. try:
  849. links = tup(Link._by_url(url, None))
  850. except NotFound:
  851. links = []
  852. links = [ link for link in links
  853. if link._fullname != exclude ]
  854. return links
  855. class TimeoutFunctionException(Exception):
  856. pass
  857. class TimeoutFunction:
  858. """Force an operation to timeout after N seconds. Works with POSIX
  859. signals, so it's not safe to use in a multi-treaded environment"""
  860. def __init__(self, function, timeout):
  861. self.timeout = timeout
  862. self.function = function
  863. def handle_timeout(self, signum, frame):
  864. raise TimeoutFunctionException()
  865. def __call__(self, *args):
  866. # can only be called from the main thread
  867. old = signal.signal(signal.SIGALRM, self.handle_timeout)
  868. signal.alarm(self.timeout)
  869. try:
  870. result = self.function(*args)
  871. finally:
  872. signal.alarm(0)
  873. signal.signal(signal.SIGALRM, old)
  874. return result
  875. def make_offset_date(start_date, interval, future = True,
  876. business_days = False):
  877. """
  878. Generates a date in the future or past "interval" days from start_date.
  879. Can optionally give weekends no weight in the calculation if
  880. "business_days" is set to true.
  881. """
  882. if interval is not None:
  883. interval = int(interval)
  884. if business_days:
  885. weeks = interval / 7
  886. dow = start_date.weekday()
  887. if future:
  888. future_dow = (dow + interval) % 7
  889. if dow > future_dow or future_dow > 4:
  890. weeks += 1
  891. else:
  892. future_dow = (dow - interval) % 7
  893. if dow < future_dow or future_dow > 4:
  894. weeks += 1
  895. interval += 2 * weeks;
  896. if future:
  897. return start_date + timedelta(interval)
  898. return start_date - timedelta(interval)
  899. return start_date
  900. def to_date(d):
  901. if isinstance(d, datetime):
  902. return d.date()
  903. return d
  904. def in_chunks(it, size=25):
  905. chunk = []
  906. it = iter(it)
  907. try:
  908. while True:
  909. chunk.append(it.next())
  910. if len(chunk) >= size:
  911. yield chunk
  912. chunk = []
  913. except StopIteration:
  914. if chunk:
  915. yield chunk
  916. def spaceout(items, targetseconds,
  917. minsleep = 0, die = False,
  918. estimate = None):
  919. """Given a list of items and a function to apply to them, space
  920. the execution out over the target number of seconds and
  921. optionally stop when we're out of time"""
  922. targetseconds = float(targetseconds)
  923. state = [1.0]
  924. if estimate is None:
  925. try:
  926. estimate = len(items)
  927. except TypeError:
  928. # if we can't come up with an estimate, the best we can do
  929. # is just enforce the minimum sleep time (and the max
  930. # targetseconds if die==True)
  931. pass
  932. mean = lambda lst: sum(float(x) for x in lst)/float(len(lst))
  933. beginning = datetime.now()
  934. for item in items:
  935. start = datetime.now()
  936. yield item
  937. end = datetime.now()
  938. took_delta = end - start
  939. took = (took_delta.days * 60 * 24
  940. + took_delta.seconds
  941. + took_delta.microseconds/1000000.0)
  942. state.append(took)
  943. if len(state) > 10:
  944. del state[0]
  945. if die and end > beginning + timedelta(seconds=targetseconds):
  946. # we ran out of time, ignore the rest of the iterator
  947. break
  948. if estimate is None:
  949. if minsleep:
  950. # we have no idea how many items we're going to get
  951. sleep(minsleep)
  952. else:
  953. sleeptime = max((targetseconds / estimate) - mean(state),
  954. minsleep)
  955. if sleeptime > 0:
  956. sleep(sleeptime)
  957. def progress(it, verbosity=100, key=repr, estimate=None, persec=True):
  958. """An iterator that yields everything from `it', but prints progress
  959. information along the way, including time-estimates if
  960. possible"""
  961. from itertools import islice
  962. from datetime import datetime
  963. import sys
  964. now = start = datetime.now()
  965. elapsed = start - start
  966. # try to guess at the estimate if we can
  967. if estimate is None:
  968. try:
  969. estimate = len(it)
  970. except:
  971. pass
  972. def timedelta_to_seconds(td):
  973. return td.days * (24*60*60) + td.seconds + (float(td.microseconds) / 1000000)
  974. def format_timedelta(td, sep=''):
  975. ret = []
  976. s = timedelta_to_seconds(td)
  977. if s < 0:
  978. neg = True
  979. s *= -1
  980. else:
  981. neg = False
  982. if s >= (24*60*60):
  983. days = int(s//(24*60*60))
  984. ret.append('%dd' % days)
  985. s -= days*(24*60*60)
  986. if s >= 60*60:
  987. hours = int(s//(60*60))
  988. ret.append('%dh' % hours)
  989. s -= hours*(60*60)
  990. if s >= 60:
  991. minutes = int(s//60)
  992. ret.append('%dm' % minutes)
  993. s -= minutes*60
  994. if s >= 1:
  995. seconds = int(s)
  996. ret.append('%ds' % seconds)
  997. s -= seconds
  998. if not ret:
  999. return '0s'
  1000. return ('-' if neg else '') + sep.join(ret)
  1001. def format_datetime(dt, show_date=False):
  1002. if show_date:
  1003. return dt.strftime('%Y-%m-%d %H:%M')
  1004. else:
  1005. return dt.strftime('%H:%M:%S')
  1006. def deq(dt1, dt2):
  1007. "Indicates whether the two datetimes' dates describe the same (day,month,year)"
  1008. d1, d2 = dt1.date(), dt2.date()
  1009. return ( d1.day == d2.day
  1010. and d1.month == d2.month
  1011. and d1.year == d2.year)
  1012. sys.stderr.write('Starting at %s\n' % (start,))
  1013. # we're going to islice it so we need to start an iterator
  1014. it = iter(it)
  1015. seen = 0
  1016. while True:
  1017. this_chunk = 0
  1018. thischunk_started = datetime.now()
  1019. # the simple bit: just iterate and yield
  1020. for item in islice(it, verbosity):
  1021. this_chunk += 1
  1022. seen += 1
  1023. yield item
  1024. if this_chunk < verbosity:
  1025. # we're done, the iterator is empty
  1026. break
  1027. now = datetime.now()
  1028. elapsed = now - start
  1029. thischunk_seconds = timedelta_to_seconds(now - thischunk_started)
  1030. if estimate:
  1031. # the estimate is based on the total number of items that
  1032. # we've processed in the total amount of time that's
  1033. # passed, so it should smooth over momentary spikes in
  1034. # speed (but will take a while to adjust to long-term
  1035. # changes in speed)
  1036. remaining = ((elapsed/seen)*estimate)-elapsed
  1037. completion = now + remaining
  1038. count_str = ('%d/%d %.2f%%'
  1039. % (seen, estimate, float(seen)/estimate*100))
  1040. completion_str = format_datetime(completion, not deq(completion,now))
  1041. estimate_str = (' (%s remaining; completion %s)'
  1042. % (format_timedelta(remaining),
  1043. completion_str))
  1044. else:
  1045. count_str = '%d' % seen
  1046. estimate_str = ''
  1047. if key:
  1048. key_str = ': %s' % key(item)
  1049. else:
  1050. key_str = ''
  1051. # unlike the estimate, the persec count is the number per
  1052. # second for *this* batch only, without smoothing
  1053. if persec and thischunk_seconds > 0:
  1054. persec_str = ' (%.1f/s)' % (float(this_chunk)/thischunk_seconds,)
  1055. else:
  1056. persec_str = ''
  1057. sys.stderr.write('%s%s, %s%s%s\n'
  1058. % (count_str, persec_str,
  1059. format_timedelta(elapsed), estimate_str, key_str))
  1060. now = datetime.now()
  1061. elapsed = now - start
  1062. elapsed_seconds = timedelta_to_seconds(elapsed)
  1063. if persec and seen > 0 and elapsed_seconds > 0:
  1064. persec_str = ' (@%.1f/sec)' % (float(seen)/elapsed_seconds)
  1065. else:
  1066. persec_str = ''
  1067. sys.stderr.write('Processed %d%s items in %s..%s (%s)\n'
  1068. % (seen,
  1069. persec_str,
  1070. format_datetime(start, not deq(start, now)),
  1071. format_datetime(now, not deq(start, now)),
  1072. format_timedelta(elapsed)))
  1073. class Hell(object):
  1074. def __str__(self):
  1075. return "boom!"
  1076. class Bomb(object):
  1077. @classmethod
  1078. def __getattr__(cls, key):
  1079. raise Hell()
  1080. @classmethod
  1081. def __setattr__(cls, key, val):
  1082. raise Hell()
  1083. @classmethod
  1084. def __repr__(cls):
  1085. raise Hell()
  1086. class SimpleSillyStub(object):
  1087. """A simple stub object that does nothing when you call its methods."""
  1088. def __nonzero__(self):
  1089. return False
  1090. def __getattr__(self, name):
  1091. return self.stub
  1092. def stub(self, *args, **kwargs):
  1093. pass
  1094. def strordict_fullname(item, key='fullname'):
  1095. """Sometimes we migrate AMQP queues from simple strings to pickled
  1096. dictionaries. During the migratory period there may be items in
  1097. the queue of both types, so this function tries to detect which
  1098. the item is. It shouldn't really be used on a given queue for more
  1099. than a few hours or days"""
  1100. try:
  1101. d = pickle.loads(item)
  1102. except:
  1103. d = {key: item}
  1104. if (not isinstance(d, dict)
  1105. or key not in d
  1106. or not isinstance(d[key], str)):
  1107. raise ValueError('Error trying to migrate %r (%r)'
  1108. % (item, d))
  1109. return d
  1110. def thread_dump(*a):
  1111. import sys, traceback
  1112. from datetime import datetime
  1113. sys.stderr.write('%(t)s Thread Dump @%(d)s %(t)s\n' % dict(t='*'*15,
  1114. d=datetime.now()))
  1115. for thread_id, stack in sys._current_frames().items():
  1116. sys.stderr.write('\t-- Thread ID: %s--\n' % (thread_id,))
  1117. for filename, lineno, fnname, line in traceback.extract_stack(stack):
  1118. sys.stderr.write('\t\t%(filename)s(%(lineno)d): %(fnname)s\n'
  1119. % dict(filename=filename, lineno=lineno, fnname=fnname))
  1120. sys.stderr.write('\t\t\t%(line)s\n' % dict(line=line))
  1121. def constant_time_compare(actual, expected):
  1122. """
  1123. Returns True if the two strings are equal, False otherwise
  1124. The time taken is dependent on the number of characters provided
  1125. instead of the number of characters that match.
  1126. """
  1127. actual_len = len(actual)
  1128. expected_len = len(expected)
  1129. result = actual_len ^ expected_len
  1130. if expected_len > 0:
  1131. for i in xrange(actual_len):
  1132. result |= ord(actual[i]) ^ ord(expected[i % expected_len])
  1133. return result == 0
  1134. def wraps_api(f):
  1135. # work around issue where wraps() requires attributes to exist
  1136. if not hasattr(f, '_api_doc'):
  1137. f._api_doc = {}
  1138. return wraps(f, assigned=WRAPPER_ASSIGNMENTS+('_api_doc',))
  1139. def extract_urls_from_markdown(md):
  1140. "Extract URLs that will be hot links from a piece of raw Markdown."
  1141. html = snudown.markdown(_force_utf8(md))
  1142. links = SoupStrainer("a")
  1143. for link in BeautifulSoup(html, parseOnlyThese=links):
  1144. url = link.get('href')
  1145. if url:
  1146. yield url
  1147. def summarize_markdown(md):
  1148. """Get the first paragraph of some Markdown text, potentially truncated."""
  1149. first_graf, sep, rest = md.partition("\n\n")
  1150. return first_graf[:500]
  1151. def find_containing_network(ip_ranges, address):
  1152. """Find an IP network that contains the given address."""
  1153. addr = ipaddress.ip_address(address)
  1154. for network in ip_ranges:
  1155. if addr in network:
  1156. return network
  1157. return None
  1158. def is_throttled(address):
  1159. """Determine if an IP address is in a throttled range."""
  1160. return bool(find_containing_network(g.throttles, address))
  1161. def parse_http_basic(authorization_header):
  1162. """Parse the username/credentials out of an HTTP Basic Auth header.
  1163. Raises RequirementException if anything is uncool.
  1164. """
  1165. auth_scheme, auth_token = require_split(authorization_header, 2)
  1166. require(auth_scheme.lower() == "basic")
  1167. try:
  1168. auth_data = base64.b64decode(auth_token)
  1169. except TypeError:
  1170. raise RequirementException
  1171. return require_split(auth_data, 2, ":")
  1172. def simple_traceback():
  1173. """Generate a pared-down traceback that's human readable but small."""
  1174. stack_trace = traceback.extract_stack(limit=7)[:-2]
  1175. return "\n".join(":".join((os.path.basename(filename),
  1176. function_name,
  1177. str(line_number),
  1178. ))
  1179. for filename, line_number, function_name, text
  1180. in stack_trace)
  1181. class GoldPrice(object):
  1182. """Simple price math / formatting type.
  1183. Prices are assumed to be USD at the moment.
  1184. """
  1185. def __init__(self, decimal):
  1186. self.decimal = Decimal(decimal)
  1187. def __mul__(self, other):
  1188. return type(self)(self.decimal * other)
  1189. def __div__(self, other):
  1190. return type(self)(self.decimal / other)
  1191. def __str__(self):
  1192. return "$%s" % self.decimal.quantize(Decimal("1.00"))
  1193. def __repr__(self):
  1194. return "%s(%s)" % (type(self).__name__, self)
  1195. @property
  1196. def pennies(self):
  1197. return int(self.decimal * 100)
  1198. def config_gold_price(v, key=None, data=None):
  1199. return GoldPrice(v)
  1200. def read_static_file_config(config_file):
  1201. parser = ConfigParser.RawConfigParser()
  1202. with open(config_file, "r") as cf:
  1203. parser.readfp(cf)
  1204. config = dict(parser.items("static_files"))
  1205. s3 = boto.connect_s3(config["aws_access_key_id"],
  1206. config["aws_secret_access_key"])
  1207. bucket = s3.get_bucket(config["bucket"])
  1208. return bucket, config