PageRenderTime 27ms CodeModel.GetById 19ms app.highlight 7ms RepoModel.GetById 0ms app.codeStats 0ms

/ru/experimental/nnetwork.h

https://bitbucket.org/VladimirL/robotutils
C Header | 87 lines | 58 code | 17 blank | 12 comment | 2 complexity | 07d981284081ada2ef9a5b5caba09d12 MD5 | raw file
 1#ifndef NLAYER_H
 2#define NLAYER_H
 3
 4#include <stdlib.h>
 5#include <vector>
 6#include <cmath>
 7
 8#include <QDebug>
 9#include <QString>
10
11using namespace std;
12
13const double sigma_a = -2.0;   // -a
14//const double speed = 1.0;
15
16//! ?????????????? ? ????????? 0..1
17inline double normalize(double val, double min, double max)
18{
19    if (val>max) val=max;
20    if (val<min) val=min;
21    return ((val-min)/(max-min));
22}
23
24//! ???????? ??????????????
25inline double denormalize(double val, double min, double max)
26{
27    return (val*(max-min) + min);
28}
29
30//! ????????? ?????? ??? ?????????? ?????????
31class Dataset
32{
33public:
34    Dataset(int rows, int cols);
35    ~Dataset();
36    int rows;
37    int cols;
38    int row_width;  // ?????? ?????? ? ?????????
39    double *data;
40};
41
42//! ???? ????????
43class NLayer
44{
45public:
46    NLayer(int percept_count, NLayer *previous_layer = NULL);
47    NLayer(const NLayer& layer); // prev_layer ?????????? ?????!
48    ~NLayer();
49    double *weights;    // ???????? ?????. ????????? = b, ??? ???? xn = +1
50    int percept_cnt;    // ?????????? ????????
51    int weights_cnt;    // ????? ?????????? ?????
52    double *values;     // ???????? ???????? ????????
53    double *temp_val;   // ????????? ???????? ???????. ??? ????????? ????????
54
55    NLayer *prev_layer; // ?????????? ????. NULL - ???????
56    int activation_func; // 0 - ?????????????, 1 - ???????
57};
58
59//!< ????????? ????
60class NNetwork
61{
62public:
63    //! ?????????? ???????? ?? ??????? ????, ?????? ?????????? ???????? ? ??????? ?????, ?????????? ???????? ? ???????? ????
64    NNetwork(int inputs, vector<int> hidden_layers, int outputs);
65    NNetwork(const NNetwork& nn);   //! ???????????
66
67    ~NNetwork();
68
69    //! ?????????? ???????? ???????? ????
70    vector<double> calc_output(vector<double>& input);
71
72    //! ??????? ????????? (???? ????????)
73    void train(const Dataset &dataset, double speed);
74
75    //! ?????? ?????????????
76    void dumpCoeff();
77
78    //! ????????? ????????? ????????????? ???? (??? ????????????? ?????????)
79    void mutate(double amp, int freq);
80
81protected:
82    void updateState();
83    // ????
84    vector<NLayer*> layers;
85};
86
87#endif // NLAYER_H