PageRenderTime 47ms CodeModel.GetById 16ms RepoModel.GetById 0ms app.codeStats 0ms

/r2/r2/config/middleware.py

https://github.com/wangmxf/lesswrong
Python | 515 lines | 445 code | 23 blank | 47 comment | 37 complexity | 9f85ddf229155bc4b0a972b42bbd5995 MD5 | raw file
Possible License(s): MPL-2.0-no-copyleft-exception, LGPL-2.1
  1. # The contents of this file are subject to the Common Public Attribution
  2. # License Version 1.0. (the "License"); you may not use this file except in
  3. # compliance with the License. You may obtain a copy of the License at
  4. # http://code.reddit.com/LICENSE. The License is based on the Mozilla Public
  5. # License Version 1.1, but Sections 14 and 15 have been added to cover use of
  6. # software over a computer network and provide for limited attribution for the
  7. # Original Developer. In addition, Exhibit A has been modified to be consistent
  8. # with Exhibit B.
  9. #
  10. # Software distributed under the License is distributed on an "AS IS" basis,
  11. # WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License for
  12. # the specific language governing rights and limitations under the License.
  13. #
  14. # The Original Code is Reddit.
  15. #
  16. # The Original Developer is the Initial Developer. The Initial Developer of the
  17. # Original Code is CondeNet, Inc.
  18. #
  19. # All portions of the code written by CondeNet are Copyright (c) 2006-2008
  20. # CondeNet, Inc. All Rights Reserved.
  21. ################################################################################
  22. """Pylons middleware initialization"""
  23. from paste.cascade import Cascade
  24. from paste.registry import RegistryManager
  25. from paste.urlparser import StaticURLParser
  26. from paste.deploy.converters import asbool
  27. from paste.gzipper import make_gzip_middleware
  28. from paste.request import resolve_relative_url
  29. from paste.response import header_value, replace_header
  30. from pylons import config, request, Response
  31. from pylons.error import error_template
  32. from pylons.middleware import ErrorDocuments, ErrorHandler, StaticJavascripts
  33. from pylons.wsgiapp import PylonsApp, PylonsBaseWSGIApp
  34. from r2.config.environment import load_environment
  35. from r2.config.rewrites import rewrites
  36. from r2.lib.utils import rstrips
  37. from r2.lib.jsontemplates import api_type
  38. #middleware stuff
  39. from r2.lib.html_source import HTMLValidationParser
  40. from cStringIO import StringIO
  41. import sys, tempfile, urllib, re, os, hashlib
  42. #from pylons.middleware import error_mapper
  43. def error_mapper(code, message, environ, global_conf=None, **kw):
  44. if environ.get('pylons.error_call'):
  45. return None
  46. else:
  47. environ['pylons.error_call'] = True
  48. if global_conf is None:
  49. global_conf = {}
  50. codes = [401, 403, 404, 503]
  51. if not asbool(global_conf.get('debug')):
  52. codes.append(500)
  53. if code in codes:
  54. # StatusBasedForward expects a relative URL (no SCRIPT_NAME)
  55. d = dict(code = code, message = message)
  56. if environ.get('REDDIT_CNAME'):
  57. d['cnameframe'] = 1
  58. if environ.get('REDDIT_NAME'):
  59. d['srname'] = environ.get('REDDIT_NAME')
  60. url = '/error/document/?%s' % (urllib.urlencode(d))
  61. return url
  62. class DebugMiddleware(object):
  63. def __init__(self, app, keyword):
  64. self.app = app
  65. self.keyword = keyword
  66. def __call__(self, environ, start_response):
  67. def foo(*a, **kw):
  68. self.res = self.app(environ, start_response)
  69. return self.res
  70. debug = config['global_conf']['debug'].lower() == 'true'
  71. args = {}
  72. for x in environ['QUERY_STRING'].split('&'):
  73. x = x.split('=')
  74. args[x[0]] = x[1] if x[1:] else None
  75. if debug and self.keyword in args.keys():
  76. prof_arg = args.get(self.keyword)
  77. prof_arg = urllib.unquote(prof_arg) if prof_arg else None
  78. return self.filter(foo, prof_arg = prof_arg)
  79. return self.app(environ, start_response)
  80. def filter(self, execution_func, prof_arg = None):
  81. pass
  82. class ProfilingMiddleware(DebugMiddleware):
  83. def __init__(self, app):
  84. DebugMiddleware.__init__(self, app, 'profile')
  85. def filter(self, execution_func, prof_arg = None):
  86. import cProfile as profile
  87. from pstats import Stats
  88. tmpfile = tempfile.NamedTemporaryFile()
  89. try:
  90. file, line = prof_arg.split(':')
  91. line, func = line.split('(')
  92. func = func.strip(')')
  93. except:
  94. file = line = func = None
  95. try:
  96. profile.runctx('execution_func()',
  97. globals(), locals(), tmpfile.name)
  98. out = StringIO()
  99. stats = Stats(tmpfile.name, stream=out)
  100. stats.sort_stats('time', 'calls')
  101. def parse_table(t, ncol):
  102. table = []
  103. for s in t:
  104. t = [x for x in s.split(' ') if x]
  105. if len(t) > 1:
  106. table += [t[:ncol-1] + [' '.join(t[ncol-1:])]]
  107. return table
  108. def cmp(n):
  109. def _cmp(x, y):
  110. return 0 if x[n] == y[n] else 1 if x[n] < y[n] else -1
  111. return _cmp
  112. if not file:
  113. stats.print_stats()
  114. stats_str = out.getvalue()
  115. statdata = stats_str.split('\n')
  116. headers = '\n'.join(statdata[:6])
  117. table = parse_table(statdata[6:], 6)
  118. from r2.lib.pages import Profiling
  119. res = Profiling(header = headers, table = table,
  120. path = request.path).render()
  121. return [unicode(res)]
  122. else:
  123. query = "%s:%s" % (file, line)
  124. stats.print_callees(query)
  125. stats.print_callers(query)
  126. statdata = out.getvalue()
  127. data = statdata.split(query)
  128. callee = data[2].split('->')[1].split('Ordered by')[0]
  129. callee = parse_table(callee.split('\n'), 4)
  130. callee.sort(cmp(1))
  131. callee = [['ncalls', 'tottime', 'cputime']] + callee
  132. i = 4
  133. while '<-' not in data[i] and i < len(data): i+= 1
  134. caller = data[i].split('<-')[1]
  135. caller = parse_table(caller.split('\n'), 4)
  136. caller.sort(cmp(1))
  137. caller = [['ncalls', 'tottime', 'cputime']] + caller
  138. from r2.lib.pages import Profiling
  139. res = Profiling(header = prof_arg,
  140. caller = caller, callee = callee,
  141. path = request.path).render()
  142. return [unicode(res)]
  143. finally:
  144. tmpfile.close()
  145. class SourceViewMiddleware(DebugMiddleware):
  146. def __init__(self, app):
  147. DebugMiddleware.__init__(self, app, 'chk_source')
  148. def filter(self, execution_func, prof_arg = None):
  149. output = execution_func()
  150. output = [x for x in output]
  151. parser = HTMLValidationParser()
  152. res = parser.feed(output[-1])
  153. return [res]
  154. class DomainMiddleware(object):
  155. lang_re = re.compile(r"^\w\w(-\w\w)?$")
  156. def __init__(self, app):
  157. self.app = app
  158. auth_cnames = config['global_conf'].get('authorized_cnames', '')
  159. auth_cnames = [x.strip() for x in auth_cnames.split(',')]
  160. # we are going to be matching with endswith, so make sure there
  161. # are no empty strings that have snuck in
  162. self.auth_cnames = filter(None, auth_cnames)
  163. def is_auth_cname(self, domain):
  164. return any((domain == cname or domain.endswith('.' + cname))
  165. for cname in self.auth_cnames)
  166. def __call__(self, environ, start_response):
  167. # get base domain as defined in INI file
  168. base_domain = config['global_conf']['domain']
  169. try:
  170. sub_domains, request_port = environ['HTTP_HOST'].split(':')
  171. environ['request_port'] = int(request_port)
  172. except ValueError:
  173. sub_domains = environ['HTTP_HOST'].split(':')[0]
  174. except KeyError:
  175. sub_domains = "localhost"
  176. #If the domain doesn't end with base_domain, assume
  177. #this is a cname, and redirect to the frame controller.
  178. #Ignore localhost so paster shell still works.
  179. #If this is an error, don't redirect
  180. if (not sub_domains.endswith(base_domain)
  181. and (not sub_domains == 'localhost')):
  182. environ['sub_domain'] = sub_domains
  183. if not environ.get('extension'):
  184. if environ['PATH_INFO'].startswith('/frame'):
  185. return self.app(environ, start_response)
  186. elif self.is_auth_cname(sub_domains):
  187. environ['frameless_cname'] = True
  188. environ['authorized_cname'] = True
  189. elif ("redditSession" in environ.get('HTTP_COOKIE', '')
  190. and environ['REQUEST_METHOD'] != 'POST'
  191. and not environ['PATH_INFO'].startswith('/error')):
  192. environ['original_path'] = environ['PATH_INFO']
  193. environ['PATH_INFO'] = '/frame'
  194. else:
  195. environ['frameless_cname'] = True
  196. return self.app(environ, start_response)
  197. sub_domains = sub_domains[:-len(base_domain)].strip('.')
  198. sub_domains = sub_domains.split('.')
  199. sr_redirect = None
  200. for sd in list(sub_domains):
  201. # subdomains to disregard completely
  202. if sd in ('www', 'origin', 'beta'):
  203. continue
  204. # subdomains which change the extension
  205. elif sd == 'm':
  206. environ['reddit-domain-extension'] = 'mobile'
  207. elif sd in ('api', 'rss', 'xml', 'json'):
  208. environ['reddit-domain-extension'] = sd
  209. elif (len(sd) == 2 or (len(sd) == 5 and sd[2] == '-')) and self.lang_re.match(sd):
  210. environ['reddit-prefer-lang'] = sd
  211. else:
  212. sr_redirect = sd
  213. sub_domains.remove(sd)
  214. if sr_redirect and environ.get("FULLPATH"):
  215. r = Response()
  216. sub_domains.append(base_domain)
  217. redir = "%s/r/%s/%s" % ('.'.join(sub_domains),
  218. sr_redirect, environ['FULLPATH'])
  219. redir = "http://" + redir.replace('//', '/')
  220. r.status_code = 301
  221. r.headers['location'] = redir
  222. r.content = ""
  223. return r(environ, start_response)
  224. return self.app(environ, start_response)
  225. class SubredditMiddleware(object):
  226. sr_pattern = re.compile(r'^/r/([^/]{2,})')
  227. def __init__(self, app):
  228. self.app = app
  229. def __call__(self, environ, start_response):
  230. path = environ['PATH_INFO']
  231. sr = self.sr_pattern.match(path)
  232. if sr:
  233. environ['subreddit'] = sr.groups()[0]
  234. environ['PATH_INFO'] = self.sr_pattern.sub('', path) or '/'
  235. elif path.startswith("/categories"):
  236. environ['subreddit'] = 'r'
  237. return self.app(environ, start_response)
  238. class DomainListingMiddleware(object):
  239. domain_pattern = re.compile(r'^/domain/(([-\w]+\.)+[\w]+)')
  240. def __init__(self, app):
  241. self.app = app
  242. def __call__(self, environ, start_response):
  243. if not environ.has_key('subreddit'):
  244. path = environ['PATH_INFO']
  245. domain = self.domain_pattern.match(path)
  246. if domain:
  247. environ['domain'] = domain.groups()[0]
  248. environ['PATH_INFO'] = self.domain_pattern.sub('', path) or '/'
  249. return self.app(environ, start_response)
  250. class ExtensionMiddleware(object):
  251. ext_pattern = re.compile(r'\.([^/]+)$')
  252. extensions = {'rss' : ('xml', 'text/xml; charset=UTF-8'),
  253. 'xml' : ('xml', 'text/xml; charset=UTF-8'),
  254. 'js' : ('js', 'text/javascript; charset=UTF-8'),
  255. #'png' : ('png', 'image/png'),
  256. #'css' : ('css', 'text/css'),
  257. 'api' : (api_type(), 'application/json; charset=UTF-8'),
  258. 'json' : (api_type(), 'application/json; charset=UTF-8'),
  259. 'json-html' : (api_type('html'), 'application/json; charset=UTF-8')}
  260. def __init__(self, app):
  261. self.app = app
  262. def __call__(self, environ, start_response):
  263. path = environ['PATH_INFO']
  264. domain_ext = environ.get('reddit-domain-extension')
  265. for ext, val in self.extensions.iteritems():
  266. if ext == domain_ext or path.endswith('.' + ext):
  267. environ['extension'] = ext
  268. environ['render_style'] = val[0]
  269. environ['content_type'] = val[1]
  270. #strip off the extension
  271. if path.endswith('.' + ext):
  272. environ['PATH_INFO'] = path[:-(len(ext) + 1)]
  273. break
  274. else:
  275. environ['render_style'] = 'html'
  276. environ['content_type'] = 'text/html; charset=UTF-8'
  277. return self.app(environ, start_response)
  278. class RewriteMiddleware(object):
  279. def __init__(self, app):
  280. self.app = app
  281. def rewrite(self, regex, out_template, input):
  282. m = regex.match(input)
  283. out = out_template
  284. if m:
  285. for num, group in enumerate(m.groups('')):
  286. out = out.replace('$%s' % (num + 1), group)
  287. return out
  288. def __call__(self, environ, start_response):
  289. path = environ['PATH_INFO']
  290. for r in rewrites:
  291. newpath = self.rewrite(r[0], r[1], path)
  292. if newpath:
  293. environ['PATH_INFO'] = newpath
  294. break
  295. environ['FULLPATH'] = environ.get('PATH_INFO')
  296. qs = environ.get('QUERY_STRING')
  297. if qs:
  298. environ['FULLPATH'] += '?' + qs
  299. return self.app(environ, start_response)
  300. class RequestLogMiddleware(object):
  301. def __init__(self, log_path, process_iden, app):
  302. self.log_path = log_path
  303. self.app = app
  304. self.process_iden = str(process_iden)
  305. def __call__(self, environ, start_response):
  306. request = '\n'.join('%s: %s' % (k,v) for k,v in environ.iteritems()
  307. if k.isupper())
  308. iden = self.process_iden + '-' + hashlib.sha1(request).hexdigest()
  309. fname = os.path.join(self.log_path, iden)
  310. f = open(fname, 'w')
  311. f.write(request)
  312. f.close()
  313. r = self.app(environ, start_response)
  314. if os.path.exists(fname):
  315. try:
  316. os.remove(fname)
  317. except OSError:
  318. pass
  319. return r
  320. class LimitUploadSize(object):
  321. """
  322. Middleware for restricting the size of uploaded files (such as
  323. image files for the CSS editing capability).
  324. """
  325. def __init__(self, app, max_size=1024*500):
  326. self.app = app
  327. self.max_size = max_size
  328. def __call__(self, environ, start_response):
  329. cl_key = 'CONTENT_LENGTH'
  330. if environ['REQUEST_METHOD'] == 'POST':
  331. if ((cl_key not in environ)
  332. or int(environ[cl_key]) > self.max_size):
  333. r = Response()
  334. r.status_code = 500
  335. r.content = '<html><head></head><body><script type="text/javascript">parent.too_big();</script>request too big</body></html>'
  336. return r(environ, start_response)
  337. return self.app(environ, start_response)
  338. class AbsoluteRedirectMiddleware(object):
  339. def __init__(self, app):
  340. self.app = app
  341. def __call__(self, environ, start_response):
  342. def start_response_wrapper(status, headers, exc_info=None):
  343. location_header = 'location'
  344. status_code = int(status.split(None,1)[0])
  345. if (status_code >= 301 and status_code <= 303) or status_code == 307:
  346. location = header_value(headers, location_header)
  347. if location:
  348. replace_header(headers, location_header, resolve_relative_url(location, environ))
  349. return start_response(status, headers, exc_info)
  350. return self.app(environ, start_response_wrapper)
  351. class CleanupMiddleware(object):
  352. """
  353. Put anything here that should be called after every other bit of
  354. middleware. This currently includes the code for removing
  355. duplicate headers (except multiple cookie setting). The behavior
  356. here is to disregard all but the last record.
  357. """
  358. def __init__(self, app):
  359. self.app = app
  360. def __call__(self, environ, start_response):
  361. def custom_start_response(status, headers, exc_info = None):
  362. fixed = []
  363. seen = set()
  364. for head, val in reversed(headers):
  365. head = head.title()
  366. if head == 'Set-Cookie' or head not in seen:
  367. fixed.insert(0, (head, val))
  368. seen.add(head)
  369. return start_response(status, fixed, exc_info)
  370. return self.app(environ, custom_start_response)
  371. #god this shit is disorganized and confusing
  372. class RedditApp(PylonsBaseWSGIApp):
  373. def find_controller(self, controller):
  374. if controller in self.controller_classes:
  375. return self.controller_classes[controller]
  376. full_module_name = self.package_name + '.controllers'
  377. class_name = controller.capitalize() + 'Controller'
  378. __import__(self.package_name + '.controllers')
  379. mycontroller = getattr(sys.modules[full_module_name], class_name)
  380. self.controller_classes[controller] = mycontroller
  381. return mycontroller
  382. def make_app(global_conf, full_stack=True, **app_conf):
  383. """Create a Pylons WSGI application and return it
  384. `global_conf`
  385. The inherited configuration for this application. Normally from the
  386. [DEFAULT] section of the Paste ini file.
  387. `full_stack`
  388. Whether or not this application provides a full WSGI stack (by default,
  389. meaning it handles its own exceptions and errors). Disable full_stack
  390. when this application is "managed" by another WSGI middleware.
  391. `app_conf`
  392. The application's local configuration. Normally specified in the
  393. [app:<name>] section of the Paste ini file (where <name> defaults to
  394. main).
  395. """
  396. # Configure the Pylons environment
  397. load_environment(global_conf, app_conf)
  398. # The Pylons WSGI app
  399. app = PylonsApp(base_wsgi_app=RedditApp)
  400. # CUSTOM MIDDLEWARE HERE (filtered by the error handling middlewares)
  401. app = LimitUploadSize(app)
  402. app = ProfilingMiddleware(app)
  403. app = SourceViewMiddleware(app)
  404. app = DomainListingMiddleware(app)
  405. app = SubredditMiddleware(app)
  406. app = ExtensionMiddleware(app)
  407. app = DomainMiddleware(app)
  408. log_path = global_conf.get('log_path')
  409. if log_path:
  410. process_iden = global_conf.get('scgi_port', 'default')
  411. app = RequestLogMiddleware(log_path, process_iden, app)
  412. #TODO: breaks on 404
  413. #app = make_gzip_middleware(app, app_conf)
  414. if asbool(full_stack):
  415. # Handle Python exceptions
  416. app = ErrorHandler(app, global_conf, error_template=error_template,
  417. **config['pylons.errorware'])
  418. # Display error documents for 401, 403, 404 status codes (and 500 when
  419. # debug is disabled)
  420. app = ErrorDocuments(app, global_conf, mapper=error_mapper, **app_conf)
  421. # Establish the Registry for this application
  422. app = RegistryManager(app)
  423. # Static files
  424. javascripts_app = StaticJavascripts()
  425. # Set cache headers indicating the client should cache for 7 days
  426. static_app = StaticURLParser(config['pylons.paths']['static_files'], cache_max_age=604800)
  427. app = Cascade([static_app, javascripts_app, app])
  428. app = AbsoluteRedirectMiddleware(app)
  429. #add the rewrite rules
  430. app = RewriteMiddleware(app)
  431. app = CleanupMiddleware(app)
  432. return app