PageRenderTime 53ms CodeModel.GetById 19ms RepoModel.GetById 1ms app.codeStats 0ms

/python/mxnet/io.py

https://gitlab.com/alvinahmadov2/mxnet
Python | 410 lines | 330 code | 19 blank | 61 comment | 13 complexity | f3997d497b62f1e276a0fd0f106c520e MD5 | raw file
  1. # coding: utf-8
  2. # pylint: disable=invalid-name, protected-access, fixme, too-many-arguments, W0221, W0201
  3. """NDArray interface of mxnet"""
  4. from __future__ import absolute_import
  5. from collections import namedtuple, OrderedDict
  6. import ctypes
  7. import sys
  8. import numpy as np
  9. import logging
  10. from .base import _LIB
  11. from .base import c_array, c_str, mx_uint, py_str
  12. from .base import DataIterHandle, NDArrayHandle
  13. from .base import check_call, ctypes2docstring
  14. from .ndarray import NDArray
  15. from .ndarray import array
  16. class DataIter(object):
  17. """DataIter object in mxnet. """
  18. def __init__(self):
  19. pass
  20. def __iter__(self):
  21. return self
  22. def reset(self):
  23. """Reset the iterator. """
  24. pass
  25. def next(self):
  26. """Get next data batch from iterator
  27. Returns
  28. -------
  29. data : NDArray
  30. The data of next batch.
  31. label : NDArray
  32. The label of next batch.
  33. """
  34. pass
  35. def __next__(self):
  36. return self.next()
  37. def iter_next(self):
  38. """Iterate to next batch.
  39. Returns
  40. -------
  41. has_next : boolean
  42. Whether the move is successful.
  43. """
  44. pass
  45. def getdata(self, index=0):
  46. """Get data of current batch.
  47. Parameters
  48. ----------
  49. index : int
  50. The index of data source to retrieve.
  51. Returns
  52. -------
  53. data : NDArray
  54. The data of current batch.
  55. """
  56. pass
  57. def getlabel(self):
  58. """Get label of current batch.
  59. Returns
  60. -------
  61. label : NDArray
  62. The label of current batch.
  63. """
  64. return self.getdata(-1)
  65. def getindex(self):
  66. """
  67. Retures
  68. -------
  69. index : numpy.array
  70. The index of current batch
  71. """
  72. pass
  73. def getpad(self):
  74. """Get the number of padding examples in current batch.
  75. Returns
  76. -------
  77. pad : int
  78. Number of padding examples in current batch
  79. """
  80. pass
  81. DataBatch = namedtuple('DataBatch', ['data', 'label', 'pad', 'index'])
  82. def _init_data(data, allow_empty, default_name):
  83. """Convert data into canonical form."""
  84. assert (data is not None) or allow_empty
  85. if data is None:
  86. data = []
  87. if isinstance(data, (np.ndarray, NDArray)):
  88. data = [data]
  89. if isinstance(data, list):
  90. if not allow_empty:
  91. assert(len(data) > 0)
  92. if len(data) == 1:
  93. data = OrderedDict([(default_name, data[0])])
  94. else:
  95. data = OrderedDict([('_%d_%s' % (i, default_name), d) for i, d in enumerate(data)])
  96. if not isinstance(data, dict):
  97. raise TypeError("Input must be NDArray, numpy.ndarray, " + \
  98. "a list of them or dict with them as values")
  99. for k, v in data.items():
  100. if isinstance(v, NDArray):
  101. data[k] = v.asnumpy()
  102. for k, v in data.items():
  103. if not isinstance(v, np.ndarray):
  104. raise TypeError(("Invalid type '%s' for %s, " % (type(v), k)) + \
  105. "should be NDArray or numpy.ndarray")
  106. return list(data.items())
  107. class NDArrayIter(DataIter):
  108. """NDArrayIter object in mxnet. Taking NDArray or numpy array to get dataiter.
  109. Parameters
  110. ----------
  111. data_list or data, label: a list of, or two separate NDArray or numpy.ndarray
  112. list of NDArray for data. The last one is treated as label.
  113. batch_size: int
  114. Batch Size
  115. shuffle: bool
  116. Whether to shuffle the data
  117. data_pad_value: float, optional
  118. Padding value for data
  119. label_pad_value: float, optionl
  120. Padding value for label
  121. last_batch_handle: 'pad', 'discard' or 'roll_over'
  122. How to handle the last batch
  123. Note
  124. ----
  125. This iterator will pad, discard or roll over the last batch if
  126. the size of data does not match batch_size. Roll over is intended
  127. for training and can cause problems if used for prediction.
  128. """
  129. def __init__(self, data, label=None, batch_size=1, shuffle=False, last_batch_handle='pad'):
  130. # pylint: disable=W0201
  131. super(NDArrayIter, self).__init__()
  132. self.data = _init_data(data, allow_empty=False, default_name='data')
  133. self.label = _init_data(label, allow_empty=True, default_name='softmax_label')
  134. # shuffle data
  135. if shuffle:
  136. idx = np.arange(self.data[0][1].shape[0])
  137. np.random.shuffle(idx)
  138. self.data = [(k, v[idx]) for k, v in self.data]
  139. self.label = [(k, v[idx]) for k, v in self.label]
  140. self.data_list = [x[1] for x in self.data] + [x[1] for x in self.label]
  141. self.num_source = len(self.data_list)
  142. # batching
  143. if last_batch_handle == 'discard':
  144. new_n = self.data_list[0].shape[0] - self.data_list[0].shape[0] % batch_size
  145. for k, _ in self.data:
  146. self.data[k] = self.data[k][:new_n]
  147. for k, _ in self.label:
  148. self.label[k] = self.label[k][:new_n]
  149. self.num_data = self.data_list[0].shape[0]
  150. assert self.num_data >= batch_size, \
  151. "batch_size need to be smaller than data size when not padding."
  152. self.cursor = -batch_size
  153. self.batch_size = batch_size
  154. self.last_batch_handle = last_batch_handle
  155. @property
  156. def provide_data(self):
  157. """The name and shape of data provided by this iterator"""
  158. return [(k, tuple([self.batch_size] + list(v.shape[1:]))) for k, v in self.data]
  159. @property
  160. def provide_label(self):
  161. """The name and shape of label provided by this iterator"""
  162. return [(k, tuple([self.batch_size] + list(v.shape[1:]))) for k, v in self.label]
  163. def hard_reset(self):
  164. """Igore roll over data and set to start"""
  165. self.cursor = -self.batch_size
  166. def reset(self):
  167. if self.last_batch_handle == 'roll_over' and self.cursor > self.num_data:
  168. self.cursor = -self.batch_size + (self.cursor%self.num_data)%self.batch_size
  169. else:
  170. self.cursor = -self.batch_size
  171. def iter_next(self):
  172. self.cursor += self.batch_size
  173. if self.cursor < self.num_data:
  174. return True
  175. else:
  176. return False
  177. def next(self):
  178. if self.iter_next():
  179. return DataBatch(data=self.getdata(), label=self.getlabel(), \
  180. pad=self.getpad(), index=None)
  181. else:
  182. raise StopIteration
  183. def _getdata(self, data_source):
  184. """Load data from underlying arrays, internal use only"""
  185. assert(self.cursor < self.num_data), "DataIter needs reset."
  186. if self.cursor + self.batch_size <= self.num_data:
  187. return [array(x[1][self.cursor:self.cursor+self.batch_size]) for x in data_source]
  188. else:
  189. pad = self.batch_size - self.num_data + self.cursor
  190. return [array(np.concatenate((x[1][self.cursor:], x[1][:pad]),
  191. axis=0)) for x in data_source]
  192. def getdata(self):
  193. return self._getdata(self.data)
  194. def getlabel(self):
  195. return self._getdata(self.label)
  196. def getpad(self):
  197. if self.last_batch_handle == 'pad' and \
  198. self.cursor + self.batch_size > self.num_data:
  199. return self.cursor + self.batch_size - self.num_data
  200. else:
  201. return 0
  202. class MXDataIter(DataIter):
  203. """DataIter built in MXNet. List all the needed functions here.
  204. Parameters
  205. ----------
  206. handle : DataIterHandle
  207. the handle to the underlying C++ Data Iterator
  208. """
  209. def __init__(self, handle, data_name='data', label_name='softmax_label', **_):
  210. super(MXDataIter, self).__init__()
  211. self.handle = handle
  212. # debug option, used to test the speed with io effect eliminated
  213. self._debug_skip_load = False
  214. # load the first batch to get shape information
  215. self.first_batch = None
  216. self.first_batch = self.next()
  217. data = self.first_batch.data[0]
  218. label = self.first_batch.label[0]
  219. # properties
  220. self.provide_data = [(data_name, data.shape)]
  221. self.provide_label = [(label_name, label.shape)]
  222. self.batch_size = data.shape[0]
  223. def __del__(self):
  224. check_call(_LIB.MXDataIterFree(self.handle))
  225. def debug_skip_load(self):
  226. """Set the iterator to simply return always first batch.
  227. Notes
  228. -----
  229. This can be used to test the speed of network without taking
  230. the loading delay into account.
  231. """
  232. self._debug_skip_load = True
  233. logging.info('Set debug_skip_load to be true, will simply return first batch')
  234. def reset(self):
  235. self._debug_at_begin = True
  236. self.first_batch = None
  237. check_call(_LIB.MXDataIterBeforeFirst(self.handle))
  238. def next(self):
  239. if self._debug_skip_load and not self._debug_at_begin:
  240. return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad(),
  241. index=self.getindex())
  242. if self.first_batch is not None:
  243. batch = self.first_batch
  244. self.first_batch = None
  245. return batch
  246. self._debug_at_begin = False
  247. next_res = ctypes.c_int(0)
  248. check_call(_LIB.MXDataIterNext(self.handle, ctypes.byref(next_res)))
  249. if next_res.value:
  250. return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad(),
  251. index=self.getindex())
  252. else:
  253. raise StopIteration
  254. def iter_next(self):
  255. if self.first_batch is not None:
  256. return True
  257. next_res = ctypes.c_int(0)
  258. check_call(_LIB.MXDataIterNext(self.handle, ctypes.byref(next_res)))
  259. return next_res.value
  260. def getdata(self):
  261. hdl = NDArrayHandle()
  262. check_call(_LIB.MXDataIterGetData(self.handle, ctypes.byref(hdl)))
  263. return NDArray(hdl, False)
  264. def getlabel(self):
  265. hdl = NDArrayHandle()
  266. check_call(_LIB.MXDataIterGetLabel(self.handle, ctypes.byref(hdl)))
  267. return NDArray(hdl, False)
  268. def getindex(self):
  269. index_size = ctypes.c_uint64(0)
  270. index_data = ctypes.POINTER(ctypes.c_uint64)()
  271. check_call(_LIB.MXDataIterGetIndex(self.handle,
  272. ctypes.byref(index_data),
  273. ctypes.byref(index_size)))
  274. address = ctypes.addressof(index_data.contents)
  275. dbuffer = (ctypes.c_uint64* index_size.value).from_address(address)
  276. np_index = np.frombuffer(dbuffer, dtype=np.uint64)
  277. return np_index.copy()
  278. def getpad(self):
  279. pad = ctypes.c_int(0)
  280. check_call(_LIB.MXDataIterGetPadNum(self.handle, ctypes.byref(pad)))
  281. return pad.value
  282. def _make_io_iterator(handle):
  283. """Create an io iterator by handle."""
  284. name = ctypes.c_char_p()
  285. desc = ctypes.c_char_p()
  286. num_args = mx_uint()
  287. arg_names = ctypes.POINTER(ctypes.c_char_p)()
  288. arg_types = ctypes.POINTER(ctypes.c_char_p)()
  289. arg_descs = ctypes.POINTER(ctypes.c_char_p)()
  290. check_call(_LIB.MXDataIterGetIterInfo( \
  291. handle, ctypes.byref(name), ctypes.byref(desc), \
  292. ctypes.byref(num_args), \
  293. ctypes.byref(arg_names), \
  294. ctypes.byref(arg_types), \
  295. ctypes.byref(arg_descs)))
  296. iter_name = py_str(name.value)
  297. param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs)
  298. doc_str = ('%s\n\n' +
  299. '%s\n' +
  300. 'name : string, required.\n' +
  301. ' Name of the resulting data iterator.\n\n' +
  302. 'Returns\n' +
  303. '-------\n' +
  304. 'iterator: DataIter\n'+
  305. ' The result iterator.')
  306. doc_str = doc_str % (desc.value, param_str)
  307. def creator(*args, **kwargs):
  308. """Create an iterator.
  309. The parameters listed below can be passed in as keyword arguments.
  310. Parameters
  311. ----------
  312. name : string, required.
  313. Name of the resulting data iterator.
  314. Returns
  315. -------
  316. dataiter: Dataiter
  317. the resulting data iterator
  318. """
  319. param_keys = []
  320. param_vals = []
  321. for k, val in kwargs.items():
  322. param_keys.append(c_str(k))
  323. param_vals.append(c_str(str(val)))
  324. # create atomic symbol
  325. param_keys = c_array(ctypes.c_char_p, param_keys)
  326. param_vals = c_array(ctypes.c_char_p, param_vals)
  327. iter_handle = DataIterHandle()
  328. check_call(_LIB.MXDataIterCreateIter(
  329. handle,
  330. mx_uint(len(param_keys)),
  331. param_keys, param_vals,
  332. ctypes.byref(iter_handle)))
  333. if len(args):
  334. raise TypeError('%s can only accept keyword arguments' % iter_name)
  335. return MXDataIter(iter_handle, **kwargs)
  336. creator.__name__ = iter_name
  337. creator.__doc__ = doc_str
  338. return creator
  339. def _init_io_module():
  340. """List and add all the data iterators to current module."""
  341. plist = ctypes.POINTER(ctypes.c_void_p)()
  342. size = ctypes.c_uint()
  343. check_call(_LIB.MXListDataIters(ctypes.byref(size), ctypes.byref(plist)))
  344. module_obj = sys.modules[__name__]
  345. for i in range(size.value):
  346. hdl = ctypes.c_void_p(plist[i])
  347. dataiter = _make_io_iterator(hdl)
  348. setattr(module_obj, dataiter.__name__, dataiter)
  349. # Initialize the io in startups
  350. _init_io_module()