PageRenderTime 65ms CodeModel.GetById 40ms RepoModel.GetById 1ms app.codeStats 0ms

/mona/src/pong/pongLens.cpp

#
C++ | 335 lines | 286 code | 31 blank | 18 comment | 64 complexity | b31dd3247f5abd7eace70ed735e4b3a7 MD5 | raw file
  1. // For conditions of distribution and use, see copyright notice in pong.hpp
  2. // Pong train and test with Lens NN.
  3. // Reference: http://web.stanford.edu/group/mbc/LENSManual/Manual
  4. #ifdef __cplusplus
  5. extern "C" {
  6. #include "../../lens/Src/lens.h"
  7. }
  8. #endif
  9. #ifdef WIN32
  10. #include <process.h>
  11. #endif
  12. #include <stdio.h>
  13. #include <stdlib.h>
  14. #include <string.h>
  15. #include <assert.h>
  16. #include <vector>
  17. using namespace std;
  18. // Pong game examples files.
  19. char *PongTrainingExamplesFile = NULL;
  20. char *PongTestingExamplesFile = NULL;
  21. char *Usage[] =
  22. {
  23. (char *)"pong_lens\n",
  24. (char *)" -trainingExamples <Pong game training examples file>\n",
  25. (char *)" [-trainingEpochs <number of training epochs>]\n",
  26. (char *)" -testingExamples <Pong game testing examples file>\n",
  27. NULL
  28. };
  29. void printUsage()
  30. {
  31. for (int i = 0; Usage[i] != NULL; i++)
  32. {
  33. fprintf(stderr, "%s", Usage[i]);
  34. }
  35. }
  36. // Number of training epochs.
  37. int TrainingEpochs = 1;
  38. // Lens parameters.
  39. int HIDDEN_UNITS = 20;
  40. real LEARNING_RATE = 0.2f;
  41. // Lens callback.
  42. void lensCallback();
  43. bool PrintCallback = false;
  44. // Testing target outputs.
  45. vector<vector<vector<float> > > TestingTargetOutputs;
  46. int TestNumber;
  47. int TickCounter;
  48. float ErrorAccumulator;
  49. int GameCorrectLengthAccum;
  50. int GameLengthAccum;
  51. bool GameError;
  52. int GameCorrectLength;
  53. int GameLength;
  54. int
  55. main(int argc, char *argv[])
  56. {
  57. int i, n, s, intervals;
  58. FILE *fp;
  59. char buf[BUFSIZ];
  60. int t[7];
  61. for (i = 1; i < argc; i++)
  62. {
  63. if (strcmp(argv[i], "-trainingExamples") == 0)
  64. {
  65. i++;
  66. if (i >= argc)
  67. {
  68. printUsage();
  69. return(1);
  70. }
  71. PongTrainingExamplesFile = argv[i];
  72. continue;
  73. }
  74. if (strcmp(argv[i], "-trainingEpochs") == 0)
  75. {
  76. i++;
  77. if (i >= argc)
  78. {
  79. printUsage();
  80. return(1);
  81. }
  82. TrainingEpochs = atoi(argv[i]);
  83. if (TrainingEpochs < 0)
  84. {
  85. printUsage();
  86. return(1);
  87. }
  88. continue;
  89. }
  90. if (strcmp(argv[i], "-testingExamples") == 0)
  91. {
  92. i++;
  93. if (i >= argc)
  94. {
  95. printUsage();
  96. return(1);
  97. }
  98. PongTestingExamplesFile = argv[i];
  99. continue;
  100. }
  101. printUsage();
  102. return(1);
  103. }
  104. if ((PongTrainingExamplesFile == NULL) || (PongTestingExamplesFile == NULL))
  105. {
  106. printUsage();
  107. return(1);
  108. }
  109. // Start lens.
  110. if (startLens(argv[0]))
  111. {
  112. fprintf(stderr, "Lens failed to start\n");
  113. return(1);
  114. }
  115. // Determine max intervals.
  116. if ((fp = fopen(PongTrainingExamplesFile, "r")) == NULL)
  117. {
  118. fprintf(stderr, "Cannot open %s\n", PongTrainingExamplesFile);
  119. return(1);
  120. }
  121. intervals = 0;
  122. while (fgets(buf, BUFSIZ, fp) != NULL)
  123. {
  124. if (strncmp(buf, "name:", 5) == 0)
  125. {
  126. for (i = 0; buf[i] != '}' && buf[i] != '\0'; i++)
  127. {
  128. }
  129. if (buf[i] == '}')
  130. {
  131. i++;
  132. i = atoi(&buf[i]);
  133. if (intervals < i)
  134. {
  135. intervals = i;
  136. }
  137. }
  138. }
  139. }
  140. fclose(fp);
  141. if ((fp = fopen(PongTestingExamplesFile, "r")) == NULL)
  142. {
  143. fprintf(stderr, "Cannot open %s\n", PongTestingExamplesFile);
  144. return(1);
  145. }
  146. n = -1;
  147. while (fgets(buf, BUFSIZ, fp) != NULL)
  148. {
  149. if (strncmp(buf, "name:", 5) == 0)
  150. {
  151. for (i = 0; buf[i] != '}' && buf[i] != '\0'; i++)
  152. {
  153. }
  154. if (buf[i] == '}')
  155. {
  156. i++;
  157. i = atoi(&buf[i]);
  158. if (intervals < i)
  159. {
  160. intervals = i;
  161. }
  162. n++;
  163. TestingTargetOutputs.resize(n + 1);
  164. s = -1;
  165. }
  166. }
  167. else
  168. {
  169. s++;
  170. TestingTargetOutputs[n].resize(s + 1);
  171. for (i = 0; buf[i] != 'T' && buf[i] != '\0'; i++)
  172. {
  173. }
  174. if (buf[i] == 'T')
  175. {
  176. i += 3;
  177. sscanf(&buf[i], "%d %d %d %d %d %d %d", &t[0], &t[1], &t[2], &t[3], &t[4], &t[5], &t[6]);
  178. for (i = 0; i < 7; i++)
  179. {
  180. TestingTargetOutputs[n][s].push_back((float)t[i]);
  181. }
  182. }
  183. }
  184. }
  185. fclose(fp);
  186. // Load pong examples.
  187. suppressLensOutput = 0;
  188. sprintf(buf, (char *)"loadExamples %s -set pong_training_examples -exmode PERMUTED", PongTrainingExamplesFile);
  189. lens(buf);
  190. sprintf(buf, (char *)"loadExamples %s -set pong_testing_examples", PongTestingExamplesFile);
  191. lens(buf);
  192. // Create LENS NN.
  193. sprintf(buf, (char *)"addNet pong_net -i %d 8 %d ELMAN 7 SOFT_MAX", intervals, HIDDEN_UNITS);
  194. lens(buf);
  195. // Train.
  196. sprintf(buf, (char *)"useTrainingSet pong_training_examples");
  197. lens(buf);
  198. sprintf(buf, (char *)"setObj learningRate %f", LEARNING_RATE);
  199. lens(buf);
  200. sprintf(buf, (char *)"setObj numUpdates %d", TrainingEpochs);
  201. lens(buf);
  202. lens((char *)"train");
  203. // Set up LENS output callback.
  204. netInputs = new real *[8];
  205. assert(netInputs != NULL);
  206. netOutputs = new real *[7];
  207. assert(netOutputs != NULL);
  208. clientProc = lensCallback;
  209. suppressLensOutput = 0;
  210. PrintCallback = true;
  211. TestNumber = -1;
  212. TickCounter = 0;
  213. ErrorAccumulator = 0.0f;
  214. GameCorrectLengthAccum = 0;
  215. GameLengthAccum = 0;
  216. GameError = false;
  217. GameCorrectLength = 0;
  218. GameLength = 0;
  219. // Test.
  220. sprintf(buf, (char *)"useTestingSet pong_testing_examples");
  221. lens(buf);
  222. lens((char *)"test");
  223. GameCorrectLengthAccum += GameCorrectLength;
  224. GameLengthAccum += GameLength;
  225. // Print mean error.
  226. printf("Mean error = ");
  227. if (TickCounter > 0)
  228. {
  229. printf("%f", ErrorAccumulator / (float)TickCounter);
  230. }
  231. else
  232. {
  233. printf("unavailable");
  234. }
  235. printf("\n");
  236. printf("Mean correct in game play sequence = ");
  237. if (GameLengthAccum > 0)
  238. {
  239. printf("%f", (float)GameCorrectLengthAccum / (float)GameLengthAccum);
  240. }
  241. else
  242. {
  243. printf("unavailable");
  244. }
  245. printf("\n");
  246. // Drop callback.
  247. delete netInputs;
  248. netInputs = NULL;
  249. delete netOutputs;
  250. netInputs = NULL;
  251. clientProc = NULL;
  252. return(0);
  253. }
  254. // Lens callback.
  255. void lensCallback()
  256. {
  257. int i;
  258. float error;
  259. if (exampleTick == 0)
  260. {
  261. TestNumber++;
  262. GameCorrectLengthAccum += GameCorrectLength;
  263. GameLengthAccum += GameLength;
  264. GameError = false;
  265. GameCorrectLength = 0;
  266. GameLength = 0;
  267. }
  268. TickCounter++;
  269. error = 0.0f;
  270. for (i = 0; i < 7; i++)
  271. {
  272. error += fabs(*netOutputs[i] - TestingTargetOutputs[TestNumber][exampleTick][i]);
  273. }
  274. ErrorAccumulator += error;
  275. GameLength++;
  276. if (error > 0.25f)
  277. {
  278. GameError = true;
  279. }
  280. if (!GameError)
  281. {
  282. GameCorrectLength++;
  283. }
  284. if (PrintCallback)
  285. {
  286. printf("tick=%d\n", exampleTick);
  287. printf("input: ");
  288. for (i = 0; i < 8; i++)
  289. {
  290. printf("%f ", *netInputs[i]);
  291. }
  292. printf("output: ");
  293. for (i = 0; i < 7; i++)
  294. {
  295. printf("%f ", *netOutputs[i]);
  296. }
  297. printf("target: ");
  298. for (i = 0; i < 7; i++)
  299. {
  300. printf("%f ", TestingTargetOutputs[TestNumber][exampleTick][i]);
  301. }
  302. printf("\n");
  303. }
  304. }