PageRenderTime 18ms CodeModel.GetById 13ms app.highlight 3ms RepoModel.GetById 1ms app.codeStats 0ms

/neurolab/error.py

https://github.com/liujiantong/neurolab
Python | 158 lines | 30 code | 18 blank | 110 comment | 0 complexity | 5bc21cdb3cd4dc5e8482231f3d98323f MD5 | raw file
  1# -*- coding: utf-8 -*-
  2""" Train error functions with derivatives
  3
  4    :Example:
  5        >>> msef = MSE()
  6        >>> x = np.array([[1.0, 0.0], [2.0, 0.0]])
  7        >>> msef(x)
  8        1.25
  9        >>> # calc derivative:
 10        >>> msef.deriv(x[0])
 11        array([ 1.,  0.])
 12
 13"""
 14
 15import numpy as np
 16
 17
 18class MSE():
 19    """
 20    Mean squared error function
 21
 22    :Parameters:
 23        e: ndarray
 24            current errors: target - output
 25    :Returns:
 26        v: float
 27            Error value
 28    :Example:
 29        >>> f = MSE()
 30        >>> x = np.array([[1.0, 0.0], [2.0, 0.0]])
 31        >>> f(x)
 32        1.25
 33
 34    """
 35
 36    def __call__(self, e):
 37        N = e.size
 38        v =  np.sum(np.square(e)) / N
 39        return v
 40
 41    def deriv(self, e):
 42        """
 43        Derivative of MSE error function
 44
 45        :Parameters:
 46            e: ndarray
 47                current errors: target - output
 48        :Returns:
 49            d: ndarray
 50                Derivative: dE/d_out
 51        :Example:
 52            >>> f = MSE()
 53            >>> x = np.array([1.0, 0.0])
 54            >>> # calc derivative:
 55            >>> f.deriv(x)
 56            array([ 1.,  0.])
 57
 58        """
 59
 60        N = len(e)
 61        d = e * (2 / N)
 62        return d
 63
 64
 65class SSE:
 66    """
 67    Sum squared error function
 68
 69    :Parameters:
 70        e: ndarray
 71            current errors: target - output
 72    :Returns:
 73        v: float
 74            Error value
 75
 76    """
 77
 78    def __call__(self, e):
 79        v = 0.5 * np.sum(np.square(e))
 80        return v
 81
 82    def deriv(self, e):
 83        """
 84        Derivative of SSE error function
 85
 86        :Parameters:
 87            e: ndarray
 88                current errors: target - output
 89        :Returns:
 90            d: ndarray
 91                Derivative: dE/d_out
 92
 93        """
 94        return e
 95
 96
 97class SAE:
 98    """
 99    Sum absolute error function
100
101    :Parameters:
102        e: ndarray
103            current errors: target - output
104    :Returns:
105        v: float
106            Error value
107    """
108
109    def __call__(self, e):
110        v = np.sum(np.abs(e))
111        return v
112
113    def deriv(self, e):
114        """
115        Derivative of SAE error function
116
117        :Parameters:
118            e: ndarray
119                current errors: target - output
120        :Returns:
121            d: ndarray
122                Derivative: dE/d_out
123
124        """
125        d = np.sign(e)
126        return d
127
128
129class MAE:
130    """
131    Mean absolute error function
132
133    :Parameters:
134        e: ndarray
135            current errors: target - output
136    :Returns:
137        v: float
138            Error value
139    """
140
141    def __call__(self, e):
142        v = np.sum(np.abs(e)) / e.size
143        return v
144
145    def deriv(self, e):
146        """
147        Derivative of SAE error function
148
149        :Parameters:
150            e: ndarray
151                current errors: target - output
152        :Returns:
153            d: ndarray
154                Derivative: dE/d_out
155
156        """
157        d = np.sign(e) / e.size
158        return d