PageRenderTime 24ms CodeModel.GetById 14ms RepoModel.GetById 0ms app.codeStats 0ms

/tags/cvs_final/octave-forge/main/nnet/tests/MLP/MLP9_2_1.m

#
MATLAB | 126 lines | 118 code | 8 blank | 0 comment | 2 complexity | aa550454e71f3cc929cb4c3f45126090 MD5 | raw file
Possible License(s): GPL-2.0, BSD-3-Clause, LGPL-2.1, GPL-3.0, LGPL-3.0
  1. ## Copyright (C) 2006 Michel D. Schmid <michaelschmid@users.sourceforge.net>
  2. ##
  3. ##
  4. ## This program is free software; you can redistribute it and/or modify it
  5. ## under the terms of the GNU General Public License as published by
  6. ## the Free Software Foundation; either version 2, or (at your option)
  7. ## any later version.
  8. ##
  9. ## This program is distributed in the hope that it will be useful, but
  10. ## WITHOUT ANY WARRANTY; without even the implied warranty of
  11. ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  12. ## General Public License for more details.
  13. ##
  14. ## You should have received a copy of the GNU General Public License
  15. ## along with this program; see the file COPYING. If not, write to the Free
  16. ## Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
  17. ## 02110-1301, USA.
  18. ## This is a test to train a 9-1-1 MLP (was a real project).
  19. ## author: Michel D. Schmid <michaelschmid@users.sourceforge.net>
  20. ## for debug purpose only
  21. global DEBUG = 0;
  22. ## comments to DEBUG:
  23. # 0 or not exist means NO DEBUG
  24. # 1 means, DEBUG, write to command window
  25. # 2 means, DEBUG, write to files...
  26. ## load data
  27. mData = load("mData.txt","mData");
  28. mData = mData.mData;
  29. [nRows, nColumns] = size(mData);
  30. # this file contains 13 columns. The first 12 columns are the inputs
  31. # the last column is the output
  32. # remove column 4, 8 and 12
  33. # 89 rows
  34. ## first permute the whole matrix in row wise
  35. ## this won't be used right now for debug and test purpose
  36. # order = randperm(nRows);
  37. # mData(order,:) = mData;
  38. mOutput = mData(:,end);
  39. mInput = mData(:,1:end-1);
  40. mInput(:,[4 8 12]) = []; # delete column 4, 8 and 12
  41. ## now prepare data
  42. mInput = mInput';
  43. mOutput = mOutput';
  44. %mOutput = [mOutput; mOutput*4];
  45. # now split the data matrix in 3 pieces, train data, test data and validate data
  46. # the proportion should be about 1/2 train, 1/3 test and 1/6 validate data
  47. # in this neural network we have 12 weights, for each weight at least 3 train sets..
  48. # (that's a rule of thumb like 1/2, 1/3 and 1/6)
  49. # 1/2 of 89 = 44.5; let's take 44 for training
  50. nTrainSets = floor(nRows/2);
  51. # now the rest of the sets are again 100%
  52. # ==> 2/3 for test sets and 1/3 for validate sets
  53. nTestSets = (nRows-nTrainSets)/3*2;
  54. nValiSets = nRows-nTrainSets-nTestSets;
  55. mValiInput = mInput(:,1:nValiSets);
  56. mValliOutput = mOutput(:,1:nValiSets);
  57. mInput(:,1:nValiSets) = [];
  58. mOutput(:,1:nValiSets) = [];
  59. mTestInput = mInput(:,1:nTestSets);
  60. mTestOutput = mOutput(:,1:nTestSets);
  61. mInput(:,1:nTestSets) = [];
  62. mOutput(:,1:nTestSets) = [];
  63. mTrainInput = mInput(:,1:nTrainSets);
  64. mTrainOutput = mOutput(:,1:nTrainSets);
  65. [mTrainInputN,cMeanInput,cStdInput] = prestd(mTrainInput);# standardize inputs
  66. ## comments: there is no reason to standardize the outputs because we have only
  67. # one output ...
  68. # define the max and min inputs for each row
  69. mMinMaxElements = min_max(mTrainInputN); % input matrix with (R x 2)...
  70. ## define network
  71. nHiddenNeurons = 2;
  72. nOutputNeurons = 1;
  73. MLPnet = newff(mMinMaxElements,[nHiddenNeurons nOutputNeurons],{"tansig","purelin"},"trainlm","learngdm","mse");
  74. ## for test purpose, define weights by hand
  75. MLPnet.IW{1,1}(1,:) = 0.5;
  76. MLPnet.IW{1,1}(2,:) = 1.5;
  77. MLPnet.LW{2,1}(:) = 0.5;
  78. MLPnet.b{1,1}(1,:) = 0.5;
  79. MLPnet.b{1,1}(2,:) = 1.5;
  80. MLPnet.b{2,1}(:) = 0.5;
  81. saveMLPStruct(MLPnet,"MLP3test.txt");
  82. #disp("network structure saved, press any key to continue...")
  83. #pause
  84. ## define validation data new, for matlab compatibility
  85. VV.P = mValiInput;
  86. VV.T = mValliOutput;
  87. ## standardize also the validate data
  88. VV.P = trastd(VV.P,cMeanInput,cStdInput);
  89. #[net,tr,out,E] = train(MLPnet,mInputN,mOutput,[],[],VV);
  90. [net] = train(MLPnet,mTrainInputN,mTrainOutput,[],[],VV);
  91. # saveMLPStruct(net,"MLP3testNachTraining.txt");
  92. # disp("network structure saved, press any key to continue...")
  93. # pause
  94. # % the names in matlab help, see "train" in help for more informations
  95. # tr.perf(max(tr.epoch)+1);
  96. #
  97. # % make preparations for net test and test MLPnet
  98. # % standardise input & output test data
  99. [mTestInputN] = trastd(mTestInput,cMeanInput,cStdInput);
  100. # % [mTestOutputN] = trastd(mTestOutput,cMeanOutput,cStdOutput);
  101. # % define unused parameters to get E-variable of simulation
  102. # Pi = zeros(6,0);
  103. # Ai = zeros(6,0);
  104. # % simulate net
  105. [simOut] = sim(net,mTestInputN);%,Pi,Ai,mTestOutput);
  106. simOut