onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,167 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ # Maps model class name to a tuple of model class
8
+ MODEL_CLASSES = [
9
+ "AutoModel",
10
+ "AutoModelWithLMHead",
11
+ "AutoModelForSequenceClassification",
12
+ "AutoModelForQuestionAnswering",
13
+ "AutoModelForCausalLM",
14
+ ]
15
+
16
+ # List of pretrained models: https://huggingface.co/transformers/pretrained_models.html
17
+ # Pretrained model name to a tuple of input names, opset_version, use_external_data_format, optimization model type
18
+ MODELS = {
19
+ # BERT
20
+ "bert-base-uncased": (
21
+ ["input_ids", "attention_mask", "token_type_ids"],
22
+ 12,
23
+ False,
24
+ "bert",
25
+ ),
26
+ "bert-large-uncased": (
27
+ ["input_ids", "attention_mask", "token_type_ids"],
28
+ 12,
29
+ False,
30
+ "bert",
31
+ ),
32
+ "bert-base-cased": (
33
+ ["input_ids", "attention_mask", "token_type_ids"],
34
+ 12,
35
+ False,
36
+ "bert",
37
+ ),
38
+ # "bert-large-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
39
+ # "bert-base-multilingual-uncased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
40
+ # "bert-base-multilingual-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
41
+ # "bert-base-chinese": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
42
+ # "bert-base-german-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
43
+ # "bert-large-uncased-whole-word-masking": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
44
+ # "bert-large-cased-whole-word-masking": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
45
+ # "bert-large-uncased-whole-word-masking-finetuned-squad": (["input_ids", "attention_mask",
46
+ # "token_type_ids"], 12, False, "bert"),
47
+ # "bert-large-cased-whole-word-masking-finetuned-squad": (["input_ids", "attention_mask",
48
+ # "token_type_ids"], 12, False, "bert"),
49
+ # "bert-base-cased-finetuned-mrpc": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
50
+ # "bert-base-german-dbmdz-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
51
+ # "bert-base-german-dbmdz-uncased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"),
52
+ # todo: more models to add
53
+ # GPT (no past state)
54
+ "openai-gpt": (["input_ids"], 11, False, "gpt2"),
55
+ # GPT-2 (no past state, use benchmark_gpt2.py for past_key_values)
56
+ "gpt2": (["input_ids"], 11, False, "gpt2"),
57
+ "gpt2-medium": (["input_ids"], 11, False, "gpt2"),
58
+ "gpt2-large": (["input_ids"], 11, True, "gpt2"),
59
+ "gpt2-xl": (["input_ids"], 11, True, "gpt2"),
60
+ "distilgpt2": (["input_ids"], 11, False, "gpt2"),
61
+ # Transformer-XL (Models uses Einsum, which need opset version 12 or later.)
62
+ "transfo-xl-wt103": (["input_ids", "mems"], 12, False, "bert"),
63
+ # XLNet
64
+ "xlnet-base-cased": (["input_ids"], 12, False, "bert"),
65
+ "xlnet-large-cased": (["input_ids"], 12, False, "bert"),
66
+ # XLM
67
+ "xlm-mlm-en-2048": (["input_ids"], 11, True, "bert"),
68
+ "xlm-mlm-ende-1024": (["input_ids"], 11, False, "bert"),
69
+ "xlm-mlm-enfr-1024": (["input_ids"], 11, False, "bert"),
70
+ # RoBERTa
71
+ "roberta-base": (["input_ids", "attention_mask"], 12, False, "bert"),
72
+ "roberta-large": (["input_ids", "attention_mask"], 12, False, "bert"),
73
+ "roberta-large-mnli": (["input_ids", "attention_mask"], 12, False, "bert"),
74
+ "deepset/roberta-base-squad2": (["input_ids", "attention_mask"], 11, False, "bert"),
75
+ "distilroberta-base": (["input_ids", "attention_mask"], 12, False, "bert"),
76
+ # DistilBERT
77
+ "distilbert-base-uncased": (["input_ids", "attention_mask"], 11, False, "bert"),
78
+ "distilbert-base-uncased-distilled-squad": (
79
+ ["input_ids", "attention_mask"],
80
+ 11,
81
+ False,
82
+ "bert",
83
+ ),
84
+ # CTRL
85
+ "ctrl": (["input_ids"], 11, True, "bert"),
86
+ # CamemBERT
87
+ "camembert-base": (["input_ids"], 11, False, "bert"),
88
+ # ALBERT
89
+ "albert-base-v1": (["input_ids"], 12, False, "bert"),
90
+ "albert-large-v1": (["input_ids"], 12, False, "bert"),
91
+ "albert-xlarge-v1": (["input_ids"], 12, True, "bert"),
92
+ # "albert-xxlarge-v1": (["input_ids"], 12, True, "bert"),
93
+ "albert-base-v2": (["input_ids"], 12, False, "bert"),
94
+ "albert-large-v2": (["input_ids"], 12, False, "bert"),
95
+ "albert-xlarge-v2": (["input_ids"], 12, True, "bert"),
96
+ # "albert-xxlarge-v2": (["input_ids"], 12, True, "bert"),
97
+ # T5 (use benchmark_t5.py instead)
98
+ # "t5-small": (["input_ids", "decoder_input_ids"], 12, False, "bert"),
99
+ # "t5-base": (["input_ids", "decoder_input_ids"], 12, False, "bert"),
100
+ # "t5-large": (["input_ids", "decoder_input_ids"], 12, True, "bert"),
101
+ # "t5-3b": (["input_ids", "decoder_input_ids"], 12, True, "bert"),
102
+ # "t5-11b": (["input_ids", "decoder_input_ids"], 12, True, "bert"),
103
+ # "valhalla/t5-small-qa-qg-hl": (["input_ids"], 12, True, "bert"),
104
+ # XLM-RoBERTa
105
+ "xlm-roberta-base": (["input_ids"], 11, False, "bert"),
106
+ "xlm-roberta-large": (["input_ids"], 11, True, "bert"),
107
+ # FlauBERT
108
+ "flaubert/flaubert_small_cased": (["input_ids"], 11, False, "bert"),
109
+ # "flaubert/flaubert_base_uncased": (["input_ids"], 11, False, "bert"),
110
+ "flaubert/flaubert_base_cased": (["input_ids"], 11, False, "bert"),
111
+ # "flaubert/flaubert_large_cased": (["input_ids"], 11, False, "bert"),
112
+ # Bart
113
+ "facebook/bart-large": (["input_ids", "attention_mask"], 11, False, "bart"),
114
+ "facebook/bart-base": (["input_ids", "attention_mask"], 11, False, "bart"),
115
+ "facebook/bart-large-mnli": (["input_ids", "attention_mask"], 11, False, "bart"),
116
+ "facebook/bart-large-cnn": (["input_ids", "attention_mask"], 11, False, "bart"),
117
+ # DialoGPT
118
+ "microsoft/DialoGPT-small": (["input_ids"], 11, False, "gpt2"),
119
+ "microsoft/DialoGPT-medium": (["input_ids"], 11, False, "gpt2"),
120
+ # "microsoft/DialoGPT-large": (["input_ids"], 11, True, "gpt2"),
121
+ # Reformer
122
+ # "google/reformer-enwik8": (["input_ids"], 11, False, "bert"),
123
+ # "google/reformer-crime-and-punishment": (["input_ids"], 11, False, "bert"),
124
+ # MarianMT
125
+ # "Helsinki-NLP/opus-mt-ROMANCE-en": (["input_ids"], 12, False, "bert"),
126
+ # Longformer (use benchmark_longformer.py instead)
127
+ # "allenai/longformer-base-4096": (["input_ids"], 12, False, "bert"),
128
+ # "allenai/longformer-large-4096": (["input_ids"], 12, False, "bert"),
129
+ # MBart
130
+ "facebook/mbart-large-cc25": (["input_ids"], 11, True, "bert"),
131
+ "facebook/mbart-large-en-ro": (["input_ids"], 11, True, "bert"),
132
+ # "Helsinki-NLP/opus-mt-ROMANCE-en": (["input_ids"], 12, False, "bert"),
133
+ # # Longformer
134
+ # "allenai/longformer-base-4096": (["input_ids"], 12, False, "bert"),
135
+ # "allenai/longformer-large-4096": (["input_ids"], 12, True, "bert"),
136
+ # "funnel-transformer/small": (["input_ids"], 12, False, "bert"),
137
+ # "funnel-transformer/small-base": (["input_ids"], 12, False, "bert"),
138
+ # "funnel-transformer/medium": (["input_ids"], 12, False, "bert"),
139
+ # "funnel-transformer/medium-base": (["input_ids"], 12, False, "bert"),
140
+ # "funnel-transformer/intermediate": (["input_ids"], 12, False, "bert"),
141
+ # "funnel-transformer/intermediate-base": (["input_ids"], 12, False, "bert"),
142
+ # "funnel-transformer/large": (["input_ids"], 12, True, "bert"),
143
+ # "funnel-transformer/large-base": (["input_ids"], 12, True, "bert"),
144
+ # "funnel-transformer/xlarge": (["input_ids"], 12, True, "bert"),
145
+ # "funnel-transformer/xlarge-base": (["input_ids"], 12, True, "bert"),
146
+ # Layoutlm
147
+ "microsoft/layoutlm-base-uncased": (["input_ids"], 11, False, "bert"),
148
+ "microsoft/layoutlm-large-uncased": (["input_ids"], 11, False, "bert"),
149
+ # Squeezebert
150
+ "squeezebert/squeezebert-uncased": (["input_ids"], 11, False, "bert"),
151
+ "squeezebert/squeezebert-mnli": (["input_ids"], 11, False, "bert"),
152
+ "squeezebert/squeezebert-mnli-headless": (["input_ids"], 11, False, "bert"),
153
+ "unc-nlp/lxmert-base-uncased": (
154
+ ["input_ids", "visual_feats", "visual_pos"],
155
+ 11,
156
+ False,
157
+ "bert",
158
+ ),
159
+ # "google/pegasus-xsum": (["input_ids"], 11, False, "bert"),
160
+ # "google/pegasus-large": (["input_ids"], 11, False, "bert"),
161
+ # ViT
162
+ "google/vit-base-patch16-224": (["pixel_values"], 12, False, "vit"),
163
+ # Swin
164
+ "microsoft/swin-base-patch4-window7-224": (["pixel_values"], 12, False, "swin"),
165
+ "microsoft/swin-small-patch4-window7-224": (["pixel_values"], 12, False, "swin"),
166
+ "microsoft/swin-tiny-patch4-window7-224": (["pixel_values"], 12, False, "swin"),
167
+ }
@@ -0,0 +1,20 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import importlib.metadata
6
+ import importlib.util
7
+
8
+
9
+ def is_installed(package):
10
+ try:
11
+ dist = importlib.metadata.distribution(package)
12
+ except importlib.metadata.PackageNotFoundError:
13
+ try:
14
+ spec = importlib.util.find_spec(package)
15
+ except ModuleNotFoundError:
16
+ return False
17
+
18
+ return spec is not None
19
+
20
+ return dist is not None
@@ -0,0 +1,442 @@
1
+ import copy
2
+ import logging
3
+ from collections import OrderedDict
4
+ from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
5
+
6
+ import numpy
7
+ import torch
8
+
9
+ from onnxruntime import InferenceSession, RunOptions
10
+
11
+ # Type alias
12
+ ShapeDict = Mapping[str, Union[Tuple, List[int]]]
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class TypeHelper:
18
+ @staticmethod
19
+ def get_input_type(ort_session: InferenceSession, name: str) -> str:
20
+ for _i, input in enumerate(ort_session.get_inputs()):
21
+ if input.name == name:
22
+ return input.type
23
+ raise ValueError(f"input name {name} not found")
24
+
25
+ @staticmethod
26
+ def get_output_type(ort_session, name: str) -> str:
27
+ for _i, output in enumerate(ort_session.get_outputs()):
28
+ if output.name == name:
29
+ return output.type
30
+
31
+ raise ValueError(f"output name {name} not found")
32
+
33
+ @staticmethod
34
+ def ort_type_to_numpy_type(ort_type: str):
35
+ ort_type_to_numpy_type_map = {
36
+ "tensor(int64)": numpy.longlong,
37
+ "tensor(int32)": numpy.intc,
38
+ "tensor(float)": numpy.float32,
39
+ "tensor(float16)": numpy.float16,
40
+ "tensor(bool)": bool,
41
+ }
42
+ if ort_type not in ort_type_to_numpy_type_map:
43
+ raise ValueError(f"{ort_type} not found in map")
44
+
45
+ return ort_type_to_numpy_type_map[ort_type]
46
+
47
+ @staticmethod
48
+ def ort_type_to_torch_type(ort_type: str):
49
+ ort_type_to_torch_type_map = {
50
+ "tensor(int64)": torch.int64,
51
+ "tensor(int32)": torch.int32,
52
+ "tensor(float)": torch.float32,
53
+ "tensor(float16)": torch.float16,
54
+ "tensor(bool)": torch.bool,
55
+ }
56
+ if ort_type not in ort_type_to_torch_type_map:
57
+ raise ValueError(f"{ort_type} not found in map")
58
+
59
+ return ort_type_to_torch_type_map[ort_type]
60
+
61
+ @staticmethod
62
+ def numpy_type_to_torch_type(numpy_type: numpy.dtype):
63
+ numpy_type_to_torch_type_map = {
64
+ numpy.longlong: torch.int64,
65
+ numpy.intc: torch.int32,
66
+ numpy.int32: torch.int32,
67
+ numpy.float32: torch.float32,
68
+ numpy.float16: torch.float16,
69
+ bool: torch.bool,
70
+ }
71
+ if numpy_type not in numpy_type_to_torch_type_map:
72
+ raise ValueError(f"{numpy_type} not found in map")
73
+
74
+ return numpy_type_to_torch_type_map[numpy_type]
75
+
76
+ @staticmethod
77
+ def torch_type_to_numpy_type(torch_type: torch.dtype):
78
+ torch_type_to_numpy_type_map = {
79
+ torch.int64: numpy.longlong,
80
+ torch.int32: numpy.intc,
81
+ torch.float32: numpy.float32,
82
+ torch.float16: numpy.float16,
83
+ torch.bool: bool,
84
+ }
85
+ if torch_type not in torch_type_to_numpy_type_map:
86
+ raise ValueError(f"{torch_type} not found in map")
87
+
88
+ return torch_type_to_numpy_type_map[torch_type]
89
+
90
+ @staticmethod
91
+ def get_io_numpy_type_map(ort_session: InferenceSession) -> Dict[str, numpy.dtype]:
92
+ """Create a mapping from input/output name to numpy data type"""
93
+ name_to_numpy_type = {}
94
+ for input in ort_session.get_inputs():
95
+ name_to_numpy_type[input.name] = TypeHelper.ort_type_to_numpy_type(input.type)
96
+
97
+ for output in ort_session.get_outputs():
98
+ name_to_numpy_type[output.name] = TypeHelper.ort_type_to_numpy_type(output.type)
99
+ return name_to_numpy_type
100
+
101
+
102
+ class IOBindingHelper:
103
+ @staticmethod
104
+ def get_output_buffers(ort_session: InferenceSession, output_shapes, device):
105
+ """Returns a dictionary of output name as key, and 1D tensor as value. The tensor has enough space for given shape."""
106
+ output_buffers = {}
107
+ for name, shape in output_shapes.items():
108
+ ort_type = TypeHelper.get_output_type(ort_session, name)
109
+ torch_type = TypeHelper.ort_type_to_torch_type(ort_type)
110
+ output_buffers[name] = torch.empty(numpy.prod(shape), dtype=torch_type, device=device)
111
+ return output_buffers
112
+
113
+ @staticmethod
114
+ def prepare_io_binding(
115
+ ort_session,
116
+ input_ids: torch.Tensor,
117
+ position_ids: torch.Tensor,
118
+ attention_mask: torch.Tensor,
119
+ past: List[torch.Tensor],
120
+ output_buffers,
121
+ output_shapes,
122
+ name_to_np_type=None,
123
+ ):
124
+ """Returnas IO binding object for a session."""
125
+ if name_to_np_type is None:
126
+ name_to_np_type = TypeHelper.get_io_numpy_type_map(ort_session)
127
+
128
+ # Bind inputs and outputs to onnxruntime session
129
+ io_binding = ort_session.io_binding()
130
+
131
+ # Bind inputs
132
+ assert input_ids.is_contiguous()
133
+ io_binding.bind_input(
134
+ "input_ids",
135
+ input_ids.device.type,
136
+ 0,
137
+ name_to_np_type["input_ids"],
138
+ list(input_ids.size()),
139
+ input_ids.data_ptr(),
140
+ )
141
+
142
+ if past is not None:
143
+ for i, past_i in enumerate(past):
144
+ assert past_i.is_contiguous()
145
+
146
+ data_ptr = past_i.data_ptr()
147
+ if data_ptr == 0:
148
+ # When past_sequence_length is 0, its data_ptr will be zero. IO Binding asserts that data_ptr shall not be zero.
149
+ # Here we workaround and pass data pointer of input_ids. Actual data is not used for past so it does not matter.
150
+ data_ptr = input_ids.data_ptr()
151
+
152
+ io_binding.bind_input(
153
+ f"past_{i}",
154
+ past_i.device.type,
155
+ 0,
156
+ name_to_np_type[f"past_{i}"],
157
+ list(past_i.size()),
158
+ data_ptr,
159
+ )
160
+
161
+ if attention_mask is not None:
162
+ assert attention_mask.is_contiguous()
163
+ io_binding.bind_input(
164
+ "attention_mask",
165
+ attention_mask.device.type,
166
+ 0,
167
+ name_to_np_type["attention_mask"],
168
+ list(attention_mask.size()),
169
+ attention_mask.data_ptr(),
170
+ )
171
+
172
+ if position_ids is not None:
173
+ assert position_ids.is_contiguous()
174
+ io_binding.bind_input(
175
+ "position_ids",
176
+ position_ids.device.type,
177
+ 0,
178
+ name_to_np_type["position_ids"],
179
+ list(position_ids.size()),
180
+ position_ids.data_ptr(),
181
+ )
182
+
183
+ # Bind outputs
184
+ for output in ort_session.get_outputs():
185
+ output_name = output.name
186
+ output_buffer = output_buffers[output_name]
187
+ logger.debug(f"{output_name} device type={output_buffer.device.type} shape={list(output_buffer.size())}")
188
+ io_binding.bind_output(
189
+ output_name,
190
+ output_buffer.device.type,
191
+ 0,
192
+ name_to_np_type[output_name],
193
+ output_shapes[output_name],
194
+ output_buffer.data_ptr(),
195
+ )
196
+
197
+ return io_binding
198
+
199
+ @staticmethod
200
+ def get_outputs_from_io_binding_buffer(ort_session, output_buffers, output_shapes, return_numpy=True):
201
+ """Copy results to cpu. Returns a list of numpy array."""
202
+ ort_outputs = []
203
+ for output in ort_session.get_outputs():
204
+ output_name = output.name
205
+ buffer = output_buffers[output_name]
206
+ shape = output_shapes[output_name]
207
+ copy_tensor = buffer[0 : numpy.prod(shape)].reshape(shape).clone().detach()
208
+ if return_numpy:
209
+ ort_outputs.append(copy_tensor.cpu().numpy())
210
+ else:
211
+ ort_outputs.append(copy_tensor)
212
+ return ort_outputs
213
+
214
+
215
+ class CudaSession:
216
+ """Inference Session with IO Binding for ONNX Runtime CUDA or TensorRT provider"""
217
+
218
+ def __init__(self, ort_session: InferenceSession, device: torch.device, enable_cuda_graph=False):
219
+ self.ort_session = ort_session
220
+ self.input_names = [input.name for input in self.ort_session.get_inputs()]
221
+ self.output_names = [output.name for output in self.ort_session.get_outputs()]
222
+ self.io_name_to_numpy_type = TypeHelper.get_io_numpy_type_map(self.ort_session)
223
+ self.io_binding = self.ort_session.io_binding()
224
+ self.enable_cuda_graph = enable_cuda_graph
225
+
226
+ self.input_tensors = OrderedDict()
227
+ self.output_tensors = OrderedDict()
228
+ self.device = device
229
+
230
+ # Pairs of input and output names that share the same buffer.
231
+ self.buffer_sharing: Dict[str, str] = {}
232
+
233
+ def set_buffer_sharing(self, input_name: str, output_name: str):
234
+ assert input_name in self.input_names
235
+ assert output_name in self.output_names
236
+ self.buffer_sharing[input_name] = output_name
237
+ self.buffer_sharing[output_name] = input_name
238
+
239
+ def __del__(self):
240
+ del self.input_tensors
241
+ del self.output_tensors
242
+ del self.io_binding
243
+
244
+ def bind_input_and_buffer_sharing(self, name: str, tensor: torch.Tensor):
245
+ device_id = tensor.device.index if tensor.device.index is not None else 0
246
+ tensor_shape = [1] if len(tensor.shape) == 0 else list(tensor.shape)
247
+
248
+ self.io_binding.bind_input(
249
+ name,
250
+ tensor.device.type,
251
+ device_id,
252
+ self.io_name_to_numpy_type[name],
253
+ tensor_shape,
254
+ tensor.data_ptr(),
255
+ )
256
+
257
+ if name in self.buffer_sharing:
258
+ self.io_binding.bind_output(
259
+ self.buffer_sharing[name],
260
+ tensor.device.type,
261
+ device_id,
262
+ self.io_name_to_numpy_type[name],
263
+ tensor_shape,
264
+ tensor.data_ptr(),
265
+ )
266
+ self.output_tensors[self.buffer_sharing[name]] = tensor
267
+
268
+ def allocate_buffers(self, shape_dict: ShapeDict):
269
+ """Allocate tensors for I/O Binding"""
270
+ if self.enable_cuda_graph:
271
+ for name, shape in shape_dict.items():
272
+ if name in self.input_names:
273
+ # Reuse allocated buffer when the shape is same
274
+ if name in self.input_tensors:
275
+ if tuple(self.input_tensors[name].shape) == tuple(shape):
276
+ continue
277
+ raise RuntimeError("Expect static input shape for cuda graph")
278
+
279
+ numpy_dtype = self.io_name_to_numpy_type[name]
280
+ tensor = torch.empty(tuple(shape), dtype=TypeHelper.numpy_type_to_torch_type(numpy_dtype)).to(
281
+ device=self.device
282
+ )
283
+ self.input_tensors[name] = tensor
284
+ self.bind_input_and_buffer_sharing(name, tensor)
285
+
286
+ for name, shape in shape_dict.items():
287
+ if name in self.output_names:
288
+ # Reuse allocated buffer when the shape is same
289
+ if name in self.output_tensors and tuple(self.output_tensors[name].shape) == tuple(shape):
290
+ continue
291
+
292
+ if name in self.buffer_sharing:
293
+ continue
294
+
295
+ numpy_dtype = self.io_name_to_numpy_type[name]
296
+ tensor = torch.empty(tuple(shape), dtype=TypeHelper.numpy_type_to_torch_type(numpy_dtype)).to(
297
+ device=self.device
298
+ )
299
+ self.output_tensors[name] = tensor
300
+
301
+ self.io_binding.bind_output(
302
+ name,
303
+ tensor.device.type,
304
+ tensor.device.index if tensor.device.index is not None else 0,
305
+ numpy_dtype,
306
+ list(tensor.size()),
307
+ tensor.data_ptr(),
308
+ )
309
+
310
+ def infer(self, feed_dict: Dict[str, torch.Tensor], run_options: RunOptions = None, synchronize: bool = True):
311
+ """Bind input tensors and run inference"""
312
+ for name, tensor in feed_dict.items():
313
+ assert isinstance(tensor, torch.Tensor) and tensor.is_contiguous()
314
+ if name in self.input_names:
315
+ if self.enable_cuda_graph:
316
+ assert self.input_tensors[name].nelement() == tensor.nelement()
317
+ assert self.input_tensors[name].dtype == tensor.dtype
318
+ assert tensor.device.type == "cuda"
319
+ self.input_tensors[name].copy_(tensor)
320
+ else:
321
+ self.bind_input_and_buffer_sharing(name, tensor)
322
+
323
+ if synchronize:
324
+ self.io_binding.synchronize_inputs()
325
+ self.ort_session.run_with_iobinding(self.io_binding, run_options)
326
+ self.io_binding.synchronize_outputs()
327
+ else:
328
+ self.ort_session.run_with_iobinding(self.io_binding, run_options)
329
+
330
+ return self.output_tensors
331
+
332
+ @staticmethod
333
+ def get_cuda_provider_options(device_id: int, enable_cuda_graph: bool, stream: int = 0) -> Dict[str, Any]:
334
+ options = {
335
+ "device_id": device_id,
336
+ "arena_extend_strategy": "kSameAsRequested",
337
+ "enable_cuda_graph": enable_cuda_graph,
338
+ }
339
+
340
+ # Stream is address of a CUDA stream. 0 means the default stream.
341
+ if stream != 0:
342
+ options["user_compute_stream"] = str(stream)
343
+
344
+ return options
345
+
346
+
347
+ class GpuBinding(CudaSession):
348
+ def __init__(
349
+ self,
350
+ ort_session: InferenceSession,
351
+ device: torch.device,
352
+ shape_dict: ShapeDict,
353
+ enable_gpu_graph: bool = False,
354
+ gpu_graph_id: int = -1,
355
+ stream: int = 0,
356
+ buffer_sharing: Optional[Dict[str, str]] = None,
357
+ ):
358
+ super().__init__(ort_session, device, enable_gpu_graph)
359
+ if buffer_sharing:
360
+ for input_name, output_name in buffer_sharing.items():
361
+ self.set_buffer_sharing(input_name, output_name)
362
+
363
+ self.allocate_buffers(shape_dict)
364
+ self.gpu_graph_id = gpu_graph_id
365
+ # For cuda graph, we need to keep a copy of shape_dict to check if the shape is same in inference later.
366
+ self.shape_dict = copy.deepcopy(shape_dict) if enable_gpu_graph else None
367
+ self.stream = stream
368
+ # The gpu graph id of last run. It will be saved to image metadata.
369
+ self.last_run_gpu_graph_id = None
370
+
371
+ def get_run_options(self, disable_cuda_graph_in_run: bool = False) -> RunOptions:
372
+ options = RunOptions()
373
+
374
+ gpu_graph_id = -1 if disable_cuda_graph_in_run else self.gpu_graph_id
375
+
376
+ options.add_run_config_entry("gpu_graph_id", str(gpu_graph_id))
377
+
378
+ self.last_run_gpu_graph_id = gpu_graph_id
379
+
380
+ return options
381
+
382
+ def infer(self, feed_dict: Dict[str, torch.Tensor], disable_cuda_graph_in_run: bool = False):
383
+ run_options = self.get_run_options(disable_cuda_graph_in_run)
384
+
385
+ if self.stream:
386
+ run_options.add_run_config_entry("disable_synchronize_execution_providers", "1")
387
+
388
+ return super().infer(feed_dict, run_options)
389
+
390
+
391
+ class GpuBindingManager:
392
+ """A manager for I/O bindings that support multiple CUDA Graphs.
393
+ One cuda graph is reused for same input shape. Automatically add a new cuda graph for new input shape.
394
+ """
395
+
396
+ def __init__(self, ort_session: InferenceSession, device: torch.device, stream: int = 0, max_cuda_graphs: int = 1):
397
+ self.ort_session = ort_session
398
+ self.device = device
399
+
400
+ # Binding supports cuda graphs. For a binding, it is able to disable cuda graph for a specific run.
401
+ self.graph_bindings = []
402
+
403
+ # Binding for not using cuda graph.
404
+ self.no_graph_binding = None
405
+
406
+ self.stream = stream
407
+
408
+ self.max_cuda_graphs = max_cuda_graphs
409
+
410
+ def get_binding(
411
+ self,
412
+ shape_dict: ShapeDict,
413
+ use_cuda_graph: bool = False,
414
+ buffer_sharing: Optional[Dict[str, str]] = None,
415
+ ) -> GpuBinding:
416
+ for gpu_graph_binding in self.graph_bindings:
417
+ # Found a cuda graph that captured with the same shape
418
+ if gpu_graph_binding.shape_dict == shape_dict:
419
+ return gpu_graph_binding
420
+
421
+ # Reached the maximum number of cuda graphs. Return a binding without cuda graph.
422
+ if len(self.graph_bindings) >= self.max_cuda_graphs or (not use_cuda_graph):
423
+ if self.no_graph_binding is None:
424
+ self.no_graph_binding = GpuBinding(
425
+ self.ort_session, self.device, shape_dict, stream=self.stream, buffer_sharing=buffer_sharing
426
+ )
427
+ else:
428
+ self.no_graph_binding.allocate_buffers(shape_dict)
429
+ return self.no_graph_binding
430
+
431
+ # This is a new input shape, create a new cuda graph
432
+ gpu_graph_binding = GpuBinding(
433
+ self.ort_session,
434
+ self.device,
435
+ shape_dict,
436
+ enable_gpu_graph=True,
437
+ gpu_graph_id=len(self.graph_bindings),
438
+ stream=self.stream,
439
+ buffer_sharing=buffer_sharing,
440
+ )
441
+ self.graph_bindings.append(gpu_graph_binding)
442
+ return gpu_graph_binding