PageRenderTime 47ms CodeModel.GetById 15ms RepoModel.GetById 1ms app.codeStats 0ms

/r2/r2/config/middleware.py

https://github.com/stevewilber/reddit
Python | 461 lines | 393 code | 30 blank | 38 comment | 33 complexity | e9ecc81edbb1dde2dc9fc2d94fe57547 MD5 | raw file
Possible License(s): MPL-2.0-no-copyleft-exception, Apache-2.0
  1. # The contents of this file are subject to the Common Public Attribution
  2. # License Version 1.0. (the "License"); you may not use this file except in
  3. # compliance with the License. You may obtain a copy of the License at
  4. # http://code.reddit.com/LICENSE. The License is based on the Mozilla Public
  5. # License Version 1.1, but Sections 14 and 15 have been added to cover use of
  6. # software over a computer network and provide for limited attribution for the
  7. # Original Developer. In addition, Exhibit A has been modified to be consistent
  8. # with Exhibit B.
  9. #
  10. # Software distributed under the License is distributed on an "AS IS" basis,
  11. # WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License for
  12. # the specific language governing rights and limitations under the License.
  13. #
  14. # The Original Code is reddit.
  15. #
  16. # The Original Developer is the Initial Developer. The Initial Developer of
  17. # the Original Code is reddit Inc.
  18. #
  19. # All portions of the code written by reddit are Copyright (c) 2006-2012 reddit
  20. # Inc. All Rights Reserved.
  21. ###############################################################################
  22. """Pylons middleware initialization"""
  23. import re
  24. import urllib
  25. import tempfile
  26. import urlparse
  27. from threading import Lock
  28. from paste.cascade import Cascade
  29. from paste.registry import RegistryManager
  30. from paste.urlparser import StaticURLParser
  31. from paste.deploy.converters import asbool
  32. from pylons import config, Response
  33. from pylons.error import error_template
  34. from pylons.middleware import ErrorDocuments, ErrorHandler, StaticJavascripts
  35. from pylons.wsgiapp import PylonsApp, PylonsBaseWSGIApp
  36. from r2.config.environment import load_environment
  37. from r2.config.rewrites import rewrites
  38. from r2.config.extensions import extension_mapping, set_extension
  39. from r2.lib.utils import is_subdomain
  40. # hack in Paste support for HTTP 429 "Too Many Requests"
  41. from paste import httpexceptions, wsgiwrappers
  42. class HTTPTooManyRequests(httpexceptions.HTTPClientError):
  43. code = 429
  44. title = 'Too Many Requests'
  45. explanation = ('The server has received too many requests from the client.')
  46. httpexceptions._exceptions[429] = HTTPTooManyRequests
  47. wsgiwrappers.STATUS_CODE_TEXT[429] = HTTPTooManyRequests.title
  48. #from pylons.middleware import error_mapper
  49. def error_mapper(code, message, environ, global_conf=None, **kw):
  50. from pylons import c
  51. if environ.get('pylons.error_call'):
  52. return None
  53. else:
  54. environ['pylons.error_call'] = True
  55. if global_conf is None:
  56. global_conf = {}
  57. codes = [304, 400, 401, 403, 404, 415, 429, 503]
  58. if not asbool(global_conf.get('debug')):
  59. codes.append(500)
  60. if code in codes:
  61. # StatusBasedForward expects a relative URL (no SCRIPT_NAME)
  62. d = dict(code = code, message = message)
  63. exception = environ.get('r2.controller.exception')
  64. if exception:
  65. d['explanation'] = exception.explanation
  66. error_data = getattr(exception, 'error_data', None)
  67. if error_data:
  68. environ['extra_error_data'] = error_data
  69. if environ.get('REDDIT_CNAME'):
  70. d['cnameframe'] = 1
  71. if environ.get('REDDIT_NAME'):
  72. d['srname'] = environ.get('REDDIT_NAME')
  73. if environ.get('REDDIT_TAKEDOWN'):
  74. d['takedown'] = environ.get('REDDIT_TAKEDOWN')
  75. #preserve x-sup-id when 304ing
  76. if code == 304:
  77. #check to see if c is useable
  78. try:
  79. c.test
  80. except TypeError:
  81. pass
  82. else:
  83. if c.response.headers.has_key('x-sup-id'):
  84. d['x-sup-id'] = c.response.headers['x-sup-id']
  85. extension = environ.get("extension")
  86. if extension:
  87. url = '/error/document/.%s?%s' % (extension, urllib.urlencode(d))
  88. else:
  89. url = '/error/document/?%s' % (urllib.urlencode(d))
  90. return url
  91. class ProfilingMiddleware(object):
  92. def __init__(self, app, directory):
  93. self.app = app
  94. self.directory = directory
  95. def __call__(self, environ, start_response):
  96. import cProfile
  97. try:
  98. tmpfile = tempfile.NamedTemporaryFile(prefix='profile',
  99. dir=self.directory,
  100. delete=False)
  101. profile = cProfile.Profile()
  102. result = profile.runcall(self.app, environ, start_response)
  103. profile.dump_stats(tmpfile.name)
  104. return result
  105. finally:
  106. tmpfile.close()
  107. class DomainMiddleware(object):
  108. lang_re = re.compile(r"\A\w\w(-\w\w)?\Z")
  109. def __init__(self, app):
  110. self.app = app
  111. def __call__(self, environ, start_response):
  112. g = config['pylons.g']
  113. http_host = environ.get('HTTP_HOST', 'localhost').lower()
  114. domain, s, port = http_host.partition(':')
  115. # remember the port
  116. try:
  117. environ['request_port'] = int(port)
  118. except ValueError:
  119. pass
  120. # localhost is exempt so paster run/shell will work
  121. # media_domain doesn't need special processing since it's just ads
  122. if domain == "localhost" or is_subdomain(domain, g.media_domain):
  123. return self.app(environ, start_response)
  124. # tell reddit_base to redirect to the appropriate subreddit for
  125. # a legacy CNAME
  126. if not is_subdomain(domain, g.domain):
  127. environ['legacy-cname'] = domain
  128. return self.app(environ, start_response)
  129. # figure out what subdomain we're on if any
  130. subdomains = domain[:-len(g.domain) - 1].split('.')
  131. extension_subdomains = dict(m="mobile",
  132. i="compact",
  133. api="api",
  134. rss="rss",
  135. xml="xml",
  136. json="json")
  137. sr_redirect = None
  138. for subdomain in subdomains[:]:
  139. if subdomain in g.reserved_subdomains:
  140. continue
  141. extension = extension_subdomains.get(subdomain)
  142. if extension:
  143. environ['reddit-domain-extension'] = extension
  144. elif self.lang_re.match(subdomain):
  145. environ['reddit-prefer-lang'] = subdomain
  146. environ['reddit-domain-prefix'] = subdomain
  147. else:
  148. sr_redirect = subdomain
  149. subdomains.remove(subdomain)
  150. # if there was a subreddit subdomain, redirect
  151. if sr_redirect and environ.get("FULLPATH"):
  152. r = Response()
  153. if not subdomains and g.domain_prefix:
  154. subdomains.append(g.domain_prefix)
  155. subdomains.append(g.domain)
  156. redir = "%s/r/%s/%s" % ('.'.join(subdomains),
  157. sr_redirect, environ['FULLPATH'])
  158. redir = "http://" + redir.replace('//', '/')
  159. r.status_code = 301
  160. r.headers['location'] = redir
  161. r.content = ""
  162. return r(environ, start_response)
  163. return self.app(environ, start_response)
  164. class SubredditMiddleware(object):
  165. sr_pattern = re.compile(r'^/r/([^/]{2,})')
  166. def __init__(self, app):
  167. self.app = app
  168. def __call__(self, environ, start_response):
  169. path = environ['PATH_INFO']
  170. sr = self.sr_pattern.match(path)
  171. if sr:
  172. environ['subreddit'] = sr.groups()[0]
  173. environ['PATH_INFO'] = self.sr_pattern.sub('', path) or '/'
  174. elif path.startswith("/reddits"):
  175. environ['subreddit'] = 'r'
  176. return self.app(environ, start_response)
  177. class DomainListingMiddleware(object):
  178. domain_pattern = re.compile(r'\A/domain/(([-\w]+\.)+[\w]+)')
  179. def __init__(self, app):
  180. self.app = app
  181. def __call__(self, environ, start_response):
  182. if not environ.has_key('subreddit'):
  183. path = environ['PATH_INFO']
  184. domain = self.domain_pattern.match(path)
  185. if domain:
  186. environ['domain'] = domain.groups()[0]
  187. environ['PATH_INFO'] = self.domain_pattern.sub('', path) or '/'
  188. return self.app(environ, start_response)
  189. class ExtensionMiddleware(object):
  190. ext_pattern = re.compile(r'\.([^/]+)\Z')
  191. def __init__(self, app):
  192. self.app = app
  193. def __call__(self, environ, start_response):
  194. path = environ['PATH_INFO']
  195. fname, sep, path_ext = path.rpartition('.')
  196. domain_ext = environ.get('reddit-domain-extension')
  197. ext = None
  198. if path_ext in extension_mapping:
  199. ext = path_ext
  200. # Strip off the extension.
  201. environ['PATH_INFO'] = path[:-(len(ext) + 1)]
  202. elif domain_ext in extension_mapping:
  203. ext = domain_ext
  204. if ext:
  205. set_extension(environ, ext)
  206. else:
  207. environ['render_style'] = 'html'
  208. environ['content_type'] = 'text/html; charset=UTF-8'
  209. return self.app(environ, start_response)
  210. class RewriteMiddleware(object):
  211. def __init__(self, app):
  212. self.app = app
  213. def rewrite(self, regex, out_template, input):
  214. m = regex.match(input)
  215. out = out_template
  216. if m:
  217. for num, group in enumerate(m.groups('')):
  218. out = out.replace('$%s' % (num + 1), group)
  219. return out
  220. def __call__(self, environ, start_response):
  221. path = environ['PATH_INFO']
  222. for r in rewrites:
  223. newpath = self.rewrite(r[0], r[1], path)
  224. if newpath:
  225. environ['PATH_INFO'] = newpath
  226. break
  227. environ['FULLPATH'] = environ.get('PATH_INFO')
  228. qs = environ.get('QUERY_STRING')
  229. if qs:
  230. environ['FULLPATH'] += '?' + qs
  231. return self.app(environ, start_response)
  232. class StaticTestMiddleware(object):
  233. def __init__(self, app, static_path, domain):
  234. self.app = app
  235. self.static_path = static_path
  236. self.domain = domain
  237. def __call__(self, environ, start_response):
  238. if environ['HTTP_HOST'] == self.domain:
  239. environ['PATH_INFO'] = self.static_path.rstrip('/') + environ['PATH_INFO']
  240. return self.app(environ, start_response)
  241. raise httpexceptions.HTTPNotFound()
  242. class LimitUploadSize(object):
  243. """
  244. Middleware for restricting the size of uploaded files (such as
  245. image files for the CSS editing capability).
  246. """
  247. def __init__(self, app, max_size=1024*500):
  248. self.app = app
  249. self.max_size = max_size
  250. def __call__(self, environ, start_response):
  251. cl_key = 'CONTENT_LENGTH'
  252. if environ['REQUEST_METHOD'] == 'POST':
  253. if cl_key not in environ:
  254. r = Response()
  255. r.status_code = 411
  256. r.content = '<html><head></head><body>length required</body></html>'
  257. return r(environ, start_response)
  258. try:
  259. cl_int = int(environ[cl_key])
  260. except ValueError:
  261. r = Response()
  262. r.status_code = 400
  263. r.content = '<html><head></head><body>bad request</body></html>'
  264. return r(environ, start_response)
  265. if cl_int > self.max_size:
  266. from r2.lib.strings import string_dict
  267. error_msg = string_dict['css_validator_messages']['max_size'] % dict(max_size = self.max_size/1024)
  268. r = Response()
  269. r.status_code = 413
  270. r.content = ("<html>"
  271. "<head>"
  272. "<script type='text/javascript'>"
  273. "parent.completedUploadImage('failed',"
  274. "'',"
  275. "'',"
  276. "[['BAD_CSS_NAME', ''], ['IMAGE_ERROR', '", error_msg,"']],"
  277. "'image-upload');"
  278. "</script></head><body>you shouldn\'t be here</body></html>")
  279. return r(environ, start_response)
  280. return self.app(environ, start_response)
  281. # TODO CleanupMiddleware seems to exist because cookie headers are being duplicated
  282. # somewhere in the response processing chain. It should be removed as soon as we
  283. # find the underlying issue.
  284. class CleanupMiddleware(object):
  285. """
  286. Put anything here that should be called after every other bit of
  287. middleware. This currently includes the code for removing
  288. duplicate headers (such as multiple cookie setting). The behavior
  289. here is to disregard all but the last record.
  290. """
  291. def __init__(self, app):
  292. self.app = app
  293. def __call__(self, environ, start_response):
  294. def custom_start_response(status, headers, exc_info = None):
  295. fixed = []
  296. seen = set()
  297. for head, val in reversed(headers):
  298. head = head.lower()
  299. key = (head, val.split("=", 1)[0])
  300. if key not in seen:
  301. fixed.insert(0, (head, val))
  302. seen.add(key)
  303. return start_response(status, fixed, exc_info)
  304. return self.app(environ, custom_start_response)
  305. #god this shit is disorganized and confusing
  306. class RedditApp(PylonsBaseWSGIApp):
  307. def __init__(self, *args, **kwargs):
  308. super(RedditApp, self).__init__(*args, **kwargs)
  309. self._loading_lock = Lock()
  310. self._controllers = None
  311. def load_controllers(self):
  312. with self._loading_lock:
  313. if not self._controllers:
  314. controllers = __import__(self.package_name + '.controllers').controllers
  315. controllers.load_controllers()
  316. config['r2.plugins'].load_controllers()
  317. self._controllers = controllers
  318. return self._controllers
  319. def find_controller(self, controller_name):
  320. if controller_name in self.controller_classes:
  321. return self.controller_classes[controller_name]
  322. controllers = self.load_controllers()
  323. controller_cls = controllers.get_controller(controller_name)
  324. self.controller_classes[controller_name] = controller_cls
  325. return controller_cls
  326. def make_app(global_conf, full_stack=True, **app_conf):
  327. """Create a Pylons WSGI application and return it
  328. `global_conf`
  329. The inherited configuration for this application. Normally from the
  330. [DEFAULT] section of the Paste ini file.
  331. `full_stack`
  332. Whether or not this application provides a full WSGI stack (by default,
  333. meaning it handles its own exceptions and errors). Disable full_stack
  334. when this application is "managed" by another WSGI middleware.
  335. `app_conf`
  336. The application's local configuration. Normally specified in the
  337. [app:<name>] section of the Paste ini file (where <name> defaults to
  338. main).
  339. """
  340. # Configure the Pylons environment
  341. load_environment(global_conf, app_conf)
  342. g = config['pylons.g']
  343. # The Pylons WSGI app
  344. app = PylonsApp(base_wsgi_app=RedditApp)
  345. # CUSTOM MIDDLEWARE HERE (filtered by the error handling middlewares)
  346. # last thing first from here down
  347. app = CleanupMiddleware(app)
  348. app = LimitUploadSize(app)
  349. profile_directory = g.config.get('profile_directory')
  350. if profile_directory:
  351. app = ProfilingMiddleware(app, profile_directory)
  352. app = DomainListingMiddleware(app)
  353. app = SubredditMiddleware(app)
  354. app = ExtensionMiddleware(app)
  355. app = DomainMiddleware(app)
  356. if asbool(full_stack):
  357. # Handle Python exceptions
  358. app = ErrorHandler(app, global_conf, error_template=error_template,
  359. **config['pylons.errorware'])
  360. # Display error documents for 401, 403, 404 status codes (and 500 when
  361. # debug is disabled)
  362. app = ErrorDocuments(app, global_conf, mapper=error_mapper, **app_conf)
  363. # Establish the Registry for this application
  364. app = RegistryManager(app)
  365. # Static files
  366. javascripts_app = StaticJavascripts()
  367. static_app = StaticURLParser(config['pylons.paths']['static_files'])
  368. static_cascade = [static_app, javascripts_app, app]
  369. if config['r2.plugins'] and g.config['uncompressedJS']:
  370. plugin_static_apps = Cascade([StaticURLParser(plugin.static_dir)
  371. for plugin in config['r2.plugins']])
  372. static_cascade.insert(0, plugin_static_apps)
  373. app = Cascade(static_cascade)
  374. #add the rewrite rules
  375. app = RewriteMiddleware(app)
  376. if not g.config['uncompressedJS'] and g.config['debug']:
  377. static_fallback = StaticTestMiddleware(static_app, g.config['static_path'], g.config['static_domain'])
  378. app = Cascade([static_fallback, app])
  379. return app