PageRenderTime 73ms CodeModel.GetById 15ms RepoModel.GetById 0ms app.codeStats 0ms

/r2/r2/lib/utils/utils.py

https://github.com/wangmxf/lesswrong
Python | 1073 lines | 1031 code | 5 blank | 37 comment | 21 complexity | 9543a2168a3b2acb658c53d061be6f0a MD5 | raw file
Possible License(s): MPL-2.0-no-copyleft-exception, LGPL-2.1
  1. # The contents of this file are subject to the Common Public Attribution
  2. # License Version 1.0. (the "License"); you may not use this file except in
  3. # compliance with the License. You may obtain a copy of the License at
  4. # http://code.reddit.com/LICENSE. The License is based on the Mozilla Public
  5. # License Version 1.1, but Sections 14 and 15 have been added to cover use of
  6. # software over a computer network and provide for limited attribution for the
  7. # Original Developer. In addition, Exhibit A has been modified to be consistent
  8. # with Exhibit B.
  9. #
  10. # Software distributed under the License is distributed on an "AS IS" basis,
  11. # WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License for
  12. # the specific language governing rights and limitations under the License.
  13. #
  14. # The Original Code is Reddit.
  15. #
  16. # The Original Developer is the Initial Developer. The Initial Developer of the
  17. # Original Code is CondeNet, Inc.
  18. #
  19. # All portions of the code written by CondeNet are Copyright (c) 2006-2008
  20. # CondeNet, Inc. All Rights Reserved.
  21. ################################################################################
  22. from urllib import unquote_plus, quote_plus, urlopen, urlencode
  23. from urlparse import urlparse, urlunparse
  24. from threading import local, Thread
  25. import Queue
  26. from copy import deepcopy
  27. import cPickle as pickle
  28. import re, datetime, math, random, string, os, yaml
  29. from datetime import datetime, timedelta, tzinfo
  30. from pylons.i18n import ungettext, _
  31. from r2.lib.filters import _force_unicode
  32. from mako.filters import url_escape, url_unescape
  33. from pylons import g
  34. iters = (list, tuple, set)
  35. def tup(item, ret_is_single=False):
  36. """Forces casting of item to a tuple (for a list) or generates a
  37. single element tuple (for anything else)"""
  38. #return true for iterables, except for strings, which is what we want
  39. if hasattr(item, '__iter__'):
  40. return (item, False) if ret_is_single else item
  41. else:
  42. return ((item,), True) if ret_is_single else (item,)
  43. def randstr(len, reallyrandom = False):
  44. """If reallyrandom = False, generates a random alphanumeric string
  45. (base-36 compatible) of length len. If reallyrandom, add
  46. uppercase and punctuation (which we'll call 'base-93' for the sake
  47. of argument) and suitable for use as salt."""
  48. alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789'
  49. if reallyrandom:
  50. alphabet += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ!#$%&\()*+,-./:;<=>?@[\\]^_{|}~'
  51. return ''.join(random.choice(alphabet)
  52. for i in range(len))
  53. class Storage(dict):
  54. """
  55. A Storage object is like a dictionary except `obj.foo` can be used
  56. in addition to `obj['foo']`.
  57. >>> o = storage(a=1)
  58. >>> o.a
  59. 1
  60. >>> o['a']
  61. 1
  62. >>> o.a = 2
  63. >>> o['a']
  64. 2
  65. >>> del o.a
  66. >>> o.a
  67. Traceback (most recent call last):
  68. ...
  69. AttributeError: 'a'
  70. """
  71. def __getattr__(self, key):
  72. try:
  73. return self[key]
  74. except KeyError, k:
  75. raise AttributeError, k
  76. def __setattr__(self, key, value):
  77. self[key] = value
  78. def __delattr__(self, key):
  79. try:
  80. del self[key]
  81. except KeyError, k:
  82. raise AttributeError, k
  83. def __repr__(self):
  84. return '<Storage ' + dict.__repr__(self) + '>'
  85. storage = Storage
  86. def storify(mapping, *requireds, **defaults):
  87. """
  88. Creates a `storage` object from dictionary `mapping`, raising `KeyError` if
  89. d doesn't have all of the keys in `requireds` and using the default
  90. values for keys found in `defaults`.
  91. For example, `storify({'a':1, 'c':3}, b=2, c=0)` will return the equivalent of
  92. `storage({'a':1, 'b':2, 'c':3})`.
  93. If a `storify` value is a list (e.g. multiple values in a form submission),
  94. `storify` returns the last element of the list, unless the key appears in
  95. `defaults` as a list. Thus:
  96. >>> storify({'a':[1, 2]}).a
  97. 2
  98. >>> storify({'a':[1, 2]}, a=[]).a
  99. [1, 2]
  100. >>> storify({'a':1}, a=[]).a
  101. [1]
  102. >>> storify({}, a=[]).a
  103. []
  104. Similarly, if the value has a `value` attribute, `storify will return _its_
  105. value, unless the key appears in `defaults` as a dictionary.
  106. >>> storify({'a':storage(value=1)}).a
  107. 1
  108. >>> storify({'a':storage(value=1)}, a={}).a
  109. <Storage {'value': 1}>
  110. >>> storify({}, a={}).a
  111. {}
  112. """
  113. def getvalue(x):
  114. if hasattr(x, 'value'):
  115. return x.value
  116. else:
  117. return x
  118. stor = Storage()
  119. for key in requireds + tuple(mapping.keys()):
  120. value = mapping[key]
  121. if isinstance(value, list):
  122. if isinstance(defaults.get(key), list):
  123. value = [getvalue(x) for x in value]
  124. else:
  125. value = value[-1]
  126. if not isinstance(defaults.get(key), dict):
  127. value = getvalue(value)
  128. if isinstance(defaults.get(key), list) and not isinstance(value, list):
  129. value = [value]
  130. setattr(stor, key, value)
  131. for (key, value) in defaults.iteritems():
  132. result = value
  133. if hasattr(stor, key):
  134. result = stor[key]
  135. if value == () and not isinstance(result, tuple):
  136. result = (result,)
  137. setattr(stor, key, result)
  138. return stor
  139. def _strips(direction, text, remove):
  140. if direction == 'l':
  141. if text.startswith(remove):
  142. return text[len(remove):]
  143. elif direction == 'r':
  144. if text.endswith(remove):
  145. return text[:-len(remove)]
  146. else:
  147. raise ValueError, "Direction needs to be r or l."
  148. return text
  149. def rstrips(text, remove):
  150. """
  151. removes the string `remove` from the right of `text`
  152. >>> rstrips("foobar", "bar")
  153. 'foo'
  154. """
  155. return _strips('r', text, remove)
  156. def lstrips(text, remove):
  157. """
  158. removes the string `remove` from the left of `text`
  159. >>> lstrips("foobar", "foo")
  160. 'bar'
  161. """
  162. return _strips('l', text, remove)
  163. def strips(text, remove):
  164. """removes the string `remove` from the both sides of `text`
  165. >>> strips("foobarfoo", "foo")
  166. 'bar'
  167. """
  168. return rstrips(lstrips(text, remove), remove)
  169. class Results():
  170. def __init__(self, sa_ResultProxy, build_fn, do_batch=False):
  171. self.rp = sa_ResultProxy
  172. self.fn = build_fn
  173. self.do_batch = do_batch
  174. @property
  175. def rowcount(self):
  176. return self.rp.rowcount
  177. def _fetch(self, res):
  178. if self.do_batch:
  179. return self.fn(res)
  180. else:
  181. return [self.fn(row) for row in res]
  182. def fetchall(self):
  183. return self._fetch(self.rp.fetchall())
  184. def fetchmany(self, n):
  185. rows = self._fetch(self.rp.fetchmany(n))
  186. if rows:
  187. return rows
  188. else:
  189. raise StopIteration
  190. def fetchone(self):
  191. row = self.rp.fetchone()
  192. if row:
  193. if self.do_batch:
  194. row = tup(row)
  195. return self.fn(row)[0]
  196. else:
  197. return self.fn(row)
  198. else:
  199. raise StopIteration
  200. def string2js(s):
  201. """adapted from http://svn.red-bean.com/bob/simplejson/trunk/simplejson/encoder.py"""
  202. ESCAPE = re.compile(r'[\x00-\x19\\"\b\f\n\r\t]')
  203. ESCAPE_ASCII = re.compile(r'([\\"/]|[^\ -~])')
  204. ESCAPE_DCT = {
  205. # escape all forward slashes to prevent </script> attack
  206. '/': '\\/',
  207. '\\': '\\\\',
  208. '"': '\\"',
  209. '\b': '\\b',
  210. '\f': '\\f',
  211. '\n': '\\n',
  212. '\r': '\\r',
  213. '\t': '\\t',
  214. }
  215. for i in range(20):
  216. ESCAPE_DCT.setdefault(chr(i), '\\u%04x' % (i,))
  217. def replace(match):
  218. return ESCAPE_DCT[match.group(0)]
  219. return '"' + ESCAPE.sub(replace, s) + '"'
  220. r_base_url = re.compile("(?i)(?:.+?://)?(?:www[\d]*\.)?([^#]*[^#/])/?")
  221. def base_url(url):
  222. res = r_base_url.findall(url)
  223. return (res and res[0]) or url
  224. r_domain = re.compile("(?i)(?:.+?://)?(?:www[\d]*\.)?([^/:#?]*)")
  225. def domain(s):
  226. """
  227. Takes a URL and returns the domain part, minus www., if
  228. present
  229. """
  230. res = r_domain.findall(s)
  231. domain = (res and res[0]) or s
  232. return domain.lower()
  233. r_path_component = re.compile(".*?/(.*)")
  234. def path_component(s):
  235. """
  236. takes a url http://www.foo.com/i/like/cheese and returns
  237. i/like/cheese
  238. """
  239. res = r_path_component.findall(base_url(s))
  240. return (res and res[0]) or s
  241. r_title = re.compile('<title>(.*?)<\/title>', re.I|re.S)
  242. r_charset = re.compile("<meta.*charset\W*=\W*([\w_-]+)", re.I|re.S)
  243. r_encoding = re.compile("<?xml.*encoding=\W*([\w_-]+)", re.I|re.S)
  244. def get_title(url):
  245. """Fetches the contents of url and extracts (and utf-8 encodes)
  246. the contents of <title>"""
  247. import chardet
  248. if not url or not url.startswith('http://'): return None
  249. try:
  250. content = urlopen(url).read()
  251. t = r_title.findall(content)
  252. if t:
  253. title = t[0].strip()
  254. en = (r_charset.findall(content) or
  255. r_encoding.findall(content))
  256. encoding = en[0] if en else chardet.detect(content)["encoding"]
  257. if encoding:
  258. title = unicode(title, encoding).encode("utf-8")
  259. return title
  260. except: return None
  261. valid_schemes = ('http', 'https', 'ftp', 'mailto')
  262. def sanitize_url(url, require_scheme = False):
  263. """Validates that the url is of the form
  264. scheme://domain/path/to/content#anchor?cruft
  265. using the python built-in urlparse. If the url fails to validate,
  266. returns None. If no scheme is provided and 'require_scheme =
  267. False' is set, the url is returned with scheme 'http', provided it
  268. otherwise validates"""
  269. if not url or ' ' in url:
  270. return
  271. url = url.strip()
  272. if url.lower() == 'self':
  273. return url
  274. u = urlparse(url)
  275. # first pass: make sure a scheme has been specified
  276. if not require_scheme and not u.scheme:
  277. url = 'http://' + url
  278. u = urlparse(url)
  279. if (u.scheme and u.scheme in valid_schemes
  280. and u.hostname and len(u.hostname) < 255
  281. and '%' not in u.netloc):
  282. return url
  283. def timeago(interval):
  284. """Returns a datetime object corresponding to time 'interval' in
  285. the past. Interval is of the same form as is returned by
  286. timetext(), i.e., '10 seconds'. The interval must be passed in in
  287. English (i.e., untranslated) and the format is
  288. [num] second|minute|hour|day|week|month|year(s)
  289. """
  290. from pylons import g
  291. return datetime.now(g.tz) - timeinterval_fromstr(interval)
  292. def timefromnow(interval):
  293. "The opposite of timeago"
  294. from pylons import g
  295. return datetime.now(g.tz) + timeinterval_fromstr(interval)
  296. def timeinterval_fromstr(interval):
  297. "Used by timeago and timefromnow to generate timedeltas from friendly text"
  298. parts = interval.strip().split(' ')
  299. if len(parts) == 1:
  300. num = 1
  301. period = parts[0]
  302. elif len(parts) == 2:
  303. num, period = parts
  304. num = int(num)
  305. else:
  306. raise ValueError, 'format should be ([num] second|minute|etc)'
  307. period = rstrips(period, 's')
  308. d = dict(second = 1,
  309. minute = 60,
  310. hour = 60 * 60,
  311. day = 60 * 60 * 24,
  312. week = 60 * 60 * 24 * 7,
  313. month = 60 * 60 * 24 * 30,
  314. year = 60 * 60 * 24 * 365)[period]
  315. delta = num * d
  316. return timedelta(0, delta)
  317. def timetext(delta, resultion = 1, bare=True):
  318. """
  319. Takes a datetime object, returns the time between then and now
  320. as a nicely formatted string, e.g "10 minutes"
  321. Adapted from django which was adapted from
  322. http://blog.natbat.co.uk/archive/2003/Jun/14/time_since
  323. """
  324. chunks = (
  325. (60 * 60 * 24 * 365, lambda n: ungettext('year', 'years', n)),
  326. (60 * 60 * 24 * 30, lambda n: ungettext('month', 'months', n)),
  327. (60 * 60 * 24, lambda n : ungettext('day', 'days', n)),
  328. (60 * 60, lambda n: ungettext('hour', 'hours', n)),
  329. (60, lambda n: ungettext('minute', 'minutes', n)),
  330. (1, lambda n: ungettext('second', 'seconds', n))
  331. )
  332. delta = max(delta, timedelta(0))
  333. since = delta.days * 24 * 60 * 60 + delta.seconds
  334. for i, (seconds, name) in enumerate(chunks):
  335. count = math.floor(since / seconds)
  336. if count != 0:
  337. break
  338. from r2.lib.strings import strings
  339. if count == 0 and delta.seconds == 0 and delta != timedelta(0):
  340. n = math.floor(delta.microseconds / 1000)
  341. s = strings.number_label % (n, ungettext("millisecond",
  342. "milliseconds", n))
  343. else:
  344. s = strings.number_label % (count, name(int(count)))
  345. if resultion > 1:
  346. if i + 1 < len(chunks):
  347. # Now get the second item
  348. seconds2, name2 = chunks[i + 1]
  349. count2 = (since - (seconds * count)) / seconds2
  350. if count2 != 0:
  351. s += ', %d %s' % (count2, name2(count2))
  352. if not bare: s += ' ' + _('ago')
  353. return s
  354. def timesince(d, resultion = 1, bare = True):
  355. from pylons import g
  356. return timetext(datetime.now(g.tz) - d, resultion, bare)
  357. def timeuntil(d, resultion = 1, bare = True):
  358. from pylons import g
  359. return timetext(d - datetime.now(g.tz), resultion, bare)
  360. def epochtime(date):
  361. if not date:
  362. return "0"
  363. date = date.astimezone(g.tz)
  364. return date.strftime("%s")
  365. def prettytime(date, seconds = False):
  366. date = date.astimezone(g.tz)
  367. return date.strftime('%d %B %Y %I:%M:%S%p' if seconds else '%d %B %Y %I:%M%p')
  368. def rfc822format(date):
  369. return date.strftime('%a, %d %b %Y %H:%M:%S %z')
  370. def usformat(date):
  371. """
  372. Format a datetime in US date format
  373. Makes the date readable by the Protoplasm datetime picker
  374. """
  375. return date.strftime('%m-%d-%Y %H:%M:%S')
  376. def median(nums):
  377. """Find the median of a list of numbers, which is assumed to already be sorted."""
  378. count = len(nums)
  379. mid = count // 2
  380. if count % 2:
  381. return nums[mid]
  382. else:
  383. return (nums[mid - 1] + nums[mid]) / 2
  384. def to_base(q, alphabet):
  385. if q < 0: raise ValueError, "must supply a positive integer"
  386. l = len(alphabet)
  387. converted = []
  388. while q != 0:
  389. q, r = divmod(q, l)
  390. converted.insert(0, alphabet[r])
  391. return "".join(converted) or '0'
  392. def to36(q):
  393. return to_base(q, '0123456789abcdefghijklmnopqrstuvwxyz')
  394. def query_string(dict):
  395. pairs = []
  396. for k,v in dict.iteritems():
  397. if v is not None:
  398. try:
  399. k = url_escape(unicode(k).encode('utf-8'))
  400. v = url_escape(unicode(v).encode('utf-8'))
  401. pairs.append(k + '=' + v)
  402. except UnicodeDecodeError:
  403. continue
  404. if pairs:
  405. return '?' + '&'.join(pairs)
  406. else:
  407. return ''
  408. class UrlParser(object):
  409. """
  410. Wrapper for urlparse and urlunparse for making changes to urls.
  411. All attributes present on the tuple-like object returned by
  412. urlparse are present on this class, and are setable, with the
  413. exception of netloc, which is instead treated via a getter method
  414. as a concatenation of hostname and port.
  415. Unlike urlparse, this class allows the query parameters to be
  416. converted to a dictionary via the query_dict method (and
  417. correspondingly updated vi update_query). The extension of the
  418. path can also be set and queried.
  419. The class also contains reddit-specific functions for setting,
  420. checking, and getting a path's subreddit. It also can convert
  421. paths between in-frame and out of frame cname'd forms.
  422. """
  423. __slots__ = ['scheme', 'path', 'params', 'query',
  424. 'fragment', 'username', 'password', 'hostname',
  425. 'port', '_url_updates', '_orig_url', '_query_dict']
  426. valid_schemes = ('http', 'https', 'ftp', 'mailto')
  427. cname_get = "cnameframe"
  428. def __init__(self, url):
  429. u = urlparse(url)
  430. for s in self.__slots__:
  431. if hasattr(u, s):
  432. setattr(self, s, getattr(u, s))
  433. self._url_updates = {}
  434. self._orig_url = url
  435. self._query_dict = None
  436. def update_query(self, **updates):
  437. """
  438. Can be used instead of self.query_dict.update() to add/change
  439. query params in situations where the original contents are not
  440. required.
  441. """
  442. self._url_updates.update(updates)
  443. @property
  444. def query_dict(self):
  445. """
  446. Parses the `params' attribute of the original urlparse and
  447. generates a dictionary where both the keys and values have
  448. been url_unescape'd. Any updates or changes to the resulting
  449. dict will be reflected in the updated query params
  450. """
  451. if self._query_dict is None:
  452. def _split(param):
  453. p = param.split('=')
  454. return (unquote_plus(p[0]),
  455. unquote_plus('='.join(p[1:])))
  456. self._query_dict = dict(_split(p) for p in self.query.split('&')
  457. if p)
  458. return self._query_dict
  459. def path_extension(self):
  460. """
  461. Fetches the current extension of the path.
  462. """
  463. return self.path.split('/')[-1].split('.')[-1]
  464. def set_extension(self, extension):
  465. """
  466. Changes the extension of the path to the provided value (the
  467. "." should not be included in the extension as a "." is
  468. provided)
  469. """
  470. pieces = self.path.split('/')
  471. dirs = pieces[:-1]
  472. base = pieces[-1].split('.')
  473. base = '.'.join(base[:-1] if len(base) > 1 else base)
  474. if extension:
  475. base += '.' + extension
  476. dirs.append(base)
  477. self.path = '/'.join(dirs)
  478. return self
  479. def unparse(self):
  480. """
  481. Converts the url back to a string, applying all updates made
  482. to the feilds thereof.
  483. Note: if a host name has been added and none was present
  484. before, will enforce scheme -> "http" unless otherwise
  485. specified. Double-slashes are removed from the resultant
  486. path, and the query string is reconstructed only if the
  487. query_dict has been modified/updated.
  488. """
  489. # only parse the query params if there is an update dict
  490. q = self.query
  491. if self._url_updates or self._query_dict is not None:
  492. q = self._query_dict or self.query_dict
  493. q.update(self._url_updates)
  494. q = query_string(q).lstrip('?')
  495. # make sure the port is not doubly specified
  496. if self.port and ":" in self.hostname:
  497. self.hostname = self.hostname.split(':')[0]
  498. # if there is a netloc, there had better be a scheme
  499. if self.netloc and not self.scheme:
  500. self.scheme = "http"
  501. return urlunparse((self.scheme, self.netloc,
  502. self.path.replace('//', '/'),
  503. self.params, q, self.fragment))
  504. def path_has_subreddit(self):
  505. """
  506. utility method for checking if the path starts with a
  507. subreddit specifier (namely /r/ or /categories/).
  508. """
  509. return (self.path.startswith('/r/') or
  510. self.path.startswith('/categories/'))
  511. def get_subreddit(self):
  512. """checks if the current url refers to a subreddit and returns
  513. that subreddit object. The cases here are:
  514. * the hostname is unset or is g.domain, in which case it
  515. looks for /r/XXXX or /categories. The default in this case
  516. is Default.
  517. * the hostname is a cname to a known subreddit.
  518. On failure to find a subreddit, returns None.
  519. """
  520. from pylons import g
  521. from r2.models import Subreddit, Sub, NotFound, Default
  522. try:
  523. if not self.hostname or self.hostname.startswith(g.domain):
  524. if self.path.startswith('/r/'):
  525. return Subreddit._by_name(self.path.split('/')[2])
  526. elif self.path.startswith('/categories/'):
  527. return Sub
  528. else:
  529. return Default
  530. elif self.hostname:
  531. return Subreddit._by_domain(self.hostname)
  532. except NotFound:
  533. pass
  534. return None
  535. def is_reddit_url(self, subreddit = None):
  536. """utility method for seeing if the url is associated with
  537. reddit as we don't necessarily want to mangle non-reddit
  538. domains
  539. returns true only if hostname is nonexistant, a subdomain of
  540. g.domain, or a subdomain of the provided subreddit's cname.
  541. """
  542. from pylons import g
  543. return (not self.hostname or
  544. self.hostname.endswith(g.domain) or
  545. (subreddit and subreddit.domain and
  546. self.hostname.endswith(subreddit.domain)))
  547. def path_add_subreddit(self, subreddit):
  548. """
  549. Adds the subreddit's path to the path if another subreddit's
  550. prefix is not already present.
  551. """
  552. if not self.path_has_subreddit() and subreddit.path != '/categories/':
  553. self.path = (subreddit.path + self.path)
  554. return self
  555. @property
  556. def netloc(self):
  557. """
  558. Getter method which returns the hostname:port, or empty string
  559. if no hostname is present.
  560. """
  561. if not self.hostname:
  562. return ""
  563. elif self.port:
  564. return self.hostname + ":" + str(self.port)
  565. return self.hostname
  566. def mk_cname(self, require_frame = True, subreddit = None, port = None):
  567. """
  568. Converts a ?cnameframe url into the corresponding cnamed
  569. domain if applicable. Useful for frame-busting on redirect.
  570. """
  571. # make sure the url is indeed in a frame
  572. if require_frame and not self.query_dict.has_key(self.cname_get):
  573. return self
  574. # fetch the subreddit and make sure it
  575. subreddit = subreddit or self.get_subreddit()
  576. if subreddit and subreddit.domain:
  577. # no guarantee there was a scheme
  578. self.scheme = self.scheme or "http"
  579. # update the domain (preserving the port)
  580. self.hostname = subreddit.domain
  581. self.port = self.port or port
  582. # and remove any cnameframe GET parameters
  583. if self.query_dict.has_key(self.cname_get):
  584. del self._query_dict[self.cname_get]
  585. # remove the subreddit reference
  586. self.path = lstrips(self.path, subreddit.path)
  587. if not self.path.startswith('/'):
  588. self.path = '/' + self.path
  589. return self
  590. def is_in_frame(self):
  591. """
  592. Checks if the url is in a frame by determining if
  593. cls.cname_get is present.
  594. """
  595. return self.query_dict.has_key(self.cname_get)
  596. def put_in_frame(self):
  597. """
  598. Adds the cls.cname_get get parameter to the query string.
  599. """
  600. self.update_query(**{self.cname_get:random.random()})
  601. def __repr__(self):
  602. return "<URL %s>" % repr(self.unparse())
  603. def to_js(content, callback="document.write", escape=True):
  604. before = after = ''
  605. if callback:
  606. before = callback + "("
  607. after = ");"
  608. if escape:
  609. content = string2js(content)
  610. return before + content + after
  611. class TransSet(local):
  612. def __init__(self, items = ()):
  613. self.set = set(items)
  614. self.trans = False
  615. def begin(self):
  616. self.trans = True
  617. def add_engine(self, engine):
  618. if self.trans:
  619. return self.set.add(engine.begin())
  620. def clear(self):
  621. return self.set.clear()
  622. def __iter__(self):
  623. return self.set.__iter__()
  624. def commit(self):
  625. for t in self:
  626. t.commit()
  627. self.clear()
  628. def rollback(self):
  629. for t in self:
  630. t.rollback()
  631. self.clear()
  632. def __del__(self):
  633. self.commit()
  634. def pload(fname, default = None):
  635. "Load a pickled object from a file"
  636. try:
  637. f = file(fname, 'r')
  638. d = pickle.load(f)
  639. except IOError:
  640. d = default
  641. else:
  642. f.close()
  643. return d
  644. def psave(fname, d):
  645. "Save a pickled object into a file"
  646. f = file(fname, 'w')
  647. pickle.dump(d, f)
  648. f.close()
  649. def unicode_safe(res):
  650. try:
  651. return str(res)
  652. except UnicodeEncodeError:
  653. return unicode(res).encode('utf-8')
  654. def decompose_fullname(fullname):
  655. """
  656. decompose_fullname("t3_e4fa") ->
  657. (Thing, 3, 658918)
  658. """
  659. from r2.lib.db.thing import Thing,Relation
  660. if fullname[0] == 't':
  661. type_class = Thing
  662. elif fullname[0] == 'r':
  663. type_class = Relation
  664. type_id36, thing_id36 = fullname[1:].split('_')
  665. type_id = int(type_id36,36)
  666. id = int(thing_id36,36)
  667. return (type_class, type_id, id)
  668. class Worker:
  669. def __init__(self):
  670. self.q = Queue.Queue()
  671. self.t = Thread(target=self._handle)
  672. self.t.setDaemon(True)
  673. self.t.start()
  674. def _handle(self):
  675. while True:
  676. fn = self.q.get()
  677. try:
  678. fn()
  679. except:
  680. import traceback
  681. print traceback.format_exc()
  682. def do(self, fn):
  683. self.q.put(fn)
  684. worker = Worker()
  685. def asynchronous(func):
  686. def _asynchronous(*a, **kw):
  687. f = lambda: func(*a, **kw)
  688. worker.do(f)
  689. return _asynchronous
  690. def cols(lst, ncols):
  691. """divides a list into columns, and returns the
  692. rows. e.g. cols('abcdef', 2) returns (('a', 'd'), ('b', 'e'), ('c',
  693. 'f'))"""
  694. nrows = int(math.ceil(1.*len(lst) / ncols))
  695. lst = lst + [None for i in range(len(lst), nrows*ncols)]
  696. cols = [lst[i:i+nrows] for i in range(0, nrows*ncols, nrows)]
  697. rows = zip(*cols)
  698. rows = [filter(lambda x: x is not None, r) for r in rows]
  699. return rows
  700. def fetch_things(t_class,since,until,batch_fn=None,
  701. *query_params, **extra_query_dict):
  702. """
  703. Simple utility function to fetch all Things of class t_class
  704. (spam or not, but not deleted) that were created from 'since'
  705. to 'until'
  706. """
  707. from r2.lib.db.operators import asc
  708. if not batch_fn:
  709. batch_fn = lambda x: x
  710. query_params = ([t_class.c._date >= since,
  711. t_class.c._date < until,
  712. t_class.c._spam == (True,False)]
  713. + list(query_params))
  714. query_dict = {'sort': asc('_date'),
  715. 'limit': 100,
  716. 'data': True}
  717. query_dict.update(extra_query_dict)
  718. q = t_class._query(*query_params,
  719. **query_dict)
  720. orig_rules = deepcopy(q._rules)
  721. things = list(q)
  722. while things:
  723. things = batch_fn(things)
  724. for t in things:
  725. yield t
  726. q._rules = deepcopy(orig_rules)
  727. q._after(t)
  728. things = list(q)
  729. def fetch_things2(query, chunk_size = 100, batch_fn = None):
  730. """Incrementally run query with a limit of chunk_size until there are
  731. no results left. batch_fn transforms the results for each chunk
  732. before returning."""
  733. orig_rules = deepcopy(query._rules)
  734. query._limit = chunk_size
  735. items = list(query)
  736. done = False
  737. while items and not done:
  738. #don't need to query again at the bottom if we didn't get enough
  739. if len(items) < chunk_size:
  740. done = True
  741. if batch_fn:
  742. items = batch_fn(items)
  743. for i in items:
  744. yield i
  745. if not done:
  746. query._rules = deepcopy(orig_rules)
  747. query._after(i)
  748. items = list(query)
  749. def set_emptying_cache():
  750. """
  751. The default thread-local cache is a regular dictionary, which
  752. isn't designed for long-running processes. This sets the
  753. thread-local cache to be a SelfEmptyingCache, which naively
  754. empties itself out every N requests
  755. """
  756. from pylons import g
  757. from r2.lib.cache import SelfEmptyingCache
  758. g.cache.caches = [SelfEmptyingCache(),] + list(g.cache.caches[1:])
  759. def find_recent_broken_things(from_time = None, delete = False):
  760. """
  761. Occasionally (usually during app-server crashes), Things will
  762. be partially written out to the database. Things missing data
  763. attributes break the contract for these things, which often
  764. breaks various pages. This function hunts for and destroys
  765. them as appropriate.
  766. """
  767. from r2.models import Link,Comment
  768. if not from_time:
  769. from_time = timeago("1 hour")
  770. to_time = timeago("60 seconds")
  771. for (cls,attrs) in ((Link,('author_id','sr_id')),
  772. (Comment,('author_id','sr_id','body','link_id'))):
  773. find_broken_things(cls,attrs,
  774. from_time, to_time,
  775. delete=delete)
  776. def find_broken_things(cls,attrs,from_time,to_time,delete = False):
  777. """
  778. Take a class and list of attributes, searching the database
  779. for Things of that class that are missing those attributes,
  780. deleting them if requested
  781. """
  782. for t in fetch_things(cls,from_time,to_time):
  783. for a in attrs:
  784. try:
  785. # try to retreive the attribute
  786. getattr(t,a)
  787. except AttributeError:
  788. # that failed; let's explicitly load it, and try again
  789. print "Reloading %s" % t._fullname
  790. t._load()
  791. try:
  792. getattr(t,a)
  793. except AttributeError:
  794. # it still broke. We should delete it
  795. print "%s is missing '%s'" % (t._fullname,a)
  796. if delete:
  797. t._deleted = True
  798. t._commit()
  799. break
  800. def timeit(func):
  801. "Run some function, and return (RunTimeInSeconds,Result)"
  802. before=time.time()
  803. res=func()
  804. return (time.time()-before,res)
  805. def lineno():
  806. "Returns the current line number in our program."
  807. import inspect
  808. print "%s\t%s" % (datetime.now(),inspect.currentframe().f_back.f_lineno)
  809. class IteratorChunker(object):
  810. def __init__(self,it):
  811. self.it = it
  812. self.done=False
  813. def next_chunk(self,size):
  814. chunk = []
  815. if not self.done:
  816. try:
  817. for i in xrange(size):
  818. chunk.append(self.it.next())
  819. except StopIteration:
  820. self.done=True
  821. return chunk
  822. def IteratorFilter(iterator, fn):
  823. for x in iterator:
  824. if fn(x):
  825. yield x
  826. def UniqueIterator(iterator):
  827. """
  828. Takes an iterator and returns an iterator that returns only the
  829. first occurence of each entry
  830. """
  831. so_far = set()
  832. def no_dups(x):
  833. if x in so_far:
  834. return False
  835. else:
  836. so_far.add(x)
  837. return True
  838. return IteratorFilter(iterator, no_dups)
  839. # def modhash(user, rand = None, test = False):
  840. # return user.name
  841. # def valid_hash(user, hash):
  842. # return True
  843. def check_cheating(loc):
  844. pass
  845. def vote_hash(user, thing, note='valid'):
  846. return user.name
  847. def valid_vote_hash(hash, user, thing):
  848. return True
  849. def safe_eval_str(unsafe_str):
  850. return unsafe_str.replace('\\x3d', '=').replace('\\x26', '&')
  851. rx_whitespace = re.compile('\s+', re.UNICODE)
  852. rx_notsafe = re.compile('\W+', re.UNICODE)
  853. rx_underscore = re.compile('_+', re.UNICODE)
  854. def title_to_url(title, max_length = 50):
  855. """Takes a string and makes it suitable for use in URLs"""
  856. title = _force_unicode(title) #make sure the title is unicode
  857. title = rx_whitespace.sub('_', title) #remove whitespace
  858. title = rx_notsafe.sub('', title) #remove non-printables
  859. title = rx_underscore.sub('_', title) #remove double underscores
  860. title = title.strip('_') #remove trailing underscores
  861. title = title.lower() #lowercase the title
  862. if len(title) > max_length:
  863. #truncate to nearest word
  864. title = title[:max_length]
  865. last_word = title.rfind('_')
  866. if (last_word > 0):
  867. title = title[:last_word]
  868. return title
  869. def trace(fn):
  870. from pylons import g
  871. def new_fn(*a,**kw):
  872. ret = fn(*a,**kw)
  873. g.log.debug("Fn: %s; a=%s; kw=%s\nRet: %s"
  874. % (fn,a,kw,ret))
  875. return ret
  876. return new_fn
  877. def remote_addr(env):
  878. """
  879. Returns the remote address for the WSGI env passed
  880. Takes proxies into consideration
  881. """
  882. # In production the remote address is always the load balancer
  883. # So check X-Forwarded-For first
  884. # E.g. HTTP_X_FORWARDED_FOR: '66.249.72.73, 75.101.144.164'
  885. if env.has_key('HTTP_X_FORWARDED_FOR'):
  886. ips = re.split(r'\s*,\s*', env['HTTP_X_FORWARDED_FOR'])
  887. if len(ips) > 0:
  888. return ips[0]
  889. return env['REMOTE_ADDR']
  890. # A class building tzinfo objects for a fixed offset.
  891. class FixedOffset(tzinfo):
  892. """Fixed offset in hours east from UTC. name may be None"""
  893. def __init__(self, offset, name):
  894. self.offset = timedelta(hours = offset)
  895. self.name = name
  896. # tzinfo.__init__(self, name)
  897. def utcoffset(self, dt):
  898. return self.offset
  899. def tzname(self, dt):
  900. return self.name
  901. def dst(self, dt):
  902. return timedelta(0)