PageRenderTime 49ms CodeModel.GetById 20ms RepoModel.GetById 1ms app.codeStats 0ms

/src/Microsoft.AspNet.SignalR.Core/Hubs/HubDispatcher.cs

https://github.com/mip1983/SignalR
C# | 439 lines | 348 code | 66 blank | 25 comment | 33 complexity | 639a8a955499de52088a2d10743bb9b3 MD5 | raw file
Possible License(s): Apache-2.0, CC-BY-SA-3.0
  1. // Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.md in the project root for license information.
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Diagnostics;
  5. using System.Globalization;
  6. using System.Linq;
  7. using System.Linq.Expressions;
  8. using System.Reflection;
  9. using System.Threading.Tasks;
  10. using Microsoft.AspNet.SignalR.Infrastructure;
  11. namespace Microsoft.AspNet.SignalR.Hubs
  12. {
  13. /// <summary>
  14. /// Handles all communication over the hubs persistent connection.
  15. /// </summary>
  16. public class HubDispatcher : PersistentConnection
  17. {
  18. private readonly List<HubDescriptor> _hubs = new List<HubDescriptor>();
  19. private readonly string _url;
  20. private IJavaScriptProxyGenerator _proxyGenerator;
  21. private IHubManager _manager;
  22. private IHubRequestParser _requestParser;
  23. private IParameterResolver _binder;
  24. private IHubPipelineInvoker _pipelineInvoker;
  25. private IPerformanceCounterManager _counters;
  26. private bool _isDebuggingEnabled;
  27. private static readonly MethodInfo _continueWithMethod = typeof(HubDispatcher).GetMethod("ContinueWith", BindingFlags.NonPublic | BindingFlags.Static);
  28. /// <summary>
  29. /// Initializes an instance of the <see cref="HubDispatcher"/> class.
  30. /// </summary>
  31. /// <param name="url">The base url of the connection url.</param>
  32. public HubDispatcher(string url)
  33. {
  34. _url = url;
  35. }
  36. protected override TraceSource Trace
  37. {
  38. get
  39. {
  40. return _trace["SignalR.HubDispatcher"];
  41. }
  42. }
  43. public override void Initialize(IDependencyResolver resolver, HostContext context)
  44. {
  45. _proxyGenerator = resolver.Resolve<IJavaScriptProxyGenerator>();
  46. _manager = resolver.Resolve<IHubManager>();
  47. _binder = resolver.Resolve<IParameterResolver>();
  48. _requestParser = resolver.Resolve<IHubRequestParser>();
  49. _pipelineInvoker = resolver.Resolve<IHubPipelineInvoker>();
  50. _counters = resolver.Resolve<IPerformanceCounterManager>();
  51. // Call base initializer before populating _hubs so the _jsonSerializer is initialized
  52. base.Initialize(resolver, context);
  53. // Populate _hubs
  54. string data = context.Request.QueryStringOrForm("connectionData");
  55. if (!String.IsNullOrEmpty(data))
  56. {
  57. var clientHubInfo = _jsonSerializer.Parse<IEnumerable<ClientHubInfo>>(data);
  58. if (clientHubInfo != null)
  59. {
  60. foreach (var hubInfo in clientHubInfo)
  61. {
  62. // Try to find the associated hub type
  63. HubDescriptor hubDescriptor = _manager.EnsureHub(hubInfo.Name,
  64. _counters.ErrorsHubResolutionTotal,
  65. _counters.ErrorsHubResolutionPerSec,
  66. _counters.ErrorsAllTotal,
  67. _counters.ErrorsAllPerSec);
  68. if (_pipelineInvoker.AuthorizeConnect(hubDescriptor, context.Request))
  69. {
  70. // Add this to the list of hub descriptors this connection is interested in
  71. _hubs.Add(hubDescriptor);
  72. }
  73. }
  74. }
  75. }
  76. }
  77. /// <summary>
  78. /// Processes the hub's incoming method calls.
  79. /// </summary>
  80. protected override Task OnReceivedAsync(IRequest request, string connectionId, string data)
  81. {
  82. HubRequest hubRequest = _requestParser.Parse(data);
  83. // Create the hub
  84. HubDescriptor descriptor = _manager.EnsureHub(hubRequest.Hub,
  85. _counters.ErrorsHubInvocationTotal,
  86. _counters.ErrorsHubInvocationPerSec,
  87. _counters.ErrorsAllTotal,
  88. _counters.ErrorsAllPerSec);
  89. IJsonValue[] parameterValues = hubRequest.ParameterValues;
  90. // Resolve the method
  91. MethodDescriptor methodDescriptor = _manager.GetHubMethod(descriptor.Name, hubRequest.Method, parameterValues);
  92. if (methodDescriptor == null)
  93. {
  94. _counters.ErrorsHubInvocationTotal.Increment();
  95. _counters.ErrorsHubInvocationPerSec.Increment();
  96. throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, "'{0}' method could not be resolved.", hubRequest.Method));
  97. }
  98. // Resolving the actual state object
  99. var state = new TrackingDictionary(hubRequest.State);
  100. var hub = CreateHub(request, descriptor, connectionId, state, throwIfFailedToCreate: true);
  101. return InvokeHubPipeline(request, connectionId, data, hubRequest, parameterValues, methodDescriptor, state, hub)
  102. .ContinueWith(task => hub.Dispose(), TaskContinuationOptions.ExecuteSynchronously);
  103. }
  104. private Task InvokeHubPipeline(IRequest request, string connectionId, string data, HubRequest hubRequest, IJsonValue[] parameterValues, MethodDescriptor methodDescriptor, TrackingDictionary state, IHub hub)
  105. {
  106. var args = _binder.ResolveMethodParameters(methodDescriptor, parameterValues);
  107. var context = new HubInvokerContext(hub, state, methodDescriptor, args);
  108. // Invoke the pipeline
  109. return _pipelineInvoker.Invoke(context)
  110. .ContinueWith(task =>
  111. {
  112. if (task.IsFaulted)
  113. {
  114. return ProcessResponse(state, null, hubRequest, task.Exception);
  115. }
  116. else
  117. {
  118. return ProcessResponse(state, task.Result, hubRequest, null);
  119. }
  120. })
  121. .FastUnwrap();
  122. }
  123. public override Task ProcessRequestAsync(HostContext context)
  124. {
  125. // Generate the proxy
  126. if (context.Request.Url.LocalPath.EndsWith("/hubs", StringComparison.OrdinalIgnoreCase))
  127. {
  128. context.Response.ContentType = "application/x-javascript";
  129. return context.Response.EndAsync(_proxyGenerator.GenerateProxy(_url));
  130. }
  131. _isDebuggingEnabled = context.IsDebuggingEnabled();
  132. return base.ProcessRequestAsync(context);
  133. }
  134. internal static Task Connect(IHub hub)
  135. {
  136. return hub.OnConnected();
  137. }
  138. internal static Task Reconnect(IHub hub)
  139. {
  140. return hub.OnReconnected();
  141. }
  142. internal static Task Disconnect(IHub hub)
  143. {
  144. return hub.OnDisconnected();
  145. }
  146. internal static Task<object> Incoming(IHubIncomingInvokerContext context)
  147. {
  148. var tcs = new TaskCompletionSource<object>();
  149. try
  150. {
  151. var result = context.MethodDescriptor.Invoker.Invoke(context.Hub, context.Args);
  152. Type returnType = context.MethodDescriptor.ReturnType;
  153. if (typeof(Task).IsAssignableFrom(returnType))
  154. {
  155. var task = (Task)result;
  156. if (!returnType.IsGenericType)
  157. {
  158. task.ContinueWith(tcs);
  159. }
  160. else
  161. {
  162. // Get the <T> in Task<T>
  163. Type resultType = returnType.GetGenericArguments().Single();
  164. Type genericTaskType = typeof(Task<>).MakeGenericType(resultType);
  165. // Get the correct ContinueWith overload
  166. var parameter = Expression.Parameter(typeof(object));
  167. // TODO: Cache this whole thing
  168. // Action<object> callback = result => ContinueWith((Task<T>)result, tcs);
  169. MethodInfo continueWithMethod = _continueWithMethod.MakeGenericMethod(resultType);
  170. Expression body = Expression.Call(continueWithMethod,
  171. Expression.Convert(parameter, genericTaskType),
  172. Expression.Constant(tcs));
  173. var continueWithInvoker = Expression.Lambda<Action<object>>(body, parameter).Compile();
  174. continueWithInvoker.Invoke(result);
  175. }
  176. }
  177. else
  178. {
  179. tcs.TrySetResult(result);
  180. }
  181. }
  182. catch (Exception ex)
  183. {
  184. tcs.TrySetException(ex);
  185. }
  186. return tcs.Task;
  187. }
  188. internal static Task Outgoing(IHubOutgoingInvokerContext context)
  189. {
  190. var message = new ConnectionMessage(context.Signal, context.Invocation)
  191. {
  192. ExcludedSignals = context.ExcludedSignals
  193. };
  194. return context.Connection.Send(message);
  195. }
  196. protected override Task OnConnectedAsync(IRequest request, string connectionId)
  197. {
  198. return ExecuteHubEventAsync(request, connectionId, hub => _pipelineInvoker.Connect(hub));
  199. }
  200. protected override Task OnReconnectedAsync(IRequest request, string connectionId)
  201. {
  202. return ExecuteHubEventAsync(request, connectionId, hub => _pipelineInvoker.Reconnect(hub));
  203. }
  204. protected override IEnumerable<string> OnRejoiningGroups(IRequest request, IEnumerable<string> groups, string connectionId)
  205. {
  206. return _hubs.Select(hubDescriptor =>
  207. {
  208. string groupPrefix = hubDescriptor.Type.Name + ".";
  209. IEnumerable<string> groupsToRejoin = _pipelineInvoker.RejoiningGroups(hubDescriptor,
  210. request,
  211. groups.Where(g => g.StartsWith(groupPrefix))
  212. .Select(g => g.Substring(groupPrefix.Length)))
  213. .Select(g => groupPrefix + g);
  214. return groupsToRejoin;
  215. }).SelectMany(groupsToRejoin => groupsToRejoin);
  216. }
  217. protected override Task OnDisconnectAsync(IRequest request, string connectionId)
  218. {
  219. return ExecuteHubEventAsync(request, connectionId, hub => _pipelineInvoker.Disconnect(hub));
  220. }
  221. protected override IEnumerable<string> GetSignals(string connectionId)
  222. {
  223. return _hubs.SelectMany(info => new[] { info.Name, info.CreateQualifiedName(connectionId) })
  224. .Concat(base.GetSignals(connectionId));
  225. }
  226. private Task ExecuteHubEventAsync(IRequest request, string connectionId, Func<IHub, Task> action)
  227. {
  228. var hubs = GetHubs(request, connectionId).ToList();
  229. var operations = hubs.Select(instance => action(instance).Catch().OrEmpty()).ToArray();
  230. if (operations.Length == 0)
  231. {
  232. DisposeHubs(hubs);
  233. return TaskAsyncHelper.Empty;
  234. }
  235. var tcs = new TaskCompletionSource<object>();
  236. Task.Factory.ContinueWhenAll(operations, tasks =>
  237. {
  238. DisposeHubs(hubs);
  239. var faulted = tasks.FirstOrDefault(t => t.IsFaulted);
  240. if (faulted != null)
  241. {
  242. tcs.SetException(faulted.Exception);
  243. }
  244. else if (tasks.Any(t => t.IsCanceled))
  245. {
  246. tcs.SetCanceled();
  247. }
  248. else
  249. {
  250. tcs.SetResult(null);
  251. }
  252. });
  253. return tcs.Task;
  254. }
  255. private IHub CreateHub(IRequest request, HubDescriptor descriptor, string connectionId, TrackingDictionary state = null, bool throwIfFailedToCreate = false)
  256. {
  257. try
  258. {
  259. var hub = _manager.ResolveHub(descriptor.Name);
  260. if (hub != null)
  261. {
  262. state = state ?? new TrackingDictionary();
  263. hub.Context = new HubCallerContext(request, connectionId);
  264. hub.Clients = new HubConnectionContext(_pipelineInvoker, Connection, descriptor.Name, connectionId, state);
  265. hub.Groups = new GroupManager(Connection, descriptor.Name);
  266. }
  267. return hub;
  268. }
  269. catch (Exception ex)
  270. {
  271. Trace.TraceInformation("Error creating hub {0}. " + ex.Message, descriptor.Name);
  272. if (throwIfFailedToCreate)
  273. {
  274. throw;
  275. }
  276. return null;
  277. }
  278. }
  279. private IEnumerable<IHub> GetHubs(IRequest request, string connectionId)
  280. {
  281. return from descriptor in _hubs
  282. select CreateHub(request, descriptor, connectionId) into hub
  283. where hub != null
  284. select hub;
  285. }
  286. private void DisposeHubs(IEnumerable<IHub> hubs)
  287. {
  288. foreach (var hub in hubs)
  289. {
  290. hub.Dispose();
  291. }
  292. }
  293. private Task ProcessTaskResult<T>(TrackingDictionary state, HubRequest request, Task<T> task)
  294. {
  295. if (task.IsFaulted)
  296. {
  297. return ProcessResponse(state, null, request, task.Exception);
  298. }
  299. return ProcessResponse(state, task.Result, request, null);
  300. }
  301. private Task ProcessResponse(TrackingDictionary state, object result, HubRequest request, Exception error)
  302. {
  303. var exception = error.Unwrap();
  304. string stackTrace = (exception != null && _isDebuggingEnabled) ? exception.StackTrace : null;
  305. string errorMessage = exception != null ? exception.Message : null;
  306. if (exception != null)
  307. {
  308. _counters.ErrorsHubInvocationTotal.Increment();
  309. _counters.ErrorsHubInvocationPerSec.Increment();
  310. _counters.ErrorsAllTotal.Increment();
  311. _counters.ErrorsAllPerSec.Increment();
  312. }
  313. var hubResult = new HubResponse
  314. {
  315. State = state.GetChanges(),
  316. Result = result,
  317. Id = request.Id,
  318. Error = errorMessage,
  319. StackTrace = stackTrace
  320. };
  321. return _transport.Send(hubResult);
  322. }
  323. private static void ContinueWith<T>(Task<T> task, TaskCompletionSource<object> tcs)
  324. {
  325. if (task.IsCompleted)
  326. {
  327. // Fast path for tasks that completed synchronously
  328. ContinueSync<T>(task, tcs);
  329. }
  330. else
  331. {
  332. ContinueAsync<T>(task, tcs);
  333. }
  334. }
  335. private static void ContinueSync<T>(Task<T> task, TaskCompletionSource<object> tcs)
  336. {
  337. if (task.IsFaulted)
  338. {
  339. tcs.TrySetException(task.Exception);
  340. }
  341. else if (task.IsCanceled)
  342. {
  343. tcs.TrySetCanceled();
  344. }
  345. else
  346. {
  347. tcs.TrySetResult(task.Result);
  348. }
  349. }
  350. private static void ContinueAsync<T>(Task<T> task, TaskCompletionSource<object> tcs)
  351. {
  352. task.ContinueWith(t =>
  353. {
  354. if (t.IsFaulted)
  355. {
  356. tcs.TrySetException(t.Exception);
  357. }
  358. else if (t.IsCanceled)
  359. {
  360. tcs.TrySetCanceled();
  361. }
  362. else
  363. {
  364. tcs.TrySetResult(t.Result);
  365. }
  366. });
  367. }
  368. private class ClientHubInfo
  369. {
  370. public string Name { get; set; }
  371. }
  372. }
  373. }