/encog-core-silverlight/encog-core-silverlight/Neural/Networks/Training/Concurrent/ConcurrentTrainingManager.cs

http://encog-cs.googlecode.com/ · C# · 387 lines · 236 code · 52 blank · 99 comment · 20 complexity · ce3615b6734fe356f8408691c1cc53ac MD5 · raw file

  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using System.Threading;
  6. using Encog.Neural.Networks.Training.Concurrent.Jobs;
  7. using Encog.Neural.Networks.Training.Concurrent.Performers;
  8. using Encog.Engine;
  9. using Encog.Neural.NeuralData;
  10. using Encog.Neural.Networks.Training.Strategy;
  11. using Encog.Engine.Network.Train.Prop;
  12. using Encog.Engine.Opencl;
  13. namespace Encog.Neural.Networks.Training.Concurrent
  14. {
  15. /// <summary>
  16. /// Concurrent training manager. This class allows you to queue up network
  17. /// training tasks to be executed either by the CPU cores or OpenCL devices. This
  18. /// allows the CPU/GPU to train neural networks at the same time.
  19. /// </summary>
  20. public class ConcurrentTrainingManager
  21. {
  22. /// <summary>
  23. /// The singleton instance.
  24. /// </summary>
  25. private static ConcurrentTrainingManager instance;
  26. /// <summary>
  27. /// The event used to sync waiting for tasks to stop.
  28. /// </summary>
  29. private Object accessLock = new Object();
  30. /// <summary>
  31. /// Condition used to check if we are done.
  32. /// </summary>
  33. private ManualResetEvent mightBeDone = new ManualResetEvent(false);
  34. /// <summary>
  35. /// The job number.
  36. /// </summary>
  37. private int jobNumber;
  38. /// <summary>
  39. /// True, if this should be ran single threaded.
  40. /// </summary>
  41. public bool SingleThreaded { get; set; }
  42. /// <summary>
  43. /// The singleton instance.
  44. /// </summary>
  45. public static ConcurrentTrainingManager Instance
  46. {
  47. get
  48. {
  49. if (ConcurrentTrainingManager.instance == null)
  50. {
  51. ConcurrentTrainingManager.instance = new ConcurrentTrainingManager();
  52. }
  53. return ConcurrentTrainingManager.instance;
  54. }
  55. }
  56. /// <summary>
  57. /// The performers to use.
  58. /// </summary>
  59. private IList<IConcurrentTrainingPerformer> performers = new List<IConcurrentTrainingPerformer>();
  60. /// <summary>
  61. /// The training jobs to execute.
  62. /// </summary>
  63. private IList<TrainingJob> queue = new List<TrainingJob>();
  64. /// <summary>
  65. /// The background thread.
  66. /// </summary>
  67. private Thread thread;
  68. /// <summary>
  69. /// An object used to report status.
  70. /// </summary>
  71. private IStatusReportable report = new NullStatusReportable();
  72. /// <summary>
  73. /// Private constructor.
  74. /// </summary>
  75. private ConcurrentTrainingManager()
  76. {
  77. }
  78. /// <summary>
  79. /// Add a performer.
  80. /// </summary>
  81. /// <param name="performer">The performer to add.</param>
  82. public void AddPerformer(IConcurrentTrainingPerformer performer)
  83. {
  84. lock (this.accessLock)
  85. {
  86. this.performers.Add(performer);
  87. performer.Manager = this;
  88. }
  89. }
  90. /// <summary>
  91. /// Add a training job.
  92. /// </summary>
  93. /// <param name="job">The training job to add.</param>
  94. public void AddTrainingJob(TrainingJob job)
  95. {
  96. lock (this.accessLock)
  97. {
  98. this.queue.Add(job);
  99. }
  100. }
  101. /// <summary>
  102. /// Clear all of the performers.
  103. /// </summary>
  104. public void ClearPerformers()
  105. {
  106. lock (this.accessLock)
  107. {
  108. this.performers.Clear();
  109. }
  110. }
  111. /// <summary>
  112. /// Clear the workload.
  113. /// </summary>
  114. public void ClearQueue()
  115. {
  116. lock (this.accessLock)
  117. {
  118. this.queue.Clear();
  119. }
  120. }
  121. /// <summary>
  122. /// Detect performers. Create one performer for each OpenCL device, and
  123. /// another for the CPU's. If there is an OpenCL device already for the CPU,
  124. /// do not create another CPU performer.
  125. /// </summary>
  126. public void DetectPerformers()
  127. {
  128. DetectPerformers(false, 0);
  129. }
  130. /// <summary>
  131. /// Detect performers. Create one performer for each OpenCL device, and
  132. /// another for the CPU's. If there is an OpenCL device already for the CPU,
  133. /// do not create another CPU performer.
  134. /// </summary>
  135. /// <param name="splitCores">True, if a CPU performer should be created for each core.</param>
  136. /// <param name="forceCoreCount">The core count to be forced.</param>
  137. public void DetectPerformers(bool splitCores, int forceCoreCount)
  138. {
  139. lock (this.accessLock)
  140. {
  141. bool useCPU = true;
  142. ClearPerformers();
  143. int clCount = 1;
  144. int cpuCount = 1;
  145. this.SingleThreaded = splitCores;
  146. #if !SILVERLIGHT
  147. // handle OpenCL mode
  148. if (EncogFramework.Instance.CL != null)
  149. {
  150. // should we let OpenCL run the CPU?
  151. if (EncogFramework.Instance.CL.AreCPUsPresent())
  152. {
  153. useCPU = false;
  154. }
  155. // add a performer for each OpenCL device.
  156. foreach (EncogCLDevice device in EncogFramework.Instance.CL
  157. .Devices)
  158. {
  159. AddPerformer(new ConcurrentTrainingPerformerOpenCL(clCount++, device));
  160. }
  161. }
  162. #endif
  163. // now create CPU performers
  164. if (useCPU && forceCoreCount >= 0)
  165. {
  166. int threads;
  167. if (splitCores)
  168. {
  169. if (forceCoreCount > 0)
  170. threads = forceCoreCount;
  171. else
  172. threads = Environment.ProcessorCount;
  173. }
  174. else
  175. {
  176. threads = 1;
  177. }
  178. for (int i = 0; i < threads; i++)
  179. {
  180. AddPerformer(new ConcurrentTrainingPerformerCPU(cpuCount++));
  181. }
  182. }
  183. }
  184. }
  185. /// <summary>
  186. /// Wait for all tasks to finish.
  187. /// </summary>
  188. public void Join()
  189. {
  190. this.thread.Join();
  191. }
  192. /// <summary>
  193. /// If an error has been reported, then throw it as an exception.
  194. /// </summary>
  195. private void ReportErrors()
  196. {
  197. foreach (TrainingJob job in this.queue)
  198. {
  199. if (job.Error != null)
  200. {
  201. throw new NeuralNetworkError(job.Error);
  202. }
  203. }
  204. }
  205. /// <summary>
  206. /// Perform the training. Called internally.
  207. /// </summary>
  208. public void Run()
  209. {
  210. this.jobNumber = 0;
  211. this.report.Report(this.queue.Count, 0, "Starting first job");
  212. int count = 0;
  213. foreach (TrainingJob job in this.queue)
  214. {
  215. // find a performer
  216. WaitForFreePerformer(job);
  217. count++;
  218. ReportErrors();
  219. }
  220. // now wait for all performers to finish
  221. bool done = false;
  222. this.report.Report(this.queue.Count, count,
  223. "No more jobs to submit, waiting for last job.");
  224. while (!done)
  225. {
  226. lock (this.accessLock)
  227. {
  228. bool foundOne = false;
  229. foreach (IConcurrentTrainingPerformer performer in this.performers)
  230. {
  231. if (!performer.Ready)
  232. {
  233. foundOne = true;
  234. }
  235. }
  236. if (foundOne)
  237. {
  238. this.mightBeDone.WaitOne();
  239. }
  240. else
  241. {
  242. done = true;
  243. }
  244. }
  245. }
  246. this.report.Report(this.queue.Count, count, "All training done.");
  247. }
  248. /// <summary>
  249. /// Report the status.
  250. /// </summary>
  251. /// <param name="str">The status to report.</param>
  252. private void ReportStatus(String str)
  253. {
  254. this.report.Report(this.queue.Count, jobNumber, str);
  255. }
  256. /// <summary>
  257. /// Setup the object to report status to.
  258. /// </summary>
  259. public IStatusReportable Report
  260. {
  261. get
  262. {
  263. return this.report;
  264. }
  265. set
  266. {
  267. this.report = value;
  268. }
  269. }
  270. /// <summary>
  271. /// Start the manager.
  272. /// </summary>
  273. public void Start()
  274. {
  275. this.thread = new Thread(this.Run);
  276. this.thread.Start();
  277. }
  278. /// <summary>
  279. /// Wait for a free performer.
  280. /// </summary>
  281. /// <returns>The free performer.</returns>
  282. public IConcurrentTrainingPerformer WaitForFreePerformer(TrainingJob job)
  283. {
  284. IConcurrentTrainingPerformer result = null;
  285. while (result == null)
  286. {
  287. foreach (IConcurrentTrainingPerformer performer in this.performers)
  288. {
  289. if (performer.Ready)
  290. {
  291. lock (this.accessLock)
  292. {
  293. performer.Perform(job);
  294. result = performer;
  295. }
  296. }
  297. }
  298. if (result == null)
  299. {
  300. this.mightBeDone.WaitOne();
  301. }
  302. }
  303. return result;
  304. }
  305. /// <summary>
  306. /// Report that a job is done.
  307. /// </summary>
  308. /// <param name="time">Time took by the job.</param>
  309. /// <param name="perf">The perfofmer that did the job.</param>
  310. public void JobDone(long time, ConcurrentTrainingPerformerCPU perf)
  311. {
  312. lock (this.accessLock)
  313. {
  314. this.jobNumber++;
  315. this.ReportStatus("Job finished in " + time + "ms, on " + perf.ToString());
  316. this.mightBeDone.Set();
  317. }
  318. }
  319. /// <inheritdoc/>
  320. public override String ToString()
  321. {
  322. StringBuilder builder = new StringBuilder();
  323. int index = 1;
  324. foreach (IConcurrentTrainingPerformer performer in this.performers)
  325. {
  326. builder.Append("Performer ");
  327. builder.Append(index++);
  328. builder.Append(": ");
  329. builder.Append(performer.ToString());
  330. builder.Append("\n");
  331. }
  332. return builder.ToString();
  333. }
  334. }
  335. }