PageRenderTime 58ms CodeModel.GetById 21ms RepoModel.GetById 0ms app.codeStats 0ms

/neurolab/error.py

https://github.com/liujiantong/neurolab
Python | 158 lines | 30 code | 18 blank | 110 comment | 0 complexity | 5bc21cdb3cd4dc5e8482231f3d98323f MD5 | raw file
  1. # -*- coding: utf-8 -*-
  2. """ Train error functions with derivatives
  3. :Example:
  4. >>> msef = MSE()
  5. >>> x = np.array([[1.0, 0.0], [2.0, 0.0]])
  6. >>> msef(x)
  7. 1.25
  8. >>> # calc derivative:
  9. >>> msef.deriv(x[0])
  10. array([ 1., 0.])
  11. """
  12. import numpy as np
  13. class MSE():
  14. """
  15. Mean squared error function
  16. :Parameters:
  17. e: ndarray
  18. current errors: target - output
  19. :Returns:
  20. v: float
  21. Error value
  22. :Example:
  23. >>> f = MSE()
  24. >>> x = np.array([[1.0, 0.0], [2.0, 0.0]])
  25. >>> f(x)
  26. 1.25
  27. """
  28. def __call__(self, e):
  29. N = e.size
  30. v = np.sum(np.square(e)) / N
  31. return v
  32. def deriv(self, e):
  33. """
  34. Derivative of MSE error function
  35. :Parameters:
  36. e: ndarray
  37. current errors: target - output
  38. :Returns:
  39. d: ndarray
  40. Derivative: dE/d_out
  41. :Example:
  42. >>> f = MSE()
  43. >>> x = np.array([1.0, 0.0])
  44. >>> # calc derivative:
  45. >>> f.deriv(x)
  46. array([ 1., 0.])
  47. """
  48. N = len(e)
  49. d = e * (2 / N)
  50. return d
  51. class SSE:
  52. """
  53. Sum squared error function
  54. :Parameters:
  55. e: ndarray
  56. current errors: target - output
  57. :Returns:
  58. v: float
  59. Error value
  60. """
  61. def __call__(self, e):
  62. v = 0.5 * np.sum(np.square(e))
  63. return v
  64. def deriv(self, e):
  65. """
  66. Derivative of SSE error function
  67. :Parameters:
  68. e: ndarray
  69. current errors: target - output
  70. :Returns:
  71. d: ndarray
  72. Derivative: dE/d_out
  73. """
  74. return e
  75. class SAE:
  76. """
  77. Sum absolute error function
  78. :Parameters:
  79. e: ndarray
  80. current errors: target - output
  81. :Returns:
  82. v: float
  83. Error value
  84. """
  85. def __call__(self, e):
  86. v = np.sum(np.abs(e))
  87. return v
  88. def deriv(self, e):
  89. """
  90. Derivative of SAE error function
  91. :Parameters:
  92. e: ndarray
  93. current errors: target - output
  94. :Returns:
  95. d: ndarray
  96. Derivative: dE/d_out
  97. """
  98. d = np.sign(e)
  99. return d
  100. class MAE:
  101. """
  102. Mean absolute error function
  103. :Parameters:
  104. e: ndarray
  105. current errors: target - output
  106. :Returns:
  107. v: float
  108. Error value
  109. """
  110. def __call__(self, e):
  111. v = np.sum(np.abs(e)) / e.size
  112. return v
  113. def deriv(self, e):
  114. """
  115. Derivative of SAE error function
  116. :Parameters:
  117. e: ndarray
  118. current errors: target - output
  119. :Returns:
  120. d: ndarray
  121. Derivative: dE/d_out
  122. """
  123. d = np.sign(e) / e.size
  124. return d