/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h

https://github.com/microsoft/onnxruntime · C Header · 214 lines · 158 code · 32 blank · 24 comment · 26 complexity · 9df693fdb2686131a9e76a9053efec70 MD5 · raw file

  1. // Copyright (c) Microsoft Corporation. All rights reserved.
  2. // Licensed under the MIT License.
  3. #pragma once
  4. #include "attention_base.h"
  5. #include "attention_helper.h"
  6. #include "core/common/common.h"
  7. #include "core/common/safeint.h"
  8. #include "core/framework/op_kernel.h"
  9. namespace onnxruntime {
  10. namespace contrib {
  11. class AttentionCPUBase : public AttentionBase {
  12. protected:
  13. AttentionCPUBase(const OpKernelInfo& info) : AttentionBase(info) {}
  14. template <typename T>
  15. Status ApplyAttention(const T* Q, // Q data. Its size is BxNxSxH
  16. const T* K, // K data. Its size is BxNxSxH
  17. const T* V, // V value with size BxNxSxH
  18. const Tensor* mask_index, // mask index. nullptr if no mask or its size is B
  19. const Tensor* past, // past state
  20. Tensor* output, // output tensor
  21. int batch_size, // batch size
  22. int sequence_length, // sequence length
  23. int head_size, // head size
  24. int hidden_size, // hidden size
  25. OpKernelContext* context) const {
  26. AllocatorPtr allocator;
  27. ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
  28. auto* tp = context->GetOperatorThreadPool();
  29. int past_sequence_length = 0;
  30. Tensor* present = GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length);
  31. // Total sequence length including that of past state: S* = S' + S
  32. const int all_sequence_length = past_sequence_length + sequence_length;
  33. // Compute the attention score. It does 2 things:
  34. // I. attention_probs(B, N, S, S*) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, S*, H -> B, N, H, S*) +
  35. // 1 x mask_data(B, N, S, S*)
  36. // II.attention_probs(B, N, S, S*) = Softmax(attention_probs)
  37. size_t attention_probs_bytes = SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * all_sequence_length * sizeof(T);
  38. auto attention_probs = allocator->Alloc(attention_probs_bytes);
  39. BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator));
  40. void* mask_data = nullptr;
  41. if (mask_index != nullptr || (is_unidirectional_ && sequence_length > 1)) {
  42. size_t mask_data_bytes = SafeInt<size_t>(batch_size) * sequence_length * all_sequence_length * sizeof(T);
  43. mask_data = allocator->Alloc(mask_data_bytes);
  44. memset(mask_data, 0, mask_data_bytes);
  45. }
  46. BufferUniquePtr mask_data_buffer(mask_data, BufferDeleter(allocator));
  47. const int32_t* mask_index_data = mask_index != nullptr ? mask_index->template Data<int32_t>() : nullptr;
  48. const std::vector<int64_t>* mask_index_dims = mask_index != nullptr ? &(mask_index->Shape().GetDims()) : nullptr;
  49. const T* past_data = past != nullptr ? past->template Data<T>() : nullptr;
  50. T* present_data = present != nullptr ? present->template MutableData<T>() : nullptr;
  51. ComputeAttentionProbs<T>(static_cast<T*>(attention_probs), Q, K,
  52. mask_index_data, mask_index_dims, static_cast<T*>(mask_data),
  53. batch_size, sequence_length, past_sequence_length, head_size,
  54. past_data, present_data, tp);
  55. // Compute the attentionScore * Value. It does: out_tmp(B, N, S, H) = attention_probs(B, N, S, S*) x V(B, N, S*, H)
  56. auto out_tmp_data =
  57. allocator->Alloc(SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * head_size * sizeof(T));
  58. BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(allocator));
  59. ComputeVxAttentionScore(output->template MutableData<T>(), static_cast<T*>(out_tmp_data), static_cast<T*>(attention_probs), V,
  60. batch_size, sequence_length, past_sequence_length, head_size, hidden_size,
  61. past_data, present_data, tp);
  62. return Status::OK();
  63. }
  64. private:
  65. // Helper function to compute the attention probs. It does 2 things:
  66. // I. attention_probs(B, N, S, S*) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, S*, H -> B, N, H, S*) +
  67. // 1 x mask_data(B, N, S, S*)
  68. // II.attention_probs(B, N, S, S*) = Softmax(attention_probs)
  69. template <typename T>
  70. void ComputeAttentionProbs(T* attention_probs, // output buffer for the attention probs. Its size is BxNxSxS
  71. const T* Q, // Q data. Its size is BxNxSxH
  72. const T* K, // k data. Its size is BxNxSxH
  73. const int32_t* mask_index, // mask index. nullptr if no mask or its size is B
  74. const std::vector<int64_t>* mask_index_dims, // mask index shape
  75. T* mask_data, // buffer for mask data. Its size is: SxS* if is_unidirectional_; BxSxS* if mask_index; null otherwise
  76. int batch_size, // batch size of self-attention
  77. int sequence_length, // sequence length of self-attention
  78. int past_sequence_length, // sequence length of past state
  79. int head_size, // head size of self-attention
  80. const T* past, // past state
  81. T* present, // present state
  82. ThreadPool* tp) const {
  83. const int all_sequence_length = past_sequence_length + sequence_length; // S* = S' + S
  84. const size_t past_chunk_length = static_cast<size_t>(past_sequence_length * head_size); // S' x H
  85. const size_t input_chunk_length = static_cast<size_t>(sequence_length * head_size); // S x H
  86. const size_t present_chunk_length = past_chunk_length + input_chunk_length; // S* x H
  87. {
  88. if (mask_data != nullptr) {
  89. PrepareMask(mask_index, mask_index_dims, mask_data, is_unidirectional_, batch_size, sequence_length, past_sequence_length);
  90. } else { // no any mask
  91. memset(attention_probs, 0, batch_size * num_heads_ * sequence_length * all_sequence_length * sizeof(T));
  92. }
  93. const int loop_len = batch_size * num_heads_;
  94. const float alpha = 1.0f / sqrt(static_cast<float>(head_size));
  95. // The cost of Gemm
  96. const double cost = static_cast<double>(head_size * sequence_length * all_sequence_length);
  97. ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
  98. for (std::ptrdiff_t i = begin; i != end; ++i) {
  99. const std::ptrdiff_t batch_index = i / num_heads_;
  100. // broadcast mask data: (Bx)SxS* -> (BxNx)SxS*
  101. if (mask_data != nullptr) {
  102. const T* broadcast_data_src = reinterpret_cast<T*>(mask_data) + batch_index * sequence_length * all_sequence_length;
  103. T* broadcast_data_dest = reinterpret_cast<T*>(attention_probs) + sequence_length * all_sequence_length * i;
  104. memcpy(broadcast_data_dest, broadcast_data_src, sequence_length * all_sequence_length * sizeof(T));
  105. }
  106. const T* k = K + input_chunk_length * i;
  107. if (nullptr != present) {
  108. // concatenate past_K and K : (BxNx)S'xH, (BxNx)SxH -> (BxNx)S*xH
  109. k = ConcatStateChunk(past, k, present, past_chunk_length, present_chunk_length, i);
  110. }
  111. // gemm
  112. // original transposed each iteration
  113. // A: Q (B x N x) S x H (B x N x) S x H S x H
  114. // B: K' (B x N x) S* x H (B x N x) H x S* H x S*
  115. // C: attention_probs (B x N x) S x S* (B x N x) S x S* S x S*
  116. math::Gemm<T, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, all_sequence_length, head_size, alpha,
  117. Q + input_chunk_length * i, k, 1.0,
  118. reinterpret_cast<T*>(attention_probs) + sequence_length * all_sequence_length * i, nullptr);
  119. }
  120. });
  121. }
  122. // attention_probs(B, N, S, S*) = Softmax(attention_probs)
  123. {
  124. const int N = batch_size * num_heads_ * sequence_length;
  125. const int D = all_sequence_length;
  126. ComputeAttentionSoftmaxInplace(attention_probs, N, D, tp);
  127. }
  128. }
  129. template <typename T>
  130. void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH
  131. T* tmp_buffer, // buffer for temp use with size is BxNxSxH
  132. const T* attention_probs, // Attention probs with size BxNxSxS*
  133. const T* V, // V value with size BxNxSxH
  134. int batch_size, // batch size
  135. int sequence_length, // sequence length
  136. int past_sequence_length, // sequence length in past state
  137. int head_size, // head size
  138. int hidden_size, // hidden size
  139. const T* past, // past state
  140. T* present, // present state
  141. ThreadPool* tp) const {
  142. const int all_sequence_length = past_sequence_length + sequence_length; // S* = S' + S
  143. const size_t past_chunk_length = static_cast<size_t>(past_sequence_length * head_size); // S' x H
  144. const size_t input_chunk_length = static_cast<size_t>(sequence_length * head_size); // S x H
  145. const size_t present_chunk_length = past_chunk_length + input_chunk_length; // S* x H
  146. // Move the pointer of past and present to start of v values.
  147. if (nullptr != past) {
  148. past += batch_size * num_heads_ * past_sequence_length * head_size;
  149. }
  150. if (nullptr != present) {
  151. present += batch_size * num_heads_ * all_sequence_length * head_size;
  152. }
  153. const double cost =
  154. static_cast<double>(sequence_length) * static_cast<double>(head_size) * static_cast<double>(sequence_length);
  155. ThreadPool::TryParallelFor(tp, batch_size * num_heads_, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
  156. for (std::ptrdiff_t i = begin; i != end; ++i) {
  157. const T* v = V + input_chunk_length * i;
  158. if (nullptr != present) {
  159. // concatenate past_V and V: (BxNx)S'xH, (BxNx)SxH -> (BxNx)S*xH
  160. v = ConcatStateChunk(past, v, present, past_chunk_length, present_chunk_length, i);
  161. }
  162. T* current_tmp_data = reinterpret_cast<T*>(tmp_buffer) + input_chunk_length * i;
  163. math::MatMul<T>(sequence_length, head_size, all_sequence_length,
  164. attention_probs + sequence_length * all_sequence_length * i,
  165. v, current_tmp_data, nullptr);
  166. // transpose: out(B, S, N, H) = transpose out_tmp(B, N, S, H)
  167. const int batch_index = static_cast<int>(i / num_heads_);
  168. const int head_index = static_cast<int>(i % num_heads_);
  169. T* src = current_tmp_data;
  170. T* dest = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
  171. const auto bytes_to_copy = SafeInt<size_t>(head_size) * sizeof(T);
  172. for (int j = 0; j < sequence_length; j++) {
  173. memcpy(dest, src, bytes_to_copy);
  174. src += head_size;
  175. dest += hidden_size;
  176. }
  177. }
  178. });
  179. }
  180. };
  181. } // namespace contrib
  182. } // namespace onnxruntime