PageRenderTime 35ms CodeModel.GetById 22ms RepoModel.GetById 0ms app.codeStats 0ms

/tests/unit/nupic/encoders/adaptivescalar_test.py

https://gitlab.com/github-cloud-corporation/nupic
Python | 257 lines | 229 code | 7 blank | 21 comment | 4 complexity | 7e547be4055f8d1a4fbcacf81dec7201 MD5 | raw file
  1. #!/usr/bin/env python
  2. # ----------------------------------------------------------------------
  3. # Numenta Platform for Intelligent Computing (NuPIC)
  4. # Copyright (C) 2013, Numenta, Inc. Unless you have an agreement
  5. # with Numenta, Inc., for a separate license for this software code, the
  6. # following terms and conditions apply:
  7. #
  8. # This program is free software: you can redistribute it and/or modify
  9. # it under the terms of the GNU Affero Public License version 3 as
  10. # published by the Free Software Foundation.
  11. #
  12. # This program is distributed in the hope that it will be useful,
  13. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  14. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  15. # See the GNU Affero Public License for more details.
  16. #
  17. # You should have received a copy of the GNU Affero Public License
  18. # along with this program. If not, see http://www.gnu.org/licenses.
  19. #
  20. # http://numenta.org/licenses/
  21. # ----------------------------------------------------------------------
  22. import tempfile
  23. import unittest
  24. import numpy
  25. from nupic.data import SENTINEL_VALUE_FOR_MISSING_DATA
  26. from nupic.encoders.adaptivescalar import AdaptiveScalarEncoder
  27. from nupic.encoders.base import defaultDtype
  28. try:
  29. import capnp
  30. except ImportError:
  31. capnp = None
  32. if capnp:
  33. from nupic.encoders.adaptivescalar_capnp import AdaptiveScalarEncoderProto
  34. class AdaptiveScalarTest(unittest.TestCase):
  35. """Tests for AdaptiveScalarEncoder"""
  36. def setUp(self):
  37. # forced: it's strongly recommended to use w>=21, in the example we force
  38. # skip the check for readibility
  39. self._l = AdaptiveScalarEncoder(name="scalar", n=14, w=5, minval=1,
  40. maxval=10, periodic=False, forced=True)
  41. def testMissingValues(self):
  42. """missing values"""
  43. # forced: it's strongly recommended to use w>=21, in the example we force
  44. # skip the check for readib.
  45. mv = AdaptiveScalarEncoder(name="mv", n=14, w=3, minval=1, maxval=8,
  46. periodic=False, forced=True)
  47. empty = mv.encode(SENTINEL_VALUE_FOR_MISSING_DATA)
  48. self.assertEqual(empty.sum(), 0)
  49. def testNonPeriodicEncoderMinMaxSpec(self):
  50. """Non-periodic encoder, min and max specified"""
  51. self.assertTrue(numpy.array_equal(
  52. self._l.encode(1),
  53. numpy.array([1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  54. dtype=defaultDtype)))
  55. self.assertTrue(numpy.array_equal(
  56. self._l.encode(2),
  57. numpy.array([0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
  58. dtype=defaultDtype)))
  59. self.assertTrue(numpy.array_equal(
  60. self._l.encode(10),
  61. numpy.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
  62. dtype=defaultDtype)))
  63. def testTopDownDecode(self):
  64. """Test the input description generation and topDown decoding"""
  65. l = self._l
  66. v = l.minval
  67. while v < l.maxval:
  68. output = l.encode(v)
  69. decoded = l.decode(output)
  70. (fieldsDict, _) = decoded
  71. self.assertEqual(len(fieldsDict), 1)
  72. (ranges, _) = fieldsDict.values()[0]
  73. self.assertEqual(len(ranges), 1)
  74. (rangeMin, rangeMax) = ranges[0]
  75. self.assertEqual(rangeMin, rangeMax)
  76. self.assertLess(abs(rangeMin - v), l.resolution)
  77. topDown = l.topDownCompute(output)[0]
  78. self.assertLessEqual(abs(topDown.value - v), l.resolution)
  79. # Test bucket support
  80. bucketIndices = l.getBucketIndices(v)
  81. topDown = l.getBucketInfo(bucketIndices)[0]
  82. self.assertLessEqual(abs(topDown.value - v), l.resolution / 2)
  83. self.assertEqual(topDown.value, l.getBucketValues()[bucketIndices[0]])
  84. self.assertEqual(topDown.scalar, topDown.value)
  85. self.assertTrue(numpy.array_equal(topDown.encoding, output))
  86. # Next value
  87. v += l.resolution / 4
  88. def testFillHoles(self):
  89. """Make sure we can fill in holes"""
  90. l=self._l
  91. decoded = l.decode(numpy.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1]))
  92. (fieldsDict, _) = decoded
  93. self.assertEqual(len(fieldsDict), 1)
  94. (ranges, _) = fieldsDict.values()[0]
  95. self.assertEqual(len(ranges), 1)
  96. self.assertSequenceEqual(ranges[0], [10, 10])
  97. decoded = l.decode(numpy.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1]))
  98. (fieldsDict, _) = decoded
  99. self.assertEqual(len(fieldsDict), 1)
  100. (ranges, _) = fieldsDict.values()[0]
  101. self.assertEqual(len(ranges), 1)
  102. self.assertSequenceEqual(ranges[0], [10, 10])
  103. def testNonPeriodicEncoderMinMaxNotSpec(self):
  104. """Non-periodic encoder, min and max not specified"""
  105. l = AdaptiveScalarEncoder(name="scalar", n=14, w=5, minval=None,
  106. maxval=None, periodic=False, forced=True)
  107. def _verify(v, encoded, expV=None):
  108. if expV is None:
  109. expV = v
  110. self.assertTrue(numpy.array_equal(
  111. l.encode(v),
  112. numpy.array(encoded, dtype=defaultDtype)))
  113. self.assertLessEqual(
  114. abs(l.getBucketInfo(l.getBucketIndices(v))[0].value - expV),
  115. l.resolution/2)
  116. def _verifyNot(v, encoded):
  117. self.assertFalse(numpy.array_equal(
  118. l.encode(v), numpy.array(encoded, dtype=defaultDtype)))
  119. _verify(1, [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])
  120. _verify(2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
  121. _verify(10, [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
  122. _verify(3, [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0])
  123. _verify(-9, [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])
  124. _verify(-8, [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])
  125. _verify(-7, [0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0])
  126. _verify(-6, [0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0])
  127. _verify(-5, [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0])
  128. _verify(0, [0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0])
  129. _verify(8, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0])
  130. _verify(8, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0])
  131. _verify(10, [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
  132. _verify(11, [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
  133. _verify(12, [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
  134. _verify(13, [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
  135. _verify(14, [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
  136. _verify(15, [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
  137. #"""Test switching learning off"""
  138. l = AdaptiveScalarEncoder(name="scalar", n=14, w=5, minval=1, maxval=10,
  139. periodic=False, forced=True)
  140. _verify(1, [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])
  141. _verify(10, [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
  142. _verify(20, [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
  143. _verify(10, [0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0])
  144. l.setLearning(False)
  145. _verify(30, [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1], expV=20)
  146. _verify(20, [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
  147. _verify(-10, [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], expV=1)
  148. _verify(-1, [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], expV=1)
  149. l.setLearning(True)
  150. _verify(30, [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
  151. _verifyNot(20, [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
  152. _verify(-10, [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])
  153. _verifyNot(-1, [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])
  154. def testSetFieldStats(self):
  155. """Test setting the min and max using setFieldStats"""
  156. def _dumpParams(enc):
  157. return (enc.n, enc.w, enc.minval, enc.maxval, enc.resolution,
  158. enc._learningEnabled, enc.recordNum,
  159. enc.radius, enc.rangeInternal, enc.padding, enc.nInternal)
  160. sfs = AdaptiveScalarEncoder(name='scalar', n=14, w=5, minval=1, maxval=10,
  161. periodic=False, forced=True)
  162. reg = AdaptiveScalarEncoder(name='scalar', n=14, w=5, minval=1, maxval=100,
  163. periodic=False, forced=True)
  164. self.assertNotEqual(_dumpParams(sfs), _dumpParams(reg),
  165. ("Params should not be equal, since the two encoders "
  166. "were instantiated with different values."))
  167. # set the min and the max using sFS to 1,100 respectively.
  168. sfs.setFieldStats("this", {"this":{"min":1, "max":100}})
  169. #Now the parameters for both should be the same
  170. self.assertEqual(_dumpParams(sfs), _dumpParams(reg),
  171. ("Params should now be equal, but they are not. sFS "
  172. "should be equivalent to initialization."))
  173. @unittest.skipUnless(
  174. capnp, "pycapnp is not installed, skipping serialization test.")
  175. def testReadWrite(self):
  176. originalValue = self._l.encode(1)
  177. proto1 = AdaptiveScalarEncoderProto.new_message()
  178. self._l.write(proto1)
  179. # Write the proto to a temp file and read it back into a new proto
  180. with tempfile.TemporaryFile() as f:
  181. proto1.write(f)
  182. f.seek(0)
  183. proto2 = AdaptiveScalarEncoderProto.read(f)
  184. encoder = AdaptiveScalarEncoder.read(proto2)
  185. self.assertIsInstance(encoder, AdaptiveScalarEncoder)
  186. self.assertEqual(encoder.recordNum, self._l.recordNum)
  187. self.assertDictEqual(encoder.slidingWindow.__dict__,
  188. self._l.slidingWindow.__dict__)
  189. self.assertEqual(encoder.w, self._l.w)
  190. self.assertEqual(encoder.minval, self._l.minval)
  191. self.assertEqual(encoder.maxval, self._l.maxval)
  192. self.assertEqual(encoder.periodic, self._l.periodic)
  193. self.assertEqual(encoder.n, self._l.n)
  194. self.assertEqual(encoder.radius, self._l.radius)
  195. self.assertEqual(encoder.resolution, self._l.resolution)
  196. self.assertEqual(encoder.name, self._l.name)
  197. self.assertEqual(encoder.verbosity, self._l.verbosity)
  198. self.assertEqual(encoder.clipInput, self._l.clipInput)
  199. self.assertTrue(numpy.array_equal(encoder.encode(1), originalValue))
  200. self.assertEqual(self._l.decode(encoder.encode(1)),
  201. encoder.decode(self._l.encode(1)))
  202. # Feed in a new value and ensure the encodings match
  203. result1 = self._l.encode(7)
  204. result2 = encoder.encode(7)
  205. self.assertTrue(numpy.array_equal(result1, result2))
  206. if __name__ == '__main__':
  207. unittest.main()