/theano/compile/tests/test_shared.py

https://github.com/ynd/Theano · Python · 311 lines · 203 code · 68 blank · 40 comment · 8 complexity · 4c439902b0b43a85442b6aa5ff34db2a MD5 · raw file

  1. import numpy
  2. import unittest
  3. import theano
  4. from theano.tensor import Tensor, TensorType
  5. from theano.compile.sharedvalue import *
  6. class Test_SharedVariable(unittest.TestCase):
  7. def test_ctors(self):
  8. if 0:
  9. # when using an implementation that handles scalars with
  10. # Scalar type
  11. assert shared(7).type == Scalar('int64')
  12. assert shared(7.0).type == Scalar('float64')
  13. assert shared(7, dtype='float64').type == Scalar('float64')
  14. else:
  15. if theano.gof.python_int_bitwidth() == 32:
  16. assert shared(7).type == theano.tensor.iscalar, shared(7).type
  17. else:
  18. assert shared(7).type == theano.tensor.lscalar, shared(7).type
  19. assert shared(7.0).type == theano.tensor.dscalar
  20. assert shared(numpy.float32(7)).type == theano.tensor.fscalar
  21. # test tensor constructor
  22. b = shared(numpy.zeros((5, 5), dtype='int32'))
  23. assert b.type == TensorType('int32', broadcastable=[False, False])
  24. b = shared(numpy.random.rand(4, 5))
  25. assert b.type == TensorType('float64', broadcastable=[False, False])
  26. b = shared(numpy.random.rand(5, 1, 2))
  27. assert b.type == TensorType('float64',
  28. broadcastable=[False, False, False])
  29. assert shared([]).type == generic
  30. def badfunc():
  31. shared(7, bad_kw=False)
  32. self.assertRaises(TypeError, badfunc)
  33. def test_strict_generic(self):
  34. #this should work, because
  35. # generic can hold anything even when strict=True
  36. u = shared('asdf', strict=False)
  37. v = shared('asdf', strict=True)
  38. u.set_value(88)
  39. v.set_value(88)
  40. def test_create_numpy_strict_false(self):
  41. # here the value is perfect, and we're not strict about it,
  42. # so creation should work
  43. SharedVariable(
  44. name='u',
  45. type=Tensor(broadcastable=[False], dtype='float64'),
  46. value=numpy.asarray([1., 2.]),
  47. strict=False)
  48. # here the value is castable, and we're not strict about it,
  49. # so creation should work
  50. SharedVariable(
  51. name='u',
  52. type=Tensor(broadcastable=[False], dtype='float64'),
  53. value=[1., 2.],
  54. strict=False)
  55. # here the value is castable, and we're not strict about it,
  56. # so creation should work
  57. SharedVariable(
  58. name='u',
  59. type=Tensor(broadcastable=[False], dtype='float64'),
  60. value=[1, 2], # different dtype and not a numpy array
  61. strict=False)
  62. # here the value is not castable, and we're not strict about it,
  63. # this is beyond strictness, it must fail
  64. try:
  65. SharedVariable(
  66. name='u',
  67. type=Tensor(broadcastable=[False], dtype='float64'),
  68. value=dict(), # not an array by any stretch
  69. strict=False)
  70. assert 0
  71. except TypeError:
  72. pass
  73. def test_use_numpy_strict_false(self):
  74. # here the value is perfect, and we're not strict about it,
  75. # so creation should work
  76. u = SharedVariable(
  77. name='u',
  78. type=Tensor(broadcastable=[False], dtype='float64'),
  79. value=numpy.asarray([1., 2.]),
  80. strict=False)
  81. # check that assignments to value are cast properly
  82. u.set_value([3, 4])
  83. assert type(u.get_value()) is numpy.ndarray
  84. assert str(u.get_value(borrow=True).dtype) == 'float64'
  85. assert numpy.all(u.get_value() == [3, 4])
  86. # check that assignments of nonsense fail
  87. try:
  88. u.set_value('adsf')
  89. assert 0
  90. except ValueError:
  91. pass
  92. # check that an assignment of a perfect value results in no copying
  93. uval = theano._asarray([5, 6, 7, 8], dtype='float64')
  94. u.set_value(uval, borrow=True)
  95. assert u.get_value(borrow=True) is uval
  96. def test_scalar_strict(self):
  97. def f(var, val):
  98. var.set_value(val)
  99. b = shared(numpy.int64(7), strict=True)
  100. assert b.type == theano.tensor.lscalar
  101. self.assertRaises(TypeError, f, b, 8.23)
  102. b = shared(numpy.int32(7), strict=True)
  103. assert b.type == theano.tensor.iscalar
  104. self.assertRaises(TypeError, f, b, 8.23)
  105. b = shared(numpy.int16(7), strict=True)
  106. assert b.type == theano.tensor.wscalar
  107. self.assertRaises(TypeError, f, b, 8.23)
  108. b = shared(numpy.int8(7), strict=True)
  109. assert b.type == theano.tensor.bscalar
  110. self.assertRaises(TypeError, f, b, 8.23)
  111. b = shared(numpy.float64(7.234), strict=True)
  112. assert b.type == theano.tensor.dscalar
  113. self.assertRaises(TypeError, f, b, 8)
  114. b = shared(numpy.float32(7.234), strict=True)
  115. assert b.type == theano.tensor.fscalar
  116. self.assertRaises(TypeError, f, b, 8)
  117. b = shared(numpy.float(7.234), strict=True)
  118. assert b.type == theano.tensor.dscalar
  119. self.assertRaises(TypeError, f, b, 8)
  120. b = shared(7.234, strict=True)
  121. assert b.type == theano.tensor.dscalar
  122. self.assertRaises(TypeError, f, b, 8)
  123. b = shared(numpy.zeros((5, 5), dtype='float32'))
  124. self.assertRaises(TypeError, f, b, numpy.random.rand(5, 5))
  125. def test_tensor_strict(self):
  126. def f(var, val):
  127. var.set_value(val)
  128. b = shared(numpy.int64([7]), strict=True)
  129. assert b.type == theano.tensor.lvector
  130. self.assertRaises(TypeError, f, b, 8.23)
  131. b = shared(numpy.int32([7]), strict=True)
  132. assert b.type == theano.tensor.ivector
  133. self.assertRaises(TypeError, f, b, 8.23)
  134. b = shared(numpy.int16([7]), strict=True)
  135. assert b.type == theano.tensor.wvector
  136. self.assertRaises(TypeError, f, b, 8.23)
  137. b = shared(numpy.int8([7]), strict=True)
  138. assert b.type == theano.tensor.bvector
  139. self.assertRaises(TypeError, f, b, 8.23)
  140. b = shared(numpy.float64([7.234]), strict=True)
  141. assert b.type == theano.tensor.dvector
  142. self.assertRaises(TypeError, f, b, 8)
  143. b = shared(numpy.float32([7.234]), strict=True)
  144. assert b.type == theano.tensor.fvector
  145. self.assertRaises(TypeError, f, b, 8)
  146. #numpy.float([7.234]) don't work
  147. # b = shared(numpy.float([7.234]), strict=True)
  148. # assert b.type == theano.tensor.dvector
  149. # self.assertRaises(TypeError, f, b, 8)
  150. #This generate a generic type. Should we cast? I don't think.
  151. # b = shared([7.234], strict=True)
  152. # assert b.type == theano.tensor.dvector
  153. # self.assertRaises(TypeError, f, b, 8)
  154. b = shared(numpy.zeros((5, 5), dtype='float32'))
  155. self.assertRaises(TypeError, f, b, numpy.random.rand(5, 5))
  156. def test_scalar_floatX(self):
  157. # the test should assure that floatX is not used in the shared
  158. # constructor for scalars Shared values can change, and since we don't
  159. # know the range they might take, we should keep the same
  160. # bit width / precision as the original value used to create the
  161. # shared variable.
  162. # Since downcasting of a value now raises an Exception,
  163. def f(var, val):
  164. var.set_value(val)
  165. b = shared(numpy.int64(7), allow_downcast=True)
  166. assert b.type == theano.tensor.lscalar
  167. f(b, 8.23)
  168. assert b.get_value() == 8
  169. b = shared(numpy.int32(7), allow_downcast=True)
  170. assert b.type == theano.tensor.iscalar
  171. f(b, 8.23)
  172. assert b.get_value() == 8
  173. b = shared(numpy.int16(7), allow_downcast=True)
  174. assert b.type == theano.tensor.wscalar
  175. f(b, 8.23)
  176. assert b.get_value() == 8
  177. b = shared(numpy.int8(7), allow_downcast=True)
  178. assert b.type == theano.tensor.bscalar
  179. f(b, 8.23)
  180. assert b.get_value() == 8
  181. b = shared(numpy.float64(7.234), allow_downcast=True)
  182. assert b.type == theano.tensor.dscalar
  183. f(b, 8)
  184. assert b.get_value() == 8
  185. b = shared(numpy.float32(7.234), allow_downcast=True)
  186. assert b.type == theano.tensor.fscalar
  187. f(b, 8)
  188. assert b.get_value() == 8
  189. b = shared(numpy.float(7.234), allow_downcast=True)
  190. assert b.type == theano.tensor.dscalar
  191. f(b, 8)
  192. assert b.get_value() == 8
  193. b = shared(7.234, allow_downcast=True)
  194. assert b.type == theano.tensor.dscalar
  195. f(b, 8)
  196. assert b.get_value() == 8
  197. b = shared(numpy.zeros((5, 5), dtype='float32'))
  198. self.assertRaises(TypeError, f, b, numpy.random.rand(5, 5))
  199. def test_tensor_floatX(self):
  200. def f(var, val):
  201. var.set_value(val)
  202. b = shared(numpy.int64([7]), allow_downcast=True)
  203. assert b.type == theano.tensor.lvector
  204. f(b, [8.23])
  205. assert b.get_value() == 8
  206. b = shared(numpy.int32([7]), allow_downcast=True)
  207. assert b.type == theano.tensor.ivector
  208. f(b, [8.23])
  209. assert b.get_value() == 8
  210. b = shared(numpy.int16([7]), allow_downcast=True)
  211. assert b.type == theano.tensor.wvector
  212. f(b, [8.23])
  213. assert b.get_value() == 8
  214. b = shared(numpy.int8([7]), allow_downcast=True)
  215. assert b.type == theano.tensor.bvector
  216. f(b, [8.23])
  217. assert b.get_value() == 8
  218. b = shared(numpy.float64([7.234]), allow_downcast=True)
  219. assert b.type == theano.tensor.dvector
  220. f(b, [8])
  221. assert b.get_value() == 8
  222. b = shared(numpy.float32([7.234]), allow_downcast=True)
  223. assert b.type == theano.tensor.fvector
  224. f(b, [8])
  225. assert b.get_value() == 8
  226. #numpy.float([7.234]) don't work
  227. # b = shared(numpy.float([7.234]))
  228. # assert b.type == theano.tensor.dvector
  229. # f(b,[8])
  230. #This generate a generic type. Should we cast? I don't think.
  231. # b = shared([7.234])
  232. # assert b.type == theano.tensor.dvector
  233. # f(b,[8])
  234. b = shared(numpy.asarray([7.234], dtype=theano.config.floatX),
  235. allow_downcast=True)
  236. assert b.dtype == theano.config.floatX
  237. f(b, [8])
  238. assert b.get_value() == 8
  239. b = shared(numpy.zeros((5, 5), dtype='float32'))
  240. self.assertRaises(TypeError, f, b, numpy.random.rand(5, 5))
  241. def test_err_symbolic_variable(self):
  242. self.assertRaises(TypeError, shared, theano.tensor.ones((2, 3)))
  243. shared(numpy.ones((2, 4)))