PageRenderTime 40ms CodeModel.GetById 15ms RepoModel.GetById 0ms app.codeStats 0ms

/cs/cs_parallel/VowpalWabbitAsync.cs

https://gitlab.com/admin-github-cloud/vowpal_wabbit
C# | 322 lines | 187 code | 40 blank | 95 comment | 37 complexity | f4d5101ff45f54092a515971df17164b MD5 | raw file
  1. // --------------------------------------------------------------------------------------------------------------------
  2. // <copyright file="VowpalWabbitAsync.cs">
  3. // Copyright (c) by respective owners including Yahoo!, Microsoft, and
  4. // individual contributors. All rights reserved. Released under a BSD
  5. // license as described in the file LICENSE.
  6. // </copyright>
  7. // --------------------------------------------------------------------------------------------------------------------
  8. using System;
  9. using System.Collections.Generic;
  10. using System.Diagnostics.Contracts;
  11. using System.Linq;
  12. using System.Text;
  13. using System.Threading.Tasks;
  14. using VW.Labels;
  15. using VW.Serializer;
  16. namespace VW
  17. {
  18. /// <summary>
  19. /// An async wrapper VW supporting data ingest using declarative serializer infrastructure used with <see cref="VowpalWabbitThreadedLearning"/>.
  20. /// </summary>
  21. /// <typeparam name="TExample">The user type to be serialized.</typeparam>
  22. public class VowpalWabbitAsync<TExample> : IDisposable
  23. {
  24. /// <summary>
  25. /// The owning manager.
  26. /// </summary>
  27. private VowpalWabbitThreadedLearning manager;
  28. /// <summary>
  29. /// The serializers are not thread-safe. Thus we need to allocate one for each VW instance.
  30. /// </summary>
  31. private IVowpalWabbitSerializer<TExample>[] serializers;
  32. internal VowpalWabbitAsync(VowpalWabbitThreadedLearning manager)
  33. {
  34. Contract.Requires(manager != null);
  35. Contract.Ensures(this.serializers != null);
  36. this.manager = manager;
  37. // create a serializer for each instance - maintaining separate example caches
  38. var serializer = VowpalWabbitSerializerFactory.CreateSerializer<TExample>(manager.Settings);
  39. this.serializers = this.manager.VowpalWabbits
  40. .Select(vw => serializer.Create(vw))
  41. .ToArray();
  42. }
  43. /// <summary>
  44. /// Learns from the given example.
  45. /// </summary>
  46. /// <param name="example">The example to learn.</param>
  47. /// <param name="label">The label for this <paramref name="example"/>.</param>
  48. /// <remarks>
  49. /// The method only enqueues the example for learning and returns immediately.
  50. /// You must not re-use the example.
  51. /// </remarks>
  52. public void Learn(TExample example, ILabel label = null)
  53. {
  54. Contract.Requires(example != null);
  55. Contract.Requires(label != null);
  56. manager.Post(vw =>
  57. {
  58. using (var ex = this.serializers[vw.Settings.Node].Serialize(example, label))
  59. {
  60. ex.Learn();
  61. }
  62. });
  63. }
  64. /// <summary>
  65. /// Predicts for the given example.
  66. /// </summary>
  67. /// <param name="example">The example to predict for.</param>
  68. /// <remarks>
  69. /// The method only enqueues the example for prediction and returns immediately.
  70. /// You must not re-use the example.
  71. /// </remarks>
  72. public void Predict(TExample example)
  73. {
  74. Contract.Requires(example != null);
  75. manager.Post(vw =>
  76. {
  77. using (var ex = this.serializers[vw.Settings.Node].Serialize(example))
  78. {
  79. ex.Predict();
  80. }
  81. });
  82. }
  83. /// <summary>
  84. /// Learns from the given example.
  85. /// </summary>
  86. /// <param name="example">The example to learn.</param>
  87. /// <param name="label">The label for this <paramref name="example"/>.</param>
  88. /// <param name="predictionFactory">The prediction factory to be used. See <see cref="VowpalWabbitPredictionType"/>.</param>
  89. /// <returns>The prediction for the given <paramref name="example"/>.</returns>
  90. /// <remarks>
  91. /// The method only enqueues the example for learning and returns immediately.
  92. /// Await the returned task to receive the prediction result.
  93. /// </remarks>
  94. public Task<TPrediction> Learn<TPrediction>(TExample example, ILabel label, IVowpalWabbitPredictionFactory<TPrediction> predictionFactory)
  95. {
  96. Contract.Requires(example != null);
  97. Contract.Requires(label != null);
  98. Contract.Requires(predictionFactory != null);
  99. return manager.Post(vw =>
  100. {
  101. using (var ex = this.serializers[vw.Settings.Node].Serialize(example, label))
  102. {
  103. return ex.Learn(predictionFactory);
  104. }
  105. });
  106. }
  107. /// <summary>
  108. /// Predicts for the given example.
  109. /// </summary>
  110. /// <param name="example">The example to predict for.</param>
  111. /// <param name="predictionFactory">The prediction factory to be used. See <see cref="VowpalWabbitPredictionType"/>.</param>
  112. /// <returns>The prediction for the given <paramref name="example"/>.</returns>
  113. /// <remarks>
  114. /// The method only enqueues the example for learning and returns immediately.
  115. /// Await the returned task to receive the prediction result.
  116. /// </remarks>
  117. public Task<TPrediction> Predict<TPrediction>(TExample example, IVowpalWabbitPredictionFactory<TPrediction> predictionFactory)
  118. {
  119. Contract.Requires(example != null);
  120. Contract.Requires(predictionFactory != null);
  121. return manager.Post(vw =>
  122. {
  123. using (var ex = this.serializers[vw.Settings.Node].Serialize(example))
  124. {
  125. return ex.Predict(predictionFactory);
  126. }
  127. });
  128. }
  129. /// <summary>
  130. /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
  131. /// </summary>
  132. public void Dispose()
  133. {
  134. this.Dispose(true);
  135. GC.SuppressFinalize(this);
  136. }
  137. private void Dispose(bool disposing)
  138. {
  139. if (disposing)
  140. {
  141. if (this.serializers != null)
  142. {
  143. foreach (var serializer in this.serializers)
  144. {
  145. // free cached examples
  146. serializer.Dispose();
  147. }
  148. this.serializers = null;
  149. }
  150. }
  151. }
  152. }
  153. /// <summary>
  154. /// An async VW wrapper for multiline ingest.
  155. /// </summary>
  156. /// <typeparam name="TExample">The user type of the shared feature.</typeparam>
  157. /// <typeparam name="TActionDependentFeature">The user type for each action dependent feature.</typeparam>
  158. public class VowpalWabbitAsync<TExample, TActionDependentFeature> : IDisposable
  159. {
  160. /// <summary>
  161. /// The owning manager.
  162. /// </summary>
  163. private readonly VowpalWabbitThreadedLearning manager;
  164. /// <summary>
  165. /// The serializers are not thread-safe. Thus we need to allocate one for each VW instance.
  166. /// </summary>
  167. private VowpalWabbitSingleExampleSerializer<TExample>[] serializers;
  168. /// <summary>
  169. /// The serializers are not thread-safe. Thus we need to allocate one for each VW instance.
  170. /// </summary>
  171. private VowpalWabbitSingleExampleSerializer<TActionDependentFeature>[] actionDependentFeatureSerializers;
  172. internal VowpalWabbitAsync(VowpalWabbitThreadedLearning manager)
  173. {
  174. if (manager == null)
  175. throw new ArgumentNullException("manager");
  176. if (manager.Settings == null)
  177. throw new ArgumentNullException("manager.Settings");
  178. if (manager.Settings.ParallelOptions == null)
  179. throw new ArgumentNullException("manager.Settings.ParallelOptions");
  180. if (manager.Settings.ParallelOptions.MaxDegreeOfParallelism <= 0)
  181. throw new ArgumentOutOfRangeException("MaxDegreeOfParallelism must be greater than zero.");
  182. Contract.Ensures(this.serializers != null);
  183. Contract.Ensures(this.actionDependentFeatureSerializers != null);
  184. Contract.EndContractBlock();
  185. this.manager = manager;
  186. // create a serializer for each instance - maintaining separate example caches
  187. var serializer = VowpalWabbitSerializerFactory.CreateSerializer<TExample>(manager.Settings) as VowpalWabbitSingleExampleSerializerCompiler<TExample>;
  188. if (serializer == null)
  189. throw new ArgumentException(string.Format(
  190. "{0} maps to a multiline example. Use VowpalWabbitAsync<{0}> instead.",
  191. typeof(TExample)));
  192. var adfSerializer = VowpalWabbitSerializerFactory.CreateSerializer<TActionDependentFeature>(manager.Settings) as VowpalWabbitSingleExampleSerializerCompiler<TActionDependentFeature>;
  193. if (adfSerializer == null)
  194. throw new ArgumentException(string.Format(
  195. "{0} maps to a multiline example. Use VowpalWabbitAsync<{0}> instead.",
  196. typeof(TActionDependentFeature)));
  197. this.serializers = this.manager.VowpalWabbits
  198. .Select(vw => serializer.Create(vw))
  199. .ToArray();
  200. this.actionDependentFeatureSerializers = this.manager.VowpalWabbits
  201. .Select(vw => adfSerializer.Create(vw))
  202. .ToArray();
  203. }
  204. /// <summary>
  205. /// Learn from the given example and return the current prediction for it.
  206. /// </summary>
  207. /// <param name="example">The shared example.</param>
  208. /// <param name="actionDependentFeatures">The action dependent features.</param>
  209. /// <param name="index">The index of the example to learn within <paramref name="actionDependentFeatures"/>.</param>
  210. /// <param name="label">The label for the example to learn.</param>
  211. public void Learn(TExample example, IReadOnlyCollection<TActionDependentFeature> actionDependentFeatures, int index, ILabel label)
  212. {
  213. Contract.Requires(example != null);
  214. Contract.Requires(actionDependentFeatures != null);
  215. Contract.Requires(index >= 0);
  216. Contract.Requires(label != null);
  217. manager.Post(vw => VowpalWabbitMultiLine.Learn(
  218. vw,
  219. this.serializers[vw.Settings.Node],
  220. this.actionDependentFeatureSerializers[vw.Settings.Node],
  221. example,
  222. actionDependentFeatures,
  223. index,
  224. label));
  225. }
  226. /// <summary>
  227. /// Learn from the given example and return the current prediction for it.
  228. /// </summary>
  229. /// <param name="example">The shared example.</param>
  230. /// <param name="actionDependentFeatures">The action dependent features.</param>
  231. /// <param name="index">The index of the example to learn within <paramref name="actionDependentFeatures"/>.</param>
  232. /// <param name="label">The label for the example to learn.</param>
  233. /// <returns>The ranked prediction for the given examples.</returns>
  234. public Task<ActionDependentFeature<TActionDependentFeature>[]> LearnAndPredict(TExample example, IReadOnlyCollection<TActionDependentFeature> actionDependentFeatures, int index, ILabel label)
  235. {
  236. Contract.Requires(example != null);
  237. Contract.Requires(actionDependentFeatures != null);
  238. Contract.Requires(index >= 0);
  239. Contract.Requires(label != null);
  240. return manager.Post(vw => VowpalWabbitMultiLine.LearnAndPredict(
  241. vw,
  242. this.serializers[vw.Settings.Node],
  243. this.actionDependentFeatureSerializers[vw.Settings.Node],
  244. example,
  245. actionDependentFeatures,
  246. index,
  247. label));
  248. }
  249. /// <summary>
  250. /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
  251. /// </summary>
  252. public void Dispose()
  253. {
  254. this.Dispose(true);
  255. GC.SuppressFinalize(this);
  256. }
  257. private void Dispose(bool disposing)
  258. {
  259. if (disposing)
  260. {
  261. if (this.serializers != null)
  262. {
  263. foreach (var serializer in this.serializers)
  264. {
  265. // free cached examples
  266. serializer.Dispose();
  267. }
  268. this.serializers = null;
  269. }
  270. if (this.actionDependentFeatureSerializers != null)
  271. {
  272. foreach (var serializer in this.actionDependentFeatureSerializers)
  273. {
  274. // free cached examples
  275. serializer.Dispose();
  276. }
  277. this.actionDependentFeatureSerializers = null;
  278. }
  279. }
  280. }
  281. }
  282. }