onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1307 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+
7
+ import numpy as np
8
+ from fusion_base import Fusion
9
+ from fusion_utils import NumpyHelper
10
+ from onnx import NodeProto, TensorProto, helper
11
+ from onnx_model import OnnxModel
12
+
13
+ logger = getLogger(__name__)
14
+
15
+
16
+ class FusionAttentionUnet(Fusion):
17
+ """
18
+ Fuse Attention subgraph of UNet into one Attention node.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ model: OnnxModel,
24
+ hidden_size: int,
25
+ num_heads: int,
26
+ is_cross_attention: bool,
27
+ enable_packed_qkv: bool,
28
+ enable_packed_kv: bool,
29
+ ):
30
+ super().__init__(
31
+ model,
32
+ "Attention" if is_cross_attention and enable_packed_qkv else "MultiHeadAttention",
33
+ ["LayerNormalization"],
34
+ )
35
+ self.hidden_size = hidden_size
36
+ self.num_heads = num_heads
37
+ self.is_cross_attention = is_cross_attention
38
+
39
+ # Note: pack Q/K/V or K/V weights into one tensor make it harder for updating initializers for LoRA.
40
+ # To support LoRA, it is better to use separated Q, K and V inputs in offline optimization,
41
+ # and CUDA operator pre-packs those tensors to preferred format based on available kernels.
42
+ # In this way, we can support LoRA and get optimal performance at same time.
43
+ self.enable_packed_qkv = enable_packed_qkv
44
+ self.enable_packed_kv = enable_packed_kv
45
+
46
+ # Flags to show warning only once
47
+ self.num_heads_warning = True
48
+ self.hidden_size_warning = True
49
+
50
+ def get_num_heads(self, reshape_q: NodeProto, is_torch2: bool = False) -> int:
51
+ """Detect num_heads from a reshape node.
52
+
53
+ Args:
54
+ reshape_q (NodeProto): reshape node for Q
55
+ is_torch2 (bool): graph pattern is from PyTorch 2.*
56
+ Returns:
57
+ int: num_heads, or 0 if not found
58
+ """
59
+ num_heads = 0
60
+ if is_torch2:
61
+ # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
62
+ reshape_parent = self.model.get_parent(reshape_q, 1)
63
+ if reshape_parent and reshape_parent.op_type == "Concat" and len(reshape_parent.input) == 4:
64
+ num_heads = self.model.get_constant_value(reshape_parent.input[2])
65
+ if isinstance(num_heads, np.ndarray) and list(num_heads.shape) == [1]:
66
+ num_heads = int(num_heads)
67
+ else:
68
+ # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
69
+ q_shape_value = self.model.get_constant_value(reshape_q.input[1])
70
+ if isinstance(q_shape_value, np.ndarray) and list(q_shape_value.shape) == [4]:
71
+ num_heads = int(q_shape_value[2])
72
+
73
+ if isinstance(num_heads, int) and num_heads > 0:
74
+ return num_heads
75
+
76
+ return 0
77
+
78
+ def get_hidden_size(self, layernorm_node):
79
+ """Detect hidden_size from LayerNormalization node.
80
+ Args:
81
+ layernorm_node (NodeProto): LayerNormalization node before Q, K and V
82
+ Returns:
83
+ int: hidden_size, or 0 if not found
84
+ """
85
+ layernorm_bias = self.model.get_initializer(layernorm_node.input[2])
86
+ if layernorm_bias:
87
+ return NumpyHelper.to_array(layernorm_bias).shape[0]
88
+
89
+ return 0
90
+
91
+ def get_num_heads_and_hidden_size(
92
+ self, reshape_q: NodeProto, layernorm_node: NodeProto, is_torch2: bool = False
93
+ ) -> tuple[int, int]:
94
+ """Detect num_heads and hidden_size.
95
+
96
+ Args:
97
+ reshape_q (NodeProto): reshape node for Q
98
+ is_torch2 (bool): graph pattern is from PyTorch 2.*
99
+ layernorm_node (NodeProto): LayerNormalization node before Q, K, V
100
+ Returns:
101
+ Tuple[int, int]: num_heads and hidden_size
102
+ """
103
+ num_heads = self.get_num_heads(reshape_q, is_torch2)
104
+ if num_heads <= 0:
105
+ num_heads = self.num_heads # Fall back to user specified value
106
+
107
+ if self.num_heads > 0 and num_heads != self.num_heads:
108
+ if self.num_heads_warning:
109
+ logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
110
+ self.num_heads_warning = False # Do not show the warning more than once
111
+
112
+ hidden_size = self.get_hidden_size(layernorm_node)
113
+ if hidden_size <= 0:
114
+ hidden_size = self.hidden_size # Fall back to user specified value
115
+
116
+ if self.hidden_size > 0 and hidden_size != self.hidden_size:
117
+ if self.hidden_size_warning:
118
+ logger.warning(
119
+ f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
120
+ )
121
+ self.hidden_size_warning = False # Do not show the warning more than once
122
+
123
+ return num_heads, hidden_size
124
+
125
+ def create_attention_node(
126
+ self,
127
+ q_matmul: NodeProto,
128
+ k_matmul: NodeProto,
129
+ v_matmul: NodeProto,
130
+ num_heads: int,
131
+ hidden_size: int,
132
+ input: str,
133
+ output: str,
134
+ ) -> NodeProto | None:
135
+ """Create an Attention node.
136
+
137
+ Args:
138
+ q_matmul (NodeProto): MatMul node in fully connection for Q
139
+ k_matmul (NodeProto): MatMul node in fully connection for K
140
+ v_matmul (NodeProto): MatMul node in fully connection for V
141
+ num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
142
+ hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
143
+ input (str): input name
144
+ output (str): output name
145
+
146
+ Returns:
147
+ Union[NodeProto, None]: the node created or None if failed.
148
+ """
149
+ is_self_attention = not self.is_cross_attention
150
+
151
+ if is_self_attention:
152
+ if q_matmul.input[0] != input or k_matmul.input[0] != input or v_matmul.input[0] != input:
153
+ logger.debug(
154
+ "For self attention, input hidden state for q and k/v shall be same. Got %s, %s, %s",
155
+ q_matmul.input[0],
156
+ k_matmul.input[0],
157
+ v_matmul.input[0],
158
+ )
159
+ return None
160
+ else:
161
+ if q_matmul.input[0] != input or (k_matmul.input[0] != v_matmul.input[0]) or (k_matmul.input[0] == input):
162
+ logger.debug(
163
+ "For cross attention, input hidden state for q and k/v shall be different. Got %s, %s, %s",
164
+ q_matmul.input[0],
165
+ k_matmul.input[0],
166
+ v_matmul.input[0],
167
+ )
168
+ return None
169
+
170
+ if hidden_size > 0 and (hidden_size % num_heads) != 0:
171
+ logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
172
+ return None
173
+
174
+ q_weight = self.model.get_initializer(q_matmul.input[1])
175
+ k_weight = self.model.get_initializer(k_matmul.input[1])
176
+ v_weight = self.model.get_initializer(v_matmul.input[1])
177
+ if not (q_weight and k_weight and v_weight):
178
+ return None
179
+
180
+ # Sometimes weights are stored in fp16
181
+ float_type = q_weight.data_type
182
+
183
+ qw = NumpyHelper.to_array(q_weight)
184
+ kw = NumpyHelper.to_array(k_weight)
185
+ vw = NumpyHelper.to_array(v_weight)
186
+ logger.debug(f"qw={qw.shape} kw={kw.shape} vw={vw.shape} hidden_size={hidden_size}")
187
+
188
+ # assert q and k have same shape as expected
189
+ if is_self_attention:
190
+ if qw.shape != kw.shape or qw.shape != vw.shape:
191
+ return None
192
+
193
+ qw_in_size = qw.shape[0]
194
+
195
+ if hidden_size > 0 and hidden_size != qw_in_size:
196
+ raise ValueError(
197
+ f"Input hidden size ({hidden_size}) is not same as weight dimension of q,k,v ({qw_in_size}). "
198
+ "Please provide a correct input hidden size or pass in 0"
199
+ )
200
+
201
+ # All the matrices can have the same shape or q, k matrics can have the same shape with v being different
202
+ # For 2d weights, the shapes would be [in_size, out_size].
203
+ # For 3d weights, shape would be [in_size, a, b] where a*b = out_size
204
+ qw_out_size = int(np.prod(qw.shape[1:]))
205
+
206
+ if self.enable_packed_qkv:
207
+ attention_node_name = self.model.create_node_name("MultiHeadAttention")
208
+
209
+ c = qw_in_size
210
+ n = num_heads
211
+ h = qw_out_size // num_heads
212
+
213
+ # Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 3, H] shape
214
+ qkv_weight = np.dstack([qw.reshape(c, n, h), kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(
215
+ c, n * 3 * h
216
+ )
217
+
218
+ matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_QKV")
219
+ self.add_initializer(
220
+ name=matmul_node_name + "_weight",
221
+ data_type=float_type,
222
+ dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
223
+ vals=qkv_weight,
224
+ )
225
+
226
+ matmul_node = helper.make_node(
227
+ "MatMul",
228
+ inputs=[k_matmul.input[0], matmul_node_name + "_weight"],
229
+ outputs=[matmul_node_name + "_out"],
230
+ name=matmul_node_name,
231
+ )
232
+ self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
233
+
234
+ self.add_initializer(
235
+ name=matmul_node_name + "_reshape_shape",
236
+ data_type=TensorProto.INT64,
237
+ dims=[5],
238
+ vals=[0, 0, n, 3, h],
239
+ raw=False,
240
+ )
241
+
242
+ reshape_node = helper.make_node(
243
+ "Reshape",
244
+ inputs=[
245
+ matmul_node_name + "_out",
246
+ matmul_node_name + "_reshape_shape",
247
+ ],
248
+ outputs=[attention_node_name + "_qkv_input"],
249
+ name=matmul_node_name + "_reshape",
250
+ )
251
+ self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
252
+ self.nodes_to_add.extend([matmul_node, reshape_node])
253
+ self.nodes_to_remove.extend([q_matmul, k_matmul, v_matmul])
254
+
255
+ else:
256
+ qkv_weight = np.stack((qw, kw, vw), axis=1)
257
+ qkv_weight_dim = 3 * qw_out_size
258
+
259
+ attention_node_name = self.model.create_node_name("Attention")
260
+
261
+ self.add_initializer(
262
+ name=attention_node_name + "_qkv_weight",
263
+ data_type=float_type,
264
+ dims=[qw_in_size, qkv_weight_dim],
265
+ vals=qkv_weight,
266
+ )
267
+ else: # cross attention
268
+ attention_node_name = self.model.create_node_name("MultiHeadAttention")
269
+ if self.enable_packed_kv:
270
+ if kw.shape != vw.shape:
271
+ return None
272
+
273
+ kw_in_size = kw.shape[0]
274
+ vw_in_size = vw.shape[0]
275
+ assert kw_in_size == vw_in_size
276
+
277
+ qw_out_size = qw.shape[1]
278
+ kw_out_size = kw.shape[1]
279
+ vw_out_size = vw.shape[1]
280
+ assert qw_out_size == vw_out_size and kw_out_size == vw_out_size
281
+
282
+ c = kw_in_size
283
+ n = num_heads
284
+ h = kw_out_size // num_heads
285
+
286
+ # Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 2, H] shape
287
+ kv_weight = np.dstack([kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(c, n * 2 * h)
288
+
289
+ matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV")
290
+ self.add_initializer(
291
+ name=matmul_node_name + "_weight",
292
+ data_type=float_type,
293
+ dims=[kv_weight.shape[0], kv_weight.shape[1]],
294
+ vals=kv_weight,
295
+ )
296
+
297
+ matmul_node = helper.make_node(
298
+ "MatMul",
299
+ inputs=[k_matmul.input[0], matmul_node_name + "_weight"],
300
+ outputs=[matmul_node_name + "_out"],
301
+ name=matmul_node_name,
302
+ )
303
+ self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
304
+
305
+ self.add_initializer(
306
+ name=matmul_node_name + "_reshape_shape",
307
+ data_type=TensorProto.INT64,
308
+ dims=[5],
309
+ vals=[0, 0, n, 2, h],
310
+ raw=False,
311
+ )
312
+
313
+ reshape_node = helper.make_node(
314
+ "Reshape",
315
+ inputs=[
316
+ matmul_node_name + "_out",
317
+ matmul_node_name + "_reshape_shape",
318
+ ],
319
+ outputs=[attention_node_name + "_kv_input"],
320
+ name=matmul_node_name + "_reshape",
321
+ )
322
+ self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
323
+ self.nodes_to_add.extend([matmul_node, reshape_node])
324
+ self.nodes_to_remove.extend([k_matmul, v_matmul])
325
+
326
+ # No bias, use zeros
327
+ qkv_bias = np.zeros([3, hidden_size], dtype=np.float32)
328
+ qkv_bias_dim = 3 * hidden_size
329
+
330
+ self.add_initializer(
331
+ name=attention_node_name + "_qkv_bias",
332
+ data_type=float_type,
333
+ dims=[qkv_bias_dim],
334
+ vals=qkv_bias,
335
+ )
336
+
337
+ if is_self_attention:
338
+ if not self.enable_packed_qkv:
339
+ attention_inputs = [
340
+ input,
341
+ attention_node_name + "_qkv_weight",
342
+ attention_node_name + "_qkv_bias",
343
+ ]
344
+ else:
345
+ attention_inputs = [attention_node_name + "_qkv_input"]
346
+ else:
347
+ if not self.enable_packed_kv:
348
+ attention_inputs = [
349
+ q_matmul.output[0],
350
+ k_matmul.output[0],
351
+ v_matmul.output[0],
352
+ attention_node_name + "_qkv_bias",
353
+ ]
354
+ else:
355
+ attention_inputs = [
356
+ q_matmul.output[0],
357
+ attention_node_name + "_kv_input",
358
+ ]
359
+
360
+ attention_node = helper.make_node(
361
+ "Attention" if (is_self_attention and not self.enable_packed_qkv) else "MultiHeadAttention",
362
+ inputs=attention_inputs,
363
+ outputs=[output],
364
+ name=attention_node_name,
365
+ )
366
+ attention_node.domain = "com.microsoft"
367
+ attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
368
+
369
+ counter_name = (
370
+ "Attention (self attention)"
371
+ if is_self_attention and not self.enable_packed_qkv
372
+ else "MultiHeadAttention ({})".format(
373
+ "self attention with packed qkv"
374
+ if self.enable_packed_qkv
375
+ else "cross attention with packed kv"
376
+ if self.enable_packed_kv
377
+ else "cross attention"
378
+ )
379
+ )
380
+ self.increase_counter(counter_name)
381
+ return attention_node
382
+
383
+ def create_attention_node_lora(
384
+ self,
385
+ q_matmul_add: NodeProto,
386
+ k_matmul_add: NodeProto,
387
+ v_matmul_add: NodeProto,
388
+ num_heads: int,
389
+ hidden_size: int,
390
+ input: str,
391
+ output: str,
392
+ ) -> NodeProto | None:
393
+ """Create an Attention node.
394
+
395
+ Args:
396
+ q_matmul (NodeProto): MatMul node in fully connection for Q
397
+ k_matmul (NodeProto): MatMul node in fully connection for K
398
+ v_matmul (NodeProto): MatMul node in fully connection for V
399
+ num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
400
+ hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
401
+ input (str): input name
402
+ output (str): output name
403
+
404
+ Returns:
405
+ Union[NodeProto, None]: the node created or None if failed.
406
+ """
407
+ is_self_attention = not self.is_cross_attention
408
+
409
+ q_matmul = self.model.match_parent(q_matmul_add, "MatMul", 0)
410
+ k_matmul = self.model.match_parent(k_matmul_add, "MatMul", 0)
411
+ v_matmul = self.model.match_parent(v_matmul_add, "MatMul", 0)
412
+
413
+ q_lora_nodes = self.match_lora_path(q_matmul_add)
414
+ if q_lora_nodes is None:
415
+ return None
416
+ (q_lora_last_node, q_lora_matmul_1) = q_lora_nodes
417
+
418
+ k_lora_nodes = self.match_lora_path(k_matmul_add)
419
+ if k_lora_nodes is None:
420
+ return None
421
+ (k_lora_last_node, k_lora_matmul_1) = k_lora_nodes
422
+
423
+ v_lora_nodes = self.match_lora_path(v_matmul_add)
424
+ if v_lora_nodes is None:
425
+ return None
426
+ (v_lora_last_node, v_lora_matmul_1) = v_lora_nodes
427
+
428
+ if is_self_attention:
429
+ if q_matmul.input[0] != input or k_matmul.input[0] != input or v_matmul.input[0] != input:
430
+ logger.debug(
431
+ "For self attention, input hidden state for q and k/v shall be same. Got %s, %s, %s",
432
+ q_matmul.input[0],
433
+ k_matmul.input[0],
434
+ v_matmul.input[0],
435
+ )
436
+ return None
437
+
438
+ if (
439
+ q_lora_matmul_1.input[0] != input
440
+ or k_lora_matmul_1.input[0] != input
441
+ or v_lora_matmul_1.input[0] != input
442
+ ):
443
+ logger.debug(
444
+ "For self attention, input hidden state for LoRA q and k/v weights shall be same. Got %s, %s, %s",
445
+ q_lora_matmul_1.input[0],
446
+ k_lora_matmul_1.input[0],
447
+ v_lora_matmul_1.input[0],
448
+ )
449
+ return None
450
+ else:
451
+ if q_matmul.input[0] != input or (k_matmul.input[0] != v_matmul.input[0]) or (k_matmul.input[0] == input):
452
+ logger.debug(
453
+ "For cross attention, input hidden state for q and k/v shall be different. Got %s, %s, %s",
454
+ q_matmul.input[0],
455
+ k_matmul.input[0],
456
+ v_matmul.input[0],
457
+ )
458
+ return None
459
+
460
+ if (
461
+ q_lora_matmul_1.input[0] != input
462
+ or (k_lora_matmul_1.input[0] != v_lora_matmul_1.input[0])
463
+ or (k_matmul.input[0] == input)
464
+ ):
465
+ logger.debug(
466
+ (
467
+ "For cross attention, input hidden state for LoRA q and k/v weights shall be different. "
468
+ "Got %s, %s, %s"
469
+ ),
470
+ q_lora_matmul_1.input[0],
471
+ k_lora_matmul_1.input[0],
472
+ v_lora_matmul_1.input[0],
473
+ )
474
+ return None
475
+
476
+ if hidden_size > 0 and (hidden_size % num_heads) != 0:
477
+ logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
478
+ return None
479
+
480
+ q_weight = self.model.get_initializer(q_matmul.input[1])
481
+ k_weight = self.model.get_initializer(k_matmul.input[1])
482
+ v_weight = self.model.get_initializer(v_matmul.input[1])
483
+ if not (q_weight and k_weight and v_weight):
484
+ return None
485
+
486
+ # Sometimes weights are stored in fp16
487
+ if q_weight.data_type == 10:
488
+ logger.debug("weights are in fp16. Please run fp16 conversion after optimization")
489
+ return None
490
+
491
+ qw = NumpyHelper.to_array(q_weight)
492
+ kw = NumpyHelper.to_array(k_weight)
493
+ vw = NumpyHelper.to_array(v_weight)
494
+ logger.debug(f"qw={qw.shape} kw={kw.shape} vw={vw.shape} hidden_size={hidden_size}")
495
+
496
+ # assert q and k have same shape as expected
497
+ if is_self_attention:
498
+ if qw.shape != kw.shape or qw.shape != vw.shape:
499
+ return None
500
+
501
+ qw_in_size = qw.shape[0]
502
+
503
+ if hidden_size > 0 and hidden_size != qw_in_size:
504
+ raise ValueError(
505
+ f"Input hidden size ({hidden_size}) is not same as weight dimension of q,k,v ({qw_in_size}). "
506
+ "Please provide a correct input hidden size or pass in 0"
507
+ )
508
+
509
+ # All the matrices can have the same shape or q, k matrics can have the same shape with v being different
510
+ # For 2d weights, the shapes would be [in_size, out_size].
511
+ # For 3d weights, shape would be [in_size, a, b] where a*b = out_size
512
+ qw_out_size = int(np.prod(qw.shape[1:]))
513
+
514
+ if self.enable_packed_qkv:
515
+ attention_node_name = self.model.create_node_name("MultiHeadAttention")
516
+
517
+ c = qw_in_size
518
+ n = num_heads
519
+ h = qw_out_size // num_heads
520
+
521
+ # Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 3, H] shape
522
+ qkv_weight = np.dstack([qw.reshape(c, n, h), kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(
523
+ c, n * 3 * h
524
+ )
525
+
526
+ matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_QKV")
527
+ self.add_initializer(
528
+ name=matmul_node_name + "_weight",
529
+ data_type=TensorProto.FLOAT,
530
+ dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
531
+ vals=qkv_weight,
532
+ )
533
+
534
+ matmul_node = helper.make_node(
535
+ "MatMul",
536
+ inputs=[k_matmul.input[0], matmul_node_name + "_weight"],
537
+ outputs=[matmul_node_name + "_out"],
538
+ name=matmul_node_name,
539
+ )
540
+ self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
541
+
542
+ # Do the same thing with the LoRA weights, but don't constant fold the result. The goal is to allow
543
+ # the Q/K/V weights to be changed without having to re-run the optimizer.
544
+ lora_weight_shape_tensor_name = q_lora_last_node.name + "_reshape_shape"
545
+
546
+ self.add_initializer(
547
+ name=lora_weight_shape_tensor_name,
548
+ data_type=TensorProto.INT64,
549
+ dims=[4],
550
+ vals=[0, 0, n, h],
551
+ raw=False,
552
+ )
553
+
554
+ # Reshape the LoRA Q weights
555
+ q_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_Q")
556
+ q_lora_reshape_node = helper.make_node(
557
+ "Reshape",
558
+ inputs=[q_lora_last_node.output[0], lora_weight_shape_tensor_name],
559
+ outputs=[q_lora_reshape_node_name + "_out"],
560
+ name=q_lora_reshape_node_name,
561
+ )
562
+ self.node_name_to_graph_name[q_lora_reshape_node.name] = self.this_graph_name
563
+
564
+ # Reshape the LoRA K weights
565
+ k_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_K")
566
+ k_lora_reshape_node = helper.make_node(
567
+ "Reshape",
568
+ inputs=[k_lora_last_node.output[0], lora_weight_shape_tensor_name],
569
+ outputs=[k_lora_reshape_node_name + "_out"],
570
+ name=k_lora_reshape_node_name,
571
+ )
572
+ self.node_name_to_graph_name[k_lora_reshape_node.name] = self.this_graph_name
573
+
574
+ # Reshape the LoRA V weights
575
+ v_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_V")
576
+ v_lora_reshape_node = helper.make_node(
577
+ "Reshape",
578
+ inputs=[v_lora_last_node.output[0], lora_weight_shape_tensor_name],
579
+ outputs=[v_lora_reshape_node_name + "_out"],
580
+ name=v_lora_reshape_node_name,
581
+ )
582
+ self.node_name_to_graph_name[v_lora_reshape_node.name] = self.this_graph_name
583
+
584
+ # Concat the reshaped LoRA Q/K/V weights together on the third axis
585
+ qkv_lora_concat_node_name = self.model.create_node_name("Concat", name_prefix="Concat_LoRA_QKV")
586
+ qkv_lora_concat_node = helper.make_node(
587
+ "Concat",
588
+ inputs=[
589
+ q_lora_reshape_node.output[0],
590
+ k_lora_reshape_node.output[0],
591
+ v_lora_reshape_node.output[0],
592
+ ],
593
+ outputs=[qkv_lora_concat_node_name + "_out"],
594
+ name=qkv_lora_concat_node_name,
595
+ )
596
+ qkv_lora_concat_node.attribute.extend([helper.make_attribute("axis", 3)])
597
+ self.node_name_to_graph_name[qkv_lora_concat_node.name] = self.this_graph_name
598
+
599
+ # Reshape the LoRA concatenated weights to [..., n * 3 * h]
600
+ reshaped_lora_weights_shape_tensor_name = qkv_lora_concat_node.name + "_reshape_shape"
601
+ self.add_initializer(
602
+ name=reshaped_lora_weights_shape_tensor_name,
603
+ data_type=TensorProto.INT64,
604
+ dims=[3],
605
+ vals=[0, 0, n * 3 * h],
606
+ raw=False,
607
+ )
608
+
609
+ qkv_lora_reshaped_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_QKV")
610
+ qkv_lora_reshaped_node = helper.make_node(
611
+ "Reshape",
612
+ inputs=[qkv_lora_concat_node.output[0], reshaped_lora_weights_shape_tensor_name],
613
+ outputs=[qkv_lora_reshaped_node_name + "_out"],
614
+ name=qkv_lora_reshaped_node_name,
615
+ )
616
+ self.node_name_to_graph_name[qkv_lora_reshaped_node.name] = self.this_graph_name
617
+
618
+ # Add the LoRA Q/K/V weights to the base Q/K/V weights
619
+ add_weights_node_name = self.model.create_node_name("Add", name_prefix="Add_Weights_QKV")
620
+ add_weights_node = helper.make_node(
621
+ "Add",
622
+ inputs=[qkv_lora_reshaped_node.output[0], matmul_node.output[0]],
623
+ outputs=[add_weights_node_name + "_out"],
624
+ name=add_weights_node_name,
625
+ )
626
+ self.node_name_to_graph_name[add_weights_node.name] = self.this_graph_name
627
+
628
+ # Finally, reshape the concatenated Q/K/V result to 5D
629
+ shape_tensor_name = add_weights_node_name + "_reshape_shape"
630
+ self.add_initializer(
631
+ name=shape_tensor_name,
632
+ data_type=TensorProto.INT64,
633
+ dims=[5],
634
+ vals=[0, 0, n, 3, h],
635
+ raw=False,
636
+ )
637
+
638
+ reshape_node = helper.make_node(
639
+ "Reshape",
640
+ inputs=[add_weights_node.output[0], shape_tensor_name],
641
+ outputs=[attention_node_name + "_qkv_input"],
642
+ name=add_weights_node_name + "_reshape",
643
+ )
644
+ self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
645
+
646
+ self.nodes_to_add.extend(
647
+ [
648
+ matmul_node,
649
+ q_lora_reshape_node,
650
+ k_lora_reshape_node,
651
+ v_lora_reshape_node,
652
+ qkv_lora_concat_node,
653
+ qkv_lora_reshaped_node,
654
+ add_weights_node,
655
+ reshape_node,
656
+ ]
657
+ )
658
+ self.nodes_to_remove.extend([q_matmul, k_matmul, v_matmul, q_matmul_add, k_matmul_add, v_matmul_add])
659
+ else:
660
+ # TODO: Support non-packed QKV
661
+ return None
662
+ else: # cross attention
663
+ attention_node_name = self.model.create_node_name("MultiHeadAttention")
664
+ if self.enable_packed_kv:
665
+ if kw.shape != vw.shape:
666
+ return None
667
+
668
+ kw_in_size = kw.shape[0]
669
+ vw_in_size = vw.shape[0]
670
+ assert kw_in_size == vw_in_size
671
+
672
+ qw_out_size = qw.shape[1]
673
+ kw_out_size = kw.shape[1]
674
+ vw_out_size = vw.shape[1]
675
+ assert qw_out_size == vw_out_size and kw_out_size == vw_out_size
676
+
677
+ c = kw_in_size
678
+ n = num_heads
679
+ h = kw_out_size // num_heads
680
+
681
+ # Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 2, H] shape
682
+ kv_weight = np.dstack([kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(c, n * 2 * h)
683
+
684
+ matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV")
685
+ self.add_initializer(
686
+ name=matmul_node_name + "_weight",
687
+ data_type=TensorProto.FLOAT,
688
+ dims=[kv_weight.shape[0], kv_weight.shape[1]],
689
+ vals=kv_weight,
690
+ )
691
+
692
+ matmul_node = helper.make_node(
693
+ "MatMul",
694
+ inputs=[k_matmul.input[0], matmul_node_name + "_weight"],
695
+ outputs=[matmul_node_name + "_out"],
696
+ name=matmul_node_name,
697
+ )
698
+ self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
699
+
700
+ # Do the same thing with the LoRA weights, but don't constant fold the result. The goal is to allow
701
+ # the Q/K/V weights to be changed without having to re-run the optimizer.
702
+ kv_lora_weight_shape_tensor_name = q_lora_last_node.name + "_reshape_shape"
703
+ self.add_initializer(
704
+ name=kv_lora_weight_shape_tensor_name,
705
+ data_type=TensorProto.INT64,
706
+ dims=[4],
707
+ vals=[0, 0, n, h],
708
+ raw=False,
709
+ )
710
+
711
+ # Reshape the LoRA K weights
712
+ k_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_K")
713
+ k_lora_reshape_node = helper.make_node(
714
+ "Reshape",
715
+ inputs=[k_lora_last_node.output[0], kv_lora_weight_shape_tensor_name],
716
+ outputs=[k_lora_reshape_node_name + "_out"],
717
+ name=k_lora_reshape_node_name,
718
+ )
719
+ self.node_name_to_graph_name[k_lora_reshape_node.name] = self.this_graph_name
720
+
721
+ # Reshape the LoRA V weights
722
+ v_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_V")
723
+ v_lora_reshape_node = helper.make_node(
724
+ "Reshape",
725
+ inputs=[v_lora_last_node.output[0], kv_lora_weight_shape_tensor_name],
726
+ outputs=[v_lora_reshape_node_name + "_out"],
727
+ name=v_lora_reshape_node_name,
728
+ )
729
+ self.node_name_to_graph_name[v_lora_reshape_node.name] = self.this_graph_name
730
+
731
+ # Concat the reshaped LoRA K/V weights together on the third axis
732
+ kv_lora_concat_node_name = self.model.create_node_name("Concat", name_prefix="Concat_LoRA_KV")
733
+ kv_lora_concat_node = helper.make_node(
734
+ "Concat",
735
+ inputs=[k_lora_reshape_node.output[0], v_lora_reshape_node.output[0]],
736
+ outputs=[kv_lora_concat_node_name + "_out"],
737
+ name=kv_lora_concat_node_name,
738
+ )
739
+ kv_lora_concat_node.attribute.extend([helper.make_attribute("axis", 3)])
740
+ self.node_name_to_graph_name[kv_lora_concat_node.name] = self.this_graph_name
741
+
742
+ # Reshape the LoRA concatenated weights to [..., n * 2 * h]
743
+ reshaped_kv_lora_weights_shape_tensor_name = kv_lora_concat_node.name + "_reshape_shape"
744
+ self.add_initializer(
745
+ name=reshaped_kv_lora_weights_shape_tensor_name,
746
+ data_type=TensorProto.INT64,
747
+ dims=[3],
748
+ vals=[0, 0, n * 2 * h],
749
+ raw=False,
750
+ )
751
+
752
+ kv_lora_reshaped_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_KV")
753
+ kv_lora_reshaped_node = helper.make_node(
754
+ "Reshape",
755
+ inputs=[kv_lora_concat_node.output[0], reshaped_kv_lora_weights_shape_tensor_name],
756
+ outputs=[kv_lora_reshaped_node_name + "_out"],
757
+ name=kv_lora_reshaped_node_name,
758
+ )
759
+ self.node_name_to_graph_name[kv_lora_reshaped_node.name] = self.this_graph_name
760
+
761
+ # Add the LoRA K/V weights to the base K/V weights
762
+ add_kv_weights_node_name = self.model.create_node_name("Add", name_prefix="Add_Weights_KV")
763
+ add_kv_weights_node = helper.make_node(
764
+ "Add",
765
+ inputs=[kv_lora_reshaped_node.output[0], matmul_node.output[0]],
766
+ outputs=[add_kv_weights_node_name + "_out"],
767
+ name=add_kv_weights_node_name,
768
+ )
769
+ self.node_name_to_graph_name[add_kv_weights_node.name] = self.this_graph_name
770
+
771
+ # Finally, reshape the concatenated K/V result to 5D
772
+ shape_tensor_name = add_kv_weights_node_name + "_reshape_shape"
773
+ self.add_initializer(
774
+ name=shape_tensor_name,
775
+ data_type=TensorProto.INT64,
776
+ dims=[5],
777
+ vals=[0, 0, n, 2, h],
778
+ raw=False,
779
+ )
780
+
781
+ reshape_node = helper.make_node(
782
+ "Reshape",
783
+ inputs=[add_kv_weights_node.output[0], shape_tensor_name],
784
+ outputs=[attention_node_name + "_kv_input"],
785
+ name=add_kv_weights_node_name + "_reshape",
786
+ )
787
+ self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
788
+ self.nodes_to_add.extend(
789
+ [
790
+ matmul_node,
791
+ k_lora_reshape_node,
792
+ v_lora_reshape_node,
793
+ kv_lora_concat_node,
794
+ kv_lora_reshaped_node,
795
+ add_kv_weights_node,
796
+ reshape_node,
797
+ ]
798
+ )
799
+ self.nodes_to_remove.extend([k_matmul, v_matmul, k_matmul_add, v_matmul_add])
800
+ else:
801
+ # TODO: Support non-packed KV
802
+ return None
803
+
804
+ # No bias, use zeros
805
+ qkv_bias = np.zeros([3, hidden_size], dtype=np.float32)
806
+ qkv_bias_dim = 3 * hidden_size
807
+ self.add_initializer(
808
+ name=attention_node_name + "_qkv_bias",
809
+ data_type=TensorProto.FLOAT,
810
+ dims=[qkv_bias_dim],
811
+ vals=qkv_bias,
812
+ )
813
+
814
+ if is_self_attention:
815
+ if not self.enable_packed_qkv:
816
+ # TODO: Support non-packed QKV
817
+ return None
818
+ else:
819
+ attention_inputs = [attention_node_name + "_qkv_input"]
820
+ else:
821
+ if not self.enable_packed_kv:
822
+ # TODO: Support non-packed QKV
823
+ return None
824
+ else:
825
+ attention_inputs = [
826
+ q_matmul_add.output[0],
827
+ attention_node_name + "_kv_input",
828
+ ]
829
+
830
+ attention_node = helper.make_node(
831
+ "Attention" if (is_self_attention and not self.enable_packed_qkv) else "MultiHeadAttention",
832
+ inputs=attention_inputs,
833
+ outputs=[output],
834
+ name=attention_node_name,
835
+ )
836
+ attention_node.domain = "com.microsoft"
837
+ attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
838
+
839
+ counter_name = (
840
+ "Attention (self attention)"
841
+ if is_self_attention and not self.enable_packed_qkv
842
+ else "MultiHeadAttention ({})".format(
843
+ "self attention with packed qkv"
844
+ if self.enable_packed_qkv
845
+ else "cross attention with packed kv"
846
+ if self.enable_packed_kv
847
+ else "cross attention"
848
+ )
849
+ )
850
+ self.increase_counter(counter_name)
851
+ return attention_node
852
+
853
+ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
854
+ if self.fuse_a1111_fp16(normalize_node, input_name_to_nodes, output_name_to_node):
855
+ return
856
+
857
+ node_before_layernorm = self.model.match_parent(normalize_node, "Add", 0)
858
+
859
+ # In SD 1.5, for self attention, LayerNorm has parent Reshape
860
+ if node_before_layernorm is None and not self.is_cross_attention:
861
+ node_before_layernorm = self.model.match_parent(normalize_node, "Reshape", 0)
862
+
863
+ if node_before_layernorm is None:
864
+ return
865
+
866
+ root_input = node_before_layernorm.output[0]
867
+
868
+ children_nodes = input_name_to_nodes[root_input]
869
+ skip_add = None
870
+ for node in children_nodes:
871
+ if node.op_type == "Add": # SkipLayerNormalization fusion is not applied yet
872
+ skip_add = node
873
+ break
874
+ if skip_add is None:
875
+ return
876
+
877
+ match_qkv = self.match_qkv_torch1(root_input, skip_add) or self.match_qkv_torch2(root_input, skip_add)
878
+ if match_qkv is not None:
879
+ is_torch2, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v = match_qkv
880
+
881
+ attention_last_node = reshape_qkv
882
+
883
+ q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2)
884
+ if q_num_heads <= 0:
885
+ logger.debug("fuse_attention: failed to detect num_heads")
886
+ return
887
+
888
+ # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
889
+ new_node = self.create_attention_node(
890
+ matmul_q,
891
+ matmul_k,
892
+ matmul_v,
893
+ q_num_heads,
894
+ q_hidden_size,
895
+ input=normalize_node.output[0],
896
+ output=attention_last_node.output[0],
897
+ )
898
+ if new_node is None:
899
+ return
900
+ else:
901
+ # Check if we have a LoRA pattern
902
+ match_qkv = self.match_qkv_torch1_lora(root_input, skip_add) or self.match_qkv_torch2_lora(
903
+ root_input, skip_add
904
+ )
905
+ if match_qkv is None:
906
+ return
907
+
908
+ is_torch2, reshape_qkv, transpose_qkv, reshape_q, matmul_add_q, matmul_add_k, matmul_add_v = match_qkv
909
+
910
+ attention_last_node = reshape_qkv
911
+
912
+ q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2)
913
+ if q_num_heads <= 0:
914
+ logger.debug("fuse_attention: failed to detect num_heads")
915
+ return
916
+
917
+ # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
918
+ new_node = self.create_attention_node_lora(
919
+ matmul_add_q,
920
+ matmul_add_k,
921
+ matmul_add_v,
922
+ q_num_heads,
923
+ q_hidden_size,
924
+ input=normalize_node.output[0],
925
+ output=attention_last_node.output[0],
926
+ )
927
+ if new_node is None:
928
+ return
929
+
930
+ q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2)
931
+ if q_num_heads <= 0:
932
+ logger.debug("fuse_attention: failed to detect num_heads")
933
+ return
934
+
935
+ self.nodes_to_add.append(new_node)
936
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
937
+
938
+ self.nodes_to_remove.extend([attention_last_node, transpose_qkv])
939
+
940
+ # Use prune graph to remove nodes since they are shared by all attention nodes.
941
+ self.prune_graph = True
942
+
943
+ def match_qkv_torch1(self, root_input, skip_add):
944
+ """Match Q, K and V paths exported by PyTorch 1.*"""
945
+ another_input = 1 if skip_add.input[0] == root_input else 0
946
+ qkv_nodes = self.model.match_parent_path(
947
+ skip_add,
948
+ ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
949
+ [another_input, None, None, 0, 0, 0],
950
+ )
951
+
952
+ if qkv_nodes is None:
953
+ return None
954
+
955
+ (_, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes
956
+
957
+ # No bias. For cross-attention, the input of the MatMul is encoder_hidden_states graph input.
958
+ v_nodes = self.model.match_parent_path(matmul_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
959
+ if v_nodes is None:
960
+ logger.debug("fuse_attention: failed to match v path")
961
+ return None
962
+ (_, _, _, matmul_v) = v_nodes
963
+
964
+ qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0])
965
+ if qk_nodes is not None:
966
+ (_softmax_qk, _mul_qk, matmul_qk) = qk_nodes
967
+ else:
968
+ qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
969
+ if qk_nodes is not None:
970
+ (_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes
971
+ else:
972
+ logger.debug("fuse_attention: failed to match qk path")
973
+ return None
974
+
975
+ q_nodes = self.model.match_parent_path(matmul_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0])
976
+ if q_nodes is None:
977
+ logger.debug("fuse_attention: failed to match q path")
978
+ return None
979
+ (_, _transpose_q, reshape_q, matmul_q) = q_nodes
980
+
981
+ k_nodes = self.model.match_parent_path(
982
+ matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0, 0]
983
+ )
984
+ if k_nodes is None:
985
+ logger.debug("fuse_attention: failed to match k path")
986
+ return None
987
+
988
+ (_, _, _, _, matmul_k) = k_nodes
989
+
990
+ return False, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v
991
+
992
+ def match_qkv_torch2(self, root_input, skip_add):
993
+ """Match Q, K and V paths exported by PyTorch 2.*"""
994
+ another_input = 1 if skip_add.input[0] == root_input else 0
995
+ qkv_nodes = self.model.match_parent_path(
996
+ skip_add,
997
+ ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
998
+ [another_input, None, None, 0, 0],
999
+ )
1000
+
1001
+ if qkv_nodes is None:
1002
+ return None
1003
+
1004
+ (_, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
1005
+
1006
+ v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "MatMul"], [1, 0, 0])
1007
+ if v_nodes is None:
1008
+ logger.debug("fuse_attention: failed to match v path")
1009
+ return None
1010
+ (_, _, matmul_v) = v_nodes
1011
+
1012
+ qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
1013
+ if qk_nodes is not None:
1014
+ (_softmax_qk, matmul_qk) = qk_nodes
1015
+ else:
1016
+ logger.debug("fuse_attention: failed to match qk path")
1017
+ return None
1018
+
1019
+ q_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "MatMul"], [0, None, 0, 0])
1020
+ if q_nodes is None:
1021
+ logger.debug("fuse_attention: failed to match q path")
1022
+ return None
1023
+ (mul_q, _transpose_q, reshape_q, matmul_q) = q_nodes
1024
+
1025
+ k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "MatMul"], [1, None, 0, 0])
1026
+ if k_nodes is None:
1027
+ logger.debug("fuse_attention: failed to match k path")
1028
+ return None
1029
+
1030
+ (_mul_k, _, _, matmul_k) = k_nodes
1031
+
1032
+ # The scalar for Q and K is sqrt(1.0/sqrt(head_size)).
1033
+ mul_q_nodes = self.model.match_parent_path(
1034
+ mul_q,
1035
+ ["Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape", "Transpose", "Reshape"],
1036
+ [None, 0, 1, 0, 0, 0, 0, 0],
1037
+ )
1038
+ if mul_q_nodes is None or mul_q_nodes[-1] != reshape_q:
1039
+ logger.debug("fuse_attention: failed to match mul_q path")
1040
+ return None
1041
+
1042
+ return True, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v
1043
+
1044
+ def match_qkv_torch1_lora(self, root_input, skip_add):
1045
+ """Match Q, K and V paths exported by PyTorch 1 that contains LoRA patterns.*"""
1046
+ another_input = 1 if skip_add.input[0] == root_input else 0
1047
+ qkv_nodes = self.model.match_parent_path(
1048
+ skip_add,
1049
+ ["Add", "Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
1050
+ [another_input, 0, None, None, 0, 0, 0],
1051
+ )
1052
+ if qkv_nodes is None:
1053
+ return None
1054
+
1055
+ (_, _, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes
1056
+
1057
+ # No bias. For cross-attention, the input of the MatMul is encoder_hidden_states graph input.
1058
+ v_nodes = self.model.match_parent_path(matmul_qkv, ["Reshape", "Transpose", "Reshape", "Add"], [1, 0, 0, 0])
1059
+ if v_nodes is None:
1060
+ logger.debug("fuse_attention: failed to match LoRA v path")
1061
+ return None
1062
+ (_, _, _, matmul_add_v) = v_nodes
1063
+
1064
+ qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0])
1065
+ if qk_nodes is not None:
1066
+ (_softmax_qk, _mul_qk, matmul_qk) = qk_nodes
1067
+ else:
1068
+ qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
1069
+ if qk_nodes is not None:
1070
+ (_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes
1071
+ else:
1072
+ logger.debug("fuse_attention: failed to match LoRA qk path")
1073
+ return None
1074
+
1075
+ q_nodes = self.model.match_parent_path(matmul_qk, ["Reshape", "Transpose", "Reshape", "Add"], [0, 0, 0, 0])
1076
+ if q_nodes is None:
1077
+ logger.debug("fuse_attention: failed to match LoRA q path")
1078
+ return None
1079
+ (_, _transpose_q, reshape_q, matmul_add_q) = q_nodes
1080
+
1081
+ k_nodes = self.model.match_parent_path(
1082
+ matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "Add"], [1, 0, 0, 0, 0]
1083
+ )
1084
+ if k_nodes is None:
1085
+ logger.debug("fuse_attention: failed to match LoRA k path")
1086
+ return None
1087
+
1088
+ (_, _, _, _, matmul_add_k) = k_nodes
1089
+
1090
+ return False, reshape_qkv, transpose_qkv, reshape_q, matmul_add_q, matmul_add_k, matmul_add_v
1091
+
1092
+ def match_qkv_torch2_lora(self, root_input, skip_add):
1093
+ """Match Q, K and V paths exported by PyTorch 2 that contains LoRA patterns.*"""
1094
+ another_input = 1 if skip_add.input[0] == root_input else 0
1095
+ qkv_nodes = self.model.match_parent_path(
1096
+ skip_add,
1097
+ ["Add", "Add", "MatMul", "Reshape", "Transpose", "MatMul"],
1098
+ [another_input, 0, None, None, 0, 0],
1099
+ )
1100
+ if qkv_nodes is None:
1101
+ return None
1102
+
1103
+ (_, _, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
1104
+
1105
+ v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add"], [1, 0, 0])
1106
+ if v_nodes is None:
1107
+ logger.debug("fuse_attention: failed to match LoRA v path")
1108
+ return None
1109
+ (_, _, matmul_add_v) = v_nodes
1110
+
1111
+ qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
1112
+ if qk_nodes is not None:
1113
+ (_softmax_qk, matmul_qk) = qk_nodes
1114
+ else:
1115
+ logger.debug("fuse_attention: failed to match LoRA qk path")
1116
+ return None
1117
+
1118
+ q_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "Add"], [0, None, 0, 0])
1119
+ if q_nodes is None:
1120
+ logger.debug("fuse_attention: failed to match LoRA q path")
1121
+ return None
1122
+ (mul_q, _transpose_q, reshape_q, matmul_add_q) = q_nodes
1123
+
1124
+ k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "Add"], [1, None, 0, 0])
1125
+ if k_nodes is None:
1126
+ logger.debug("fuse_attention: failed to match LoRA k path")
1127
+ return None
1128
+
1129
+ (_mul_k, _, _, matmul_add_k) = k_nodes
1130
+
1131
+ # The scalar for Q and K is sqrt(1.0/sqrt(head_size)).
1132
+ mul_q_nodes = self.model.match_parent_path(
1133
+ mul_q,
1134
+ ["Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape", "Transpose", "Reshape"],
1135
+ [None, 0, 1, 0, 0, 0, 0, 0],
1136
+ )
1137
+ if mul_q_nodes is None or mul_q_nodes[-1] != reshape_q:
1138
+ logger.debug("fuse_attention: failed to match LoRA mul_q path")
1139
+ return None
1140
+
1141
+ return True, reshape_qkv, transpose_qkv, reshape_q, matmul_add_q, matmul_add_k, matmul_add_v
1142
+
1143
+ def match_lora_path(
1144
+ self,
1145
+ add_node: NodeProto,
1146
+ ):
1147
+ # Lora paths can look like one of the following options:
1148
+ # MatMul -> MatMul -> Add
1149
+ # MatMul -> MatMul -> Mul -> Add
1150
+ # MatMul -> MatMul -> Mul -> Mul -> Add
1151
+
1152
+ # Try matching MatMul -> MatMul -> Add
1153
+ lora_nodes = self.model.match_parent_path(
1154
+ add_node,
1155
+ ["MatMul", "MatMul"],
1156
+ [1, 0],
1157
+ )
1158
+
1159
+ if lora_nodes is not None:
1160
+ (lora_matmul_2_node, lora_matmul_1_node) = lora_nodes
1161
+ return (lora_matmul_2_node, lora_matmul_1_node)
1162
+
1163
+ # Try matching MatMul -> MatMul -> Mul -> Add
1164
+ lora_nodes = self.model.match_parent_path(
1165
+ add_node,
1166
+ ["Mul", "MatMul", "MatMul"],
1167
+ [1, 0, 0],
1168
+ )
1169
+
1170
+ if lora_nodes is not None:
1171
+ (lora_mul_node, _, lora_matmul_1_node) = lora_nodes
1172
+ return (lora_mul_node, lora_matmul_1_node)
1173
+
1174
+ # Try matching MatMul -> MatMul -> Mul -> Mul -> Add
1175
+ lora_nodes = self.model.match_parent_path(
1176
+ add_node,
1177
+ ["Mul", "Mul", "MatMul", "MatMul"],
1178
+ [1, 0, 0, 0],
1179
+ )
1180
+
1181
+ if lora_nodes is not None:
1182
+ (lora_mul_node, _, _, lora_matmul_1_node) = lora_nodes
1183
+ return (lora_mul_node, lora_matmul_1_node)
1184
+
1185
+ return None
1186
+
1187
+ def fuse_a1111_fp16(self, normalize_node, input_name_to_nodes, output_name_to_node):
1188
+ """Fuse attention of fp16 UNet exported in A1111 (stable diffusion webui) extension"""
1189
+ entry_path = self.model.match_parent_path(normalize_node, ["Cast", "Add"], [0, 0])
1190
+ if entry_path is None:
1191
+ entry_path = self.model.match_parent_path(normalize_node, ["Cast", "Reshape"], [0, 0])
1192
+ if entry_path is None:
1193
+ return False
1194
+ _cast, node_before_layernorm = entry_path
1195
+
1196
+ root_input = node_before_layernorm.output[0]
1197
+
1198
+ children_nodes = input_name_to_nodes[root_input]
1199
+ skip_add = None
1200
+ for node in children_nodes:
1201
+ if node.op_type == "Add": # SkipLayerNormalization fusion is not applied yet
1202
+ skip_add = node
1203
+ break
1204
+ if skip_add is None:
1205
+ return False
1206
+
1207
+ match_qkv = self.match_qkv_a1111(root_input, skip_add)
1208
+ if match_qkv is None:
1209
+ return False
1210
+
1211
+ (
1212
+ reshape_qkv,
1213
+ transpose_qkv,
1214
+ reshape_q,
1215
+ matmul_q,
1216
+ matmul_k,
1217
+ matmul_v,
1218
+ ) = match_qkv
1219
+
1220
+ cast_q = self.model.match_parent(matmul_q, "Cast", 0)
1221
+ cast_k = self.model.match_parent(matmul_k, "Cast", 0)
1222
+ cast_v = self.model.match_parent(matmul_v, "Cast", 0)
1223
+ if not (
1224
+ cast_q is not None
1225
+ and cast_k is not None
1226
+ and (cast_q == cast_k if not self.is_cross_attention else cast_q != cast_k)
1227
+ and cast_k == cast_v
1228
+ ):
1229
+ return False
1230
+
1231
+ if cast_q.input[0] != normalize_node.output[0]:
1232
+ return False
1233
+
1234
+ attention_last_node = reshape_qkv
1235
+
1236
+ q_num_heads = self.get_num_heads(reshape_q, True) or self.get_num_heads(reshape_q, False)
1237
+ if q_num_heads <= 0:
1238
+ logger.debug("fuse_attention: failed to detect num_heads")
1239
+ return False
1240
+
1241
+ q_hidden_size = self.get_hidden_size(normalize_node)
1242
+
1243
+ # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
1244
+ new_node = self.create_attention_node(
1245
+ matmul_q,
1246
+ matmul_k,
1247
+ matmul_v,
1248
+ q_num_heads,
1249
+ q_hidden_size,
1250
+ input=matmul_q.input[0],
1251
+ output=attention_last_node.output[0],
1252
+ )
1253
+ if new_node is None:
1254
+ return False
1255
+
1256
+ self.nodes_to_add.append(new_node)
1257
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
1258
+
1259
+ self.nodes_to_remove.extend([attention_last_node, transpose_qkv])
1260
+
1261
+ # Use prune graph to remove nodes since they are shared by all attention nodes.
1262
+ self.prune_graph = True
1263
+ return True
1264
+
1265
+ def match_qkv_a1111(self, root_input, skip_add):
1266
+ """Match Q, K and V paths exported by A1111 (stable diffusion webui) extension"""
1267
+ another_input = 1 if skip_add.input[0] == root_input else 0
1268
+ qkv_nodes = self.model.match_parent_path(
1269
+ skip_add,
1270
+ ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "Einsum"],
1271
+ [another_input, None, None, 0, 0, 0],
1272
+ )
1273
+
1274
+ if qkv_nodes is None:
1275
+ return None
1276
+
1277
+ (_, _, reshape_qkv, transpose_qkv, reshape_einsum, einsum_qkv) = qkv_nodes
1278
+
1279
+ v_nodes = self.model.match_parent_path(einsum_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
1280
+ if v_nodes is None:
1281
+ logger.debug("fuse_attention: failed to match v path")
1282
+ return None
1283
+ (_, _, _, matmul_v) = v_nodes
1284
+
1285
+ qk_nodes = self.model.match_parent_path(
1286
+ einsum_qkv, ["Cast", "Cast", "Softmax", "Mul", "Einsum"], [0, 0, 0, 0, None]
1287
+ )
1288
+ if qk_nodes is not None:
1289
+ (_, _, _softmax_qk, _, einsum_qk) = qk_nodes
1290
+ else:
1291
+ logger.debug("fuse_attention: failed to match qk path")
1292
+ return None
1293
+
1294
+ q_nodes = self.model.match_parent_path(einsum_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0])
1295
+ if q_nodes is None:
1296
+ logger.debug("fuse_attention: failed to match q path")
1297
+ return None
1298
+ (_, _transpose_q, reshape_q, matmul_q) = q_nodes
1299
+
1300
+ k_nodes = self.model.match_parent_path(einsum_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
1301
+ if k_nodes is None:
1302
+ logger.debug("fuse_attention: failed to match k path")
1303
+ return None
1304
+
1305
+ (_, _, _, matmul_k) = k_nodes
1306
+
1307
+ return reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v