/ru/experimental/nnetwork.h

https://bitbucket.org/VladimirL/robotutils · C Header · 87 lines · 58 code · 17 blank · 12 comment · 2 complexity · 07d981284081ada2ef9a5b5caba09d12 MD5 · raw file

  1. #ifndef NLAYER_H
  2. #define NLAYER_H
  3. #include <stdlib.h>
  4. #include <vector>
  5. #include <cmath>
  6. #include <QDebug>
  7. #include <QString>
  8. using namespace std;
  9. const double sigma_a = -2.0; // -a
  10. //const double speed = 1.0;
  11. //! ?????????????? ? ????????? 0..1
  12. inline double normalize(double val, double min, double max)
  13. {
  14. if (val>max) val=max;
  15. if (val<min) val=min;
  16. return ((val-min)/(max-min));
  17. }
  18. //! ???????? ??????????????
  19. inline double denormalize(double val, double min, double max)
  20. {
  21. return (val*(max-min) + min);
  22. }
  23. //! ????????? ?????? ??? ?????????? ?????????
  24. class Dataset
  25. {
  26. public:
  27. Dataset(int rows, int cols);
  28. ~Dataset();
  29. int rows;
  30. int cols;
  31. int row_width; // ?????? ?????? ? ?????????
  32. double *data;
  33. };
  34. //! ???? ????????
  35. class NLayer
  36. {
  37. public:
  38. NLayer(int percept_count, NLayer *previous_layer = NULL);
  39. NLayer(const NLayer& layer); // prev_layer ?????????? ?????!
  40. ~NLayer();
  41. double *weights; // ???????? ?????. ????????? = b, ??? ???? xn = +1
  42. int percept_cnt; // ?????????? ????????
  43. int weights_cnt; // ????? ?????????? ?????
  44. double *values; // ???????? ???????? ????????
  45. double *temp_val; // ????????? ???????? ???????. ??? ????????? ????????
  46. NLayer *prev_layer; // ?????????? ????. NULL - ???????
  47. int activation_func; // 0 - ?????????????, 1 - ???????
  48. };
  49. //!< ????????? ????
  50. class NNetwork
  51. {
  52. public:
  53. //! ?????????? ???????? ?? ??????? ????, ?????? ?????????? ???????? ? ??????? ?????, ?????????? ???????? ? ???????? ????
  54. NNetwork(int inputs, vector<int> hidden_layers, int outputs);
  55. NNetwork(const NNetwork& nn); //! ???????????
  56. ~NNetwork();
  57. //! ?????????? ???????? ???????? ????
  58. vector<double> calc_output(vector<double>& input);
  59. //! ??????? ????????? (???? ????????)
  60. void train(const Dataset &dataset, double speed);
  61. //! ?????? ?????????????
  62. void dumpCoeff();
  63. //! ????????? ????????? ????????????? ???? (??? ????????????? ?????????)
  64. void mutate(double amp, int freq);
  65. protected:
  66. void updateState();
  67. // ????
  68. vector<NLayer*> layers;
  69. };
  70. #endif // NLAYER_H