/tags/rel-0-4-3/FreeSpeech/NNet/src/NNetTrain.cc

# · C++ · 209 lines · 95 code · 41 blank · 73 comment · 20 complexity · 54fb5b7d41e0eea2252b50fed0d12ccc MD5 · raw file

  1. // Copyright (C) 1999 Jean-Marc Valin
  2. //
  3. // This program is free software; you can redistribute it and/or modify
  4. // it under the terms of the GNU General Public License as published by
  5. // the Free Software Foundation; either version 2, or (at your option)
  6. // any later version.
  7. //
  8. // This program is distributed in the hope that it will be useful, but
  9. // WITHOUT ANY WARRANTY; without even the implied warranty of
  10. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  11. // General Public License for more details.
  12. //
  13. // You should have received a copy of the GNU General Public License
  14. // along with this file. If not, write to the Free Software Foundation,
  15. // 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
  16. #include "Node.h"
  17. #include "ObjectRef.h"
  18. #include "FFNet.h"
  19. class NNetTrain;
  20. DECLARE_NODE(NNetTrain)
  21. /*Node
  22. *
  23. * @name NNetTrain
  24. * @category NNet
  25. * @description No description available
  26. *
  27. * @input_name TRAIN_IN
  28. * @input_description No description available
  29. *
  30. * @input_name TRAIN_OUT
  31. * @input_description No description available
  32. *
  33. * @input_name NNET
  34. * @input_description No description available
  35. *
  36. * @output_name OUTPUT
  37. * @output_description No description available
  38. *
  39. * @parameter_name MAX_EPOCH
  40. * @parameter_description No description available
  41. *
  42. * @parameter_name LEARN_RATE
  43. * @parameter_description No description available
  44. *
  45. * @parameter_name MOMENTUM
  46. * @parameter_description No description available
  47. *
  48. * @parameter_name INCREASE
  49. * @parameter_description No description available
  50. *
  51. * @parameter_name DECREASE
  52. * @parameter_description No description available
  53. *
  54. * @parameter_name ERR_RATIO
  55. * @parameter_description No description available
  56. *
  57. * @parameter_name BATCH_SETS
  58. * @parameter_description No description available
  59. *
  60. END*/
  61. class NNetTrain : public Node {
  62. protected:
  63. /**The ID of the 'trainIN' input*/
  64. int trainInID;
  65. /**The ID of the 'trainOut' input*/
  66. int trainOutID;
  67. /**The ID of the 'output' output*/
  68. int outputID;
  69. /**The ID of the 'nnet' input*/
  70. int netInputID;
  71. /**Reference to the current stream*/
  72. ObjectRef currentNet;
  73. int maxEpoch;
  74. double learnRate;
  75. double momentum;
  76. double decrease;
  77. double increase;
  78. double errRatio;
  79. int nbSets;
  80. public:
  81. /**Constructor, takes the name of the node and a set of parameters*/
  82. NNetTrain(string nodeName, ParameterSet params)
  83. : Node(nodeName, params)
  84. {
  85. outputID = addOutput("OUTPUT");
  86. netInputID = addInput("NNET");
  87. trainInID = addInput("TRAIN_IN");
  88. trainOutID = addInput("TRAIN_OUT");
  89. if (parameters.exist("MAX_EPOCH"))
  90. maxEpoch = dereference_cast<int> (parameters.get("MAX_EPOCH"));
  91. else maxEpoch = 200;
  92. if (parameters.exist("LEARN_RATE"))
  93. learnRate = dereference_cast<float> (parameters.get("LEARN_RATE"));
  94. else learnRate = .00001;
  95. if (parameters.exist("MOMENTUM"))
  96. momentum = dereference_cast<float> (parameters.get("MOMENTUM"));
  97. else momentum = .9;
  98. if (parameters.exist("INCREASE"))
  99. increase = dereference_cast<float> (parameters.get("INCREASE"));
  100. else increase = 1.05;
  101. if (parameters.exist("DECREASE"))
  102. decrease = dereference_cast<float> (parameters.get("DECREASE"));
  103. else decrease = .7;
  104. if (parameters.exist("ERR_RATIO"))
  105. errRatio = dereference_cast<float> (parameters.get("ERR_RATIO"));
  106. else errRatio = 1.04;
  107. if (parameters.exist("BATCH_SETS"))
  108. nbSets = dereference_cast<int> (parameters.get("BATCH_SETS"));
  109. else nbSets = 1;
  110. }
  111. /**Class specific initialization routine.
  112. Each class will call its subclass specificInitialize() method*/
  113. virtual void specificInitialize()
  114. {
  115. NodeInput trainInInput = inputs[trainInID];
  116. cerr << "in name = " << trainInInput.outputID << endl ;
  117. NodeInput trainOutInput = inputs[trainOutID];
  118. cerr << "out name = " << trainOutInput.outputID << endl;
  119. this->Node::specificInitialize();
  120. }
  121. /**Class reset routine.
  122. Each class will call its superclass reset() method*/
  123. virtual void reset()
  124. {
  125. this->Node::reset();
  126. }
  127. /**Ask for the node's output which ID (number) is output_id
  128. and for the 'count' iteration */
  129. virtual ObjectRef getOutput(int output_id, int count)
  130. {
  131. if (output_id==outputID)
  132. {
  133. if (count != processCount)
  134. {
  135. cerr << "getOutput in NNetTrain\n";
  136. int i,j;
  137. NodeInput trainInInput = inputs[trainInID];
  138. ObjectRef trainInValue = trainInInput.node->getOutput(trainInInput.outputID,count);
  139. NodeInput trainOutInput = inputs[trainOutID];
  140. ObjectRef trainOutValue = trainOutInput.node->getOutput(trainOutInput.outputID,count);
  141. NodeInput netInput = inputs[netInputID];
  142. ObjectRef netValue = netInput.node->getOutput(netInput.outputID,count);
  143. //cerr << "inputs calculated\n";
  144. Vector<ObjectRef> &inBuff = object_cast<Vector<ObjectRef> > (trainInValue);
  145. Vector<ObjectRef> &outBuff = object_cast<Vector<ObjectRef> > (trainOutValue);
  146. //cerr << "inputs converted\n";
  147. vector <float *> in(inBuff.size());
  148. for (i=0;i<inBuff.size();i++)
  149. in[i]=&object_cast <Vector<float> > (inBuff[i])[0];
  150. vector <float *> out(outBuff.size());
  151. for (i=0;i<outBuff.size();i++)
  152. out[i]=&object_cast <Vector<float> > (outBuff[i])[0];
  153. //FFNet *net = new FFNet( topo );
  154. FFNet &net = object_cast<FFNet> (netValue);
  155. net.train(in, out, maxEpoch, learnRate, momentum, increase, decrease, errRatio, nbSets);
  156. //net->trainlm(in, out, maxEpoch);
  157. currentNet = netValue;
  158. //exit(1);
  159. }
  160. return currentNet;
  161. }
  162. else
  163. throw new NodeException (this, "NNetTrain: Unknown output id", __FILE__, __LINE__);
  164. }
  165. protected:
  166. /**Default constructor, should not be used*/
  167. NNetTrain() {throw new GeneralException("NNetTrain copy constructor should not be called",__FILE__,__LINE__);}
  168. };