onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,667 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+
7
+ import numpy as np
8
+ from fusion_base import Fusion
9
+ from fusion_utils import FusionUtils
10
+ from onnx import NodeProto, TensorProto, helper, numpy_helper
11
+ from onnx_model import OnnxModel
12
+
13
+ logger = getLogger(__name__)
14
+
15
+
16
+ class FusionMultiHeadAttentionMMDit(Fusion):
17
+ """
18
+ Fuse MultiHeadAttention for Multimodal Diffusion Transformer (MMDiT).
19
+ """
20
+
21
+ def __init__(self, model: OnnxModel):
22
+ super().__init__(model, fused_op_type="MultiHeadAttention", search_op_types=["Softmax"])
23
+ self.unsqueeze_update_map = {}
24
+
25
+ def get_num_heads(self, start_node: NodeProto, output_name_to_node, input_index=0) -> int:
26
+ """
27
+ Detect num_heads from Reshape & Transpose of q/k/v for both Stable Diffusion 3.x and Flux 1.x:
28
+
29
+ MatMul .. [-1] [24] ..
30
+ | | | / /
31
+ Add Concat(axis=0)
32
+ | /
33
+ Reshape
34
+ |
35
+ Transpose(perm=0,1,3,2)
36
+ |
37
+ (start_node)
38
+ """
39
+ nodes = self.model.match_parent_path(
40
+ start_node, ["Transpose", "Reshape", "Concat"], [input_index, 0, 1], output_name_to_node=output_name_to_node
41
+ )
42
+ if nodes is None:
43
+ return 0
44
+
45
+ concat_shape = nodes[-1]
46
+ if len(concat_shape.input) != 4:
47
+ return 0
48
+
49
+ value = self.model.get_constant_value(concat_shape.input[2])
50
+ if value is None:
51
+ return 0
52
+
53
+ if len(value.shape) != 1:
54
+ return 0
55
+
56
+ return int(value[0])
57
+
58
+ def get_num_heads_from_k(self, transpose_k: NodeProto, output_name_to_node, concat_before_transpose: bool) -> int:
59
+ """
60
+ Detect num_heads from subgraph like the following (num_heads=24 in this example):
61
+ MatMu .. [-1] [24] ..
62
+ | | | / /
63
+ Add Concat
64
+ | /
65
+ Reshape
66
+ |
67
+ Transpose(perm=0,2,1,3)
68
+ |
69
+ SimplifiedLayerNormalization
70
+ |
71
+ Transpose(perm=0,1,3,2)
72
+
73
+ Another variant is to an extra Concat node to join two symmetrical subgraphs:
74
+
75
+ | |
76
+ MatMul MatMul .. [-1] [24] ..
77
+ | | | | / /
78
+ Add Concat Add Concat
79
+ | / | /
80
+ Reshape Reshape
81
+ | |
82
+ Transpose Transpose(perm=0,2,1,3)
83
+ | |
84
+ SimplifiedLayerNormalization SimplifiedLayerNormalization
85
+ | /
86
+ Concat
87
+ |
88
+ Transpose(perm=0,1,3,2)
89
+
90
+ Both patterns are used in stable diffusion 3.5 model.
91
+ """
92
+ if concat_before_transpose:
93
+ nodes = self.model.match_parent_path(
94
+ transpose_k, ["Concat", "SimplifiedLayerNormalization"], [0, 1], output_name_to_node=output_name_to_node
95
+ )
96
+ if nodes:
97
+ return self.get_num_heads(nodes[1], output_name_to_node)
98
+ else:
99
+ nodes = self.model.match_parent_path(
100
+ transpose_k, ["SimplifiedLayerNormalization"], [0], output_name_to_node=output_name_to_node
101
+ )
102
+ if nodes:
103
+ return self.get_num_heads(nodes[0], output_name_to_node)
104
+
105
+ return 0
106
+
107
+ def reshape_to_3d(self, input_name: str, output_name: str) -> str:
108
+ """Add a Reshape node to convert 4D BxSxNxH to 3D BxSxD.
109
+
110
+ Args:
111
+ input_name (str): input name for the 4D tensor of shape BxSxNxH.
112
+ output_name (str): output name for the 3D tensor of shape BxSxD, where D = N * H.
113
+
114
+ Returns:
115
+ str: the output name
116
+ """
117
+
118
+ new_dims_name = "bsnh_to_bsd_reshape_dims"
119
+ new_dims = self.model.get_initializer(new_dims_name)
120
+ if new_dims is None:
121
+ new_dims = numpy_helper.from_array(np.array([0, 0, -1], dtype="int64"), name=new_dims_name)
122
+ self.model.add_initializer(new_dims, self.this_graph_name)
123
+ reshape_q = helper.make_node(
124
+ "Reshape",
125
+ inputs=[input_name, new_dims_name],
126
+ outputs=[output_name],
127
+ name=self.model.create_node_name("Reshape"),
128
+ )
129
+ self.nodes_to_add.append(reshape_q)
130
+ self.node_name_to_graph_name[reshape_q.name] = self.this_graph_name
131
+ return reshape_q.output[0]
132
+
133
+ def adjust_query_from_bnsh_to_bsd_no_concat(self, mul_q: NodeProto, output_name_to_node) -> str | None:
134
+ """
135
+ MultiHeadAttenion requires query in BSD format. This function adjusts query from BNSH to BSD format.
136
+
137
+ Before:
138
+ MatMul
139
+ |
140
+ Add Concat
141
+ | /
142
+ Reshape
143
+ |
144
+ Transpose(perm=0,2,1,3)
145
+ |
146
+ SimplifiedLayerNorm
147
+ |
148
+ Mul
149
+
150
+ After:
151
+ MatMul
152
+ |
153
+ Add Concat
154
+ | /
155
+ Reshape
156
+ |
157
+ SimplifiedLayerNorm
158
+ |
159
+ Reshape (shape=[0, 0, -1])
160
+ """
161
+
162
+ path = self.model.match_parent_path(
163
+ mul_q,
164
+ ["SimplifiedLayerNormalization", "Transpose"],
165
+ [0, 0],
166
+ )
167
+ if path is None:
168
+ return None
169
+ sln_a, transpose_a = path
170
+
171
+ if not FusionUtils.check_node_attribute(transpose_a, "perm", [0, 2, 1, 3]):
172
+ return None
173
+
174
+ # Update the graph
175
+ sln_a.input[0] = transpose_a.input[0]
176
+ sln_output = sln_a.output[0]
177
+ sln_a.output[0] = sln_output + "_BSNH"
178
+
179
+ return self.reshape_to_3d(sln_a.output[0], sln_output + "_BSD")
180
+
181
+ def adjust_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -> str | None:
182
+ """
183
+ MultiHeadAttenion requires query in BSD format. This function adjusts query from BNSH to BSD format.
184
+
185
+ Before:
186
+ MatMul MatMul
187
+ | |
188
+ Add Concat Add Concat
189
+ | / | /
190
+ Reshape Reshape
191
+ | |
192
+ Transpose(perm=0,2,1,3) Transpose(perm=0,2,1,3)
193
+ | |
194
+ SimplifiedLayerNorm SimplifiedLayerNorm
195
+ | /
196
+ Concat(axis=2)
197
+ |
198
+ Mul
199
+
200
+ After:
201
+ MatMul MatMul
202
+ | |
203
+ Add Concat Add Concat
204
+ | / | /
205
+ Reshape Reshape
206
+ | |
207
+ SimplifiedLayerNorm SimplifiedLayerNorm
208
+ | /
209
+ Concat(axis=1)
210
+ |
211
+ Reshape (shape=[0, 0, -1])
212
+ """
213
+
214
+ path = self.model.match_parent_path(
215
+ mul_q,
216
+ ["Concat", "SimplifiedLayerNormalization", "Transpose"],
217
+ [0, 0, 0],
218
+ )
219
+ if path is None:
220
+ return None
221
+ concat, sln_a, transpose_a = path
222
+
223
+ if len(concat.input) != 2:
224
+ return None
225
+
226
+ path = self.model.match_parent_path(
227
+ concat,
228
+ ["SimplifiedLayerNormalization", "Transpose"],
229
+ [1, 0],
230
+ )
231
+ if path is None:
232
+ return None
233
+ sln_b, transpose_b = path
234
+
235
+ if not FusionUtils.check_node_attribute(transpose_a, "perm", [0, 2, 1, 3]):
236
+ return None
237
+
238
+ if not FusionUtils.check_node_attribute(transpose_b, "perm", [0, 2, 1, 3]):
239
+ return None
240
+
241
+ if not FusionUtils.check_node_attribute(concat, "axis", 2):
242
+ return None
243
+
244
+ # Update the graph
245
+ sln_a.input[0] = transpose_a.input[0]
246
+ sln_b.input[0] = transpose_b.input[0]
247
+
248
+ new_concat_node = helper.make_node(
249
+ "Concat",
250
+ inputs=[sln_a.output[0], sln_b.output[0]],
251
+ outputs=[concat.output[0] + "_BSNH"],
252
+ name=self.model.create_node_name("Concat"),
253
+ axis=1,
254
+ )
255
+ self.nodes_to_add.append(new_concat_node)
256
+ self.node_name_to_graph_name[new_concat_node.name] = self.this_graph_name
257
+
258
+ return self.reshape_to_3d(new_concat_node.output[0], concat.output[0] + "_BSD")
259
+
260
+ def update_unsqueeze_axes_1_to_2(self, unsqueeze: NodeProto) -> str:
261
+ updated_unsqueeze_output = self.unsqueeze_update_map.get(unsqueeze.name)
262
+ if updated_unsqueeze_output is None:
263
+ if len(unsqueeze.input) == 1:
264
+ new_node = helper.make_node(
265
+ "Unsqueeze",
266
+ inputs=unsqueeze.input,
267
+ outputs=[unsqueeze.output[0] + "_BSNH"],
268
+ name=self.model.create_node_name("Unsqueeze"),
269
+ axes=[2],
270
+ )
271
+ else:
272
+ initializer_name = "unsqueeze_axes_2"
273
+ if self.model.get_initializer(initializer_name) is None:
274
+ unsqueeze_axes_2 = helper.make_tensor(
275
+ name=initializer_name,
276
+ data_type=TensorProto.INT64,
277
+ dims=[1], # Shape of the tensor
278
+ vals=[2], # Tensor values
279
+ )
280
+ self.model.add_initializer(unsqueeze_axes_2, self.this_graph_name)
281
+
282
+ new_node = helper.make_node(
283
+ "Unsqueeze",
284
+ inputs=[unsqueeze.input[0], initializer_name],
285
+ outputs=[unsqueeze.output[0] + "_BSNH"],
286
+ name=self.model.create_node_name("Unsqueeze"),
287
+ )
288
+
289
+ self.nodes_to_add.append(new_node)
290
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
291
+ updated_unsqueeze_output = new_node.output[0]
292
+ self.unsqueeze_update_map[unsqueeze.name] = updated_unsqueeze_output
293
+
294
+ return updated_unsqueeze_output
295
+
296
+ def update_unsqueeze_axes(self, add: NodeProto, output_name_to_node: dict[str, NodeProto]) -> bool:
297
+ """
298
+ Update axes of Unsqueeze from [1] to [2] in the following pattern:
299
+ Unsqueeze Unsqueeze
300
+ (axes=[0]) (axes=[0])
301
+ | |
302
+ Unsqueeze Unsqueeze
303
+ ... (axes=[1]) ... (axes=[1])
304
+ | / | /
305
+ Mul Mul
306
+ | /
307
+ Add
308
+ Args:
309
+ add (NodeProto): the Add node
310
+ output_name_to_node (Dict[str, NodeProto]): mapping from output name to node
311
+
312
+ Returns:
313
+ bool: True if the pattern is matched and updated successfully, False otherwise.
314
+ """
315
+ if len(add.input) != 2:
316
+ return False
317
+
318
+ # Check axes of Unsqueeze nodes are [0] and [1], and change to [0] and [2] respectively.
319
+ nodes_b = self.model.match_parent_path(add, ["Mul", "Unsqueeze", "Unsqueeze"], [1, 1, 0], output_name_to_node)
320
+ if nodes_b is None:
321
+ return False
322
+
323
+ fusion_utils = FusionUtils(self.model)
324
+ axes_1 = fusion_utils.get_squeeze_or_unsqueeze_axes(nodes_b[1])
325
+ if axes_1 is None or axes_1 != [1]:
326
+ return False
327
+
328
+ axes_0 = fusion_utils.get_squeeze_or_unsqueeze_axes(nodes_b[2])
329
+ if axes_0 is None or axes_0 != [0]:
330
+ return False
331
+
332
+ # Check axes of Unsqueeze nodes are [0] and [1], and change to [0] and [2] respectively.
333
+ nodes_a = self.model.match_parent_path(add, ["Mul", "Unsqueeze", "Unsqueeze"], [0, 1, 0], output_name_to_node)
334
+ if nodes_a is None:
335
+ return False
336
+
337
+ axes_1 = fusion_utils.get_squeeze_or_unsqueeze_axes(nodes_a[1])
338
+ if axes_1 is None or axes_1 != [1]:
339
+ return False
340
+
341
+ axes_0 = fusion_utils.get_squeeze_or_unsqueeze_axes(nodes_a[2])
342
+ if axes_0 is None or axes_0 != [0]:
343
+ return False
344
+
345
+ nodes_a[0].input[1] = self.update_unsqueeze_axes_1_to_2(nodes_a[1])
346
+ nodes_b[0].input[1] = self.update_unsqueeze_axes_1_to_2(nodes_b[1])
347
+ return True
348
+
349
+ def adjust_flux_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -> str | None:
350
+ """
351
+ Adjust graph to change query format from BNSH to BSD for Flux model.
352
+ Note that the graph pattern is complex, and we only do a shallow match here.
353
+
354
+ Before:
355
+ | |
356
+ Transpose(perm=0,2,1,3) Transpose(perm=0,2,1,3)
357
+ | |
358
+ SimplifiedLayerNorm SimplifiedLayerNorm
359
+ | /
360
+ Concat(axis=2)
361
+ |
362
+ Mul Mul
363
+ | /
364
+ Add
365
+ |
366
+ Mul
367
+
368
+ After (Transpose nods are removed, and a Reshape is added):
369
+
370
+ | |
371
+ SimplifiedLayerNorm SimplifiedLayerNorm
372
+ | /
373
+ Concat(axis=1)
374
+ |
375
+ Mul Mul
376
+ | /
377
+ Add
378
+ |
379
+ Reshape (shape=[0, 0, -1])
380
+ """
381
+
382
+ path = self.model.match_parent_path(
383
+ mul_q,
384
+ ["Add", "Mul", "Concat", "SimplifiedLayerNormalization", "Transpose"],
385
+ [0, 0, 0, 0, 0],
386
+ )
387
+ if path is None:
388
+ return None
389
+ add, _mul_a, concat, sln_a, transpose_a = path
390
+
391
+ if len(concat.input) != 2:
392
+ return None
393
+
394
+ path = self.model.match_parent_path(
395
+ concat,
396
+ ["SimplifiedLayerNormalization", "Transpose"],
397
+ [1, 0],
398
+ )
399
+ if path is None:
400
+ return None
401
+ sln_b, transpose_b = path
402
+
403
+ if not FusionUtils.check_node_attribute(transpose_a, "perm", [0, 2, 1, 3]):
404
+ return None
405
+
406
+ if not FusionUtils.check_node_attribute(transpose_b, "perm", [0, 2, 1, 3]):
407
+ return None
408
+
409
+ if not FusionUtils.check_node_attribute(concat, "axis", 2):
410
+ return None
411
+
412
+ # Need adjust axes of Unsqueeze nodes from [1] to [2] so that the tensors to Mul nodes are BSNH instead of BNSH.
413
+ if not self.update_unsqueeze_axes(add, output_name_to_node):
414
+ return None
415
+
416
+ # Update the graph
417
+ sln_a.input[0] = transpose_a.input[0]
418
+ sln_b.input[0] = transpose_b.input[0]
419
+
420
+ new_concat_node = helper.make_node(
421
+ "Concat",
422
+ inputs=[sln_a.output[0], sln_b.output[0]],
423
+ outputs=[concat.output[0] + "_BSNH"],
424
+ name=self.model.create_node_name("Concat"),
425
+ axis=1,
426
+ )
427
+ self.nodes_to_add.append(new_concat_node)
428
+ self.node_name_to_graph_name[new_concat_node.name] = self.this_graph_name
429
+ self.model.replace_input_of_all_nodes(concat.output[0], new_concat_node.output[0])
430
+
431
+ return self.reshape_to_3d(add.output[0], add.output[0] + "_BSD")
432
+
433
+ def adjust_flux_single_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -> str | None:
434
+ """
435
+ Adjust graph to change query format from BNSH to BSD for Flux model.
436
+ Note that the graph pattern is complex, and we only do a shallow match here.
437
+
438
+ Before:
439
+ |
440
+ Transpose(perm=0,2,1,3)
441
+ |
442
+ SimplifiedLayerNorm
443
+ |
444
+ Mul Mul
445
+ | /
446
+ Add
447
+ |
448
+ Mul
449
+
450
+ After (Transpose is removed, and a Reshape is added):
451
+
452
+ |
453
+ SimplifiedLayerNorm
454
+ |
455
+ Mul Mul
456
+ | /
457
+ Add
458
+ |
459
+ Reshape (shape=[0, 0, -1])
460
+ """
461
+
462
+ path = self.model.match_parent_path(
463
+ mul_q,
464
+ ["Add", "Mul", "SimplifiedLayerNormalization", "Transpose"],
465
+ [0, 0, 0, 0],
466
+ )
467
+ if path is None:
468
+ return None
469
+ add, _mul_a, sln_a, transpose_a = path
470
+
471
+ if not FusionUtils.check_node_attribute(transpose_a, "perm", [0, 2, 1, 3]):
472
+ return None
473
+
474
+ # Need adjust axes of Unsqueeze nodes from [1] to [2] so that the tensors to Mul nodes are BSNH instead of BNSH.
475
+ if not self.update_unsqueeze_axes(add, output_name_to_node):
476
+ return None
477
+
478
+ # Update the graph
479
+ sln_a.input[0] = transpose_a.input[0]
480
+ add.output[0] = add.output[0] + "_BSNH"
481
+
482
+ return self.reshape_to_3d(add.output[0], add.output[0] + "_BSD")
483
+
484
+ def transpose_reshape_bnsh_to_bsd(self, q: str, output_name_to_node) -> str | None:
485
+ transpose_q = helper.make_node(
486
+ "Transpose",
487
+ [q],
488
+ [q + "_BSNH"],
489
+ name=self.model.create_node_name("Transpose", name_prefix="Transpose_BNSH_to_BSNH"),
490
+ perm=[0, 2, 1, 3],
491
+ )
492
+ self.nodes_to_add.append(transpose_q)
493
+ self.node_name_to_graph_name[transpose_q.name] = self.this_graph_name
494
+
495
+ return self.reshape_to_3d(q + "_BSNH", q + "_BSD")
496
+
497
+ def create_multihead_attention_node(
498
+ self,
499
+ q: str,
500
+ k: str,
501
+ v: str,
502
+ output: str,
503
+ num_heads: int,
504
+ ) -> NodeProto:
505
+ """
506
+ Create a MultiHeadAttention node.
507
+
508
+ Args:
509
+ q (str): name of q
510
+ k (str): name of k
511
+ v (str): name of v
512
+ output (str): output name of MHA
513
+ num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
514
+
515
+ Returns:
516
+ NodeProto: the node created.
517
+ """
518
+
519
+ assert num_heads > 0
520
+
521
+ # Add inputs for MHA: Query, Key, Value (Proj_Bias, Mask, Attention_Bias, Past_K, Past_V are optional)
522
+ mha_inputs = [q, k, v]
523
+
524
+ # Add outputs for MHA (Present_K, Present_V are optional)
525
+ mha_outputs = [output]
526
+
527
+ mha_node = helper.make_node(
528
+ "MultiHeadAttention",
529
+ inputs=mha_inputs,
530
+ outputs=mha_outputs,
531
+ name=self.model.create_node_name("MultiHeadAttention"),
532
+ )
533
+
534
+ mha_node.domain = "com.microsoft"
535
+ mha_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
536
+
537
+ # No mask is used in MMDit model, so we need not set the optional mask_filter_value attribute.
538
+ return mha_node
539
+
540
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
541
+ assert node.op_type == "Softmax"
542
+ softmax = node
543
+
544
+ # Softmax output shall not be graph output.
545
+ if self.model.find_graph_output(softmax.output[0]):
546
+ return
547
+
548
+ nodes = self.model.match_child_path(
549
+ softmax, ["MatMul", "Transpose", "Reshape"], [(0, 0), (0, 0), (0, 0)], input_name_to_nodes
550
+ )
551
+ if nodes is None:
552
+ return
553
+
554
+ matmul_s_v, transpose_out, reshape_out = nodes
555
+ if not FusionUtils.check_node_attribute(transpose_out, "perm", [0, 2, 1, 3]):
556
+ return
557
+
558
+ q_nodes = self.model.match_parent_path(
559
+ softmax,
560
+ ["MatMul", "Mul", "Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape"],
561
+ [0, 0, 1, 0, 1, 0, 0, 0],
562
+ )
563
+
564
+ if q_nodes is None:
565
+ return
566
+
567
+ matmul_qk, mul_q, sqrt_q_2, div_q, sqrt_q, _, _, shape_q = q_nodes
568
+
569
+ q_bnsh = mul_q.input[0]
570
+ if q_bnsh != shape_q.input[0]:
571
+ return
572
+
573
+ k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose"], [1, 0])
574
+ if k_nodes is None:
575
+ return
576
+
577
+ mul_k, transpose_k = k_nodes
578
+ k = transpose_k.input[0]
579
+ if not FusionUtils.check_node_attribute(transpose_k, "perm", [0, 1, 3, 2]):
580
+ return
581
+
582
+ k_scale_nodes = self.model.match_parent_path(mul_k, ["Sqrt", "Div"], [1, 0])
583
+ if k_scale_nodes is None:
584
+ return
585
+ if k_scale_nodes[0].input[0] != sqrt_q_2.input[0]:
586
+ return
587
+
588
+ v = matmul_s_v.input[1]
589
+
590
+ # Here we sanity check the v path to make sure it is in the expected BNSH format.
591
+ concat_v = self.model.match_parent(matmul_s_v, "Concat", input_index=1, output_name_to_node=output_name_to_node)
592
+ if concat_v is not None:
593
+ # Match v path like:
594
+ # -- Transpose (perm=[0,2,1,3]) ----+
595
+ # |
596
+ # v
597
+ # -- Transpose (perm=[0,2,1,3]) -> Concat -> (v)
598
+ transpose_1 = self.model.match_parent(
599
+ concat_v, "Transpose", input_index=0, output_name_to_node=output_name_to_node
600
+ )
601
+ if transpose_1 is None:
602
+ return
603
+ if not FusionUtils.check_node_attribute(transpose_1, "perm", [0, 2, 1, 3]):
604
+ return
605
+
606
+ transpose_2 = self.model.match_parent(
607
+ concat_v, "Transpose", input_index=1, output_name_to_node=output_name_to_node
608
+ )
609
+ if transpose_2 is None:
610
+ return
611
+ if not FusionUtils.check_node_attribute(transpose_2, "perm", [0, 2, 1, 3]):
612
+ return
613
+ else:
614
+ # Match v path like:
615
+ # -- Transpose (perm=[0,2,1,3]) -> (v)
616
+ transpose_1 = self.model.match_parent(
617
+ matmul_s_v, "Transpose", input_index=1, output_name_to_node=output_name_to_node
618
+ )
619
+ if transpose_1 is None:
620
+ return
621
+ if not FusionUtils.check_node_attribute(transpose_1, "perm", [0, 2, 1, 3]):
622
+ return
623
+
624
+ # Match patterns for Flux.
625
+ num_heads = (
626
+ self.get_num_heads(concat_v, output_name_to_node)
627
+ if concat_v
628
+ else self.get_num_heads(matmul_s_v, output_name_to_node, input_index=1)
629
+ )
630
+
631
+ if num_heads == 0:
632
+ # Match patterns for Stable Diffusion 3.5.
633
+ num_heads = self.get_num_heads_from_k(transpose_k, output_name_to_node, concat_v is not None)
634
+ if num_heads <= 0:
635
+ return
636
+
637
+ # Q is in BNSH format, we need to adjust it to BSD format due to limitation of MHA op.
638
+ # TODO: MHA op support BNSH format to reduce the effort in fusion.
639
+ if concat_v is not None:
640
+ query = self.adjust_query_from_bnsh_to_bsd(mul_q, output_name_to_node)
641
+ else:
642
+ query = self.adjust_query_from_bnsh_to_bsd_no_concat(mul_q, output_name_to_node)
643
+
644
+ if query is None:
645
+ query = self.adjust_flux_query_from_bnsh_to_bsd(mul_q, output_name_to_node)
646
+ if query is None:
647
+ query = self.adjust_flux_single_query_from_bnsh_to_bsd(mul_q, output_name_to_node)
648
+ if query is None:
649
+ # fallback to use Transpose and Add to adjust query from BNSH to BSD
650
+ # This is more general approach.
651
+ # However, it might be slower if the extra Transpose node cannot be removed by ORT optimizer.
652
+ query = self.transpose_reshape_bnsh_to_bsd(q_bnsh, output_name_to_node)
653
+
654
+ new_node = self.create_multihead_attention_node(
655
+ q=query,
656
+ k=k,
657
+ v=v,
658
+ output=reshape_out.output[0],
659
+ num_heads=num_heads,
660
+ )
661
+ self.nodes_to_add.append(new_node)
662
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
663
+
664
+ self.nodes_to_remove.extend([matmul_s_v, transpose_out, reshape_out])
665
+
666
+ # Use prune graph to remove nodes
667
+ self.prune_graph = True