PageRenderTime 50ms CodeModel.GetById 17ms RepoModel.GetById 1ms app.codeStats 0ms

/Raven.Database/Server/Connections/WebSocketsTransport.cs

https://github.com/nwendel/ravendb
C# | 420 lines | 345 code | 65 blank | 10 comment | 57 complexity | 112d6b19f12fc8a917b178070ea79264 MD5 | raw file
Possible License(s): MPL-2.0-no-copyleft-exception, BSD-3-Clause, CC-BY-SA-3.0
  1. using System;
  2. using System.Collections.Concurrent;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Linq;
  6. using System.Net;
  7. using System.Security.Principal;
  8. using System.Threading;
  9. using System.Threading.Tasks;
  10. using Microsoft.Owin;
  11. using Raven.Abstractions;
  12. using Raven.Abstractions.Data;
  13. using Raven.Abstractions.Logging;
  14. using Raven.Abstractions.Util;
  15. using Raven.Database.Counters;
  16. using Raven.Database.Extensions;
  17. using Raven.Database.Server.Security;
  18. using Raven.Database.Server.Tenancy;
  19. using Raven.Imports.Newtonsoft.Json;
  20. using Raven.Json.Linq;
  21. namespace Raven.Database.Server.Connections
  22. {
  23. /*
  24. * This is really ugly way to go about it, but that is the interface that OWIN
  25. * gives us
  26. * http://owin.org/extensions/owin-WebSocket-Extension-v0.4.0.htm
  27. *
  28. */
  29. using WebSocketAccept = Action<IDictionary<string, object>, // options
  30. Func<IDictionary<string, object>, Task>>; // callback
  31. using WebSocketCloseAsync =
  32. Func<int /* closeStatus */,
  33. string /* closeDescription */,
  34. CancellationToken /* cancel */,
  35. Task>;
  36. using WebSocketReceiveAsync =
  37. Func<ArraySegment<byte> /* data */,
  38. CancellationToken /* cancel */,
  39. Task<Tuple<int /* messageType */,
  40. bool /* endOfMessage */,
  41. int /* count */>>>;
  42. using WebSocketSendAsync =
  43. Func<ArraySegment<byte> /* data */,
  44. int /* messageType */,
  45. bool /* endOfMessage */,
  46. CancellationToken /* cancel */,
  47. Task>;
  48. using WebSocketReceiveResult = Tuple<int, // type
  49. bool, // end of message?
  50. int>;
  51. using Raven.Database.Server.RavenFS; // count
  52. public class WebSocketsTransport : IEventsTransport
  53. {
  54. private static ILog log = LogManager.GetCurrentClassLogger();
  55. private readonly IOwinContext _context;
  56. private readonly RavenDBOptions _options;
  57. private readonly AsyncManualResetEvent manualResetEvent = new AsyncManualResetEvent();
  58. private readonly ConcurrentQueue<object> msgs = new ConcurrentQueue<object>();
  59. private const int WebSocketCloseMessageType = 8;
  60. private const int NormalClosureCode = 1000;
  61. private const string NormalClosureMessage = "CLOSE_NORMAL";
  62. public string Id { get; private set; }
  63. public bool Connected { get; set; }
  64. public long CoolDownWithDataLossInMiliseconds { get; set; }
  65. private long lastMessageSentTick = 0;
  66. private object lastMessageEnqueuedAndNotSent = null;
  67. public string ResourceName { get; set; }
  68. private Func<string, WebSocketsTransport, IPrincipal,bool> RegistrationLogicAction = null;
  69. private Func<object, object> MessageFormatter = null;
  70. public WebSocketsTransport(RavenDBOptions options, IOwinContext context, Func<string, WebSocketsTransport, IPrincipal, bool> registrationLogics = null, Func<object, object> messageFormatter = null)
  71. {
  72. _options = options;
  73. _context = context;
  74. Connected = true;
  75. Id = context.Request.Query["id"];
  76. long waitTimeBetweenMessages = 0;
  77. long.TryParse(context.Request.Query["coolDownWithDataLoss"], out waitTimeBetweenMessages);
  78. CoolDownWithDataLossInMiliseconds = waitTimeBetweenMessages;
  79. RegistrationLogicAction = registrationLogics;
  80. MessageFormatter = messageFormatter;
  81. }
  82. public void Dispose()
  83. {
  84. }
  85. public event Action Disconnected;
  86. public void SendAsync(object msg)
  87. {
  88. msgs.Enqueue(msg);
  89. manualResetEvent.Set();
  90. }
  91. public async Task Run(IDictionary<string, object> websocketContext)
  92. {
  93. try
  94. {
  95. var sendAsync = (WebSocketSendAsync) websocketContext["websocket.SendAsync"];
  96. var callCancelled = (CancellationToken) websocketContext["websocket.CallCancelled"];
  97. var memoryStream = new MemoryStream();
  98. var serializer = new JsonSerializer
  99. {
  100. Converters = {new EtagJsonConverter()}
  101. };
  102. CreateWaitForClientCloseTask(websocketContext, callCancelled);
  103. while (callCancelled.IsCancellationRequested == false)
  104. {
  105. var result = await manualResetEvent.WaitAsync(5000);
  106. if (callCancelled.IsCancellationRequested)
  107. break;
  108. if (result == false)
  109. {
  110. await SendMessage(memoryStream, serializer,
  111. new { Type = "Heartbeat", Time = SystemTime.UtcNow },
  112. sendAsync, callCancelled);
  113. if (lastMessageEnqueuedAndNotSent != null)
  114. {
  115. var messageToSend = lastMessageEnqueuedAndNotSent;
  116. if (MessageFormatter != null)
  117. {
  118. messageToSend = MessageFormatter(lastMessageEnqueuedAndNotSent);
  119. }
  120. await SendMessage(memoryStream, serializer, messageToSend, sendAsync, callCancelled);
  121. lastMessageEnqueuedAndNotSent = null;
  122. lastMessageSentTick = Environment.TickCount;
  123. }
  124. continue;
  125. }
  126. manualResetEvent.Reset();
  127. object message;
  128. while (msgs.TryDequeue(out message))
  129. {
  130. if (callCancelled.IsCancellationRequested)
  131. break;
  132. if (CoolDownWithDataLossInMiliseconds > 0 && Environment.TickCount - lastMessageSentTick < CoolDownWithDataLossInMiliseconds)
  133. {
  134. lastMessageEnqueuedAndNotSent = message;
  135. continue;
  136. }
  137. var messageToSend = message;
  138. if (MessageFormatter != null)
  139. {
  140. messageToSend = MessageFormatter(message);
  141. }
  142. await SendMessage(memoryStream, serializer, messageToSend, sendAsync, callCancelled);
  143. lastMessageEnqueuedAndNotSent = null;
  144. lastMessageSentTick = Environment.TickCount;
  145. }
  146. }
  147. }
  148. finally
  149. {
  150. OnDisconnection();
  151. }
  152. }
  153. private void CreateWaitForClientCloseTask(IDictionary<string, object> websocketContext, CancellationToken callCancelled)
  154. {
  155. new Task(async () =>
  156. {
  157. var buffer = new ArraySegment<byte>(new byte[1024]);
  158. var receiveAsync = (WebSocketReceiveAsync)websocketContext["websocket.ReceiveAsync"];
  159. var closeAsync = (WebSocketCloseAsync) websocketContext["websocket.CloseAsync"];
  160. while (callCancelled.IsCancellationRequested == false)
  161. {
  162. try
  163. {
  164. WebSocketReceiveResult receiveResult = await receiveAsync(buffer, callCancelled);
  165. if (receiveResult.Item1 == WebSocketCloseMessageType)
  166. {
  167. var clientCloseStatus = (int) websocketContext["websocket.ClientCloseStatus"];
  168. var clientCloseDescription = (string) websocketContext["websocket.ClientCloseDescription"];
  169. if (clientCloseStatus == NormalClosureCode && clientCloseDescription == NormalClosureMessage)
  170. {
  171. await closeAsync(clientCloseStatus, clientCloseDescription, callCancelled);
  172. }
  173. //At this point the WebSocket is in a 'CloseReceived' state, so there is no need to continue waiting for messages
  174. break;
  175. }
  176. }
  177. catch (Exception e)
  178. {
  179. log.WarnException("Error when recieving message from web socket transport", e);
  180. return;
  181. }
  182. }
  183. }).Start();
  184. }
  185. private async Task SendMessage(MemoryStream memoryStream, JsonSerializer serializer, object message, WebSocketSendAsync sendAsync, CancellationToken callCancelled)
  186. {
  187. memoryStream.Position = 0;
  188. var jsonTextWriter = new JsonTextWriter(new StreamWriter(memoryStream));
  189. serializer.Serialize(jsonTextWriter, message);
  190. jsonTextWriter.Flush();
  191. var arraySegment = new ArraySegment<byte>(memoryStream.GetBuffer(), 0, (int) memoryStream.Position);
  192. await sendAsync(arraySegment, 1, true, callCancelled);
  193. }
  194. private void OnDisconnection()
  195. {
  196. Connected = false;
  197. Action onDisconnected = Disconnected;
  198. if (onDisconnected != null)
  199. onDisconnected();
  200. }
  201. public async Task<bool> TrySetupRequest()
  202. {
  203. if (string.IsNullOrEmpty(Id))
  204. {
  205. _context.Response.StatusCode = 400;
  206. _context.Response.ReasonPhrase = "BadRequest";
  207. _context.Response.Write("{ 'Error': 'Id is mandatory' }");
  208. return false;
  209. }
  210. var documentDatabase = await GetDatabase();
  211. var fileSystem = await GetFileSystem();
  212. var counterStorage = await GetCounterStorage();
  213. if (documentDatabase == null && fileSystem == null && counterStorage == null)
  214. {
  215. return false;
  216. }
  217. var singleUseToken = _context.Request.Query["singleUseAuthToken"];
  218. IPrincipal user = null;
  219. var ResourceName = (fileSystem != null) ? fileSystem.Name : (counterStorage != null) ? counterStorage.Name : (documentDatabase != null) ? documentDatabase.Name : null;
  220. if (string.IsNullOrEmpty(singleUseToken) == false)
  221. {
  222. object msg;
  223. HttpStatusCode code;
  224. if (_options.MixedModeRequestAuthorizer.TryAuthorizeSingleUseAuthToken(singleUseToken, ResourceName, out msg, out code, out user) == false)
  225. {
  226. _context.Response.StatusCode = (int) code;
  227. _context.Response.ReasonPhrase = code.ToString();
  228. _context.Response.Write(RavenJToken.FromObject(msg).ToString(Formatting.Indented));
  229. return false;
  230. }
  231. }
  232. else
  233. {
  234. switch (_options.SystemDatabase.Configuration.AnonymousUserAccessMode)
  235. {
  236. case AnonymousUserAccessMode.Admin:
  237. case AnonymousUserAccessMode.All:
  238. case AnonymousUserAccessMode.Get:
  239. // this is effectively a GET request, so we'll allow it
  240. // under this circumstances
  241. user = CurrentOperationContext.User.Value;
  242. break;
  243. case AnonymousUserAccessMode.None:
  244. _context.Response.StatusCode = 403;
  245. _context.Response.ReasonPhrase = "Forbidden";
  246. _context.Response.Write("{'Error': 'Single use token is required for authenticated web sockets connections' }");
  247. return false;
  248. break;
  249. default:
  250. throw new ArgumentOutOfRangeException(_options.SystemDatabase.Configuration.AnonymousUserAccessMode.ToString());
  251. }
  252. }
  253. // execute custom registration logic received as a parameter
  254. if (RegistrationLogicAction!= null)
  255. {
  256. return RegistrationLogicAction(ResourceName, this, user);
  257. }
  258. if (fileSystem != null)
  259. {
  260. fileSystem.TransportState.Register(this);
  261. }
  262. else if (counterStorage != null)
  263. {
  264. counterStorage.TransportState.Register(this);
  265. }
  266. else if (documentDatabase != null)
  267. {
  268. documentDatabase.TransportState.Register(this);
  269. }
  270. return true;
  271. }
  272. private async Task<DocumentDatabase> GetDatabase()
  273. {
  274. var dbName = GetDatabaseName();
  275. if (dbName == null)
  276. return _options.SystemDatabase;
  277. DocumentDatabase documentDatabase;
  278. try
  279. {
  280. documentDatabase = await _options.DatabaseLandlord.GetDatabaseInternal(dbName);
  281. }
  282. catch (Exception e)
  283. {
  284. _context.Response.StatusCode = 500;
  285. _context.Response.ReasonPhrase = "InternalServerError";
  286. _context.Response.Write(e.ToString());
  287. return null;
  288. }
  289. return documentDatabase;
  290. }
  291. private async Task<RavenFileSystem> GetFileSystem()
  292. {
  293. var fsName = GetFileSystemName();
  294. if (fsName == null)
  295. return null;
  296. RavenFileSystem ravenFileSystem;
  297. try
  298. {
  299. ravenFileSystem = await _options.FileSystemLandlord.GetFileSystemInternal(fsName);
  300. }
  301. catch (Exception e)
  302. {
  303. _context.Response.StatusCode = 500;
  304. _context.Response.ReasonPhrase = "InternalServerError";
  305. _context.Response.Write(e.ToString());
  306. return null;
  307. }
  308. return ravenFileSystem;
  309. }
  310. private async Task<CounterStorage> GetCounterStorage()
  311. {
  312. var csName = GetCounterStorageName();
  313. if (csName == null)
  314. return null;
  315. CounterStorage counterStorage;
  316. try
  317. {
  318. counterStorage = await _options.CountersLandlord.GetCounterInternal(csName);
  319. }
  320. catch (Exception e)
  321. {
  322. _context.Response.StatusCode = 500;
  323. _context.Response.ReasonPhrase = "InternalServerError";
  324. _context.Response.Write(e.ToString());
  325. return null;
  326. }
  327. return counterStorage;
  328. }
  329. private string GetDatabaseName()
  330. {
  331. var localPath = _context.Request.Uri.LocalPath;
  332. const string databasesPrefix = "/databases/";
  333. return GetResourceName(localPath, databasesPrefix);
  334. }
  335. private string GetFileSystemName()
  336. {
  337. var localPath = _context.Request.Uri.LocalPath;
  338. const string fileSystemPrefix = "/fs/";
  339. return GetResourceName(localPath, fileSystemPrefix);
  340. }
  341. private string GetCounterStorageName()
  342. {
  343. var localPath = _context.Request.Uri.LocalPath;
  344. const string counterStoragePrefix = "/counters/";
  345. return GetResourceName(localPath, counterStoragePrefix);
  346. }
  347. private string GetResourceName(string localPath, string prefix)
  348. {
  349. if (localPath.StartsWith(prefix) == false)
  350. return null;
  351. var indexOf = localPath.IndexOf('/', prefix.Length + 1);
  352. return (indexOf > -1) ? localPath.Substring(prefix.Length, indexOf - prefix.Length) : null;
  353. }
  354. }
  355. }