/scalate-core/src/main/scala/org/fusesource/scalate/servlet/ServletRenderContext.scala

http://github.com/scalate/scalate · Scala · 258 lines · 146 code · 41 blank · 71 comment · 11 complexity · 0dae7f1239a07464108c60cdc096965e MD5 · raw file

  1. /**
  2. * Copyright (C) 2009-2011 the original author or authors.
  3. * See the notice.md file distributed with this work for additional
  4. * information regarding copyright ownership.
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. package org.fusesource.scalate.servlet
  19. import java.io._
  20. import java.util.Locale
  21. import _root_.org.fusesource.scalate.util.URIs._
  22. import javax.servlet.http._
  23. import javax.servlet.{ ServletContext, ServletException, ServletOutputStream }
  24. import org.fusesource.scalate.{ AttributeMap, DefaultRenderContext, RenderContext, TemplateEngine }
  25. import scala.collection.JavaConverters._
  26. import scala.collection.Set
  27. import scala.collection.mutable.HashSet
  28. /**
  29. * Easy access to servlet request state.
  30. *
  31. * If you add the following code to your program
  32. * <code>import org.fusesource.scalate.servlet.ServletRequestContext._</code>
  33. * then you can access the current renderContext, request, response, servletContext
  34. */
  35. object ServletRenderContext {
  36. /**
  37. * Returns the currently active render context in this thread
  38. * @throws IllegalArgumentException if there is no suitable render context available in this thread
  39. */
  40. def renderContext: ServletRenderContext = RenderContext() match {
  41. case s: ServletRenderContext => s
  42. case n => throw new IllegalArgumentException("This threads RenderContext is not a ServletRenderContext as it is: " + n)
  43. }
  44. def request: HttpServletRequest = renderContext.request
  45. def response: HttpServletResponse = renderContext.response
  46. def servletContext: ServletContext = renderContext.servletContext
  47. }
  48. /**
  49. * A template context for use in servlets
  50. *
  51. * @version $Revision : 1.1 $
  52. */
  53. class ServletRenderContext(
  54. engine: TemplateEngine,
  55. out: PrintWriter,
  56. val request: HttpServletRequest,
  57. val response: HttpServletResponse,
  58. val servletContext: ServletContext) extends DefaultRenderContext(request.getRequestURI, engine, out) {
  59. def this(
  60. engine: TemplateEngine,
  61. request: HttpServletRequest,
  62. response: HttpServletResponse,
  63. servletContext: ServletContext) = {
  64. this(engine, response.getWriter, request, response, servletContext)
  65. }
  66. viewPrefixes = List("WEB-INF", "")
  67. override val attributes = new AttributeMap {
  68. request.setAttribute("context", ServletRenderContext.this)
  69. def get(key: String): Option[Any] = {
  70. val value = apply(key)
  71. Option(value)
  72. }
  73. def apply(key: String): Any = key match {
  74. case "context" => ServletRenderContext.this
  75. case _ => request.getAttribute(key)
  76. }
  77. def update(key: String, value: Any): Unit = value match {
  78. case null => request.removeAttribute(key)
  79. case _ => request.setAttribute(key, value)
  80. }
  81. def remove(key: String) = {
  82. val answer = get(key)
  83. request.removeAttribute(key)
  84. answer
  85. }
  86. def keySet: Set[String] = {
  87. val answer = new HashSet[String]()
  88. for (a <- request.getAttributeNames.asScala) {
  89. answer.add(a.toString)
  90. }
  91. answer
  92. }
  93. override def toString = keySet.map(k => "" + k + " -> " + apply(k)).mkString("{", ", ", "}")
  94. }
  95. /**
  96. * Named servletConfig for historical reasons; actually returns a Config, which presents a unified view of either a
  97. * ServletConfig or a FilterConfig.
  98. *
  99. * @return a Config, if the servlet engine is a ServletTemplateEngine
  100. * @throws IllegalStateException if the servlet engine is not a ServletTemplateEngine
  101. */
  102. def servletConfig: Config = engine match {
  103. case servletEngine: ServletTemplateEngine => servletEngine.config
  104. case _ => throw new IllegalArgumentException("render context not created with ServletTemplateEngine so cannot provide a ServletConfig")
  105. }
  106. override def locale: Locale = {
  107. val locale = request.getLocale
  108. if (locale == null) Locale.getDefault else locale
  109. }
  110. /**
  111. * Forwards this request to the given page
  112. */
  113. def forward(page: String, escape: Boolean = false) = {
  114. val newResponse = wrappedResponse
  115. requestDispatcher(page).forward(wrappedRequest, newResponse)
  116. newResponse.output(this, escape)
  117. }
  118. /**
  119. * Includes the given servlet page
  120. */
  121. def servlet(page: String, escape: Boolean = false) = {
  122. val newResponse = wrappedResponse
  123. requestDispatcher(page).include(wrappedRequest, newResponse)
  124. newResponse.output(this, escape)
  125. }
  126. /**
  127. * Creates a URI which if the uri starts with / then the link is prefixed with the web applications context
  128. */
  129. override def uri(uri: String) = {
  130. if (uri.startsWith("/")) {
  131. request.getContextPath + uri
  132. } else {
  133. uri
  134. }
  135. }
  136. /**
  137. * Returns the current URI with new query arguments (separated with &)
  138. */
  139. def currentUriPlus(newQueryArgs: String) = {
  140. uriPlus(requestUri, queryString, newQueryArgs)
  141. }
  142. /**
  143. * Returns the current URI with query arguments (separated with &) removed
  144. */
  145. def currentUriMinus(newQueryArgs: String) = {
  146. uriMinus(requestUri, queryString, newQueryArgs)
  147. }
  148. /**
  149. * Returns all of the parameter values
  150. */
  151. def parameterValues(name: String): Array[String] = {
  152. val answer = request.getParameterValues(name)
  153. if (answer != null) {
  154. answer
  155. } else {
  156. Array[String]()
  157. }
  158. }
  159. /**
  160. * Returns the first parameter
  161. */
  162. def parameter(name: String) = { request.getParameter(name) }
  163. /**
  164. * Returns the forwarded request uri or the current request URI if its not forwarded
  165. */
  166. override def requestUri: String = attributes.get("javax.servlet.forward.request_uri") match {
  167. case Some(value: String) => value
  168. case _ => request.getRequestURI
  169. }
  170. /**
  171. * Returns the forwarded query string or the current query string if its not forwarded
  172. */
  173. def queryString: String = attributes.get("javax.servlet.forward.query_string") match {
  174. case Some(value: String) => value
  175. case _ => request.getQueryString
  176. }
  177. /**
  178. * Returns the forwarded context path or the current context path if its not forwarded
  179. */
  180. def contextPath: String = attributes.get("javax.servlet.forward.context_path") match {
  181. case Some(value: String) => value
  182. case _ => request.getContextPath
  183. }
  184. protected def wrappedRequest = new WrappedRequest(request)
  185. protected def wrappedResponse = new WrappedResponse(response)
  186. protected def requestDispatcher(page: String) = {
  187. // lets flush first to avoid missing current output
  188. flush
  189. val dispatcher = request.getRequestDispatcher(page)
  190. if (dispatcher == null) {
  191. throw new ServletException("No dispatcher available for path: " + page)
  192. }
  193. dispatcher
  194. }
  195. }
  196. class WrappedRequest(request: HttpServletRequest) extends HttpServletRequestWrapper(request) {
  197. override def getMethod = "GET"
  198. }
  199. class WrappedResponse(response: HttpServletResponse) extends HttpServletResponseWrapper(response) {
  200. private[this] val bos = new ByteArrayOutputStream()
  201. private[this] val sos = new ServletOutputStream {
  202. def write(b: Int) = bos.write(b)
  203. }
  204. private[this] val writer = new PrintWriter(new OutputStreamWriter(bos))
  205. override def getWriter = writer
  206. override def getOutputStream = sos
  207. def bytes = {
  208. writer.flush
  209. bos.toByteArray
  210. }
  211. def text = {
  212. new String(bytes)
  213. }
  214. def output(context: RenderContext, escape: Boolean = false): Unit = {
  215. context << context.value(text, escape)
  216. }
  217. }