PageRenderTime 66ms CodeModel.GetById 22ms app.highlight 20ms RepoModel.GetById 19ms app.codeStats 1ms

/scalate-core/src/main/scala/org/fusesource/scalate/servlet/ServletRenderContext.scala

http://github.com/scalate/scalate
Scala | 258 lines | 146 code | 41 blank | 71 comment | 11 complexity | 0dae7f1239a07464108c60cdc096965e MD5 | raw file
  1/**
  2 * Copyright (C) 2009-2011 the original author or authors.
  3 * See the notice.md file distributed with this work for additional
  4 * information regarding copyright ownership.
  5 *
  6 * Licensed under the Apache License, Version 2.0 (the "License");
  7 * you may not use this file except in compliance with the License.
  8 * You may obtain a copy of the License at
  9 *
 10 *     http://www.apache.org/licenses/LICENSE-2.0
 11 *
 12 * Unless required by applicable law or agreed to in writing, software
 13 * distributed under the License is distributed on an "AS IS" BASIS,
 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 15 * See the License for the specific language governing permissions and
 16 * limitations under the License.
 17 */
 18package org.fusesource.scalate.servlet
 19
 20import java.io._
 21import java.util.Locale
 22
 23import _root_.org.fusesource.scalate.util.URIs._
 24import javax.servlet.http._
 25import javax.servlet.{ ServletContext, ServletException, ServletOutputStream }
 26import org.fusesource.scalate.{ AttributeMap, DefaultRenderContext, RenderContext, TemplateEngine }
 27
 28import scala.collection.JavaConverters._
 29import scala.collection.Set
 30import scala.collection.mutable.HashSet
 31
 32/**
 33 * Easy access to servlet request state.
 34 *
 35 * If you add the following code to your program
 36 * <code>import org.fusesource.scalate.servlet.ServletRequestContext._</code>
 37 * then you can access the current renderContext, request, response, servletContext
 38 */
 39object ServletRenderContext {
 40
 41  /**
 42   * Returns the currently active render context in this thread
 43   * @throws IllegalArgumentException if there is no suitable render context available in this thread
 44   */
 45  def renderContext: ServletRenderContext = RenderContext() match {
 46    case s: ServletRenderContext => s
 47    case n => throw new IllegalArgumentException("This threads RenderContext is not a ServletRenderContext as it is: " + n)
 48  }
 49
 50  def request: HttpServletRequest = renderContext.request
 51
 52  def response: HttpServletResponse = renderContext.response
 53
 54  def servletContext: ServletContext = renderContext.servletContext
 55}
 56
 57/**
 58 * A template context for use in servlets
 59 *
 60 * @version $Revision : 1.1 $
 61 */
 62class ServletRenderContext(
 63  engine: TemplateEngine,
 64  out: PrintWriter,
 65  val request: HttpServletRequest,
 66  val response: HttpServletResponse,
 67  val servletContext: ServletContext) extends DefaultRenderContext(request.getRequestURI, engine, out) {
 68
 69  def this(
 70    engine: TemplateEngine,
 71    request: HttpServletRequest,
 72    response: HttpServletResponse,
 73    servletContext: ServletContext) = {
 74    this(engine, response.getWriter, request, response, servletContext)
 75  }
 76
 77  viewPrefixes = List("WEB-INF", "")
 78
 79  override val attributes = new AttributeMap {
 80    request.setAttribute("context", ServletRenderContext.this)
 81
 82    def get(key: String): Option[Any] = {
 83      val value = apply(key)
 84      Option(value)
 85    }
 86
 87    def apply(key: String): Any = key match {
 88      case "context" => ServletRenderContext.this
 89      case _ => request.getAttribute(key)
 90    }
 91
 92    def update(key: String, value: Any): Unit = value match {
 93      case null => request.removeAttribute(key)
 94      case _ => request.setAttribute(key, value)
 95    }
 96
 97    def remove(key: String) = {
 98      val answer = get(key)
 99      request.removeAttribute(key)
100      answer
101    }
102
103    def keySet: Set[String] = {
104      val answer = new HashSet[String]()
105      for (a <- request.getAttributeNames.asScala) {
106        answer.add(a.toString)
107      }
108      answer
109    }
110
111    override def toString = keySet.map(k => "" + k + " -> " + apply(k)).mkString("{", ", ", "}")
112  }
113
114  /**
115   * Named servletConfig for historical reasons; actually returns a Config, which presents  a unified view of either a
116   * ServletConfig or a FilterConfig.
117   *
118   * @return a Config, if the servlet engine is a ServletTemplateEngine
119   * @throws IllegalStateException if the servlet engine is not a ServletTemplateEngine
120   */
121  def servletConfig: Config = engine match {
122    case servletEngine: ServletTemplateEngine => servletEngine.config
123    case _ => throw new IllegalArgumentException("render context not created with ServletTemplateEngine so cannot provide a ServletConfig")
124  }
125
126  override def locale: Locale = {
127    val locale = request.getLocale
128    if (locale == null) Locale.getDefault else locale
129  }
130
131  /**
132   * Forwards this request to the given page
133   */
134  def forward(page: String, escape: Boolean = false) = {
135    val newResponse = wrappedResponse
136    requestDispatcher(page).forward(wrappedRequest, newResponse)
137    newResponse.output(this, escape)
138  }
139
140  /**
141   * Includes the given servlet page
142   */
143  def servlet(page: String, escape: Boolean = false) = {
144    val newResponse = wrappedResponse
145    requestDispatcher(page).include(wrappedRequest, newResponse)
146    newResponse.output(this, escape)
147  }
148
149  /**
150   * Creates a URI which if the uri starts with / then the link is prefixed with the web applications context
151   */
152  override def uri(uri: String) = {
153    if (uri.startsWith("/")) {
154      request.getContextPath + uri
155    } else {
156      uri
157    }
158  }
159
160  /**
161   * Returns the current URI with new query arguments (separated with &)
162   */
163  def currentUriPlus(newQueryArgs: String) = {
164    uriPlus(requestUri, queryString, newQueryArgs)
165  }
166
167  /**
168   * Returns the current URI with query arguments (separated with &) removed
169   */
170  def currentUriMinus(newQueryArgs: String) = {
171    uriMinus(requestUri, queryString, newQueryArgs)
172  }
173
174  /**
175   * Returns all of the parameter values
176   */
177  def parameterValues(name: String): Array[String] = {
178    val answer = request.getParameterValues(name)
179    if (answer != null) {
180      answer
181    } else {
182      Array[String]()
183    }
184  }
185
186  /**
187   * Returns the first parameter
188   */
189  def parameter(name: String) = { request.getParameter(name) }
190
191  /**
192   * Returns the forwarded request uri or the current request URI if its not forwarded
193   */
194  override def requestUri: String = attributes.get("javax.servlet.forward.request_uri") match {
195    case Some(value: String) => value
196    case _ => request.getRequestURI
197  }
198
199  /**
200   * Returns the forwarded query string or the current query string if its not forwarded
201   */
202  def queryString: String = attributes.get("javax.servlet.forward.query_string") match {
203    case Some(value: String) => value
204    case _ => request.getQueryString
205  }
206
207  /**
208   * Returns the forwarded context path or the current context path if its not forwarded
209   */
210  def contextPath: String = attributes.get("javax.servlet.forward.context_path") match {
211    case Some(value: String) => value
212    case _ => request.getContextPath
213  }
214
215  protected def wrappedRequest = new WrappedRequest(request)
216
217  protected def wrappedResponse = new WrappedResponse(response)
218
219  protected def requestDispatcher(page: String) = {
220    // lets flush first to avoid missing current output
221    flush
222
223    val dispatcher = request.getRequestDispatcher(page)
224    if (dispatcher == null) {
225      throw new ServletException("No dispatcher available for path: " + page)
226    }
227    dispatcher
228  }
229}
230
231class WrappedRequest(request: HttpServletRequest) extends HttpServletRequestWrapper(request) {
232  override def getMethod = "GET"
233}
234
235class WrappedResponse(response: HttpServletResponse) extends HttpServletResponseWrapper(response) {
236  private[this] val bos = new ByteArrayOutputStream()
237  private[this] val sos = new ServletOutputStream {
238    def write(b: Int) = bos.write(b)
239  }
240  private[this] val writer = new PrintWriter(new OutputStreamWriter(bos))
241
242  override def getWriter = writer
243
244  override def getOutputStream = sos
245
246  def bytes = {
247    writer.flush
248    bos.toByteArray
249  }
250
251  def text = {
252    new String(bytes)
253  }
254
255  def output(context: RenderContext, escape: Boolean = false): Unit = {
256    context << context.value(text, escape)
257  }
258}