PageRenderTime 28ms CodeModel.GetById 11ms app.highlight 12ms RepoModel.GetById 2ms app.codeStats 0ms

/scalate-core/src/main/scala/org/fusesource/scalate/servlet/TemplateEngineFilter.scala

http://github.com/scalate/scalate
Scala | 154 lines | 92 code | 23 blank | 39 comment | 0 complexity | 95cd1d7880ecc042e67461003a59fe78 MD5 | raw file
  1/**
  2 * Copyright (C) 2009-2011 the original author or authors.
  3 * See the notice.md file distributed with this work for additional
  4 * information regarding copyright ownership.
  5 *
  6 * Licensed under the Apache License, Version 2.0 (the "License");
  7 * you may not use this file except in compliance with the License.
  8 * You may obtain a copy of the License at
  9 *
 10 *     http://www.apache.org/licenses/LICENSE-2.0
 11 *
 12 * Unless required by applicable law or agreed to in writing, software
 13 * distributed under the License is distributed on an "AS IS" BASIS,
 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 15 * See the License for the specific language governing permissions and
 16 * limitations under the License.
 17 */
 18package org.fusesource.scalate.servlet
 19
 20import javax.servlet._
 21import javax.servlet.http.{ HttpServletRequest, HttpServletRequestWrapper, HttpServletResponse }
 22import org.fusesource.scalate.support.TemplateFinder
 23import org.fusesource.scalate.util.Log
 24
 25object TemplateEngineFilter extends Log
 26
 27/**
 28 * Servlet filter which auto routes to the scalate engines for paths which have a scalate template
 29 * defined.
 30 *
 31 * @author <a href="http://hiramchirino.com">Hiram Chirino</a>
 32 */
 33class TemplateEngineFilter extends Filter {
 34  import TemplateEngineFilter._
 35
 36  var config: FilterConfig = _
 37  var engine: ServletTemplateEngine = _
 38  var finder: TemplateFinder = _
 39  var errorUris: List[String] = ServletHelper.errorUris()
 40
 41  /**
 42   * Called by the servlet engine to create the template engine and configure this filter
 43   */
 44  def init(filterConfig: FilterConfig) = {
 45    config = filterConfig
 46    engine = createTemplateEngine(config)
 47    finder = new TemplateFinder(engine)
 48
 49    filterConfig.getInitParameter("replaced-extensions") match {
 50      case null =>
 51      case x =>
 52        finder.replacedExtensions = x.split(":+").toList
 53    }
 54
 55    // register the template engine so they can be easily resolved from elsewhere
 56    ServletTemplateEngine(filterConfig.getServletContext) = engine
 57  }
 58
 59  /**
 60   * Called by the servlet engine on shut down.
 61   */
 62  def destroy = {
 63  }
 64
 65  /**
 66   * Performs the actual filter
 67   */
 68  def doFilter(request: ServletRequest, response: ServletResponse, chain: FilterChain): Unit = {
 69    (request, response) match {
 70      case (request: HttpServletRequest, response: HttpServletResponse) =>
 71        val request_wrapper = wrap(request)
 72
 73        debug("Checking '%s'", request.getRequestURI)
 74        findTemplate(request.getRequestURI.substring(request.getContextPath.length)) match {
 75          case Some(template) =>
 76            debug("Rendering '%s' using template '%s'", request.getRequestURI, template)
 77            val context = new ServletRenderContext(engine, request_wrapper, response, config.getServletContext)
 78
 79            try {
 80              context.include(template, true)
 81            } catch {
 82              case e: Throwable => showErrorPage(request_wrapper, response, e)
 83            }
 84
 85          case None =>
 86            chain.doFilter(request_wrapper, response)
 87        }
 88
 89      case _ =>
 90        chain.doFilter(request, response)
 91    }
 92  }
 93
 94  def showErrorPage(request: HttpServletRequest, response: HttpServletResponse, e: Throwable): Unit = {
 95
 96    info(e, "failure: %s", e)
 97
 98    // we need to expose all the errors property here...
 99    request.setAttribute("javax.servlet.error.exception", e)
100    request.setAttribute("javax.servlet.error.exception_type", e.getClass)
101    request.setAttribute("javax.servlet.error.message", e.getMessage)
102    request.setAttribute("javax.servlet.error.request_uri", request.getRequestURI)
103    request.setAttribute("javax.servlet.error.servlet_name", request.getServerName)
104    request.setAttribute("javax.servlet.error.status_code", 500)
105    response.setStatus(500)
106
107    errorUris.find(x => findTemplate(x).isDefined) match {
108      case Some(template) =>
109        val context = new ServletRenderContext(engine, request, response, config.getServletContext)
110        context.include(template, true)
111        // since we directly rendered the error page.. remove the attributes
112        // since they screw /w tomcat.
113        request.removeAttribute("javax.servlet.error.exception")
114        request.removeAttribute("javax.servlet.error.exception_type")
115        request.removeAttribute("javax.servlet.error.message")
116        request.removeAttribute("javax.servlet.error.request_uri")
117        request.removeAttribute("javax.servlet.error.servlet_name")
118        request.removeAttribute("javax.servlet.error.status_code")
119      case None =>
120        throw e;
121    }
122  }
123
124  /**
125   * Allow derived filters to override and customize the template engine from the configuration
126   */
127  protected def createTemplateEngine(config: FilterConfig): ServletTemplateEngine = {
128    new ServletTemplateEngine(config)
129  }
130
131  protected def findTemplate(name: String) = finder.findTemplate(name)
132
133  def wrap(request: HttpServletRequest) = new ScalateServletRequestWrapper(request)
134
135  class ScalateServletRequestWrapper(request: HttpServletRequest) extends HttpServletRequestWrapper(request) {
136    override def getRequestDispatcher(path: String) = {
137      findTemplate(path).map(new ScalateRequestDispatcher(_)).getOrElse(request.getRequestDispatcher(path))
138    }
139  }
140
141  class ScalateRequestDispatcher(template: String) extends RequestDispatcher {
142    def forward(request: ServletRequest, response: ServletResponse): Unit = include(request, response)
143    def include(request: ServletRequest, response: ServletResponse): Unit = {
144      (request, response) match {
145        case (request: HttpServletRequest, response: HttpServletResponse) =>
146          val context = new ServletRenderContext(engine, wrap(request), response, config.getServletContext)
147          context.include(template, true)
148        case _ =>
149          None
150      }
151    }
152  }
153
154}