onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,74 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ # Maps model class name to a tuple of model class
8
+ MODEL_CLASSES = [
9
+ "AutoModel",
10
+ "AutoModelWithLMHead",
11
+ "AutoModelForSequenceClassification",
12
+ "AutoModelForQuestionAnswering",
13
+ "AutoModelForCausalLM",
14
+ ]
15
+
16
+ # Pretrained model name to a tuple of input names, opset_version, use_external_data_format, optimization model type
17
+ # Some models like GPT, T5, Bart etc has its own convert_to_onnx.py in models sub-directory, and they are excluded here.
18
+ MODELS = {
19
+ # BERT
20
+ "bert-base-cased": (["input_ids", "attention_mask", "token_type_ids"], 16, False, "bert"),
21
+ "bert-large-cased": (["input_ids", "attention_mask", "token_type_ids"], 16, False, "bert"),
22
+ # Transformer-XL (Models uses Einsum, which need opset version 16 or later.)
23
+ "transfo-xl-wt103": (["input_ids", "mems"], 16, False, "bert"),
24
+ # XLNet
25
+ "xlnet-base-cased": (["input_ids"], 16, False, "bert"),
26
+ "xlnet-large-cased": (["input_ids"], 16, False, "bert"),
27
+ # XLM
28
+ "xlm-mlm-en-2048": (["input_ids"], 16, True, "bert"),
29
+ "xlm-mlm-ende-1024": (["input_ids"], 16, False, "bert"),
30
+ "xlm-mlm-enfr-1024": (["input_ids"], 16, False, "bert"),
31
+ # RoBERTa
32
+ "roberta-base": (["input_ids", "attention_mask"], 16, False, "bert"),
33
+ "roberta-large": (["input_ids", "attention_mask"], 16, False, "bert"),
34
+ "roberta-large-mnli": (["input_ids", "attention_mask"], 16, False, "bert"),
35
+ "deepset/roberta-base-squad2": (["input_ids", "attention_mask"], 16, False, "bert"),
36
+ "distilroberta-base": (["input_ids", "attention_mask"], 16, False, "bert"),
37
+ # DistilBERT
38
+ "distilbert-base-uncased": (["input_ids", "attention_mask"], 16, False, "bert"),
39
+ "distilbert-base-uncased-distilled-squad": (["input_ids", "attention_mask"], 16, False, "bert"),
40
+ # CTRL
41
+ "ctrl": (["input_ids"], 16, True, "bert"),
42
+ # CamemBERT
43
+ "camembert-base": (["input_ids"], 16, False, "bert"),
44
+ # ALBERT
45
+ "albert-base-v1": (["input_ids"], 16, False, "bert"),
46
+ "albert-large-v1": (["input_ids"], 16, False, "bert"),
47
+ "albert-xlarge-v1": (["input_ids"], 16, True, "bert"),
48
+ # "albert-xxlarge-v1": (["input_ids"], 16, True, "bert"),
49
+ "albert-base-v2": (["input_ids"], 16, False, "bert"),
50
+ "albert-large-v2": (["input_ids"], 16, False, "bert"),
51
+ "albert-xlarge-v2": (["input_ids"], 16, True, "bert"),
52
+ # "albert-xxlarge-v2": (["input_ids"], 16, True, "bert"),
53
+ # XLM-RoBERTa
54
+ "xlm-roberta-base": (["input_ids"], 16, False, "bert"),
55
+ "xlm-roberta-large": (["input_ids"], 16, True, "bert"),
56
+ # FlauBERT
57
+ "flaubert/flaubert_small_cased": (["input_ids"], 16, False, "bert"),
58
+ "flaubert/flaubert_base_cased": (["input_ids"], 16, False, "bert"),
59
+ # "flaubert/flaubert_large_cased": (["input_ids"], 16, False, "bert"),
60
+ # Layoutlm
61
+ "microsoft/layoutlm-base-uncased": (["input_ids"], 16, False, "bert"),
62
+ "microsoft/layoutlm-large-uncased": (["input_ids"], 16, False, "bert"),
63
+ # Squeezebert
64
+ "squeezebert/squeezebert-uncased": (["input_ids"], 16, False, "bert"),
65
+ "squeezebert/squeezebert-mnli": (["input_ids"], 16, False, "bert"),
66
+ "squeezebert/squeezebert-mnli-headless": (["input_ids"], 16, False, "bert"),
67
+ "unc-nlp/lxmert-base-uncased": (["input_ids", "visual_feats", "visual_pos"], 16, False, "bert"),
68
+ # ViT
69
+ "google/vit-base-patch16-224": (["pixel_values"], 16, False, "vit"),
70
+ # Swin
71
+ "microsoft/swin-base-patch4-window7-224": (["pixel_values"], 16, False, "swin"),
72
+ "microsoft/swin-small-patch4-window7-224": (["pixel_values"], 16, False, "swin"),
73
+ "microsoft/swin-tiny-patch4-window7-224": (["pixel_values"], 16, False, "swin"),
74
+ }
@@ -0,0 +1,20 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import importlib.metadata
6
+ import importlib.util
7
+
8
+
9
+ def is_installed(package):
10
+ try:
11
+ dist = importlib.metadata.distribution(package)
12
+ except importlib.metadata.PackageNotFoundError:
13
+ try:
14
+ spec = importlib.util.find_spec(package)
15
+ except ModuleNotFoundError:
16
+ return False
17
+
18
+ return spec is not None
19
+
20
+ return dist is not None
@@ -0,0 +1,487 @@
1
+ import copy
2
+ import logging
3
+ from collections import OrderedDict
4
+ from collections.abc import Mapping
5
+ from typing import Any
6
+
7
+ import numpy
8
+ import torch
9
+ from onnx import TensorProto
10
+
11
+ from onnxruntime import InferenceSession, RunOptions
12
+
13
+ # Type alias
14
+ ShapeDict = Mapping[str, tuple | list[int]]
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class TypeHelper:
20
+ @staticmethod
21
+ def get_input_type(ort_session: InferenceSession, name: str) -> str:
22
+ for _i, input in enumerate(ort_session.get_inputs()):
23
+ if input.name == name:
24
+ return input.type
25
+ raise ValueError(f"input name {name} not found")
26
+
27
+ @staticmethod
28
+ def get_output_type(ort_session, name: str) -> str:
29
+ for _i, output in enumerate(ort_session.get_outputs()):
30
+ if output.name == name:
31
+ return output.type
32
+
33
+ raise ValueError(f"output name {name} not found")
34
+
35
+ @staticmethod
36
+ def ort_type_to_numpy_type(ort_type: str):
37
+ ort_type_to_numpy_type_map = {
38
+ "tensor(int64)": numpy.longlong,
39
+ "tensor(int32)": numpy.intc,
40
+ "tensor(float)": numpy.float32,
41
+ "tensor(float16)": numpy.float16,
42
+ "tensor(bool)": bool,
43
+ "tensor(uint8)": numpy.uint8,
44
+ "tensor(int8)": numpy.int8,
45
+ }
46
+ if ort_type not in ort_type_to_numpy_type_map:
47
+ raise ValueError(f"{ort_type} not found in map")
48
+
49
+ return ort_type_to_numpy_type_map[ort_type]
50
+
51
+ @staticmethod
52
+ def ort_type_to_torch_type(ort_type: str):
53
+ ort_type_to_torch_type_map = {
54
+ "tensor(int64)": torch.int64,
55
+ "tensor(int32)": torch.int32,
56
+ "tensor(float)": torch.float32,
57
+ "tensor(float16)": torch.float16,
58
+ "tensor(bfloat16)": torch.bfloat16,
59
+ "tensor(bool)": torch.bool,
60
+ "tensor(uint8)": torch.uint8,
61
+ "tensor(int8)": torch.int8,
62
+ }
63
+ if ort_type not in ort_type_to_torch_type_map:
64
+ raise ValueError(f"{ort_type} not found in map")
65
+
66
+ return ort_type_to_torch_type_map[ort_type]
67
+
68
+ @staticmethod
69
+ def get_io_onnx_type_map(ort_session: InferenceSession) -> dict[str, int]:
70
+ """Create a mapping from input/output name to onnx data type"""
71
+ name_to_onnx_type = {}
72
+ for input in ort_session.get_inputs():
73
+ name_to_onnx_type[input.name] = TypeHelper.ort_type_to_onnx_type(input.type)
74
+
75
+ for output in ort_session.get_outputs():
76
+ name_to_onnx_type[output.name] = TypeHelper.ort_type_to_onnx_type(output.type)
77
+ return name_to_onnx_type
78
+
79
+ @staticmethod
80
+ def ort_type_to_onnx_type(ort_type: str):
81
+ ort_type_to_onnx_type_map = {
82
+ "tensor(int64)": TensorProto.INT64,
83
+ "tensor(int32)": TensorProto.INT32,
84
+ "tensor(float)": TensorProto.FLOAT,
85
+ "tensor(float16)": TensorProto.FLOAT16,
86
+ "tensor(bfloat16)": TensorProto.BFLOAT16,
87
+ "tensor(bool)": TensorProto.BOOL,
88
+ "tensor(uint8)": TensorProto.UINT8,
89
+ "tensor(int8)": TensorProto.INT8,
90
+ }
91
+ if ort_type not in ort_type_to_onnx_type_map:
92
+ raise ValueError(f"{ort_type} not found in map")
93
+
94
+ return ort_type_to_onnx_type_map[ort_type]
95
+
96
+ @staticmethod
97
+ def numpy_type_to_torch_type(numpy_type: numpy.dtype):
98
+ numpy_type_to_torch_type_map = {
99
+ numpy.longlong: torch.int64,
100
+ numpy.intc: torch.int32,
101
+ numpy.int32: torch.int32,
102
+ numpy.float32: torch.float32,
103
+ numpy.float16: torch.float16,
104
+ bool: torch.bool,
105
+ numpy.uint8: torch.uint8,
106
+ numpy.int8: torch.int8,
107
+ }
108
+ if numpy_type not in numpy_type_to_torch_type_map:
109
+ raise ValueError(f"{numpy_type} not found in map")
110
+
111
+ return numpy_type_to_torch_type_map[numpy_type]
112
+
113
+ @staticmethod
114
+ def torch_type_to_numpy_type(torch_type: torch.dtype):
115
+ torch_type_to_numpy_type_map = {
116
+ torch.int64: numpy.longlong,
117
+ torch.int32: numpy.intc,
118
+ torch.float32: numpy.float32,
119
+ torch.float16: numpy.float16,
120
+ torch.bool: bool,
121
+ torch.uint8: numpy.uint8,
122
+ }
123
+ if torch_type not in torch_type_to_numpy_type_map:
124
+ raise ValueError(f"{torch_type} not found in map")
125
+
126
+ return torch_type_to_numpy_type_map[torch_type]
127
+
128
+ @staticmethod
129
+ def get_io_numpy_type_map(ort_session: InferenceSession) -> dict[str, numpy.dtype]:
130
+ """Create a mapping from input/output name to numpy data type"""
131
+ name_to_numpy_type = {}
132
+ for input in ort_session.get_inputs():
133
+ name_to_numpy_type[input.name] = TypeHelper.ort_type_to_numpy_type(input.type)
134
+
135
+ for output in ort_session.get_outputs():
136
+ name_to_numpy_type[output.name] = TypeHelper.ort_type_to_numpy_type(output.type)
137
+ return name_to_numpy_type
138
+
139
+ @staticmethod
140
+ def get_io_torch_type_map(ort_session: InferenceSession) -> dict[str, torch.dtype]:
141
+ """Create a mapping from input/output name to torch data type"""
142
+ name_to_torch_type = {}
143
+ for input in ort_session.get_inputs():
144
+ name_to_torch_type[input.name] = TypeHelper.ort_type_to_torch_type(input.type)
145
+
146
+ for output in ort_session.get_outputs():
147
+ name_to_torch_type[output.name] = TypeHelper.ort_type_to_torch_type(output.type)
148
+ return name_to_torch_type
149
+
150
+
151
+ class IOBindingHelper:
152
+ @staticmethod
153
+ def get_output_buffers(ort_session: InferenceSession, output_shapes, device):
154
+ """Returns a dictionary of output name as key, and 1D tensor as value. The tensor has enough space for given shape."""
155
+ output_buffers = {}
156
+ for name, shape in output_shapes.items():
157
+ ort_type = TypeHelper.get_output_type(ort_session, name)
158
+ torch_type = TypeHelper.ort_type_to_torch_type(ort_type)
159
+ output_buffers[name] = torch.empty(numpy.prod(shape), dtype=torch_type, device=device)
160
+ return output_buffers
161
+
162
+ @staticmethod
163
+ def prepare_io_binding(
164
+ ort_session,
165
+ input_ids: torch.Tensor,
166
+ position_ids: torch.Tensor,
167
+ attention_mask: torch.Tensor,
168
+ past: list[torch.Tensor],
169
+ output_buffers,
170
+ output_shapes,
171
+ ):
172
+ """IO binding for a session: bind inputs (input_ids, position_ids, attention_mask, past_*) and outputs."""
173
+
174
+ name_to_onnx_type = TypeHelper.get_io_onnx_type_map(ort_session)
175
+
176
+ # Bind inputs and outputs to onnxruntime session
177
+ io_binding = ort_session.io_binding()
178
+
179
+ # Bind inputs
180
+ assert input_ids.is_contiguous()
181
+ io_binding.bind_input(
182
+ "input_ids",
183
+ input_ids.device.type,
184
+ 0,
185
+ name_to_onnx_type["input_ids"],
186
+ list(input_ids.size()),
187
+ input_ids.data_ptr(),
188
+ )
189
+
190
+ if past is not None:
191
+ for i, past_i in enumerate(past):
192
+ assert past_i.is_contiguous()
193
+
194
+ data_ptr = past_i.data_ptr()
195
+ if data_ptr == 0:
196
+ # When past_sequence_length is 0, its data_ptr will be zero. IO Binding asserts that data_ptr shall not be zero.
197
+ # Here we workaround and pass data pointer of input_ids. Actual data is not used for past so it does not matter.
198
+ data_ptr = input_ids.data_ptr()
199
+
200
+ io_binding.bind_input(
201
+ f"past_{i}",
202
+ past_i.device.type,
203
+ 0,
204
+ name_to_onnx_type[f"past_{i}"],
205
+ list(past_i.size()),
206
+ data_ptr,
207
+ )
208
+
209
+ if attention_mask is not None:
210
+ assert attention_mask.is_contiguous()
211
+ io_binding.bind_input(
212
+ "attention_mask",
213
+ attention_mask.device.type,
214
+ 0,
215
+ name_to_onnx_type["attention_mask"],
216
+ list(attention_mask.size()),
217
+ attention_mask.data_ptr(),
218
+ )
219
+
220
+ if position_ids is not None:
221
+ assert position_ids.is_contiguous()
222
+ io_binding.bind_input(
223
+ "position_ids",
224
+ position_ids.device.type,
225
+ 0,
226
+ name_to_onnx_type["position_ids"],
227
+ list(position_ids.size()),
228
+ position_ids.data_ptr(),
229
+ )
230
+
231
+ # Bind outputs
232
+ for output in ort_session.get_outputs():
233
+ output_name = output.name
234
+ output_buffer = output_buffers[output_name]
235
+ logger.debug(f"{output_name} device type={output_buffer.device.type} shape={list(output_buffer.size())}")
236
+ io_binding.bind_output(
237
+ output_name,
238
+ output_buffer.device.type,
239
+ 0,
240
+ name_to_onnx_type[output_name],
241
+ output_shapes[output_name],
242
+ output_buffer.data_ptr(),
243
+ )
244
+
245
+ return io_binding
246
+
247
+ @staticmethod
248
+ def get_outputs_from_io_binding_buffer(ort_session, output_buffers, output_shapes, return_numpy=True):
249
+ """Copy results to cpu. Returns a list of numpy array."""
250
+ ort_outputs = []
251
+ for output in ort_session.get_outputs():
252
+ output_name = output.name
253
+ buffer = output_buffers[output_name]
254
+ shape = output_shapes[output_name]
255
+ copy_tensor = buffer[0 : numpy.prod(shape)].reshape(shape).clone().detach()
256
+ if return_numpy:
257
+ ort_outputs.append(copy_tensor.cpu().numpy())
258
+ else:
259
+ ort_outputs.append(copy_tensor)
260
+ return ort_outputs
261
+
262
+
263
+ class CudaSession:
264
+ """Inference Session with IO Binding for ONNX Runtime CUDA or TensorRT provider"""
265
+
266
+ def __init__(self, ort_session: InferenceSession, device: torch.device, enable_cuda_graph=False):
267
+ self.ort_session = ort_session
268
+ self.input_names = [input.name for input in self.ort_session.get_inputs()]
269
+ self.output_names = [output.name for output in self.ort_session.get_outputs()]
270
+ self.io_name_to_onnx_type = TypeHelper.get_io_onnx_type_map(self.ort_session)
271
+ self.io_name_to_torch_type = TypeHelper.get_io_torch_type_map(self.ort_session)
272
+ self.io_binding = self.ort_session.io_binding()
273
+ self.enable_cuda_graph = enable_cuda_graph
274
+
275
+ self.input_tensors = OrderedDict()
276
+ self.output_tensors = OrderedDict()
277
+ self.device = device
278
+
279
+ # Pairs of input and output names that share the same buffer.
280
+ self.buffer_sharing: dict[str, str] = {}
281
+
282
+ def set_buffer_sharing(self, input_name: str, output_name: str):
283
+ assert input_name in self.input_names
284
+ assert output_name in self.output_names
285
+ self.buffer_sharing[input_name] = output_name
286
+ self.buffer_sharing[output_name] = input_name
287
+
288
+ def __del__(self):
289
+ del self.input_tensors
290
+ del self.output_tensors
291
+ del self.io_binding
292
+
293
+ def bind_input_and_buffer_sharing(self, name: str, tensor: torch.Tensor):
294
+ device_id = tensor.device.index if tensor.device.index is not None else 0
295
+ tensor_shape = [1] if len(tensor.shape) == 0 else list(tensor.shape)
296
+
297
+ self.io_binding.bind_input(
298
+ name,
299
+ tensor.device.type,
300
+ device_id,
301
+ self.io_name_to_onnx_type[name],
302
+ tensor_shape,
303
+ tensor.data_ptr(),
304
+ )
305
+
306
+ if name in self.buffer_sharing:
307
+ self.io_binding.bind_output(
308
+ self.buffer_sharing[name],
309
+ tensor.device.type,
310
+ device_id,
311
+ self.io_name_to_onnx_type[name],
312
+ tensor_shape,
313
+ tensor.data_ptr(),
314
+ )
315
+ self.output_tensors[self.buffer_sharing[name]] = tensor
316
+
317
+ def allocate_buffers(self, shape_dict: ShapeDict):
318
+ """Allocate tensors for I/O Binding"""
319
+ if self.enable_cuda_graph:
320
+ for name, shape in shape_dict.items():
321
+ if name in self.input_names:
322
+ # Reuse allocated buffer when the shape is same
323
+ if name in self.input_tensors:
324
+ if tuple(self.input_tensors[name].shape) == tuple(shape):
325
+ continue
326
+ raise RuntimeError("Expect static input shape for cuda graph")
327
+
328
+ torch_dtype = self.io_name_to_torch_type[name]
329
+ tensor = torch.empty(tuple(shape), dtype=torch_dtype).to(device=self.device)
330
+ self.input_tensors[name] = tensor
331
+ self.bind_input_and_buffer_sharing(name, tensor)
332
+
333
+ for name, shape in shape_dict.items():
334
+ if name in self.output_names:
335
+ # Reuse allocated buffer when the shape is same
336
+ if name in self.output_tensors and tuple(self.output_tensors[name].shape) == tuple(shape):
337
+ continue
338
+
339
+ if name in self.buffer_sharing:
340
+ continue
341
+
342
+ torch_dtype = self.io_name_to_torch_type[name]
343
+ tensor = torch.empty(tuple(shape), dtype=torch_dtype).to(device=self.device)
344
+ self.output_tensors[name] = tensor
345
+
346
+ self.io_binding.bind_output(
347
+ name,
348
+ tensor.device.type,
349
+ tensor.device.index if tensor.device.index is not None else 0,
350
+ self.io_name_to_onnx_type[name],
351
+ list(tensor.size()),
352
+ tensor.data_ptr(),
353
+ )
354
+
355
+ def infer(self, feed_dict: dict[str, torch.Tensor], run_options: RunOptions = None, synchronize: bool = True):
356
+ """Bind input tensors and run inference"""
357
+ for name, tensor in feed_dict.items():
358
+ assert isinstance(tensor, torch.Tensor) and tensor.is_contiguous()
359
+ if name in self.input_names:
360
+ if self.enable_cuda_graph:
361
+ assert self.input_tensors[name].nelement() == tensor.nelement()
362
+ assert self.input_tensors[name].dtype == tensor.dtype
363
+ assert tensor.device.type == "cuda"
364
+ self.input_tensors[name].copy_(tensor)
365
+ else:
366
+ self.bind_input_and_buffer_sharing(name, tensor)
367
+
368
+ if synchronize:
369
+ self.io_binding.synchronize_inputs()
370
+ self.ort_session.run_with_iobinding(self.io_binding, run_options)
371
+ self.io_binding.synchronize_outputs()
372
+ else:
373
+ self.ort_session.run_with_iobinding(self.io_binding, run_options)
374
+
375
+ return self.output_tensors
376
+
377
+ @staticmethod
378
+ def get_cuda_provider_options(device_id: int, enable_cuda_graph: bool, stream: int = 0) -> dict[str, Any]:
379
+ options = {
380
+ "device_id": device_id,
381
+ "arena_extend_strategy": "kSameAsRequested",
382
+ "enable_cuda_graph": enable_cuda_graph,
383
+ }
384
+
385
+ # Stream is address of a CUDA stream. 0 means the default stream.
386
+ if stream != 0:
387
+ options["user_compute_stream"] = str(stream)
388
+
389
+ return options
390
+
391
+
392
+ class GpuBinding(CudaSession):
393
+ def __init__(
394
+ self,
395
+ ort_session: InferenceSession,
396
+ device: torch.device,
397
+ shape_dict: ShapeDict,
398
+ enable_gpu_graph: bool = False,
399
+ gpu_graph_id: int = -1,
400
+ stream: int = 0,
401
+ buffer_sharing: dict[str, str] | None = None,
402
+ ):
403
+ super().__init__(ort_session, device, enable_gpu_graph)
404
+ if buffer_sharing:
405
+ for input_name, output_name in buffer_sharing.items():
406
+ self.set_buffer_sharing(input_name, output_name)
407
+
408
+ self.allocate_buffers(shape_dict)
409
+ self.gpu_graph_id = gpu_graph_id
410
+ # For cuda graph, we need to keep a copy of shape_dict to check if the shape is same in inference later.
411
+ self.shape_dict = copy.deepcopy(shape_dict) if enable_gpu_graph else None
412
+ self.stream = stream
413
+ # The gpu graph id of last run. It will be saved to image metadata.
414
+ self.last_run_gpu_graph_id = None
415
+
416
+ def get_run_options(self, disable_cuda_graph_in_run: bool = False) -> RunOptions:
417
+ options = RunOptions()
418
+
419
+ gpu_graph_id = -1 if disable_cuda_graph_in_run else self.gpu_graph_id
420
+
421
+ options.add_run_config_entry("gpu_graph_id", str(gpu_graph_id))
422
+
423
+ self.last_run_gpu_graph_id = gpu_graph_id
424
+
425
+ return options
426
+
427
+ def infer(self, feed_dict: dict[str, torch.Tensor], disable_cuda_graph_in_run: bool = False):
428
+ run_options = self.get_run_options(disable_cuda_graph_in_run)
429
+
430
+ if self.stream:
431
+ run_options.add_run_config_entry("disable_synchronize_execution_providers", "1")
432
+
433
+ return super().infer(feed_dict, run_options)
434
+
435
+
436
+ class GpuBindingManager:
437
+ """A manager for I/O bindings that support multiple CUDA Graphs.
438
+ One cuda graph is reused for same input shape. Automatically add a new cuda graph for new input shape.
439
+ """
440
+
441
+ def __init__(self, ort_session: InferenceSession, device: torch.device, stream: int = 0, max_cuda_graphs: int = 1):
442
+ self.ort_session = ort_session
443
+ self.device = device
444
+
445
+ # Binding supports cuda graphs. For a binding, it is able to disable cuda graph for a specific run.
446
+ self.graph_bindings = []
447
+
448
+ # Binding for not using cuda graph.
449
+ self.no_graph_binding = None
450
+
451
+ self.stream = stream
452
+
453
+ self.max_cuda_graphs = max_cuda_graphs
454
+
455
+ def get_binding(
456
+ self,
457
+ shape_dict: ShapeDict,
458
+ use_cuda_graph: bool = False,
459
+ buffer_sharing: dict[str, str] | None = None,
460
+ ) -> GpuBinding:
461
+ for gpu_graph_binding in self.graph_bindings:
462
+ # Found a cuda graph that captured with the same shape
463
+ if gpu_graph_binding.shape_dict == shape_dict:
464
+ return gpu_graph_binding
465
+
466
+ # Reached the maximum number of cuda graphs. Return a binding without cuda graph.
467
+ if len(self.graph_bindings) >= self.max_cuda_graphs or (not use_cuda_graph):
468
+ if self.no_graph_binding is None:
469
+ self.no_graph_binding = GpuBinding(
470
+ self.ort_session, self.device, shape_dict, stream=self.stream, buffer_sharing=buffer_sharing
471
+ )
472
+ else:
473
+ self.no_graph_binding.allocate_buffers(shape_dict)
474
+ return self.no_graph_binding
475
+
476
+ # This is a new input shape, create a new cuda graph
477
+ gpu_graph_binding = GpuBinding(
478
+ self.ort_session,
479
+ self.device,
480
+ shape_dict,
481
+ enable_gpu_graph=True,
482
+ gpu_graph_id=len(self.graph_bindings),
483
+ stream=self.stream,
484
+ buffer_sharing=buffer_sharing,
485
+ )
486
+ self.graph_bindings.append(gpu_graph_binding)
487
+ return gpu_graph_binding