onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1440 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from __future__ import annotations
6
+
7
+ import collections
8
+ import collections.abc
9
+ import os
10
+ import typing
11
+ import warnings
12
+ from collections.abc import Callable, Sequence
13
+ from typing import Any
14
+
15
+ from onnxruntime.capi import _pybind_state as C
16
+
17
+ if typing.TYPE_CHECKING:
18
+ import numpy as np
19
+ import numpy.typing as npt
20
+
21
+ import onnxruntime
22
+
23
+
24
+ def get_ort_device_type(device_type: str) -> int:
25
+ if device_type == "cuda":
26
+ return C.OrtDevice.cuda()
27
+ elif device_type == "cann":
28
+ return C.OrtDevice.cann()
29
+ elif device_type == "cpu":
30
+ return C.OrtDevice.cpu()
31
+ elif device_type == "dml":
32
+ return C.OrtDevice.dml()
33
+ elif device_type == "webgpu":
34
+ return C.OrtDevice.webgpu()
35
+ elif device_type == "gpu":
36
+ return C.OrtDevice.gpu()
37
+ elif device_type == "npu":
38
+ return C.OrtDevice.npu()
39
+ else:
40
+ raise Exception("Unsupported device type: " + device_type)
41
+
42
+
43
+ class AdapterFormat:
44
+ """
45
+ This class is used to create adapter files from python structures
46
+ """
47
+
48
+ def __init__(self, adapter=None) -> None:
49
+ if adapter is None:
50
+ self._adapter = C.AdapterFormat()
51
+ else:
52
+ self._adapter = adapter
53
+
54
+ @staticmethod
55
+ def read_adapter(file_path: os.PathLike) -> AdapterFormat:
56
+ return AdapterFormat(C.AdapterFormat.read_adapter(file_path))
57
+
58
+ def export_adapter(self, file_path: os.PathLike):
59
+ """
60
+ This function writes a file at the specified location
61
+ in onnxrunitme adapter format containing Lora parameters.
62
+
63
+ :param file_path: absolute path for the adapter
64
+ """
65
+ self._adapter.export_adapter(file_path)
66
+
67
+ def get_format_version(self) -> int:
68
+ return self._adapter.format_version
69
+
70
+ def set_adapter_version(self, adapter_version: int) -> None:
71
+ self._adapter.adapter_version = adapter_version
72
+
73
+ def get_adapter_version(self) -> int:
74
+ return self._adapter.adapter_version
75
+
76
+ def set_model_version(self, model_version: int) -> None:
77
+ self._adapter.model_version = model_version
78
+
79
+ def get_model_version(self) -> int:
80
+ return self._adapter.model_version
81
+
82
+ def set_parameters(self, params: dict[str, OrtValue]) -> None:
83
+ self._adapter.parameters = {k: v._ortvalue for k, v in params.items()}
84
+
85
+ def get_parameters(self) -> dict[str, OrtValue]:
86
+ return {k: OrtValue(v) for k, v in self._adapter.parameters.items()}
87
+
88
+
89
+ def check_and_normalize_provider_args(
90
+ providers: Sequence[str | tuple[str, dict[Any, Any]]] | None,
91
+ provider_options: Sequence[dict[Any, Any]] | None,
92
+ available_provider_names: Sequence[str],
93
+ ):
94
+ """
95
+ Validates the 'providers' and 'provider_options' arguments and returns a
96
+ normalized version.
97
+
98
+ :param providers: Optional sequence of providers in order of decreasing
99
+ precedence. Values can either be provider names or tuples of
100
+ (provider name, options dict).
101
+ :param provider_options: Optional sequence of options dicts corresponding
102
+ to the providers listed in 'providers'.
103
+ :param available_provider_names: The available provider names.
104
+
105
+ :return: Tuple of (normalized 'providers' sequence, normalized
106
+ 'provider_options' sequence).
107
+
108
+ 'providers' can contain either names or names and options. When any options
109
+ are given in 'providers', 'provider_options' should not be used.
110
+
111
+ The normalized result is a tuple of:
112
+ 1. Sequence of provider names in the same order as 'providers'.
113
+ 2. Sequence of corresponding provider options dicts with string keys and
114
+ values. Unspecified provider options yield empty dicts.
115
+ """
116
+ if providers is None:
117
+ return [], []
118
+
119
+ provider_name_to_options = collections.OrderedDict()
120
+
121
+ def set_provider_options(name, options):
122
+ if name not in available_provider_names:
123
+ warnings.warn(
124
+ "Specified provider '{}' is not in available provider names.Available providers: '{}'".format(
125
+ name, ", ".join(available_provider_names)
126
+ )
127
+ )
128
+
129
+ if name in provider_name_to_options:
130
+ warnings.warn(f"Duplicate provider '{name}' encountered, ignoring.")
131
+ return
132
+
133
+ normalized_options = {str(key): str(value) for key, value in options.items()}
134
+ provider_name_to_options[name] = normalized_options
135
+
136
+ if not isinstance(providers, collections.abc.Sequence):
137
+ raise ValueError("'providers' should be a sequence.")
138
+
139
+ if provider_options is not None:
140
+ if not isinstance(provider_options, collections.abc.Sequence):
141
+ raise ValueError("'provider_options' should be a sequence.")
142
+
143
+ if len(providers) != len(provider_options):
144
+ raise ValueError("'providers' and 'provider_options' should be the same length if both are given.")
145
+
146
+ if not all(isinstance(provider, str) for provider in providers):
147
+ raise ValueError("Only string values for 'providers' are supported if 'provider_options' is given.")
148
+
149
+ if not all(isinstance(options_for_provider, dict) for options_for_provider in provider_options):
150
+ raise ValueError("'provider_options' values must be dicts.")
151
+
152
+ for name, options in zip(providers, provider_options, strict=False):
153
+ set_provider_options(name, options)
154
+
155
+ else:
156
+ for provider in providers:
157
+ if isinstance(provider, str):
158
+ set_provider_options(provider, {})
159
+ elif (
160
+ isinstance(provider, tuple)
161
+ and len(provider) == 2
162
+ and isinstance(provider[0], str)
163
+ and isinstance(provider[1], dict)
164
+ ):
165
+ set_provider_options(provider[0], provider[1])
166
+ else:
167
+ raise ValueError("'providers' values must be either strings or (string, dict) tuples.")
168
+
169
+ return list(provider_name_to_options.keys()), list(provider_name_to_options.values())
170
+
171
+
172
+ class Session:
173
+ """
174
+ This is the main class used to run a model.
175
+ """
176
+
177
+ def __init__(self, enable_fallback: bool = True):
178
+ # self._sess is managed by the derived class and relies on bindings from C.InferenceSession
179
+ self._sess = None
180
+ self._enable_fallback = enable_fallback
181
+
182
+ def get_session_options(self) -> onnxruntime.SessionOptions:
183
+ "Return the session options. See :class:`onnxruntime.SessionOptions`."
184
+ return self._sess_options
185
+
186
+ def get_inputs(self) -> Sequence[onnxruntime.NodeArg]:
187
+ "Return the inputs metadata as a list of :class:`onnxruntime.NodeArg`."
188
+ return self._inputs_meta
189
+
190
+ def get_outputs(self) -> Sequence[onnxruntime.NodeArg]:
191
+ "Return the outputs metadata as a list of :class:`onnxruntime.NodeArg`."
192
+ return self._outputs_meta
193
+
194
+ def get_overridable_initializers(self) -> Sequence[onnxruntime.NodeArg]:
195
+ "Return the inputs (including initializers) metadata as a list of :class:`onnxruntime.NodeArg`."
196
+ return self._overridable_initializers
197
+
198
+ def get_modelmeta(self) -> onnxruntime.ModelMetadata:
199
+ "Return the metadata. See :class:`onnxruntime.ModelMetadata`."
200
+ return self._model_meta
201
+
202
+ def get_input_memory_infos(self) -> Sequence[onnxruntime.MemoryInfo]:
203
+ "Return the memory info for the inputs."
204
+ return self._input_meminfos
205
+
206
+ def get_output_memory_infos(self) -> Sequence[onnxruntime.MemoryInfo]:
207
+ "Return the memory info for the outputs."
208
+ return self._output_meminfos
209
+
210
+ def get_input_epdevices(self) -> Sequence[onnxruntime.OrtEpDevice]:
211
+ "Return the execution providers for the inputs."
212
+ return self._input_epdevices
213
+
214
+ def get_providers(self) -> Sequence[str]:
215
+ "Return list of registered execution providers."
216
+ return self._providers
217
+
218
+ def get_provider_options(self):
219
+ "Return registered execution providers' configurations."
220
+ return self._provider_options
221
+
222
+ def get_provider_graph_assignment_info(self) -> Sequence[onnxruntime.OrtEpAssignedSubgraph]:
223
+ """
224
+ Get information about the subgraphs assigned to each execution provider and the nodes within.
225
+
226
+ Application must enable the recording of graph assignment information by setting the session configuration
227
+ for the key "session.record_ep_graph_assignment_info" to "1".
228
+ """
229
+ return self._sess.get_provider_graph_assignment_info()
230
+
231
+ def set_providers(self, providers=None, provider_options=None) -> None:
232
+ """
233
+ Register the input list of execution providers. The underlying session is re-created.
234
+
235
+ :param providers: Optional sequence of providers in order of decreasing
236
+ precedence. Values can either be provider names or tuples of
237
+ (provider name, options dict). If not provided, then all available
238
+ providers are used with the default precedence.
239
+ :param provider_options: Optional sequence of options dicts corresponding
240
+ to the providers listed in 'providers'.
241
+
242
+ 'providers' can contain either names or names and options. When any options
243
+ are given in 'providers', 'provider_options' should not be used.
244
+
245
+ The list of providers is ordered by precedence. For example
246
+ `['CUDAExecutionProvider', 'CPUExecutionProvider']`
247
+ means execute a node using CUDAExecutionProvider if capable,
248
+ otherwise execute using CPUExecutionProvider.
249
+ """
250
+ # recreate the underlying C.InferenceSession
251
+ self._reset_session(providers, provider_options)
252
+
253
+ def disable_fallback(self) -> None:
254
+ """
255
+ Disable session.run() fallback mechanism.
256
+ """
257
+ self._enable_fallback = False
258
+
259
+ def enable_fallback(self) -> None:
260
+ """
261
+ Enable session.Run() fallback mechanism. If session.Run() fails due to an internal Execution Provider failure,
262
+ reset the Execution Providers enabled for this session.
263
+ If GPU is enabled, fall back to CUDAExecutionProvider.
264
+ otherwise fall back to CPUExecutionProvider.
265
+ """
266
+ self._enable_fallback = True
267
+
268
+ def _validate_input(self, feed_input_names):
269
+ missing_input_names = []
270
+ for input in self._inputs_meta:
271
+ if input.name not in feed_input_names and not input.type.startswith("optional"):
272
+ missing_input_names.append(input.name)
273
+ if missing_input_names:
274
+ raise ValueError(
275
+ f"Required inputs ({missing_input_names}) are missing from input feed ({feed_input_names})."
276
+ )
277
+
278
+ def run(self, output_names, input_feed, run_options=None) -> Sequence[np.ndarray | SparseTensor | list | dict]:
279
+ """
280
+ Compute the predictions.
281
+
282
+ :param output_names: name of the outputs
283
+ :param input_feed: dictionary ``{ input_name: input_value }``
284
+ :param run_options: See :class:`onnxruntime.RunOptions`.
285
+ :return: list of results, every result is either a numpy array,
286
+ a sparse tensor, a list or a dictionary.
287
+
288
+ ::
289
+
290
+ sess.run([output_name], {input_name: x})
291
+ """
292
+ self._validate_input(list(input_feed.keys()))
293
+ if not output_names:
294
+ output_names = [output.name for output in self._outputs_meta]
295
+ try:
296
+ return self._sess.run(output_names, input_feed, run_options)
297
+ except C.EPFail as err:
298
+ if self._enable_fallback:
299
+ print(f"EP Error: {err!s} using {self._providers}")
300
+ print(f"Falling back to {self._fallback_providers} and retrying.")
301
+ self.set_providers(self._fallback_providers)
302
+ # Fallback only once.
303
+ self.disable_fallback()
304
+ return self._sess.run(output_names, input_feed, run_options)
305
+ raise
306
+
307
+ def run_async(self, output_names, input_feed, callback, user_data, run_options=None):
308
+ """
309
+ Compute the predictions asynchronously in a separate cxx thread from ort intra-op threadpool.
310
+
311
+ :param output_names: name of the outputs
312
+ :param input_feed: dictionary ``{ input_name: input_value }``
313
+ :param callback: python function that accept array of results, and a status string on error.
314
+ The callback will be invoked by a cxx thread from ort intra-op threadpool.
315
+ :param run_options: See :class:`onnxruntime.RunOptions`.
316
+
317
+ ::
318
+ class MyData:
319
+ def __init__(self):
320
+ # ...
321
+ def save_results(self, results):
322
+ # ...
323
+
324
+ def callback(results: np.ndarray, user_data: MyData, err: str) -> None:
325
+ if err:
326
+ print (err)
327
+ else:
328
+ # save results to user_data
329
+
330
+ sess.run_async([output_name], {input_name: x}, callback)
331
+ """
332
+ self._validate_input(list(input_feed.keys()))
333
+ if not output_names:
334
+ output_names = [output.name for output in self._outputs_meta]
335
+ return self._sess.run_async(output_names, input_feed, callback, user_data, run_options)
336
+
337
+ def run_with_ort_values(self, output_names, input_dict_ort_values, run_options=None) -> Sequence[OrtValue]:
338
+ """
339
+ Compute the predictions.
340
+
341
+ :param output_names: name of the outputs
342
+ :param input_dict_ort_values: dictionary ``{ input_name: input_ort_value }``
343
+ See ``OrtValue`` class how to create `OrtValue`
344
+ from numpy array or `SparseTensor`
345
+ :param run_options: See :class:`onnxruntime.RunOptions`.
346
+ :return: an array of `OrtValue`
347
+
348
+ ::
349
+
350
+ sess.run([output_name], {input_name: x})
351
+ """
352
+
353
+ def invoke(sess, output_names, input_dict_ort_values, run_options):
354
+ input_dict = {}
355
+ for n, v in input_dict_ort_values.items():
356
+ input_dict[n] = v._get_c_value()
357
+ result = sess.run_with_ort_values(input_dict, output_names, run_options)
358
+ if not isinstance(result, C.OrtValueVector):
359
+ raise TypeError("run_with_ort_values() must return a instance of type 'OrtValueVector'.")
360
+ ort_values = [OrtValue(v) for v in result]
361
+ return ort_values
362
+
363
+ self._validate_input(list(input_dict_ort_values.keys()))
364
+ if not output_names:
365
+ output_names = [output.name for output in self._outputs_meta]
366
+ try:
367
+ return invoke(self._sess, output_names, input_dict_ort_values, run_options)
368
+ except C.EPFail as err:
369
+ if self._enable_fallback:
370
+ print(f"EP Error: {err!s} using {self._providers}")
371
+ print(f"Falling back to {self._fallback_providers} and retrying.")
372
+ self.set_providers(self._fallback_providers)
373
+ # Fallback only once.
374
+ self.disable_fallback()
375
+ return invoke(self._sess, output_names, input_dict_ort_values, run_options)
376
+ raise
377
+
378
+ def end_profiling(self):
379
+ """
380
+ End profiling and return results in a file.
381
+
382
+ The results are stored in a filename if the option
383
+ :meth:`onnxruntime.SessionOptions.enable_profiling`.
384
+ """
385
+ return self._sess.end_profiling()
386
+
387
+ def get_profiling_start_time_ns(self):
388
+ """
389
+ Return the nanoseconds of profiling's start time
390
+ Comparable to time.monotonic_ns() after Python 3.3
391
+ On some platforms, this timer may not be as precise as nanoseconds
392
+ For instance, on Windows and MacOS, the precision will be ~100ns
393
+ """
394
+ return self._sess.get_profiling_start_time_ns
395
+
396
+ def io_binding(self) -> IOBinding:
397
+ "Return an onnxruntime.IOBinding object`."
398
+ return IOBinding(self)
399
+
400
+ def run_with_iobinding(self, iobinding, run_options=None):
401
+ """
402
+ Compute the predictions.
403
+
404
+ :param iobinding: the iobinding object that has graph inputs/outputs bind.
405
+ :param run_options: See :class:`onnxruntime.RunOptions`.
406
+ """
407
+ self._sess.run_with_iobinding(iobinding._iobinding, run_options)
408
+
409
+ def set_ep_dynamic_options(self, options: dict[str, str]):
410
+ """
411
+ Set dynamic options for execution providers.
412
+
413
+ :param options: Dictionary of key-value pairs where both keys and values are strings.
414
+ These options will be passed to the execution providers to modify
415
+ their runtime behavior.
416
+ """
417
+ self._sess.set_ep_dynamic_options(options)
418
+
419
+ def get_tuning_results(self):
420
+ return self._sess.get_tuning_results()
421
+
422
+ def set_tuning_results(self, results, *, error_on_invalid=False):
423
+ return self._sess.set_tuning_results(results, error_on_invalid)
424
+
425
+ def run_with_ortvaluevector(self, run_options, feed_names, feeds, fetch_names, fetches, fetch_devices):
426
+ """
427
+ Compute the predictions similar to other run_*() methods but with minimal C++/Python conversion overhead.
428
+
429
+ :param run_options: See :class:`onnxruntime.RunOptions`.
430
+ :param feed_names: list of input names.
431
+ :param feeds: list of input OrtValue.
432
+ :param fetch_names: list of output names.
433
+ :param fetches: list of output OrtValue.
434
+ :param fetch_devices: list of output devices.
435
+ """
436
+ self._sess.run_with_ortvaluevector(run_options, feed_names, feeds, fetch_names, fetches, fetch_devices)
437
+
438
+
439
+ class InferenceSession(Session):
440
+ """
441
+ This is the main class used to run a model.
442
+ """
443
+
444
+ def __init__(
445
+ self,
446
+ path_or_bytes: str | bytes | os.PathLike,
447
+ sess_options: onnxruntime.SessionOptions | None = None,
448
+ providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None,
449
+ provider_options: Sequence[dict[Any, Any]] | None = None,
450
+ **kwargs,
451
+ ) -> None:
452
+ """
453
+ :param path_or_bytes: Filename or serialized ONNX or ORT format model in a byte string.
454
+ :param sess_options: Session options.
455
+ :param providers: Optional sequence of providers in order of decreasing
456
+ precedence. Values can either be provider names or tuples of
457
+ (provider name, options dict). If not provided, then all available
458
+ providers are used with the default precedence.
459
+ :param provider_options: Optional sequence of options dicts corresponding
460
+ to the providers listed in 'providers'.
461
+
462
+ The model type will be inferred unless explicitly set in the SessionOptions.
463
+ To explicitly set:
464
+
465
+ ::
466
+
467
+ so = onnxruntime.SessionOptions()
468
+ # so.add_session_config_entry('session.load_model_format', 'ONNX') or
469
+ so.add_session_config_entry('session.load_model_format', 'ORT')
470
+
471
+ A file extension of '.ort' will be inferred as an ORT format model.
472
+ All other filenames are assumed to be ONNX format models.
473
+
474
+ 'providers' can contain either names or names and options. When any options
475
+ are given in 'providers', 'provider_options' should not be used.
476
+
477
+ The list of providers is ordered by precedence. For example
478
+ `['CUDAExecutionProvider', 'CPUExecutionProvider']`
479
+ means execute a node using `CUDAExecutionProvider`
480
+ if capable, otherwise execute using `CPUExecutionProvider`.
481
+ """
482
+ super().__init__(enable_fallback=int(kwargs.get("enable_fallback", 1)) == 1)
483
+
484
+ if isinstance(path_or_bytes, (str, os.PathLike)):
485
+ self._model_path = os.fspath(path_or_bytes)
486
+ self._model_bytes = None
487
+ elif isinstance(path_or_bytes, bytes):
488
+ self._model_path = None
489
+ self._model_bytes = path_or_bytes # TODO: This is bad as we're holding the memory indefinitely
490
+ else:
491
+ raise TypeError(f"Unable to load from type '{type(path_or_bytes)}'")
492
+
493
+ self._sess_options = sess_options
494
+ self._sess_options_initial = sess_options
495
+ if "read_config_from_model" in kwargs:
496
+ self._read_config_from_model = int(kwargs["read_config_from_model"]) == 1
497
+ else:
498
+ self._read_config_from_model = os.environ.get("ORT_LOAD_CONFIG_FROM_MODEL") == "1"
499
+
500
+ # internal parameters that we don't expect to be used in general so aren't documented
501
+ disabled_optimizers = kwargs.get("disabled_optimizers")
502
+
503
+ try:
504
+ self._create_inference_session(providers, provider_options, disabled_optimizers)
505
+ except (ValueError, RuntimeError) as e:
506
+ if self._enable_fallback:
507
+ try:
508
+ print("*************** EP Error ***************")
509
+ print(f"EP Error {e} when using {providers}")
510
+ print(f"Falling back to {self._fallback_providers} and retrying.")
511
+ print("****************************************")
512
+ self._create_inference_session(self._fallback_providers, None)
513
+ # Fallback only once.
514
+ self.disable_fallback()
515
+ return
516
+ except Exception as fallback_error:
517
+ raise fallback_error from e
518
+ # Fallback is disabled. Raise the original error.
519
+ raise e
520
+
521
+ def _create_inference_session(self, providers, provider_options, disabled_optimizers=None):
522
+ available_providers = C.get_available_providers()
523
+
524
+ # Validate that TensorrtExecutionProvider and NvTensorRTRTXExecutionProvider are not both specified
525
+ if providers:
526
+ has_tensorrt = any(
527
+ provider == "TensorrtExecutionProvider"
528
+ or (isinstance(provider, tuple) and provider[0] == "TensorrtExecutionProvider")
529
+ for provider in providers
530
+ )
531
+ has_tensorrt_rtx = any(
532
+ provider == "NvTensorRTRTXExecutionProvider"
533
+ or (isinstance(provider, tuple) and provider[0] == "NvTensorRTRTXExecutionProvider")
534
+ for provider in providers
535
+ )
536
+ if has_tensorrt and has_tensorrt_rtx:
537
+ raise ValueError(
538
+ "Cannot enable both 'TensorrtExecutionProvider' and 'NvTensorRTRTXExecutionProvider' "
539
+ "in the same session."
540
+ )
541
+ # Tensorrt and TensorRT RTX can fall back to CUDA if it's explicitly assigned. All others fall back to CPU.
542
+ if "NvTensorRTRTXExecutionProvider" in available_providers:
543
+ if (
544
+ providers
545
+ and any(
546
+ provider == "CUDAExecutionProvider"
547
+ or (isinstance(provider, tuple) and provider[0] == "CUDAExecutionProvider")
548
+ for provider in providers
549
+ )
550
+ and any(
551
+ provider == "NvTensorRTRTXExecutionProvider"
552
+ or (isinstance(provider, tuple) and provider[0] == "NvTensorRTRTXExecutionProvider")
553
+ for provider in providers
554
+ )
555
+ ):
556
+ self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
557
+ else:
558
+ self._fallback_providers = ["CPUExecutionProvider"]
559
+ elif "TensorrtExecutionProvider" in available_providers:
560
+ if (
561
+ providers
562
+ and any(
563
+ provider == "CUDAExecutionProvider"
564
+ or (isinstance(provider, tuple) and provider[0] == "CUDAExecutionProvider")
565
+ for provider in providers
566
+ )
567
+ and any(
568
+ provider == "TensorrtExecutionProvider"
569
+ or (isinstance(provider, tuple) and provider[0] == "TensorrtExecutionProvider")
570
+ for provider in providers
571
+ )
572
+ ):
573
+ self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
574
+ else:
575
+ self._fallback_providers = ["CPUExecutionProvider"]
576
+ else:
577
+ self._fallback_providers = ["CPUExecutionProvider"]
578
+
579
+ # validate providers and provider_options before other initialization
580
+ providers, provider_options = check_and_normalize_provider_args(
581
+ providers, provider_options, available_providers
582
+ )
583
+
584
+ # Print a warning if user passed providers to InferenceSession() but the SessionOptions instance
585
+ # already has provider information (e.g., via add_provider_for_devices()). The providers specified
586
+ # here will take precedence.
587
+ if self._sess_options is not None and (providers or provider_options) and self._sess_options.has_providers():
588
+ warnings.warn(
589
+ "Specified 'providers'/'provider_options' when creating InferenceSession but SessionOptions has "
590
+ "already been configured with providers. InferenceSession will only use the providers "
591
+ "passed to InferenceSession()."
592
+ )
593
+
594
+ session_options = self._sess_options if self._sess_options else C.get_default_session_options()
595
+
596
+ self._register_ep_custom_ops(session_options, providers, provider_options, available_providers)
597
+
598
+ if self._model_path:
599
+ sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
600
+ else:
601
+ sess = C.InferenceSession(session_options, self._model_bytes, False, self._read_config_from_model)
602
+
603
+ if disabled_optimizers is None:
604
+ disabled_optimizers = set()
605
+ elif not isinstance(disabled_optimizers, set):
606
+ # convert to set. assumes iterable
607
+ disabled_optimizers = set(disabled_optimizers)
608
+
609
+ # initialize the C++ InferenceSession
610
+ sess.initialize_session(providers, provider_options, disabled_optimizers)
611
+
612
+ self._sess = sess
613
+ self._sess_options = self._sess.session_options
614
+ self._inputs_meta = self._sess.inputs_meta
615
+ self._outputs_meta = self._sess.outputs_meta
616
+ self._overridable_initializers = self._sess.overridable_initializers
617
+ self._input_meminfos = self._sess.input_meminfos
618
+ self._output_meminfos = self._sess.output_meminfos
619
+ self._input_epdevices = self._sess.input_epdevices
620
+ self._model_meta = self._sess.model_meta
621
+ self._providers = self._sess.get_providers()
622
+ self._provider_options = self._sess.get_provider_options()
623
+ self._profiling_start_time_ns = self._sess.get_profiling_start_time_ns
624
+
625
+ def _reset_session(self, providers, provider_options) -> None:
626
+ "release underlying session object."
627
+ # meta data references session internal structures
628
+ # so they must be set to None to decrement _sess reference count.
629
+ self._sess_options = None
630
+ self._inputs_meta = None
631
+ self._outputs_meta = None
632
+ self._overridable_initializers = None
633
+ self._input_meminfos = None
634
+ self._output_meminfos = None
635
+ self._input_epdevices = None
636
+ self._model_meta = None
637
+ self._providers = None
638
+ self._provider_options = None
639
+ self._profiling_start_time_ns = None
640
+
641
+ # create a new C.InferenceSession
642
+ self._sess = None
643
+ self._sess_options = self._sess_options_initial
644
+ self._create_inference_session(providers, provider_options)
645
+
646
+ def _register_ep_custom_ops(self, session_options, providers, provider_options, available_providers):
647
+ for i in range(len(providers)):
648
+ if providers[i] in available_providers and providers[i] == "TensorrtExecutionProvider":
649
+ C.register_tensorrt_plugins_as_custom_ops(session_options, provider_options[i])
650
+ elif (
651
+ isinstance(providers[i], tuple)
652
+ and providers[i][0] in available_providers
653
+ and providers[i][0] == "TensorrtExecutionProvider"
654
+ ):
655
+ C.register_tensorrt_plugins_as_custom_ops(session_options, providers[i][1])
656
+
657
+ if providers[i] in available_providers and providers[i] == "NvTensorRTRTXExecutionProvider":
658
+ C.register_nv_tensorrt_rtx_plugins_as_custom_ops(session_options, provider_options[i])
659
+ elif (
660
+ isinstance(providers[i], tuple)
661
+ and providers[i][0] in available_providers
662
+ and providers[i][0] == "NvTensorrtRTXExecutionProvider"
663
+ ):
664
+ C.register_nv_tensorrt_rtx_plugins_as_custom_ops(session_options, providers[i][1])
665
+
666
+
667
+ def make_get_initializer_location_func_wrapper(
668
+ get_initializer_location_func: GetInitializerLocationFunc,
669
+ ) -> GetInitializerLocationWrapperFunc:
670
+ """
671
+ Wraps a user's "get initializer location" function. The returned wrapper function adheres to the
672
+ signature expected by ORT.
673
+
674
+ Need this wrapper to:
675
+ - Convert the `initializer_value` parameter from `C.OrtValue` to `onnxruntime.OrtValue`, which is more
676
+ convenient for the user's function to use.
677
+ - Allow the user's function to return the original `external_info` parameter (this wrapper makes a copy)
678
+ """
679
+
680
+ def get_initializer_location_func_wrapper(
681
+ initializer_name: str,
682
+ initializer_value: C.OrtValue,
683
+ external_info: C.OrtExternalInitializerInfo | None,
684
+ ) -> C.OrtExternalInitializerInfo | None:
685
+ ret_val: C.OrtExternalInitializerInfo | None = get_initializer_location_func(
686
+ initializer_name, OrtValue(initializer_value), external_info
687
+ )
688
+ if ret_val is not None and ret_val == external_info:
689
+ # User returned `external_info` (const and owned by ORT). ORT expects the returned value to be
690
+ # a new instance (that it deletes), so make a copy.
691
+ ret_val = C.OrtExternalInitializerInfo(ret_val.filepath, ret_val.file_offset, ret_val.byte_size)
692
+ return ret_val
693
+
694
+ return get_initializer_location_func_wrapper
695
+
696
+
697
+ class ModelCompiler:
698
+ """
699
+ This class is used to compile an ONNX model. A compiled ONNX model has EPContext nodes that each
700
+ encapsulates a subgraph compiled/optimized for a specific execution provider.
701
+
702
+ Refer to the EPContext design document for more information about EPContext models:
703
+ https://onnxruntime.ai/docs/execution-providers/EP-Context-Design.html
704
+
705
+ ::
706
+
707
+ sess_options = onnxruntime.SessionOptions()
708
+ sess_options.add_provider("SomeExecutionProvider", {"option1": "value1"})
709
+ # Alternatively, allow ONNX Runtime to select the provider automatically given a policy:
710
+ # sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_NPU)
711
+
712
+ model_compiler = onnxruntime.ModelCompiler(sess_options, "input_model.onnx")
713
+ model_compiler.compile_to_file("output_model.onnx")
714
+ """
715
+
716
+ def __init__(
717
+ self,
718
+ sess_options: onnxruntime.SessionOptions,
719
+ input_model_path_or_bytes: str | os.PathLike | bytes,
720
+ embed_compiled_data_into_model: bool = False,
721
+ external_initializers_file_path: str | os.PathLike | None = None,
722
+ external_initializers_size_threshold: int = 1024,
723
+ flags: int = C.OrtCompileApiFlags.NONE,
724
+ graph_optimization_level: C.GraphOptimizationLevel = C.GraphOptimizationLevel.ORT_DISABLE_ALL,
725
+ get_initializer_location_func: GetInitializerLocationFunc | None = None,
726
+ ):
727
+ """
728
+ Creates a ModelCompiler instance.
729
+
730
+ :param sess_options: Session options containing the providers for which the model will be compiled.
731
+ Refer to SessionOptions.add_provider() and SessionOptions.set_provider_selection_policy().
732
+ :param input_model_path_or_bytes: The path to the input model file or bytes representing a serialized
733
+ ONNX model.
734
+ :param embed_compiled_data_into_model: Defaults to False. Set to True to embed compiled binary data into
735
+ EPContext nodes in the compiled model.
736
+ :param external_initializers_file_path: Defaults to None. Set to a path for a file that will store the
737
+ initializers for non-compiled nodes.
738
+ :param external_initializers_size_threshold: Defaults to 1024. Ignored if `external_initializers_file_path`
739
+ is None or empty. Initializers larger than this threshold are stored in the external initializers file.
740
+ :param flags: Additional boolean options to enable. Set this parameter to a bitwise OR of
741
+ flags in onnxruntime.OrtCompileApiFlags.
742
+ :param graph_optimization_level: The graph optimization level.
743
+ Defaults to onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL.
744
+ :param get_initializer_location_func: Optional function called for every initializer to allow user to specify
745
+ whether an initializer should be stored within the model or externally. Example:
746
+ ```
747
+ def get_initializer_location(
748
+ initializer_name: str,
749
+ initializer_value: onnxrt.OrtValue,
750
+ external_info: onnxrt.OrtExternalInitializerInfo | None,
751
+ ) -> onnxrt.OrtExternalInitializerInfo | None:
752
+ byte_size = initializer_value.tensor_size_in_bytes()
753
+
754
+ if byte_size < 64:
755
+ return None # Store small initializer within compiled model.
756
+
757
+ # Else, write initializer to new external file.
758
+ value_np = initializer_value.numpy()
759
+ file_offset = ext_init_file.tell()
760
+ ext_init_file.write(value_np.tobytes())
761
+ return onnxrt.OrtExternalInitializerInfo(initializer_file_path, file_offset, byte_size)
762
+ ```
763
+ """
764
+ input_model_path: str | os.PathLike | None = None
765
+ input_model_bytes: bytes | None = None
766
+ if isinstance(input_model_path_or_bytes, (str, os.PathLike)):
767
+ if not input_model_path_or_bytes:
768
+ raise ValueError("Input model path is empty")
769
+ input_model_path = os.fspath(input_model_path_or_bytes)
770
+ elif isinstance(input_model_path_or_bytes, bytes):
771
+ if len(input_model_path_or_bytes) == 0:
772
+ raise ValueError("Input model bytes array is empty")
773
+ input_model_bytes = input_model_path_or_bytes
774
+ else:
775
+ raise TypeError(f"Unable to load from type '{type(input_model_path_or_bytes)}'")
776
+
777
+ if external_initializers_file_path:
778
+ if not isinstance(external_initializers_file_path, (str, os.PathLike)):
779
+ arg_type = type(external_initializers_file_path)
780
+ raise TypeError(f"Output external initializer filepath is of unexpected type '{arg_type}'")
781
+ external_initializers_file_path = os.fspath(external_initializers_file_path)
782
+ else:
783
+ external_initializers_file_path = ""
784
+
785
+ if get_initializer_location_func is not None:
786
+ if external_initializers_file_path:
787
+ raise ValueError(
788
+ "Cannot initialize ModelCompiler with both `external_initializers_file_path` "
789
+ "and `get_initializer_location_func`"
790
+ )
791
+ self.get_initializer_location_func_wrapper = make_get_initializer_location_func_wrapper(
792
+ get_initializer_location_func
793
+ )
794
+ else:
795
+ self.get_initializer_location_func_wrapper = None
796
+
797
+ if input_model_path:
798
+ self._model_compiler = C.ModelCompiler(
799
+ sess_options,
800
+ input_model_path,
801
+ True, # is path
802
+ embed_compiled_data_into_model,
803
+ external_initializers_file_path,
804
+ external_initializers_size_threshold,
805
+ flags,
806
+ graph_optimization_level,
807
+ self.get_initializer_location_func_wrapper,
808
+ )
809
+ else:
810
+ self._model_compiler = C.ModelCompiler(
811
+ sess_options,
812
+ input_model_bytes,
813
+ False, # is bytes
814
+ embed_compiled_data_into_model,
815
+ external_initializers_file_path,
816
+ external_initializers_size_threshold,
817
+ flags,
818
+ graph_optimization_level,
819
+ self.get_initializer_location_func_wrapper,
820
+ )
821
+
822
+ def compile_to_file(self, output_model_path: str | None = None):
823
+ """
824
+ Compiles to an output file. If an output file path is not provided,
825
+ the output file path is generated based on the input model path by replacing
826
+ '.onnx' with '_ctx.onnx'. Ex: The generated output file is 'model_ctx.onnx' for
827
+ an input model with path 'model.onnx'.
828
+
829
+ Raises an 'InvalidArgument' exception if the compilation options are invalid.
830
+
831
+ :param output_model_path: Defaults to None. The path for the output/compiled model.
832
+ """
833
+ if output_model_path:
834
+ if not isinstance(output_model_path, (str, os.PathLike)):
835
+ raise TypeError(f"Output model's filepath is of unexpected type '{type(output_model_path)}'")
836
+ output_model_path = os.fspath(output_model_path)
837
+ self._model_compiler.compile_to_file(output_model_path)
838
+
839
+ def compile_to_bytes(self) -> bytes:
840
+ """
841
+ Compiles to bytes representing the serialized compiled ONNX model.
842
+
843
+ Raises an 'InvalidArgument' exception if the compilation options are invalid.
844
+
845
+ :return: A bytes object representing the compiled ONNX model.
846
+ """
847
+ return self._model_compiler.compile_to_bytes()
848
+
849
+ def compile_to_stream(self, write_function: Callable[[bytes], None]):
850
+ """
851
+ Compiles the input model and writes the serialized ONNX bytes to a stream using the provided write function.
852
+ Raises an 'InvalidArgument' exception if the compilation options are invalid.
853
+ :param write_function: A callable that accepts a bytes buffer to write.
854
+ """
855
+ self._model_compiler.compile_to_stream(write_function)
856
+
857
+
858
+ class IOBinding:
859
+ """
860
+ This class provides API to bind input/output to a specified device, e.g. GPU.
861
+ """
862
+
863
+ def __init__(self, session: Session):
864
+ self._iobinding = C.SessionIOBinding(session._sess)
865
+ self._numpy_obj_references = {}
866
+
867
+ def bind_cpu_input(self, name, arr_on_cpu):
868
+ """
869
+ bind an input to array on CPU
870
+ :param name: input name
871
+ :param arr_on_cpu: input values as a python array on CPU
872
+ """
873
+ # Hold a reference to the numpy object as the bound OrtValue is backed
874
+ # directly by the data buffer of the numpy object and so the numpy object
875
+ # must be around until this IOBinding instance is around
876
+ self._numpy_obj_references[name] = arr_on_cpu
877
+ self._iobinding.bind_input(name, arr_on_cpu)
878
+
879
+ def bind_input(self, name, device_type, device_id, element_type, shape, buffer_ptr):
880
+ """
881
+ :param name: input name
882
+ :param device_type: e.g. cpu, cuda, cann
883
+ :param device_id: device id, e.g. 0
884
+ :param element_type: input element type. It can be either numpy type (like numpy.float32) or an integer for onnx type (like onnx.TensorProto.BFLOAT16)
885
+ :param shape: input shape
886
+ :param buffer_ptr: memory pointer to input data
887
+ """
888
+ self._iobinding.bind_input(
889
+ name,
890
+ C.OrtDevice(
891
+ get_ort_device_type(device_type),
892
+ C.OrtDevice.default_memory(),
893
+ device_id,
894
+ ),
895
+ element_type,
896
+ shape,
897
+ buffer_ptr,
898
+ )
899
+
900
+ def bind_ortvalue_input(self, name, ortvalue):
901
+ """
902
+ :param name: input name
903
+ :param ortvalue: OrtValue instance to bind
904
+ """
905
+ self._iobinding.bind_ortvalue_input(name, ortvalue._ortvalue)
906
+
907
+ def synchronize_inputs(self):
908
+ self._iobinding.synchronize_inputs()
909
+
910
+ def bind_output(
911
+ self,
912
+ name,
913
+ device_type="cpu",
914
+ device_id=0,
915
+ element_type=None,
916
+ shape=None,
917
+ buffer_ptr=None,
918
+ ):
919
+ """
920
+ :param name: output name
921
+ :param device_type: e.g. cpu, cuda, cann, cpu by default
922
+ :param device_id: device id, e.g. 0
923
+ :param element_type: output element type. It can be either numpy type (like numpy.float32) or an integer for onnx type (like onnx.TensorProto.BFLOAT16)
924
+ :param shape: output shape
925
+ :param buffer_ptr: memory pointer to output data
926
+ """
927
+
928
+ # Follow the `if` path when the user has not provided any pre-allocated buffer but still
929
+ # would like to bind an output to a specific device (e.g. cuda).
930
+ # Pre-allocating an output buffer may not be an option for the user as :
931
+ # (1) They may not want to use a custom allocator specific to the device they want to bind the output to,
932
+ # in which case ORT will allocate the memory for the user
933
+ # (2) The output has a dynamic shape and hence the size of the buffer may not be fixed across runs
934
+ if buffer_ptr is None:
935
+ self._iobinding.bind_output(
936
+ name,
937
+ C.OrtDevice(
938
+ get_ort_device_type(device_type),
939
+ C.OrtDevice.default_memory(),
940
+ device_id,
941
+ ),
942
+ )
943
+ else:
944
+ if element_type is None or shape is None:
945
+ raise ValueError("`element_type` and `shape` are to be provided if pre-allocated memory is provided")
946
+ self._iobinding.bind_output(
947
+ name,
948
+ C.OrtDevice(
949
+ get_ort_device_type(device_type),
950
+ C.OrtDevice.default_memory(),
951
+ device_id,
952
+ ),
953
+ element_type,
954
+ shape,
955
+ buffer_ptr,
956
+ )
957
+
958
+ def bind_ortvalue_output(self, name, ortvalue):
959
+ """
960
+ :param name: output name
961
+ :param ortvalue: OrtValue instance to bind
962
+ """
963
+ self._iobinding.bind_ortvalue_output(name, ortvalue._ortvalue)
964
+
965
+ def synchronize_outputs(self):
966
+ self._iobinding.synchronize_outputs()
967
+
968
+ def get_outputs(self):
969
+ """
970
+ Returns the output OrtValues from the Run() that preceded the call.
971
+ The data buffer of the obtained OrtValues may not reside on CPU memory
972
+ """
973
+ outputs = self._iobinding.get_outputs()
974
+ if not isinstance(outputs, C.OrtValueVector):
975
+ raise TypeError("get_outputs() must return an instance of type 'OrtValueVector'.")
976
+ return [OrtValue(ortvalue) for ortvalue in outputs]
977
+
978
+ def get_outputs_as_ortvaluevector(self):
979
+ return self._iobinding.get_outputs()
980
+
981
+ def copy_outputs_to_cpu(self):
982
+ """Copy output contents to CPU."""
983
+ return self._iobinding.copy_outputs_to_cpu()
984
+
985
+ def clear_binding_inputs(self):
986
+ self._iobinding.clear_binding_inputs()
987
+
988
+ def clear_binding_outputs(self):
989
+ self._iobinding.clear_binding_outputs()
990
+
991
+
992
+ class OrtValue:
993
+ """
994
+ A data structure that supports all ONNX data formats (tensors and non-tensors) that allows users
995
+ to place the data backing these on a device, for example, on a CUDA supported device.
996
+ This class provides APIs to construct and deal with OrtValues.
997
+ """
998
+
999
+ def __init__(self, ortvalue: C.OrtValue, numpy_obj: np.ndarray | None = None):
1000
+ if isinstance(ortvalue, C.OrtValue):
1001
+ self._ortvalue = ortvalue
1002
+ # Hold a ref count to the numpy object if the OrtValue is backed directly
1003
+ # by its data buffer so that it isn't destroyed when the OrtValue is in use
1004
+ self._numpy_obj = numpy_obj
1005
+ else:
1006
+ # An end user won't hit this error
1007
+ raise ValueError(
1008
+ "`Provided ortvalue` needs to be of type `onnxruntime.capi.onnxruntime_pybind11_state.OrtValue`"
1009
+ )
1010
+
1011
+ def _get_c_value(self) -> C.OrtValue:
1012
+ return self._ortvalue
1013
+
1014
+ @classmethod
1015
+ def ortvalue_from_numpy(cls, numpy_obj: np.ndarray, /, device_type="cpu", device_id=0, vendor_id=-1) -> OrtValue:
1016
+ """
1017
+ Factory method to construct an OrtValue (which holds a Tensor) from a given Numpy object
1018
+ A copy of the data in the Numpy object is held by the OrtValue only if the device is NOT cpu
1019
+
1020
+ :param numpy_obj: The Numpy object to construct the OrtValue from
1021
+ :param device_type: e.g. cpu, cuda, cann, cpu by default
1022
+ :param device_id: device id, e.g. 0
1023
+ :param vendor_id: The device's PCI vendor id. If provided, the device_type should be "gpu" or "npu".
1024
+ """
1025
+ # Hold a reference to the numpy object (if device_type is 'cpu') as the OrtValue
1026
+ # is backed directly by the data buffer of the numpy object and so the numpy object
1027
+ # must be around until this OrtValue instance is around
1028
+ return cls(
1029
+ C.OrtValue.ortvalue_from_numpy(
1030
+ numpy_obj,
1031
+ OrtDevice.make(device_type, device_id, vendor_id)._get_c_device(),
1032
+ ),
1033
+ numpy_obj if device_type.lower() == "cpu" else None,
1034
+ )
1035
+
1036
+ @classmethod
1037
+ def ortvalue_from_numpy_with_onnx_type(cls, data: np.ndarray, /, onnx_element_type: int) -> OrtValue:
1038
+ """
1039
+ This method creates an instance of OrtValue on top of the numpy array.
1040
+ No data copy is made and the lifespan of the resulting OrtValue should never
1041
+ exceed the lifespan of bytes object. The API attempts to reinterpret
1042
+ the data type which is expected to be the same size. This is useful
1043
+ when we want to use an ONNX data type that is not supported by numpy.
1044
+
1045
+ :param data: numpy.ndarray.
1046
+ :param onnx_element_type: a valid onnx TensorProto::DataType enum value
1047
+ """
1048
+ return cls(C.OrtValue.ortvalue_from_numpy_with_onnx_type(data, onnx_element_type), data)
1049
+
1050
+ @classmethod
1051
+ def ortvalue_from_shape_and_type(
1052
+ cls, shape: Sequence[int], element_type, device_type: str = "cpu", device_id: int = 0, vendor_id: int = -1
1053
+ ) -> OrtValue:
1054
+ """
1055
+ Factory method to construct an OrtValue (which holds a Tensor) from given shape and element_type
1056
+
1057
+ :param shape: List of integers indicating the shape of the OrtValue
1058
+ :param element_type: The data type of the elements. It can be either numpy type (like numpy.float32) or an integer for onnx type (like onnx.TensorProto.BFLOAT16).
1059
+ :param device_type: e.g. cpu, cuda, cann, cpu by default
1060
+ :param device_id: device id, e.g. 0
1061
+ :param vendor_id: If provided the device type should be "gpu" or "npu".
1062
+ """
1063
+
1064
+ device = OrtDevice.make(device_type, device_id, vendor_id)._get_c_device()
1065
+
1066
+ # Integer for onnx element type (see https://onnx.ai/onnx/api/mapping.html).
1067
+ # This is helpful for some data type (like TensorProto.BFLOAT16) that is not available in numpy.
1068
+ if isinstance(element_type, int):
1069
+ return cls(
1070
+ C.OrtValue.ortvalue_from_shape_and_onnx_type(
1071
+ shape,
1072
+ element_type,
1073
+ device,
1074
+ )
1075
+ )
1076
+
1077
+ return cls(
1078
+ C.OrtValue.ortvalue_from_shape_and_type(
1079
+ shape,
1080
+ element_type,
1081
+ device,
1082
+ )
1083
+ )
1084
+
1085
+ @classmethod
1086
+ def ort_value_from_sparse_tensor(cls, sparse_tensor: SparseTensor) -> OrtValue:
1087
+ """
1088
+ The function will construct an OrtValue instance from a valid SparseTensor
1089
+ The new instance of OrtValue will assume the ownership of sparse_tensor
1090
+ """
1091
+ return cls(C.OrtValue.ort_value_from_sparse_tensor(sparse_tensor._get_c_tensor()))
1092
+
1093
+ def as_sparse_tensor(self) -> SparseTensor:
1094
+ """
1095
+ The function will return SparseTensor contained in this OrtValue
1096
+ """
1097
+ return SparseTensor(self._ortvalue.as_sparse_tensor())
1098
+
1099
+ def data_ptr(self) -> int:
1100
+ """
1101
+ Returns the address of the first element in the OrtValue's data buffer
1102
+ """
1103
+ return self._ortvalue.data_ptr()
1104
+
1105
+ def device_name(self) -> str:
1106
+ """
1107
+ Returns the name of the device where the OrtValue's data buffer resides e.g. cpu, cuda, cann
1108
+ """
1109
+ return self._ortvalue.device_name().lower()
1110
+
1111
+ def shape(self) -> Sequence[int]:
1112
+ """
1113
+ Returns the shape of the data in the OrtValue
1114
+ """
1115
+ return self._ortvalue.shape()
1116
+
1117
+ def data_type(self) -> str:
1118
+ """
1119
+ Returns the data type of the data in the OrtValue. E.g. 'tensor(int64)'
1120
+ """
1121
+ return self._ortvalue.data_type()
1122
+
1123
+ def element_type(self) -> int:
1124
+ """
1125
+ Returns the proto type of the data in the OrtValue
1126
+ if the OrtValue is a tensor.
1127
+ """
1128
+ return self._ortvalue.element_type()
1129
+
1130
+ def tensor_size_in_bytes(self) -> int:
1131
+ """
1132
+ Returns the size of the data in the OrtValue in bytes
1133
+ if the OrtValue is a tensor.
1134
+ """
1135
+ return self._ortvalue.tensor_size_in_bytes()
1136
+
1137
+ def has_value(self) -> bool:
1138
+ """
1139
+ Returns True if the OrtValue corresponding to an
1140
+ optional type contains data, else returns False
1141
+ """
1142
+ return self._ortvalue.has_value()
1143
+
1144
+ def is_tensor(self) -> bool:
1145
+ """
1146
+ Returns True if the OrtValue contains a Tensor, else returns False
1147
+ """
1148
+ return self._ortvalue.is_tensor()
1149
+
1150
+ def is_sparse_tensor(self) -> bool:
1151
+ """
1152
+ Returns True if the OrtValue contains a SparseTensor, else returns False
1153
+ """
1154
+ return self._ortvalue.is_sparse_tensor()
1155
+
1156
+ def is_tensor_sequence(self) -> bool:
1157
+ """
1158
+ Returns True if the OrtValue contains a Tensor Sequence, else returns False
1159
+ """
1160
+ return self._ortvalue.is_tensor_sequence()
1161
+
1162
+ def numpy(self) -> np.ndarray:
1163
+ """
1164
+ Returns a Numpy object from the OrtValue.
1165
+ Valid only for OrtValues holding Tensors. Throws for OrtValues holding non-Tensors.
1166
+ Use accessors to gain a reference to non-Tensor objects such as SparseTensor
1167
+ """
1168
+ return self._ortvalue.numpy()
1169
+
1170
+ def update_inplace(self, np_arr) -> None:
1171
+ """
1172
+ Update the OrtValue in place with a new Numpy array. The numpy contents
1173
+ are copied over to the device memory backing the OrtValue. It can be used
1174
+ to update the input valuess for an InferenceSession with CUDA graph
1175
+ enabled or other scenarios where the OrtValue needs to be updated while
1176
+ the memory address can not be changed.
1177
+ """
1178
+ self._ortvalue.update_inplace(np_arr)
1179
+
1180
+
1181
+ def copy_tensors(src: Sequence[OrtValue], dst: Sequence[OrtValue], stream=None) -> None:
1182
+ """
1183
+ Copy tensor data from source OrtValue sequence to destination OrtValue sequence.
1184
+ """
1185
+ c_sources = [s._get_c_value() for s in src]
1186
+ c_dsts = [d._get_c_value() for d in dst]
1187
+ C.copy_tensors(c_sources, c_dsts, stream)
1188
+
1189
+
1190
+ class OrtDevice:
1191
+ """
1192
+ A data structure that exposes the underlying C++ OrtDevice
1193
+ """
1194
+
1195
+ def __init__(self, c_ort_device):
1196
+ """
1197
+ Internal constructor
1198
+ """
1199
+ if isinstance(c_ort_device, C.OrtDevice):
1200
+ self._ort_device = c_ort_device
1201
+ else:
1202
+ # An end user won't hit this error
1203
+ raise ValueError(
1204
+ "`Provided object` needs to be of type `onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice`"
1205
+ )
1206
+
1207
+ def _get_c_device(self):
1208
+ """
1209
+ Internal accessor to underlying object
1210
+ """
1211
+ return self._ort_device
1212
+
1213
+ @staticmethod
1214
+ def make(ort_device_name, device_id, vendor_id=-1):
1215
+ if vendor_id < 0:
1216
+ # backwards compatibility with predefined OrtDevice names
1217
+ return OrtDevice(
1218
+ C.OrtDevice(
1219
+ get_ort_device_type(ort_device_name),
1220
+ C.OrtDevice.default_memory(),
1221
+ device_id,
1222
+ )
1223
+ )
1224
+ else:
1225
+ # generic. use GPU or NPU for ort_device_name and provide a vendor id.
1226
+ # vendor id of 0 is valid in some cases (e.g. webgpu is generic and does not have a vendor id)
1227
+ return OrtDevice(
1228
+ C.OrtDevice(
1229
+ get_ort_device_type(ort_device_name),
1230
+ C.OrtDevice.default_memory(),
1231
+ vendor_id,
1232
+ device_id,
1233
+ )
1234
+ )
1235
+
1236
+ def device_id(self):
1237
+ return self._ort_device.device_id()
1238
+
1239
+ def device_type(self):
1240
+ return self._ort_device.device_type()
1241
+
1242
+ def device_vendor_id(self):
1243
+ return self._ort_device.vendor_id()
1244
+
1245
+ def device_mem_type(self):
1246
+ return self._ort_device.mem_type()
1247
+
1248
+
1249
+ class SparseTensor:
1250
+ """
1251
+ A data structure that project the C++ SparseTensor object
1252
+ The class provides API to work with the object.
1253
+ Depending on the format, the class will hold more than one buffer
1254
+ depending on the format
1255
+ """
1256
+
1257
+ def __init__(self, sparse_tensor: C.SparseTensor):
1258
+ """
1259
+ Internal constructor
1260
+ """
1261
+ if isinstance(sparse_tensor, C.SparseTensor):
1262
+ self._tensor = sparse_tensor
1263
+ else:
1264
+ # An end user won't hit this error
1265
+ raise ValueError(
1266
+ "`Provided object` needs to be of type `onnxruntime.capi.onnxruntime_pybind11_state.SparseTensor`"
1267
+ )
1268
+
1269
+ def _get_c_tensor(self) -> C.SparseTensor:
1270
+ return self._tensor
1271
+
1272
+ @classmethod
1273
+ def sparse_coo_from_numpy(
1274
+ cls,
1275
+ dense_shape: npt.NDArray[np.int64],
1276
+ values: np.ndarray,
1277
+ coo_indices: npt.NDArray[np.int64],
1278
+ ort_device: OrtDevice,
1279
+ ) -> SparseTensor:
1280
+ """
1281
+ Factory method to construct a SparseTensor in COO format from given arguments
1282
+
1283
+ :param dense_shape: 1-D numpy array(int64) or a python list that contains a dense_shape of the sparse tensor
1284
+ must be on cpu memory
1285
+ :param values: a homogeneous, contiguous 1-D numpy array that contains non-zero elements of the tensor
1286
+ of a type.
1287
+ :param coo_indices: contiguous numpy array(int64) that contains COO indices for the tensor. coo_indices may
1288
+ have a 1-D shape when it contains a linear index of non-zero values and its length must be equal to
1289
+ that of the values. It can also be of 2-D shape, in which has it contains pairs of coordinates for
1290
+ each of the nnz values and its length must be exactly twice of the values length.
1291
+ :param ort_device: - describes the backing memory owned by the supplied nummpy arrays. Only CPU memory is
1292
+ suppored for non-numeric data types.
1293
+
1294
+ For primitive types, the method will map values and coo_indices arrays into native memory and will use
1295
+ them as backing storage. It will increment the reference count for numpy arrays and it will decrement it
1296
+ on GC. The buffers may reside in any storage either CPU or GPU.
1297
+ For strings and objects, it will create a copy of the arrays in CPU memory as ORT does not support those
1298
+ on other devices and their memory can not be mapped.
1299
+ """
1300
+ return cls(C.SparseTensor.sparse_coo_from_numpy(dense_shape, values, coo_indices, ort_device._get_c_device()))
1301
+
1302
+ @classmethod
1303
+ def sparse_csr_from_numpy(
1304
+ cls,
1305
+ dense_shape: npt.NDArray[np.int64],
1306
+ values: np.ndarray,
1307
+ inner_indices: npt.NDArray[np.int64],
1308
+ outer_indices: npt.NDArray[np.int64],
1309
+ ort_device: OrtDevice,
1310
+ ) -> SparseTensor:
1311
+ """
1312
+ Factory method to construct a SparseTensor in CSR format from given arguments
1313
+
1314
+ :param dense_shape: 1-D numpy array(int64) or a python list that contains a dense_shape of the
1315
+ sparse tensor (rows, cols) must be on cpu memory
1316
+ :param values: a contiguous, homogeneous 1-D numpy array that contains non-zero elements of the tensor
1317
+ of a type.
1318
+ :param inner_indices: contiguous 1-D numpy array(int64) that contains CSR inner indices for the tensor.
1319
+ Its length must be equal to that of the values.
1320
+ :param outer_indices: contiguous 1-D numpy array(int64) that contains CSR outer indices for the tensor.
1321
+ Its length must be equal to the number of rows + 1.
1322
+ :param ort_device: - describes the backing memory owned by the supplied nummpy arrays. Only CPU memory is
1323
+ suppored for non-numeric data types.
1324
+
1325
+ For primitive types, the method will map values and indices arrays into native memory and will use them as
1326
+ backing storage. It will increment the reference count and it will decrement then count when it is GCed.
1327
+ The buffers may reside in any storage either CPU or GPU.
1328
+ For strings and objects, it will create a copy of the arrays in CPU memory as ORT does not support those
1329
+ on other devices and their memory can not be mapped.
1330
+ """
1331
+ return cls(
1332
+ C.SparseTensor.sparse_csr_from_numpy(
1333
+ dense_shape,
1334
+ values,
1335
+ inner_indices,
1336
+ outer_indices,
1337
+ ort_device._get_c_device(),
1338
+ )
1339
+ )
1340
+
1341
+ def values(self) -> np.ndarray:
1342
+ """
1343
+ The method returns a numpy array that is backed by the native memory
1344
+ if the data type is numeric. Otherwise, the returned numpy array that contains
1345
+ copies of the strings.
1346
+ """
1347
+ return self._tensor.values()
1348
+
1349
+ def as_coo_view(self):
1350
+ """
1351
+ The method will return coo representation of the sparse tensor which will enable
1352
+ querying COO indices. If the instance did not contain COO format, it would throw.
1353
+ You can query coo indices as:
1354
+
1355
+ ::
1356
+
1357
+ coo_indices = sparse_tensor.as_coo_view().indices()
1358
+
1359
+ which will return a numpy array that is backed by the native memory.
1360
+ """
1361
+ return self._tensor.get_coo_data()
1362
+
1363
+ def as_csrc_view(self):
1364
+ """
1365
+ The method will return CSR(C) representation of the sparse tensor which will enable
1366
+ querying CRS(C) indices. If the instance dit not contain CSR(C) format, it would throw.
1367
+ You can query indices as:
1368
+
1369
+ ::
1370
+
1371
+ inner_ndices = sparse_tensor.as_csrc_view().inner()
1372
+ outer_ndices = sparse_tensor.as_csrc_view().outer()
1373
+
1374
+ returning numpy arrays backed by the native memory.
1375
+ """
1376
+ return self._tensor.get_csrc_data()
1377
+
1378
+ def as_blocksparse_view(self):
1379
+ """
1380
+ The method will return coo representation of the sparse tensor which will enable
1381
+ querying BlockSparse indices. If the instance did not contain BlockSparse format, it would throw.
1382
+ You can query coo indices as:
1383
+
1384
+ ::
1385
+
1386
+ block_sparse_indices = sparse_tensor.as_blocksparse_view().indices()
1387
+
1388
+ which will return a numpy array that is backed by the native memory
1389
+ """
1390
+ return self._tensor.get_blocksparse_data()
1391
+
1392
+ def to_cuda(self, ort_device):
1393
+ """
1394
+ Returns a copy of this instance on the specified cuda device
1395
+
1396
+ :param ort_device: with name 'cuda' and valid gpu device id
1397
+
1398
+ The method will throw if:
1399
+
1400
+ - this instance contains strings
1401
+ - this instance is already on GPU. Cross GPU copy is not supported
1402
+ - CUDA is not present in this build
1403
+ - if the specified device is not valid
1404
+ """
1405
+ return SparseTensor(self._tensor.to_cuda(ort_device._get_c_device()))
1406
+
1407
+ def format(self):
1408
+ """
1409
+ Returns a OrtSparseFormat enumeration
1410
+ """
1411
+ return self._tensor.format
1412
+
1413
+ def dense_shape(self) -> npt.NDArray[np.int64]:
1414
+ """
1415
+ Returns a numpy array(int64) containing a dense shape of a sparse tensor
1416
+ """
1417
+ return self._tensor.dense_shape()
1418
+
1419
+ def data_type(self) -> str:
1420
+ """
1421
+ Returns a string data type of the data in the OrtValue
1422
+ """
1423
+ return self._tensor.data_type()
1424
+
1425
+ def device_name(self) -> str:
1426
+ """
1427
+ Returns the name of the device where the SparseTensor data buffers reside e.g. cpu, cuda
1428
+ """
1429
+ return self._tensor.device_name().lower()
1430
+
1431
+
1432
+ # Type hint for user-specified function that allows the user to specify initializer locations when compiling a model.
1433
+ GetInitializerLocationFunc = Callable[
1434
+ [str, OrtValue, C.OrtExternalInitializerInfo | None], C.OrtExternalInitializerInfo | None
1435
+ ]
1436
+
1437
+ # Type hint that adheres to the signature expected by ORT.
1438
+ GetInitializerLocationWrapperFunc = Callable[
1439
+ [str, C.OrtValue, C.OrtExternalInitializerInfo | None], C.OrtExternalInitializerInfo | None
1440
+ ]