onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,640 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import logging
6
+
7
+ import numpy as np
8
+ from fusion_attention import AttentionMask, FusionAttention
9
+ from onnx import TensorProto, helper
10
+ from onnx_model import OnnxModel
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class FusionBartAttention(FusionAttention):
16
+ """
17
+ Fuse Bart Attention subgraph into one Attention node.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ model: OnnxModel,
23
+ hidden_size: int,
24
+ num_heads: int,
25
+ attention_mask: AttentionMask,
26
+ ):
27
+ super().__init__(model, hidden_size, num_heads, attention_mask)
28
+
29
+ def check_runtime_shape_path(
30
+ self,
31
+ reshape_qkv_2,
32
+ reshape_qkv_1,
33
+ reshape_q_2,
34
+ reshape_k_2,
35
+ reshape_v_2,
36
+ root_input,
37
+ ):
38
+ concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ["Concat"], [1])
39
+ if concat_qkv_2_path is None:
40
+ return False
41
+ concat_qkv_2 = concat_qkv_2_path[0]
42
+
43
+ reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
44
+ reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
45
+ if reshape_qkv_2_path_1 is None or reshape_qkv_2_path_2 is None:
46
+ return False
47
+
48
+ _, gather_1, shape_1 = reshape_qkv_2_path_1
49
+ _, gather_2, shape_2 = reshape_qkv_2_path_2
50
+
51
+ if shape_1.input[0] != root_input or shape_2.input[0] != root_input:
52
+ return False
53
+
54
+ reshape_qkv_1_path_1 = self.model.match_parent_path(reshape_qkv_1, ["Concat", "Unsqueeze", "Gather"], [1, 0, 0])
55
+ reshape_qkv_1_path_2 = self.model.match_parent_path(reshape_qkv_1, ["Concat", "Unsqueeze", "Gather"], [1, 2, 0])
56
+ if reshape_qkv_1_path_1 is None or reshape_qkv_1_path_2 is None:
57
+ return False
58
+ if reshape_qkv_1_path_1[-1].name != gather_1.name or reshape_qkv_1_path_2[-1].name != gather_2.name:
59
+ return False
60
+
61
+ reshape_q_2_path = self.model.match_parent_path(reshape_q_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0])
62
+ reshape_k_2_path = self.model.match_parent_path(reshape_k_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0])
63
+ reshape_v_2_path = self.model.match_parent_path(reshape_v_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0])
64
+ if reshape_q_2_path is None or reshape_k_2_path is None or reshape_v_2_path is None:
65
+ return False
66
+
67
+ mul_q = reshape_q_2_path[-1]
68
+ mul_k = reshape_k_2_path[-1]
69
+ mul_v = reshape_v_2_path[-1]
70
+
71
+ gather_1_out = gather_1.output[0]
72
+ if mul_q.input[0] != gather_1_out or mul_k.input[0] != gather_1_out or mul_v.input[0] != gather_1_out:
73
+ return False
74
+
75
+ return True
76
+
77
+ def check_runtime_shape_path_openai(
78
+ self,
79
+ reshape_qkv_2,
80
+ matmul_qkv,
81
+ add_qk,
82
+ matmul_qk,
83
+ add_q,
84
+ ):
85
+ reshape_qkv_2_path = self.model.match_parent_path(
86
+ reshape_qkv_2, ["Concat", "Slice", "Gather", "Shape"], [1, 0, 0, 0]
87
+ )
88
+ if reshape_qkv_2_path is None:
89
+ return False
90
+ else:
91
+ if reshape_qkv_2_path[-1].input[0] != matmul_qkv.output[0]:
92
+ return False
93
+
94
+ matmul_qk_path_1 = self.model.match_parent_path(
95
+ matmul_qk, ["Mul", "Pow", "Cast", "Div", "Gather", "Shape"], [0, 1, 0, 0, 0, 0]
96
+ )
97
+ matmul_qk_path_2 = self.model.match_parent_path(
98
+ matmul_qk, ["Mul", "Pow", "Cast", "Div", "Gather", "Shape"], [1, 1, 0, 0, 0, 0]
99
+ )
100
+ if matmul_qk_path_1 is None or matmul_qk_path_2 is None:
101
+ return False
102
+
103
+ mul_1 = matmul_qk_path_1[0]
104
+ mul_2 = matmul_qk_path_2[0]
105
+ if mul_1.input[1] != mul_2.input[1]:
106
+ return False
107
+ if matmul_qk_path_1[-1].input[0] != add_q.output[0] and matmul_qk_path_2[-1].input[0] != add_q.output[0]:
108
+ return False
109
+
110
+ # For decoder attentions only
111
+ if add_qk is not None:
112
+ add_qk_path = self.model.match_parent_path(add_qk, ["Slice"], [1])
113
+ if add_qk_path is None:
114
+ return False
115
+ slice_q_path_1 = self.model.match_parent_path(
116
+ add_qk_path[0], ["Slice", "Unsqueeze", "Gather", "Shape"], [0, 2, 0, 0]
117
+ )
118
+ slice_q_path_2 = self.model.match_parent_path(add_qk_path[0], ["Unsqueeze", "Gather", "Shape"], [2, 0, 0])
119
+ if slice_q_path_1 is None and slice_q_path_2 is None:
120
+ return False
121
+ _, unsqueeze_1, _, _ = slice_q_path_1
122
+ unsqueeze_2, _, _ = slice_q_path_2
123
+ if unsqueeze_1.input[0] != unsqueeze_2.input[0]:
124
+ return False
125
+ if slice_q_path_1[-1].input[0] != add_q.output[0] and slice_q_path_2[-1].input[0] != add_q.output[0]:
126
+ return False
127
+
128
+ return True
129
+
130
+ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
131
+ # Track if fusion is occurring for OpenAI implementation of Whisper
132
+ model_impl_openai = False
133
+
134
+ # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
135
+ qkv_nodes = self.model.match_parent_path(
136
+ normalize_node,
137
+ ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
138
+ [1, 1, 0, 0, 0, 0],
139
+ )
140
+ qkv_nodes_openai = self.model.match_parent_path(
141
+ normalize_node,
142
+ ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
143
+ [1, 1, 0, 0, 0],
144
+ )
145
+ if qkv_nodes is not None:
146
+ (
147
+ add_out,
148
+ matmul_out,
149
+ reshape_qkv_2,
150
+ transpose_qkv,
151
+ reshape_qkv_1,
152
+ matmul_qkv,
153
+ ) = qkv_nodes
154
+ elif qkv_nodes_openai is not None:
155
+ qkv_nodes = qkv_nodes_openai
156
+ (
157
+ add_out,
158
+ matmul_out,
159
+ reshape_qkv_2,
160
+ transpose_qkv,
161
+ matmul_qkv,
162
+ ) = qkv_nodes
163
+ # Set model implementation to openai
164
+ model_impl_openai = True
165
+ else:
166
+ return
167
+
168
+ other_inputs = []
169
+ for input in normalize_node.input:
170
+ if input not in output_name_to_node:
171
+ continue
172
+ if input == qkv_nodes[0].output[0]:
173
+ continue
174
+ other_inputs.append(input)
175
+ if len(other_inputs) != 1:
176
+ return
177
+ root_input = other_inputs[0]
178
+
179
+ # Sometimes the input name to the attention MatMul nodes does not match the input name to the end
180
+ # SkipLayerNormalization node (name saved in root_input). We find the true input name to the MatMul
181
+ # nodes by getting the initial SkipLayerNormalization node and checking how many MatMul nodes are
182
+ # children nodes for each of its output names.
183
+ """
184
+ root_input
185
+ +---------------------------------------------------+
186
+ | |
187
+ | |
188
+ SkipLayerNormalization --> Attention --> MatMul --> SkipLayerNormalization
189
+ """
190
+ skip_layernorm = output_name_to_node[root_input]
191
+ # For some attention blocks, the end SkipLayerNormalization node may point to an Add node whose
192
+ # child is the LayerNormalization node.
193
+ if skip_layernorm.op_type == "Add":
194
+ skip_layernorm = self.model.get_children(skip_layernorm)[0]
195
+ for output in skip_layernorm.output:
196
+ if not output:
197
+ continue
198
+ children = input_name_to_nodes[output]
199
+ children_types = [child.op_type for child in children]
200
+ if children_types.count("MatMul") >= 1:
201
+ root_input = output
202
+ break
203
+
204
+ graph_input_names = set([node.name for node in self.model.graph().input])
205
+ graph_output_names = set([node.name for node in self.model.graph().output])
206
+
207
+ v_nodes = self.model.match_parent_path(
208
+ matmul_qkv,
209
+ ["Reshape", "Transpose", "Reshape", "Add", "MatMul"],
210
+ [1, 0, 0, 0, None],
211
+ )
212
+ v_nodes_openai = self.model.match_parent_path(
213
+ matmul_qkv,
214
+ ["Transpose", "Reshape", "Add", "MatMul"],
215
+ [1, 0, 0, None],
216
+ )
217
+ v_nodes_with_past_self_attn = self.model.match_parent_path(
218
+ # Decoder attention with past value concatenated before MatMul
219
+ matmul_qkv,
220
+ ["Reshape", "Concat", "Transpose", "Reshape", "Add", "MatMul"],
221
+ [1, 0, 1, 0, 0, None],
222
+ )
223
+ v_nodes_with_past_cross_attn = self.model.match_parent_path(
224
+ # Decoder attention with past value directly used in MatMul
225
+ matmul_qkv,
226
+ ["Reshape"],
227
+ [1],
228
+ )
229
+ v_nodes_with_past_cross_attn_openai = self.model.match_parent_path(
230
+ matmul_qkv,
231
+ ["Transpose", "Reshape", "Reshape", "Transpose"],
232
+ [1, 0, 0, 0],
233
+ )
234
+ past_v, present_v = "", ""
235
+ reshape_v_2, add_v = None, None
236
+ if v_nodes is not None:
237
+ (reshape_v_2, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes
238
+ # For initial pass through encoder-decoder_with_past to get starting past values (beam search)
239
+ present_v = transpose_v.output[0]
240
+ elif v_nodes_openai is not None:
241
+ v_nodes = v_nodes_openai
242
+ (transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes
243
+ # For initial pass through encoder-decoder_with_past to get starting past values (beam search)
244
+
245
+ # Find the child path to access the correct present_v values
246
+ # Openai impl provides present/past v values in 3D format
247
+ # whereas ort MultiHeadAttention expects v values in 4D, hence the
248
+ # additional Reshape and Transpose nodes are added
249
+ # For encoder attention types
250
+ # Add -> Reshape -> Transpose -> Present_V
251
+ reshape_path = self.model.match_child_path(
252
+ add_v,
253
+ ["Reshape", "Transpose"],
254
+ exclude=[reshape_v_1],
255
+ )
256
+ # For decoder attention types
257
+ # add_v_node Reshape <- Transpose <-Past_V
258
+ # \ /
259
+ # \ /
260
+ # -> Concat <-
261
+ # |
262
+ # |--> Reshape -> Transpose -> Present_V
263
+ concat_path = self.model.match_child_path(add_v, ["Concat", "Reshape", "Transpose"])
264
+ if reshape_path is not None:
265
+ (_, transpose_add_v) = reshape_path
266
+ if transpose_add_v.output[0] in graph_output_names:
267
+ present_v = transpose_add_v.output[0]
268
+ if concat_path is not None:
269
+ (concat_v, _, transpose_concat_v) = concat_path
270
+ if transpose_concat_v.output[0] in graph_output_names:
271
+ present_v = transpose_concat_v.output[0]
272
+ concat_nodes = self.model.match_parent_path(concat_v, ["Reshape", "Transpose"], [0, 0])
273
+ _, transpose_concat_v_in = concat_nodes
274
+ past_v = transpose_concat_v_in.input[0]
275
+ elif v_nodes_with_past_self_attn is not None:
276
+ (reshape_v_2, concat_v, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes_with_past_self_attn
277
+ v_nodes = v_nodes_with_past_self_attn
278
+ past_v = concat_v.input[0]
279
+ present_v = concat_v.output[0]
280
+ elif (
281
+ v_nodes_with_past_cross_attn is not None and v_nodes_with_past_cross_attn[-1].input[0] in graph_input_names
282
+ ):
283
+ v_nodes = v_nodes_with_past_cross_attn
284
+ past_v = v_nodes[-1].input[0]
285
+ present_v = v_nodes[-1].output[0]
286
+ if present_v not in graph_output_names:
287
+ identity_node_v = list(
288
+ filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v])
289
+ )
290
+ present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else ""
291
+ elif (
292
+ v_nodes_with_past_cross_attn_openai is not None
293
+ and v_nodes_with_past_cross_attn_openai[-1].input[0] in graph_input_names
294
+ ):
295
+ v_nodes = v_nodes_with_past_cross_attn_openai
296
+ past_v = v_nodes[-1].input[0]
297
+ present_v = v_nodes[-1].output[0]
298
+ if present_v not in graph_output_names:
299
+ identity_node_v = list(
300
+ filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v])
301
+ )
302
+ present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else ""
303
+ else:
304
+ logger.debug("fuse_attention: failed to match v path")
305
+ return
306
+ past_v = past_v if past_v in graph_input_names else ""
307
+ present_v = present_v if present_v in graph_output_names else ""
308
+
309
+ qk_nodes_1 = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
310
+ qk_nodes_2 = self.model.match_parent_path(
311
+ matmul_qkv, ["Softmax", "Reshape", "Add", "Reshape", "MatMul"], [0, 0, 0, 0, 0]
312
+ )
313
+ qk_nodes_2_openai = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
314
+ add_qk = None
315
+ if qk_nodes_1 is not None:
316
+ _, matmul_qk = qk_nodes_1
317
+ qk_nodes = qk_nodes_1
318
+ elif qk_nodes_2 is not None:
319
+ _, _, add_qk, _, matmul_qk = qk_nodes_2
320
+ qk_nodes = qk_nodes_2
321
+ elif qk_nodes_2_openai is not None:
322
+ _, add_qk, matmul_qk = qk_nodes_2_openai
323
+ qk_nodes = qk_nodes_2_openai
324
+ else:
325
+ return
326
+
327
+ q_nodes = self.model.match_parent_path(
328
+ matmul_qk,
329
+ ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"],
330
+ [0, 0, 0, 0, 0, 1],
331
+ )
332
+ q_nodes_openai = self.model.match_parent_path(
333
+ matmul_qk,
334
+ ["Mul", "Transpose", "Reshape", "Add", "MatMul"],
335
+ [0, 0, 0, 0, 1],
336
+ )
337
+ reshape_q_2 = None
338
+ if q_nodes is not None:
339
+ reshape_q_2, transpose_q, reshape_q_1, mul_q, add_q, matmul_q = q_nodes
340
+ elif q_nodes_openai is not None:
341
+ q_nodes = q_nodes_openai
342
+ mul_q, transpose_q, reshape_q_1, add_q, matmul_q = q_nodes
343
+ else:
344
+ return
345
+
346
+ k_nodes_with_bias = self.model.match_parent_path(
347
+ matmul_qk,
348
+ ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"],
349
+ [1, 0, 0, 0, 0, 1],
350
+ )
351
+ k_nodes_with_bias_openai = self.model.match_parent_path(
352
+ matmul_qk,
353
+ ["Mul", "Transpose", "Reshape", "MatMul"],
354
+ [1, 0, 0, 0],
355
+ )
356
+ k_nodes_no_bias = self.model.match_parent_path(
357
+ matmul_qk,
358
+ ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"],
359
+ [1, 0, 0, 0, 0],
360
+ )
361
+ k_nodes_no_bias_with_past_self_attn = self.model.match_parent_path(
362
+ # Decoder attention with past key concatenated before MatMul
363
+ matmul_qk,
364
+ ["Transpose", "Reshape", "Concat", "Transpose", "Reshape", "MatMul"],
365
+ [1, 0, 0, 1, 0, 0],
366
+ )
367
+ k_nodes_no_bias_with_past_cross_attn = self.model.match_parent_path(
368
+ # Decoder attention with past key directly used in MatMul
369
+ matmul_qk,
370
+ ["Transpose", "Reshape"],
371
+ [1, 0],
372
+ )
373
+ k_nodes_no_bias_with_past_cross_attn_openai = self.model.match_parent_path(
374
+ # Decoder attention with past key directly used in MatMul
375
+ matmul_qk,
376
+ ["Mul", "Transpose", "Reshape", "Reshape", "Transpose"],
377
+ [1, 0, 0, 0, 0],
378
+ )
379
+ past_k, present_k = "", ""
380
+ reshape_k_2, reshape_k_1, matmul_k = None, None, None
381
+ if k_nodes_with_bias is not None:
382
+ _, reshape_k_2, transpose_k_1, reshape_k_1, add_k, matmul_k = k_nodes_with_bias
383
+ k_nodes = k_nodes_with_bias
384
+ elif k_nodes_with_bias_openai is not None:
385
+ mul_k, transpose_k_1, reshape_k_1, matmul_k = k_nodes_with_bias_openai
386
+ k_nodes = k_nodes_with_bias_openai
387
+ present_k = matmul_k.output[0]
388
+
389
+ # Find the child path to access the correct present_k values
390
+ # Openai impl provides present/past k values in 3D format
391
+ # whereas ort MultiHeadAttention expects k values in 4D, hence the
392
+ # additional Reshape and Transpose nodes are added
393
+ # For encoder attention types
394
+ # Matmul -> Reshape -> Transpose -> Present_K
395
+ reshape_path = self.model.match_child_path(
396
+ matmul_k,
397
+ ["Reshape", "Transpose"],
398
+ exclude=[reshape_k_1],
399
+ )
400
+ # For decoder attention types
401
+ # matmul_k_node Reshape <- Transpose <- Past_K
402
+ # \ /
403
+ # \ /
404
+ # -> Concat <-
405
+ # |
406
+ # |--> Reshape -> Transpose -> Present_K
407
+ concat_path = self.model.match_child_path(matmul_k, ["Concat", "Reshape", "Transpose"])
408
+ if reshape_path is not None:
409
+ (_, transpose_matmul_k) = reshape_path
410
+ if transpose_matmul_k.output[0] in graph_output_names:
411
+ present_k = transpose_matmul_k.output[0]
412
+ if concat_path is not None:
413
+ (concat_k, _, transpose_concat_k) = concat_path
414
+ if transpose_concat_k.output[0] in graph_output_names:
415
+ present_k = transpose_concat_k.output[0]
416
+ concat_nodes = self.model.match_parent_path(concat_k, ["Reshape", "Transpose"], [0, 0])
417
+ _, transpose_concat_k_in = concat_nodes
418
+ past_k = transpose_concat_k_in.input[0]
419
+ elif k_nodes_no_bias is not None:
420
+ _, reshape_k_2, transpose_k_1, reshape_k_1, matmul_k = k_nodes_no_bias
421
+ k_nodes = k_nodes_no_bias
422
+ # For initial pass through encoder-decoder_with_past to get starting past values (beam search)
423
+ present_k = transpose_k_1.output[0]
424
+ elif k_nodes_no_bias_with_past_self_attn is not None:
425
+ _, reshape_k_2, concat_k, _, reshape_k_1, matmul_k = k_nodes_no_bias_with_past_self_attn
426
+ k_nodes = k_nodes_no_bias_with_past_self_attn
427
+ past_k = concat_k.input[0]
428
+ present_k = concat_k.output[0]
429
+ elif (
430
+ k_nodes_no_bias_with_past_cross_attn is not None
431
+ and k_nodes_no_bias_with_past_cross_attn[-1].input[0] in graph_input_names
432
+ ):
433
+ k_nodes = k_nodes_no_bias_with_past_cross_attn
434
+ past_k = k_nodes[-1].input[0]
435
+ present_k = k_nodes[-1].output[0]
436
+ if present_k not in graph_output_names:
437
+ identity_node_k = list(
438
+ filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k])
439
+ )
440
+ present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else ""
441
+ elif (
442
+ k_nodes_no_bias_with_past_cross_attn_openai is not None
443
+ and k_nodes_no_bias_with_past_cross_attn_openai[-1].input[0] in graph_input_names
444
+ ):
445
+ k_nodes = k_nodes_no_bias_with_past_cross_attn_openai
446
+ past_k = k_nodes[-1].input[0]
447
+ present_k = k_nodes[-1].output[0]
448
+ if present_k not in graph_output_names:
449
+ identity_node_k = list(
450
+ filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k])
451
+ )
452
+ present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else ""
453
+ else:
454
+ return
455
+ past_k = past_k if past_k in graph_input_names else ""
456
+ present_k = present_k if present_k in graph_output_names else ""
457
+
458
+ if k_nodes in (k_nodes_with_bias_openai, k_nodes_no_bias, k_nodes_no_bias_with_past_self_attn):
459
+ # Create empty Add node for attention graph
460
+ bias_dim = self.model.get_initializer(add_v.input[0]).dims[0]
461
+ empty_bias_name = "empty_bias"
462
+ empty_tensor = self.model.get_initializer(empty_bias_name)
463
+ if empty_tensor is None:
464
+ self.add_initializer(
465
+ empty_bias_name,
466
+ TensorProto.FLOAT,
467
+ dims=[bias_dim],
468
+ vals=np.array([0.0] * bias_dim, dtype=np.float32),
469
+ )
470
+
471
+ add_name = self.model.create_node_name("Add")
472
+ add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k_1.name], add_name)
473
+
474
+ if (
475
+ model_impl_openai
476
+ and not past_k
477
+ and not self.check_runtime_shape_path_openai(
478
+ reshape_qkv_2,
479
+ matmul_qkv,
480
+ add_qk,
481
+ matmul_qk,
482
+ add_q,
483
+ )
484
+ ):
485
+ return
486
+ elif (
487
+ not model_impl_openai
488
+ and not past_k
489
+ and not self.check_runtime_shape_path(
490
+ reshape_qkv_2,
491
+ reshape_qkv_1,
492
+ reshape_q_2,
493
+ reshape_k_2,
494
+ reshape_v_2,
495
+ root_input,
496
+ )
497
+ ):
498
+ return
499
+
500
+ three_root_inputs = past_k and past_v and matmul_k is None and "matmul_v" not in locals()
501
+ one_root_input = (
502
+ not three_root_inputs
503
+ and matmul_k.input[0] == root_input
504
+ and matmul_q.input[0] == root_input
505
+ and matmul_v.input[0] == root_input
506
+ )
507
+ two_root_inputs = (
508
+ not three_root_inputs
509
+ and matmul_q.input[0] == root_input
510
+ and matmul_k.input[0] == matmul_v.input[0]
511
+ and matmul_k.input[0] != matmul_q.input[0]
512
+ )
513
+
514
+ # There are 5 types of attention:
515
+ # 1) Encoder attention with one_root_input=True and qk_nodes=qk_nodes_1
516
+ # 2) Decoder attention with one_root_input=True and qk_nodes=qk_nodes_2
517
+ # 3) Decoder attention with past with one_root_input=True and qk_nodes=qk_nodes_1 and past_k=past_decoder_key and past_v=past_decoder_value
518
+ # 4) Decoder cross attention with two_root_inputs=True and qk_nodes=qk_nodes_1
519
+ # 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_1
520
+ encoder_attention = one_root_input and qk_nodes == qk_nodes_1
521
+ decoder_attention = one_root_input and qk_nodes in (qk_nodes_2, qk_nodes_2_openai)
522
+ decoder_attention_with_past = (
523
+ (encoder_attention if not model_impl_openai else decoder_attention) and past_k and past_v
524
+ )
525
+ decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_1
526
+ decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_1
527
+
528
+ # For decoder_attention, the attention mask needs to be included in the attention node
529
+ mask_index = None
530
+ if decoder_attention:
531
+ mask_nodes_bart = self.model.match_parent_path(
532
+ add_qk,
533
+ ["Where"],
534
+ [1],
535
+ )
536
+ mask_nodes_whisper = self.model.match_parent_path(
537
+ add_qk,
538
+ ["Expand", "Unsqueeze", "Unsqueeze", "Where"],
539
+ [1, 0, 0, 0],
540
+ )
541
+ if mask_nodes_whisper is not None:
542
+ mask_index = mask_nodes_whisper[0].output[-1]
543
+ elif mask_nodes_bart is not None:
544
+ mask_index = mask_nodes_bart[0].output[-1]
545
+
546
+ if (
547
+ encoder_attention
548
+ or decoder_attention
549
+ or decoder_attention_with_past
550
+ or decoder_cross_attention
551
+ or decoder_cross_attention_with_past
552
+ ):
553
+ attention_last_node = reshape_qkv_2
554
+ num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q_1)
555
+
556
+ if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
557
+ logger.debug("fuse_attention: failed to detect num_heads or hidden_size")
558
+ return
559
+
560
+ new_node = None
561
+ if decoder_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past:
562
+ # Note: Decoder attention with past key and past value is fused as multihead attention
563
+ # rather than attention because multihead attention supports separate past key and past
564
+ # value whereas attention supports concatenated past key and past value.
565
+ new_node = (
566
+ self.create_multihead_attention_node(
567
+ matmul_q,
568
+ matmul_k if decoder_cross_attention or decoder_attention_with_past else past_k,
569
+ matmul_v if decoder_cross_attention or decoder_attention_with_past else past_v,
570
+ add_q,
571
+ add_k if decoder_cross_attention or decoder_attention_with_past else None,
572
+ add_v if decoder_cross_attention or decoder_attention_with_past else None,
573
+ num_heads,
574
+ hidden_size,
575
+ attention_last_node.output[0],
576
+ past_k=past_k if decoder_attention_with_past else "",
577
+ past_v=past_v if decoder_attention_with_past else "",
578
+ present_k=present_k,
579
+ present_v=present_v,
580
+ packed_qkv=decoder_attention_with_past,
581
+ )
582
+ if self.use_multi_head_attention
583
+ else None
584
+ )
585
+ else:
586
+ # Temporarily set multihead attention flag to false
587
+ use_multi_head_attention_ground_truth = self.use_multi_head_attention
588
+ self.use_multi_head_attention = False
589
+ new_node = self.create_attention_node(
590
+ None,
591
+ matmul_q,
592
+ matmul_k,
593
+ matmul_v,
594
+ add_q,
595
+ add_k,
596
+ add_v,
597
+ num_heads,
598
+ hidden_size,
599
+ root_input,
600
+ attention_last_node.output[0],
601
+ add_qk_str=mask_index if decoder_attention else None,
602
+ past_k=past_k,
603
+ past_v=past_v,
604
+ present_k=present_k,
605
+ present_v=present_v,
606
+ )
607
+ self.use_multi_head_attention = use_multi_head_attention_ground_truth
608
+ if new_node is None:
609
+ return
610
+
611
+ self.nodes_to_add.append(new_node)
612
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
613
+
614
+ self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
615
+ self.nodes_to_remove.extend(qk_nodes)
616
+
617
+ # When using multihead attention, keep MatMul nodes in original graph
618
+ if decoder_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past:
619
+ if q_nodes[-1].op_type == "MatMul":
620
+ q_nodes.pop()
621
+ if k_nodes[-1].op_type == "MatMul":
622
+ k_nodes.pop()
623
+ if v_nodes[-1].op_type == "MatMul":
624
+ v_nodes.pop()
625
+ if self.disable_multi_head_attention_bias and (
626
+ decoder_cross_attention or decoder_cross_attention_with_past
627
+ ):
628
+ if q_nodes[-1].op_type == "Add":
629
+ q_nodes.pop()
630
+ if k_nodes[-1].op_type == "Add":
631
+ k_nodes.pop()
632
+ if v_nodes[-1].op_type == "Add":
633
+ v_nodes.pop()
634
+
635
+ self.nodes_to_remove.extend(q_nodes)
636
+ self.nodes_to_remove.extend(k_nodes)
637
+ self.nodes_to_remove.extend(v_nodes)
638
+
639
+ # Use prune graph to remove mask nodes since they are shared by all attention nodes.
640
+ self.prune_graph = True