PageRenderTime 66ms CodeModel.GetById 25ms RepoModel.GetById 0ms app.codeStats 0ms

/theano/gof/tests/test_op.py

https://github.com/mrocklin/Theano
Python | 341 lines | 298 code | 36 blank | 7 comment | 12 complexity | 95269b40cab11ecfe61e55e21a935cf4 MD5 | raw file
  1. from copy import copy
  2. import unittest
  3. import numpy
  4. import theano
  5. import theano.gof.op as op
  6. from theano.gof.type import Type, Generic
  7. from theano.gof.graph import Apply, Variable
  8. import theano.tensor as T
  9. from theano import scalar
  10. from theano import shared
  11. config = theano.config
  12. Op = op.Op
  13. utils = op.utils
  14. def as_variable(x):
  15. assert isinstance(x, Variable)
  16. return x
  17. class MyType(Type):
  18. def __init__(self, thingy):
  19. self.thingy = thingy
  20. def __eq__(self, other):
  21. return type(other) == type(self) and other.thingy == self.thingy
  22. def __str__(self):
  23. return str(self.thingy)
  24. def __repr__(self):
  25. return str(self.thingy)
  26. def filter(self, x, strict=False, allow_downcast=None):
  27. # Dummy filter: we want this type to represent strings that
  28. # start with `self.thingy`.
  29. if not isinstance(x, basestring):
  30. raise TypeError("Invalid type")
  31. if not x.startswith(self.thingy):
  32. raise ValueError("Invalid value")
  33. return x
  34. class MyOp(Op):
  35. def make_node(self, *inputs):
  36. inputs = map(as_variable, inputs)
  37. for input in inputs:
  38. if not isinstance(input.type, MyType):
  39. raise Exception("Error 1")
  40. outputs = [MyType(sum([input.type.thingy for input in inputs]))()]
  41. return Apply(self, inputs, outputs)
  42. MyOp = MyOp()
  43. class NoInputOp(Op):
  44. """An Op to test the corner-case of an Op with no input."""
  45. def __eq__(self, other):
  46. return type(self) == type(other)
  47. def __hash__(self):
  48. return hash(type(self))
  49. def make_node(self):
  50. return Apply(self, [], [MyType('test')()])
  51. def perform(self, node, inputs, output_storage):
  52. output_storage[0][0] = 'test Op no input'
  53. class TestOp:
  54. # Sanity tests
  55. def test_sanity_0(self):
  56. r1, r2 = MyType(1)(), MyType(2)()
  57. node = MyOp.make_node(r1, r2)
  58. assert [x for x in node.inputs] == [r1, r2] # Are the inputs what I provided?
  59. assert [x.type for x in node.outputs] == [MyType(3)] # Are the outputs what I expect?
  60. assert node.outputs[0].owner is node and node.outputs[0].index == 0
  61. # validate
  62. def test_validate(self):
  63. try:
  64. MyOp(Generic()(), MyType(1)()) # MyOp requires MyType instances
  65. raise Exception("Expected an exception")
  66. except Exception, e:
  67. if str(e) != "Error 1":
  68. raise
  69. def test_op_no_input(self):
  70. x = NoInputOp()()
  71. f = theano.function([], x)
  72. rval = f()
  73. assert rval == 'test Op no input'
  74. class TestMakeThunk(unittest.TestCase):
  75. def test_no_c_code(self):
  76. class IncOnePython(Op):
  77. """An Op with only a Python (perform) implementation"""
  78. def __eq__(self, other):
  79. return type(self) == type(other)
  80. def __hash__(self):
  81. return hash(type(self))
  82. def make_node(self, input):
  83. input = scalar.as_scalar(input)
  84. output = input.type()
  85. return Apply(self, [input], [output])
  86. def perform(self, node, inputs, outputs):
  87. input, = inputs
  88. output, = outputs
  89. output[0] = input + 1
  90. i = scalar.int32('i')
  91. o = IncOnePython()(i)
  92. # Check that the c_code function is not implemented
  93. self.assertRaises((NotImplementedError, utils.MethodNotDefined),
  94. o.owner.op.c_code,
  95. o.owner, 'o', ['x'], 'z', {'fail': ''})
  96. storage_map = {
  97. i: [numpy.int32(3)],
  98. o: [None]}
  99. compute_map = {
  100. i: [True],
  101. o: [False]}
  102. thunk = o.owner.op.make_thunk(o.owner, storage_map, compute_map,
  103. no_recycling=[])
  104. required = thunk()
  105. # Check everything went OK
  106. assert not required # We provided all inputs
  107. assert compute_map[o][0]
  108. assert storage_map[o][0] == 4
  109. def test_no_perform(self):
  110. class IncOneC(Op):
  111. """An Op with only a C (c_code) implementation"""
  112. def __eq__(self, other):
  113. return type(self) == type(other)
  114. def __hash__(self):
  115. return hash(type(self))
  116. def make_node(self, input):
  117. input = scalar.as_scalar(input)
  118. output = input.type()
  119. return Apply(self, [input], [output])
  120. def c_code(self, node, name, inputs, outputs, sub):
  121. x, = inputs
  122. z, = outputs
  123. return "%(z)s = %(x)s + 1;" % locals()
  124. i = scalar.int32('i')
  125. o = IncOneC()(i)
  126. # Check that the perform function is not implemented
  127. self.assertRaises((NotImplementedError, utils.MethodNotDefined),
  128. o.owner.op.perform,
  129. o.owner, 0, [None])
  130. storage_map = {
  131. i: [numpy.int32(3)],
  132. o: [None]}
  133. compute_map = {
  134. i: [True],
  135. o: [False]}
  136. thunk = o.owner.op.make_thunk(o.owner, storage_map, compute_map,
  137. no_recycling=[])
  138. if theano.config.cxx:
  139. required = thunk()
  140. # Check everything went OK
  141. assert not required # We provided all inputs
  142. assert compute_map[o][0]
  143. assert storage_map[o][0] == 4
  144. else:
  145. self.assertRaises((NotImplementedError, utils.MethodNotDefined),
  146. thunk)
  147. def test_test_value_python_objects():
  148. for x in (range(3), 0, 0.5, 1):
  149. assert (op.get_test_value(x) == x).all()
  150. def test_test_value_ndarray():
  151. x = numpy.zeros((5,5))
  152. v = op.get_test_value(x)
  153. assert (v == x).all()
  154. def test_test_value_constant():
  155. x = T.as_tensor_variable(numpy.zeros((5,5)))
  156. v = op.get_test_value(x)
  157. assert numpy.all(v == numpy.zeros((5,5)))
  158. def test_test_value_shared():
  159. x = shared(numpy.zeros((5,5)))
  160. v = op.get_test_value(x)
  161. assert numpy.all(v == numpy.zeros((5,5)))
  162. def test_test_value_op():
  163. try:
  164. prev_value = config.compute_test_value
  165. config.compute_test_value = 'raise'
  166. x = T.log(numpy.ones((5,5)))
  167. v = op.get_test_value(x)
  168. assert numpy.allclose(v, numpy.zeros((5,5)))
  169. finally:
  170. config.compute_test_value = prev_value
  171. def test_get_debug_values_no_debugger():
  172. 'get_debug_values should return [] when debugger is off'
  173. prev_value = config.compute_test_value
  174. try:
  175. config.compute_test_value = 'off'
  176. x = T.vector()
  177. for x_val in op.get_debug_values(x):
  178. assert False
  179. finally:
  180. config.compute_test_value = prev_value
  181. def test_get_det_debug_values_ignore():
  182. """get_debug_values should return [] when debugger is ignore
  183. and some values are missing """
  184. prev_value = config.compute_test_value
  185. try:
  186. config.compute_test_value = 'ignore'
  187. x = T.vector()
  188. for x_val in op.get_debug_values(x):
  189. assert False
  190. finally:
  191. config.compute_test_value = prev_value
  192. def test_get_debug_values_success():
  193. """tests that get_debug_value returns values when available
  194. (and the debugger is on)"""
  195. prev_value = config.compute_test_value
  196. for mode in [ 'ignore', 'warn', 'raise' ]:
  197. try:
  198. config.compute_test_value = mode
  199. x = T.vector()
  200. x.tag.test_value = numpy.zeros((4,), dtype=config.floatX)
  201. y = numpy.zeros((5,5))
  202. iters = 0
  203. for x_val, y_val in op.get_debug_values(x, y):
  204. assert x_val.shape == (4,)
  205. assert y_val.shape == (5,5)
  206. iters += 1
  207. assert iters == 1
  208. finally:
  209. config.compute_test_value = prev_value
  210. def test_get_debug_values_exc():
  211. """tests that get_debug_value raises an exception when
  212. debugger is set to raise and a value is missing """
  213. prev_value = config.compute_test_value
  214. try:
  215. config.compute_test_value = 'raise'
  216. x = T.vector()
  217. try:
  218. for x_val in op.get_debug_values(x):
  219. #this assert catches the case where we
  220. #erroneously get a value returned
  221. assert False
  222. raised = False
  223. except AttributeError:
  224. raised = True
  225. #this assert catches the case where we got []
  226. #returned, and possibly issued a warning,
  227. #rather than raising an exception
  228. assert raised
  229. finally:
  230. config.compute_test_value = prev_value
  231. def test_debug_error_message():
  232. """tests that debug_error_message raises an
  233. exception when it should."""
  234. prev_value = config.compute_test_value
  235. for mode in [ 'ignore', 'raise' ]:
  236. try:
  237. config.compute_test_value = mode
  238. try:
  239. op.debug_error_message('msg')
  240. raised = False
  241. except ValueError:
  242. raised = True
  243. assert raised
  244. finally:
  245. config.compute_test_value = prev_value
  246. if __name__ == '__main__':
  247. unittest.main()