/onnxruntime/core/framework/execution_frame.cc

https://github.com/microsoft/onnxruntime · C++ · 679 lines · 470 code · 93 blank · 116 comment · 108 complexity · b7dde42175593e2292b3200dbabe4baa MD5 · raw file

  1. // Copyright (c) Microsoft Corporation. All rights reserved.
  2. // Licensed under the MIT License.
  3. #include "core/framework/execution_frame.h"
  4. #include <sstream>
  5. #include "core/framework/mem_pattern_planner.h"
  6. #include "core/framework/execution_plan_base.h"
  7. #include "core/framework/sequential_execution_plan.h"
  8. #include "core/framework/ort_value_pattern_planner.h"
  9. #include "core/framework/tensorprotoutils.h"
  10. #include "core/framework/node_index_info.h"
  11. #include "core/framework/op_kernel.h"
  12. #include "core/framework/session_state.h"
  13. #include "core/framework/TensorSeq.h"
  14. #include "core/framework/utils.h"
  15. using namespace onnxruntime::common;
  16. namespace onnxruntime {
  17. IExecutionFrame::IExecutionFrame(const OrtValueNameIdxMap& ort_value_idx_map,
  18. const NodeIndexInfo& node_index_info,
  19. const std::vector<int>& fetch_mlvalue_idxs)
  20. : node_index_info_(node_index_info),
  21. all_values_size_(static_cast<size_t>(ort_value_idx_map.MaxIdx()) + 1),
  22. fetch_mlvalue_idxs_(fetch_mlvalue_idxs) {
  23. ORT_ENFORCE(node_index_info_.GetMaxMLValueIdx() == ort_value_idx_map.MaxIdx(),
  24. "node_index_info and ort_value_idx_map are out of sync and cannot be used");
  25. }
  26. IExecutionFrame::~IExecutionFrame() = default;
  27. // Return nullptr if index map to an value that is an unused optional input/output
  28. const OrtValue* IExecutionFrame::GetNodeInputOrOutputMLValue(int index) const {
  29. int ort_value_idx = GetNodeIdxToMLValueIdx(index);
  30. return ort_value_idx != NodeIndexInfo::kInvalidEntry ? &all_values_[ort_value_idx] : nullptr;
  31. }
  32. OrtValue* IExecutionFrame::GetMutableNodeInputOrOutputMLValue(int index) {
  33. return const_cast<OrtValue*>(GetNodeInputOrOutputMLValue(index));
  34. }
  35. // TO DO: make it thread safe
  36. // This method is not thread safe!
  37. // Return S_OK and nullptr if index map to an value that is an unused optional input/output
  38. Status IExecutionFrame::GetOrCreateNodeOutputMLValue(int index, const TensorShape* shape, OrtValue*& p_ort_value,
  39. size_t nnz) {
  40. auto status = Status::OK();
  41. int ort_value_idx = GetNodeIdxToMLValueIdx(index);
  42. // return nullptr if it is optional
  43. if (ort_value_idx == NodeIndexInfo::kInvalidEntry) {
  44. p_ort_value = nullptr;
  45. } else {
  46. p_ort_value = &all_values_[ort_value_idx];
  47. if (p_ort_value->IsAllocated()) {
  48. // already allocated. verify shape matches if tensor.
  49. if (p_ort_value->IsTensor()) {
  50. const Tensor& tensor = p_ort_value->Get<Tensor>();
  51. ORT_ENFORCE(shape && tensor.Shape() == *shape,
  52. "OrtValue shape verification failed. Current shape:", tensor.Shape(),
  53. " Requested shape:", shape ? shape->ToString() : "null");
  54. }
  55. } else {
  56. status = CreateNodeOutputMLValueImpl(*p_ort_value, ort_value_idx, shape, nnz);
  57. }
  58. }
  59. return status;
  60. }
  61. bool IExecutionFrame::TryGetInferredShape(int /*index*/, TensorShape& /*shape*/) const {
  62. // By default, there is not information about inferred shape, so this default
  63. // implementation always returns false. The derived class of IExecutionFrame
  64. // can override this function to provide, for example, activations' shape information.
  65. return false;
  66. }
  67. AllocatorPtr IExecutionFrame::GetAllocator(const OrtMemoryInfo& info) const {
  68. return GetAllocatorImpl(info);
  69. }
  70. Status IExecutionFrame::ReleaseMLValue(int ort_value_idx) { return ReleaseMLValueImpl(ort_value_idx); }
  71. Status IExecutionFrame::ReleaseMLValueImpl(int ort_value_idx) {
  72. if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast<size_t>(ort_value_idx) >= all_values_size_) {
  73. return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index ", ort_value_idx);
  74. }
  75. // If fence is available, check whether async read has completed or not.
  76. Fence_t fence = GetMLValue(ort_value_idx).Fence();
  77. if (fence && !fence->CanRelease()) {
  78. // Async data reading is not done yet, defer mem release until Session.run() end.
  79. return Status::OK();
  80. }
  81. all_values_[ort_value_idx] = OrtValue();
  82. return Status::OK();
  83. }
  84. int IExecutionFrame::GetNodeIdxToMLValueIdx(int index) const {
  85. // the validity of index is checked by GetMLValueIndex
  86. int ort_value_idx = node_index_info_.GetMLValueIndex(index);
  87. return ort_value_idx;
  88. }
  89. void IExecutionFrame::Init(const std::vector<int>& feed_mlvalue_idxs, const std::vector<OrtValue>& feeds,
  90. const std::unordered_map<int, OrtValue>& initializers,
  91. const std::vector<OrtValue>& fetches) {
  92. ORT_ENFORCE(feeds.size() == feed_mlvalue_idxs.size());
  93. ORT_ENFORCE(fetches.empty() || fetches.size() == fetch_mlvalue_idxs_.size());
  94. // 1. resize the all_value_ vector
  95. all_values_.resize(all_values_size_);
  96. // 2. Handle non-empty output vector
  97. if (!fetches.empty()) {
  98. auto num_fetches = fetch_mlvalue_idxs_.size();
  99. for (size_t idx = 0; idx < num_fetches; ++idx) {
  100. int ort_value_idx = fetch_mlvalue_idxs_[idx];
  101. all_values_[ort_value_idx] = fetches[idx];
  102. }
  103. }
  104. // 3. handle the weights.
  105. // We do this after the fetches to handle an edge case where an initializer is an output.
  106. // e.g. A Constant node gets lifted to an initializer so there's no Node producing the value as an output during
  107. // Graph execution (i.e. Graph execution won't write the value to all_values_).
  108. // A non-empty fetches vector will overwrite the actual weight in all_values_[ort_value_idx] if we did this earlier.
  109. // This makes the ONNX Constant test (onnx\backend\test\data\node\test_constant) happy as that
  110. // involves a graph with a single Constant node.
  111. for (const auto& entry : initializers) {
  112. int ort_value_index = entry.first;
  113. // if the initializer is an output we need to allocate or use a provided fetch buffer and copy the data
  114. // so it can be returned to the caller.
  115. //
  116. // The alternative to handling this as a special case would be to disallow an initializer providing a graph output.
  117. // There's nothing in the ONNX spec that says a graph output must come from a node output though.
  118. // If we took that approach we'd need to:
  119. // - reject a model with an initializer or Constant node (as we convert those to initializers in Graph::Graph)
  120. // that produces a graph output even though it conforms to the ONNX spec
  121. // - update optimizers to not convert something to an initializer that is a graph output
  122. // (e.g. constant folding)
  123. if (IsOutput(ort_value_index)) {
  124. const Tensor& src = entry.second.Get<Tensor>(); // all initializers in ONNX are tensors
  125. OrtValue& dest = all_values_[ort_value_index];
  126. if (!dest.IsAllocated()) {
  127. // NOTE: This doesn't need to support ExecutionFrame custom allocators as they only come into play
  128. // for a subgraph with an output of unknown shape that needs to be accumulated by the control flow node.
  129. // If the initializer is providing the output, the shape is known.
  130. AllocatorPtr allocator = GetAllocator(src.Location());
  131. auto p_tensor = onnxruntime::make_unique<Tensor>(src.DataType(), src.Shape(), allocator);
  132. auto ml_tensor = DataTypeImpl::GetType<Tensor>();
  133. dest.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc());
  134. }
  135. ORT_THROW_IF_ERROR(CopyTensor(src, *dest.GetMutable<Tensor>()));
  136. } else {
  137. all_values_[ort_value_index] = entry.second;
  138. }
  139. }
  140. // 4. handle feed in values. these can override initializer values so must be last
  141. for (size_t idx = 0, end = feed_mlvalue_idxs.size(); idx < end; ++idx) {
  142. int ort_value_idx = feed_mlvalue_idxs[idx];
  143. // we are sharing the underline tensor/object for MLValue
  144. all_values_[ort_value_idx] = feeds[idx];
  145. }
  146. }
  147. Status IExecutionFrame::GetOutputs(std::vector<OrtValue>& fetches) {
  148. auto num_fetches = fetch_mlvalue_idxs_.size();
  149. if (fetches.empty()) {
  150. fetches.resize(num_fetches);
  151. } else {
  152. // if there's a mismatch things are out so sync so fail
  153. if (fetches.size() != num_fetches) {
  154. return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Fetches vector passed to GetOutputs contains ", fetches.size(),
  155. " entries which doesn't match the number of fetches the frame was initialized with of ",
  156. num_fetches);
  157. }
  158. }
  159. for (size_t idx = 0; idx < num_fetches; ++idx) {
  160. fetches[idx] = GetMLValue(fetch_mlvalue_idxs_[idx]);
  161. }
  162. return Status::OK();
  163. }
  164. bool IExecutionFrame::IsOutput(int ort_value_idx) const {
  165. return std::find(fetch_mlvalue_idxs_.begin(), fetch_mlvalue_idxs_.end(), ort_value_idx) != fetch_mlvalue_idxs_.end();
  166. }
  167. ExecutionFrame::ExecutionFrame(const std::vector<int>& feed_mlvalue_idxs, const std::vector<OrtValue>& feeds,
  168. const std::vector<int>& fetch_mlvalue_idxs, const std::vector<OrtValue>& fetches,
  169. const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
  170. const SessionState& session_state)
  171. : IExecutionFrame(session_state.GetOrtValueNameIdxMap(), session_state.GetNodeIndexInfo(), fetch_mlvalue_idxs),
  172. session_state_(session_state),
  173. mem_patterns_(nullptr),
  174. planner_(nullptr) {
  175. Init(feed_mlvalue_idxs, feeds, session_state.GetInitializedTensors(), fetches);
  176. // map the custom allocators to ort_value_idx entries
  177. if (!fetch_allocators.empty()) {
  178. for (size_t idx = 0, end = fetch_mlvalue_idxs.size(); idx < end; ++idx) {
  179. int ort_value_idx = fetch_mlvalue_idxs[idx];
  180. auto custom_alloc_entry = fetch_allocators.find(idx);
  181. if (custom_alloc_entry != fetch_allocators.cend()) {
  182. custom_allocators_[ort_value_idx] = custom_alloc_entry->second;
  183. }
  184. }
  185. }
  186. // If the session enable memory pattern optimization
  187. // and we have execution plan generated, try to setup
  188. // memory pattern optimization.
  189. if (session_state.GetEnableMemoryPattern() && session_state.GetExecutionPlan()) {
  190. std::vector<std::reference_wrapper<const TensorShape>> input_shapes;
  191. bool all_tensors = true;
  192. // Reserve mem to avoid re-allocation.
  193. input_shapes.reserve(feeds.size());
  194. for (const auto& feed : feeds) {
  195. if (!(feed.IsTensor())) {
  196. all_tensors = false;
  197. break;
  198. }
  199. auto& tensor = feed.Get<Tensor>();
  200. input_shapes.push_back(std::cref(tensor.Shape()));
  201. }
  202. //if there are some traditional ml value type in inputs disable the memory pattern optimization.
  203. if (all_tensors) {
  204. mem_patterns_ = session_state.GetMemoryPatternGroup(input_shapes, feed_mlvalue_idxs, inferred_shapes_);
  205. // if no existing patterns, generate one in this executionframe
  206. if (!mem_patterns_) {
  207. planner_ = onnxruntime::make_unique<OrtValuePatternPlanner>(*session_state.GetExecutionPlan());
  208. } else {
  209. // pre-allocate the big chunk requested in memory pattern.
  210. // all the internal kernel's input/output tensors will be allocated on these buffer.
  211. for (size_t i = 0; i < mem_patterns_->locations.size(); i++) {
  212. const auto& location = mem_patterns_->locations[i];
  213. ORT_ENFORCE(buffers_.find(location) == buffers_.end());
  214. if (mem_patterns_->patterns[i].PeakSize() > 0) {
  215. AllocatorPtr alloc = GetAllocator(location);
  216. void* buffer = nullptr;
  217. // it's possible we can't allocate the large block. if we have memory patterns we know we have successfully
  218. // executed once before, so if there's an arena involved it probably has smaller blocks available.
  219. // due to that we can still run and use those blocks (inside the arena logic) instead of one large one.
  220. // it's less efficient (the arena will add some overhead to coalesce individual allocations
  221. // back into blocks on 'free'), but better than failing completely.
  222. ORT_TRY {
  223. auto peak_size = mem_patterns_->patterns[i].PeakSize();
  224. // Planning of one memory type should only happen once.
  225. ORT_ENFORCE(
  226. static_activation_memory_sizes_in_byte_.find(location.name) ==
  227. static_activation_memory_sizes_in_byte_.end(),
  228. "Memory type ",
  229. location.name,
  230. " should only appear once.");
  231. // static_activation_memory_in_bytes_ is max virtual memory size the planner computes.
  232. // Memory dynamically allocated when executing kernels is not recorded using this field.
  233. static_activation_memory_sizes_in_byte_[location.name] = peak_size;
  234. buffer = alloc->Alloc(peak_size);
  235. // handle allocator that doesn't throw
  236. if (buffer == nullptr) {
  237. // INFO level as this may fire on every run and there may not be much a user can do
  238. LOGS(session_state_.Logger(), INFO) << "Allocation of memory pattern buffer for "
  239. << location.ToString() << " returned nullptr";
  240. }
  241. }
  242. ORT_CATCH(const OnnxRuntimeException& ex) {
  243. ORT_HANDLE_EXCEPTION([&]() {
  244. LOGS(session_state_.Logger(), INFO) << "Allocation of memory pattern buffer for "
  245. << location.ToString() << " failed. Error:" << ex.what();
  246. });
  247. }
  248. if (buffer != nullptr) {
  249. buffers_[location] = BufferUniquePtr(buffer, alloc);
  250. }
  251. // log size of activation. Keep it commented out for now to avoid log flooding.
  252. // VLOGS(session_state_.Logger(), 1) << "Allocated memory for activations, size: "
  253. // << mem_patterns_->patterns[i].PeakSize();
  254. }
  255. }
  256. }
  257. }
  258. }
  259. }
  260. ExecutionFrame::~ExecutionFrame() = default;
  261. Status ExecutionFrame::CopyTensor(const Tensor& src, Tensor& dest) const {
  262. return session_state_.GetDataTransferMgr().CopyTensor(src, dest);
  263. }
  264. Status ExecutionFrame::AllocateMLValueTensorSelfOwnBuffer(OrtValue& ort_value, int ort_value_index,
  265. MLDataType element_type, const OrtMemoryInfo& location,
  266. const TensorShape& shape, bool create_fence) {
  267. return AllocateMLValueTensorSelfOwnBufferHelper(ort_value, ort_value_index, element_type, location, shape,
  268. create_fence);
  269. }
  270. Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_value, int ort_value_index,
  271. MLDataType element_type,
  272. const OrtMemoryInfo& location,
  273. const TensorShape& shape, bool create_fence) {
  274. if (ort_value_index == NodeIndexInfo::kInvalidEntry) {
  275. return Status(ONNXRUNTIME, FAIL, "Trying to allocate memory for unused optional inputs/outputs");
  276. }
  277. size_t size;
  278. int64_t len = shape.Size();
  279. if (len < 0) {
  280. return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Tensor shape cannot contain any negative value");
  281. }
  282. if (static_cast<uint64_t>(len) > std::numeric_limits<size_t>::max()) {
  283. return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Tensor shape is too large");
  284. }
  285. if (!IAllocator::CalcMemSizeForArrayWithAlignment<64>(static_cast<size_t>(len), element_type->Size(), &size)) {
  286. return Status(ONNXRUNTIME, FAIL, "size overflow");
  287. }
  288. // Lazily get the allocator only if needed.
  289. AllocatorPtr alloc = nullptr;
  290. // create fence if needed
  291. if (create_fence) {
  292. ORT_ENFORCE(ort_value.Fence() == nullptr);
  293. alloc = GetAllocator(location);
  294. FencePtr f = alloc->CreateFence(&session_state_);
  295. // it is OK to have fence been nullptr if the execution provider has no async execution,
  296. // and allocator::CreateFence returns nullptr
  297. ort_value.SetFence(f);
  298. }
  299. // if we have pre-calculated memory pattern, and the ort_value is not output mlvalue
  300. // try to allocated on pre-allocated big chunk.
  301. const auto& per_alloc_plan = GetAllocationPlan(ort_value_index);
  302. if (mem_patterns_ && per_alloc_plan.alloc_kind != AllocKind::kAllocateOutput) {
  303. auto pattern = mem_patterns_->GetPatterns(location);
  304. if (pattern) {
  305. auto block = pattern->GetBlock(ort_value_index);
  306. // if block not found, fall back to default behavior
  307. if (block) {
  308. auto it = buffers_.find(location);
  309. if (it != buffers_.end()) {
  310. // if the block is not correct, log message then fall back to default behavior
  311. if (block->size_ == size) {
  312. void* buffer = it->second.get();
  313. auto status = AllocateTensorWithPreAllocateBufferHelper(
  314. ort_value, static_cast<void*>(static_cast<char*>(buffer) + block->offset_), element_type, location,
  315. shape);
  316. return status;
  317. } else {
  318. // the block size may vary especially if the model has NonZero ops, or different sequence lengths are
  319. // fed in, so use VERBOSE as the log level as it's expected.
  320. // TODO: Should we re-use the block if the size is large enough? Would probably need to allow it
  321. // to be freed if the size difference was too large so our memory usage doesn't stick at a high water mark
  322. LOGS(session_state_.Logger(), VERBOSE) << "For ort_value with index: " << ort_value_index
  323. << ", block in memory pattern size is: " << block->size_
  324. << " but the actually size is: " << size
  325. << ", fall back to default allocation behavior";
  326. }
  327. }
  328. // else { we couldn't allocate the large block for the buffer so we didn't insert an entry }
  329. }
  330. }
  331. }
  332. //no memory pattern, or the pattern is not correct.
  333. if (!alloc) alloc = GetAllocator(location);
  334. std::unique_ptr<Tensor> p_tensor = onnxruntime::make_unique<Tensor>(element_type, shape, alloc);
  335. {
  336. auto ml_tensor = DataTypeImpl::GetType<Tensor>();
  337. ort_value.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc());
  338. }
  339. // trace the memory allocation.
  340. // don't trace the memory allocation on string tensors, as it need
  341. // placement new, we don't support it in memory pattern optimization.
  342. if (!utils::IsDataTypeString(element_type)) {
  343. TraceAllocate(ort_value_index, size);
  344. }
  345. {
  346. // This code block is not thread-safe.
  347. // Dynamic activation size would be accessed by multiple threads
  348. // if parallel executor is used.
  349. std::unique_lock<std::mutex> lock(mtx_);
  350. dynamic_activation_memory_sizes_in_byte_[location.name] += size;
  351. }
  352. return Status::OK();
  353. }
  354. Status ExecutionFrame::AllocateMLValueTensorPreAllocateBuffer(OrtValue& ort_value, int ort_value_index_reuse,
  355. MLDataType element_type, const OrtMemoryInfo& location,
  356. const TensorShape& shape, bool create_fence) {
  357. OrtValue& ort_value_reuse = GetMutableMLValue(ort_value_index_reuse);
  358. auto* reuse_tensor = ort_value_reuse.GetMutable<Tensor>();
  359. auto buffer_num_elements = reuse_tensor->Shape().Size();
  360. auto required_num_elements = shape.Size();
  361. // check number of elements matches. shape may not be an exact match (e.g. Reshape op)
  362. if (buffer_num_elements != required_num_elements) {
  363. // could be an allocation planner bug (less likely) or the model incorrectly uses something like 'None'
  364. // as a dim_param, or -1 in dim_value in multiple places making the planner think those shapes are equal.
  365. auto message = onnxruntime::MakeString(
  366. "Shape mismatch attempting to re-use buffer. ",
  367. reuse_tensor->Shape(), " != ", shape,
  368. ". Validate usage of dim_value (values should be > 0) and "
  369. "dim_param (all values with the same string should equate to the same size) in shapes in the model.");
  370. // be generous and use the buffer if it's large enough. log a warning though as it indicates a bad model
  371. if (buffer_num_elements >= required_num_elements) {
  372. // View Operator is reusing the buffer bigger than the required size.
  373. // Disabling warning message for now. The op is in the process of being deprecated.
  374. #ifndef ENABLE_TRAINING
  375. LOGS(session_state_.Logger(), WARNING) << message;
  376. #endif
  377. } else {
  378. return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, message);
  379. }
  380. }
  381. void* reuse_buffer = reuse_tensor->MutableDataRaw();
  382. // create fence on reused ort_value if needed
  383. // TODO: differentiate reuse and alias, by add AllocKind::kAlias?
  384. if (create_fence && ort_value_reuse.Fence() == nullptr) {
  385. FencePtr f = GetAllocator(location)->CreateFence(&session_state_);
  386. ort_value_reuse.SetFence(f);
  387. }
  388. // reused OrtValue share the same fence
  389. ort_value.ShareFenceWith(ort_value_reuse);
  390. return AllocateTensorWithPreAllocateBufferHelper(ort_value, reuse_buffer, element_type, location, shape);
  391. }
  392. Status ExecutionFrame::AllocateTensorWithPreAllocateBufferHelper(OrtValue& ort_value, void* pBuffer,
  393. MLDataType element_type,
  394. const OrtMemoryInfo& location,
  395. const TensorShape& shape) {
  396. auto ml_tensor = DataTypeImpl::GetType<Tensor>();
  397. auto p_tensor = onnxruntime::make_unique<Tensor>(element_type, shape, pBuffer, location);
  398. ort_value.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc());
  399. return Status::OK();
  400. }
  401. static Status AllocateTraditionalMLValue(OrtValue& ort_value, const NonTensorTypeBase& type) {
  402. auto creator = type.GetCreateFunc();
  403. ort_value.Init(creator(), &type, type.GetDeleteFunc());
  404. return Status::OK();
  405. }
  406. static Status AllocateTensorSequence(OrtValue& ort_value) {
  407. auto ml_tensor_sequence = DataTypeImpl::GetType<TensorSeq>();
  408. auto p_tensor_sequence = onnxruntime::make_unique<TensorSeq>();
  409. ort_value.Init(p_tensor_sequence.release(), ml_tensor_sequence, ml_tensor_sequence->GetDeleteFunc());
  410. return Status::OK();
  411. }
  412. #if !defined(ORT_MINIMAL_BUILD)
  413. static Status AllocateSparseTensor(MLValue& mlvalue, const DataTypeImpl& ml_type, AllocatorPtr allocator,
  414. const TensorShape& shape, size_t nnz, bool create_fence,
  415. const SessionState& session_state) {
  416. auto element_type = ml_type.AsSparseTensorType()->GetElementType();
  417. auto sparse = onnxruntime::make_unique<SparseTensor>(element_type, shape, nnz, allocator);
  418. auto deleter = DataTypeImpl::GetType<SparseTensor>()->GetDeleteFunc();
  419. mlvalue.Init(sparse.release(), DataTypeImpl::GetType<SparseTensor>(), deleter);
  420. // create fence if needed
  421. if (create_fence) {
  422. ORT_ENFORCE(mlvalue.Fence() == nullptr);
  423. FencePtr f = allocator->CreateFence(&session_state);
  424. mlvalue.SetFence(f);
  425. }
  426. return Status::OK();
  427. }
  428. #endif
  429. // This method is not thread safe!
  430. Status ExecutionFrame::AllocateAsPerAllocationPlan(OrtValue& ort_value, int ort_value_index, const TensorShape* shape,
  431. size_t nnz) {
  432. const SequentialExecutionPlan* p_seq_exec_plan = session_state_.GetExecutionPlan();
  433. const auto& alloc_plan = p_seq_exec_plan->allocation_plan;
  434. ORT_ENFORCE(ort_value_index >= 0 && static_cast<size_t>(ort_value_index) < alloc_plan.size());
  435. const auto& per_alloc_plan = alloc_plan[ort_value_index];
  436. const auto& alloc_info = per_alloc_plan.location;
  437. const auto* ml_type = per_alloc_plan.value_type;
  438. if (ml_type == nullptr) {
  439. return Status(
  440. ONNXRUNTIME, INVALID_ARGUMENT,
  441. "Tried to allocate without valid type information, ort_value index=" + std::to_string(ort_value_index));
  442. }
  443. // if there is a custom allocator for this ort_value_index, call it to do the allocation
  444. auto custom_alloc_entry = custom_allocators_.find(ort_value_index);
  445. if (custom_alloc_entry != custom_allocators_.cend()) {
  446. ORT_ENFORCE(shape, "We don't expect custom allocators for non-tensor types, so a shape is mandatory here.");
  447. bool allocated = false;
  448. // see if custom allocator can handle allocation
  449. auto status = (custom_alloc_entry->second)(*shape, alloc_info, ort_value, allocated);
  450. if (allocated || !status.IsOK())
  451. return status;
  452. }
  453. if (ml_type->IsTensorType()) {
  454. ORT_ENFORCE(shape, "Allocation of tensor types requires a shape.");
  455. // tensors
  456. const auto* ml_data_type = static_cast<const TensorTypeBase*>(ml_type)->GetElementType();
  457. AllocKind alloc_kind = per_alloc_plan.alloc_kind;
  458. switch (alloc_kind) {
  459. // Right now for kAllocate and kAllocateOutput we are using same approach.
  460. // In the future we may want to have different way to handle it.
  461. case AllocKind::kAllocateOutput:
  462. case AllocKind::kAllocate: {
  463. ORT_RETURN_IF_ERROR(AllocateMLValueTensorSelfOwnBuffer(ort_value, ort_value_index, ml_data_type, alloc_info,
  464. *shape, per_alloc_plan.create_fence_if_async));
  465. break;
  466. }
  467. case AllocKind::kReuse: {
  468. int reuse_mlvalue_index = per_alloc_plan.reused_buffer;
  469. // In case OrtRunOptions.only_execute_path_to_fetches == true, it is possible that 'reuse_value'
  470. // is not allocated (its upstream op is not executed due to the option).
  471. // In this case we need to allocate 'reuse_value' and then let 'ort_value' to reuse it.
  472. OrtValue& reuse_value = GetMutableMLValue(reuse_mlvalue_index);
  473. if (!reuse_value.IsAllocated()) {
  474. ORT_RETURN_IF_ERROR(AllocateAsPerAllocationPlan(reuse_value, reuse_mlvalue_index, shape, nnz));
  475. }
  476. ORT_RETURN_IF_ERROR(AllocateMLValueTensorPreAllocateBuffer(
  477. ort_value, reuse_mlvalue_index, ml_data_type, alloc_info, *shape, per_alloc_plan.create_fence_if_async));
  478. break;
  479. }
  480. case AllocKind::kShare: {
  481. int reuse_mlvalue_index = per_alloc_plan.reused_buffer;
  482. // copy at the OrtValue level so the shared_ptr for the data is shared between the two OrtValue instances
  483. ort_value = GetMutableMLValue(reuse_mlvalue_index);
  484. break;
  485. }
  486. default: {
  487. std::ostringstream ostr;
  488. ostr << "Invalid allocation kind: " << static_cast<std::underlying_type<AllocKind>::type>(alloc_kind);
  489. return Status(ONNXRUNTIME, FAIL, ostr.str());
  490. }
  491. }
  492. return Status::OK();
  493. } else if (ml_type->IsSparseTensorType()) {
  494. #if !defined(ORT_MINIMAL_BUILD)
  495. return AllocateSparseTensor(ort_value, *ml_type, GetAllocator(alloc_info),
  496. *shape, nnz, per_alloc_plan.create_fence_if_async, session_state_);
  497. #else
  498. // Model load should have failed so this should be unreachable
  499. ORT_THROW("SparseTensor is not supported in this build.");
  500. #endif
  501. } else if (ml_type->IsTensorSequenceType()) {
  502. return AllocateTensorSequence(ort_value);
  503. } else {
  504. return AllocateTraditionalMLValue(ort_value, *static_cast<const NonTensorTypeBase*>(ml_type));
  505. }
  506. }
  507. AllocatorPtr ExecutionFrame::GetAllocatorImpl(const OrtMemoryInfo& info) const {
  508. return session_state_.GetAllocator(info);
  509. }
  510. // This method is not thread safe!
  511. // Return S_OK and nullptr if index map to an value that is an unused optional input/output
  512. Status ExecutionFrame::CreateNodeOutputMLValueImpl(OrtValue& ort_value, int ort_value_idx,
  513. const TensorShape* shape, size_t nnz) {
  514. return AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape, nnz);
  515. }
  516. Status ExecutionFrame::ReleaseMLValueImpl(int ort_value_idx) {
  517. ORT_RETURN_IF_ERROR(IExecutionFrame::ReleaseMLValueImpl(ort_value_idx));
  518. TraceFree(ort_value_idx);
  519. return Status::OK();
  520. }
  521. const AllocPlanPerValue& ExecutionFrame::GetAllocationPlan(int ort_value_idx) {
  522. const SequentialExecutionPlan* p_seq_exec_plan = session_state_.GetExecutionPlan();
  523. const auto& alloc_plan = p_seq_exec_plan->allocation_plan;
  524. ORT_ENFORCE(ort_value_idx >= 0 && static_cast<size_t>(ort_value_idx) < alloc_plan.size());
  525. return alloc_plan[ort_value_idx];
  526. }
  527. void ExecutionFrame::TraceAllocate(int ort_value_idx, size_t size) {
  528. if (planner_) {
  529. // don't trace the output tensors.
  530. auto& allocation_plan = GetAllocationPlan(ort_value_idx);
  531. if (allocation_plan.alloc_kind == AllocKind::kAllocateOutput) return;
  532. auto status = planner_->TraceAllocation(ort_value_idx, size);
  533. if (!status.IsOK())
  534. LOGS(session_state_.Logger(), WARNING) << "TraceAllocation for ort_value_idx=" << ort_value_idx
  535. << " size=" << size << " failed: " << status.ErrorMessage();
  536. }
  537. }
  538. void ExecutionFrame::TraceFree(int ort_value_idx) {
  539. // don't trace free on output tensors.
  540. if (planner_ && !IsOutput(ort_value_idx)) {
  541. const SequentialExecutionPlan* p_seq_exec_plan = session_state_.GetExecutionPlan();
  542. const auto& alloc_plan = p_seq_exec_plan->allocation_plan;
  543. ORT_ENFORCE(ort_value_idx >= 0 && static_cast<size_t>(ort_value_idx) < alloc_plan.size());
  544. const auto& per_alloc_plan = alloc_plan[ort_value_idx];
  545. // only trace tensors
  546. auto ml_type = per_alloc_plan.value_type;
  547. if (ml_type->IsTensorType()) {
  548. // tensors
  549. auto ml_data_type = static_cast<const TensorTypeBase*>(ml_type)->GetElementType();
  550. // don't trace string tensors
  551. if (!utils::IsDataTypeString(ml_data_type)) {
  552. auto status = planner_->TraceFree(ort_value_idx);
  553. if (!status.IsOK()) {
  554. LOGS(session_state_.Logger(), WARNING)
  555. << "TraceFree for ort_value_idx=" << ort_value_idx << " failed: " << status.ErrorMessage();
  556. }
  557. }
  558. }
  559. }
  560. }
  561. // generate memory pattern based on the tracing of memory allocation/free in current execution
  562. // return error if the planner is not setup.
  563. Status ExecutionFrame::GeneratePatterns(MemoryPatternGroup* out) const {
  564. if (!planner_) {
  565. return Status(ONNXRUNTIME, FAIL, "Memory pattern planner is not enabled on this execution framework.");
  566. }
  567. return planner_->GeneratePatterns(out);
  568. }
  569. bool ExecutionFrame::TryGetInferredShape(int index, TensorShape& shape) const {
  570. // NodeArg index to OrtValue index.
  571. int ort_value_idx = GetNodeIdxToMLValueIdx(index);
  572. // Check if index is valid.
  573. if (ort_value_idx == NodeIndexInfo::kInvalidEntry) {
  574. return false;
  575. }
  576. // Search for inferred shape.
  577. // If inferred shape is found, it's assigned to "shape" so that caller can use it.
  578. auto it = inferred_shapes_.find(ort_value_idx);
  579. if (it != inferred_shapes_.end()) {
  580. shape = it->second;
  581. return true;
  582. }
  583. // Tell the caller if the search is successful or not.
  584. return false;
  585. }
  586. } // namespace onnxruntime