/libnd4j/include/array/impl/ShapeDescriptor.cpp

https://github.com/deeplearning4j/deeplearning4j · C++ · 380 lines · 280 code · 71 blank · 29 comment · 62 complexity · 63a27e914e0804ee14ad9bfd757d9267 MD5 · raw file

  1. /*******************************************************************************
  2. * Copyright (c) 2015-2018 Skymind, Inc.
  3. *
  4. * This program and the accompanying materials are made available under the
  5. * terms of the Apache License, Version 2.0 which is available at
  6. * https://www.apache.org/licenses/LICENSE-2.0.
  7. *
  8. * Unless required by applicable law or agreed to in writing, software
  9. * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
  10. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
  11. * License for the specific language governing permissions and limitations
  12. * under the License.
  13. *
  14. * SPDX-License-Identifier: Apache-2.0
  15. ******************************************************************************/
  16. //
  17. // @author raver119@gmail.com
  18. //
  19. #include <array/ShapeDescriptor.h>
  20. #include <helpers/shape.h>
  21. #include <helpers/ShapeBuilders.h>
  22. namespace sd {
  23. //////////////////////////////////////////////////////////////////////////
  24. // equal to operator
  25. bool ShapeDescriptor::operator==(const ShapeDescriptor &other) const {
  26. if (_empty != other._empty)
  27. return false;
  28. if (_rank != other._rank)
  29. return false;
  30. if (_order != other._order)
  31. return false;
  32. if (_dataType != other._dataType)
  33. return false;
  34. if (_ews != other._ews)
  35. return false;
  36. if (_shape != other._shape)
  37. return false;
  38. if (_strides != other._strides)
  39. return false;
  40. return true;
  41. }
  42. //////////////////////////////////////////////////////////////////////////
  43. // less than operator
  44. bool ShapeDescriptor::operator<(const ShapeDescriptor &other) const {
  45. return std::tie(_empty, _rank, _dataType, _ews, _order, _shape, _strides) <
  46. std::tie(other._empty, other._rank, other._dataType, other._ews, other._order, other._shape,
  47. other._strides);
  48. }
  49. Nd4jLong *ShapeDescriptor::toShapeInfo() const {
  50. if (_empty) {
  51. if (_rank == 0)
  52. return ShapeBuilders::emptyShapeInfo(_dataType);
  53. else {
  54. return ShapeBuilders::emptyShapeInfo(_dataType, _order, _shape);
  55. }
  56. }
  57. switch (_rank) {
  58. case 0: {
  59. auto shapeInfo = ShapeBuilders::createScalarShapeInfo(_dataType);
  60. shapeInfo[2] = _ews;
  61. return shapeInfo;
  62. }
  63. case 1: {
  64. auto shapeInfo = ShapeBuilders::createVectorShapeInfo(_dataType, _shape[0]);
  65. shapeInfo[2 + _rank * 2] = _ews;
  66. shapeInfo[2] = _strides[0];
  67. shapeInfo[2 + _rank * 2 + 1] = _order;
  68. return shapeInfo;
  69. }
  70. default: {
  71. auto shapeInfo = ShapeBuilders::createShapeInfo(_dataType, _order, _shape);
  72. for (int e = 0; e < _rank; e++)
  73. shapeInfo[e + 1 + _rank] = _strides[e];
  74. shapeInfo[2 + _rank * 2] = _ews;
  75. return shapeInfo;
  76. }
  77. }
  78. }
  79. ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const Nd4jLong *shape, const int rank)
  80. : _dataType(type), _order(order), _rank(rank), _ews(1) {
  81. _shape.resize(rank);
  82. _strides.resize(rank);
  83. for (int e = 0; e < rank; e++)
  84. _shape[e] = shape[e];
  85. if (order == 'c')
  86. shape::calcStrides(_shape.data(), _shape.size(), _strides.data());
  87. else
  88. shape::calcStridesFortran(_shape.data(), _shape.size(), _strides.data());
  89. for (auto v:_shape) {
  90. if (v == 0) {
  91. _empty = true;
  92. break;
  93. }
  94. }
  95. }
  96. ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const Nd4jLong *shape,
  97. const Nd4jLong *strides, const int rank, Nd4jLong ews, const bool empty) {
  98. _shape.resize(rank);
  99. _strides.resize(rank);
  100. _dataType = type;
  101. _order = order;
  102. _rank = rank;
  103. _empty = empty;
  104. _ews = ews;
  105. for (int e = 0; e < rank; e++)
  106. _shape[e] = shape[e];
  107. for (int e = 0; e < rank; e++)
  108. _strides[e] = strides[e];
  109. for (auto v:_shape) {
  110. if (v == 0) {
  111. _empty = true;
  112. break;
  113. }
  114. }
  115. }
  116. //////////////////////////////////////////////////////////////////////////
  117. ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector<Nd4jLong> &shape)
  118. : _dataType(type), _order(order), _shape(shape) {
  119. _rank = shape.size();
  120. _ews = 1;
  121. if (_rank > 0) {
  122. _strides.resize(_rank);
  123. for (auto v:_shape) {
  124. if (v == 0) {
  125. _empty = true;
  126. break;
  127. }
  128. }
  129. // no point calculating strides for empty arrays
  130. if (!_empty) {
  131. if (order == 'c')
  132. shape::calcStrides(_shape.data(), shape.size(), _strides.data());
  133. else
  134. shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data());
  135. } else {
  136. // all strides set to 0
  137. memset(_strides.data(), 0, sizeof(Nd4jLong) * shape.size());
  138. }
  139. }
  140. }
  141. //////////////////////////////////////////////////////////////////////////
  142. ShapeDescriptor::ShapeDescriptor(const DataType type, const char order,
  143. const std::initializer_list<Nd4jLong> &shape) : _dataType(type), _order(order),
  144. _shape(shape) {
  145. _rank = shape.size();
  146. _ews = 1;
  147. _strides.resize(shape.size());
  148. if (order == 'c')
  149. shape::calcStrides(_shape.data(), shape.size(), _strides.data());
  150. else
  151. shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data());
  152. for (auto v:_shape) {
  153. if (v == 0) {
  154. _empty = true;
  155. break;
  156. }
  157. }
  158. }
  159. //////////////////////////////////////////////////////////////////////////
  160. ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector<Nd4jLong> &shape,
  161. const std::vector<Nd4jLong> &strides, const Nd4jLong ews) : ShapeDescriptor(type,
  162. order,
  163. shape,
  164. strides) {
  165. _ews = ews;
  166. }
  167. ShapeDescriptor::ShapeDescriptor(const DataType type, const Nd4jLong length) : _dataType(type), _ews(1),
  168. _order('c'), _rank(1),
  169. _empty(false) {
  170. _shape = {length};
  171. _strides = {1};
  172. }
  173. ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, bool inheritDtype) {
  174. _order = shape::order(shapeInfo);
  175. _ews = shape::elementWiseStride(shapeInfo);
  176. _rank = shape::rank(shapeInfo);
  177. if (inheritDtype)
  178. _dataType = ArrayOptions::dataType(shapeInfo);
  179. _empty = shape::isEmpty(shapeInfo);
  180. for (int e = 0; e < _rank; e++) {
  181. _shape.emplace_back(shapeInfo[e + 1]);
  182. if (shapeInfo[e + 1] == 0)
  183. _empty = true;
  184. }
  185. for (int e = 0; e < _rank; e++)
  186. _strides.emplace_back(shapeInfo[e + 1 + _rank]);
  187. }
  188. ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, const sd::DataType dtypeOverride)
  189. : ShapeDescriptor::ShapeDescriptor(shapeInfo, false) {
  190. _dataType = dtypeOverride;
  191. }
  192. ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, const Nd4jLong *dtypeOverride)
  193. : ShapeDescriptor::ShapeDescriptor(shapeInfo, ArrayOptions::dataType(dtypeOverride)) {
  194. //
  195. }
  196. ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, const Nd4jLong *dtypeOverride,
  197. const Nd4jLong *orderOverride) : ShapeDescriptor::ShapeDescriptor(shapeInfo,
  198. ArrayOptions::dataType(
  199. dtypeOverride)) {
  200. _order = shape::order(orderOverride);
  201. }
  202. int ShapeDescriptor::rank() const {
  203. return _rank;
  204. }
  205. Nd4jLong ShapeDescriptor::ews() const {
  206. return _ews;
  207. }
  208. Nd4jLong ShapeDescriptor::arrLength() const {
  209. Nd4jLong len = 1;
  210. for (const auto &dim : const_cast<ShapeDescriptor *>(this)->shape())
  211. len *= dim;
  212. return len;
  213. }
  214. char ShapeDescriptor::order() const {
  215. return _order;
  216. }
  217. DataType ShapeDescriptor::dataType() const {
  218. return _dataType;
  219. }
  220. bool ShapeDescriptor::isEmpty() const {
  221. return _empty;
  222. }
  223. std::vector<Nd4jLong> &ShapeDescriptor::shape() {
  224. return _shape;
  225. }
  226. std::vector<Nd4jLong> &ShapeDescriptor::strides() {
  227. return _strides;
  228. }
  229. ShapeDescriptor::ShapeDescriptor(const ShapeDescriptor &other) {
  230. _rank = other._rank;
  231. _ews = other._ews;
  232. _empty = other._empty;
  233. _dataType = other._dataType;
  234. _order = other._order;
  235. _shape = other._shape;
  236. _strides = other._strides;
  237. }
  238. //////////////////////////////////////////////////////////////////////////
  239. ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector<Nd4jLong> &shape,
  240. const std::vector<Nd4jLong> &strides) : _dataType(type), _order(order),
  241. _shape(shape) {
  242. if (strides.empty() && !shape.empty()) {
  243. _strides.resize(shape.size());
  244. if (order == 'c')
  245. shape::calcStrides(_shape.data(), shape.size(), _strides.data());
  246. else
  247. shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data());
  248. } else {
  249. _strides = strides;
  250. }
  251. for (auto v:_shape) {
  252. if (v == 0) {
  253. _empty = true;
  254. break;
  255. }
  256. }
  257. }
  258. ShapeDescriptor ShapeDescriptor::emptyDescriptor(const DataType type) {
  259. ShapeDescriptor descriptor;
  260. descriptor._dataType = type;
  261. descriptor._empty = true;
  262. descriptor._rank = 0;
  263. descriptor._order = 'c';
  264. descriptor._ews = 1;
  265. return descriptor;
  266. }
  267. ShapeDescriptor ShapeDescriptor::scalarDescriptor(const DataType type) {
  268. ShapeDescriptor descriptor;
  269. descriptor._dataType = type;
  270. descriptor._empty = false;
  271. descriptor._rank = 0;
  272. descriptor._order = 'c';
  273. descriptor._ews = 1;
  274. return descriptor;
  275. }
  276. ShapeDescriptor ShapeDescriptor::vectorDescriptor(const Nd4jLong length, const DataType type) {
  277. ShapeDescriptor descriptor;
  278. descriptor._dataType = type;
  279. descriptor._shape.emplace_back(length);
  280. if (length > 0)
  281. descriptor._strides.emplace_back(1);
  282. else {
  283. descriptor._strides.emplace_back(0);
  284. descriptor._empty = true;
  285. }
  286. descriptor._order = 'c';
  287. descriptor._ews = 1;
  288. descriptor._rank = 1;
  289. return descriptor;
  290. }
  291. }
  292. namespace std {
  293. size_t hash<sd::ShapeDescriptor>::operator()(const sd::ShapeDescriptor &k) const {
  294. auto res = std::hash<Nd4jLong>()(k.arrLength());
  295. res ^= std::hash<char>()(k.order()) + 0x9e3779b9 + (res << 6) + (res >> 2);
  296. res ^= k.dataType() + 0x9e3779b9 + (res << 6) + (res >> 2);
  297. res ^= std::hash<int>()(k.rank()) + 0x9e3779b9 + (res << 6) + (res >> 2);
  298. res ^= std::hash<Nd4jLong>()(k.ews()) + 0x9e3779b9 + (res << 6) + (res >> 2);
  299. auto shapes = const_cast<sd::ShapeDescriptor&>(k).shape();
  300. auto strides = const_cast<sd::ShapeDescriptor&>(k).strides();
  301. for (auto s: shapes) {
  302. res ^= std::hash<Nd4jLong>()(s) + 0x9e3779b9 + (res << 6) + (res >> 2);
  303. }
  304. for (auto s: strides) {
  305. res ^= std::hash<Nd4jLong>()(s) + 0x9e3779b9 + (res << 6) + (res >> 2);
  306. }
  307. return res;
  308. }
  309. }