onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1235 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ import numpy as np
9
+ from fusion_base import Fusion
10
+ from fusion_options import AttentionMaskFormat
11
+ from fusion_utils import FusionUtils, NumpyHelper
12
+ from onnx import NodeProto, TensorProto, helper, numpy_helper
13
+ from onnx_model import OnnxModel
14
+
15
+ logger = getLogger(__name__)
16
+
17
+
18
+ class AttentionMask:
19
+ """
20
+ Fuse Attention subgraph into one Attention node.
21
+ """
22
+
23
+ def __init__(self, model: OnnxModel):
24
+ self.model = model
25
+ # A lookup table with mask input as key, and mask index output as value
26
+ self.mask_indice = {}
27
+ # A lookup table with mask input as key, and cast (to int32) output as value
28
+ self.mask_casted = {}
29
+ self.utils = FusionUtils(model)
30
+ self.mask_format = AttentionMaskFormat.MaskIndexEnd
31
+ self.opset_version = model.get_opset_version()
32
+
33
+ def set_mask_format(self, mask_format: AttentionMaskFormat):
34
+ self.mask_format = mask_format
35
+
36
+ def set_mask_indice(self, mask, mask_index):
37
+ if mask in self.mask_indice:
38
+ assert mask_index == self.mask_indice[mask]
39
+ self.mask_indice[mask] = mask_index
40
+
41
+ def get_first_mask(self):
42
+ assert len(self.mask_indice) > 0
43
+ return next(iter(self.mask_indice))
44
+
45
+ def process_mask(self, input: str) -> str:
46
+ if self.mask_format == AttentionMaskFormat.NoMask:
47
+ return None
48
+
49
+ if input in self.mask_indice:
50
+ return self.mask_indice[input]
51
+
52
+ # Add cast to convert int64 to int32
53
+ if self.model.find_graph_input(input):
54
+ casted, input_name = self.utils.cast_graph_input_to_int32(input)
55
+ else:
56
+ input_name, cast_node = self.utils.cast_input_to_int32(input)
57
+ casted = True
58
+
59
+ if casted:
60
+ self.mask_casted[input] = input_name
61
+
62
+ # Attention supports int32 attention mask (2D) since 1.4.0
63
+ if self.mask_format == AttentionMaskFormat.AttentionMask:
64
+ self.mask_indice[input] = input_name
65
+ return input_name
66
+
67
+ # Add a mask processing node to convert attention mask to mask index (1D)
68
+ output_name = self.model.create_node_name("mask_index")
69
+ if self.opset_version < 13:
70
+ mask_index_node = helper.make_node(
71
+ "ReduceSum",
72
+ inputs=[input_name],
73
+ outputs=[output_name],
74
+ name=self.model.create_node_name("ReduceSum", "MaskReduceSum"),
75
+ )
76
+ mask_index_node.attribute.extend([helper.make_attribute("axes", [1]), helper.make_attribute("keepdims", 0)])
77
+ else:
78
+ # ReduceSum-13: axes is moved from attribute to input
79
+ axes_name = "ort_const_1_reduce_sum_axes"
80
+ if self.model.get_initializer(axes_name) is None:
81
+ self.model.add_initializer(
82
+ helper.make_tensor(
83
+ name=axes_name,
84
+ data_type=TensorProto.INT64,
85
+ dims=[1],
86
+ vals=[1],
87
+ raw=False,
88
+ )
89
+ )
90
+ mask_index_node = helper.make_node(
91
+ "ReduceSum",
92
+ inputs=[input_name, axes_name],
93
+ outputs=[output_name],
94
+ name=self.model.create_node_name("ReduceSum", "MaskReduceSum"),
95
+ )
96
+ mask_index_node.attribute.extend([helper.make_attribute("keepdims", 0)])
97
+
98
+ self.model.add_node(mask_index_node)
99
+
100
+ self.mask_indice[input] = output_name
101
+ return output_name
102
+
103
+
104
+ class FusionAttention(Fusion):
105
+ """
106
+ Fuse Attention subgraph into one Attention node.
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ model: OnnxModel,
112
+ hidden_size: int,
113
+ num_heads: int,
114
+ attention_mask: Optional[AttentionMask] = None,
115
+ use_multi_head_attention: bool = False,
116
+ disable_multi_head_attention_bias: bool = False,
117
+ search_op_types: List[str] = ["SkipLayerNormalization", "LayerNormalization"], # noqa: B006
118
+ ):
119
+ attention_op_name = "MultiHeadAttention" if use_multi_head_attention else "Attention"
120
+ super().__init__(model, attention_op_name, search_op_types)
121
+ self.hidden_size = hidden_size
122
+ self.num_heads = num_heads
123
+ self.attention_mask = attention_mask if attention_mask else AttentionMask(model)
124
+ self.use_multi_head_attention = use_multi_head_attention
125
+ self.disable_multi_head_attention_bias = disable_multi_head_attention_bias
126
+ self.mask_filter_value = None
127
+
128
+ # Flags to show warning only once
129
+ self.num_heads_warning = True
130
+ self.hidden_size_warning = True
131
+
132
+ self.shape_infer = None
133
+ self.shape_infer_done = True
134
+
135
+ def get_num_heads_and_hidden_size_from_concat(self, concat: NodeProto) -> Tuple[int, int]:
136
+ """
137
+ Detect num_heads and hidden_size from Concat node in the following subgraph:
138
+
139
+ SkipLayerNormalization or EmbedLayerNormalization
140
+ / |
141
+ MatMul Shape
142
+ | |
143
+ Add Gather(indices=0)
144
+ | |
145
+ | Unsqueeze
146
+ | |
147
+ | Concat (*, -1, 12, 64)
148
+ | /
149
+ Reshape
150
+ |
151
+ Transpose
152
+ """
153
+ if len(concat.input) == 4:
154
+ num_heads = self.model.get_constant_value(concat.input[2])
155
+ head_size = self.model.get_constant_value(concat.input[3])
156
+ if (
157
+ isinstance(num_heads, np.ndarray)
158
+ and num_heads.size == 1
159
+ and isinstance(head_size, np.ndarray)
160
+ and head_size.size == 1
161
+ ):
162
+ return num_heads[0], num_heads[0] * head_size[0]
163
+
164
+ return self.num_heads, self.hidden_size
165
+
166
+ def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int]:
167
+ """Detect num_heads and hidden_size from a reshape node.
168
+
169
+ Args:
170
+ reshape_q (NodeProto): reshape node for Q
171
+
172
+ Returns:
173
+ Tuple[int, int]: num_heads and hidden_size
174
+ """
175
+ # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
176
+ q_shape = self.model.get_initializer(reshape_q.input[1])
177
+ if q_shape is None:
178
+ concat = self.model.get_parent(reshape_q, 1)
179
+ if concat is not None and concat.op_type == "Concat":
180
+ return self.get_num_heads_and_hidden_size_from_concat(concat)
181
+ logger.debug(f"{reshape_q.input[1]} is not initializer.")
182
+ return self.num_heads, self.hidden_size # Fall back to user specified value
183
+
184
+ q_shape_value = NumpyHelper.to_array(q_shape)
185
+ if len(q_shape_value) != 4 or (q_shape_value[2] <= 0 or q_shape_value[3] <= 0):
186
+ logger.debug(f"q_shape_value={q_shape_value}. Expected value are like [0, 0, num_heads, head_size].")
187
+ return self.num_heads, self.hidden_size # Fall back to user specified value
188
+
189
+ num_heads = q_shape_value[2]
190
+ head_size = q_shape_value[3]
191
+ hidden_size = num_heads * head_size
192
+
193
+ if self.num_heads > 0 and num_heads != self.num_heads:
194
+ if self.num_heads_warning:
195
+ logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
196
+ self.num_heads_warning = False # Do not show the warning more than once
197
+
198
+ if self.hidden_size > 0 and hidden_size != self.hidden_size:
199
+ if self.hidden_size_warning:
200
+ logger.warning(
201
+ f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
202
+ )
203
+ self.hidden_size_warning = False # Do not show the warning more than once
204
+
205
+ return num_heads, hidden_size
206
+
207
+ def get_add_qk_str(self, add_qk: NodeProto):
208
+ if not self.shape_infer_done:
209
+ self.shape_infer = self.model.infer_runtime_shape(update=True)
210
+ self.shape_infer_done = True
211
+
212
+ if self.shape_infer is None:
213
+ return None
214
+
215
+ input_0_shape = self.shape_infer.get_edge_shape(add_qk.input[0])
216
+ input_1_shape = self.shape_infer.get_edge_shape(add_qk.input[1])
217
+
218
+ if input_0_shape is None or input_1_shape is None:
219
+ logger.debug(f"one of the inputs of {add_qk} is None")
220
+ return None
221
+
222
+ if input_0_shape != input_1_shape:
223
+ logger.debug(f"the shape of two inputs of {add_qk} is not same")
224
+ return None
225
+
226
+ return add_qk.input[1]
227
+
228
+ def reshape_add_qk(self, add_qk: str):
229
+ # Convert 4D mask from (B,1,S,T) to (B,N,S,T)
230
+ # B = batch size, N = num heads, S = source sequence length, T = target sequence length
231
+ mask_output_name = add_qk + "_mask"
232
+
233
+ # Check if concat node for (B,1,S,T) --> (B,N,S,T) already exists
234
+ concat_node = list(filter(lambda node: node.output[0] == mask_output_name, self.nodes_to_add))
235
+ if len(concat_node) == 1:
236
+ return mask_output_name
237
+
238
+ assert len(concat_node) == 0
239
+ concat_node_name = self.model.create_node_name("Concat")
240
+ concat_add_qk_fp32 = helper.make_node(
241
+ "Concat",
242
+ inputs=[add_qk for _ in range(self.num_heads)],
243
+ outputs=[mask_output_name],
244
+ name=concat_node_name,
245
+ axis=1,
246
+ )
247
+ # Add new node to graph
248
+ self.nodes_to_add.append(concat_add_qk_fp32)
249
+ self.node_name_to_graph_name[concat_node_name] = self.this_graph_name
250
+
251
+ return mask_output_name
252
+
253
+ def concat_kv(self, past_k: str, past_v: str) -> str:
254
+ """Concatenate past_k and past_v inputs to create past_kv input.
255
+
256
+ Args:
257
+ past_k (str): name of past K value
258
+ past_v (str): name of past V value
259
+
260
+ Returns:
261
+ kv_output_name (str): name of past KV value
262
+ """
263
+ # Unsqueeze K and V nodes from (B,N,P,H) to (1,B,N,P,H)
264
+ # B = batch size, N = num heads, P = past sequence length, H = head size
265
+ unsqueeze_k_name = self.model.create_node_name("Unsqueeze")
266
+ unsqueeze_v_name = self.model.create_node_name("Unsqueeze")
267
+ k_5d_name = (past_k + "_5d").replace(".", "_")
268
+ v_5d_name = (past_v + "_5d").replace(".", "_")
269
+
270
+ k_5d = helper.make_node(
271
+ "Unsqueeze",
272
+ inputs=[past_k],
273
+ outputs=[k_5d_name],
274
+ name=unsqueeze_k_name,
275
+ axes=[0],
276
+ )
277
+ v_5d = helper.make_node(
278
+ "Unsqueeze",
279
+ inputs=[past_v],
280
+ outputs=[v_5d_name],
281
+ name=unsqueeze_v_name,
282
+ axes=[0],
283
+ )
284
+
285
+ # Add unsqueeze nodes to graph
286
+ self.nodes_to_add.append(k_5d)
287
+ self.nodes_to_add.append(v_5d)
288
+ self.node_name_to_graph_name[unsqueeze_k_name] = self.this_graph_name
289
+ self.node_name_to_graph_name[unsqueeze_v_name] = self.this_graph_name
290
+
291
+ # Concat K and V to get one node of size (2,B,N,P,H)
292
+ concat_node_name = self.model.create_node_name("Concat")
293
+ kv_output_name = past_v.replace(".value", ".kv").replace(".", "_").replace("_value", "_kv")
294
+ concat_kv = helper.make_node(
295
+ "Concat",
296
+ inputs=[k_5d_name, v_5d_name],
297
+ outputs=[kv_output_name],
298
+ name=concat_node_name,
299
+ axis=0,
300
+ )
301
+
302
+ # Add concat node to graph
303
+ self.nodes_to_add.append(concat_kv)
304
+ self.node_name_to_graph_name[concat_node_name] = self.this_graph_name
305
+
306
+ return kv_output_name
307
+
308
+ def reshape_kv(self, past_k: str, past_v: str) -> (str, str):
309
+ """Reshape past_k and past_v from 4D to 3D to use as inputs for multihead attention node.
310
+
311
+ Args:
312
+ past_k (str): name of past K value of shape 4D
313
+ past_v (str): name of past V value of shape 4D
314
+
315
+ Returns:
316
+ k_3d (str): name of past K value of shape 3D
317
+ v_3d (str): name of past V value of shape 3D
318
+ """
319
+ # Reshape past_k and past_v from (B,N,P,H) to (B,P,N*H)
320
+ # B = batch size, N = num heads, P = past seq len, H = head size
321
+
322
+ # Create initializer for reshaping past_k and past_v
323
+ new_dims_name = "kv_4d_to_3d"
324
+ new_dims = self.model.get_initializer(new_dims_name)
325
+ if new_dims is None:
326
+ new_dims = numpy_helper.from_array(
327
+ np.array([0, -1, self.model.hidden_size], dtype="int64"), name=new_dims_name
328
+ )
329
+ self.model.add_initializer(new_dims, self.this_graph_name)
330
+
331
+ reshape_k_name = self.model.create_node_name("Reshape")
332
+ reshape_v_name = self.model.create_node_name("Reshape")
333
+ k_3d_name = (past_k + "_3d").replace(".", "_")
334
+ v_3d_name = (past_v + "_3d").replace(".", "_")
335
+
336
+ k_3d = helper.make_node(
337
+ "Reshape",
338
+ inputs=[past_k, new_dims_name],
339
+ outputs=[k_3d_name],
340
+ name=reshape_k_name,
341
+ )
342
+ v_3d = helper.make_node(
343
+ "Reshape",
344
+ inputs=[past_v, new_dims_name],
345
+ outputs=[v_3d_name],
346
+ name=reshape_v_name,
347
+ )
348
+
349
+ # Add reshape nodes to graph
350
+ self.nodes_to_add.append(k_3d)
351
+ self.nodes_to_add.append(v_3d)
352
+ self.node_name_to_graph_name[reshape_k_name] = self.this_graph_name
353
+ self.node_name_to_graph_name[reshape_v_name] = self.this_graph_name
354
+
355
+ return k_3d_name, v_3d_name
356
+
357
+ def split_kv(self, present_k_name: str, present_v_name: str, kv_node: str):
358
+ """Split kv_node containing present KV values into separate present K and present V values.
359
+
360
+ Args:
361
+ present_k_name (str): name of output to store present K value in
362
+ present_v_name (str): name of output to store present V value in
363
+ kv_node (str): name of present KV values
364
+ """
365
+ # Split kv_node into present_k and present_v nodes
366
+
367
+ # Create initializers for indexing kv_node, whose shape is (2,B,N,P,H)
368
+ k_index, v_index = "index_0", "index_1"
369
+ k_dim = self.model.get_initializer(k_index)
370
+ v_dim = self.model.get_initializer(v_index)
371
+ if k_dim is None:
372
+ k_dim = numpy_helper.from_array(np.array(0, dtype="int64"), name=k_index)
373
+ self.model.add_initializer(k_dim, self.this_graph_name)
374
+ if v_dim is None:
375
+ v_dim = numpy_helper.from_array(np.array(1, dtype="int64"), name=v_index)
376
+ self.model.add_initializer(v_dim, self.this_graph_name)
377
+
378
+ # Create nodes to index kv_node
379
+ gather_k_name = self.model.create_node_name("Gather")
380
+ gather_v_name = self.model.create_node_name("Gather")
381
+ present_k = helper.make_node(
382
+ "Gather",
383
+ inputs=[kv_node, k_index],
384
+ outputs=[present_k_name],
385
+ name=gather_k_name,
386
+ axis=0,
387
+ )
388
+ present_v = helper.make_node(
389
+ "Gather",
390
+ inputs=[kv_node, v_index],
391
+ outputs=[present_v_name],
392
+ name=gather_v_name,
393
+ axis=0,
394
+ )
395
+
396
+ # Add gather nodes to graph
397
+ self.nodes_to_add.append(present_k)
398
+ self.nodes_to_add.append(present_v)
399
+ self.node_name_to_graph_name[gather_k_name] = self.this_graph_name
400
+ self.node_name_to_graph_name[gather_v_name] = self.this_graph_name
401
+
402
+ def transpose_kv(self, past_k: str, past_v: str):
403
+ """Transpose past_k and past_v from (B,N,P,H) to (B,P,N,H)
404
+
405
+ Args:
406
+ past_k (str): name of past K value of shape (B,N,P,H)
407
+ past_v (str): name of past V value of shape (B,N,P,H)
408
+
409
+ Returns:
410
+ past_k_transpose (str): name of past K value of shape (B,P,N,H)
411
+ past_v_transpose (str): name of past V value of shape (B,P,N,H)
412
+ """
413
+ past_k_transpose = (past_k + "_transposed").replace(".", "_")
414
+ past_v_transpose = (past_v + "_transposed").replace(".", "_")
415
+ transpose_k_name = self.model.create_node_name("Transpose")
416
+ transpose_v_name = self.model.create_node_name("Transpose")
417
+
418
+ transpose_k = helper.make_node(
419
+ "Transpose",
420
+ inputs=[past_k],
421
+ outputs=[past_k_transpose],
422
+ name=transpose_k_name,
423
+ perm=[0, 2, 1, 3],
424
+ )
425
+ transpose_v = helper.make_node(
426
+ "Transpose",
427
+ inputs=[past_v],
428
+ outputs=[past_v_transpose],
429
+ name=transpose_v_name,
430
+ perm=[0, 2, 1, 3],
431
+ )
432
+
433
+ # Add reshape nodes to graph
434
+ self.nodes_to_add.append(transpose_k)
435
+ self.nodes_to_add.append(transpose_v)
436
+ self.node_name_to_graph_name[transpose_k_name] = self.this_graph_name
437
+ self.node_name_to_graph_name[transpose_v_name] = self.this_graph_name
438
+
439
+ return past_k_transpose, past_v_transpose
440
+
441
+ def create_combined_qkv_bias(
442
+ self,
443
+ q_add: NodeProto,
444
+ k_add: Union[NodeProto, None],
445
+ v_add: Union[NodeProto, None],
446
+ name_prefix: str,
447
+ ) -> Union[NodeProto, None]:
448
+ q_bias = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0])
449
+ qb = NumpyHelper.to_array(q_bias)
450
+ kb = np.zeros_like(qb)
451
+ vb = np.zeros_like(qb)
452
+ if k_add is not None:
453
+ k_bias = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0])
454
+ kb = NumpyHelper.to_array(k_bias)
455
+ if v_add is not None:
456
+ v_bias = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])
457
+ vb = NumpyHelper.to_array(v_bias)
458
+
459
+ qkv_bias = np.stack((qb, kb, vb), axis=0)
460
+ qkv_bias_dim = 3 * np.prod(qb.shape)
461
+
462
+ bias_name = name_prefix + "_qkv_bias"
463
+ self.add_initializer(
464
+ name=bias_name,
465
+ data_type=q_bias.data_type,
466
+ dims=[qkv_bias_dim],
467
+ vals=qkv_bias,
468
+ )
469
+ return bias_name
470
+
471
+ def create_packed_qkv_matmul_node(
472
+ self,
473
+ q_matmul: NodeProto,
474
+ k_matmul: NodeProto,
475
+ v_matmul: NodeProto,
476
+ q_add: NodeProto,
477
+ k_add: Union[NodeProto, None],
478
+ v_add: Union[NodeProto, None],
479
+ num_heads: int,
480
+ ) -> Union[NodeProto, None]:
481
+ """Create packed QKV MatMul node before MultiHeadAttention node.
482
+ This is for the scenario where an Attention node should be created but cannot be created
483
+ because past_key and past_value are separate inputs and not one concatenated input.
484
+
485
+ Args:
486
+ q_matmul (NodeProto): name of MatMul from Q path - (batch_size, sequence_length, hidden_size)
487
+ k_matmul (NodeProto): name of MatMul from K path - (batch_size, sequence_length, hidden_size)
488
+ v_matmul (NodeProto): name of MatMul from V path - (batch_size, sequence_length, hidden_size)
489
+ q_add (NodeProto): name of Add from Q path
490
+ k_add (NodeProto): name of Add from K path
491
+ v_add (NodeProto): name of Add from V path
492
+ num_heads (int): number of heads
493
+
494
+ Returns:
495
+ Union[NodeProto, None]: the node created or None if failed.
496
+ """
497
+ matmul_node_name = self.model.create_node_name("MatMul")
498
+
499
+ # Check that input for Q, K, V is the same
500
+ assert q_matmul.input[0] == k_matmul.input[0] and k_matmul.input[0] == v_matmul.input[0]
501
+
502
+ # Created packed QKV weight
503
+ q_weight = self.model.get_initializer(q_matmul.input[1])
504
+ k_weight = self.model.get_initializer(k_matmul.input[1])
505
+ v_weight = self.model.get_initializer(v_matmul.input[1])
506
+
507
+ qw = NumpyHelper.to_array(q_weight)
508
+ kw = NumpyHelper.to_array(k_weight)
509
+ vw = NumpyHelper.to_array(v_weight)
510
+
511
+ assert qw.shape == kw.shape and kw.shape == vw.shape
512
+ d = qw.shape[0]
513
+
514
+ qkv_weight = np.stack((qw, kw, vw), axis=1).reshape((d, 3 * d))
515
+ qkv_weight_name = matmul_node_name + "_qkv_weight"
516
+
517
+ self.add_initializer(
518
+ name=qkv_weight_name,
519
+ data_type=q_weight.data_type,
520
+ dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
521
+ vals=qkv_weight,
522
+ )
523
+
524
+ # Created packed QKV MatMul with output (B, S, 3*D)
525
+ # Output is of the form:
526
+ #
527
+ # [[[Q Q ... Q Q K K ... K K V V ... V V]]]
528
+ # [Q Q ... Q Q K K ... K K V V ... V V]
529
+ # .
530
+ # .
531
+ # .
532
+ # [[Q Q ... Q Q K K ... K K V V ... V V]
533
+ # [Q Q ... Q Q K K ... K K V V ... V V]]]
534
+ qkv_matmul_output = matmul_node_name + "_qkv_out"
535
+ qkv_matmul = helper.make_node(
536
+ "MatMul",
537
+ inputs=[q_matmul.input[0], qkv_weight_name],
538
+ outputs=[qkv_matmul_output],
539
+ name=matmul_node_name,
540
+ )
541
+ self.node_name_to_graph_name[matmul_node_name] = self.this_graph_name
542
+
543
+ qkv_nodes = [qkv_matmul]
544
+
545
+ # Create Slice nodes to access Q, K, V
546
+ q_slice_name = matmul_node_name + "_q_start_index"
547
+ self.add_initializer(name=q_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[0], raw=False)
548
+ k_slice_name = matmul_node_name + "_k_start_index"
549
+ self.add_initializer(name=k_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[d], raw=False)
550
+ v_slice_name = matmul_node_name + "_v_start_index"
551
+ self.add_initializer(name=v_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[2 * d], raw=False)
552
+ end_of_qkv_name = matmul_node_name + "_end_of_qkv_index"
553
+ self.add_initializer(name=end_of_qkv_name, data_type=TensorProto.INT64, dims=[1], vals=[3 * d], raw=False)
554
+ qkv_last_axis_name = matmul_node_name + "_qkv_last_axis"
555
+ self.add_initializer(name=qkv_last_axis_name, data_type=TensorProto.INT64, dims=[1], vals=[-1], raw=False)
556
+
557
+ q_slice_output = matmul_node_name + "_q_out"
558
+ q_slice = helper.make_node(
559
+ "Slice",
560
+ inputs=[qkv_matmul_output, q_slice_name, k_slice_name, qkv_last_axis_name],
561
+ outputs=[q_slice_output],
562
+ name=self.model.create_node_name("Slice"),
563
+ )
564
+ self.node_name_to_graph_name[q_slice.name] = self.this_graph_name
565
+ k_slice_output = matmul_node_name + "_k_out"
566
+ k_slice = helper.make_node(
567
+ "Slice",
568
+ inputs=[qkv_matmul_output, k_slice_name, v_slice_name, qkv_last_axis_name],
569
+ outputs=[k_slice_output],
570
+ name=self.model.create_node_name("Slice"),
571
+ )
572
+ self.node_name_to_graph_name[k_slice.name] = self.this_graph_name
573
+ v_slice_output = matmul_node_name + "_v_out"
574
+ v_slice = helper.make_node(
575
+ "Slice",
576
+ inputs=[qkv_matmul_output, v_slice_name, end_of_qkv_name, qkv_last_axis_name],
577
+ outputs=[v_slice_output],
578
+ name=self.model.create_node_name("Slice"),
579
+ )
580
+ self.node_name_to_graph_name[v_slice.name] = self.this_graph_name
581
+
582
+ q_output = q_slice
583
+ k_output = k_slice
584
+ v_output = v_slice
585
+ qkv_nodes.extend([q_slice, k_slice, v_slice])
586
+
587
+ if self.disable_multi_head_attention_bias:
588
+ if q_add is not None:
589
+ initializer_input = 1 if self.model.get_initializer(q_add.input[1]) else 0
590
+ if np.any(NumpyHelper.to_array(self.model.get_initializer(q_add.input[initializer_input]))):
591
+ q_add.input[1 - initializer_input] = q_slice_output
592
+ q_output = q_add
593
+ qkv_nodes.append(q_add)
594
+ self.node_name_to_graph_name[q_add.name] = self.this_graph_name
595
+ if k_add is not None:
596
+ initializer_input = 1 if self.model.get_initializer(k_add.input[1]) else 0
597
+ if np.any(NumpyHelper.to_array(self.model.get_initializer(k_add.input[initializer_input]))):
598
+ k_add.input[1 - initializer_input] = k_slice_output
599
+ k_output = k_add
600
+ qkv_nodes.append(k_add)
601
+ self.node_name_to_graph_name[k_add.name] = self.this_graph_name
602
+ if v_add is not None:
603
+ initializer_input = 1 if self.model.get_initializer(v_add.input[1]) else 0
604
+ if np.any(NumpyHelper.to_array(self.model.get_initializer(v_add.input[initializer_input]))):
605
+ v_add.input[1 - initializer_input] = v_slice_output
606
+ v_output = v_add
607
+ qkv_nodes.append(v_add)
608
+ self.node_name_to_graph_name[v_add.name] = self.this_graph_name
609
+
610
+ # Add nodes to graph
611
+ self.nodes_to_add.extend(qkv_nodes)
612
+ return q_output, k_output, v_output
613
+
614
+ def create_multihead_attention_node(
615
+ self,
616
+ q_matmul: NodeProto,
617
+ k_matmul: Union[NodeProto, str, None],
618
+ v_matmul: Union[NodeProto, str, None],
619
+ q_add: NodeProto,
620
+ k_add: Union[NodeProto, None],
621
+ v_add: Union[NodeProto, None],
622
+ num_heads: int,
623
+ hidden_size: int,
624
+ output: str,
625
+ key_padding_mask: str = "",
626
+ add_qk: str = "",
627
+ past_k: str = "",
628
+ past_v: str = "",
629
+ present_k: str = "",
630
+ present_v: str = "",
631
+ packed_qkv: bool = False,
632
+ ) -> Union[NodeProto, None]:
633
+ """Create a MultiHeadAttention node.
634
+
635
+ Args:
636
+ q_matmul (NodeProto): name of MatMul from Q path - (batch_size, sequence_length, hidden_size)
637
+ k_matmul (NodeProto): name of MatMul from K path - (batch_size, sequence_length, hidden_size) or (batch_size, num_heads, past_sequence_length, head_size)
638
+ v_matmul (NodeProto): name of MatMul from V path - (batch_size, sequence_length, hidden_size) or (batch_size, num_heads, past_sequence_length, head_size)
639
+ q_add (NodeProto): name of Add from Q path
640
+ k_add (NodeProto): name of Add from K path
641
+ v_add (NodeProto): name of Add from V path
642
+ num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
643
+ hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
644
+ output (str): output name of MHA
645
+ key_padding_mask (str): name of key padding mask
646
+ add_qk (str): name of add after Q x K'
647
+ past_k (str): name of past K value - (batch_size, num_heads, past_sequence_length, head_size)
648
+ past_v (str): name of past V value - (batch_size, num_heads, past_sequence_length, head_size)
649
+ present_k (str): name of present K value - (batch_size, num_heads, sequence_length, head_size)
650
+ present_v (str): name of present V value - (batch_size, num_heads, sequence_length, head_size)
651
+ packed_qkv (bool): whether to combine MatMuls from Q, K, V paths
652
+ Note: This is for the scenario where an Attention node should be created but cannot be created
653
+ because past_key and past_value are separate inputs and not one concatenated input.
654
+
655
+ Returns:
656
+ Union[NodeProto, None]: the node created or None if failed.
657
+ """
658
+ # B = batch size, N = num heads, P = past seq len, H = head size
659
+ assert num_heads > 0
660
+
661
+ if hidden_size > 0 and (hidden_size % num_heads) != 0:
662
+ logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
663
+ return None
664
+
665
+ graph_input_names = set([node.name for node in self.model.graph().input])
666
+ mha_node_name = self.model.create_node_name("Attention")
667
+
668
+ # Add initial Q/K/V inputs for MHA
669
+ mha_inputs = []
670
+ if packed_qkv:
671
+ q_slice, k_slice, v_slice = self.create_packed_qkv_matmul_node(
672
+ q_matmul, k_matmul, v_matmul, q_add, k_add, v_add, num_heads
673
+ )
674
+ mha_inputs.extend([q_slice.output[0], k_slice.output[0], v_slice.output[0]])
675
+ elif type(k_matmul) is NodeProto and type(v_matmul) is NodeProto:
676
+ if self.disable_multi_head_attention_bias:
677
+ mha_inputs.extend([q_add.output[0], k_matmul.output[0], v_add.output[0]])
678
+ else:
679
+ mha_inputs.extend([q_matmul.output[0], k_matmul.output[0], v_matmul.output[0]])
680
+ elif (
681
+ type(k_matmul) == str # noqa: E721
682
+ and type(v_matmul) == str # noqa: E721
683
+ and k_matmul in graph_input_names
684
+ and v_matmul in graph_input_names
685
+ ):
686
+ if self.disable_multi_head_attention_bias:
687
+ mha_inputs.extend([q_add.output[0], k_matmul, v_matmul])
688
+ else:
689
+ mha_inputs.extend([q_matmul.output[0], k_matmul, v_matmul])
690
+ else:
691
+ return None
692
+
693
+ # Add bias to inputs for MHA
694
+ # Bias for cross attention is not fully supported in DMMHA and cpu MHA kernels since they assume
695
+ # bias has been added to key and value when they are in BNSH format, so only bias for query is used.
696
+ # Need add checks if we found such assumption is not true.
697
+ if not self.disable_multi_head_attention_bias:
698
+ bias_name = self.create_combined_qkv_bias(q_add, k_add, v_add, mha_node_name)
699
+ mha_inputs.append(bias_name)
700
+ else:
701
+ mha_inputs.append("")
702
+
703
+ # Add optional inputs for MHA
704
+
705
+ if past_k and past_v:
706
+ mha_inputs.extend([key_padding_mask, add_qk, past_k, past_v])
707
+ elif key_padding_mask or add_qk:
708
+ mha_inputs.extend([key_padding_mask, add_qk])
709
+
710
+ # Add outputs for MHA
711
+ mha_outputs = [output]
712
+ if present_k and present_v:
713
+ mha_outputs.extend([present_k, present_v])
714
+
715
+ mha_node = helper.make_node(
716
+ "MultiHeadAttention",
717
+ inputs=mha_inputs,
718
+ outputs=mha_outputs,
719
+ name=mha_node_name,
720
+ )
721
+ mha_node.domain = "com.microsoft"
722
+ mha_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
723
+ return mha_node
724
+
725
+ def create_attention_node(
726
+ self,
727
+ mask_index: str,
728
+ q_matmul: NodeProto,
729
+ k_matmul: NodeProto,
730
+ v_matmul: NodeProto,
731
+ q_add: NodeProto,
732
+ k_add: NodeProto,
733
+ v_add: NodeProto,
734
+ num_heads: int,
735
+ hidden_size: int,
736
+ input: str,
737
+ output: str,
738
+ add_qk_str: str = "",
739
+ past_k: str = "",
740
+ past_v: str = "",
741
+ present_k: str = "",
742
+ present_v: str = "",
743
+ scale: Optional[float] = None,
744
+ causal: bool = False,
745
+ ) -> Union[NodeProto, None]:
746
+ """Create an Attention node.
747
+
748
+ Args:
749
+ mask_index (str): mask input
750
+ q_matmul (NodeProto): MatMul node in fully connection for Q
751
+ k_matmul (NodeProto): MatMul node in fully connection for K
752
+ v_matmul (NodeProto): MatMul node in fully connection for V
753
+ q_add (NodeProto): Add bias node in fully connection for Q
754
+ k_add (NodeProto): Add bias node in fully connection for K
755
+ v_add (NodeProto): Add bias node in fully connection for V
756
+ num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
757
+ hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
758
+ input (str): input name
759
+ output (str): output name
760
+ add_qk_str (str): name of Add node after Q x K'
761
+ past_k (str): name of input for past K value
762
+ past_v (str): name of input for past V value
763
+ present_k (str): name of output to store present K value
764
+ present_v (str): name of output to store present V value
765
+ scale: scale before softmax
766
+ causal: whether it is uni-directional mask.
767
+
768
+ Returns:
769
+ Union[NodeProto, None]: the node created or None if failed.
770
+ """
771
+ assert num_heads > 0
772
+
773
+ if hidden_size > 0 and (hidden_size % num_heads) != 0:
774
+ logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
775
+ return None
776
+
777
+ has_bias = True
778
+ if q_add is None and k_add is None and v_add is None:
779
+ has_bias = False
780
+
781
+ q_weight = self.model.get_initializer(q_matmul.input[1])
782
+ k_weight = self.model.get_initializer(k_matmul.input[1])
783
+ v_weight = self.model.get_initializer(v_matmul.input[1])
784
+
785
+ q_bias, k_bias, v_bias = None, None, None
786
+ if has_bias:
787
+ q_bias = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0])
788
+ k_bias = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0])
789
+ v_bias = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])
790
+
791
+ if not (k_weight and v_weight and q_bias and k_bias):
792
+ return None
793
+
794
+ if q_weight is None:
795
+ print(
796
+ f"{q_matmul.input[1]} is not an initializer. "
797
+ "Please set do_constant_folding=True in torch.onnx.export to unblock attention fusion"
798
+ )
799
+ return None
800
+
801
+ qw = NumpyHelper.to_array(q_weight)
802
+ kw = NumpyHelper.to_array(k_weight)
803
+ vw = NumpyHelper.to_array(v_weight)
804
+
805
+ # assert q and k have same shape as expected
806
+ assert qw.shape == kw.shape
807
+
808
+ qw_in_size = qw.shape[0]
809
+ kw_in_size = kw.shape[0]
810
+ vw_in_size = vw.shape[0]
811
+
812
+ assert qw_in_size == kw_in_size == vw_in_size
813
+
814
+ if hidden_size > 0 and hidden_size != qw_in_size:
815
+ logger.warning(
816
+ f"Input hidden size ({hidden_size}) is not same as weight matrix dimension of q,k,v ({qw_in_size}). "
817
+ "Please provide a correct input hidden size or pass in 0"
818
+ )
819
+
820
+ is_qkv_diff_dims = False
821
+ if qw.shape != vw.shape:
822
+ is_qkv_diff_dims = True
823
+
824
+ # All the matrices can have the same shape or q, k matrices can have the same shape with v being different
825
+ # For 2d weights, the shapes would be [in_size, out_size].
826
+ # For 3d weights, shape would be [in_size, a, b] where a*b = out_size
827
+ qw_out_size = np.prod(qw.shape[1:])
828
+ kw_out_size = np.prod(kw.shape[1:])
829
+ vw_out_size = np.prod(vw.shape[1:])
830
+
831
+ qkv_weight_dim = 0
832
+ if is_qkv_diff_dims:
833
+ qkv_weight = np.concatenate((qw, kw, vw), axis=1)
834
+ qkv_weight_dim = qw_out_size + kw_out_size + vw_out_size
835
+ else:
836
+ qkv_weight = np.stack((qw, kw, vw), axis=1)
837
+ qkv_weight_dim = 3 * qw_out_size
838
+
839
+ if has_bias:
840
+ qb = NumpyHelper.to_array(q_bias)
841
+ kb = NumpyHelper.to_array(k_bias)
842
+ vb = NumpyHelper.to_array(v_bias)
843
+
844
+ q_bias_shape = np.prod(qb.shape)
845
+ k_bias_shape = np.prod(kb.shape)
846
+ v_bias_shape = np.prod(vb.shape)
847
+
848
+ assert q_bias_shape == k_bias_shape == qw_out_size
849
+ assert v_bias_shape == vw_out_size
850
+
851
+ if is_qkv_diff_dims:
852
+ qkv_bias = np.concatenate((qb, kb, vb), axis=0)
853
+ qkv_bias_dim = q_bias_shape + k_bias_shape + v_bias_shape
854
+ else:
855
+ qkv_bias = np.stack((qb, kb, vb), axis=0)
856
+ qkv_bias_dim = 3 * q_bias_shape
857
+
858
+ attention_node_name = self.model.create_node_name("Attention")
859
+
860
+ if not self.use_multi_head_attention:
861
+ self.add_initializer(
862
+ name=attention_node_name + "_qkv_weight",
863
+ data_type=q_weight.data_type,
864
+ dims=[qw_in_size, qkv_weight_dim],
865
+ vals=qkv_weight,
866
+ )
867
+
868
+ if has_bias:
869
+ self.add_initializer(
870
+ name=attention_node_name + "_qkv_bias",
871
+ data_type=q_bias.data_type,
872
+ dims=[qkv_bias_dim],
873
+ vals=qkv_bias,
874
+ )
875
+
876
+ # For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights.
877
+ if self.use_multi_head_attention:
878
+ if add_qk_str:
879
+ logger.debug("MultiHeadAttention does not support relative_position_bias: cannot fuse the attention.")
880
+ return None
881
+
882
+ attention_inputs = [
883
+ q_matmul.output[0],
884
+ k_matmul.output[0],
885
+ v_matmul.output[0],
886
+ attention_node_name + "_qkv_bias",
887
+ ]
888
+
889
+ if mask_index is not None:
890
+ attention_inputs.append(mask_index)
891
+
892
+ attention_node = helper.make_node(
893
+ "MultiHeadAttention",
894
+ inputs=attention_inputs,
895
+ outputs=[output],
896
+ name=attention_node_name,
897
+ )
898
+ else:
899
+ attention_inputs = [
900
+ input,
901
+ attention_node_name + "_qkv_weight",
902
+ attention_node_name + "_qkv_bias" if has_bias else "",
903
+ ]
904
+ if mask_index is not None:
905
+ attention_inputs.append(mask_index)
906
+ else:
907
+ attention_inputs.append("")
908
+
909
+ past_exists = past_k and past_v
910
+ if past_exists:
911
+ past_kv = self.concat_kv(past_k, past_v)
912
+ attention_inputs.append(past_kv)
913
+
914
+ if add_qk_str is not None:
915
+ mask_output_name = self.reshape_add_qk(add_qk_str)
916
+
917
+ # Add attention mask to attention node
918
+ if not past_exists:
919
+ attention_inputs.append("")
920
+ attention_inputs.append(mask_output_name)
921
+
922
+ attention_outputs = [output]
923
+ if present_k and present_v:
924
+ present_kv = present_k.replace(".key", "").replace("_key", "").replace(".", "_")
925
+ attention_outputs.append(present_kv)
926
+ self.split_kv(present_k, present_v, present_kv)
927
+
928
+ attention_node = helper.make_node(
929
+ "Attention",
930
+ inputs=attention_inputs,
931
+ outputs=attention_outputs,
932
+ name=attention_node_name,
933
+ )
934
+
935
+ attention_node.domain = "com.microsoft"
936
+ attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
937
+
938
+ if causal:
939
+ attention_node.attribute.extend([helper.make_attribute("unidirectional", 1)])
940
+
941
+ if scale is not None:
942
+ attention_node.attribute.extend([helper.make_attribute("scale", scale)])
943
+
944
+ if is_qkv_diff_dims:
945
+ attention_node.attribute.extend(
946
+ [helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])]
947
+ )
948
+
949
+ if self.mask_filter_value is not None:
950
+ attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
951
+
952
+ return attention_node
953
+
954
+ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
955
+ # Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm
956
+ # Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern
957
+ start_node = normalize_node
958
+ if normalize_node.op_type == "LayerNormalization":
959
+ add_before_layernorm = self.model.match_parent(normalize_node, "Add", 0)
960
+ if add_before_layernorm is not None:
961
+ start_node = add_before_layernorm
962
+ else:
963
+ return
964
+
965
+ # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
966
+ qkv_nodes = self.model.match_parent_path(
967
+ start_node,
968
+ ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
969
+ [None, None, 0, 0, 0],
970
+ )
971
+ einsum_node = None
972
+ if qkv_nodes is not None:
973
+ (_, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
974
+ else:
975
+ # Match Albert
976
+ qkv_nodes = self.model.match_parent_path(
977
+ start_node, ["Add", "Einsum", "Transpose", "MatMul"], [1, None, 0, 0]
978
+ )
979
+ if qkv_nodes is not None:
980
+ (_, einsum_node, transpose_qkv, matmul_qkv) = qkv_nodes
981
+ else:
982
+ return
983
+
984
+ other_inputs = []
985
+ for _i, input in enumerate(start_node.input):
986
+ if input not in output_name_to_node:
987
+ continue
988
+
989
+ if input == qkv_nodes[0].output[0]:
990
+ continue
991
+ other_inputs.append(input)
992
+ if len(other_inputs) != 1:
993
+ return
994
+
995
+ root_input = other_inputs[0]
996
+ """
997
+ Match flaubert Mask
998
+ |
999
+ Mul --> LayerNormalization --> Attention --> MatMul --> Add
1000
+ | |
1001
+ | |
1002
+ +---------------------------------------------------------
1003
+ """
1004
+ mul_before_layernorm = self.model.match_parent(start_node, "Mul", 0)
1005
+ if mul_before_layernorm is not None:
1006
+ mul_children = input_name_to_nodes[mul_before_layernorm.output[0]]
1007
+ if mul_children is not None and len(mul_children) == 2:
1008
+ layernorm_node = mul_children[1]
1009
+ if layernorm_node.op_type == "LayerNormalization":
1010
+ root_input = layernorm_node.output[0]
1011
+ else:
1012
+ return
1013
+ elif mul_children is not None and len(mul_children) == 5:
1014
+ root_input = mul_before_layernorm.output[0]
1015
+ else:
1016
+ return
1017
+ elif normalize_node.op_type == "LayerNormalization":
1018
+ children = input_name_to_nodes[root_input]
1019
+ for child in children:
1020
+ if child.op_type == "LayerNormalization":
1021
+ root_input = child.output[0]
1022
+
1023
+ """
1024
+ When Add before the LayerNormalization produces an output
1025
+ that is consumed by some other nodes other than the LayerNormalization itself,
1026
+ fused SkipLayerNormalization will have several outputs.
1027
+ In this case we need to pick the one used in Attention
1028
+
1029
+ For example, this is the case for ViT
1030
+
1031
+ SkipLayerNormalization --> Attention --> MatMul --> Add --> SkipLayerNormalization
1032
+ | |
1033
+ | |
1034
+ +---------------------------------------------------------------------+
1035
+ """
1036
+ parent_node = output_name_to_node[root_input]
1037
+ if parent_node.op_type == "SkipLayerNormalization" and len(parent_node.output) == 4:
1038
+ root_input = parent_node.output[0]
1039
+
1040
+ children = input_name_to_nodes[root_input]
1041
+ children_types = [child.op_type for child in children]
1042
+ if children_types.count("MatMul") != 3:
1043
+ return
1044
+
1045
+ v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None])
1046
+ if v_nodes is None:
1047
+ logger.debug("fuse_attention: failed to match v path")
1048
+ return
1049
+ (_, _, add_v, matmul_v) = v_nodes
1050
+
1051
+ is_distill = False
1052
+ is_distill_add = False
1053
+ is_no_mask_attention = False
1054
+ qk_paths = {
1055
+ "path1": (["Softmax", "Add", "Div", "MatMul"], [0, 0, None, 0]),
1056
+ "path2": (["Softmax", "Add", "Mul", "MatMul"], [0, 0, None, 0]),
1057
+ "path3": (["Softmax", "Where", "MatMul", "Div"], [0, 0, 2, 0]),
1058
+ "path4": (["Softmax", "Add", "Where", "MatMul"], [0, 0, 0, 2]),
1059
+ "path5": (["Softmax", "Div", "MatMul"], [0, 0, 0]),
1060
+ }
1061
+
1062
+ qk_nodes = None
1063
+ for k, v in qk_paths.items():
1064
+ qk_nodes = self.model.match_parent_path(matmul_qkv, v[0], v[1])
1065
+ if qk_nodes is None:
1066
+ continue
1067
+ if k == "path3":
1068
+ is_distill = True
1069
+ if k == "path4":
1070
+ is_distill_add = True
1071
+ if k == "path5":
1072
+ is_no_mask_attention = True
1073
+ break
1074
+
1075
+ if qk_nodes is None:
1076
+ logger.debug("fuse_attention: failed to match qk path")
1077
+ return
1078
+
1079
+ add_qk = None
1080
+ matmul_qk = None
1081
+ where_qk = None
1082
+ if is_distill:
1083
+ (_, where_qk, matmul_qk, _) = qk_nodes
1084
+ elif is_distill_add:
1085
+ (_, add_qk, where_qk, matmul_qk) = qk_nodes
1086
+ elif is_no_mask_attention:
1087
+ (_, _, matmul_qk) = qk_nodes
1088
+ else:
1089
+ (_, add_qk, _, matmul_qk) = qk_nodes
1090
+
1091
+ q_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, None])
1092
+ if q_nodes is None:
1093
+ q_nodes = self.model.match_parent_path(
1094
+ matmul_qk,
1095
+ ["Div", "Transpose", "Reshape", "Add", "MatMul"],
1096
+ [0, 0, 0, 0, None],
1097
+ )
1098
+ if q_nodes is None:
1099
+ logger.debug("fuse_attention: failed to match q path")
1100
+ return
1101
+ reshape_q = q_nodes[-3]
1102
+ add_q = q_nodes[-2]
1103
+ matmul_q = q_nodes[-1]
1104
+
1105
+ k_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None])
1106
+ if k_nodes is None:
1107
+ k_nodes = self.model.match_parent_path(
1108
+ matmul_qk,
1109
+ ["Transpose", "Transpose", "Reshape", "Add", "MatMul"],
1110
+ [1, 0, 0, 0, None],
1111
+ )
1112
+ if k_nodes is None:
1113
+ logger.debug("fuse_attention: failed to match k path")
1114
+ return
1115
+ add_k = k_nodes[-2]
1116
+ matmul_k = k_nodes[-1]
1117
+
1118
+ # Note that Cast might be removed by OnnxRuntime so we match two patterns here.
1119
+ mask_nodes = None
1120
+ add_qk_str = None
1121
+ if is_distill:
1122
+ _, mask_nodes, _ = self.model.match_parent_paths(
1123
+ where_qk,
1124
+ [
1125
+ (["Expand", "Reshape", "Equal"], [0, 0, 0]),
1126
+ (["Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0]),
1127
+ (["Cast", "Expand", "Reshape", "Equal"], [0, 0, 0, 0]),
1128
+ ],
1129
+ output_name_to_node,
1130
+ )
1131
+ elif is_distill_add:
1132
+ _, mask_nodes, _ = self.model.match_parent_paths(
1133
+ where_qk,
1134
+ [
1135
+ (["Cast", "Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0, 0]),
1136
+ (["Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0]),
1137
+ ],
1138
+ output_name_to_node,
1139
+ )
1140
+ if add_qk is not None:
1141
+ add_qk_str = self.get_add_qk_str(add_qk)
1142
+ if add_qk_str is None:
1143
+ logger.debug(f"fuse_attention: failed to verify shape inference of {add_qk}")
1144
+ return
1145
+ elif is_no_mask_attention:
1146
+ pass
1147
+ else:
1148
+ _, mask_nodes, _ = self.model.match_parent_paths(
1149
+ add_qk,
1150
+ [
1151
+ (
1152
+ ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
1153
+ [None, 0, 1, 0, 0],
1154
+ ),
1155
+ (["Mul", "Sub", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0]),
1156
+ ],
1157
+ output_name_to_node,
1158
+ )
1159
+ if not is_no_mask_attention and mask_nodes is None:
1160
+ logger.debug("fuse_attention: failed to match mask path")
1161
+ return
1162
+
1163
+ if not is_no_mask_attention and len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul":
1164
+ _, mul_val = self.model.get_constant_input(mask_nodes[0])
1165
+ if mul_val != -10000:
1166
+ self.mask_filter_value = mul_val
1167
+
1168
+ if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input:
1169
+ mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) if not is_no_mask_attention else None
1170
+
1171
+ attention_last_node = reshape_qkv if einsum_node is None else transpose_qkv
1172
+
1173
+ q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
1174
+ if q_num_heads <= 0 or q_hidden_size <= 0:
1175
+ logger.warning(
1176
+ "Failed to detect num_heads and hidden_size for Attention fusion. "
1177
+ "Please specify those parameters in argument."
1178
+ )
1179
+ return
1180
+
1181
+ # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
1182
+ # the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately
1183
+ new_node = self.create_attention_node(
1184
+ mask_index,
1185
+ matmul_q,
1186
+ matmul_k,
1187
+ matmul_v,
1188
+ add_q,
1189
+ add_k,
1190
+ add_v,
1191
+ q_num_heads,
1192
+ q_hidden_size,
1193
+ root_input,
1194
+ attention_last_node.output[0],
1195
+ add_qk_str,
1196
+ )
1197
+ if new_node is None:
1198
+ return
1199
+
1200
+ self.nodes_to_add.append(new_node)
1201
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
1202
+
1203
+ if einsum_node is not None:
1204
+ unique_index = einsum_node.input[0]
1205
+ new_edge = "edge_modified_" + unique_index
1206
+
1207
+ shape_tensor = self.add_initializer(
1208
+ name="shape_modified_tensor" + unique_index,
1209
+ data_type=TensorProto.INT64,
1210
+ dims=[4],
1211
+ vals=np.int64([0, 0, q_num_heads, int(q_hidden_size / q_num_heads)]),
1212
+ raw=False,
1213
+ )
1214
+
1215
+ self.model.add_node(
1216
+ helper.make_node(
1217
+ "Reshape",
1218
+ [attention_last_node.output[0], shape_tensor.name],
1219
+ [new_edge],
1220
+ "reshape_modified_" + unique_index,
1221
+ ),
1222
+ self.this_graph_name,
1223
+ )
1224
+ einsum_node.input[0] = new_edge
1225
+
1226
+ self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
1227
+ self.nodes_to_remove.extend(qk_nodes)
1228
+
1229
+ # For MultiHeadAttention operator, MatMul nodes for Q/K/V projection shall not be fused.
1230
+ self.nodes_to_remove.extend(q_nodes if not self.use_multi_head_attention else q_nodes[:-1])
1231
+ self.nodes_to_remove.extend(k_nodes if not self.use_multi_head_attention else k_nodes[:-1])
1232
+ self.nodes_to_remove.extend(v_nodes if not self.use_multi_head_attention else v_nodes[:-1])
1233
+
1234
+ # Use prune graph to remove mask nodes since they are shared by all attention nodes.
1235
+ self.prune_graph = True