PageRenderTime 61ms CodeModel.GetById 26ms RepoModel.GetById 0ms app.codeStats 0ms

/python/caffe/io.py

https://gitlab.com/philokeys/caffe
Python | 383 lines | 316 code | 18 blank | 49 comment | 22 complexity | c7c965845f7915d61358da91e35e5b60 MD5 | raw file
  1. import numpy as np
  2. import skimage.io
  3. from scipy.ndimage import zoom
  4. from skimage.transform import resize
  5. try:
  6. # Python3 will most likely not be able to load protobuf
  7. from caffe.proto import caffe_pb2
  8. except:
  9. import sys
  10. if sys.version_info >= (3, 0):
  11. print("Failed to include caffe_pb2, things might go wrong!")
  12. else:
  13. raise
  14. ## proto / datum / ndarray conversion
  15. def blobproto_to_array(blob, return_diff=False):
  16. """
  17. Convert a blob proto to an array. In default, we will just return the data,
  18. unless return_diff is True, in which case we will return the diff.
  19. """
  20. # Read the data into an array
  21. if return_diff:
  22. data = np.array(blob.diff)
  23. else:
  24. data = np.array(blob.data)
  25. # Reshape the array
  26. if blob.HasField('num') or blob.HasField('channels') or blob.HasField('height') or blob.HasField('width'):
  27. # Use legacy 4D shape
  28. return data.reshape(blob.num, blob.channels, blob.height, blob.width)
  29. else:
  30. return data.reshape(blob.shape.dim)
  31. def array_to_blobproto(arr, diff=None):
  32. """Converts a N-dimensional array to blob proto. If diff is given, also
  33. convert the diff. You need to make sure that arr and diff have the same
  34. shape, and this function does not do sanity check.
  35. """
  36. blob = caffe_pb2.BlobProto()
  37. blob.shape.dim.extend(arr.shape)
  38. blob.data.extend(arr.astype(float).flat)
  39. if diff is not None:
  40. blob.diff.extend(diff.astype(float).flat)
  41. return blob
  42. def arraylist_to_blobprotovector_str(arraylist):
  43. """Converts a list of arrays to a serialized blobprotovec, which could be
  44. then passed to a network for processing.
  45. """
  46. vec = caffe_pb2.BlobProtoVector()
  47. vec.blobs.extend([array_to_blobproto(arr) for arr in arraylist])
  48. return vec.SerializeToString()
  49. def blobprotovector_str_to_arraylist(str):
  50. """Converts a serialized blobprotovec to a list of arrays.
  51. """
  52. vec = caffe_pb2.BlobProtoVector()
  53. vec.ParseFromString(str)
  54. return [blobproto_to_array(blob) for blob in vec.blobs]
  55. def array_to_datum(arr, label=None):
  56. """Converts a 3-dimensional array to datum. If the array has dtype uint8,
  57. the output data will be encoded as a string. Otherwise, the output data
  58. will be stored in float format.
  59. """
  60. if arr.ndim != 3:
  61. raise ValueError('Incorrect array shape.')
  62. datum = caffe_pb2.Datum()
  63. datum.channels, datum.height, datum.width = arr.shape
  64. if arr.dtype == np.uint8:
  65. datum.data = arr.tostring()
  66. else:
  67. datum.float_data.extend(arr.flat)
  68. if label is not None:
  69. datum.label = label
  70. return datum
  71. def datum_to_array(datum):
  72. """Converts a datum to an array. Note that the label is not returned,
  73. as one can easily get it by calling datum.label.
  74. """
  75. if len(datum.data):
  76. return np.fromstring(datum.data, dtype=np.uint8).reshape(
  77. datum.channels, datum.height, datum.width)
  78. else:
  79. return np.array(datum.float_data).astype(float).reshape(
  80. datum.channels, datum.height, datum.width)
  81. ## Pre-processing
  82. class Transformer:
  83. """
  84. Transform input for feeding into a Net.
  85. Note: this is mostly for illustrative purposes and it is likely better
  86. to define your own input preprocessing routine for your needs.
  87. Parameters
  88. ----------
  89. net : a Net for which the input should be prepared
  90. """
  91. def __init__(self, inputs):
  92. self.inputs = inputs
  93. self.transpose = {}
  94. self.channel_swap = {}
  95. self.raw_scale = {}
  96. self.mean = {}
  97. self.input_scale = {}
  98. def __check_input(self, in_):
  99. if in_ not in self.inputs:
  100. raise Exception('{} is not one of the net inputs: {}'.format(
  101. in_, self.inputs))
  102. def preprocess(self, in_, data):
  103. """
  104. Format input for Caffe:
  105. - convert to single
  106. - resize to input dimensions (preserving number of channels)
  107. - transpose dimensions to K x H x W
  108. - reorder channels (for instance color to BGR)
  109. - scale raw input (e.g. from [0, 1] to [0, 255] for ImageNet models)
  110. - subtract mean
  111. - scale feature
  112. Parameters
  113. ----------
  114. in_ : name of input blob to preprocess for
  115. data : (H' x W' x K) ndarray
  116. Returns
  117. -------
  118. caffe_in : (K x H x W) ndarray for input to a Net
  119. """
  120. self.__check_input(in_)
  121. caffe_in = data.astype(np.float32, copy=False)
  122. transpose = self.transpose.get(in_)
  123. channel_swap = self.channel_swap.get(in_)
  124. raw_scale = self.raw_scale.get(in_)
  125. mean = self.mean.get(in_)
  126. input_scale = self.input_scale.get(in_)
  127. in_dims = self.inputs[in_][2:]
  128. if caffe_in.shape[:2] != in_dims:
  129. caffe_in = resize_image(caffe_in, in_dims)
  130. if transpose is not None:
  131. caffe_in = caffe_in.transpose(transpose)
  132. if channel_swap is not None:
  133. caffe_in = caffe_in[channel_swap, :, :]
  134. if raw_scale is not None:
  135. caffe_in *= raw_scale
  136. if mean is not None:
  137. caffe_in -= mean
  138. if input_scale is not None:
  139. caffe_in *= input_scale
  140. return caffe_in
  141. def deprocess(self, in_, data):
  142. """
  143. Invert Caffe formatting; see preprocess().
  144. """
  145. self.__check_input(in_)
  146. decaf_in = data.copy().squeeze()
  147. transpose = self.transpose.get(in_)
  148. channel_swap = self.channel_swap.get(in_)
  149. raw_scale = self.raw_scale.get(in_)
  150. mean = self.mean.get(in_)
  151. input_scale = self.input_scale.get(in_)
  152. if input_scale is not None:
  153. decaf_in /= input_scale
  154. if mean is not None:
  155. decaf_in += mean
  156. if raw_scale is not None:
  157. decaf_in /= raw_scale
  158. if channel_swap is not None:
  159. decaf_in = decaf_in[np.argsort(channel_swap), :, :]
  160. if transpose is not None:
  161. decaf_in = decaf_in.transpose(np.argsort(transpose))
  162. return decaf_in
  163. def set_transpose(self, in_, order):
  164. """
  165. Set the input channel order for e.g. RGB to BGR conversion
  166. as needed for the reference ImageNet model.
  167. Parameters
  168. ----------
  169. in_ : which input to assign this channel order
  170. order : the order to transpose the dimensions
  171. """
  172. self.__check_input(in_)
  173. if len(order) != len(self.inputs[in_]) - 1:
  174. raise Exception('Transpose order needs to have the same number of '
  175. 'dimensions as the input.')
  176. self.transpose[in_] = order
  177. def set_channel_swap(self, in_, order):
  178. """
  179. Set the input channel order for e.g. RGB to BGR conversion
  180. as needed for the reference ImageNet model.
  181. N.B. this assumes the channels are the first dimension AFTER transpose.
  182. Parameters
  183. ----------
  184. in_ : which input to assign this channel order
  185. order : the order to take the channels.
  186. (2,1,0) maps RGB to BGR for example.
  187. """
  188. self.__check_input(in_)
  189. if len(order) != self.inputs[in_][1]:
  190. raise Exception('Channel swap needs to have the same number of '
  191. 'dimensions as the input channels.')
  192. self.channel_swap[in_] = order
  193. def set_raw_scale(self, in_, scale):
  194. """
  195. Set the scale of raw features s.t. the input blob = input * scale.
  196. While Python represents images in [0, 1], certain Caffe models
  197. like CaffeNet and AlexNet represent images in [0, 255] so the raw_scale
  198. of these models must be 255.
  199. Parameters
  200. ----------
  201. in_ : which input to assign this scale factor
  202. scale : scale coefficient
  203. """
  204. self.__check_input(in_)
  205. self.raw_scale[in_] = scale
  206. def set_mean(self, in_, mean):
  207. """
  208. Set the mean to subtract for centering the data.
  209. Parameters
  210. ----------
  211. in_ : which input to assign this mean.
  212. mean : mean ndarray (input dimensional or broadcastable)
  213. """
  214. self.__check_input(in_)
  215. ms = mean.shape
  216. if mean.ndim == 1:
  217. # broadcast channels
  218. if ms[0] != self.inputs[in_][1]:
  219. raise ValueError('Mean channels incompatible with input.')
  220. mean = mean[:, np.newaxis, np.newaxis]
  221. else:
  222. # elementwise mean
  223. if len(ms) == 2:
  224. ms = (1,) + ms
  225. if len(ms) != 3:
  226. raise ValueError('Mean shape invalid')
  227. if ms != self.inputs[in_][1:]:
  228. raise ValueError('Mean shape incompatible with input shape.')
  229. self.mean[in_] = mean
  230. def set_input_scale(self, in_, scale):
  231. """
  232. Set the scale of preprocessed inputs s.t. the blob = blob * scale.
  233. N.B. input_scale is done AFTER mean subtraction and other preprocessing
  234. while raw_scale is done BEFORE.
  235. Parameters
  236. ----------
  237. in_ : which input to assign this scale factor
  238. scale : scale coefficient
  239. """
  240. self.__check_input(in_)
  241. self.input_scale[in_] = scale
  242. ## Image IO
  243. def load_image(filename, color=True):
  244. """
  245. Load an image converting from grayscale or alpha as needed.
  246. Parameters
  247. ----------
  248. filename : string
  249. color : boolean
  250. flag for color format. True (default) loads as RGB while False
  251. loads as intensity (if image is already grayscale).
  252. Returns
  253. -------
  254. image : an image with type np.float32 in range [0, 1]
  255. of size (H x W x 3) in RGB or
  256. of size (H x W x 1) in grayscale.
  257. """
  258. img = skimage.img_as_float(skimage.io.imread(filename, as_grey=not color)).astype(np.float32)
  259. if img.ndim == 2:
  260. img = img[:, :, np.newaxis]
  261. if color:
  262. img = np.tile(img, (1, 1, 3))
  263. elif img.shape[2] == 4:
  264. img = img[:, :, :3]
  265. return img
  266. def resize_image(im, new_dims, interp_order=1):
  267. """
  268. Resize an image array with interpolation.
  269. Parameters
  270. ----------
  271. im : (H x W x K) ndarray
  272. new_dims : (height, width) tuple of new dimensions.
  273. interp_order : interpolation order, default is linear.
  274. Returns
  275. -------
  276. im : resized ndarray with shape (new_dims[0], new_dims[1], K)
  277. """
  278. if im.shape[-1] == 1 or im.shape[-1] == 3:
  279. im_min, im_max = im.min(), im.max()
  280. if im_max > im_min:
  281. # skimage is fast but only understands {1,3} channel images
  282. # in [0, 1].
  283. im_std = (im - im_min) / (im_max - im_min)
  284. resized_std = resize(im_std, new_dims, order=interp_order)
  285. resized_im = resized_std * (im_max - im_min) + im_min
  286. else:
  287. # the image is a constant -- avoid divide by 0
  288. ret = np.empty((new_dims[0], new_dims[1], im.shape[-1]),
  289. dtype=np.float32)
  290. ret.fill(im_min)
  291. return ret
  292. else:
  293. # ndimage interpolates anything but more slowly.
  294. scale = tuple(np.array(new_dims, dtype=float) / np.array(im.shape[:2]))
  295. resized_im = zoom(im, scale + (1,), order=interp_order)
  296. return resized_im.astype(np.float32)
  297. def oversample(images, crop_dims):
  298. """
  299. Crop images into the four corners, center, and their mirrored versions.
  300. Parameters
  301. ----------
  302. image : iterable of (H x W x K) ndarrays
  303. crop_dims : (height, width) tuple for the crops.
  304. Returns
  305. -------
  306. crops : (10*N x H x W x K) ndarray of crops for number of inputs N.
  307. """
  308. # Dimensions and center.
  309. im_shape = np.array(images[0].shape)
  310. crop_dims = np.array(crop_dims)
  311. im_center = im_shape[:2] / 2.0
  312. # Make crop coordinates
  313. h_indices = (0, im_shape[0] - crop_dims[0])
  314. w_indices = (0, im_shape[1] - crop_dims[1])
  315. crops_ix = np.empty((5, 4), dtype=int)
  316. curr = 0
  317. for i in h_indices:
  318. for j in w_indices:
  319. crops_ix[curr] = (i, j, i + crop_dims[0], j + crop_dims[1])
  320. curr += 1
  321. crops_ix[4] = np.tile(im_center, (1, 2)) + np.concatenate([
  322. -crop_dims / 2.0,
  323. crop_dims / 2.0
  324. ])
  325. crops_ix = np.tile(crops_ix, (2, 1))
  326. # Extract crops
  327. crops = np.empty((10 * len(images), crop_dims[0], crop_dims[1],
  328. im_shape[-1]), dtype=np.float32)
  329. ix = 0
  330. for im in images:
  331. for crop in crops_ix:
  332. crops[ix] = im[crop[0]:crop[2], crop[1]:crop[3], :]
  333. ix += 1
  334. crops[ix-5:ix] = crops[ix-5:ix, :, ::-1, :] # flip for mirrors
  335. return crops