#### /neurolab/error.py

https://github.com/liujiantong/neurolab
Python | 158 lines | 30 code | 18 blank | 110 comment | 0 complexity | 5bc21cdb3cd4dc5e8482231f3d98323f MD5 | raw file
```  1# -*- coding: utf-8 -*-
2""" Train error functions with derivatives
3
4    :Example:
5        >>> msef = MSE()
6        >>> x = np.array([[1.0, 0.0], [2.0, 0.0]])
7        >>> msef(x)
8        1.25
9        >>> # calc derivative:
10        >>> msef.deriv(x[0])
11        array([ 1.,  0.])
12
13"""
14
15import numpy as np
16
17
18class MSE():
19    """
20    Mean squared error function
21
22    :Parameters:
23        e: ndarray
24            current errors: target - output
25    :Returns:
26        v: float
27            Error value
28    :Example:
29        >>> f = MSE()
30        >>> x = np.array([[1.0, 0.0], [2.0, 0.0]])
31        >>> f(x)
32        1.25
33
34    """
35
36    def __call__(self, e):
37        N = e.size
38        v =  np.sum(np.square(e)) / N
39        return v
40
41    def deriv(self, e):
42        """
43        Derivative of MSE error function
44
45        :Parameters:
46            e: ndarray
47                current errors: target - output
48        :Returns:
49            d: ndarray
50                Derivative: dE/d_out
51        :Example:
52            >>> f = MSE()
53            >>> x = np.array([1.0, 0.0])
54            >>> # calc derivative:
55            >>> f.deriv(x)
56            array([ 1.,  0.])
57
58        """
59
60        N = len(e)
61        d = e * (2 / N)
62        return d
63
64
65class SSE:
66    """
67    Sum squared error function
68
69    :Parameters:
70        e: ndarray
71            current errors: target - output
72    :Returns:
73        v: float
74            Error value
75
76    """
77
78    def __call__(self, e):
79        v = 0.5 * np.sum(np.square(e))
80        return v
81
82    def deriv(self, e):
83        """
84        Derivative of SSE error function
85
86        :Parameters:
87            e: ndarray
88                current errors: target - output
89        :Returns:
90            d: ndarray
91                Derivative: dE/d_out
92
93        """
94        return e
95
96
97class SAE:
98    """
99    Sum absolute error function
100
101    :Parameters:
102        e: ndarray
103            current errors: target - output
104    :Returns:
105        v: float
106            Error value
107    """
108
109    def __call__(self, e):
110        v = np.sum(np.abs(e))
111        return v
112
113    def deriv(self, e):
114        """
115        Derivative of SAE error function
116
117        :Parameters:
118            e: ndarray
119                current errors: target - output
120        :Returns:
121            d: ndarray
122                Derivative: dE/d_out
123
124        """
125        d = np.sign(e)
126        return d
127
128
129class MAE:
130    """
131    Mean absolute error function
132
133    :Parameters:
134        e: ndarray
135            current errors: target - output
136    :Returns:
137        v: float
138            Error value
139    """
140
141    def __call__(self, e):
142        v = np.sum(np.abs(e)) / e.size
143        return v
144
145    def deriv(self, e):
146        """
147        Derivative of SAE error function
148
149        :Parameters:
150            e: ndarray
151                current errors: target - output
152        :Returns:
153            d: ndarray
154                Derivative: dE/d_out
155
156        """
157        d = np.sign(e) / e.size
158        return d
```