onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1189 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+
7
+ import numpy as np
8
+ from fusion_base import Fusion
9
+ from fusion_options import AttentionMaskFormat
10
+ from fusion_utils import FusionUtils, NumpyHelper
11
+ from onnx import NodeProto, TensorProto, helper, numpy_helper
12
+ from onnx_model import OnnxModel
13
+
14
+ logger = getLogger(__name__)
15
+
16
+
17
+ class AttentionMask:
18
+ """
19
+ Fuse Attention subgraph into one Attention node.
20
+ """
21
+
22
+ def __init__(self, model: OnnxModel):
23
+ self.model = model
24
+ # A lookup table with mask input as key, and mask index output as value
25
+ self.mask_indice = {}
26
+ # A lookup table with mask input as key, and cast (to int32) output as value
27
+ self.mask_casted = {}
28
+ self.utils = FusionUtils(model)
29
+ self.mask_format = AttentionMaskFormat.MaskIndexEnd
30
+ self.opset_version = model.get_opset_version()
31
+
32
+ def set_mask_format(self, mask_format: AttentionMaskFormat):
33
+ self.mask_format = mask_format
34
+
35
+ def set_mask_indice(self, mask, mask_index):
36
+ if mask in self.mask_indice:
37
+ assert mask_index == self.mask_indice[mask]
38
+ self.mask_indice[mask] = mask_index
39
+
40
+ def get_first_mask(self):
41
+ assert len(self.mask_indice) > 0
42
+ return next(iter(self.mask_indice))
43
+
44
+ def process_mask(self, mask_2d: str) -> str | None:
45
+ if self.mask_format == AttentionMaskFormat.NoMask:
46
+ return None
47
+
48
+ if mask_2d in self.mask_indice:
49
+ return self.mask_indice[mask_2d]
50
+
51
+ # Add cast to convert int64 to int32
52
+ if self.model.find_graph_input(mask_2d):
53
+ casted, input_name = self.utils.cast_graph_input_to_int32(mask_2d)
54
+ else:
55
+ input_name, _cast_node = self.utils.cast_input_to_int32(mask_2d)
56
+ casted = True
57
+
58
+ if casted:
59
+ self.mask_casted[mask_2d] = input_name
60
+
61
+ # Attention supports int32 attention mask (2D) since 1.4.0
62
+ if self.mask_format == AttentionMaskFormat.AttentionMask:
63
+ self.mask_indice[mask_2d] = input_name
64
+ return input_name
65
+
66
+ # Add a mask processing node to convert attention mask to mask index (1D)
67
+ output_name = self.model.create_node_name("mask_index")
68
+ if self.opset_version < 13:
69
+ mask_index_node = helper.make_node(
70
+ "ReduceSum",
71
+ inputs=[input_name],
72
+ outputs=[output_name],
73
+ name=self.model.create_node_name("ReduceSum", "MaskReduceSum"),
74
+ )
75
+ mask_index_node.attribute.extend([helper.make_attribute("axes", [1]), helper.make_attribute("keepdims", 0)])
76
+ else:
77
+ # ReduceSum-13: axes is moved from attribute to input
78
+ axes_name = "ort_const_1_reduce_sum_axes"
79
+ if self.model.get_initializer(axes_name) is None:
80
+ self.model.add_initializer(
81
+ helper.make_tensor(
82
+ name=axes_name,
83
+ data_type=TensorProto.INT64,
84
+ dims=[1],
85
+ vals=[1],
86
+ raw=False,
87
+ )
88
+ )
89
+ mask_index_node = helper.make_node(
90
+ "ReduceSum",
91
+ inputs=[input_name, axes_name],
92
+ outputs=[output_name],
93
+ name=self.model.create_node_name("ReduceSum", "MaskReduceSum"),
94
+ )
95
+ mask_index_node.attribute.extend([helper.make_attribute("keepdims", 0)])
96
+
97
+ self.model.add_node(mask_index_node)
98
+
99
+ self.mask_indice[mask_2d] = output_name
100
+ return output_name
101
+
102
+
103
+ class FusionAttention(Fusion):
104
+ """
105
+ Fuse Attention subgraph into one Attention node.
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ model: OnnxModel,
111
+ hidden_size: int,
112
+ num_heads: int,
113
+ attention_mask: AttentionMask | None = None,
114
+ use_multi_head_attention: bool = False,
115
+ disable_multi_head_attention_bias: bool = False,
116
+ search_op_types: list[str] = ["SkipLayerNormalization", "LayerNormalization"], # noqa: B006
117
+ ):
118
+ attention_op_name = "MultiHeadAttention" if use_multi_head_attention else "Attention"
119
+ super().__init__(model, attention_op_name, search_op_types)
120
+ self.hidden_size = hidden_size
121
+ self.num_heads = num_heads
122
+ self.attention_mask = attention_mask if attention_mask else AttentionMask(model)
123
+ self.use_multi_head_attention = use_multi_head_attention
124
+ self.disable_multi_head_attention_bias = disable_multi_head_attention_bias
125
+ self.mask_filter_value = None
126
+
127
+ # Flags to show warning only once
128
+ self.num_heads_warning = True
129
+ self.hidden_size_warning = True
130
+
131
+ self.shape_infer = None
132
+ self.shape_infer_done = True
133
+
134
+ def get_num_heads_and_hidden_size_from_concat(self, concat: NodeProto) -> tuple[int, int]:
135
+ """
136
+ Detect num_heads and hidden_size from Concat node in the following subgraph:
137
+
138
+ SkipLayerNormalization or EmbedLayerNormalization
139
+ / |
140
+ MatMul Shape
141
+ | |
142
+ Add Gather(indices=0)
143
+ | |
144
+ | Unsqueeze
145
+ | |
146
+ | Concat (*, -1, 12, 64)
147
+ | /
148
+ Reshape
149
+ |
150
+ Transpose
151
+ """
152
+ if len(concat.input) == 4:
153
+ num_heads = self.model.get_constant_value(concat.input[2])
154
+ head_size = self.model.get_constant_value(concat.input[3])
155
+ if (
156
+ isinstance(num_heads, np.ndarray)
157
+ and num_heads.size == 1
158
+ and isinstance(head_size, np.ndarray)
159
+ and head_size.size == 1
160
+ ):
161
+ return num_heads[0], num_heads[0] * head_size[0]
162
+
163
+ return self.num_heads, self.hidden_size
164
+
165
+ def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> tuple[int, int]:
166
+ """Detect num_heads and hidden_size from a reshape node.
167
+
168
+ Args:
169
+ reshape_q (NodeProto): reshape node for Q
170
+
171
+ Returns:
172
+ Tuple[int, int]: num_heads and hidden_size
173
+ """
174
+ # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
175
+ q_shape_value = self.model.get_constant_value(reshape_q.input[1])
176
+ if q_shape_value is None:
177
+ concat = self.model.get_parent(reshape_q, 1)
178
+ if concat is not None and concat.op_type == "Concat":
179
+ return self.get_num_heads_and_hidden_size_from_concat(concat)
180
+ logger.debug("%s is not initializer.", reshape_q.input[1])
181
+ return self.num_heads, self.hidden_size # Fall back to user specified value
182
+
183
+ if (
184
+ (not isinstance(q_shape_value, np.ndarray))
185
+ or len(q_shape_value) != 4
186
+ or (q_shape_value[2] <= 0 or q_shape_value[3] <= 0)
187
+ ):
188
+ logger.debug("q_shape_value=%s. Expected value are like [0, 0, num_heads, head_size].", q_shape_value)
189
+ return self.num_heads, self.hidden_size # Fall back to user specified value
190
+
191
+ num_heads = q_shape_value[2]
192
+ head_size = q_shape_value[3]
193
+ hidden_size = num_heads * head_size
194
+
195
+ if self.num_heads > 0 and num_heads != self.num_heads:
196
+ if self.num_heads_warning:
197
+ logger.warning(
198
+ "--num_heads is %d. Detected value is %d. Using detected value.", self.num_heads, num_heads
199
+ )
200
+ self.num_heads_warning = False # Do not show the warning more than once
201
+
202
+ if self.hidden_size > 0 and hidden_size != self.hidden_size:
203
+ if self.hidden_size_warning:
204
+ logger.warning(
205
+ "--hidden_size is %d. Detected value is %d. Using detected value.", self.hidden_size, hidden_size
206
+ )
207
+ self.hidden_size_warning = False # Do not show the warning more than once
208
+
209
+ return num_heads, hidden_size
210
+
211
+ def get_add_qk_str(self, add_qk: NodeProto):
212
+ if not self.shape_infer_done:
213
+ self.shape_infer = self.model.infer_runtime_shape(update=True)
214
+ self.shape_infer_done = True
215
+
216
+ if self.shape_infer is None:
217
+ return None
218
+
219
+ input_0_shape = self.shape_infer.get_edge_shape(add_qk.input[0])
220
+ input_1_shape = self.shape_infer.get_edge_shape(add_qk.input[1])
221
+
222
+ if input_0_shape is None or input_1_shape is None:
223
+ logger.debug("one of the inputs of %s is None", add_qk)
224
+ return None
225
+
226
+ if input_0_shape != input_1_shape:
227
+ logger.debug("the shape of two inputs of %s is not same", add_qk)
228
+ return None
229
+
230
+ return add_qk.input[1]
231
+
232
+ def reshape_add_qk(self, add_qk: str):
233
+ # Convert 4D mask from (B,1,S,T) to (B,N,S,T)
234
+ # B = batch size, N = num heads, S = source sequence length, T = target sequence length
235
+ mask_output_name = add_qk + "_mask"
236
+
237
+ # Check if concat node for (B,1,S,T) --> (B,N,S,T) already exists
238
+ concat_node = list(filter(lambda node: node.output[0] == mask_output_name, self.nodes_to_add))
239
+ if len(concat_node) == 1:
240
+ return mask_output_name
241
+
242
+ assert len(concat_node) == 0
243
+ concat_node_name = self.model.create_node_name("Concat")
244
+ concat_add_qk_fp32 = helper.make_node(
245
+ "Concat",
246
+ inputs=[add_qk for _ in range(self.num_heads)],
247
+ outputs=[mask_output_name],
248
+ name=concat_node_name,
249
+ axis=1,
250
+ )
251
+ # Add new node to graph
252
+ self.nodes_to_add.append(concat_add_qk_fp32)
253
+ self.node_name_to_graph_name[concat_node_name] = self.this_graph_name
254
+
255
+ return mask_output_name
256
+
257
+ def concat_kv(self, past_k: str, past_v: str) -> str:
258
+ """Concatenate past_k and past_v inputs to create past_kv input.
259
+
260
+ Args:
261
+ past_k (str): name of past K value
262
+ past_v (str): name of past V value
263
+
264
+ Returns:
265
+ kv_output_name (str): name of past KV value
266
+ """
267
+ # Unsqueeze K and V nodes from (B,N,P,H) to (1,B,N,P,H)
268
+ # B = batch size, N = num heads, P = past sequence length, H = head size
269
+ unsqueeze_k_name = self.model.create_node_name("Unsqueeze")
270
+ unsqueeze_v_name = self.model.create_node_name("Unsqueeze")
271
+ k_5d_name = (past_k + "_5d").replace(".", "_")
272
+ v_5d_name = (past_v + "_5d").replace(".", "_")
273
+
274
+ k_5d = helper.make_node(
275
+ "Unsqueeze",
276
+ inputs=[past_k],
277
+ outputs=[k_5d_name],
278
+ name=unsqueeze_k_name,
279
+ axes=[0],
280
+ )
281
+ v_5d = helper.make_node(
282
+ "Unsqueeze",
283
+ inputs=[past_v],
284
+ outputs=[v_5d_name],
285
+ name=unsqueeze_v_name,
286
+ axes=[0],
287
+ )
288
+
289
+ # Add unsqueeze nodes to graph
290
+ self.nodes_to_add.append(k_5d)
291
+ self.nodes_to_add.append(v_5d)
292
+ self.node_name_to_graph_name[unsqueeze_k_name] = self.this_graph_name
293
+ self.node_name_to_graph_name[unsqueeze_v_name] = self.this_graph_name
294
+
295
+ # Concat K and V to get one node of size (2,B,N,P,H)
296
+ concat_node_name = self.model.create_node_name("Concat")
297
+ kv_output_name = past_v.replace(".value", ".kv").replace(".", "_").replace("_value", "_kv")
298
+ concat_kv = helper.make_node(
299
+ "Concat",
300
+ inputs=[k_5d_name, v_5d_name],
301
+ outputs=[kv_output_name],
302
+ name=concat_node_name,
303
+ axis=0,
304
+ )
305
+
306
+ # Add concat node to graph
307
+ self.nodes_to_add.append(concat_kv)
308
+ self.node_name_to_graph_name[concat_node_name] = self.this_graph_name
309
+
310
+ return kv_output_name
311
+
312
+ def split_kv(self, present_k_name: str, present_v_name: str, kv_node: str):
313
+ """Split kv_node containing present KV values into separate present K and present V values.
314
+
315
+ Args:
316
+ present_k_name (str): name of output to store present K value in
317
+ present_v_name (str): name of output to store present V value in
318
+ kv_node (str): name of present KV values
319
+ """
320
+ # Split kv_node into present_k and present_v nodes
321
+
322
+ # Create initializers for indexing kv_node, whose shape is (2,B,N,P,H)
323
+ k_index, v_index = "index_0", "index_1"
324
+ k_dim = self.model.get_initializer(k_index)
325
+ v_dim = self.model.get_initializer(v_index)
326
+ if k_dim is None:
327
+ k_dim = numpy_helper.from_array(np.array(0, dtype="int64"), name=k_index)
328
+ self.model.add_initializer(k_dim, self.this_graph_name)
329
+ if v_dim is None:
330
+ v_dim = numpy_helper.from_array(np.array(1, dtype="int64"), name=v_index)
331
+ self.model.add_initializer(v_dim, self.this_graph_name)
332
+
333
+ # Create nodes to index kv_node
334
+ gather_k_name = self.model.create_node_name("Gather")
335
+ gather_v_name = self.model.create_node_name("Gather")
336
+ present_k = helper.make_node(
337
+ "Gather",
338
+ inputs=[kv_node, k_index],
339
+ outputs=[present_k_name],
340
+ name=gather_k_name,
341
+ axis=0,
342
+ )
343
+ present_v = helper.make_node(
344
+ "Gather",
345
+ inputs=[kv_node, v_index],
346
+ outputs=[present_v_name],
347
+ name=gather_v_name,
348
+ axis=0,
349
+ )
350
+
351
+ # Add gather nodes to graph
352
+ self.nodes_to_add.append(present_k)
353
+ self.nodes_to_add.append(present_v)
354
+ self.node_name_to_graph_name[gather_k_name] = self.this_graph_name
355
+ self.node_name_to_graph_name[gather_v_name] = self.this_graph_name
356
+
357
+ def create_combined_qkv_bias(
358
+ self,
359
+ q_add: NodeProto,
360
+ k_add: NodeProto | None,
361
+ v_add: NodeProto | None,
362
+ name_prefix: str,
363
+ ) -> NodeProto | None:
364
+ q_bias = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0])
365
+ qb = NumpyHelper.to_array(q_bias)
366
+ kb = np.zeros_like(qb)
367
+ vb = np.zeros_like(qb)
368
+ if k_add is not None:
369
+ k_bias = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0])
370
+ kb = NumpyHelper.to_array(k_bias)
371
+ if v_add is not None:
372
+ v_bias = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])
373
+ vb = NumpyHelper.to_array(v_bias)
374
+
375
+ qkv_bias = np.stack((qb, kb, vb), axis=0)
376
+ qkv_bias_dim = 3 * np.prod(qb.shape)
377
+
378
+ bias_name = name_prefix + "_qkv_bias"
379
+ self.add_initializer(
380
+ name=bias_name,
381
+ data_type=q_bias.data_type,
382
+ dims=[qkv_bias_dim],
383
+ vals=qkv_bias,
384
+ )
385
+ return bias_name
386
+
387
+ def create_packed_qkv_matmul_node(
388
+ self,
389
+ q_matmul: NodeProto,
390
+ k_matmul: NodeProto,
391
+ v_matmul: NodeProto,
392
+ q_add: NodeProto,
393
+ k_add: NodeProto | None,
394
+ v_add: NodeProto | None,
395
+ ) -> tuple[NodeProto, NodeProto, NodeProto]:
396
+ """Create packed QKV MatMul node before MultiHeadAttention node.
397
+ This is for the scenario where an Attention node should be created but cannot be created
398
+ because past_key and past_value are separate inputs and not one concatenated input.
399
+
400
+ Args:
401
+ q_matmul (NodeProto): name of MatMul from Q path - (batch_size, sequence_length, hidden_size)
402
+ k_matmul (NodeProto): name of MatMul from K path - (batch_size, sequence_length, hidden_size)
403
+ v_matmul (NodeProto): name of MatMul from V path - (batch_size, sequence_length, hidden_size)
404
+ q_add (NodeProto): name of Add from Q path
405
+ k_add (NodeProto): name of Add from K path
406
+ v_add (NodeProto): name of Add from V path
407
+
408
+ Returns:
409
+ q_output (NodeProto): Slice node for Q
410
+ k_output (NodeProto): Slice node for K
411
+ v_output (NodeProto): Slice node for V
412
+ """
413
+ matmul_node_name = self.model.create_node_name("MatMul")
414
+
415
+ # Check that input for Q, K, V is the same
416
+ assert q_matmul.input[0] == k_matmul.input[0] and k_matmul.input[0] == v_matmul.input[0]
417
+
418
+ # Created packed QKV weight
419
+ q_weight = self.model.get_initializer(q_matmul.input[1])
420
+ k_weight = self.model.get_initializer(k_matmul.input[1])
421
+ v_weight = self.model.get_initializer(v_matmul.input[1])
422
+
423
+ qw = NumpyHelper.to_array(q_weight)
424
+ kw = NumpyHelper.to_array(k_weight)
425
+ vw = NumpyHelper.to_array(v_weight)
426
+
427
+ assert qw.shape == kw.shape and kw.shape == vw.shape
428
+ d = qw.shape[0]
429
+
430
+ qkv_weight = np.stack((qw, kw, vw), axis=1).reshape((d, 3 * d))
431
+ qkv_weight_name = matmul_node_name + "_qkv_weight"
432
+
433
+ self.add_initializer(
434
+ name=qkv_weight_name,
435
+ data_type=q_weight.data_type,
436
+ dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
437
+ vals=qkv_weight,
438
+ )
439
+
440
+ # Created packed QKV MatMul with output (B, S, 3*D)
441
+ # Output is of the form:
442
+ #
443
+ # [[[Q Q ... Q Q K K ... K K V V ... V V]]]
444
+ # [Q Q ... Q Q K K ... K K V V ... V V]
445
+ # .
446
+ # .
447
+ # .
448
+ # [[Q Q ... Q Q K K ... K K V V ... V V]
449
+ # [Q Q ... Q Q K K ... K K V V ... V V]]]
450
+ qkv_matmul_output = matmul_node_name + "_qkv_out"
451
+ qkv_matmul = helper.make_node(
452
+ "MatMul",
453
+ inputs=[q_matmul.input[0], qkv_weight_name],
454
+ outputs=[qkv_matmul_output],
455
+ name=matmul_node_name,
456
+ )
457
+ self.node_name_to_graph_name[matmul_node_name] = self.this_graph_name
458
+
459
+ qkv_nodes = [qkv_matmul]
460
+
461
+ # Create Slice nodes to access Q, K, V
462
+ q_slice_name = matmul_node_name + "_q_start_index"
463
+ self.add_initializer(name=q_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[0], raw=False)
464
+ k_slice_name = matmul_node_name + "_k_start_index"
465
+ self.add_initializer(name=k_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[d], raw=False)
466
+ v_slice_name = matmul_node_name + "_v_start_index"
467
+ self.add_initializer(name=v_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[2 * d], raw=False)
468
+ end_of_qkv_name = matmul_node_name + "_end_of_qkv_index"
469
+ self.add_initializer(name=end_of_qkv_name, data_type=TensorProto.INT64, dims=[1], vals=[3 * d], raw=False)
470
+ qkv_last_axis_name = matmul_node_name + "_qkv_last_axis"
471
+ self.add_initializer(name=qkv_last_axis_name, data_type=TensorProto.INT64, dims=[1], vals=[-1], raw=False)
472
+
473
+ q_slice_output = matmul_node_name + "_q_out"
474
+ q_slice = helper.make_node(
475
+ "Slice",
476
+ inputs=[qkv_matmul_output, q_slice_name, k_slice_name, qkv_last_axis_name],
477
+ outputs=[q_slice_output],
478
+ name=self.model.create_node_name("Slice"),
479
+ )
480
+ self.node_name_to_graph_name[q_slice.name] = self.this_graph_name
481
+ k_slice_output = matmul_node_name + "_k_out"
482
+ k_slice = helper.make_node(
483
+ "Slice",
484
+ inputs=[qkv_matmul_output, k_slice_name, v_slice_name, qkv_last_axis_name],
485
+ outputs=[k_slice_output],
486
+ name=self.model.create_node_name("Slice"),
487
+ )
488
+ self.node_name_to_graph_name[k_slice.name] = self.this_graph_name
489
+ v_slice_output = matmul_node_name + "_v_out"
490
+ v_slice = helper.make_node(
491
+ "Slice",
492
+ inputs=[qkv_matmul_output, v_slice_name, end_of_qkv_name, qkv_last_axis_name],
493
+ outputs=[v_slice_output],
494
+ name=self.model.create_node_name("Slice"),
495
+ )
496
+ self.node_name_to_graph_name[v_slice.name] = self.this_graph_name
497
+
498
+ q_output = q_slice
499
+ k_output = k_slice
500
+ v_output = v_slice
501
+ qkv_nodes.extend([q_slice, k_slice, v_slice])
502
+
503
+ if self.disable_multi_head_attention_bias:
504
+ if q_add is not None:
505
+ initializer_input = 1 if self.model.get_initializer(q_add.input[1]) else 0
506
+ if np.any(NumpyHelper.to_array(self.model.get_initializer(q_add.input[initializer_input]))):
507
+ q_add.input[1 - initializer_input] = q_slice_output
508
+ q_output = q_add
509
+ qkv_nodes.append(q_add)
510
+ self.node_name_to_graph_name[q_add.name] = self.this_graph_name
511
+ if k_add is not None:
512
+ initializer_input = 1 if self.model.get_initializer(k_add.input[1]) else 0
513
+ if np.any(NumpyHelper.to_array(self.model.get_initializer(k_add.input[initializer_input]))):
514
+ k_add.input[1 - initializer_input] = k_slice_output
515
+ k_output = k_add
516
+ qkv_nodes.append(k_add)
517
+ self.node_name_to_graph_name[k_add.name] = self.this_graph_name
518
+ if v_add is not None:
519
+ initializer_input = 1 if self.model.get_initializer(v_add.input[1]) else 0
520
+ if np.any(NumpyHelper.to_array(self.model.get_initializer(v_add.input[initializer_input]))):
521
+ v_add.input[1 - initializer_input] = v_slice_output
522
+ v_output = v_add
523
+ qkv_nodes.append(v_add)
524
+ self.node_name_to_graph_name[v_add.name] = self.this_graph_name
525
+
526
+ # Add nodes to graph
527
+ self.nodes_to_add.extend(qkv_nodes)
528
+ return q_output, k_output, v_output
529
+
530
+ # This function is used in child classes for bart or conformer model.
531
+ def create_multihead_attention_node(
532
+ self,
533
+ q_matmul: NodeProto,
534
+ k_matmul: NodeProto | str | None,
535
+ v_matmul: NodeProto | str | None,
536
+ q_add: NodeProto,
537
+ k_add: NodeProto | None,
538
+ v_add: NodeProto | None,
539
+ num_heads: int,
540
+ hidden_size: int,
541
+ output: str,
542
+ key_padding_mask: str = "",
543
+ add_qk: str = "",
544
+ unidirectional: bool = False,
545
+ past_k: str = "",
546
+ past_v: str = "",
547
+ present_k: str = "",
548
+ present_v: str = "",
549
+ packed_qkv: bool = False,
550
+ ) -> NodeProto | None:
551
+ """Create a MultiHeadAttention node.
552
+
553
+ Args:
554
+ q_matmul (NodeProto): name of MatMul from Q path - (batch_size, sequence_length, hidden_size)
555
+ k_matmul (NodeProto): name of MatMul from K path - (batch_size, sequence_length, hidden_size) or (batch_size, num_heads, past_sequence_length, head_size)
556
+ v_matmul (NodeProto): name of MatMul from V path - (batch_size, sequence_length, hidden_size) or (batch_size, num_heads, past_sequence_length, head_size)
557
+ q_add (NodeProto): name of Add from Q path
558
+ k_add (NodeProto): name of Add from K path
559
+ v_add (NodeProto): name of Add from V path
560
+ num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
561
+ hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
562
+ output (str): output name of MHA
563
+ key_padding_mask (str): name of key padding mask
564
+ add_qk (str): name of add after Q x K'
565
+ unidirectional (bool): whether to apply causal attention mask automatically or not
566
+ past_k (str): name of past K value - (batch_size, num_heads, past_sequence_length, head_size)
567
+ past_v (str): name of past V value - (batch_size, num_heads, past_sequence_length, head_size)
568
+ present_k (str): name of present K value - (batch_size, num_heads, sequence_length, head_size)
569
+ present_v (str): name of present V value - (batch_size, num_heads, sequence_length, head_size)
570
+ packed_qkv (bool): whether to combine MatMuls from Q, K, V paths
571
+ Note: This is for the scenario where an Attention node should be created but cannot be created
572
+ because past_key and past_value are separate inputs and not one concatenated input.
573
+
574
+ Returns:
575
+ Union[NodeProto, None]: the node created or None if failed.
576
+ """
577
+ # B = batch size, N = num heads, P = past seq len, H = head size
578
+ assert num_heads > 0
579
+
580
+ if hidden_size > 0 and (hidden_size % num_heads) != 0:
581
+ logger.debug("input hidden size %d is not a multiple of num of heads %d", hidden_size, num_heads)
582
+ return None
583
+
584
+ graph_input_names = {node.name for node in self.model.graph().input}
585
+ mha_node_name = self.model.create_node_name("Attention")
586
+
587
+ # Add initial Q/K/V inputs for MHA
588
+ mha_inputs = []
589
+ if packed_qkv:
590
+ q_slice, k_slice, v_slice = self.create_packed_qkv_matmul_node(
591
+ q_matmul,
592
+ k_matmul,
593
+ v_matmul,
594
+ q_add,
595
+ k_add,
596
+ v_add,
597
+ )
598
+ mha_inputs.extend([q_slice.output[0], k_slice.output[0], v_slice.output[0]])
599
+ elif isinstance(k_matmul, NodeProto) and isinstance(v_matmul, NodeProto):
600
+ if self.disable_multi_head_attention_bias:
601
+ mha_inputs.extend([q_add.output[0], k_matmul.output[0], v_add.output[0]])
602
+ else:
603
+ mha_inputs.extend([q_matmul.output[0], k_matmul.output[0], v_matmul.output[0]])
604
+ elif (
605
+ isinstance(k_matmul, str)
606
+ and isinstance(v_matmul, str)
607
+ and k_matmul in graph_input_names
608
+ and v_matmul in graph_input_names
609
+ ):
610
+ if self.disable_multi_head_attention_bias:
611
+ mha_inputs.extend([q_add.output[0], k_matmul, v_matmul])
612
+ else:
613
+ mha_inputs.extend([q_matmul.output[0], k_matmul, v_matmul])
614
+ else:
615
+ return None
616
+
617
+ # Add bias to inputs for MHA
618
+ # Bias for cross attention is not fully supported in DMMHA and cpu MHA kernels since they assume
619
+ # bias has been added to key and value when they are in BNSH format, so only bias for query is used.
620
+ # Need add checks if we found such assumption is not true.
621
+ if not self.disable_multi_head_attention_bias:
622
+ bias_name = self.create_combined_qkv_bias(q_add, k_add, v_add, mha_node_name)
623
+ mha_inputs.append(bias_name)
624
+ else:
625
+ mha_inputs.append("")
626
+
627
+ # Add optional inputs for MHA
628
+ if past_k and past_v:
629
+ mha_inputs.extend([key_padding_mask, add_qk, past_k, past_v])
630
+ elif key_padding_mask or add_qk:
631
+ mha_inputs.extend([key_padding_mask, add_qk])
632
+
633
+ # Add outputs for MHA
634
+ mha_outputs = [output]
635
+ if present_k and present_v:
636
+ mha_outputs.extend([present_k, present_v])
637
+
638
+ mha_node = helper.make_node(
639
+ "MultiHeadAttention",
640
+ inputs=mha_inputs,
641
+ outputs=mha_outputs,
642
+ name=mha_node_name,
643
+ )
644
+ mha_node.domain = "com.microsoft"
645
+ mha_node.attribute.append(helper.make_attribute("num_heads", num_heads))
646
+ if unidirectional:
647
+ mha_node.attribute.append(helper.make_attribute("unidirectional", int(unidirectional)))
648
+
649
+ self.increase_counter("MultiHeadAttention")
650
+ return mha_node
651
+
652
+ def create_attention_node(
653
+ self,
654
+ mask_index: str | None,
655
+ q_matmul: NodeProto,
656
+ k_matmul: NodeProto,
657
+ v_matmul: NodeProto,
658
+ q_add: NodeProto,
659
+ k_add: NodeProto,
660
+ v_add: NodeProto,
661
+ num_heads: int,
662
+ hidden_size: int,
663
+ first_input: str,
664
+ output: str,
665
+ add_qk_str: str = "",
666
+ causal: bool = False,
667
+ past_k: str = "",
668
+ past_v: str = "",
669
+ present_k: str = "",
670
+ present_v: str = "",
671
+ scale: float | None = None,
672
+ ) -> NodeProto | None:
673
+ """Create an Attention node.
674
+
675
+ Args:
676
+ mask_index (str | None): mask input
677
+ q_matmul (NodeProto): MatMul node in fully connection for Q
678
+ k_matmul (NodeProto): MatMul node in fully connection for K
679
+ v_matmul (NodeProto): MatMul node in fully connection for V
680
+ q_add (NodeProto): Add bias node in fully connection for Q
681
+ k_add (NodeProto): Add bias node in fully connection for K
682
+ v_add (NodeProto): Add bias node in fully connection for V
683
+ num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
684
+ hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
685
+ first_input (str): first input name
686
+ output (str): output name
687
+ add_qk_str (str): name of Add node after Q x K'
688
+ causal: whether it is uni-directional mask.
689
+ past_k (str): name of input for past K value
690
+ past_v (str): name of input for past V value
691
+ present_k (str): name of output to store present K value
692
+ present_v (str): name of output to store present V value
693
+ scale: scale before softmax
694
+
695
+ Returns:
696
+ Union[NodeProto, None]: the node created or None if failed.
697
+ """
698
+ assert num_heads > 0
699
+
700
+ if hidden_size > 0 and (hidden_size % num_heads) != 0:
701
+ logger.debug("input hidden size %d is not a multiple of num of heads %d", hidden_size, num_heads)
702
+ return None
703
+
704
+ has_bias = True
705
+ if q_add is None and k_add is None and v_add is None:
706
+ has_bias = False
707
+
708
+ q_weight = self.model.get_initializer(q_matmul.input[1])
709
+ k_weight = self.model.get_initializer(k_matmul.input[1])
710
+ v_weight = self.model.get_initializer(v_matmul.input[1])
711
+
712
+ q_bias, k_bias, v_bias = None, None, None
713
+ if has_bias:
714
+ q_bias = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0])
715
+ k_bias = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0])
716
+ v_bias = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])
717
+
718
+ if not (k_weight and v_weight and q_bias and k_bias):
719
+ return None
720
+
721
+ if q_weight is None:
722
+ print(
723
+ f"{q_matmul.input[1]} is not an initializer. "
724
+ "Please set do_constant_folding=True in torch.onnx.export to unblock attention fusion"
725
+ )
726
+ return None
727
+
728
+ qw = NumpyHelper.to_array(q_weight)
729
+ kw = NumpyHelper.to_array(k_weight)
730
+ vw = NumpyHelper.to_array(v_weight)
731
+
732
+ # assert q and k have same shape as expected
733
+ assert qw.shape == kw.shape
734
+
735
+ qw_in_size = qw.shape[0]
736
+ kw_in_size = kw.shape[0]
737
+ vw_in_size = vw.shape[0]
738
+
739
+ assert qw_in_size == kw_in_size == vw_in_size
740
+
741
+ if hidden_size > 0 and hidden_size != qw_in_size:
742
+ logger.warning(
743
+ "Input hidden size (%d) is not same as weight matrix dimension of q,k,v (%d). "
744
+ "Please provide a correct input hidden size or pass in 0",
745
+ hidden_size,
746
+ qw_in_size,
747
+ )
748
+
749
+ is_qkv_diff_dims = False
750
+ if qw.shape != vw.shape:
751
+ is_qkv_diff_dims = True
752
+
753
+ # All the matrices can have the same shape or q, k matrices can have the same shape with v being different
754
+ # For 2d weights, the shapes would be [in_size, out_size].
755
+ # For 3d weights, shape would be [in_size, a, b] where a*b = out_size
756
+ qw_out_size = np.prod(qw.shape[1:])
757
+ kw_out_size = np.prod(kw.shape[1:])
758
+ vw_out_size = np.prod(vw.shape[1:])
759
+
760
+ qkv_weight_dim = 0
761
+ if is_qkv_diff_dims:
762
+ qkv_weight = np.concatenate((qw, kw, vw), axis=1)
763
+ qkv_weight_dim = qw_out_size + kw_out_size + vw_out_size
764
+ else:
765
+ qkv_weight = np.stack((qw, kw, vw), axis=1)
766
+ qkv_weight_dim = 3 * qw_out_size
767
+
768
+ qkv_bias_dim = 0
769
+ qkv_bias: np.ndarray | None = None
770
+ if has_bias:
771
+ qb = NumpyHelper.to_array(q_bias)
772
+ kb = NumpyHelper.to_array(k_bias)
773
+ vb = NumpyHelper.to_array(v_bias)
774
+
775
+ q_bias_shape = np.prod(qb.shape)
776
+ k_bias_shape = np.prod(kb.shape)
777
+ v_bias_shape = np.prod(vb.shape)
778
+
779
+ assert q_bias_shape == k_bias_shape == qw_out_size
780
+ assert v_bias_shape == vw_out_size
781
+
782
+ if is_qkv_diff_dims:
783
+ qkv_bias = np.concatenate((qb, kb, vb), axis=0)
784
+ qkv_bias_dim = q_bias_shape + k_bias_shape + v_bias_shape
785
+ else:
786
+ qkv_bias = np.stack((qb, kb, vb), axis=0)
787
+ qkv_bias_dim = 3 * q_bias_shape
788
+
789
+ attention_node_name = self.model.create_node_name("Attention")
790
+
791
+ if not self.use_multi_head_attention:
792
+ self.add_initializer(
793
+ name=attention_node_name + "_qkv_weight",
794
+ data_type=q_weight.data_type,
795
+ dims=[qw_in_size, int(qkv_weight_dim)],
796
+ vals=qkv_weight,
797
+ )
798
+
799
+ if has_bias:
800
+ self.add_initializer(
801
+ name=attention_node_name + "_qkv_bias",
802
+ data_type=q_bias.data_type,
803
+ dims=[int(qkv_bias_dim)],
804
+ vals=qkv_bias,
805
+ )
806
+
807
+ # For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights.
808
+ if self.use_multi_head_attention:
809
+ if add_qk_str:
810
+ logger.debug("MultiHeadAttention does not support relative_position_bias: cannot fuse the attention.")
811
+ return None
812
+
813
+ attention_inputs = [
814
+ q_matmul.output[0],
815
+ k_matmul.output[0],
816
+ v_matmul.output[0],
817
+ attention_node_name + "_qkv_bias",
818
+ ]
819
+
820
+ if mask_index is not None:
821
+ attention_inputs.append(mask_index)
822
+
823
+ attention_node = helper.make_node(
824
+ "MultiHeadAttention",
825
+ inputs=attention_inputs,
826
+ outputs=[output],
827
+ name=attention_node_name,
828
+ )
829
+ self.increase_counter("MultiHeadAttention")
830
+
831
+ else:
832
+ attention_inputs = [
833
+ first_input,
834
+ attention_node_name + "_qkv_weight",
835
+ attention_node_name + "_qkv_bias" if has_bias else "",
836
+ ]
837
+ if mask_index is not None:
838
+ attention_inputs.append(mask_index)
839
+ else:
840
+ attention_inputs.append("")
841
+
842
+ past_exists = past_k and past_v
843
+ if past_exists:
844
+ past_kv = self.concat_kv(past_k, past_v)
845
+ attention_inputs.append(past_kv)
846
+
847
+ if add_qk_str:
848
+ # Add additional add to attention node (input name = attention_bias)
849
+ if not past_exists:
850
+ attention_inputs.append("")
851
+ attention_inputs.append(add_qk_str)
852
+
853
+ attention_outputs = [output]
854
+ if present_k and present_v:
855
+ present_kv = present_k.replace(".key", "").replace("_key", "").replace(".", "_")
856
+ attention_outputs.append(present_kv)
857
+ self.split_kv(present_k, present_v, present_kv)
858
+
859
+ attention_node = helper.make_node(
860
+ "Attention",
861
+ inputs=attention_inputs,
862
+ outputs=attention_outputs,
863
+ name=attention_node_name,
864
+ )
865
+ self.increase_counter("Attention")
866
+
867
+ attention_node.domain = "com.microsoft"
868
+ attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
869
+
870
+ if causal:
871
+ attention_node.attribute.extend([helper.make_attribute("unidirectional", 1)])
872
+
873
+ if scale is not None:
874
+ attention_node.attribute.extend([helper.make_attribute("scale", scale)])
875
+
876
+ if is_qkv_diff_dims:
877
+ attention_node.attribute.extend(
878
+ [helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])]
879
+ )
880
+
881
+ if self.mask_filter_value is not None:
882
+ attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
883
+
884
+ return attention_node
885
+
886
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
887
+ # Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm
888
+ # Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern
889
+ normalize_node = node
890
+ start_node = normalize_node
891
+ if normalize_node.op_type == "LayerNormalization":
892
+ add_before_layernorm = self.model.match_parent(normalize_node, "Add", 0)
893
+ if add_before_layernorm is not None:
894
+ start_node = add_before_layernorm
895
+ else:
896
+ return
897
+
898
+ # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
899
+ qkv_nodes = self.model.match_parent_path(
900
+ start_node,
901
+ ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
902
+ [None, None, 0, 0, 0],
903
+ )
904
+ einsum_node = None
905
+ if qkv_nodes is not None:
906
+ (_, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
907
+ else:
908
+ # Match Albert
909
+ qkv_nodes = self.model.match_parent_path(
910
+ start_node, ["Add", "Einsum", "Transpose", "MatMul"], [1, None, 0, 0]
911
+ )
912
+ if qkv_nodes is not None:
913
+ (_, einsum_node, transpose_qkv, matmul_qkv) = qkv_nodes
914
+ else:
915
+ return
916
+
917
+ other_inputs = []
918
+ for _i, node_input in enumerate(start_node.input):
919
+ if node_input not in output_name_to_node:
920
+ continue
921
+
922
+ if node_input == qkv_nodes[0].output[0]:
923
+ continue
924
+ other_inputs.append(node_input)
925
+ if len(other_inputs) != 1:
926
+ return
927
+
928
+ root_input = other_inputs[0]
929
+
930
+ # Match flaubert Mask
931
+ # |
932
+ # Mul --> LayerNormalization --> Attention --> MatMul --> Add
933
+ # | |
934
+ # | |
935
+ # +---------------------------------------------------------
936
+ mul_before_layernorm = self.model.match_parent(start_node, "Mul", 0)
937
+ if mul_before_layernorm is not None:
938
+ mul_children = input_name_to_nodes[mul_before_layernorm.output[0]]
939
+ if mul_children is not None and len(mul_children) == 2:
940
+ layernorm_node = mul_children[1]
941
+ if layernorm_node.op_type == "LayerNormalization":
942
+ root_input = layernorm_node.output[0]
943
+ else:
944
+ return
945
+ elif mul_children is not None and len(mul_children) == 5:
946
+ root_input = mul_before_layernorm.output[0]
947
+ else:
948
+ return
949
+ elif normalize_node.op_type == "LayerNormalization":
950
+ children = input_name_to_nodes[root_input]
951
+ for child in children:
952
+ if child.op_type == "LayerNormalization":
953
+ root_input = child.output[0]
954
+
955
+ # When Add before the LayerNormalization produces an output
956
+ # that is consumed by some other nodes other than the LayerNormalization itself,
957
+ # fused SkipLayerNormalization will have several outputs.
958
+ # In this case we need to pick the one used in Attention
959
+ # For example, this is the case for ViT
960
+ # SkipLayerNormalization --> Attention --> MatMul --> Add --> SkipLayerNormalization
961
+ # | |
962
+ # | |
963
+ # +---------------------------------------------------------------------+
964
+ parent_node = output_name_to_node[root_input]
965
+ if parent_node.op_type == "SkipLayerNormalization" and len(parent_node.output) == 4:
966
+ root_input = parent_node.output[0]
967
+
968
+ children = input_name_to_nodes[root_input]
969
+ children_types = [child.op_type for child in children]
970
+ if children_types.count("MatMul") != 3:
971
+ return
972
+
973
+ v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None])
974
+ if v_nodes is None:
975
+ logger.debug("fuse_attention: failed to match v path")
976
+ return
977
+ (_, _, add_v, matmul_v) = v_nodes
978
+
979
+ is_distill = False
980
+ is_distill_add = False
981
+ is_no_mask_attention = False
982
+ is_sdpa = False
983
+ qk_paths = {
984
+ "path1": (["Softmax", "Add", "Div", "MatMul"], [0, 0, None, 0]),
985
+ "path2": (["Softmax", "Add", "Mul", "MatMul"], [0, 0, None, 0]),
986
+ "path3": (["Softmax", "Where", "MatMul", "Div"], [0, 0, 2, 0]),
987
+ "path4": (["Softmax", "Add", "Where", "MatMul"], [0, 0, 0, 2]),
988
+ "path5": (["Softmax", "Div", "MatMul"], [0, 0, 0]),
989
+ "sdpa": (["Softmax", "Add", "MatMul", "Mul", "Sqrt"], [0, 0, None, 0, 1]),
990
+ }
991
+
992
+ qk_nodes = None
993
+ for k, v in qk_paths.items():
994
+ qk_nodes = self.model.match_parent_path(matmul_qkv, v[0], v[1])
995
+ if qk_nodes is None:
996
+ continue
997
+ if k == "path3":
998
+ is_distill = True
999
+ elif k == "path4":
1000
+ is_distill_add = True
1001
+ elif k == "path5":
1002
+ is_no_mask_attention = True
1003
+ elif k == "sdpa":
1004
+ is_sdpa = True
1005
+ break
1006
+
1007
+ if qk_nodes is None:
1008
+ logger.debug("fuse_attention: failed to match qk path")
1009
+ return
1010
+
1011
+ add_qk = None
1012
+ matmul_qk = None
1013
+ where_qk = None
1014
+ after_q = None
1015
+ if is_distill:
1016
+ (_, where_qk, matmul_qk, _) = qk_nodes
1017
+ elif is_distill_add:
1018
+ (_, add_qk, where_qk, matmul_qk) = qk_nodes
1019
+ elif is_no_mask_attention:
1020
+ (_, _, matmul_qk) = qk_nodes
1021
+ elif is_sdpa:
1022
+ (_, add_qk, matmul_qk, after_q, _) = qk_nodes
1023
+ else:
1024
+ (_, add_qk, _, matmul_qk) = qk_nodes
1025
+
1026
+ after_q = after_q or matmul_qk
1027
+ q_nodes = self.model.match_parent_path(after_q, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, None])
1028
+ if q_nodes is None:
1029
+ q_nodes = self.model.match_parent_path(
1030
+ after_q,
1031
+ ["Div", "Transpose", "Reshape", "Add", "MatMul"],
1032
+ [0, 0, 0, 0, None],
1033
+ )
1034
+ if q_nodes is None:
1035
+ logger.debug("fuse_attention: failed to match q path")
1036
+ return
1037
+ reshape_q = q_nodes[-3]
1038
+ add_q = q_nodes[-2]
1039
+ matmul_q = q_nodes[-1]
1040
+
1041
+ after_k = matmul_qk
1042
+ if is_sdpa:
1043
+ mul_k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Sqrt"], [1, None])
1044
+ if mul_k_nodes is None:
1045
+ logger.debug("fuse_attention: failed to match mul sqrt q path")
1046
+ return
1047
+ (after_k, _) = mul_k_nodes
1048
+
1049
+ k_nodes = self.model.match_parent_path(
1050
+ after_k, ["Transpose", "Reshape", "Add", "MatMul"], [0 if is_sdpa else 1, 0, 0, None]
1051
+ )
1052
+ if k_nodes is None:
1053
+ k_nodes = self.model.match_parent_path(
1054
+ matmul_qk,
1055
+ ["Transpose", "Transpose", "Reshape", "Add", "MatMul"],
1056
+ [1, 0, 0, 0, None],
1057
+ )
1058
+ if k_nodes is None:
1059
+ logger.debug("fuse_attention: failed to match k path")
1060
+ return
1061
+ add_k = k_nodes[-2]
1062
+ matmul_k = k_nodes[-1]
1063
+
1064
+ # Note that Cast might be removed by OnnxRuntime so we match two patterns here.
1065
+ mask_nodes = None
1066
+ add_qk_str = ""
1067
+ if is_distill:
1068
+ _, mask_nodes, _ = self.model.match_parent_paths(
1069
+ where_qk,
1070
+ [
1071
+ (["Expand", "Reshape", "Equal"], [0, 0, 0]),
1072
+ (["Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0]),
1073
+ (["Cast", "Expand", "Reshape", "Equal"], [0, 0, 0, 0]),
1074
+ ],
1075
+ output_name_to_node,
1076
+ )
1077
+ elif is_distill_add:
1078
+ _, mask_nodes, _ = self.model.match_parent_paths(
1079
+ where_qk,
1080
+ [
1081
+ (["Cast", "Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0, 0]),
1082
+ (["Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0]),
1083
+ ],
1084
+ output_name_to_node,
1085
+ )
1086
+ if add_qk is not None:
1087
+ add_qk_str = self.get_add_qk_str(add_qk)
1088
+ if add_qk_str is None:
1089
+ logger.debug("fuse_attention: failed to verify shape inference of %s", add_qk)
1090
+ return
1091
+ elif is_no_mask_attention:
1092
+ pass
1093
+ else:
1094
+ _, mask_nodes, _ = self.model.match_parent_paths(
1095
+ add_qk,
1096
+ [
1097
+ (["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0, 0]),
1098
+ (["Mul", "Sub", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0]),
1099
+ # The following two patterns are for SDPA.
1100
+ (["Where", "Cast", "Sub", "Expand", "Unsqueeze", "Unsqueeze"], [None, 0, 0, 1, 0, 0]),
1101
+ (["Where", "Cast", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], [None, 0, 0, 1, 0, 0, 0]),
1102
+ ],
1103
+ output_name_to_node,
1104
+ )
1105
+ if not is_no_mask_attention and mask_nodes is None:
1106
+ logger.debug("fuse_attention: failed to match mask path")
1107
+ return
1108
+
1109
+ if not is_no_mask_attention and len(mask_nodes) > 1:
1110
+ _, mul_val = self.model.get_constant_input(mask_nodes[0])
1111
+ # The mask value shall be a float scalar (usually is the lowest float value).
1112
+ if (
1113
+ (mul_val is None)
1114
+ or not (isinstance(mul_val, np.ndarray) and mul_val.size == 1)
1115
+ or (float(mul_val) >= 0)
1116
+ ):
1117
+ return
1118
+ if float(mul_val) != -10000:
1119
+ self.mask_filter_value = float(mul_val)
1120
+
1121
+ if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input:
1122
+ mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) if not is_no_mask_attention else None
1123
+
1124
+ attention_last_node = reshape_qkv if einsum_node is None else transpose_qkv
1125
+
1126
+ q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
1127
+ if q_num_heads <= 0 or q_hidden_size <= 0:
1128
+ logger.warning(
1129
+ "Failed to detect num_heads and hidden_size for Attention fusion. "
1130
+ "Please specify those parameters in argument."
1131
+ )
1132
+ return
1133
+
1134
+ # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
1135
+ # the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately
1136
+ new_node = self.create_attention_node(
1137
+ mask_index=mask_index,
1138
+ q_matmul=matmul_q,
1139
+ k_matmul=matmul_k,
1140
+ v_matmul=matmul_v,
1141
+ q_add=add_q,
1142
+ k_add=add_k,
1143
+ v_add=add_v,
1144
+ num_heads=q_num_heads,
1145
+ hidden_size=q_hidden_size,
1146
+ first_input=root_input,
1147
+ output=attention_last_node.output[0],
1148
+ add_qk_str=add_qk_str,
1149
+ )
1150
+
1151
+ if new_node is None:
1152
+ return
1153
+
1154
+ self.nodes_to_add.append(new_node)
1155
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
1156
+
1157
+ if einsum_node is not None:
1158
+ unique_index = einsum_node.input[0]
1159
+ new_edge = "edge_modified_" + unique_index
1160
+
1161
+ shape_tensor = self.add_initializer(
1162
+ name="shape_modified_tensor" + unique_index,
1163
+ data_type=TensorProto.INT64,
1164
+ dims=[4],
1165
+ vals=[0, 0, q_num_heads, int(q_hidden_size / q_num_heads)],
1166
+ raw=False,
1167
+ )
1168
+
1169
+ self.model.add_node(
1170
+ helper.make_node(
1171
+ "Reshape",
1172
+ [attention_last_node.output[0], shape_tensor.name],
1173
+ [new_edge],
1174
+ "reshape_modified_" + unique_index,
1175
+ ),
1176
+ self.this_graph_name,
1177
+ )
1178
+ einsum_node.input[0] = new_edge
1179
+
1180
+ self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
1181
+ self.nodes_to_remove.extend(qk_nodes)
1182
+
1183
+ # For MultiHeadAttention operator, MatMul nodes for Q/K/V projection shall not be fused.
1184
+ self.nodes_to_remove.extend(q_nodes if not self.use_multi_head_attention else q_nodes[:-1])
1185
+ self.nodes_to_remove.extend(k_nodes if not self.use_multi_head_attention else k_nodes[:-1])
1186
+ self.nodes_to_remove.extend(v_nodes if not self.use_multi_head_attention else v_nodes[:-1])
1187
+
1188
+ # Use prune graph to remove mask nodes since they are shared by all attention nodes.
1189
+ self.prune_graph = True