onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1304 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+ from typing import Tuple, Union
7
+
8
+ import numpy as np
9
+ from fusion_base import Fusion
10
+ from fusion_utils import NumpyHelper
11
+ from onnx import NodeProto, TensorProto, helper
12
+ from onnx_model import OnnxModel
13
+
14
+ logger = getLogger(__name__)
15
+
16
+
17
+ class FusionAttentionUnet(Fusion):
18
+ """
19
+ Fuse Attention subgraph of UNet into one Attention node.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ model: OnnxModel,
25
+ hidden_size: int,
26
+ num_heads: int,
27
+ is_cross_attention: bool,
28
+ enable_packed_qkv: bool,
29
+ enable_packed_kv: bool,
30
+ ):
31
+ super().__init__(
32
+ model,
33
+ "Attention" if is_cross_attention and enable_packed_qkv else "MultiHeadAttention",
34
+ ["LayerNormalization"],
35
+ )
36
+ self.hidden_size = hidden_size
37
+ self.num_heads = num_heads
38
+ self.is_cross_attention = is_cross_attention
39
+
40
+ # Note: pack Q/K/V or K/V weights into one tensor make it harder for updating initializers for LoRA.
41
+ # To support LoRA, it is better to use separated Q, K and V inputs in offline optimization,
42
+ # and CUDA operator pre-packs those tensors to preferred format based on available kernels.
43
+ # In this way, we can support LoRA and get optimal performance at same time.
44
+ self.enable_packed_qkv = enable_packed_qkv
45
+ self.enable_packed_kv = enable_packed_kv
46
+
47
+ # Flags to show warning only once
48
+ self.num_heads_warning = True
49
+ self.hidden_size_warning = True
50
+
51
+ def get_num_heads(self, reshape_q: NodeProto, is_torch2: bool = False) -> int:
52
+ """Detect num_heads from a reshape node.
53
+
54
+ Args:
55
+ reshape_q (NodeProto): reshape node for Q
56
+ is_torch2 (bool): graph pattern is from PyTorch 2.*
57
+ Returns:
58
+ int: num_heads, or 0 if not found
59
+ """
60
+ num_heads = 0
61
+ if is_torch2:
62
+ # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
63
+ reshape_parent = self.model.get_parent(reshape_q, 1)
64
+ if reshape_parent and reshape_parent.op_type == "Concat" and len(reshape_parent.input) == 4:
65
+ num_heads = self.model.get_constant_value(reshape_parent.input[2])
66
+ if isinstance(num_heads, np.ndarray) and list(num_heads.shape) == [1]:
67
+ num_heads = int(num_heads)
68
+ else:
69
+ # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
70
+ q_shape_value = self.model.get_constant_value(reshape_q.input[1])
71
+ if isinstance(q_shape_value, np.ndarray) and list(q_shape_value.shape) == [4]:
72
+ num_heads = int(q_shape_value[2])
73
+
74
+ if isinstance(num_heads, int) and num_heads > 0:
75
+ return num_heads
76
+
77
+ return 0
78
+
79
+ def get_hidden_size(self, layernorm_node):
80
+ """Detect hidden_size from LayerNormalization node.
81
+ Args:
82
+ layernorm_node (NodeProto): LayerNormalization node before Q, K and V
83
+ Returns:
84
+ int: hidden_size, or 0 if not found
85
+ """
86
+ layernorm_bias = self.model.get_initializer(layernorm_node.input[2])
87
+ if layernorm_bias:
88
+ return NumpyHelper.to_array(layernorm_bias).shape[0]
89
+
90
+ return 0
91
+
92
+ def get_num_heads_and_hidden_size(
93
+ self, reshape_q: NodeProto, layernorm_node: NodeProto, is_torch2: bool = False
94
+ ) -> Tuple[int, int]:
95
+ """Detect num_heads and hidden_size.
96
+
97
+ Args:
98
+ reshape_q (NodeProto): reshape node for Q
99
+ is_torch2 (bool): graph pattern is from PyTorch 2.*
100
+ layernorm_node (NodeProto): LayerNormalization node before Q, K, V
101
+ Returns:
102
+ Tuple[int, int]: num_heads and hidden_size
103
+ """
104
+ num_heads = self.get_num_heads(reshape_q, is_torch2)
105
+ if num_heads <= 0:
106
+ num_heads = self.num_heads # Fall back to user specified value
107
+
108
+ if self.num_heads > 0 and num_heads != self.num_heads:
109
+ if self.num_heads_warning:
110
+ logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
111
+ self.num_heads_warning = False # Do not show the warning more than once
112
+
113
+ hidden_size = self.get_hidden_size(layernorm_node)
114
+ if hidden_size <= 0:
115
+ hidden_size = self.hidden_size # Fall back to user specified value
116
+
117
+ if self.hidden_size > 0 and hidden_size != self.hidden_size:
118
+ if self.hidden_size_warning:
119
+ logger.warning(
120
+ f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
121
+ )
122
+ self.hidden_size_warning = False # Do not show the warning more than once
123
+
124
+ return num_heads, hidden_size
125
+
126
+ def create_attention_node(
127
+ self,
128
+ q_matmul: NodeProto,
129
+ k_matmul: NodeProto,
130
+ v_matmul: NodeProto,
131
+ num_heads: int,
132
+ hidden_size: int,
133
+ input: str,
134
+ output: str,
135
+ ) -> Union[NodeProto, None]:
136
+ """Create an Attention node.
137
+
138
+ Args:
139
+ q_matmul (NodeProto): MatMul node in fully connection for Q
140
+ k_matmul (NodeProto): MatMul node in fully connection for K
141
+ v_matmul (NodeProto): MatMul node in fully connection for V
142
+ num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
143
+ hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
144
+ input (str): input name
145
+ output (str): output name
146
+
147
+ Returns:
148
+ Union[NodeProto, None]: the node created or None if failed.
149
+ """
150
+ is_self_attention = not self.is_cross_attention
151
+
152
+ if is_self_attention:
153
+ if q_matmul.input[0] != input or k_matmul.input[0] != input or v_matmul.input[0] != input:
154
+ logger.debug(
155
+ "For self attention, input hidden state for q and k/v shall be same. Got %s, %s, %s",
156
+ q_matmul.input[0],
157
+ k_matmul.input[0],
158
+ v_matmul.input[0],
159
+ )
160
+ return None
161
+ else:
162
+ if q_matmul.input[0] != input or (k_matmul.input[0] != v_matmul.input[0]) or (k_matmul.input[0] == input):
163
+ logger.debug(
164
+ "For cross attention, input hidden state for q and k/v shall be different. Got %s, %s, %s",
165
+ q_matmul.input[0],
166
+ k_matmul.input[0],
167
+ v_matmul.input[0],
168
+ )
169
+ return None
170
+
171
+ if hidden_size > 0 and (hidden_size % num_heads) != 0:
172
+ logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
173
+ return None
174
+
175
+ q_weight = self.model.get_initializer(q_matmul.input[1])
176
+ k_weight = self.model.get_initializer(k_matmul.input[1])
177
+ v_weight = self.model.get_initializer(v_matmul.input[1])
178
+ if not (q_weight and k_weight and v_weight):
179
+ return None
180
+
181
+ # Sometimes weights are stored in fp16
182
+ float_type = q_weight.data_type
183
+
184
+ qw = NumpyHelper.to_array(q_weight)
185
+ kw = NumpyHelper.to_array(k_weight)
186
+ vw = NumpyHelper.to_array(v_weight)
187
+ logger.debug(f"qw={qw.shape} kw={kw.shape} vw={vw.shape} hidden_size={hidden_size}")
188
+
189
+ # assert q and k have same shape as expected
190
+ if is_self_attention:
191
+ if qw.shape != kw.shape or qw.shape != vw.shape:
192
+ return None
193
+
194
+ qw_in_size = qw.shape[0]
195
+
196
+ if hidden_size > 0 and hidden_size != qw_in_size:
197
+ raise ValueError(
198
+ f"Input hidden size ({hidden_size}) is not same as weight dimension of q,k,v ({qw_in_size}). "
199
+ "Please provide a correct input hidden size or pass in 0"
200
+ )
201
+
202
+ # All the matrices can have the same shape or q, k matrics can have the same shape with v being different
203
+ # For 2d weights, the shapes would be [in_size, out_size].
204
+ # For 3d weights, shape would be [in_size, a, b] where a*b = out_size
205
+ qw_out_size = int(np.prod(qw.shape[1:]))
206
+
207
+ if self.enable_packed_qkv:
208
+ attention_node_name = self.model.create_node_name("MultiHeadAttention")
209
+
210
+ c = qw_in_size
211
+ n = num_heads
212
+ h = qw_out_size // num_heads
213
+
214
+ # Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 3, H] shape
215
+ qkv_weight = np.dstack([qw.reshape(c, n, h), kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(
216
+ c, n * 3 * h
217
+ )
218
+
219
+ matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_QKV")
220
+ self.add_initializer(
221
+ name=matmul_node_name + "_weight",
222
+ data_type=float_type,
223
+ dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
224
+ vals=qkv_weight,
225
+ )
226
+
227
+ matmul_node = helper.make_node(
228
+ "MatMul",
229
+ inputs=[k_matmul.input[0], matmul_node_name + "_weight"],
230
+ outputs=[matmul_node_name + "_out"],
231
+ name=matmul_node_name,
232
+ )
233
+ self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
234
+
235
+ self.add_initializer(
236
+ name=matmul_node_name + "_reshape_shape",
237
+ data_type=TensorProto.INT64,
238
+ dims=[5],
239
+ vals=[0, 0, n, 3, h],
240
+ raw=False,
241
+ )
242
+
243
+ reshape_node = helper.make_node(
244
+ "Reshape",
245
+ inputs=[
246
+ matmul_node_name + "_out",
247
+ matmul_node_name + "_reshape_shape",
248
+ ],
249
+ outputs=[attention_node_name + "_qkv_input"],
250
+ name=matmul_node_name + "_reshape",
251
+ )
252
+ self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
253
+ self.nodes_to_add.extend([matmul_node, reshape_node])
254
+ self.nodes_to_remove.extend([q_matmul, k_matmul, v_matmul])
255
+
256
+ else:
257
+ qkv_weight = np.stack((qw, kw, vw), axis=1)
258
+ qkv_weight_dim = 3 * qw_out_size
259
+
260
+ attention_node_name = self.model.create_node_name("Attention")
261
+
262
+ self.add_initializer(
263
+ name=attention_node_name + "_qkv_weight",
264
+ data_type=float_type,
265
+ dims=[qw_in_size, qkv_weight_dim],
266
+ vals=qkv_weight,
267
+ )
268
+ else: # cross attention
269
+ attention_node_name = self.model.create_node_name("MultiHeadAttention")
270
+ if self.enable_packed_kv:
271
+ if kw.shape != vw.shape:
272
+ return None
273
+
274
+ kw_in_size = kw.shape[0]
275
+ vw_in_size = vw.shape[0]
276
+ assert kw_in_size == vw_in_size
277
+
278
+ qw_out_size = qw.shape[1]
279
+ kw_out_size = kw.shape[1]
280
+ vw_out_size = vw.shape[1]
281
+ assert qw_out_size == vw_out_size and kw_out_size == vw_out_size
282
+
283
+ c = kw_in_size
284
+ n = num_heads
285
+ h = kw_out_size // num_heads
286
+
287
+ # Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 2, H] shape
288
+ kv_weight = np.dstack([kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(c, n * 2 * h)
289
+
290
+ matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV")
291
+ self.add_initializer(
292
+ name=matmul_node_name + "_weight",
293
+ data_type=float_type,
294
+ dims=[kv_weight.shape[0], kv_weight.shape[1]],
295
+ vals=kv_weight,
296
+ )
297
+
298
+ matmul_node = helper.make_node(
299
+ "MatMul",
300
+ inputs=[k_matmul.input[0], matmul_node_name + "_weight"],
301
+ outputs=[matmul_node_name + "_out"],
302
+ name=matmul_node_name,
303
+ )
304
+ self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
305
+
306
+ self.add_initializer(
307
+ name=matmul_node_name + "_reshape_shape",
308
+ data_type=TensorProto.INT64,
309
+ dims=[5],
310
+ vals=[0, 0, n, 2, h],
311
+ raw=False,
312
+ )
313
+
314
+ reshape_node = helper.make_node(
315
+ "Reshape",
316
+ inputs=[
317
+ matmul_node_name + "_out",
318
+ matmul_node_name + "_reshape_shape",
319
+ ],
320
+ outputs=[attention_node_name + "_kv_input"],
321
+ name=matmul_node_name + "_reshape",
322
+ )
323
+ self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
324
+ self.nodes_to_add.extend([matmul_node, reshape_node])
325
+ self.nodes_to_remove.extend([k_matmul, v_matmul])
326
+
327
+ # No bias, use zeros
328
+ qkv_bias = np.zeros([3, hidden_size], dtype=np.float32)
329
+ qkv_bias_dim = 3 * hidden_size
330
+
331
+ self.add_initializer(
332
+ name=attention_node_name + "_qkv_bias",
333
+ data_type=float_type,
334
+ dims=[qkv_bias_dim],
335
+ vals=qkv_bias,
336
+ )
337
+
338
+ if is_self_attention:
339
+ if not self.enable_packed_qkv:
340
+ attention_inputs = [
341
+ input,
342
+ attention_node_name + "_qkv_weight",
343
+ attention_node_name + "_qkv_bias",
344
+ ]
345
+ else:
346
+ attention_inputs = [attention_node_name + "_qkv_input"]
347
+ else:
348
+ if not self.enable_packed_kv:
349
+ attention_inputs = [
350
+ q_matmul.output[0],
351
+ k_matmul.output[0],
352
+ v_matmul.output[0],
353
+ attention_node_name + "_qkv_bias",
354
+ ]
355
+ else:
356
+ attention_inputs = [
357
+ q_matmul.output[0],
358
+ attention_node_name + "_kv_input",
359
+ ]
360
+
361
+ attention_node = helper.make_node(
362
+ "Attention" if (is_self_attention and not self.enable_packed_qkv) else "MultiHeadAttention",
363
+ inputs=attention_inputs,
364
+ outputs=[output],
365
+ name=attention_node_name,
366
+ )
367
+ attention_node.domain = "com.microsoft"
368
+ attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
369
+
370
+ counter_name = (
371
+ "Attention (self attention)"
372
+ if is_self_attention and not self.enable_packed_qkv
373
+ else "MultiHeadAttention ({})".format(
374
+ "self attention with packed qkv"
375
+ if self.enable_packed_qkv
376
+ else "cross attention with packed kv" if self.enable_packed_kv else "cross attention"
377
+ )
378
+ )
379
+ self.increase_counter(counter_name)
380
+ return attention_node
381
+
382
+ def create_attention_node_lora(
383
+ self,
384
+ q_matmul_add: NodeProto,
385
+ k_matmul_add: NodeProto,
386
+ v_matmul_add: NodeProto,
387
+ num_heads: int,
388
+ hidden_size: int,
389
+ input: str,
390
+ output: str,
391
+ ) -> Union[NodeProto, None]:
392
+ """Create an Attention node.
393
+
394
+ Args:
395
+ q_matmul (NodeProto): MatMul node in fully connection for Q
396
+ k_matmul (NodeProto): MatMul node in fully connection for K
397
+ v_matmul (NodeProto): MatMul node in fully connection for V
398
+ num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
399
+ hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
400
+ input (str): input name
401
+ output (str): output name
402
+
403
+ Returns:
404
+ Union[NodeProto, None]: the node created or None if failed.
405
+ """
406
+ is_self_attention = not self.is_cross_attention
407
+
408
+ q_matmul = self.model.match_parent(q_matmul_add, "MatMul", 0)
409
+ k_matmul = self.model.match_parent(k_matmul_add, "MatMul", 0)
410
+ v_matmul = self.model.match_parent(v_matmul_add, "MatMul", 0)
411
+
412
+ q_lora_nodes = self.match_lora_path(q_matmul_add)
413
+ if q_lora_nodes is None:
414
+ return None
415
+ (q_lora_last_node, q_lora_matmul_1) = q_lora_nodes
416
+
417
+ k_lora_nodes = self.match_lora_path(k_matmul_add)
418
+ if k_lora_nodes is None:
419
+ return None
420
+ (k_lora_last_node, k_lora_matmul_1) = k_lora_nodes
421
+
422
+ v_lora_nodes = self.match_lora_path(v_matmul_add)
423
+ if v_lora_nodes is None:
424
+ return None
425
+ (v_lora_last_node, v_lora_matmul_1) = v_lora_nodes
426
+
427
+ if is_self_attention:
428
+ if q_matmul.input[0] != input or k_matmul.input[0] != input or v_matmul.input[0] != input:
429
+ logger.debug(
430
+ "For self attention, input hidden state for q and k/v shall be same. Got %s, %s, %s",
431
+ q_matmul.input[0],
432
+ k_matmul.input[0],
433
+ v_matmul.input[0],
434
+ )
435
+ return None
436
+
437
+ if (
438
+ q_lora_matmul_1.input[0] != input
439
+ or k_lora_matmul_1.input[0] != input
440
+ or v_lora_matmul_1.input[0] != input
441
+ ):
442
+ logger.debug(
443
+ "For self attention, input hidden state for LoRA q and k/v weights shall be same. Got %s, %s, %s",
444
+ q_lora_matmul_1.input[0],
445
+ k_lora_matmul_1.input[0],
446
+ v_lora_matmul_1.input[0],
447
+ )
448
+ return None
449
+ else:
450
+ if q_matmul.input[0] != input or (k_matmul.input[0] != v_matmul.input[0]) or (k_matmul.input[0] == input):
451
+ logger.debug(
452
+ "For cross attention, input hidden state for q and k/v shall be different. Got %s, %s, %s",
453
+ q_matmul.input[0],
454
+ k_matmul.input[0],
455
+ v_matmul.input[0],
456
+ )
457
+ return None
458
+
459
+ if (
460
+ q_lora_matmul_1.input[0] != input
461
+ or (k_lora_matmul_1.input[0] != v_lora_matmul_1.input[0])
462
+ or (k_matmul.input[0] == input)
463
+ ):
464
+ logger.debug(
465
+ (
466
+ "For cross attention, input hidden state for LoRA q and k/v weights shall be different. "
467
+ "Got %s, %s, %s"
468
+ ),
469
+ q_lora_matmul_1.input[0],
470
+ k_lora_matmul_1.input[0],
471
+ v_lora_matmul_1.input[0],
472
+ )
473
+ return None
474
+
475
+ if hidden_size > 0 and (hidden_size % num_heads) != 0:
476
+ logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
477
+ return None
478
+
479
+ q_weight = self.model.get_initializer(q_matmul.input[1])
480
+ k_weight = self.model.get_initializer(k_matmul.input[1])
481
+ v_weight = self.model.get_initializer(v_matmul.input[1])
482
+ if not (q_weight and k_weight and v_weight):
483
+ return None
484
+
485
+ # Sometimes weights are stored in fp16
486
+ if q_weight.data_type == 10:
487
+ logger.debug("weights are in fp16. Please run fp16 conversion after optimization")
488
+ return None
489
+
490
+ qw = NumpyHelper.to_array(q_weight)
491
+ kw = NumpyHelper.to_array(k_weight)
492
+ vw = NumpyHelper.to_array(v_weight)
493
+ logger.debug(f"qw={qw.shape} kw={kw.shape} vw={vw.shape} hidden_size={hidden_size}")
494
+
495
+ # assert q and k have same shape as expected
496
+ if is_self_attention:
497
+ if qw.shape != kw.shape or qw.shape != vw.shape:
498
+ return None
499
+
500
+ qw_in_size = qw.shape[0]
501
+
502
+ if hidden_size > 0 and hidden_size != qw_in_size:
503
+ raise ValueError(
504
+ f"Input hidden size ({hidden_size}) is not same as weight dimension of q,k,v ({qw_in_size}). "
505
+ "Please provide a correct input hidden size or pass in 0"
506
+ )
507
+
508
+ # All the matrices can have the same shape or q, k matrics can have the same shape with v being different
509
+ # For 2d weights, the shapes would be [in_size, out_size].
510
+ # For 3d weights, shape would be [in_size, a, b] where a*b = out_size
511
+ qw_out_size = int(np.prod(qw.shape[1:]))
512
+
513
+ if self.enable_packed_qkv:
514
+ attention_node_name = self.model.create_node_name("MultiHeadAttention")
515
+
516
+ c = qw_in_size
517
+ n = num_heads
518
+ h = qw_out_size // num_heads
519
+
520
+ # Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 3, H] shape
521
+ qkv_weight = np.dstack([qw.reshape(c, n, h), kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(
522
+ c, n * 3 * h
523
+ )
524
+
525
+ matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_QKV")
526
+ self.add_initializer(
527
+ name=matmul_node_name + "_weight",
528
+ data_type=TensorProto.FLOAT,
529
+ dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
530
+ vals=qkv_weight,
531
+ )
532
+
533
+ matmul_node = helper.make_node(
534
+ "MatMul",
535
+ inputs=[k_matmul.input[0], matmul_node_name + "_weight"],
536
+ outputs=[matmul_node_name + "_out"],
537
+ name=matmul_node_name,
538
+ )
539
+ self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
540
+
541
+ # Do the same thing with the LoRA weights, but don't constant fold the result. The goal is to allow
542
+ # the Q/K/V weights to be changed without having to re-run the optimizer.
543
+ lora_weight_shape_tensor_name = q_lora_last_node.name + "_reshape_shape"
544
+
545
+ self.add_initializer(
546
+ name=lora_weight_shape_tensor_name,
547
+ data_type=TensorProto.INT64,
548
+ dims=[4],
549
+ vals=[0, 0, n, h],
550
+ raw=False,
551
+ )
552
+
553
+ # Reshape the LoRA Q weights
554
+ q_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_Q")
555
+ q_lora_reshape_node = helper.make_node(
556
+ "Reshape",
557
+ inputs=[q_lora_last_node.output[0], lora_weight_shape_tensor_name],
558
+ outputs=[q_lora_reshape_node_name + "_out"],
559
+ name=q_lora_reshape_node_name,
560
+ )
561
+ self.node_name_to_graph_name[q_lora_reshape_node.name] = self.this_graph_name
562
+
563
+ # Reshape the LoRA K weights
564
+ k_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_K")
565
+ k_lora_reshape_node = helper.make_node(
566
+ "Reshape",
567
+ inputs=[k_lora_last_node.output[0], lora_weight_shape_tensor_name],
568
+ outputs=[k_lora_reshape_node_name + "_out"],
569
+ name=k_lora_reshape_node_name,
570
+ )
571
+ self.node_name_to_graph_name[k_lora_reshape_node.name] = self.this_graph_name
572
+
573
+ # Reshape the LoRA V weights
574
+ v_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_V")
575
+ v_lora_reshape_node = helper.make_node(
576
+ "Reshape",
577
+ inputs=[v_lora_last_node.output[0], lora_weight_shape_tensor_name],
578
+ outputs=[v_lora_reshape_node_name + "_out"],
579
+ name=v_lora_reshape_node_name,
580
+ )
581
+ self.node_name_to_graph_name[v_lora_reshape_node.name] = self.this_graph_name
582
+
583
+ # Concat the reshaped LoRA Q/K/V weights together on the third axis
584
+ qkv_lora_concat_node_name = self.model.create_node_name("Concat", name_prefix="Concat_LoRA_QKV")
585
+ qkv_lora_concat_node = helper.make_node(
586
+ "Concat",
587
+ inputs=[
588
+ q_lora_reshape_node.output[0],
589
+ k_lora_reshape_node.output[0],
590
+ v_lora_reshape_node.output[0],
591
+ ],
592
+ outputs=[qkv_lora_concat_node_name + "_out"],
593
+ name=qkv_lora_concat_node_name,
594
+ )
595
+ qkv_lora_concat_node.attribute.extend([helper.make_attribute("axis", 3)])
596
+ self.node_name_to_graph_name[qkv_lora_concat_node.name] = self.this_graph_name
597
+
598
+ # Reshape the LoRA concatenated weights to [..., n * 3 * h]
599
+ reshaped_lora_weights_shape_tensor_name = qkv_lora_concat_node.name + "_reshape_shape"
600
+ self.add_initializer(
601
+ name=reshaped_lora_weights_shape_tensor_name,
602
+ data_type=TensorProto.INT64,
603
+ dims=[3],
604
+ vals=[0, 0, n * 3 * h],
605
+ raw=False,
606
+ )
607
+
608
+ qkv_lora_reshaped_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_QKV")
609
+ qkv_lora_reshaped_node = helper.make_node(
610
+ "Reshape",
611
+ inputs=[qkv_lora_concat_node.output[0], reshaped_lora_weights_shape_tensor_name],
612
+ outputs=[qkv_lora_reshaped_node_name + "_out"],
613
+ name=qkv_lora_reshaped_node_name,
614
+ )
615
+ self.node_name_to_graph_name[qkv_lora_reshaped_node.name] = self.this_graph_name
616
+
617
+ # Add the LoRA Q/K/V weights to the base Q/K/V weights
618
+ add_weights_node_name = self.model.create_node_name("Add", name_prefix="Add_Weights_QKV")
619
+ add_weights_node = helper.make_node(
620
+ "Add",
621
+ inputs=[qkv_lora_reshaped_node.output[0], matmul_node.output[0]],
622
+ outputs=[add_weights_node_name + "_out"],
623
+ name=add_weights_node_name,
624
+ )
625
+ self.node_name_to_graph_name[add_weights_node.name] = self.this_graph_name
626
+
627
+ # Finally, reshape the concatenated Q/K/V result to 5D
628
+ shape_tensor_name = add_weights_node_name + "_reshape_shape"
629
+ self.add_initializer(
630
+ name=shape_tensor_name,
631
+ data_type=TensorProto.INT64,
632
+ dims=[5],
633
+ vals=[0, 0, n, 3, h],
634
+ raw=False,
635
+ )
636
+
637
+ reshape_node = helper.make_node(
638
+ "Reshape",
639
+ inputs=[add_weights_node.output[0], shape_tensor_name],
640
+ outputs=[attention_node_name + "_qkv_input"],
641
+ name=add_weights_node_name + "_reshape",
642
+ )
643
+ self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
644
+
645
+ self.nodes_to_add.extend(
646
+ [
647
+ matmul_node,
648
+ q_lora_reshape_node,
649
+ k_lora_reshape_node,
650
+ v_lora_reshape_node,
651
+ qkv_lora_concat_node,
652
+ qkv_lora_reshaped_node,
653
+ add_weights_node,
654
+ reshape_node,
655
+ ]
656
+ )
657
+ self.nodes_to_remove.extend([q_matmul, k_matmul, v_matmul, q_matmul_add, k_matmul_add, v_matmul_add])
658
+ else:
659
+ # TODO: Support non-packed QKV
660
+ return None
661
+ else: # cross attention
662
+ attention_node_name = self.model.create_node_name("MultiHeadAttention")
663
+ if self.enable_packed_kv:
664
+ if kw.shape != vw.shape:
665
+ return None
666
+
667
+ kw_in_size = kw.shape[0]
668
+ vw_in_size = vw.shape[0]
669
+ assert kw_in_size == vw_in_size
670
+
671
+ qw_out_size = qw.shape[1]
672
+ kw_out_size = kw.shape[1]
673
+ vw_out_size = vw.shape[1]
674
+ assert qw_out_size == vw_out_size and kw_out_size == vw_out_size
675
+
676
+ c = kw_in_size
677
+ n = num_heads
678
+ h = kw_out_size // num_heads
679
+
680
+ # Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 2, H] shape
681
+ kv_weight = np.dstack([kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(c, n * 2 * h)
682
+
683
+ matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV")
684
+ self.add_initializer(
685
+ name=matmul_node_name + "_weight",
686
+ data_type=TensorProto.FLOAT,
687
+ dims=[kv_weight.shape[0], kv_weight.shape[1]],
688
+ vals=kv_weight,
689
+ )
690
+
691
+ matmul_node = helper.make_node(
692
+ "MatMul",
693
+ inputs=[k_matmul.input[0], matmul_node_name + "_weight"],
694
+ outputs=[matmul_node_name + "_out"],
695
+ name=matmul_node_name,
696
+ )
697
+ self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
698
+
699
+ # Do the same thing with the LoRA weights, but don't constant fold the result. The goal is to allow
700
+ # the Q/K/V weights to be changed without having to re-run the optimizer.
701
+ kv_lora_weight_shape_tensor_name = q_lora_last_node.name + "_reshape_shape"
702
+ self.add_initializer(
703
+ name=kv_lora_weight_shape_tensor_name,
704
+ data_type=TensorProto.INT64,
705
+ dims=[4],
706
+ vals=[0, 0, n, h],
707
+ raw=False,
708
+ )
709
+
710
+ # Reshape the LoRA K weights
711
+ k_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_K")
712
+ k_lora_reshape_node = helper.make_node(
713
+ "Reshape",
714
+ inputs=[k_lora_last_node.output[0], kv_lora_weight_shape_tensor_name],
715
+ outputs=[k_lora_reshape_node_name + "_out"],
716
+ name=k_lora_reshape_node_name,
717
+ )
718
+ self.node_name_to_graph_name[k_lora_reshape_node.name] = self.this_graph_name
719
+
720
+ # Reshape the LoRA V weights
721
+ v_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_V")
722
+ v_lora_reshape_node = helper.make_node(
723
+ "Reshape",
724
+ inputs=[v_lora_last_node.output[0], kv_lora_weight_shape_tensor_name],
725
+ outputs=[v_lora_reshape_node_name + "_out"],
726
+ name=v_lora_reshape_node_name,
727
+ )
728
+ self.node_name_to_graph_name[v_lora_reshape_node.name] = self.this_graph_name
729
+
730
+ # Concat the reshaped LoRA K/V weights together on the third axis
731
+ kv_lora_concat_node_name = self.model.create_node_name("Concat", name_prefix="Concat_LoRA_KV")
732
+ kv_lora_concat_node = helper.make_node(
733
+ "Concat",
734
+ inputs=[k_lora_reshape_node.output[0], v_lora_reshape_node.output[0]],
735
+ outputs=[kv_lora_concat_node_name + "_out"],
736
+ name=kv_lora_concat_node_name,
737
+ )
738
+ kv_lora_concat_node.attribute.extend([helper.make_attribute("axis", 3)])
739
+ self.node_name_to_graph_name[kv_lora_concat_node.name] = self.this_graph_name
740
+
741
+ # Reshape the LoRA concatenated weights to [..., n * 2 * h]
742
+ reshaped_kv_lora_weights_shape_tensor_name = kv_lora_concat_node.name + "_reshape_shape"
743
+ self.add_initializer(
744
+ name=reshaped_kv_lora_weights_shape_tensor_name,
745
+ data_type=TensorProto.INT64,
746
+ dims=[3],
747
+ vals=[0, 0, n * 2 * h],
748
+ raw=False,
749
+ )
750
+
751
+ kv_lora_reshaped_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_KV")
752
+ kv_lora_reshaped_node = helper.make_node(
753
+ "Reshape",
754
+ inputs=[kv_lora_concat_node.output[0], reshaped_kv_lora_weights_shape_tensor_name],
755
+ outputs=[kv_lora_reshaped_node_name + "_out"],
756
+ name=kv_lora_reshaped_node_name,
757
+ )
758
+ self.node_name_to_graph_name[kv_lora_reshaped_node.name] = self.this_graph_name
759
+
760
+ # Add the LoRA K/V weights to the base K/V weights
761
+ add_kv_weights_node_name = self.model.create_node_name("Add", name_prefix="Add_Weights_KV")
762
+ add_kv_weights_node = helper.make_node(
763
+ "Add",
764
+ inputs=[kv_lora_reshaped_node.output[0], matmul_node.output[0]],
765
+ outputs=[add_kv_weights_node_name + "_out"],
766
+ name=add_kv_weights_node_name,
767
+ )
768
+ self.node_name_to_graph_name[add_kv_weights_node.name] = self.this_graph_name
769
+
770
+ # Finally, reshape the concatenated K/V result to 5D
771
+ shape_tensor_name = add_kv_weights_node_name + "_reshape_shape"
772
+ self.add_initializer(
773
+ name=shape_tensor_name,
774
+ data_type=TensorProto.INT64,
775
+ dims=[5],
776
+ vals=[0, 0, n, 2, h],
777
+ raw=False,
778
+ )
779
+
780
+ reshape_node = helper.make_node(
781
+ "Reshape",
782
+ inputs=[add_kv_weights_node.output[0], shape_tensor_name],
783
+ outputs=[attention_node_name + "_kv_input"],
784
+ name=add_kv_weights_node_name + "_reshape",
785
+ )
786
+ self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
787
+ self.nodes_to_add.extend(
788
+ [
789
+ matmul_node,
790
+ k_lora_reshape_node,
791
+ v_lora_reshape_node,
792
+ kv_lora_concat_node,
793
+ kv_lora_reshaped_node,
794
+ add_kv_weights_node,
795
+ reshape_node,
796
+ ]
797
+ )
798
+ self.nodes_to_remove.extend([k_matmul, v_matmul, k_matmul_add, v_matmul_add])
799
+ else:
800
+ # TODO: Support non-packed KV
801
+ return None
802
+
803
+ # No bias, use zeros
804
+ qkv_bias = np.zeros([3, hidden_size], dtype=np.float32)
805
+ qkv_bias_dim = 3 * hidden_size
806
+ self.add_initializer(
807
+ name=attention_node_name + "_qkv_bias",
808
+ data_type=TensorProto.FLOAT,
809
+ dims=[qkv_bias_dim],
810
+ vals=qkv_bias,
811
+ )
812
+
813
+ if is_self_attention:
814
+ if not self.enable_packed_qkv:
815
+ # TODO: Support non-packed QKV
816
+ return None
817
+ else:
818
+ attention_inputs = [attention_node_name + "_qkv_input"]
819
+ else:
820
+ if not self.enable_packed_kv:
821
+ # TODO: Support non-packed QKV
822
+ return None
823
+ else:
824
+ attention_inputs = [
825
+ q_matmul_add.output[0],
826
+ attention_node_name + "_kv_input",
827
+ ]
828
+
829
+ attention_node = helper.make_node(
830
+ "Attention" if (is_self_attention and not self.enable_packed_qkv) else "MultiHeadAttention",
831
+ inputs=attention_inputs,
832
+ outputs=[output],
833
+ name=attention_node_name,
834
+ )
835
+ attention_node.domain = "com.microsoft"
836
+ attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
837
+
838
+ counter_name = (
839
+ "Attention (self attention)"
840
+ if is_self_attention and not self.enable_packed_qkv
841
+ else "MultiHeadAttention ({})".format(
842
+ "self attention with packed qkv"
843
+ if self.enable_packed_qkv
844
+ else "cross attention with packed kv" if self.enable_packed_kv else "cross attention"
845
+ )
846
+ )
847
+ self.increase_counter(counter_name)
848
+ return attention_node
849
+
850
+ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
851
+ if self.fuse_a1111_fp16(normalize_node, input_name_to_nodes, output_name_to_node):
852
+ return
853
+
854
+ node_before_layernorm = self.model.match_parent(normalize_node, "Add", 0)
855
+
856
+ # In SD 1.5, for self attention, LayerNorm has parent Reshape
857
+ if node_before_layernorm is None and not self.is_cross_attention:
858
+ node_before_layernorm = self.model.match_parent(normalize_node, "Reshape", 0)
859
+
860
+ if node_before_layernorm is None:
861
+ return
862
+
863
+ root_input = node_before_layernorm.output[0]
864
+
865
+ children_nodes = input_name_to_nodes[root_input]
866
+ skip_add = None
867
+ for node in children_nodes:
868
+ if node.op_type == "Add": # SkipLayerNormalization fusion is not applied yet
869
+ skip_add = node
870
+ break
871
+ if skip_add is None:
872
+ return
873
+
874
+ match_qkv = self.match_qkv_torch1(root_input, skip_add) or self.match_qkv_torch2(root_input, skip_add)
875
+ if match_qkv is not None:
876
+ is_torch2, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v = match_qkv
877
+
878
+ attention_last_node = reshape_qkv
879
+
880
+ q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2)
881
+ if q_num_heads <= 0:
882
+ logger.debug("fuse_attention: failed to detect num_heads")
883
+ return
884
+
885
+ # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
886
+ new_node = self.create_attention_node(
887
+ matmul_q,
888
+ matmul_k,
889
+ matmul_v,
890
+ q_num_heads,
891
+ q_hidden_size,
892
+ input=normalize_node.output[0],
893
+ output=attention_last_node.output[0],
894
+ )
895
+ if new_node is None:
896
+ return
897
+ else:
898
+ # Check if we have a LoRA pattern
899
+ match_qkv = self.match_qkv_torch1_lora(root_input, skip_add) or self.match_qkv_torch2_lora(
900
+ root_input, skip_add
901
+ )
902
+ if match_qkv is None:
903
+ return
904
+
905
+ is_torch2, reshape_qkv, transpose_qkv, reshape_q, matmul_add_q, matmul_add_k, matmul_add_v = match_qkv
906
+
907
+ attention_last_node = reshape_qkv
908
+
909
+ q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2)
910
+ if q_num_heads <= 0:
911
+ logger.debug("fuse_attention: failed to detect num_heads")
912
+ return
913
+
914
+ # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
915
+ new_node = self.create_attention_node_lora(
916
+ matmul_add_q,
917
+ matmul_add_k,
918
+ matmul_add_v,
919
+ q_num_heads,
920
+ q_hidden_size,
921
+ input=normalize_node.output[0],
922
+ output=attention_last_node.output[0],
923
+ )
924
+ if new_node is None:
925
+ return
926
+
927
+ q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2)
928
+ if q_num_heads <= 0:
929
+ logger.debug("fuse_attention: failed to detect num_heads")
930
+ return
931
+
932
+ self.nodes_to_add.append(new_node)
933
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
934
+
935
+ self.nodes_to_remove.extend([attention_last_node, transpose_qkv])
936
+
937
+ # Use prune graph to remove nodes since they are shared by all attention nodes.
938
+ self.prune_graph = True
939
+
940
+ def match_qkv_torch1(self, root_input, skip_add):
941
+ """Match Q, K and V paths exported by PyTorch 1.*"""
942
+ another_input = 1 if skip_add.input[0] == root_input else 0
943
+ qkv_nodes = self.model.match_parent_path(
944
+ skip_add,
945
+ ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
946
+ [another_input, None, None, 0, 0, 0],
947
+ )
948
+
949
+ if qkv_nodes is None:
950
+ return None
951
+
952
+ (_, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes
953
+
954
+ # No bias. For cross-attention, the input of the MatMul is encoder_hidden_states graph input.
955
+ v_nodes = self.model.match_parent_path(matmul_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
956
+ if v_nodes is None:
957
+ logger.debug("fuse_attention: failed to match v path")
958
+ return None
959
+ (_, _, _, matmul_v) = v_nodes
960
+
961
+ qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0])
962
+ if qk_nodes is not None:
963
+ (_softmax_qk, _mul_qk, matmul_qk) = qk_nodes
964
+ else:
965
+ qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
966
+ if qk_nodes is not None:
967
+ (_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes
968
+ else:
969
+ logger.debug("fuse_attention: failed to match qk path")
970
+ return None
971
+
972
+ q_nodes = self.model.match_parent_path(matmul_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0])
973
+ if q_nodes is None:
974
+ logger.debug("fuse_attention: failed to match q path")
975
+ return None
976
+ (_, _transpose_q, reshape_q, matmul_q) = q_nodes
977
+
978
+ k_nodes = self.model.match_parent_path(
979
+ matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0, 0]
980
+ )
981
+ if k_nodes is None:
982
+ logger.debug("fuse_attention: failed to match k path")
983
+ return None
984
+
985
+ (_, _, _, _, matmul_k) = k_nodes
986
+
987
+ return False, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v
988
+
989
+ def match_qkv_torch2(self, root_input, skip_add):
990
+ """Match Q, K and V paths exported by PyTorch 2.*"""
991
+ another_input = 1 if skip_add.input[0] == root_input else 0
992
+ qkv_nodes = self.model.match_parent_path(
993
+ skip_add,
994
+ ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
995
+ [another_input, None, None, 0, 0],
996
+ )
997
+
998
+ if qkv_nodes is None:
999
+ return None
1000
+
1001
+ (_, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
1002
+
1003
+ v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "MatMul"], [1, 0, 0])
1004
+ if v_nodes is None:
1005
+ logger.debug("fuse_attention: failed to match v path")
1006
+ return None
1007
+ (_, _, matmul_v) = v_nodes
1008
+
1009
+ qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
1010
+ if qk_nodes is not None:
1011
+ (_softmax_qk, matmul_qk) = qk_nodes
1012
+ else:
1013
+ logger.debug("fuse_attention: failed to match qk path")
1014
+ return None
1015
+
1016
+ q_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "MatMul"], [0, None, 0, 0])
1017
+ if q_nodes is None:
1018
+ logger.debug("fuse_attention: failed to match q path")
1019
+ return None
1020
+ (mul_q, _transpose_q, reshape_q, matmul_q) = q_nodes
1021
+
1022
+ k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "MatMul"], [1, None, 0, 0])
1023
+ if k_nodes is None:
1024
+ logger.debug("fuse_attention: failed to match k path")
1025
+ return None
1026
+
1027
+ (_mul_k, _, _, matmul_k) = k_nodes
1028
+
1029
+ # The scalar for Q and K is sqrt(1.0/sqrt(head_size)).
1030
+ mul_q_nodes = self.model.match_parent_path(
1031
+ mul_q,
1032
+ ["Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape", "Transpose", "Reshape"],
1033
+ [None, 0, 1, 0, 0, 0, 0, 0],
1034
+ )
1035
+ if mul_q_nodes is None or mul_q_nodes[-1] != reshape_q:
1036
+ logger.debug("fuse_attention: failed to match mul_q path")
1037
+ return None
1038
+
1039
+ return True, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v
1040
+
1041
+ def match_qkv_torch1_lora(self, root_input, skip_add):
1042
+ """Match Q, K and V paths exported by PyTorch 1 that contains LoRA patterns.*"""
1043
+ another_input = 1 if skip_add.input[0] == root_input else 0
1044
+ qkv_nodes = self.model.match_parent_path(
1045
+ skip_add,
1046
+ ["Add", "Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
1047
+ [another_input, 0, None, None, 0, 0, 0],
1048
+ )
1049
+ if qkv_nodes is None:
1050
+ return None
1051
+
1052
+ (_, _, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes
1053
+
1054
+ # No bias. For cross-attention, the input of the MatMul is encoder_hidden_states graph input.
1055
+ v_nodes = self.model.match_parent_path(matmul_qkv, ["Reshape", "Transpose", "Reshape", "Add"], [1, 0, 0, 0])
1056
+ if v_nodes is None:
1057
+ logger.debug("fuse_attention: failed to match LoRA v path")
1058
+ return None
1059
+ (_, _, _, matmul_add_v) = v_nodes
1060
+
1061
+ qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0])
1062
+ if qk_nodes is not None:
1063
+ (_softmax_qk, _mul_qk, matmul_qk) = qk_nodes
1064
+ else:
1065
+ qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
1066
+ if qk_nodes is not None:
1067
+ (_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes
1068
+ else:
1069
+ logger.debug("fuse_attention: failed to match LoRA qk path")
1070
+ return None
1071
+
1072
+ q_nodes = self.model.match_parent_path(matmul_qk, ["Reshape", "Transpose", "Reshape", "Add"], [0, 0, 0, 0])
1073
+ if q_nodes is None:
1074
+ logger.debug("fuse_attention: failed to match LoRA q path")
1075
+ return None
1076
+ (_, _transpose_q, reshape_q, matmul_add_q) = q_nodes
1077
+
1078
+ k_nodes = self.model.match_parent_path(
1079
+ matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "Add"], [1, 0, 0, 0, 0]
1080
+ )
1081
+ if k_nodes is None:
1082
+ logger.debug("fuse_attention: failed to match LoRA k path")
1083
+ return None
1084
+
1085
+ (_, _, _, _, matmul_add_k) = k_nodes
1086
+
1087
+ return False, reshape_qkv, transpose_qkv, reshape_q, matmul_add_q, matmul_add_k, matmul_add_v
1088
+
1089
+ def match_qkv_torch2_lora(self, root_input, skip_add):
1090
+ """Match Q, K and V paths exported by PyTorch 2 that contains LoRA patterns.*"""
1091
+ another_input = 1 if skip_add.input[0] == root_input else 0
1092
+ qkv_nodes = self.model.match_parent_path(
1093
+ skip_add,
1094
+ ["Add", "Add", "MatMul", "Reshape", "Transpose", "MatMul"],
1095
+ [another_input, 0, None, None, 0, 0],
1096
+ )
1097
+ if qkv_nodes is None:
1098
+ return None
1099
+
1100
+ (_, _, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
1101
+
1102
+ v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add"], [1, 0, 0])
1103
+ if v_nodes is None:
1104
+ logger.debug("fuse_attention: failed to match LoRA v path")
1105
+ return None
1106
+ (_, _, matmul_add_v) = v_nodes
1107
+
1108
+ qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
1109
+ if qk_nodes is not None:
1110
+ (_softmax_qk, matmul_qk) = qk_nodes
1111
+ else:
1112
+ logger.debug("fuse_attention: failed to match LoRA qk path")
1113
+ return None
1114
+
1115
+ q_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "Add"], [0, None, 0, 0])
1116
+ if q_nodes is None:
1117
+ logger.debug("fuse_attention: failed to match LoRA q path")
1118
+ return None
1119
+ (mul_q, _transpose_q, reshape_q, matmul_add_q) = q_nodes
1120
+
1121
+ k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "Add"], [1, None, 0, 0])
1122
+ if k_nodes is None:
1123
+ logger.debug("fuse_attention: failed to match LoRA k path")
1124
+ return None
1125
+
1126
+ (_mul_k, _, _, matmul_add_k) = k_nodes
1127
+
1128
+ # The scalar for Q and K is sqrt(1.0/sqrt(head_size)).
1129
+ mul_q_nodes = self.model.match_parent_path(
1130
+ mul_q,
1131
+ ["Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape", "Transpose", "Reshape"],
1132
+ [None, 0, 1, 0, 0, 0, 0, 0],
1133
+ )
1134
+ if mul_q_nodes is None or mul_q_nodes[-1] != reshape_q:
1135
+ logger.debug("fuse_attention: failed to match LoRA mul_q path")
1136
+ return None
1137
+
1138
+ return True, reshape_qkv, transpose_qkv, reshape_q, matmul_add_q, matmul_add_k, matmul_add_v
1139
+
1140
+ def match_lora_path(
1141
+ self,
1142
+ add_node: NodeProto,
1143
+ ):
1144
+ # Lora paths can look like one of the following options:
1145
+ # MatMul -> MatMul -> Add
1146
+ # MatMul -> MatMul -> Mul -> Add
1147
+ # MatMul -> MatMul -> Mul -> Mul -> Add
1148
+
1149
+ # Try matching MatMul -> MatMul -> Add
1150
+ lora_nodes = self.model.match_parent_path(
1151
+ add_node,
1152
+ ["MatMul", "MatMul"],
1153
+ [1, 0],
1154
+ )
1155
+
1156
+ if lora_nodes is not None:
1157
+ (lora_matmul_2_node, lora_matmul_1_node) = lora_nodes
1158
+ return (lora_matmul_2_node, lora_matmul_1_node)
1159
+
1160
+ # Try matching MatMul -> MatMul -> Mul -> Add
1161
+ lora_nodes = self.model.match_parent_path(
1162
+ add_node,
1163
+ ["Mul", "MatMul", "MatMul"],
1164
+ [1, 0, 0],
1165
+ )
1166
+
1167
+ if lora_nodes is not None:
1168
+ (lora_mul_node, _, lora_matmul_1_node) = lora_nodes
1169
+ return (lora_mul_node, lora_matmul_1_node)
1170
+
1171
+ # Try matching MatMul -> MatMul -> Mul -> Mul -> Add
1172
+ lora_nodes = self.model.match_parent_path(
1173
+ add_node,
1174
+ ["Mul", "Mul", "MatMul", "MatMul"],
1175
+ [1, 0, 0, 0],
1176
+ )
1177
+
1178
+ if lora_nodes is not None:
1179
+ (lora_mul_node, _, _, lora_matmul_1_node) = lora_nodes
1180
+ return (lora_mul_node, lora_matmul_1_node)
1181
+
1182
+ return None
1183
+
1184
+ def fuse_a1111_fp16(self, normalize_node, input_name_to_nodes, output_name_to_node):
1185
+ """Fuse attention of fp16 UNet exported in A1111 (stable diffusion webui) extension"""
1186
+ entry_path = self.model.match_parent_path(normalize_node, ["Cast", "Add"], [0, 0])
1187
+ if entry_path is None:
1188
+ entry_path = self.model.match_parent_path(normalize_node, ["Cast", "Reshape"], [0, 0])
1189
+ if entry_path is None:
1190
+ return False
1191
+ _cast, node_before_layernorm = entry_path
1192
+
1193
+ root_input = node_before_layernorm.output[0]
1194
+
1195
+ children_nodes = input_name_to_nodes[root_input]
1196
+ skip_add = None
1197
+ for node in children_nodes:
1198
+ if node.op_type == "Add": # SkipLayerNormalization fusion is not applied yet
1199
+ skip_add = node
1200
+ break
1201
+ if skip_add is None:
1202
+ return False
1203
+
1204
+ match_qkv = self.match_qkv_a1111(root_input, skip_add)
1205
+ if match_qkv is None:
1206
+ return False
1207
+
1208
+ (
1209
+ reshape_qkv,
1210
+ transpose_qkv,
1211
+ reshape_q,
1212
+ matmul_q,
1213
+ matmul_k,
1214
+ matmul_v,
1215
+ ) = match_qkv
1216
+
1217
+ cast_q = self.model.match_parent(matmul_q, "Cast", 0)
1218
+ cast_k = self.model.match_parent(matmul_k, "Cast", 0)
1219
+ cast_v = self.model.match_parent(matmul_v, "Cast", 0)
1220
+ if not (
1221
+ cast_q is not None
1222
+ and cast_k is not None
1223
+ and (cast_q == cast_k if not self.is_cross_attention else cast_q != cast_k)
1224
+ and cast_k == cast_v
1225
+ ):
1226
+ return False
1227
+
1228
+ if cast_q.input[0] != normalize_node.output[0]:
1229
+ return False
1230
+
1231
+ attention_last_node = reshape_qkv
1232
+
1233
+ q_num_heads = self.get_num_heads(reshape_q, True) or self.get_num_heads(reshape_q, False)
1234
+ if q_num_heads <= 0:
1235
+ logger.debug("fuse_attention: failed to detect num_heads")
1236
+ return False
1237
+
1238
+ q_hidden_size = self.get_hidden_size(normalize_node)
1239
+
1240
+ # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
1241
+ new_node = self.create_attention_node(
1242
+ matmul_q,
1243
+ matmul_k,
1244
+ matmul_v,
1245
+ q_num_heads,
1246
+ q_hidden_size,
1247
+ input=matmul_q.input[0],
1248
+ output=attention_last_node.output[0],
1249
+ )
1250
+ if new_node is None:
1251
+ return False
1252
+
1253
+ self.nodes_to_add.append(new_node)
1254
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
1255
+
1256
+ self.nodes_to_remove.extend([attention_last_node, transpose_qkv])
1257
+
1258
+ # Use prune graph to remove nodes since they are shared by all attention nodes.
1259
+ self.prune_graph = True
1260
+ return True
1261
+
1262
+ def match_qkv_a1111(self, root_input, skip_add):
1263
+ """Match Q, K and V paths exported by A1111 (stable diffusion webui) extension"""
1264
+ another_input = 1 if skip_add.input[0] == root_input else 0
1265
+ qkv_nodes = self.model.match_parent_path(
1266
+ skip_add,
1267
+ ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "Einsum"],
1268
+ [another_input, None, None, 0, 0, 0],
1269
+ )
1270
+
1271
+ if qkv_nodes is None:
1272
+ return None
1273
+
1274
+ (_, _, reshape_qkv, transpose_qkv, reshape_einsum, einsum_qkv) = qkv_nodes
1275
+
1276
+ v_nodes = self.model.match_parent_path(einsum_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
1277
+ if v_nodes is None:
1278
+ logger.debug("fuse_attention: failed to match v path")
1279
+ return None
1280
+ (_, _, _, matmul_v) = v_nodes
1281
+
1282
+ qk_nodes = self.model.match_parent_path(
1283
+ einsum_qkv, ["Cast", "Cast", "Softmax", "Mul", "Einsum"], [0, 0, 0, 0, None]
1284
+ )
1285
+ if qk_nodes is not None:
1286
+ (_, _, _softmax_qk, _, einsum_qk) = qk_nodes
1287
+ else:
1288
+ logger.debug("fuse_attention: failed to match qk path")
1289
+ return None
1290
+
1291
+ q_nodes = self.model.match_parent_path(einsum_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0])
1292
+ if q_nodes is None:
1293
+ logger.debug("fuse_attention: failed to match q path")
1294
+ return None
1295
+ (_, _transpose_q, reshape_q, matmul_q) = q_nodes
1296
+
1297
+ k_nodes = self.model.match_parent_path(einsum_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
1298
+ if k_nodes is None:
1299
+ logger.debug("fuse_attention: failed to match k path")
1300
+ return None
1301
+
1302
+ (_, _, _, matmul_k) = k_nodes
1303
+
1304
+ return reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v