PageRenderTime 45ms CodeModel.GetById 19ms RepoModel.GetById 0ms app.codeStats 0ms

/netty-websockets/src/main/scala/plans.scala

https://github.com/dwestheide/Unfiltered
Scala | 316 lines | 237 code | 55 blank | 24 comment | 11 complexity | 874ef15004b3edcccdc74b7b29ca0ca1 MD5 | raw file
  1. package unfiltered.netty.websockets
  2. import unfiltered.request._
  3. import unfiltered.response._
  4. import unfiltered.netty._
  5. import org.jboss.{netty => jnetty}
  6. import jnetty.channel.{Channel, ChannelEvent, ChannelFuture, ChannelFutureListener,
  7. ChannelHandlerContext, MessageEvent, SimpleChannelUpstreamHandler}
  8. import jnetty.buffer.{ChannelBuffer, ChannelBuffers}
  9. import jnetty.handler.codec.http.HttpHeaders
  10. import jnetty.handler.codec.http.{HttpRequest => NHttpRequest,
  11. HttpResponse => NHttpResponse,
  12. DefaultHttpResponse}
  13. import jnetty.handler.codec.http.HttpVersion.HTTP_1_1
  14. import jnetty.handler.codec.http.HttpResponseStatus.FORBIDDEN
  15. import jnetty.handler.codec.http.HttpHeaders.setContentLength
  16. import jnetty.util.CharsetUtil
  17. object Plan {
  18. /** The trasition from an http request handling to websocket request handling.
  19. * Note: This can not be an Async.Intent because RequestBinding is a Responder for HttpResponses */
  20. type Intent =
  21. PartialFunction[RequestBinding, SocketIntent]
  22. /** WebSockets may be responded to asynchronously, thus their handler does not need to return */
  23. type SocketIntent =
  24. PartialFunction[SocketCallback, Unit]
  25. /** Equivalent of an HttpResponse's Pass fn.
  26. * A SocketIntent that does nothing */
  27. val Pass = ({
  28. case _ => ()
  29. }: SocketIntent)
  30. type PassHandler = (ChannelHandlerContext, ChannelEvent) => Unit
  31. val DefaultPassHandler = ({ (ctx, event) =>
  32. event match {
  33. case me: MessageEvent =>
  34. me.getMessage match {
  35. case request: NHttpRequest =>
  36. val res = new DefaultHttpResponse(HTTP_1_1, FORBIDDEN)
  37. res.setContent(ChannelBuffers.copiedBuffer(res.getStatus.toString, CharsetUtil.UTF_8))
  38. setContentLength(res, res.getContent.readableBytes)
  39. ctx.getChannel.write(res).addListener(ChannelFutureListener.CLOSE)
  40. case msg =>
  41. error("Invalid type of event message (%s) for Plan pass handling".format(
  42. msg.getClass.getName))
  43. }
  44. case _ => () // we really only care about MessageEvents but need to support the more generic ChannelEvent
  45. }
  46. }: PassHandler)
  47. }
  48. trait CloseOnException { self: ExceptionHandler =>
  49. def onException(ctx: ChannelHandlerContext, t: Throwable) {
  50. t.printStackTrace
  51. ctx.getChannel.close
  52. }
  53. }
  54. /** a light wrapper around both Sec-WebSocket-Draft + Sec-WebSocket-Version headers */
  55. private [websockets] object Version {
  56. def apply[T](req: HttpRequest[T]) = IetfDrafts.SecWebSocketDraft.unapply(req).orElse(
  57. IetfDrafts.SecWebSocketVersion.unapply(req)
  58. )
  59. }
  60. /** See also http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-14 */
  61. private [websockets] object IetfDrafts {
  62. /** Server handshake as described in
  63. * http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-14#section-4.2.2 */
  64. object Handshake {
  65. import java.security.MessageDigest
  66. import org.apache.commons.codec.binary.Base64.encodeBase64
  67. val GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
  68. val Sha1 = "SHA1"
  69. def sign(key: String): Array[Byte] =
  70. encodeBase64(MessageDigest.getInstance(Sha1).digest((key.trim + GUID).getBytes))
  71. def apply(key: String) = WebSocketAccept(new String(sign(key)))
  72. }
  73. // request headers
  74. object SecWebSocketKey extends StringHeader("Sec-WebSocket-Key")
  75. object SecWebSocketVersion extends StringHeader("Sec-WebSocket-Version")
  76. /** Prior to draft 04, the websocket spec provided an optional draft header
  77. * http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-03#section-10.8 */
  78. object SecWebSocketDraft extends StringHeader("Sec-WebSocket-Draft")
  79. // response headers
  80. object WebSocketAccept extends HeaderName("Sec-WebSocket-Accept")
  81. object SecWebSocketVersionName extends HeaderName("Sec-WebSocket-Version")
  82. }
  83. private [websockets] object HixieDrafts {
  84. import java.security.MessageDigest
  85. import jnetty.handler.codec.http.{HttpRequest => NHttpRequest}
  86. /** Sec-WebSocket-Key(1/2) included in the hixie drafts and later removed in ietf drafts
  87. * see the later in drafts 00-03
  88. * http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-03#section-1.3 */
  89. object SecKeyOne extends StringHeader("Sec-WebSocket-Key1")
  90. object SecKeyTwo extends StringHeader("Sec-WebSocket-Key2")
  91. case class Handshake(binding: RequestBinding) extends Responder[NHttpResponse] {
  92. def respond(res: HttpResponse[NHttpResponse]) {
  93. (HixieDrafts.SecKeyOne(binding), HixieDrafts.SecKeyTwo(binding)) match {
  94. case (Some(k1), Some(k2)) =>
  95. val buff = ChannelBuffers.buffer(16)
  96. (k1 :: k2 :: Nil).foreach( k =>
  97. buff.writeInt((k.replaceAll("[^0-9]", "").toLong /
  98. k.replaceAll("[^ ]", "").length).toInt)
  99. )
  100. buff.writeLong(binding.underlying.request.getContent().readLong)
  101. res.underlying.setContent(ChannelBuffers.wrappedBuffer(
  102. MessageDigest.getInstance("MD5").digest(buff.array)
  103. ))
  104. case _ => ()
  105. }
  106. }
  107. }
  108. }
  109. private [websockets] object ProtocolRequestHeader
  110. extends StringHeader(HttpHeaders.Names.SEC_WEBSOCKET_PROTOCOL)
  111. private [websockets] object OriginRequestHeader
  112. extends StringHeader(HttpHeaders.Names.ORIGIN)
  113. private [websockets] object ConnectionUpgrade {
  114. def unapply[T](req: HttpRequest[T]) =
  115. unfiltered.request.Connection(req).filter {
  116. !_.split(",").map(_.trim).filter(_.equalsIgnoreCase(HttpHeaders.Values.UPGRADE)).isEmpty
  117. }
  118. }
  119. private [websockets] object UpgradeWebsockets {
  120. def unapply[T](req: HttpRequest[T]) =
  121. Upgrade(req).filter {
  122. _.equalsIgnoreCase(HttpHeaders.Values.WEBSOCKET)
  123. }.headOption.map { _ => req }
  124. }
  125. private [websockets] object WSLocation {
  126. def apply[T](r: HttpRequest[T]) = "ws://%s%s" format(Host(r).get, r.uri)
  127. }
  128. trait Plan extends SimpleChannelUpstreamHandler with ExceptionHandler {
  129. import jnetty.channel.{ChannelStateEvent, ExceptionEvent}
  130. import jnetty.handler.codec.http.websocket.{DefaultWebSocketFrame, WebSocketFrame,
  131. WebSocketFrameDecoder => LegacyWebSocketFrameDecoder,
  132. WebSocketFrameEncoder => LegacyWebSocketFrameEncoder}
  133. import jnetty.handler.codec.http.{HttpHeaders, HttpMethod, HttpRequest => NHttpRequest,
  134. HttpResponseStatus, HttpVersion, DefaultHttpResponse}
  135. import HttpHeaders._
  136. import HttpHeaders.Names.{CONNECTION, ORIGIN, HOST, UPGRADE}
  137. import HttpHeaders.Values._
  138. val SecWebSocketLocation = "Sec-WebSocket-Location"
  139. val SecWebSocketOrigin = "Sec-WebSocket-Origin"
  140. val SecWebSocketProtocol = "Sec-WebSocket-Protocol"
  141. def intent: Plan.Intent
  142. def pass: Plan.PassHandler
  143. override def messageReceived(ctx: ChannelHandlerContext, event: MessageEvent) =
  144. event.getMessage match {
  145. case http: NHttpRequest => upgrade(ctx, http, event)
  146. case _ => pass(ctx, event)
  147. }
  148. private def upgrade(ctx: ChannelHandlerContext, request: NHttpRequest,
  149. event: MessageEvent) = {
  150. val msg = ReceivedMessage(request, ctx, event)
  151. val binding = new RequestBinding(msg)
  152. val version = Version(binding)
  153. binding match {
  154. case GET(ConnectionUpgrade(_) & UpgradeWebsockets(_)) =>
  155. intent.orElse({ case _ => Plan.Pass }: Plan.Intent)(binding) match {
  156. case Plan.Pass =>
  157. pass(ctx, event)
  158. case socketIntent =>
  159. val response = msg.response(
  160. new DefaultHttpResponse(
  161. HttpVersion.HTTP_1_1,
  162. new HttpResponseStatus(101, "Web Socket Protocol Handshake")
  163. )
  164. )_
  165. def attempt = socketIntent.orElse({ case _ => () }: Plan.SocketIntent)
  166. val Protocol = new Responder[NHttpResponse] {
  167. def respond(res: HttpResponse[NHttpResponse]) {
  168. ProtocolRequestHeader(binding) match {
  169. case Some(protocol) =>
  170. res.header(SecWebSocketProtocol, protocol)
  171. case _ => ()
  172. }
  173. }
  174. }
  175. val pipe = ctx.getChannel.getPipeline
  176. if(pipe.get("aggregator") != null) pipe.remove("aggregator")
  177. val legacy = version match {
  178. case None => true
  179. case Some(earlier)
  180. if(earlier.toInt < 4) => true
  181. case Some(recent) => false
  182. }
  183. pipe.replace("decoder", "wsdecoder",
  184. if(legacy) new LegacyWebSocketFrameDecoder
  185. else new Draft14WebSocketFrameDecoder)
  186. ctx.getChannel.write(
  187. response(
  188. new HeaderName(UPGRADE)(Values.WEBSOCKET) ~>
  189. new HeaderName(CONNECTION)(Values.UPGRADE) ~>
  190. new HeaderName(SecWebSocketOrigin)(OriginRequestHeader(binding).getOrElse("*")) ~>
  191. new HeaderName(SecWebSocketLocation)(WSLocation(binding)) ~>
  192. Protocol ~> (
  193. if(legacy) HixieDrafts.Handshake(binding)
  194. else {
  195. IetfDrafts.Handshake(IetfDrafts.SecWebSocketKey.unapply(binding).get) ~>
  196. IetfDrafts.SecWebSocketVersionName(version.getOrElse("0"))
  197. })
  198. )
  199. )
  200. ctx.getChannel.getCloseFuture.addListener(new ChannelFutureListener {
  201. def operationComplete(future: ChannelFuture) = {
  202. attempt(Close(WebSocket(ctx.getChannel)))
  203. }
  204. })
  205. pipe.replace("encoder", "wsencoder",
  206. if(legacy) new LegacyWebSocketFrameEncoder
  207. else new Draft14WebSocketFrameEncoder)
  208. attempt(Open(WebSocket(ctx.getChannel)))
  209. pipe.replace(this, ctx.getName, SocketPlan(socketIntent, pass))
  210. }
  211. case _ =>
  212. pass(ctx, event)
  213. }
  214. }
  215. /** By default, when a websocket handler `passes` it writes an Http Forbidden response
  216. * to the channel. To override that behavior, call this method with a function to handle
  217. * the ChannelEvent with custom behavior */
  218. def onPass(handler: Plan.PassHandler) = Planify(intent, handler)
  219. }
  220. class Planify(val intent: Plan.Intent, val pass: Plan.PassHandler) extends Plan with CloseOnException
  221. object Planify {
  222. import jnetty.buffer.ChannelBuffers
  223. import jnetty.handler.codec.http.{HttpRequest => NHttpRequest, DefaultHttpResponse}
  224. import jnetty.handler.codec.http.HttpHeaders._
  225. import jnetty.util.CharsetUtil
  226. def apply(intent: Plan.Intent, pass: Plan.PassHandler) = new Planify(intent, pass)
  227. /** Creates a WebSocketHandler that, when `Passing`, will return a forbidden
  228. * response to the client */
  229. def apply(intent: Plan.Intent): Plan =
  230. Planify(intent, Plan.DefaultPassHandler)
  231. }
  232. case class SocketPlan(intent: Plan.SocketIntent,
  233. pass: Plan.PassHandler) extends SimpleChannelUpstreamHandler {
  234. import jnetty.channel.{ChannelFuture, ChannelFutureListener, ExceptionEvent}
  235. import jnetty.handler.codec.http.websocket.WebSocketFrame
  236. /** 0x00-0x7F typed frame becomes (UTF-8) Text
  237. 0x80-0xFF typed frame becomes Binary */
  238. implicit def wsf2msg(wsf: WebSocketFrame): Msg =
  239. if(wsf.isText) Text(wsf.getTextData) else Binary(wsf.getBinaryData)
  240. def attempt = intent.orElse({ case _ => () }: Plan.SocketIntent)
  241. override def messageReceived(ctx: ChannelHandlerContext, event: MessageEvent) =
  242. event.getMessage match {
  243. case c @ ClosingFrame(_) =>
  244. ctx.getChannel.write(c).addListener(new ChannelFutureListener {
  245. override def operationComplete(f: ChannelFuture) = ctx.getChannel.close
  246. })
  247. case p @ PingFrame(_) => ctx.getChannel.write(p)
  248. case p @ PongFrame(_) => ctx.getChannel.write(p)
  249. case f: WebSocketFrame => f.getType match {
  250. case 0xFF => /* binary not impl */()
  251. case _ => attempt(Message(WebSocket(ctx.getChannel), f))
  252. }
  253. case _ => pass(ctx, event)
  254. }
  255. // todo: if there's an error we may want to bubble this upstream
  256. override def exceptionCaught(ctx: ChannelHandlerContext, event: ExceptionEvent) = {
  257. event.getCause.printStackTrace
  258. attempt(Error(WebSocket(ctx.getChannel), event.getCause))
  259. event.getChannel.close
  260. }
  261. }