PageRenderTime 55ms CodeModel.GetById 15ms RepoModel.GetById 0ms app.codeStats 0ms

/theano/misc/pkl_utils.py

https://github.com/lamblin/Theano
Python | 405 lines | 360 code | 14 blank | 31 comment | 8 complexity | af74591644aa8faa72d9b3cce52075ce MD5 | raw file
  1. """
  2. Utility classes and methods to pickle parts of symbolic graph.
  3. These pickled graphs can be used, for instance, as cases for
  4. unit tests or regression tests.
  5. """
  6. from __future__ import absolute_import, print_function, division
  7. import numpy as np
  8. import os
  9. import pickle
  10. import sys
  11. import tempfile
  12. import zipfile
  13. import warnings
  14. from collections import defaultdict
  15. from contextlib import closing
  16. from pickle import HIGHEST_PROTOCOL
  17. from six import BytesIO
  18. try:
  19. from pickle import DEFAULT_PROTOCOL
  20. except ImportError:
  21. DEFAULT_PROTOCOL = HIGHEST_PROTOCOL
  22. import theano
  23. from theano import config
  24. from theano.compat import PY3
  25. from six import string_types
  26. from theano.compile.sharedvalue import SharedVariable
  27. __docformat__ = "restructuredtext en"
  28. __authors__ = "Pascal Lamblin"
  29. __copyright__ = "Copyright 2013, Universite de Montreal"
  30. __license__ = "3-clause BSD"
  31. min_recursion = 3000
  32. if sys.getrecursionlimit() < min_recursion:
  33. sys.setrecursionlimit(min_recursion)
  34. Pickler = pickle.Pickler
  35. class StripPickler(Pickler):
  36. """
  37. Subclass of Pickler that strips unnecessary attributes from Theano objects.
  38. .. versionadded:: 0.8
  39. Example of use::
  40. fn_args = dict(inputs=inputs,
  41. outputs=outputs,
  42. updates=updates)
  43. dest_pkl = 'my_test.pkl'
  44. f = open(dest_pkl, 'wb')
  45. strip_pickler = StripPickler(f, protocol=-1)
  46. strip_pickler.dump(fn_args)
  47. f.close()
  48. """
  49. def __init__(self, file, protocol=0, extra_tag_to_remove=None):
  50. # Can't use super as Pickler isn't a new style class
  51. Pickler.__init__(self, file, protocol)
  52. self.tag_to_remove = ['trace', 'test_value']
  53. if extra_tag_to_remove:
  54. self.tag_to_remove.extend(extra_tag_to_remove)
  55. def save(self, obj):
  56. # Remove the tag.trace attribute from Variable and Apply nodes
  57. if isinstance(obj, theano.gof.utils.scratchpad):
  58. for tag in self.tag_to_remove:
  59. if hasattr(obj, tag):
  60. del obj.__dict__[tag]
  61. # Remove manually-added docstring of Elemwise ops
  62. elif (isinstance(obj, theano.tensor.Elemwise)):
  63. if '__doc__' in obj.__dict__:
  64. del obj.__dict__['__doc__']
  65. return Pickler.save(self, obj)
  66. # Make an unpickler that tries encoding byte streams before raising TypeError.
  67. # This is useful with python 3, in order to unpickle files created with
  68. # python 2.
  69. # This code is taken from Pandas, https://github.com/pydata/pandas,
  70. # under the same 3-clause BSD license.
  71. def load_reduce(self):
  72. stack = self.stack
  73. args = stack.pop()
  74. func = stack[-1]
  75. try:
  76. value = func(*args)
  77. except Exception:
  78. # try to reencode the arguments
  79. if self.encoding is not None:
  80. new_args = []
  81. for arg in args:
  82. if isinstance(arg, string_types):
  83. new_args.append(arg.encode(self.encoding))
  84. else:
  85. new_args.append(arg)
  86. args = tuple(new_args)
  87. try:
  88. stack[-1] = func(*args)
  89. return
  90. except Exception:
  91. pass
  92. # if self.is_verbose:
  93. # print(sys.exc_info())
  94. # print(func, args)
  95. raise
  96. stack[-1] = value
  97. if PY3:
  98. class CompatUnpickler(pickle._Unpickler):
  99. """
  100. Allow to reload in python 3 some pickled numpy ndarray.
  101. .. versionadded:: 0.8
  102. Examples
  103. --------
  104. ::
  105. with open(fname, 'rb') as fp:
  106. if PY3:
  107. u = CompatUnpickler(fp, encoding="latin1")
  108. else:
  109. u = CompatUnpickler(fp)
  110. mat = u.load()
  111. """
  112. pass
  113. # Register `load_reduce` defined above in CompatUnpickler
  114. CompatUnpickler.dispatch[pickle.REDUCE[0]] = load_reduce
  115. else:
  116. class CompatUnpickler(pickle.Unpickler):
  117. """
  118. Allow to reload in python 3 some pickled numpy ndarray.
  119. .. versionadded:: 0.8
  120. Examples
  121. --------
  122. ::
  123. with open(fname, 'rb') as fp:
  124. if PY3:
  125. u = CompatUnpickler(fp, encoding="latin1")
  126. else:
  127. u = CompatUnpickler(fp)
  128. mat = u.load()
  129. """
  130. pass
  131. class PersistentNdarrayID(object):
  132. """Persist ndarrays in an object by saving them to a zip file.
  133. :param zip_file: A zip file handle that the NumPy arrays will be saved to.
  134. :type zip_file: :class:`zipfile.ZipFile`
  135. .. note:
  136. The convention for persistent ids given by this class and its derived
  137. classes is that the name should take the form `type.name` where `type`
  138. can be used by the persistent loader to determine how to load the
  139. object, while `name` is human-readable and as descriptive as possible.
  140. """
  141. def __init__(self, zip_file):
  142. self.zip_file = zip_file
  143. self.count = 0
  144. self.seen = {}
  145. def _resolve_name(self, obj):
  146. """Determine the name the object should be saved under."""
  147. name = 'array_{0}'.format(self.count)
  148. self.count += 1
  149. return name
  150. def __call__(self, obj):
  151. if type(obj) is np.ndarray:
  152. if id(obj) not in self.seen:
  153. def write_array(f):
  154. np.lib.format.write_array(f, obj)
  155. name = self._resolve_name(obj)
  156. zipadd(write_array, self.zip_file, name)
  157. self.seen[id(obj)] = 'ndarray.{0}'.format(name)
  158. return self.seen[id(obj)]
  159. class PersistentGpuArrayID(PersistentNdarrayID):
  160. def __call__(self, obj):
  161. from theano.gpuarray.type import _name_for_ctx
  162. try:
  163. import pygpu
  164. except ImportError:
  165. pygpu = None
  166. if (pygpu and
  167. isinstance(obj, pygpu.gpuarray.GpuArray)):
  168. if id(obj) not in self.seen:
  169. def write_array(f):
  170. pickle.dump(_name_for_ctx(obj.context), f, 2)
  171. np.lib.format.write_array(f, np.asarray(obj))
  172. name = self._resolve_name(obj)
  173. zipadd(write_array, self.zip_file, name)
  174. self.seen[id(obj)] = 'gpuarray.{0}'.format(name)
  175. return self.seen[id(obj)]
  176. return super(PersistentGpuArrayID, self).__call__(obj)
  177. class PersistentSharedVariableID(PersistentGpuArrayID):
  178. """Uses shared variable names when persisting to zip file.
  179. If a shared variable has a name, this name is used as the name of the
  180. NPY file inside of the zip file. NumPy arrays that aren't matched to a
  181. shared variable are persisted as usual (i.e. `array_0`, `array_1`,
  182. etc.)
  183. :param allow_unnamed: Allow shared variables without a name to be
  184. persisted. Defaults to ``True``.
  185. :type allow_unnamed: bool, optional
  186. :param allow_duplicates: Allow multiple shared variables to have the same
  187. name, in which case they will be numbered e.g. `x`, `x_2`, `x_3`, etc.
  188. Defaults to ``True``.
  189. :type allow_duplicates: bool, optional
  190. :raises ValueError
  191. If an unnamed shared variable is encountered and `allow_unnamed` is
  192. ``False``, or if two shared variables have the same name, and
  193. `allow_duplicates` is ``False``.
  194. """
  195. def __init__(self, zip_file, allow_unnamed=True, allow_duplicates=True):
  196. super(PersistentSharedVariableID, self).__init__(zip_file)
  197. self.name_counter = defaultdict(int)
  198. self.ndarray_names = {}
  199. self.allow_unnamed = allow_unnamed
  200. self.allow_duplicates = allow_duplicates
  201. def _resolve_name(self, obj):
  202. if id(obj) in self.ndarray_names:
  203. name = self.ndarray_names[id(obj)]
  204. count = self.name_counter[name]
  205. self.name_counter[name] += 1
  206. if count:
  207. if not self.allow_duplicates:
  208. raise ValueError("multiple shared variables with the name "
  209. "`{0}` found".format(name))
  210. name = '{0}_{1}'.format(name, count + 1)
  211. return name
  212. return super(PersistentSharedVariableID, self)._resolve_name(obj)
  213. def __call__(self, obj):
  214. if isinstance(obj, SharedVariable):
  215. if obj.name:
  216. if obj.name == 'pkl':
  217. ValueError("can't pickle shared variable with name `pkl`")
  218. self.ndarray_names[id(obj.container.storage[0])] = obj.name
  219. elif not self.allow_unnamed:
  220. raise ValueError("unnamed shared variable, {0}".format(obj))
  221. return super(PersistentSharedVariableID, self).__call__(obj)
  222. class PersistentNdarrayLoad(object):
  223. """Load NumPy arrays that were persisted to a zip file when pickling.
  224. :param zip_file: The zip file handle in which the NumPy arrays are saved.
  225. :type zip_file: :class:`zipfile.ZipFile`
  226. """
  227. def __init__(self, zip_file):
  228. self.zip_file = zip_file
  229. self.cache = {}
  230. def __call__(self, persid):
  231. from theano.gpuarray.type import get_context
  232. from theano.gpuarray import pygpu
  233. array_type, name = persid.split('.')
  234. if name in self.cache:
  235. return self.cache[name]
  236. ret = None
  237. if array_type == 'gpuarray':
  238. with self.zip_file.open(name) as f:
  239. ctx_name = pickle.load(f)
  240. array = np.lib.format.read_array(f)
  241. if config.experimental.unpickle_gpu_on_cpu:
  242. # directly return numpy array
  243. warnings.warn("config.experimental.unpickle_gpu_on_cpu is set "
  244. "to True. Unpickling GpuArray as numpy.ndarray")
  245. ret = array
  246. elif pygpu:
  247. ret = pygpu.array(array, context=get_context(ctx_name))
  248. else:
  249. raise ImportError("pygpu not found. Cannot unpickle GpuArray")
  250. else:
  251. with self.zip_file.open(name) as f:
  252. ret = np.lib.format.read_array(f)
  253. self.cache[name] = ret
  254. return ret
  255. def dump(obj, file_handler, protocol=DEFAULT_PROTOCOL,
  256. persistent_id=PersistentSharedVariableID):
  257. """Pickles an object to a zip file using external persistence.
  258. :param obj: The object to pickle.
  259. :type obj: object
  260. :param file_handler: The file handle to save the object to.
  261. :type file_handler: file
  262. :param protocol: The pickling protocol to use. Unlike Python's built-in
  263. pickle, the default is set to `2` instead of 0 for Python 2. The
  264. Python 3 default (level 3) is maintained.
  265. :type protocol: int, optional
  266. :param persistent_id: The callable that persists certain objects in the
  267. object hierarchy to separate files inside of the zip file. For example,
  268. :class:`PersistentNdarrayID` saves any :class:`numpy.ndarray` to a
  269. separate NPY file inside of the zip file.
  270. :type persistent_id: callable
  271. .. versionadded:: 0.8
  272. .. note::
  273. The final file is simply a zipped file containing at least one file,
  274. `pkl`, which contains the pickled object. It can contain any other
  275. number of external objects. Note that the zip files are compatible with
  276. NumPy's :func:`numpy.load` function.
  277. >>> import theano
  278. >>> foo_1 = theano.shared(0, name='foo')
  279. >>> foo_2 = theano.shared(1, name='foo')
  280. >>> with open('model.zip', 'wb') as f:
  281. ... dump((foo_1, foo_2, np.array(2)), f)
  282. >>> np.load('model.zip').keys()
  283. ['foo', 'foo_2', 'array_0', 'pkl']
  284. >>> np.load('model.zip')['foo']
  285. array(0)
  286. >>> with open('model.zip', 'rb') as f:
  287. ... foo_1, foo_2, array = load(f)
  288. >>> array
  289. array(2)
  290. """
  291. with closing(zipfile.ZipFile(file_handler, 'w', zipfile.ZIP_DEFLATED,
  292. allowZip64=True)) as zip_file:
  293. def func(f):
  294. p = pickle.Pickler(f, protocol=protocol)
  295. p.persistent_id = persistent_id(zip_file)
  296. p.dump(obj)
  297. zipadd(func, zip_file, 'pkl')
  298. def load(f, persistent_load=PersistentNdarrayLoad):
  299. """Load a file that was dumped to a zip file.
  300. :param f: The file handle to the zip file to load the object from.
  301. :type f: file
  302. :param persistent_load: The persistent loading function to use for
  303. unpickling. This must be compatible with the `persisten_id` function
  304. used when pickling.
  305. :type persistent_load: callable, optional
  306. .. versionadded:: 0.8
  307. """
  308. with closing(zipfile.ZipFile(f, 'r')) as zip_file:
  309. p = pickle.Unpickler(BytesIO(zip_file.open('pkl').read()))
  310. p.persistent_load = persistent_load(zip_file)
  311. return p.load()
  312. def zipadd(func, zip_file, name):
  313. """Calls a function with a file object, saving it to a zip file.
  314. :param func: The function to call.
  315. :type func: callable
  316. :param zip_file: The zip file that `func` should write its data to.
  317. :type zip_file: :class:`zipfile.ZipFile`
  318. :param name: The name of the file inside of the zipped archive that `func`
  319. should save its data to.
  320. :type name: str
  321. """
  322. with tempfile.NamedTemporaryFile('wb', delete=False) as temp_file:
  323. func(temp_file)
  324. temp_file.close()
  325. zip_file.write(temp_file.name, arcname=name)
  326. if os.path.isfile(temp_file.name):
  327. os.remove(temp_file.name)