PageRenderTime 27ms CodeModel.GetById 27ms RepoModel.GetById 0ms app.codeStats 0ms

/akka-remote/src/main/scala/akka/remote/netty/Client.scala

https://github.com/joshlong/akka
Scala | 364 lines | 281 code | 60 blank | 23 comment | 30 complexity | af15f83f4d8cfca1ca68492440244ed7 MD5 | raw file
  1. /**
  2. * Copyright (C) 2009-2011 Typesafe Inc. <http://www.typesafe.com>
  3. */
  4. package akka.remote.netty
  5. import java.net.InetSocketAddress
  6. import org.jboss.netty.util.HashedWheelTimer
  7. import org.jboss.netty.bootstrap.ClientBootstrap
  8. import org.jboss.netty.channel.group.DefaultChannelGroup
  9. import org.jboss.netty.channel.{ ChannelHandler, StaticChannelPipeline, SimpleChannelUpstreamHandler, MessageEvent, ExceptionEvent, ChannelStateEvent, ChannelPipelineFactory, ChannelPipeline, ChannelHandlerContext, ChannelFuture, Channel }
  10. import org.jboss.netty.handler.codec.frame.{ LengthFieldPrepender, LengthFieldBasedFrameDecoder }
  11. import org.jboss.netty.handler.execution.ExecutionHandler
  12. import akka.remote.RemoteProtocol.{ RemoteControlProtocol, CommandType, AkkaRemoteProtocol }
  13. import akka.remote.{ RemoteProtocol, RemoteMessage, RemoteLifeCycleEvent, RemoteClientStarted, RemoteClientShutdown, RemoteClientException, RemoteClientError, RemoteClientDisconnected, RemoteClientConnected }
  14. import akka.actor.{ simpleName, Address }
  15. import akka.AkkaException
  16. import akka.event.Logging
  17. import akka.util.Switch
  18. import akka.actor.ActorRef
  19. import org.jboss.netty.channel.ChannelFutureListener
  20. import akka.remote.RemoteClientWriteFailed
  21. import java.net.InetAddress
  22. import org.jboss.netty.util.TimerTask
  23. import org.jboss.netty.util.Timeout
  24. import java.util.concurrent.TimeUnit
  25. import org.jboss.netty.handler.timeout.{ IdleState, IdleStateEvent, IdleStateAwareChannelHandler, IdleStateHandler }
  26. class RemoteClientMessageBufferException(message: String, cause: Throwable) extends AkkaException(message, cause) {
  27. def this(msg: String) = this(msg, null)
  28. }
  29. /**
  30. * This is the abstract baseclass for netty remote clients, currently there's only an
  31. * ActiveRemoteClient, but others could be feasible, like a PassiveRemoteClient that
  32. * reuses an already established connection.
  33. */
  34. abstract class RemoteClient private[akka] (
  35. val netty: NettyRemoteTransport,
  36. val remoteAddress: Address) {
  37. val log = Logging(netty.system, "RemoteClient")
  38. val name = simpleName(this) + "@" + remoteAddress
  39. private[remote] val runSwitch = new Switch()
  40. private[remote] def isRunning = runSwitch.isOn
  41. protected def currentChannel: Channel
  42. def connect(reconnectIfAlreadyConnected: Boolean = false): Boolean
  43. def shutdown(): Boolean
  44. def isBoundTo(address: Address): Boolean = remoteAddress == address
  45. /**
  46. * Converts the message to the wireprotocol and sends the message across the wire
  47. */
  48. def send(message: Any, senderOption: Option[ActorRef], recipient: ActorRef): Unit = if (isRunning) {
  49. if (netty.remoteSettings.LogSend) log.debug("Sending message {} from {} to {}", message, senderOption, recipient)
  50. send((message, senderOption, recipient))
  51. } else {
  52. val exception = new RemoteClientException("RemoteModule client is not running, make sure you have invoked 'RemoteClient.connect()' before using it.", netty, remoteAddress)
  53. netty.notifyListeners(RemoteClientError(exception, netty, remoteAddress))
  54. throw exception
  55. }
  56. /**
  57. * Sends the message across the wire
  58. */
  59. private def send(request: (Any, Option[ActorRef], ActorRef)): Unit = {
  60. try {
  61. val channel = currentChannel
  62. val f = channel.write(request)
  63. f.addListener(
  64. new ChannelFutureListener {
  65. def operationComplete(future: ChannelFuture) {
  66. if (future.isCancelled || !future.isSuccess) {
  67. netty.notifyListeners(RemoteClientWriteFailed(request, future.getCause, netty, remoteAddress))
  68. }
  69. }
  70. })
  71. // Check if we should back off
  72. if (!channel.isWritable) {
  73. val backoff = netty.settings.BackoffTimeout
  74. if (backoff.length > 0 && !f.await(backoff.length, backoff.unit)) f.cancel() //Waited as long as we could, now back off
  75. }
  76. } catch {
  77. case e: Exception netty.notifyListeners(RemoteClientError(e, netty, remoteAddress))
  78. }
  79. }
  80. override def toString = name
  81. }
  82. /**
  83. * RemoteClient represents a connection to an Akka node. Is used to send messages to remote actors on the node.
  84. */
  85. class ActiveRemoteClient private[akka] (
  86. netty: NettyRemoteTransport,
  87. remoteAddress: Address,
  88. localAddress: Address)
  89. extends RemoteClient(netty, remoteAddress) {
  90. import netty.settings
  91. //TODO rewrite to a wrapper object (minimize volatile access and maximize encapsulation)
  92. @volatile
  93. private var bootstrap: ClientBootstrap = _
  94. @volatile
  95. private var connection: ChannelFuture = _
  96. @volatile
  97. private[remote] var openChannels: DefaultChannelGroup = _
  98. @volatile
  99. private var executionHandler: ExecutionHandler = _
  100. @volatile
  101. private var reconnectionTimeWindowStart = 0L
  102. def notifyListeners(msg: RemoteLifeCycleEvent): Unit = netty.notifyListeners(msg)
  103. def currentChannel = connection.getChannel
  104. /**
  105. * Connect to remote server.
  106. */
  107. def connect(reconnectIfAlreadyConnected: Boolean = false): Boolean = {
  108. def sendSecureCookie(connection: ChannelFuture) {
  109. val handshake = RemoteControlProtocol.newBuilder.setCommandType(CommandType.CONNECT)
  110. if (settings.SecureCookie.nonEmpty) handshake.setCookie(settings.SecureCookie.get)
  111. handshake.setOrigin(RemoteProtocol.AddressProtocol.newBuilder
  112. .setSystem(localAddress.system)
  113. .setHostname(localAddress.host.get)
  114. .setPort(localAddress.port.get)
  115. .build)
  116. connection.getChannel.write(netty.createControlEnvelope(handshake.build))
  117. }
  118. def attemptReconnect(): Boolean = {
  119. val remoteIP = InetAddress.getByName(remoteAddress.host.get)
  120. log.debug("Remote client reconnecting to [{}|{}]", remoteAddress, remoteIP)
  121. connection = bootstrap.connect(new InetSocketAddress(remoteIP, remoteAddress.port.get))
  122. openChannels.add(connection.awaitUninterruptibly.getChannel) // Wait until the connection attempt succeeds or fails.
  123. if (!connection.isSuccess) {
  124. notifyListeners(RemoteClientError(connection.getCause, netty, remoteAddress))
  125. false
  126. } else {
  127. sendSecureCookie(connection)
  128. true
  129. }
  130. }
  131. runSwitch switchOn {
  132. openChannels = new DefaultDisposableChannelGroup(classOf[RemoteClient].getName)
  133. executionHandler = new ExecutionHandler(netty.executor)
  134. val b = new ClientBootstrap(netty.clientChannelFactory)
  135. b.setPipelineFactory(new ActiveRemoteClientPipelineFactory(name, b, executionHandler, remoteAddress, localAddress, this))
  136. b.setOption("tcpNoDelay", true)
  137. b.setOption("keepAlive", true)
  138. b.setOption("connectTimeoutMillis", settings.ConnectionTimeout.toMillis)
  139. settings.OutboundLocalAddress.foreach(s b.setOption("localAddress", new InetSocketAddress(s, 0)))
  140. bootstrap = b
  141. val remoteIP = InetAddress.getByName(remoteAddress.host.get)
  142. log.debug("Starting remote client connection to [{}|{}]", remoteAddress, remoteIP)
  143. connection = bootstrap.connect(new InetSocketAddress(remoteIP, remoteAddress.port.get))
  144. openChannels.add(connection.awaitUninterruptibly.getChannel) // Wait until the connection attempt succeeds or fails.
  145. if (!connection.isSuccess) {
  146. notifyListeners(RemoteClientError(connection.getCause, netty, remoteAddress))
  147. false
  148. } else {
  149. sendSecureCookie(connection)
  150. notifyListeners(RemoteClientStarted(netty, remoteAddress))
  151. true
  152. }
  153. } match {
  154. case true true
  155. case false if reconnectIfAlreadyConnected
  156. connection.getChannel.close()
  157. openChannels.remove(connection.getChannel)
  158. log.debug("Remote client reconnecting to [{}]", remoteAddress)
  159. attemptReconnect()
  160. case false false
  161. }
  162. }
  163. // Please note that this method does _not_ remove the ARC from the NettyRemoteClientModule's map of clients
  164. def shutdown() = runSwitch switchOff {
  165. log.debug("Shutting down remote client [{}]", name)
  166. notifyListeners(RemoteClientShutdown(netty, remoteAddress))
  167. try {
  168. if ((connection ne null) && (connection.getChannel ne null))
  169. connection.getChannel.close()
  170. } finally {
  171. try {
  172. if (openChannels ne null) openChannels.close.awaitUninterruptibly()
  173. } finally {
  174. connection = null
  175. executionHandler = null
  176. }
  177. }
  178. log.debug("[{}] has been shut down", name)
  179. }
  180. private[akka] def isWithinReconnectionTimeWindow: Boolean = {
  181. if (reconnectionTimeWindowStart == 0L) {
  182. reconnectionTimeWindowStart = System.currentTimeMillis
  183. true
  184. } else {
  185. val timeLeft = (settings.ReconnectionTimeWindow.toMillis - (System.currentTimeMillis - reconnectionTimeWindowStart)) > 0
  186. if (timeLeft)
  187. log.info("Will try to reconnect to remote server for another [{}] milliseconds", timeLeft)
  188. timeLeft
  189. }
  190. }
  191. private[akka] def resetReconnectionTimeWindow = reconnectionTimeWindowStart = 0L
  192. }
  193. @ChannelHandler.Sharable
  194. class ActiveRemoteClientHandler(
  195. val name: String,
  196. val bootstrap: ClientBootstrap,
  197. val remoteAddress: Address,
  198. val localAddress: Address,
  199. val timer: HashedWheelTimer,
  200. val client: ActiveRemoteClient)
  201. extends IdleStateAwareChannelHandler {
  202. def runOnceNow(thunk: Unit): Unit = timer.newTimeout(new TimerTask() {
  203. def run(timeout: Timeout) = try { thunk } finally { timeout.cancel() }
  204. }, 0, TimeUnit.MILLISECONDS)
  205. override def channelIdle(ctx: ChannelHandlerContext, e: IdleStateEvent) {
  206. import IdleState._
  207. def createHeartBeat(localAddress: Address, cookie: Option[String]): AkkaRemoteProtocol = {
  208. val beat = RemoteControlProtocol.newBuilder.setCommandType(CommandType.HEARTBEAT)
  209. if (cookie.nonEmpty) beat.setCookie(cookie.get)
  210. client.netty.createControlEnvelope(
  211. beat.setOrigin(RemoteProtocol.AddressProtocol.newBuilder
  212. .setSystem(localAddress.system)
  213. .setHostname(localAddress.host.get)
  214. .setPort(localAddress.port.get)
  215. .build).build)
  216. }
  217. e.getState match {
  218. case READER_IDLE | ALL_IDLE runOnceNow { client.netty.shutdownClientConnection(remoteAddress) }
  219. case WRITER_IDLE e.getChannel.write(createHeartBeat(localAddress, client.netty.settings.SecureCookie))
  220. }
  221. }
  222. override def messageReceived(ctx: ChannelHandlerContext, event: MessageEvent) {
  223. try {
  224. event.getMessage match {
  225. case arp: AkkaRemoteProtocol if arp.hasInstruction
  226. val rcp = arp.getInstruction
  227. rcp.getCommandType match {
  228. case CommandType.SHUTDOWN runOnceNow { client.netty.shutdownClientConnection(remoteAddress) }
  229. case _ //Ignore others
  230. }
  231. case arp: AkkaRemoteProtocol if arp.hasMessage
  232. client.netty.receiveMessage(new RemoteMessage(arp.getMessage, client.netty.system))
  233. case other
  234. throw new RemoteClientException("Unknown message received in remote client handler: " + other, client.netty, client.remoteAddress)
  235. }
  236. } catch {
  237. case e: Exception client.notifyListeners(RemoteClientError(e, client.netty, client.remoteAddress))
  238. }
  239. }
  240. override def channelClosed(ctx: ChannelHandlerContext, event: ChannelStateEvent) = client.runSwitch ifOn {
  241. if (client.isWithinReconnectionTimeWindow) {
  242. timer.newTimeout(new TimerTask() {
  243. def run(timeout: Timeout) =
  244. if (client.isRunning) {
  245. client.openChannels.remove(event.getChannel)
  246. client.connect(reconnectIfAlreadyConnected = true)
  247. }
  248. }, client.netty.settings.ReconnectDelay.toMillis, TimeUnit.MILLISECONDS)
  249. } else runOnceNow {
  250. client.netty.shutdownClientConnection(remoteAddress) // spawn in another thread
  251. }
  252. }
  253. override def channelConnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = {
  254. try {
  255. client.notifyListeners(RemoteClientConnected(client.netty, client.remoteAddress))
  256. client.resetReconnectionTimeWindow
  257. } catch {
  258. case e: Exception client.notifyListeners(RemoteClientError(e, client.netty, client.remoteAddress))
  259. }
  260. }
  261. override def channelDisconnected(ctx: ChannelHandlerContext, event: ChannelStateEvent) = {
  262. client.notifyListeners(RemoteClientDisconnected(client.netty, client.remoteAddress))
  263. }
  264. override def exceptionCaught(ctx: ChannelHandlerContext, event: ExceptionEvent) = {
  265. val cause = if (event.getCause ne null) event.getCause else new Exception("Unknown cause")
  266. client.notifyListeners(RemoteClientError(cause, client.netty, client.remoteAddress))
  267. event.getChannel.close()
  268. }
  269. }
  270. class ActiveRemoteClientPipelineFactory(
  271. name: String,
  272. bootstrap: ClientBootstrap,
  273. executionHandler: ExecutionHandler,
  274. remoteAddress: Address,
  275. localAddress: Address,
  276. client: ActiveRemoteClient) extends ChannelPipelineFactory {
  277. import client.netty.settings
  278. def getPipeline: ChannelPipeline = {
  279. val timeout = new IdleStateHandler(client.netty.timer,
  280. settings.ReadTimeout.toSeconds.toInt,
  281. settings.WriteTimeout.toSeconds.toInt,
  282. settings.AllTimeout.toSeconds.toInt)
  283. val lenDec = new LengthFieldBasedFrameDecoder(settings.MessageFrameSize, 0, 4, 0, 4)
  284. val lenPrep = new LengthFieldPrepender(4)
  285. val messageDec = new RemoteMessageDecoder
  286. val messageEnc = new RemoteMessageEncoder(client.netty)
  287. val remoteClient = new ActiveRemoteClientHandler(name, bootstrap, remoteAddress, localAddress, client.netty.timer, client)
  288. new StaticChannelPipeline(timeout, lenDec, messageDec, lenPrep, messageEnc, executionHandler, remoteClient)
  289. }
  290. }
  291. class PassiveRemoteClient(val currentChannel: Channel,
  292. netty: NettyRemoteTransport,
  293. remoteAddress: Address)
  294. extends RemoteClient(netty, remoteAddress) {
  295. def connect(reconnectIfAlreadyConnected: Boolean = false): Boolean = runSwitch switchOn {
  296. netty.notifyListeners(RemoteClientStarted(netty, remoteAddress))
  297. log.debug("Starting remote client connection to [{}]", remoteAddress)
  298. }
  299. def shutdown() = runSwitch switchOff {
  300. log.debug("Shutting down remote client [{}]", name)
  301. netty.notifyListeners(RemoteClientShutdown(netty, remoteAddress))
  302. log.debug("[{}] has been shut down", name)
  303. }
  304. }