onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,534 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+ from typing import Tuple, Union
7
+
8
+ import numpy as np
9
+ from fusion_base import Fusion
10
+ from fusion_utils import NumpyHelper
11
+ from onnx import NodeProto, helper, numpy_helper
12
+ from onnx_model import OnnxModel
13
+
14
+ logger = getLogger(__name__)
15
+
16
+
17
+ class FusionMultiHeadAttentionSam2(Fusion):
18
+ """
19
+ Fuse MultiHeadAttention subgraph of Segment Anything v2 (SAM2).
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ model: OnnxModel,
25
+ hidden_size: int,
26
+ num_heads: int,
27
+ ):
28
+ super().__init__(model, "MultiHeadAttention", ["LayerNormalization"])
29
+ self.hidden_size = hidden_size
30
+ self.num_heads = num_heads
31
+
32
+ # Flags to show warning only once
33
+ self.num_heads_warning = True
34
+ self.hidden_size_warning = True
35
+
36
+ def get_decoder_num_heads(self, reshape_q: NodeProto) -> int:
37
+ """Detect num_heads from a reshape node.
38
+
39
+ Args:
40
+ reshape_q (NodeProto): reshape node for Q
41
+ Returns:
42
+ int: num_heads, or 0 if not found
43
+ """
44
+ num_heads = 0
45
+
46
+ # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
47
+ shape_value = self.model.get_constant_value(reshape_q.input[1])
48
+ if shape_value is not None:
49
+ if isinstance(shape_value, np.ndarray) and list(shape_value.shape) == [4]:
50
+ num_heads = int(shape_value[2])
51
+
52
+ if isinstance(num_heads, int) and num_heads > 0:
53
+ return num_heads
54
+
55
+ return 0
56
+
57
+ def get_encoder_num_heads(self, reshape_in: NodeProto) -> int:
58
+ """Detect num_heads from a reshape node.
59
+
60
+ Args:
61
+ reshape_q (NodeProto): reshape node for Q
62
+ Returns:
63
+ int: num_heads, or 0 if not found
64
+ """
65
+ num_heads = 0
66
+
67
+ shape_value = self.model.get_constant_value(reshape_in.input[1])
68
+ if shape_value is not None:
69
+ if isinstance(shape_value, np.ndarray) and list(shape_value.shape) == [5]:
70
+ num_heads = int(shape_value[3])
71
+ else:
72
+ concat_shape = self.model.match_parent(reshape_in, "Concat", 1)
73
+ if concat_shape is not None and len(concat_shape.input) == 5:
74
+ # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
75
+ shape_value = self.model.get_constant_value(concat_shape.input[3])
76
+ if shape_value is not None:
77
+ if isinstance(shape_value, np.ndarray) and list(shape_value.shape) == [1]:
78
+ num_heads = int(shape_value[0])
79
+
80
+ if isinstance(num_heads, int) and num_heads > 0:
81
+ return num_heads
82
+
83
+ return 0
84
+
85
+ def get_hidden_size(self, layernorm_node):
86
+ """Detect hidden_size from LayerNormalization node.
87
+ Args:
88
+ layernorm_node (NodeProto): LayerNormalization node before Q, K and V
89
+ Returns:
90
+ int: hidden_size, or 0 if not found
91
+ """
92
+ layernorm_bias = self.model.get_initializer(layernorm_node.input[2])
93
+ if layernorm_bias:
94
+ return NumpyHelper.to_array(layernorm_bias).shape[0]
95
+
96
+ return 0
97
+
98
+ def get_num_heads_and_hidden_size(
99
+ self, reshape_q: NodeProto, layernorm_node: NodeProto, is_encoder: bool = False
100
+ ) -> Tuple[int, int]:
101
+ """Detect num_heads and hidden_size.
102
+
103
+ Args:
104
+ reshape_q (NodeProto): reshape node for Q
105
+ layernorm_node (NodeProto): LayerNormalization node before Q, K, V
106
+ Returns:
107
+ Tuple[int, int]: num_heads and hidden_size
108
+ """
109
+ if is_encoder:
110
+ num_heads = self.get_encoder_num_heads(reshape_q)
111
+ else:
112
+ num_heads = self.get_decoder_num_heads(reshape_q)
113
+ if num_heads <= 0:
114
+ num_heads = self.num_heads # Fall back to user specified value
115
+
116
+ if self.num_heads > 0 and num_heads != self.num_heads:
117
+ if self.num_heads_warning:
118
+ logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
119
+ self.num_heads_warning = False # Do not show the warning more than once
120
+
121
+ hidden_size = self.get_hidden_size(layernorm_node)
122
+ if hidden_size <= 0:
123
+ hidden_size = self.hidden_size # Fall back to user specified value
124
+
125
+ if self.hidden_size > 0 and hidden_size != self.hidden_size:
126
+ if self.hidden_size_warning:
127
+ logger.warning(
128
+ f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
129
+ )
130
+ self.hidden_size_warning = False # Do not show the warning more than once
131
+
132
+ return num_heads, hidden_size
133
+
134
+ def create_attention_node(
135
+ self,
136
+ q_matmul: NodeProto,
137
+ q_add: NodeProto,
138
+ k_matmul: NodeProto,
139
+ k_add: NodeProto,
140
+ v_matmul: NodeProto,
141
+ v_add: NodeProto,
142
+ num_heads: int,
143
+ hidden_size: int,
144
+ output: str,
145
+ ) -> Union[NodeProto, None]:
146
+ """Create an Attention node.
147
+
148
+ Args:
149
+ q_matmul (NodeProto): MatMul node in fully connection for Q
150
+ q_add (NodeProto): Add bias node in fully connection for Q
151
+ k_matmul (NodeProto): MatMul node in fully connection for K
152
+ k_add (NodeProto): Add bias node in fully connection for K
153
+ v_matmul (NodeProto): MatMul node in fully connection for V
154
+ v_add (NodeProto): Add bias node in fully connection for V
155
+ num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
156
+ hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
157
+ output (str): output name
158
+
159
+ Returns:
160
+ Union[NodeProto, None]: the node created or None if failed.
161
+ """
162
+ if hidden_size > 0 and (hidden_size % num_heads) != 0:
163
+ logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
164
+ return None
165
+
166
+ q_weight = self.model.get_initializer(q_matmul.input[1])
167
+ k_weight = self.model.get_initializer(k_matmul.input[1])
168
+ v_weight = self.model.get_initializer(v_matmul.input[1])
169
+ if not (q_weight and k_weight and v_weight):
170
+ return None
171
+
172
+ qw = NumpyHelper.to_array(q_weight)
173
+ kw = NumpyHelper.to_array(k_weight)
174
+ vw = NumpyHelper.to_array(v_weight)
175
+ logger.debug(f"qw={qw.shape} kw={kw.shape} vw={vw.shape} hidden_size={hidden_size}")
176
+
177
+ attention_node_name = self.model.create_node_name("MultiHeadAttention")
178
+
179
+ attention_inputs = [
180
+ q_add.output[0],
181
+ k_add.output[0],
182
+ v_add.output[0],
183
+ ]
184
+
185
+ attention_node = helper.make_node(
186
+ "MultiHeadAttention",
187
+ inputs=attention_inputs,
188
+ outputs=[output],
189
+ name=attention_node_name,
190
+ )
191
+ attention_node.domain = "com.microsoft"
192
+ attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
193
+
194
+ counter_name = "MultiHeadAttention ({})".format("cross attention")
195
+ self.increase_counter(counter_name)
196
+ return attention_node
197
+
198
+ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
199
+ if self.fuse_sam_encoder_pattern(normalize_node, input_name_to_nodes, output_name_to_node):
200
+ return
201
+
202
+ match_qkv = self.match_attention_subgraph(normalize_node)
203
+ if match_qkv is None:
204
+ if normalize_node.input[0] not in output_name_to_node:
205
+ return
206
+
207
+ skip_add = output_name_to_node[normalize_node.input[0]]
208
+ if skip_add.op_type != "Add":
209
+ return
210
+
211
+ match_qkv = self.match_attention_subgraph(skip_add)
212
+
213
+ if match_qkv is None:
214
+ return
215
+
216
+ reshape_qkv, transpose_qkv, reshape_q, matmul_q, add_q, matmul_k, add_k, matmul_v, add_v = match_qkv
217
+
218
+ attention_last_node = reshape_qkv
219
+
220
+ q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, False)
221
+ if q_num_heads <= 0:
222
+ logger.debug("fuse_attention: failed to detect num_heads")
223
+ return
224
+
225
+ # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
226
+ new_node = self.create_attention_node(
227
+ matmul_q,
228
+ add_q,
229
+ matmul_k,
230
+ add_k,
231
+ matmul_v,
232
+ add_v,
233
+ q_num_heads,
234
+ q_hidden_size,
235
+ output=attention_last_node.output[0],
236
+ )
237
+ if new_node is None:
238
+ return
239
+
240
+ self.nodes_to_add.append(new_node)
241
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
242
+
243
+ self.nodes_to_remove.extend([attention_last_node, transpose_qkv])
244
+
245
+ # Use prune graph to remove nodes since they are shared by all attention nodes.
246
+ self.prune_graph = True
247
+
248
+ def match_attention_subgraph(self, node_after_output_projection):
249
+ """Match Q, K and V paths exported by PyTorch 2.*"""
250
+ qkv_nodes = self.model.match_parent_path(
251
+ node_after_output_projection,
252
+ ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
253
+ [None, None, None, 0, 0],
254
+ )
255
+
256
+ if qkv_nodes is None:
257
+ return None
258
+
259
+ (_, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
260
+
261
+ v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None])
262
+ if v_nodes is None:
263
+ logger.debug("fuse_attention: failed to match v path")
264
+ return None
265
+ (_, _, add_v, matmul_v) = v_nodes
266
+
267
+ qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
268
+ if qk_nodes is not None:
269
+ (_softmax_qk, matmul_qk) = qk_nodes
270
+ else:
271
+ logger.debug("fuse_attention: failed to match qk path")
272
+ return None
273
+
274
+ q_nodes = self.model.match_parent_path(
275
+ matmul_qk, ["Mul", "Transpose", "Reshape", "Add", "MatMul"], [0, None, 0, 0, None]
276
+ )
277
+ if q_nodes is None:
278
+ logger.debug("fuse_attention: failed to match q path")
279
+ return None
280
+ (mul_q, _transpose_q, reshape_q, add_q, matmul_q) = q_nodes
281
+
282
+ k_nodes = self.model.match_parent_path(
283
+ matmul_qk, ["Mul", "Transpose", "Reshape", "Add", "MatMul"], [1, None, 0, 0, None]
284
+ )
285
+ if k_nodes is None:
286
+ logger.debug("fuse_attention: failed to match k path")
287
+ return None
288
+
289
+ (_mul_k, _, _, add_k, matmul_k) = k_nodes
290
+
291
+ # The scalar for Q and K is sqrt(1.0/sqrt(head_size)).
292
+ mul_q_nodes = self.model.match_parent_path(
293
+ mul_q,
294
+ ["Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape", "Transpose", "Reshape"],
295
+ [None, 0, 1, 0, 0, 0, 0, 0],
296
+ )
297
+ if mul_q_nodes is None or mul_q_nodes[-1] != reshape_q:
298
+ logger.debug("fuse_attention: failed to match mul_q path")
299
+ return None
300
+
301
+ return reshape_qkv, transpose_qkv, reshape_q, matmul_q, add_q, matmul_k, add_k, matmul_v, add_v
302
+
303
+ # --------------------------------------------------------
304
+ # The following are for SAM encoder
305
+ # --------------------------------------------------------
306
+ def fuse_sam_encoder_pattern(self, normalize_node, input_name_to_nodes, output_name_to_node) -> bool:
307
+ # SAM encoder attention layer pattern:
308
+ # Add -----------+
309
+ # | |
310
+ # LayerNorm |
311
+ # | |
312
+ # Reshape |
313
+ # | |
314
+ # Transpose |
315
+ # | |
316
+ # MatMul |
317
+ # | |
318
+ # Add |
319
+ # | |
320
+ # Reshape |
321
+ # | |
322
+ # Split |
323
+ # | |
324
+ # Self Attention subgraph |
325
+ # | |
326
+ # Reshape |
327
+ # | |
328
+ # Transpose |
329
+ # | |
330
+ # Reshape |
331
+ # | |
332
+ # Add ----------+
333
+ # |
334
+ # LayerNorm (starts from here)
335
+
336
+ nodes = self.model.match_parent_path(
337
+ normalize_node,
338
+ ["Add", "Reshape", "Transpose", "Reshape"],
339
+ [0, None, 0, 0],
340
+ )
341
+ if nodes is None:
342
+ nodes = self.model.match_parent_path(
343
+ normalize_node,
344
+ ["Add", "Slice", "Slice", "Reshape", "Transpose", "Reshape"],
345
+ [0, None, 0, 0, 0, 0],
346
+ )
347
+ if nodes is None:
348
+ nodes = self.model.match_parent_path(
349
+ normalize_node,
350
+ ["Add"],
351
+ [0],
352
+ )
353
+ if nodes is None:
354
+ return False
355
+
356
+ node_after_output_projection = nodes[-1]
357
+ matched_sdpa = self.match_sam_encoder_attention_subgraph(
358
+ node_after_output_projection, input_index=1 if len(nodes) == 1 else None
359
+ )
360
+ if matched_sdpa is None:
361
+ return False
362
+
363
+ reshape_out, transpose_out, split_qkv, transpose_q, transpose_k, transpose_v = matched_sdpa
364
+
365
+ # B, S, N, H => B, N, S, H
366
+ permutation_q = OnnxModel.get_node_attribute(transpose_q, "perm")
367
+ if (not isinstance(permutation_q, list)) or permutation_q != [0, 2, 1, 3]:
368
+ return False
369
+
370
+ # B, S, N, H => B, N, H, S
371
+ permutation_k = OnnxModel.get_node_attribute(transpose_k, "perm")
372
+ if (not isinstance(permutation_k, list)) or permutation_k != [0, 2, 3, 1]:
373
+ return False
374
+
375
+ # B, S, N, H => B, N, S, H
376
+ permutation_v = OnnxModel.get_node_attribute(transpose_v, "perm")
377
+ if (not isinstance(permutation_v, list)) or permutation_v != [0, 2, 1, 3]:
378
+ return False
379
+
380
+ input_projection_nodes = self.model.match_parent_path(
381
+ split_qkv,
382
+ ["Reshape", "Add", "MatMul"],
383
+ [0, 0, None],
384
+ )
385
+ if input_projection_nodes is None:
386
+ return False
387
+ reshape_in, add_in, matmul_in = input_projection_nodes
388
+ q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_in, normalize_node, True)
389
+ if q_num_heads <= 0:
390
+ logger.debug("fuse_attention: failed to detect num_heads")
391
+ return False
392
+
393
+ # Add a shape to convert 4D BxSxNxH to 3D BxSxD, which is required by MHA operator.
394
+ new_dims_name = "bsnh_to_bsd_reshape_dims"
395
+ new_dims = self.model.get_initializer(new_dims_name)
396
+ if new_dims is None:
397
+ new_dims = numpy_helper.from_array(np.array([0, 0, -1], dtype="int64"), name=new_dims_name)
398
+ self.model.add_initializer(new_dims, self.this_graph_name)
399
+ reshape_q_name = self.model.create_node_name("Reshape")
400
+ reshape_q = helper.make_node(
401
+ "Reshape",
402
+ inputs=[transpose_q.input[0], new_dims_name],
403
+ outputs=[transpose_q.input[0] + "_BSD"],
404
+ name=reshape_q_name,
405
+ )
406
+ self.nodes_to_add.append(reshape_q)
407
+ self.node_name_to_graph_name[reshape_q.name] = self.this_graph_name
408
+
409
+ # Reuse the transpose_q node to transpose K from BSNH to BNSH. Here we update the input and output of the node.
410
+ transpose_k_bnsh = transpose_q
411
+ transpose_k_bnsh.input[0] = transpose_k.input[0]
412
+ transpose_k_bnsh.output[0] = transpose_k.input[0] + "_BNSH"
413
+
414
+ logger.debug(f"Found MHA: {q_num_heads=} {q_hidden_size=}")
415
+
416
+ # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
417
+ new_node = self.create_mha_node(
418
+ reshape_q,
419
+ transpose_k_bnsh,
420
+ transpose_v,
421
+ q_num_heads,
422
+ )
423
+ if new_node is None:
424
+ return False
425
+
426
+ # Update the input of the next node that consumes the output of the MHA.
427
+ assert len(self.model.get_children(transpose_out, input_name_to_nodes)) == 1
428
+ reshape_out.input[0] = new_node.output[0]
429
+
430
+ self.nodes_to_add.append(new_node)
431
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
432
+ self.nodes_to_remove.extend([transpose_out])
433
+
434
+ # Use prune graph to remove nodes since they are shared by all attention nodes.
435
+ self.prune_graph = True
436
+ return True
437
+
438
+ def match_sam_encoder_attention_subgraph(self, node_after_output_projection, input_index=None):
439
+ """Match SDPA pattern in SAM2 enconder.*"""
440
+
441
+ # nodes of output projection and the second MatMul in SDPA.
442
+ out_nodes = self.model.match_parent_path(
443
+ node_after_output_projection,
444
+ ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
445
+ [input_index, None, None, 0, 0],
446
+ )
447
+
448
+ if out_nodes is None:
449
+ return None
450
+
451
+ (_, _, reshape_out, transpose_out, matmul_qk_v) = out_nodes
452
+
453
+ # Split and Reshape is for packed QKV
454
+ v_nodes = self.model.match_parent_path(matmul_qk_v, ["Transpose", "Squeeze", "Split", "Reshape"], [1, 0, 0, 0])
455
+ if v_nodes is None:
456
+ logger.debug("failed to match v path")
457
+ return None
458
+ (transpose_v, _, split_qkv, reshape_qkv) = v_nodes
459
+
460
+ qk_nodes = self.model.match_parent_path(matmul_qk_v, ["Softmax", "MatMul"], [0, 0])
461
+ if qk_nodes is not None:
462
+ (_softmax_qk, matmul_qk) = qk_nodes
463
+ else:
464
+ logger.debug("failed to match qk path")
465
+ return None
466
+
467
+ q_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Squeeze", "Split"], [0, None, 0, 0])
468
+ if q_nodes is None:
469
+ q_nodes = self.model.match_parent_path(
470
+ matmul_qk,
471
+ ["Mul", "Transpose", "Reshape", "Transpose", "MaxPool", "Transpose", "Reshape", "Squeeze", "Split"],
472
+ [0, None, 0, 0, 0, 0, 0, 0, 0],
473
+ )
474
+ if q_nodes is None:
475
+ logger.debug("failed to match q path")
476
+ return None
477
+
478
+ if q_nodes[-1] != split_qkv:
479
+ return None
480
+ transpose_q = q_nodes[1]
481
+
482
+ k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Squeeze", "Split"], [1, None, 0, 0])
483
+ if k_nodes is None:
484
+ logger.debug("failed to match k path")
485
+ return None
486
+
487
+ if k_nodes[-1] != split_qkv:
488
+ return None
489
+ (mul_k, transpose_k, _squeeze_k, _) = k_nodes
490
+
491
+ return reshape_out, transpose_out, split_qkv, transpose_q, transpose_k, transpose_v
492
+
493
+ def create_mha_node(
494
+ self,
495
+ reshape_q: NodeProto,
496
+ transpose_k: NodeProto,
497
+ transpose_v: NodeProto,
498
+ num_heads: int,
499
+ ) -> NodeProto:
500
+ """Create a MultiHeadAttention node for SAM2 encoder.
501
+
502
+ Args:
503
+ reshape_q (NodeProto): Reshape node for Q, output is 3D BxSxNH format
504
+ transpose_k (NodeProto): Transpose node for K, output is BNSH format
505
+ transpose_v (NodeProto): Transpose node for V, output is BNSH format
506
+ num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
507
+
508
+ Returns:
509
+ NodeProto: the MultiHeadAttention node created.
510
+ """
511
+
512
+ attention_node_name = self.model.create_node_name("MultiHeadAttention")
513
+
514
+ inputs = [
515
+ reshape_q.output[0],
516
+ transpose_k.output[0],
517
+ transpose_v.output[0],
518
+ ]
519
+
520
+ # Create a new output name since the shape is 3D, which is different from the original output shape (4D).
521
+ output = attention_node_name + "_out"
522
+
523
+ attention_node = helper.make_node(
524
+ "MultiHeadAttention",
525
+ inputs=inputs,
526
+ outputs=[output],
527
+ name=attention_node_name,
528
+ )
529
+ attention_node.domain = "com.microsoft"
530
+ attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
531
+
532
+ counter_name = "MultiHeadAttention ({})".format("self attention")
533
+ self.increase_counter(counter_name)
534
+ return attention_node