/tags/rel-0-4-0/FreeSpeech/VQ/src/VQTrain.cc

# · C++ · 107 lines · 56 code · 12 blank · 39 comment · 6 complexity · 75bc253ea1b1565fa9e773db1e9656fc MD5 · raw file

  1. // Copyright (C) 1999 Jean-Marc Valin
  2. //
  3. // This program is free software; you can redistribute it and/or modify
  4. // it under the terms of the GNU General Public License as published by
  5. // the Free Software Foundation; either version 2, or (at your option)
  6. // any later version.
  7. //
  8. // This program is distributed in the hope that it will be useful, but
  9. // WITHOUT ANY WARRANTY; without even the implied warranty of
  10. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  11. // General Public License for more details.
  12. //
  13. // You should have received a copy of the GNU General Public License
  14. // along with this file. If not, write to the Free Software Foundation,
  15. // 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
  16. #include "VQTrain.h"
  17. #include "net_types.h"
  18. #include "GrowingBuffer.h"
  19. #include "kmeans.h"
  20. #include "Vector.h"
  21. //#include "multithread.h"
  22. DECLARE_NODE(VQTrain)
  23. /*Node
  24. *
  25. * @name VQTrain
  26. * @category VQ
  27. * @description No description available
  28. *
  29. * @input_name FRAMES
  30. * @input_description No description available
  31. *
  32. * @output_name OUTPUT
  33. * @output_description No description available
  34. *
  35. * @parameter_name MEANS
  36. * @parameter_description No description available
  37. *
  38. * @parameter_name BINARY
  39. * @parameter_description No description available
  40. *
  41. END*/
  42. VQTrain::VQTrain(string nodeName, ParameterSet params)
  43. : Node(nodeName, params)
  44. {
  45. try {
  46. //cerr << "VQTrain initialize\n";
  47. outputID = addOutput("OUTPUT");
  48. framesInputID = addInput("FRAMES");
  49. //cerr << "VQTrain initialization done\n";
  50. nbMeans = dereference_cast<int> (parameters.get("MEANS"));
  51. } catch (BaseException *e)
  52. {
  53. //e->print(cerr);
  54. throw e->add(new NodeException(NULL, "Exception caught in VQTrain constructor", __FILE__, __LINE__));
  55. }
  56. }
  57. void VQTrain::specificInitialize()
  58. {
  59. this->Node::specificInitialize();
  60. }
  61. void VQTrain::reset()
  62. {
  63. this->Node::reset();
  64. }
  65. ObjectRef VQTrain::getOutput(int output_id, int count)
  66. {
  67. //cerr << "Getting output in VQTrain\n";
  68. if (output_id==outputID)
  69. {
  70. if (count != processCount)
  71. {
  72. bool binary = false;
  73. if (parameters.exist("BINARY"))
  74. binary = dereference_cast<bool> (parameters.get("BINARY"));
  75. int i;
  76. NodeInput framesInput = inputs[framesInputID];
  77. cerr << "getting frames..." << endl;
  78. ObjectRef matRef = framesInput.node->getOutput(framesInput.outputID,count);
  79. cerr << "got frames..." << endl;
  80. GrowingBuffer &mat = object_cast<GrowingBuffer> (matRef);
  81. KMeans *vq = new KMeans;
  82. vector <float *> data(mat.getCurrentPos()+1);
  83. for (i=0;i<=mat.getCurrentPos();i++)
  84. data[i]= &object_cast <Vector<float> > (mat[i])[0];
  85. int length = object_cast <Vector<float> > (mat[0]).size();
  86. cerr << "training..." << endl;
  87. vq->train(nbMeans,data,length,binary);
  88. cerr << "training complete." << endl;
  89. current = ObjectRef(vq);
  90. }
  91. return current;
  92. }
  93. else
  94. throw new NodeException (this, "VQTrain: Unknown output id", __FILE__, __LINE__);
  95. }