/r2/r2/config/middleware.py
Python | 515 lines | 445 code | 23 blank | 47 comment | 37 complexity | 9f85ddf229155bc4b0a972b42bbd5995 MD5 | raw file
Possible License(s): MPL-2.0-no-copyleft-exception, LGPL-2.1
- # The contents of this file are subject to the Common Public Attribution
- # License Version 1.0. (the "License"); you may not use this file except in
- # compliance with the License. You may obtain a copy of the License at
- # http://code.reddit.com/LICENSE. The License is based on the Mozilla Public
- # License Version 1.1, but Sections 14 and 15 have been added to cover use of
- # software over a computer network and provide for limited attribution for the
- # Original Developer. In addition, Exhibit A has been modified to be consistent
- # with Exhibit B.
- #
- # Software distributed under the License is distributed on an "AS IS" basis,
- # WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License for
- # the specific language governing rights and limitations under the License.
- #
- # The Original Code is Reddit.
- #
- # The Original Developer is the Initial Developer. The Initial Developer of the
- # Original Code is CondeNet, Inc.
- #
- # All portions of the code written by CondeNet are Copyright (c) 2006-2008
- # CondeNet, Inc. All Rights Reserved.
- ################################################################################
- """Pylons middleware initialization"""
- from paste.cascade import Cascade
- from paste.registry import RegistryManager
- from paste.urlparser import StaticURLParser
- from paste.deploy.converters import asbool
- from paste.gzipper import make_gzip_middleware
- from paste.request import resolve_relative_url
- from paste.response import header_value, replace_header
- from pylons import config, request, Response
- from pylons.error import error_template
- from pylons.middleware import ErrorDocuments, ErrorHandler, StaticJavascripts
- from pylons.wsgiapp import PylonsApp, PylonsBaseWSGIApp
- from r2.config.environment import load_environment
- from r2.config.rewrites import rewrites
- from r2.lib.utils import rstrips
- from r2.lib.jsontemplates import api_type
- #middleware stuff
- from r2.lib.html_source import HTMLValidationParser
- from cStringIO import StringIO
- import sys, tempfile, urllib, re, os, hashlib
- #from pylons.middleware import error_mapper
- def error_mapper(code, message, environ, global_conf=None, **kw):
- if environ.get('pylons.error_call'):
- return None
- else:
- environ['pylons.error_call'] = True
- if global_conf is None:
- global_conf = {}
- codes = [401, 403, 404, 503]
- if not asbool(global_conf.get('debug')):
- codes.append(500)
- if code in codes:
- # StatusBasedForward expects a relative URL (no SCRIPT_NAME)
- d = dict(code = code, message = message)
- if environ.get('REDDIT_CNAME'):
- d['cnameframe'] = 1
- if environ.get('REDDIT_NAME'):
- d['srname'] = environ.get('REDDIT_NAME')
- url = '/error/document/?%s' % (urllib.urlencode(d))
- return url
- class DebugMiddleware(object):
- def __init__(self, app, keyword):
- self.app = app
- self.keyword = keyword
- def __call__(self, environ, start_response):
- def foo(*a, **kw):
- self.res = self.app(environ, start_response)
- return self.res
- debug = config['global_conf']['debug'].lower() == 'true'
- args = {}
- for x in environ['QUERY_STRING'].split('&'):
- x = x.split('=')
- args[x[0]] = x[1] if x[1:] else None
- if debug and self.keyword in args.keys():
- prof_arg = args.get(self.keyword)
- prof_arg = urllib.unquote(prof_arg) if prof_arg else None
- return self.filter(foo, prof_arg = prof_arg)
- return self.app(environ, start_response)
- def filter(self, execution_func, prof_arg = None):
- pass
- class ProfilingMiddleware(DebugMiddleware):
- def __init__(self, app):
- DebugMiddleware.__init__(self, app, 'profile')
- def filter(self, execution_func, prof_arg = None):
- import cProfile as profile
- from pstats import Stats
- tmpfile = tempfile.NamedTemporaryFile()
- try:
- file, line = prof_arg.split(':')
- line, func = line.split('(')
- func = func.strip(')')
- except:
- file = line = func = None
- try:
- profile.runctx('execution_func()',
- globals(), locals(), tmpfile.name)
- out = StringIO()
- stats = Stats(tmpfile.name, stream=out)
- stats.sort_stats('time', 'calls')
- def parse_table(t, ncol):
- table = []
- for s in t:
- t = [x for x in s.split(' ') if x]
- if len(t) > 1:
- table += [t[:ncol-1] + [' '.join(t[ncol-1:])]]
- return table
- def cmp(n):
- def _cmp(x, y):
- return 0 if x[n] == y[n] else 1 if x[n] < y[n] else -1
- return _cmp
- if not file:
- stats.print_stats()
- stats_str = out.getvalue()
- statdata = stats_str.split('\n')
- headers = '\n'.join(statdata[:6])
- table = parse_table(statdata[6:], 6)
- from r2.lib.pages import Profiling
- res = Profiling(header = headers, table = table,
- path = request.path).render()
- return [unicode(res)]
- else:
- query = "%s:%s" % (file, line)
- stats.print_callees(query)
- stats.print_callers(query)
- statdata = out.getvalue()
- data = statdata.split(query)
- callee = data[2].split('->')[1].split('Ordered by')[0]
- callee = parse_table(callee.split('\n'), 4)
- callee.sort(cmp(1))
- callee = [['ncalls', 'tottime', 'cputime']] + callee
- i = 4
- while '<-' not in data[i] and i < len(data): i+= 1
- caller = data[i].split('<-')[1]
- caller = parse_table(caller.split('\n'), 4)
- caller.sort(cmp(1))
- caller = [['ncalls', 'tottime', 'cputime']] + caller
- from r2.lib.pages import Profiling
- res = Profiling(header = prof_arg,
- caller = caller, callee = callee,
- path = request.path).render()
- return [unicode(res)]
- finally:
- tmpfile.close()
- class SourceViewMiddleware(DebugMiddleware):
- def __init__(self, app):
- DebugMiddleware.__init__(self, app, 'chk_source')
- def filter(self, execution_func, prof_arg = None):
- output = execution_func()
- output = [x for x in output]
- parser = HTMLValidationParser()
- res = parser.feed(output[-1])
- return [res]
- class DomainMiddleware(object):
- lang_re = re.compile(r"^\w\w(-\w\w)?$")
- def __init__(self, app):
- self.app = app
- auth_cnames = config['global_conf'].get('authorized_cnames', '')
- auth_cnames = [x.strip() for x in auth_cnames.split(',')]
- # we are going to be matching with endswith, so make sure there
- # are no empty strings that have snuck in
- self.auth_cnames = filter(None, auth_cnames)
- def is_auth_cname(self, domain):
- return any((domain == cname or domain.endswith('.' + cname))
- for cname in self.auth_cnames)
- def __call__(self, environ, start_response):
- # get base domain as defined in INI file
- base_domain = config['global_conf']['domain']
- try:
- sub_domains, request_port = environ['HTTP_HOST'].split(':')
- environ['request_port'] = int(request_port)
- except ValueError:
- sub_domains = environ['HTTP_HOST'].split(':')[0]
- except KeyError:
- sub_domains = "localhost"
- #If the domain doesn't end with base_domain, assume
- #this is a cname, and redirect to the frame controller.
- #Ignore localhost so paster shell still works.
- #If this is an error, don't redirect
- if (not sub_domains.endswith(base_domain)
- and (not sub_domains == 'localhost')):
- environ['sub_domain'] = sub_domains
- if not environ.get('extension'):
- if environ['PATH_INFO'].startswith('/frame'):
- return self.app(environ, start_response)
- elif self.is_auth_cname(sub_domains):
- environ['frameless_cname'] = True
- environ['authorized_cname'] = True
- elif ("redditSession" in environ.get('HTTP_COOKIE', '')
- and environ['REQUEST_METHOD'] != 'POST'
- and not environ['PATH_INFO'].startswith('/error')):
- environ['original_path'] = environ['PATH_INFO']
- environ['PATH_INFO'] = '/frame'
- else:
- environ['frameless_cname'] = True
- return self.app(environ, start_response)
- sub_domains = sub_domains[:-len(base_domain)].strip('.')
- sub_domains = sub_domains.split('.')
- sr_redirect = None
- for sd in list(sub_domains):
- # subdomains to disregard completely
- if sd in ('www', 'origin', 'beta'):
- continue
- # subdomains which change the extension
- elif sd == 'm':
- environ['reddit-domain-extension'] = 'mobile'
- elif sd in ('api', 'rss', 'xml', 'json'):
- environ['reddit-domain-extension'] = sd
- elif (len(sd) == 2 or (len(sd) == 5 and sd[2] == '-')) and self.lang_re.match(sd):
- environ['reddit-prefer-lang'] = sd
- else:
- sr_redirect = sd
- sub_domains.remove(sd)
- if sr_redirect and environ.get("FULLPATH"):
- r = Response()
- sub_domains.append(base_domain)
- redir = "%s/r/%s/%s" % ('.'.join(sub_domains),
- sr_redirect, environ['FULLPATH'])
- redir = "http://" + redir.replace('//', '/')
- r.status_code = 301
- r.headers['location'] = redir
- r.content = ""
- return r(environ, start_response)
- return self.app(environ, start_response)
- class SubredditMiddleware(object):
- sr_pattern = re.compile(r'^/r/([^/]{2,})')
- def __init__(self, app):
- self.app = app
- def __call__(self, environ, start_response):
- path = environ['PATH_INFO']
- sr = self.sr_pattern.match(path)
- if sr:
- environ['subreddit'] = sr.groups()[0]
- environ['PATH_INFO'] = self.sr_pattern.sub('', path) or '/'
- elif path.startswith("/categories"):
- environ['subreddit'] = 'r'
- return self.app(environ, start_response)
- class DomainListingMiddleware(object):
- domain_pattern = re.compile(r'^/domain/(([-\w]+\.)+[\w]+)')
- def __init__(self, app):
- self.app = app
- def __call__(self, environ, start_response):
- if not environ.has_key('subreddit'):
- path = environ['PATH_INFO']
- domain = self.domain_pattern.match(path)
- if domain:
- environ['domain'] = domain.groups()[0]
- environ['PATH_INFO'] = self.domain_pattern.sub('', path) or '/'
- return self.app(environ, start_response)
- class ExtensionMiddleware(object):
- ext_pattern = re.compile(r'\.([^/]+)$')
- extensions = {'rss' : ('xml', 'text/xml; charset=UTF-8'),
- 'xml' : ('xml', 'text/xml; charset=UTF-8'),
- 'js' : ('js', 'text/javascript; charset=UTF-8'),
- #'png' : ('png', 'image/png'),
- #'css' : ('css', 'text/css'),
- 'api' : (api_type(), 'application/json; charset=UTF-8'),
- 'json' : (api_type(), 'application/json; charset=UTF-8'),
- 'json-html' : (api_type('html'), 'application/json; charset=UTF-8')}
- def __init__(self, app):
- self.app = app
- def __call__(self, environ, start_response):
- path = environ['PATH_INFO']
- domain_ext = environ.get('reddit-domain-extension')
- for ext, val in self.extensions.iteritems():
- if ext == domain_ext or path.endswith('.' + ext):
- environ['extension'] = ext
- environ['render_style'] = val[0]
- environ['content_type'] = val[1]
- #strip off the extension
- if path.endswith('.' + ext):
- environ['PATH_INFO'] = path[:-(len(ext) + 1)]
- break
- else:
- environ['render_style'] = 'html'
- environ['content_type'] = 'text/html; charset=UTF-8'
- return self.app(environ, start_response)
- class RewriteMiddleware(object):
- def __init__(self, app):
- self.app = app
- def rewrite(self, regex, out_template, input):
- m = regex.match(input)
- out = out_template
- if m:
- for num, group in enumerate(m.groups('')):
- out = out.replace('$%s' % (num + 1), group)
- return out
- def __call__(self, environ, start_response):
- path = environ['PATH_INFO']
- for r in rewrites:
- newpath = self.rewrite(r[0], r[1], path)
- if newpath:
- environ['PATH_INFO'] = newpath
- break
- environ['FULLPATH'] = environ.get('PATH_INFO')
- qs = environ.get('QUERY_STRING')
- if qs:
- environ['FULLPATH'] += '?' + qs
- return self.app(environ, start_response)
- class RequestLogMiddleware(object):
- def __init__(self, log_path, process_iden, app):
- self.log_path = log_path
- self.app = app
- self.process_iden = str(process_iden)
- def __call__(self, environ, start_response):
- request = '\n'.join('%s: %s' % (k,v) for k,v in environ.iteritems()
- if k.isupper())
- iden = self.process_iden + '-' + hashlib.sha1(request).hexdigest()
- fname = os.path.join(self.log_path, iden)
- f = open(fname, 'w')
- f.write(request)
- f.close()
- r = self.app(environ, start_response)
- if os.path.exists(fname):
- try:
- os.remove(fname)
- except OSError:
- pass
- return r
- class LimitUploadSize(object):
- """
- Middleware for restricting the size of uploaded files (such as
- image files for the CSS editing capability).
- """
- def __init__(self, app, max_size=1024*500):
- self.app = app
- self.max_size = max_size
- def __call__(self, environ, start_response):
- cl_key = 'CONTENT_LENGTH'
- if environ['REQUEST_METHOD'] == 'POST':
- if ((cl_key not in environ)
- or int(environ[cl_key]) > self.max_size):
- r = Response()
- r.status_code = 500
- r.content = '<html><head></head><body><script type="text/javascript">parent.too_big();</script>request too big</body></html>'
- return r(environ, start_response)
- return self.app(environ, start_response)
- class AbsoluteRedirectMiddleware(object):
- def __init__(self, app):
- self.app = app
- def __call__(self, environ, start_response):
- def start_response_wrapper(status, headers, exc_info=None):
- location_header = 'location'
- status_code = int(status.split(None,1)[0])
- if (status_code >= 301 and status_code <= 303) or status_code == 307:
- location = header_value(headers, location_header)
- if location:
- replace_header(headers, location_header, resolve_relative_url(location, environ))
- return start_response(status, headers, exc_info)
- return self.app(environ, start_response_wrapper)
- class CleanupMiddleware(object):
- """
- Put anything here that should be called after every other bit of
- middleware. This currently includes the code for removing
- duplicate headers (except multiple cookie setting). The behavior
- here is to disregard all but the last record.
- """
- def __init__(self, app):
- self.app = app
- def __call__(self, environ, start_response):
- def custom_start_response(status, headers, exc_info = None):
- fixed = []
- seen = set()
- for head, val in reversed(headers):
- head = head.title()
- if head == 'Set-Cookie' or head not in seen:
- fixed.insert(0, (head, val))
- seen.add(head)
- return start_response(status, fixed, exc_info)
- return self.app(environ, custom_start_response)
- #god this shit is disorganized and confusing
- class RedditApp(PylonsBaseWSGIApp):
- def find_controller(self, controller):
- if controller in self.controller_classes:
- return self.controller_classes[controller]
- full_module_name = self.package_name + '.controllers'
- class_name = controller.capitalize() + 'Controller'
- __import__(self.package_name + '.controllers')
- mycontroller = getattr(sys.modules[full_module_name], class_name)
- self.controller_classes[controller] = mycontroller
- return mycontroller
- def make_app(global_conf, full_stack=True, **app_conf):
- """Create a Pylons WSGI application and return it
- `global_conf`
- The inherited configuration for this application. Normally from the
- [DEFAULT] section of the Paste ini file.
- `full_stack`
- Whether or not this application provides a full WSGI stack (by default,
- meaning it handles its own exceptions and errors). Disable full_stack
- when this application is "managed" by another WSGI middleware.
- `app_conf`
- The application's local configuration. Normally specified in the
- [app:<name>] section of the Paste ini file (where <name> defaults to
- main).
- """
- # Configure the Pylons environment
- load_environment(global_conf, app_conf)
- # The Pylons WSGI app
- app = PylonsApp(base_wsgi_app=RedditApp)
- # CUSTOM MIDDLEWARE HERE (filtered by the error handling middlewares)
- app = LimitUploadSize(app)
- app = ProfilingMiddleware(app)
- app = SourceViewMiddleware(app)
- app = DomainListingMiddleware(app)
- app = SubredditMiddleware(app)
- app = ExtensionMiddleware(app)
- app = DomainMiddleware(app)
- log_path = global_conf.get('log_path')
- if log_path:
- process_iden = global_conf.get('scgi_port', 'default')
- app = RequestLogMiddleware(log_path, process_iden, app)
- #TODO: breaks on 404
- #app = make_gzip_middleware(app, app_conf)
- if asbool(full_stack):
- # Handle Python exceptions
- app = ErrorHandler(app, global_conf, error_template=error_template,
- **config['pylons.errorware'])
- # Display error documents for 401, 403, 404 status codes (and 500 when
- # debug is disabled)
- app = ErrorDocuments(app, global_conf, mapper=error_mapper, **app_conf)
- # Establish the Registry for this application
- app = RegistryManager(app)
- # Static files
- javascripts_app = StaticJavascripts()
- # Set cache headers indicating the client should cache for 7 days
- static_app = StaticURLParser(config['pylons.paths']['static_files'], cache_max_age=604800)
- app = Cascade([static_app, javascripts_app, app])
- app = AbsoluteRedirectMiddleware(app)
- #add the rewrite rules
- app = RewriteMiddleware(app)
- app = CleanupMiddleware(app)
- return app