PageRenderTime 46ms CodeModel.GetById 20ms RepoModel.GetById 1ms app.codeStats 0ms

/tensorflow/core/kernels/sparse_dense_binary_op_shared.cc

https://gitlab.com/hrishikeshvganu/tensorflow
C++ | 179 lines | 119 code | 18 blank | 42 comment | 9 complexity | 43bfc14d001e21309d5f4c5923aee9cf MD5 | raw file
  1. /* Copyright 2016 Google Inc. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. // SparseDenseBinaryOpShared is the shared code for binary coefficient-wise
  13. // (cwise) operations of the following form:
  14. //
  15. // sparse_t <binary cwise op> dense_t -> new sparse_t
  16. //
  17. // where:
  18. //
  19. // (1) "binary cwise op" can be, for example, cdiv, cmul, cfloordiv, etc.
  20. // (2) LIMITATION: we only support broadcasting the dense side to the sparse
  21. // side. In other words, NumDims(sparse_t) >= NumDims(dense_t), and if
  22. // they are equal, each dim size of sparse_t >= that of dense_t.
  23. // (3) Note that the result is a new sparse tensor, which means the implicitly
  24. // zero elements of sparse_t do not participate. (Hence, this should not
  25. // be used for, say, cadd.)
  26. //
  27. // The only output is a vector of flat values with shape [nnz], since this op
  28. // does not change neither the indices nor the shape of the sparse operand.
  29. //
  30. // See docs of all registered ops in ../ops/sparse_ops.cc.
  31. #define EIGEN_USE_THREADS
  32. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
  33. #include "tensorflow/core/framework/op_kernel.h"
  34. #include "tensorflow/core/framework/register_types.h"
  35. #include "tensorflow/core/framework/tensor.h"
  36. #include "tensorflow/core/framework/tensor_util.h"
  37. #include "tensorflow/core/framework/types.h"
  38. #include "tensorflow/core/kernels/cwise_ops.h"
  39. #include "tensorflow/core/kernels/cwise_ops_common.h"
  40. #include "tensorflow/core/util/bcast.h"
  41. using Eigen::TensorRef;
  42. using tensorflow::gtl::ArraySlice;
  43. namespace tensorflow {
  44. typedef Eigen::ThreadPoolDevice CPUDevice;
  45. template <typename Device, typename T, typename Functor>
  46. class SparseDenseBinaryOpShared : public OpKernel {
  47. public:
  48. explicit SparseDenseBinaryOpShared(OpKernelConstruction *ctx)
  49. : OpKernel(ctx) {}
  50. void Compute(OpKernelContext *ctx) override {
  51. const Tensor *indices_t, *values_t, *shape_t, *dense_t;
  52. OP_REQUIRES_OK(ctx, ctx->input("sp_indices", &indices_t));
  53. OP_REQUIRES_OK(ctx, ctx->input("sp_values", &values_t));
  54. OP_REQUIRES_OK(ctx, ctx->input("sp_shape", &shape_t));
  55. OP_REQUIRES_OK(ctx, ctx->input("dense", &dense_t));
  56. // Validations.
  57. OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(indices_t->shape()),
  58. errors::InvalidArgument(
  59. "Input sp_indices should be a matrix but received shape: ",
  60. indices_t->shape().DebugString()));
  61. OP_REQUIRES(ctx, TensorShapeUtils::IsVector(values_t->shape()) &&
  62. TensorShapeUtils::IsVector(shape_t->shape()),
  63. errors::InvalidArgument(
  64. "Inputs sp_values and sp_shape should be vectors "
  65. "but received shapes: ",
  66. values_t->shape().DebugString(), " and ",
  67. shape_t->shape().DebugString()));
  68. OP_REQUIRES(ctx, indices_t->dim_size(0) < std::numeric_limits<int>::max(),
  69. errors::InvalidArgument(
  70. "Number of non-zero elements exceeds int32 range"));
  71. const auto indices_mat = indices_t->matrix<int64>();
  72. const auto shape_vec = shape_t->vec<int64>();
  73. const auto lhs_dims = BCast::FromShape(TensorShape(shape_vec));
  74. const auto rhs_dims = BCast::FromShape(dense_t->shape());
  75. BCast b(lhs_dims, rhs_dims, false); // false for keeping the same num dims.
  76. // True iff (size(lhs) > size(rhs)), or (sizes equal, lhs cwise rhs).
  77. auto VecGreaterEq = [](ArraySlice<int64> lhs, ArraySlice<int64> rhs) {
  78. if (lhs.size() > rhs.size()) return true;
  79. if (lhs.size() < rhs.size()) return false;
  80. for (int i = 0; i < lhs.size(); ++i) {
  81. if (lhs[i] < rhs[i]) return false;
  82. }
  83. return true;
  84. };
  85. OP_REQUIRES(ctx, VecGreaterEq(lhs_dims, rhs_dims) && b.IsValid(),
  86. errors::InvalidArgument(
  87. "SparseDenseBinaryOpShared broadcasts dense to sparse "
  88. "only; got incompatible shapes: [",
  89. str_util::Join(lhs_dims, ","), "] vs. [",
  90. str_util::Join(rhs_dims, ","), "]"));
  91. Tensor *output_values = nullptr;
  92. Tensor dense_gathered;
  93. const int nnz = static_cast<int>(indices_t->dim_size(0));
  94. OP_REQUIRES_OK(ctx,
  95. ctx->allocate_output(0, TensorShape({nnz}), &output_values));
  96. OP_REQUIRES_OK(
  97. ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, TensorShape({nnz}),
  98. &dense_gathered));
  99. // Pulls relevant entries from the dense side, with reshape and broadcasting
  100. // *of the dense side* taken into account. Use a TensorRef to avoid blowing
  101. // up memory.
  102. //
  103. // We can directly use the sparse indices to look up dense side, because
  104. // "b.y_reshape()" and "b.y_bcast()" are guaranteed to have rank "ndims".
  105. auto dense_gathered_flat = dense_gathered.flat<T>();
  106. const int ndims = lhs_dims.size();
  107. switch (ndims) {
  108. #define CASE(NDIM) \
  109. case NDIM: { \
  110. TensorRef<Eigen::Tensor<const T, NDIM, Eigen::RowMajor>> rhs_ref = \
  111. dense_t->shaped<T, NDIM>(b.y_reshape()) \
  112. .broadcast(BCast::ToIndexArray<NDIM>(b.y_bcast())); \
  113. Eigen::array<Eigen::DenseIndex, NDIM> idx; \
  114. bool indices_valid = true; \
  115. for (int i = 0; i < nnz; ++i) { \
  116. for (int d = 0; d < NDIM; ++d) { \
  117. idx[d] = internal::SubtleMustCopy(indices_mat(i, d)); \
  118. if (!FastBoundsCheck(idx[d], rhs_ref.dimension(d))) { \
  119. indices_valid = false; \
  120. } \
  121. } \
  122. OP_REQUIRES( \
  123. ctx, indices_valid, \
  124. errors::InvalidArgument("Provided indices are out-of-bounds w.r.t. " \
  125. "dense side with broadcasted shape")); \
  126. dense_gathered_flat(i) = rhs_ref.coeff(idx); \
  127. } \
  128. break; \
  129. }
  130. CASE(1);
  131. CASE(2);
  132. CASE(3);
  133. CASE(4);
  134. CASE(5);
  135. default:
  136. OP_REQUIRES(ctx, false, errors::InvalidArgument(
  137. "Only tensors with ranks between 1 and 5 "
  138. "are currently supported. Tensor rank: ",
  139. ndims));
  140. #undef CASE
  141. }
  142. output_values->flat<T>().device(ctx->eigen_device<Device>()) =
  143. values_t->flat<T>().binaryExpr(dense_gathered_flat,
  144. typename Functor::func());
  145. }
  146. };
  147. // TODO(zongheng): extend to other eligible cwise operations as requested.
  148. #define REGISTER_KERNELS(T) \
  149. REGISTER_KERNEL_BUILDER( \
  150. Name("SparseDenseCwiseMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
  151. SparseDenseBinaryOpShared<CPUDevice, T, functor::mul<T>>) \
  152. \
  153. REGISTER_KERNEL_BUILDER( \
  154. Name("SparseDenseCwiseDiv").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
  155. SparseDenseBinaryOpShared<CPUDevice, T, functor::div<T>>)
  156. TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
  157. #undef REGISTER_KERNELS
  158. } // namespace tensorflow