onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1592 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import logging
6
+ from typing import Optional, Union
7
+
8
+ from fusion_attention import FusionAttention
9
+ from fusion_base import Fusion
10
+ from onnx import FunctionProto, NodeProto, TensorProto, helper, numpy_helper
11
+ from onnx_model import OnnxModel
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class FusionRotaryAttention(FusionAttention):
17
+ """
18
+ Fuse Attention subgraph with rotary positional embeddings into one MultiHeadAttention node.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ model: OnnxModel,
24
+ hidden_size: int,
25
+ num_heads: int,
26
+ ):
27
+ super().__init__(
28
+ model,
29
+ hidden_size,
30
+ num_heads,
31
+ use_multi_head_attention=True,
32
+ search_op_types=[
33
+ "SimplifiedLayerNormalization",
34
+ "SkipSimplifiedLayerNormalization",
35
+ "LayerNormalization",
36
+ "SkipLayerNormalization",
37
+ "Add",
38
+ ],
39
+ )
40
+
41
+ def create_mha_node(
42
+ self,
43
+ input: str,
44
+ output: str,
45
+ q_rotary: NodeProto,
46
+ k_rotary: NodeProto,
47
+ v_matmul: NodeProto,
48
+ attn_mask: str = "",
49
+ add_qk: str = "",
50
+ past_k: str = "",
51
+ past_v: str = "",
52
+ present_k: str = "",
53
+ present_v: str = "",
54
+ scale: Optional[float] = None,
55
+ ) -> Union[NodeProto, None]:
56
+ assert self.num_heads > 0
57
+
58
+ if self.hidden_size > 0 and (self.hidden_size % self.num_heads) != 0:
59
+ logger.debug(
60
+ f"fuse_rotary_attention: input hidden size {self.hidden_size} is not a multiple of num of heads {self.num_heads}"
61
+ )
62
+ return None
63
+
64
+ mha_node_name = self.model.create_node_name("MultiHeadAttention")
65
+ mha_inputs = [
66
+ q_rotary.output[0],
67
+ k_rotary.output[0],
68
+ v_matmul.output[0],
69
+ "", # bias
70
+ attn_mask, # key_padding_mask
71
+ add_qk, # attention_bias
72
+ past_k,
73
+ past_v,
74
+ ]
75
+
76
+ mha_outputs = [output]
77
+ if present_k and present_v:
78
+ mha_outputs.extend([present_k, present_v])
79
+
80
+ mha_node = helper.make_node(
81
+ "MultiHeadAttention",
82
+ inputs=mha_inputs,
83
+ outputs=mha_outputs,
84
+ name=mha_node_name,
85
+ )
86
+
87
+ mha_node.domain = "com.microsoft"
88
+ mha_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)])
89
+ if scale is not None:
90
+ mha_node.attribute.extend([helper.make_attribute("scale", scale)])
91
+ if self.mask_filter_value is not None:
92
+ mha_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
93
+
94
+ self.increase_counter("MultiHeadAttention")
95
+ return mha_node
96
+
97
+ def check_runtime_shape_paths_for_function(
98
+ self,
99
+ reshape_qkv_2, # Reshape after Transpose
100
+ reshape_qkv_1, # Reshape before Transpose
101
+ reshape_q_2, # Reshape after RotaryEmbedding
102
+ reshape_k_2, # Reshape after RotaryEmbedding
103
+ reshape_v_2, # Reshape after Transpose
104
+ reshape_v_1, # Reshape before Transpose
105
+ add_qk, # Add before Softmax
106
+ root_input, # Root input to attention subgraph
107
+ ):
108
+ # Check #1: check paths for qkv nodes
109
+ concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ["Concat"], [1])
110
+ concat_qkv_1_path = self.model.match_parent_path(reshape_qkv_1, ["Concat"], [1])
111
+ if concat_qkv_2_path is None or concat_qkv_1_path is None:
112
+ return False
113
+ concat_qkv_2, concat_qkv_1 = concat_qkv_2_path[0], concat_qkv_1_path[0]
114
+
115
+ reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
116
+ reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
117
+ reshape_qkv_1_path_1 = self.model.match_parent_path(concat_qkv_1, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
118
+ reshape_qkv_1_path_2 = self.model.match_parent_path(concat_qkv_1, ["Unsqueeze", "Gather", "Shape"], [2, 0, 0])
119
+ if (
120
+ reshape_qkv_2_path_1 is None
121
+ or reshape_qkv_2_path_2 is None
122
+ or reshape_qkv_1_path_1 is None
123
+ or reshape_qkv_1_path_2 is None
124
+ ):
125
+ return False
126
+
127
+ _, gather_1, shape_1 = reshape_qkv_2_path_1
128
+ _, gather_2, shape_2 = reshape_qkv_2_path_2
129
+
130
+ # Check root_input --> Shape --> Gather connection
131
+ if shape_1.input[0] != root_input or shape_2.input[0] != root_input:
132
+ return False
133
+
134
+ # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_qkv_1_path_1 and reshape_qkv_1_path_2
135
+ if reshape_qkv_1_path_1[1].name != gather_1.name or reshape_qkv_1_path_2[1].name != gather_2.name:
136
+ return False
137
+
138
+ # Check #2: check paths for v nodes
139
+ concat_v_2_path = self.model.match_parent_path(reshape_v_2, ["Concat"], [1])
140
+ concat_v_1_path = self.model.match_parent_path(reshape_v_1, ["Concat"], [1])
141
+ if concat_v_2_path is None or concat_v_1_path is None:
142
+ return False
143
+ concat_v_2, concat_v_1 = concat_v_2_path[0], concat_v_1_path[0]
144
+
145
+ reshape_v_2_path_1 = self.model.match_parent_path(
146
+ concat_v_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0]
147
+ )
148
+ reshape_v_2_path_2 = self.model.match_parent_path(
149
+ concat_v_2, ["Unsqueeze", "Add", "Gather", "Shape"], [1, 0, 0, 0]
150
+ )
151
+ reshape_v_1_path_1 = self.model.match_parent_path(concat_v_1, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
152
+ reshape_v_1_path_2 = self.model.match_parent_path(concat_v_1, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
153
+ if (
154
+ reshape_v_2_path_1 is None
155
+ or reshape_v_2_path_2 is None
156
+ or reshape_v_1_path_1 is None
157
+ or reshape_v_1_path_2 is None
158
+ ):
159
+ return False
160
+
161
+ # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_v_2_path_1
162
+ # Check Gather --> Add --> Unsqueeze --> Concat --> Reshape connection for reshape_v_2_path_2
163
+ # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_v_1_path_1 and reshape_v_1_path_2
164
+ if (
165
+ reshape_v_2_path_1[2].name != gather_1.name
166
+ or reshape_v_2_path_2[2].name != gather_2.name
167
+ or reshape_v_1_path_1[1].name != gather_1.name
168
+ or reshape_v_1_path_2[1].name != gather_2.name
169
+ ):
170
+ return False
171
+
172
+ # Check #3: check paths for k nodes
173
+ concat_k_2_path = self.model.match_parent_path(reshape_k_2, ["Concat"], [1])
174
+ if concat_k_2_path is None:
175
+ return False
176
+ concat_k_2 = concat_k_2_path[0]
177
+
178
+ reshape_k_2_path_1 = self.model.match_parent_path(
179
+ concat_k_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0]
180
+ )
181
+ reshape_k_2_path_2 = self.model.match_parent_path(
182
+ concat_k_2, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 0, 0]
183
+ )
184
+ if reshape_k_2_path_1 is None or reshape_k_2_path_2 is None:
185
+ return False
186
+
187
+ # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_k_2_path_1
188
+ # Check Gather --> Add --> Unsqueeze --> Concat --> Reshape connection for reshape_k_2_path_2
189
+ if reshape_k_2_path_1[2].name != gather_1.name or reshape_k_2_path_2[2].name != gather_2.name:
190
+ return False
191
+
192
+ # Check #4: check paths for q nodes
193
+ concat_q_2_path = self.model.match_parent_path(reshape_q_2, ["Concat"], [1])
194
+ if concat_q_2_path is None:
195
+ return False
196
+ concat_q_2 = concat_q_2_path[0]
197
+
198
+ reshape_q_2_path_1 = self.model.match_parent_path(
199
+ concat_q_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0]
200
+ )
201
+ reshape_q_2_path_2 = self.model.match_parent_path(concat_q_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
202
+ if reshape_q_2_path_1 is None or reshape_q_2_path_2 is None:
203
+ return False
204
+
205
+ # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_q_2_path_1
206
+ # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_q_2_path_2
207
+ if reshape_q_2_path_1[2].name != gather_1.name or reshape_q_2_path_2[1].name != gather_2.name:
208
+ return False
209
+
210
+ # Check #5: check Mul nodes are the same for q, k, v
211
+ mul_q = reshape_q_2_path_1[1]
212
+ mul_k = reshape_k_2_path_1[1]
213
+ mul_v = reshape_v_2_path_1[1]
214
+ gather_1_out = gather_1.output[0]
215
+ if mul_q.input[0] != gather_1_out or mul_k.input[0] != gather_1_out or mul_v.input[0] != gather_1_out:
216
+ return False
217
+
218
+ # Check #6: check paths for attention mask nodes
219
+ attn_mask_path_1 = self.model.match_parent_path(add_qk, ["Concat", "Slice", "Slice"], [1, 0, 0])
220
+ attn_mask_path_2 = self.model.match_parent_path(add_qk, ["Cast", "Concat", "Slice", "Slice"], [1, 0, 0, 0])
221
+ if attn_mask_path_1 is not None:
222
+ _, slice_qk_2, slice_qk_1 = attn_mask_path_1
223
+ elif attn_mask_path_2 is not None:
224
+ _, _, slice_qk_2, slice_qk_1 = attn_mask_path_2
225
+ else:
226
+ return False
227
+ # Check first input to Slice #1 is 3D attention mask of shape (B,S,T)
228
+ if slice_qk_1.input[0] not in {"attn_mask", "attention_mask"}:
229
+ return False
230
+
231
+ slice_qk_2_path = self.model.match_parent_path(
232
+ slice_qk_2, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 1, 0]
233
+ )
234
+ slice_qk_1_path_1 = self.model.match_parent_path(
235
+ slice_qk_1, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 1, 0]
236
+ )
237
+ slice_qk_1_path_2 = self.model.match_parent_path(slice_qk_1, ["Unsqueeze"], [1])
238
+ if slice_qk_2_path is None or slice_qk_1_path_1 is None or slice_qk_1_path_2 is None:
239
+ return False
240
+
241
+ # Check Gather --> Add --> Unsqueeze #3 --> Slice #2 connection for slice_qk_2_path
242
+ # Check Gather --> Add --> Unsqueeze #2 --> Slice #1 connection for slice_qk_1_path_1
243
+ if slice_qk_2_path[1].name != slice_qk_1_path_1[1].name or slice_qk_2_path[2].name != slice_qk_1_path_1[2].name:
244
+ return False
245
+
246
+ # Check Unsqueeze #1 --> Slice #1 connection for slice_qk_1_path_2
247
+ # Check if first input to Add and Unsqueeze #1 is position ids
248
+ if slice_qk_1_path_1[1].input[0] != slice_qk_1_path_2[0].input[0]:
249
+ return False
250
+
251
+ return True
252
+
253
+ def check_runtime_shape_paths_for_nodes(
254
+ self,
255
+ reshape_qkv, # Final reshape before o_proj MatMul
256
+ reshape_q, # Reshape before q_proj MatMul
257
+ reshape_k, # Reshape before k_proj MatMul
258
+ reshape_v, # Reshape before v_proj MatMul
259
+ root_input, # Root input to attention subgraph
260
+ ):
261
+ # Check #1: check paths for qkv nodes
262
+ concat_qkv_path = self.model.match_parent_path(reshape_qkv, ["Concat"], [1])
263
+ if concat_qkv_path is None:
264
+ return False
265
+ concat_qkv = concat_qkv_path[0]
266
+
267
+ reshape_qkv_path_1 = self.model.match_parent_path(concat_qkv, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
268
+ reshape_qkv_path_2 = self.model.match_parent_path(concat_qkv, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
269
+ if reshape_qkv_path_1 is None or reshape_qkv_path_2 is None:
270
+ return False
271
+
272
+ _, gather_1, shape_1 = reshape_qkv_path_1
273
+ _, gather_2, shape_2 = reshape_qkv_path_2
274
+
275
+ # Check root_input --> Shape --> Gather connection
276
+ if shape_1.input[0] != root_input or shape_2.input[0] != root_input:
277
+ return False
278
+
279
+ # Check #2: check paths for v nodes
280
+ concat_v_path = self.model.match_parent_path(reshape_v, ["Concat"], [1])
281
+ if concat_v_path is None:
282
+ return False
283
+ concat_v = concat_v_path[0]
284
+
285
+ reshape_v_path_1 = self.model.match_parent_path(concat_v, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
286
+ reshape_v_path_2 = self.model.match_parent_path(concat_v, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
287
+ if reshape_v_path_1 is None or reshape_v_path_2 is None:
288
+ return False
289
+
290
+ # Check Gather --> Unsqueeze --> Concat --> Reshape connection
291
+ if reshape_v_path_1[1].name != gather_1.name or reshape_v_path_2[1].name != gather_2.name:
292
+ return False
293
+
294
+ # Check #3: check paths for k nodes
295
+ concat_k_path = self.model.match_parent_path(reshape_k, ["Concat"], [1])
296
+ if concat_k_path is None:
297
+ return False
298
+ concat_k = concat_k_path[0]
299
+
300
+ reshape_k_path_1 = self.model.match_parent_path(concat_k, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
301
+ reshape_k_path_2 = self.model.match_parent_path(concat_k, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
302
+ if reshape_k_path_1 is None or reshape_k_path_2 is None:
303
+ return False
304
+
305
+ # Check Gather --> Unsqueeze --> Concat --> Reshape connection
306
+ if reshape_k_path_1[1].name != gather_1.name or reshape_k_path_2[1].name != gather_2.name:
307
+ return False
308
+
309
+ # Check #4: check paths for q nodes
310
+ concat_q_path = self.model.match_parent_path(reshape_q, ["Concat"], [1])
311
+ if concat_q_path is None:
312
+ return False
313
+ concat_q = concat_q_path[0]
314
+
315
+ reshape_q_path_1 = self.model.match_parent_path(concat_q, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
316
+ reshape_q_path_2 = self.model.match_parent_path(concat_q, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
317
+ if reshape_q_path_1 is None or reshape_q_path_2 is None:
318
+ return False
319
+
320
+ # Check Gather --> Unsqueeze --> Concat --> Reshape connection
321
+ if reshape_q_path_1[1].name != gather_1.name or reshape_q_path_2[1].name != gather_2.name:
322
+ return False
323
+
324
+ return True
325
+
326
+ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
327
+ if normalize_node.op_type not in {"SkipSimplifiedLayerNormalization", "SkipLayerNormalization", "Add"}:
328
+ return
329
+
330
+ # qkv_nodes_1 is for LLaMA-2 Microsoft
331
+ # qkv_nodes_2 is for LLaMA-2 Hugging Face
332
+ # qkv_nodes_3 is for LLaMA-2 distribute Hugging Face model
333
+ qkv_nodes = None
334
+ qkv_nodes_1 = self.model.match_parent_path(
335
+ normalize_node,
336
+ ["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
337
+ [1, 0, 0, 0, 0],
338
+ )
339
+ qkv_nodes_2 = self.model.match_parent_path(
340
+ normalize_node,
341
+ ["MatMul", "Reshape", "Transpose", "MatMul"],
342
+ [1, 0, 0, 0],
343
+ )
344
+ qkv_nodes_3 = self.model.match_parent_path(
345
+ normalize_node,
346
+ ["AllReduce", "MatMul", "Reshape", "Transpose", "MatMul"],
347
+ [1, 0, 0, 0, 0],
348
+ )
349
+ if qkv_nodes_1 is not None:
350
+ _, reshape_qkv_2, _, reshape_qkv_1, matmul_qkv = qkv_nodes_1
351
+ qkv_nodes = qkv_nodes_1
352
+ elif qkv_nodes_2 is not None:
353
+ _, reshape_qkv, _, matmul_qkv = qkv_nodes_2
354
+ qkv_nodes = qkv_nodes_2
355
+ elif qkv_nodes_3 is not None:
356
+ _, _, reshape_qkv, _, matmul_qkv = qkv_nodes_3
357
+ qkv_nodes = qkv_nodes_3
358
+ else:
359
+ logger.debug("fuse_rotary_attention: failed to match qkv nodes")
360
+ return
361
+
362
+ # v_nodes_1 is for LLaMA-2 Microsoft
363
+ # v_nodes_3 is for LLaMA-2 Hugging Face
364
+ # v_nodes_4 is for LLaMA-2 70B model
365
+ # v_nodes_5 is for Phi-2 DirectML
366
+ past_v, present_v, past_seq_len = "", "", ""
367
+ v_nodes = None
368
+ add_v = None
369
+ v_nodes_1 = self.model.match_parent_path(
370
+ matmul_qkv,
371
+ ["Reshape", "Transpose", "Concat", "Transpose", "Reshape", "MatMul"],
372
+ [1, 0, 0, 1, 0, 0],
373
+ )
374
+ v_nodes_2 = self.model.match_parent_path(
375
+ matmul_qkv,
376
+ ["Concat", "Transpose", "Reshape", "MatMul"],
377
+ [1, 1, 0, 0],
378
+ )
379
+ v_nodes_3 = self.model.match_parent_path(
380
+ matmul_qkv,
381
+ ["Transpose", "Reshape", "MatMul"],
382
+ [1, 0, 0],
383
+ )
384
+ _, v_nodes_4, _ = self.model.match_parent_paths_all(
385
+ matmul_qkv,
386
+ [
387
+ (
388
+ ["Reshape", "Expand", "Unsqueeze", "Concat", "Transpose", "Reshape", "MatMul"],
389
+ [1, 0, 0, 0, 1, 0, 0],
390
+ ),
391
+ (
392
+ [
393
+ "Reshape",
394
+ "Expand",
395
+ "Where",
396
+ "Equal",
397
+ "Reshape",
398
+ "Concat",
399
+ "Unsqueeze",
400
+ "Gather",
401
+ "Shape",
402
+ "Concat",
403
+ "Transpose",
404
+ "Reshape",
405
+ "MatMul",
406
+ ],
407
+ [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
408
+ ),
409
+ (
410
+ [
411
+ "Reshape",
412
+ "Expand",
413
+ "Where",
414
+ "Equal",
415
+ "Mul",
416
+ "ConstantOfShape",
417
+ "Shape",
418
+ "Reshape",
419
+ "Concat",
420
+ "Unsqueeze",
421
+ "Gather",
422
+ "Shape",
423
+ "Concat",
424
+ "Transpose",
425
+ "Reshape",
426
+ "MatMul",
427
+ ],
428
+ [1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0],
429
+ ),
430
+ (
431
+ [
432
+ "Reshape",
433
+ "Expand",
434
+ "Where",
435
+ "ConstantOfShape",
436
+ "Shape",
437
+ "Reshape",
438
+ "Concat",
439
+ "Unsqueeze",
440
+ "Gather",
441
+ "Shape",
442
+ "Concat",
443
+ "Transpose",
444
+ "Reshape",
445
+ "MatMul",
446
+ ],
447
+ [1, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0],
448
+ ),
449
+ (
450
+ [
451
+ "Reshape",
452
+ "Expand",
453
+ "Where",
454
+ "Reshape",
455
+ "Concat",
456
+ "Unsqueeze",
457
+ "Gather",
458
+ "Shape",
459
+ "Concat",
460
+ "Transpose",
461
+ "Reshape",
462
+ "MatMul",
463
+ ],
464
+ [1, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0],
465
+ ),
466
+ (
467
+ ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
468
+ [1, 1, 0, 0, 0, 0, 1, 0, 0],
469
+ ),
470
+ (
471
+ [
472
+ "Reshape",
473
+ "Concat",
474
+ "Unsqueeze",
475
+ "Mul",
476
+ "Gather",
477
+ "Shape",
478
+ "Concat",
479
+ "Transpose",
480
+ "Reshape",
481
+ "MatMul",
482
+ ],
483
+ [1, 1, 1, 0, 0, 0, 0, 1, 0, 0],
484
+ ),
485
+ (
486
+ ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
487
+ [1, 1, 2, 0, 0, 0, 1, 0, 0],
488
+ ),
489
+ (
490
+ ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
491
+ [1, 1, 3, 0, 0, 0, 1, 0, 0],
492
+ ),
493
+ ],
494
+ output_name_to_node=None,
495
+ )
496
+ v_nodes_5 = self.model.match_parent_path(
497
+ matmul_qkv,
498
+ ["Concat", "Transpose", "Reshape", "Add", "MatMul"],
499
+ [1, 1, 0, 0, 1],
500
+ )
501
+ if v_nodes_1 is not None:
502
+ reshape_v_2, _, concat_v, _, reshape_v_1, matmul_v = v_nodes_1
503
+ v_nodes = v_nodes_1
504
+
505
+ concat_v_path = self.model.match_parent_path(
506
+ concat_v,
507
+ ["Slice", "Unsqueeze"],
508
+ [0, 2],
509
+ )
510
+ if concat_v_path is None:
511
+ logger.debug("fuse_rotary_attention: failed to match past/present concat in v path")
512
+ return
513
+
514
+ past_v = concat_v_path[0].input[0]
515
+ past_seq_len = concat_v_path[-1].input[0]
516
+ present_v = concat_v.output[0]
517
+ elif v_nodes_2 is not None:
518
+ concat_v, transpose_v, reshape_v, matmul_v = v_nodes_2
519
+ v_nodes = v_nodes_2
520
+ past_v = concat_v.input[0]
521
+ present_v = concat_v.output[0]
522
+ elif v_nodes_3 is not None:
523
+ transpose_v, reshape_v, matmul_v = v_nodes_3
524
+ v_nodes = v_nodes_3
525
+ present_v = transpose_v.output[0]
526
+ elif v_nodes_4 is not None and len(v_nodes_4) == 9:
527
+ concat_v, transpose_v, reshape_v, matmul_v = v_nodes_4[0][-4:]
528
+ v_nodes = v_nodes_4
529
+ past_v = concat_v.input[0]
530
+ present_v = concat_v.output[0]
531
+ elif v_nodes_5 is not None:
532
+ concat_v, transpose_v, reshape_v, add_v, matmul_v = v_nodes_5
533
+ matmul_v = add_v
534
+ v_nodes = v_nodes_5
535
+ past_v = concat_v.input[0]
536
+ present_v = concat_v.output[0]
537
+ else:
538
+ logger.debug("fuse_rotary_attention: failed to match v path")
539
+ return
540
+
541
+ qk_nodes = self.model.match_parent_path(
542
+ matmul_qkv,
543
+ ["Softmax", "Add", "Div", "MatMul"],
544
+ [0, 0, 0, 0],
545
+ )
546
+ add_qk, matmul_qk = None, None
547
+ if qk_nodes is not None:
548
+ _, add_qk, _, matmul_qk = qk_nodes
549
+ else:
550
+ logger.debug("fuse_rotary_attention: failed to match qk nodes")
551
+ return
552
+
553
+ # attn_mask_nodes_1, attn_mask_nodes_2 are for LLaMA-2 Microsoft's 3D attention mask
554
+ # attn_mask_nodes_3, attn_mask_nodes_4 are for LLaMA-2 Hugging Face's 2D attention mask
555
+ # attn_mask_nodes_5, attn_mask_nodes_6 are for LLaMA-2 Microsoft's model for the DML EP
556
+ # attn_mask_nodes_7 is for LLaMA-2 Hugging Face's changes to the attention mask
557
+ attn_mask, add_qk_str = "", ""
558
+ attn_mask_nodes_1 = self.model.match_parent_path(
559
+ add_qk,
560
+ ["Concat", "Slice", "Slice"],
561
+ [1, 0, 0],
562
+ )
563
+ attn_mask_nodes_2 = self.model.match_parent_path(
564
+ add_qk,
565
+ ["Cast", "Concat", "Slice", "Slice"],
566
+ [1, 0, 0, 0],
567
+ )
568
+ attn_mask_nodes_3 = self.model.match_parent_path(
569
+ add_qk,
570
+ ["Add", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
571
+ [1, 0, 2, 1, 0, 0, 0],
572
+ )
573
+ attn_mask_nodes_4 = self.model.match_parent_path(
574
+ add_qk,
575
+ ["Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
576
+ [1, 2, 1, 0, 0, 0],
577
+ )
578
+ attn_mask_nodes_5 = self.model.match_parent_path(
579
+ add_qk,
580
+ ["Expand", "Add", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
581
+ [1, 0, 0, 2, 1, 0, 0, 0],
582
+ )
583
+ attn_mask_nodes_6 = self.model.match_parent_path(
584
+ add_qk,
585
+ ["Expand", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
586
+ [1, 0, 2, 1, 0, 0, 0],
587
+ )
588
+ attn_mask_nodes_7 = self.model.match_parent_path(
589
+ add_qk,
590
+ ["Where", "Cast", "Where", "Cast", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
591
+ [1, 0, 0, 0, 0, 1, 0, 0, 0],
592
+ )
593
+ if attn_mask_nodes_1 is not None:
594
+ _, slice_mask_1, slice_mask_2 = attn_mask_nodes_1
595
+ attn_mask = slice_mask_1.output[0]
596
+ elif attn_mask_nodes_2 is not None:
597
+ _, _, slice_mask_1, slice_mask_2 = attn_mask_nodes_2
598
+ attn_mask = slice_mask_1.output[0]
599
+ elif attn_mask_nodes_3 is not None:
600
+ # Reshape from (B,1,S,T) to (B,N,S,T)
601
+ add_qk_str = self.reshape_add_qk(attn_mask_nodes_3[0].output[0])
602
+ elif attn_mask_nodes_4 is not None:
603
+ # Reshape from (B,1,S,T) to (B,N,S,T)
604
+ add_qk_str = self.reshape_add_qk(attn_mask_nodes_4[0].output[0])
605
+ elif attn_mask_nodes_5 is not None:
606
+ # The mask has already been reshaped to (B,N,S,T)
607
+ add_qk_str = attn_mask_nodes_5[0].output[0]
608
+ elif attn_mask_nodes_6 is not None:
609
+ # The mask has already been reshaped to (B,N,S,T)
610
+ add_qk_str = attn_mask_nodes_6[0].output[0]
611
+ elif attn_mask_nodes_7 is not None:
612
+ # Reshape from (B,1,S,T) to (B,N,S,T)
613
+ add_qk_str = self.reshape_add_qk(attn_mask_nodes_7[0].output[0])
614
+ else:
615
+ logger.debug("fuse_rotary_attention: failed to match attention mask nodes")
616
+ return
617
+
618
+ # k_nodes_1 is for LLaMA-2 Microsoft
619
+ # k_nodes_2 is for LLaMA-2 Hugging Face
620
+ # k_nodes_4 is for LLaMA-2 70B Hugging Face
621
+ past_k, present_k = "", ""
622
+ k_nodes = None
623
+ slice_k = None
624
+ concat_k_half = None
625
+ k_nodes_1 = self.model.match_parent_path(
626
+ matmul_qk,
627
+ ["Reshape", "Transpose", "Concat", "Transpose", "RotaryEmbedding", "MatMul"],
628
+ [1, 0, 0, 1, 0, 0],
629
+ )
630
+ k_nodes_2 = self.model.match_parent_path(
631
+ matmul_qk,
632
+ ["Transpose", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
633
+ [1, 0, 0, 0, 0],
634
+ )
635
+ k_nodes_3 = self.model.match_parent_path(
636
+ matmul_qk,
637
+ ["Transpose", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
638
+ [1, 0, 1, 0, 0, 0],
639
+ )
640
+ _, k_nodes_4, _ = self.model.match_parent_paths_all(
641
+ matmul_qk,
642
+ [
643
+ (
644
+ [
645
+ "Transpose",
646
+ "Reshape",
647
+ "Expand",
648
+ "Unsqueeze",
649
+ "Concat",
650
+ "RotaryEmbedding",
651
+ "Transpose",
652
+ "Reshape",
653
+ "MatMul",
654
+ ],
655
+ [1, 0, 0, 0, 0, 1, 0, 0, 0],
656
+ ),
657
+ (
658
+ [
659
+ "Transpose",
660
+ "Reshape",
661
+ "Expand",
662
+ "Where",
663
+ "Equal",
664
+ "Reshape",
665
+ "Concat",
666
+ "Unsqueeze",
667
+ "Gather",
668
+ "Shape",
669
+ "Concat",
670
+ "RotaryEmbedding",
671
+ "Transpose",
672
+ "Reshape",
673
+ "MatMul",
674
+ ],
675
+ [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
676
+ ),
677
+ (
678
+ [
679
+ "Transpose",
680
+ "Reshape",
681
+ "Expand",
682
+ "Where",
683
+ "Equal",
684
+ "Mul",
685
+ "ConstantOfShape",
686
+ "Shape",
687
+ "Reshape",
688
+ "Concat",
689
+ "Unsqueeze",
690
+ "Gather",
691
+ "Shape",
692
+ "Concat",
693
+ "RotaryEmbedding",
694
+ "Transpose",
695
+ "Reshape",
696
+ "MatMul",
697
+ ],
698
+ [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0],
699
+ ),
700
+ (
701
+ [
702
+ "Transpose",
703
+ "Reshape",
704
+ "Expand",
705
+ "Where",
706
+ "ConstantOfShape",
707
+ "Shape",
708
+ "Reshape",
709
+ "Concat",
710
+ "Unsqueeze",
711
+ "Gather",
712
+ "Shape",
713
+ "Concat",
714
+ "RotaryEmbedding",
715
+ "Transpose",
716
+ "Reshape",
717
+ "MatMul",
718
+ ],
719
+ [1, 0, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0],
720
+ ),
721
+ (
722
+ [
723
+ "Transpose",
724
+ "Reshape",
725
+ "Expand",
726
+ "Where",
727
+ "Reshape",
728
+ "Concat",
729
+ "Unsqueeze",
730
+ "Gather",
731
+ "Shape",
732
+ "Concat",
733
+ "RotaryEmbedding",
734
+ "Transpose",
735
+ "Reshape",
736
+ "MatMul",
737
+ ],
738
+ [1, 0, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0, 0],
739
+ ),
740
+ (
741
+ [
742
+ "Transpose",
743
+ "Reshape",
744
+ "Concat",
745
+ "Unsqueeze",
746
+ "Gather",
747
+ "Shape",
748
+ "Concat",
749
+ "RotaryEmbedding",
750
+ "Transpose",
751
+ "Reshape",
752
+ "MatMul",
753
+ ],
754
+ [1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
755
+ ),
756
+ (
757
+ [
758
+ "Transpose",
759
+ "Reshape",
760
+ "Concat",
761
+ "Unsqueeze",
762
+ "Mul",
763
+ "Gather",
764
+ "Shape",
765
+ "Concat",
766
+ "RotaryEmbedding",
767
+ "Transpose",
768
+ "Reshape",
769
+ "MatMul",
770
+ ],
771
+ [1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0],
772
+ ),
773
+ (
774
+ [
775
+ "Transpose",
776
+ "Reshape",
777
+ "Concat",
778
+ "Unsqueeze",
779
+ "Gather",
780
+ "Shape",
781
+ "Concat",
782
+ "RotaryEmbedding",
783
+ "Transpose",
784
+ "Reshape",
785
+ "MatMul",
786
+ ],
787
+ [1, 0, 1, 2, 0, 0, 0, 1, 0, 0, 0],
788
+ ),
789
+ (
790
+ [
791
+ "Transpose",
792
+ "Reshape",
793
+ "Concat",
794
+ "Unsqueeze",
795
+ "Gather",
796
+ "Shape",
797
+ "Concat",
798
+ "RotaryEmbedding",
799
+ "Transpose",
800
+ "Reshape",
801
+ "MatMul",
802
+ ],
803
+ [1, 0, 1, 3, 0, 0, 0, 1, 0, 0, 0],
804
+ ),
805
+ ],
806
+ output_name_to_node=None,
807
+ )
808
+ k_nodes_5 = self.model.match_parent_path(
809
+ matmul_qk,
810
+ ["Transpose", "Concat", "Concat", "RotaryEmbedding", "Slice", "Transpose", "Reshape", "Add", "MatMul"],
811
+ [1, 0, 1, 0, 0, 0, 0, 0, 1],
812
+ )
813
+ if k_nodes_1 is not None:
814
+ reshape_k_2, _, concat_k, _, rotary_k, matmul_k = k_nodes_1
815
+ k_nodes = k_nodes_1
816
+
817
+ concat_k_path = self.model.match_parent_path(
818
+ concat_k,
819
+ ["Slice", "Unsqueeze"],
820
+ [0, 2],
821
+ )
822
+ if concat_k_path is None:
823
+ logger.debug("fuse_rotary_attention: failed to match past/present concat in k path")
824
+ return
825
+
826
+ past_k = concat_k_path[0].input[0]
827
+ shared_past_seq_len = concat_k_path[-1].input[0]
828
+ present_k = concat_k.output[0]
829
+
830
+ assert past_seq_len == shared_past_seq_len
831
+ elif k_nodes_2 is not None:
832
+ _, rotary_k, _, reshape_k, matmul_k = k_nodes_2
833
+ k_nodes = k_nodes_2
834
+ present_k = rotary_k.output[0]
835
+ elif k_nodes_3 is not None:
836
+ _, concat_k, rotary_k, _, reshape_k, matmul_k = k_nodes_3
837
+ k_nodes = k_nodes_3
838
+ past_k = concat_k.input[0]
839
+ present_k = concat_k.output[0]
840
+ elif k_nodes_4 is not None and len(k_nodes_4) == 9:
841
+ reshape_k, matmul_k = k_nodes_4[0][-2:]
842
+ concat_k, rotary_k = k_nodes_4[0][-5:-3]
843
+ k_nodes = k_nodes_4
844
+ past_k = concat_k.input[0]
845
+ present_k = concat_k.output[0]
846
+ elif k_nodes_5 is not None:
847
+ _, concat_k, concat_k_half, rotary_k, slice_k, _, reshape_k, _, matmul_k = k_nodes_5
848
+ k_nodes = k_nodes_5
849
+ past_k = concat_k.input[0]
850
+ present_k = concat_k.output[0]
851
+ else:
852
+ logger.debug("fuse_rotary_attention: failed to match k nodes")
853
+ return
854
+
855
+ # q_nodes_1 is for LLaMA-2 Microsoft
856
+ # q_nodes_2 is for LLaMA-2 Hugging Face
857
+ # q_nodes_3 is for Phi-2 DirectML
858
+ q_nodes = None
859
+ slice_q = None
860
+ concat_q_half = None
861
+ q_nodes_1 = self.model.match_parent_path(
862
+ matmul_qk,
863
+ ["Reshape", "Transpose", "RotaryEmbedding", "MatMul"],
864
+ [0, 0, 0, 0],
865
+ )
866
+ q_nodes_2 = self.model.match_parent_path(
867
+ matmul_qk,
868
+ ["RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
869
+ [0, 0, 0, 0],
870
+ )
871
+ q_nodes_3 = self.model.match_parent_path(
872
+ matmul_qk,
873
+ ["Concat", "RotaryEmbedding", "Slice", "Transpose", "Reshape", "Add", "MatMul"],
874
+ [0, 0, 0, 0, 0, 0, 1],
875
+ )
876
+ if q_nodes_1 is not None:
877
+ reshape_q_2, _, rotary_q, matmul_q = q_nodes_1
878
+ q_nodes = q_nodes_1
879
+ elif q_nodes_2 is not None:
880
+ rotary_q, _, reshape_q, matmul_q = q_nodes_2
881
+ q_nodes = q_nodes_2
882
+ elif q_nodes_3 is not None:
883
+ concat_q_half, rotary_q, slice_q, _, reshape_q, _, matmul_q = q_nodes_3
884
+ q_nodes = q_nodes_3
885
+ else:
886
+ logger.debug("fuse_rotary_attention: failed to match q nodes")
887
+ return
888
+
889
+ if matmul_q.input[0] != matmul_k.input[0] and matmul_k.input[0] != matmul_v.input[0]:
890
+ logger.debug("fuse_rotary_attention: failed to find the same root_input for q, k, v paths")
891
+ return
892
+
893
+ root_output = ""
894
+ if qkv_nodes == qkv_nodes_1:
895
+ if not self.check_runtime_shape_paths_for_function(
896
+ reshape_qkv_2,
897
+ reshape_qkv_1,
898
+ reshape_q_2,
899
+ reshape_k_2,
900
+ reshape_v_2,
901
+ reshape_v_1,
902
+ add_qk,
903
+ matmul_q.input[0],
904
+ ):
905
+ logger.debug("fuse_rotary_attention: failed to verify runtime shape paths")
906
+ return
907
+ root_output = reshape_qkv_2.output[0]
908
+
909
+ elif qkv_nodes in (qkv_nodes_2, qkv_nodes_3):
910
+ if not self.check_runtime_shape_paths_for_nodes(
911
+ reshape_qkv,
912
+ reshape_q,
913
+ reshape_k,
914
+ reshape_v,
915
+ matmul_q.input[0],
916
+ ):
917
+ logger.debug("fuse_rotary_attention: failed to verify runtime shape paths")
918
+ return
919
+ root_output = reshape_qkv.output[0]
920
+
921
+ # Rename inputs of rotary_q/k so it connects with output of matmul_q/k
922
+ # Before: MatMul --> Reshape --> Transpose --> RotaryEmbedding
923
+ # After: MatMul --> RotaryEmbedding
924
+ rotary_q.input[0] = slice_q.output[0] if slice_q else matmul_q.output[0]
925
+ rotary_k.input[0] = slice_k.output[0] if slice_k else matmul_k.output[0]
926
+
927
+ # Rename current output of rotary_k (present_key) so it doesn't match output of MHA (present_key)
928
+ if concat_q_half is None:
929
+ rotary_k.output[0] = rotary_k.name + "_output_0"
930
+
931
+ if qkv_nodes == qkv_nodes_3:
932
+ qkv_nodes = qkv_nodes[1:]
933
+
934
+ def create_hidden_size_concat_node(reshape_q):
935
+ """Detect num_heads and hidden_size for ONNX model from phi-2
936
+ Args:
937
+ reshape_q (NodeProto): reshape node for q
938
+ Returns:
939
+ hidden_size_concat_node(NodeProto): Concat node to be used by reshape
940
+ """
941
+ concat = self.model.match_parent(reshape_q, "Concat", 1)
942
+
943
+ if concat is None:
944
+ logger.debug("fuse_rotary_attention: failed to trace the concat node from reshape_q")
945
+ return None
946
+
947
+ # The shape is a tensor like [?, ?, num_heads, head_size]
948
+ num_head_constant_node = self.model.get_constant_value(concat.input[2])
949
+ head_size_constant_node = self.model.get_constant_value(concat.input[3])
950
+
951
+ if num_head_constant_node is None or head_size_constant_node is None:
952
+ logger.debug("fuse_rotary_attention: failed to get constant nodes of num_heads or head_size")
953
+ return None
954
+
955
+ num_head_value = num_head_constant_node[0]
956
+ head_size_value = head_size_constant_node[0]
957
+
958
+ hidden_size = num_head_value * head_size_value
959
+
960
+ hidden_size_initilizer = self.model.create_node_name("Initializer", name_prefix="hidden_size")
961
+ if self.model.get_initializer(hidden_size_initilizer) is None:
962
+ self.add_initializer(
963
+ name=hidden_size_initilizer,
964
+ data_type=TensorProto.INT64,
965
+ dims=[1],
966
+ vals=[hidden_size],
967
+ raw=False,
968
+ )
969
+
970
+ hidden_size_reshape_node_name = self.model.create_node_name("Concat", name_prefix="hidden_size_concat")
971
+
972
+ hidden_size_concat_node = helper.make_node(
973
+ "Concat",
974
+ inputs=[
975
+ concat.input[0],
976
+ concat.input[1],
977
+ hidden_size_initilizer,
978
+ ],
979
+ outputs=[hidden_size_reshape_node_name + "output_0"],
980
+ name=hidden_size_reshape_node_name,
981
+ )
982
+ hidden_size_concat_node.attribute.extend([helper.make_attribute("axis", 0)])
983
+
984
+ return hidden_size_concat_node
985
+
986
+ # Add Tranpose and Reshape nodes for patial rotary embedding applied in phi-2 before passing into MHA
987
+ if concat_q_half and concat_k_half:
988
+ # Transpose the key output of rotary Embedding
989
+ k_transpose_node_name = self.model.create_node_name("Transpose")
990
+ k_tranpose_output_name = k_transpose_node_name + "_output_0"
991
+ k_transpose_node = helper.make_node(
992
+ "Transpose",
993
+ inputs=[concat_k_half.output[0]],
994
+ outputs=[k_tranpose_output_name],
995
+ name=k_transpose_node_name,
996
+ )
997
+
998
+ k_transpose_node.attribute.extend([helper.make_attribute("perm", [0, 2, 1, 3])])
999
+
1000
+ # Transpose the query output of rotary Embedding
1001
+ q_transpose_node_name = self.model.create_node_name("Transpose")
1002
+ q_tranpose_output_name = q_transpose_node_name + "_output_0"
1003
+ q_transpose_node = helper.make_node(
1004
+ "Transpose",
1005
+ inputs=[concat_q_half.output[0]],
1006
+ outputs=[q_tranpose_output_name],
1007
+ name=q_transpose_node_name,
1008
+ )
1009
+
1010
+ q_transpose_node.attribute.extend([helper.make_attribute("perm", [0, 2, 1, 3])])
1011
+
1012
+ hidden_size_concat_node = create_hidden_size_concat_node(reshape_k)
1013
+ if hidden_size_concat_node is None:
1014
+ logger.debug("fuse_rotary_attention: failed to create hidden_size_concat_node")
1015
+ return
1016
+
1017
+ # Reshape the Rotary Embedding output for key for 4D to 3D
1018
+ concat_k_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="concat_k_half")
1019
+ concat_k_reshape_node = helper.make_node(
1020
+ "Reshape",
1021
+ inputs=[k_transpose_node.output[0], hidden_size_concat_node.output[0]],
1022
+ outputs=[concat_k_reshape_node_name + "_output_0"],
1023
+ name=concat_k_reshape_node_name,
1024
+ )
1025
+
1026
+ # Reshape the Rotary Embedding output for query from 4D to 3D
1027
+ concat_q_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="concat_q_half")
1028
+ concat_q_reshape_node = helper.make_node(
1029
+ "Reshape",
1030
+ inputs=[q_transpose_node.output[0], hidden_size_concat_node.output[0]],
1031
+ outputs=[concat_q_reshape_node_name + "_output_0"],
1032
+ name=concat_q_reshape_node_name,
1033
+ )
1034
+
1035
+ rotary_k = concat_k_reshape_node
1036
+ rotary_q = concat_q_reshape_node
1037
+
1038
+ self.nodes_to_add.append(hidden_size_concat_node)
1039
+ self.nodes_to_add.append(k_transpose_node)
1040
+ self.nodes_to_add.append(q_transpose_node)
1041
+ self.nodes_to_add.append(concat_k_reshape_node)
1042
+ self.nodes_to_add.append(concat_q_reshape_node)
1043
+
1044
+ self.node_name_to_graph_name[hidden_size_concat_node.name] = self.this_graph_name
1045
+ self.node_name_to_graph_name[k_transpose_node.name] = self.this_graph_name
1046
+ self.node_name_to_graph_name[q_transpose_node.name] = self.this_graph_name
1047
+ self.node_name_to_graph_name[concat_k_reshape_node.name] = self.this_graph_name
1048
+ self.node_name_to_graph_name[concat_q_reshape_node.name] = self.this_graph_name
1049
+
1050
+ new_node = self.create_mha_node(
1051
+ matmul_q.input[0],
1052
+ root_output,
1053
+ rotary_q,
1054
+ rotary_k,
1055
+ matmul_v,
1056
+ attn_mask,
1057
+ add_qk_str,
1058
+ past_k,
1059
+ past_v,
1060
+ present_k,
1061
+ present_v,
1062
+ )
1063
+ if new_node is None:
1064
+ logger.debug("fuse_rotary_attention: failed to create multi-head attention with rotary embeddings")
1065
+ return
1066
+
1067
+ self.nodes_to_add.append(new_node)
1068
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
1069
+
1070
+ self.nodes_to_remove.extend(qkv_nodes[1:])
1071
+
1072
+ if v_nodes != v_nodes_4:
1073
+ self.nodes_to_remove.extend(v_nodes[:-1] if add_v is None else v_nodes[:-2])
1074
+ else:
1075
+ nodes_to_keep = [v_nodes[0][-1]]
1076
+ for temp_path in v_nodes:
1077
+ self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep)
1078
+
1079
+ self.nodes_to_remove.extend(qk_nodes)
1080
+
1081
+ if k_nodes == k_nodes_1:
1082
+ self.nodes_to_remove.extend(k_nodes[:-2])
1083
+ elif k_nodes == k_nodes_2:
1084
+ self.nodes_to_remove.append(k_nodes[0])
1085
+ self.nodes_to_remove.append(k_nodes[2])
1086
+ self.nodes_to_remove.append(k_nodes[3])
1087
+ elif k_nodes == k_nodes_3:
1088
+ self.nodes_to_remove.append(k_nodes[0])
1089
+ self.nodes_to_remove.append(k_nodes[1])
1090
+ self.nodes_to_remove.append(k_nodes[3])
1091
+ self.nodes_to_remove.append(k_nodes[4])
1092
+ elif k_nodes == k_nodes_5:
1093
+ self.nodes_to_remove.append(k_nodes[0])
1094
+ self.nodes_to_remove.append(k_nodes[1])
1095
+ elif k_nodes == k_nodes_4:
1096
+ nodes_to_keep = [k_nodes[0][-1], k_nodes[0][-4]]
1097
+ for temp_path in k_nodes:
1098
+ self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep)
1099
+
1100
+ if q_nodes == q_nodes_1:
1101
+ self.nodes_to_remove.extend(q_nodes[:-2])
1102
+ elif q_nodes == q_nodes_2:
1103
+ self.nodes_to_remove.append(q_nodes[1])
1104
+ self.nodes_to_remove.append(q_nodes[2])
1105
+ self.prune_graph = True
1106
+
1107
+
1108
+ class FusionRotaryEmbeddings(Fusion):
1109
+ def __init__(self, model: OnnxModel):
1110
+ self.base_name = "RotaryEmbedding"
1111
+ super().__init__(model, self.base_name, [self.base_name, self.base_name + ".1", "Add"])
1112
+
1113
+ # The RotaryEmbedding function can have multiple extraneous constant outputs even though the function is supposed to produce only one output.
1114
+ # This is a byproduct of a potential CSE bug when using `export_modules_as_functions` in the TorchScript exporter.
1115
+ # To work around this issue, we set the extraneous constant values from the RotaryEmbedding function as initializers in the locations where they are actually used.
1116
+ def reassign_extra_outputs(self, rot_emb_node: NodeProto, function: FunctionProto):
1117
+ # Find extra outputs and Constant nodes attached to those outputs
1118
+ extra_constants, extra_outputs = [], []
1119
+ for fn_node in function.node:
1120
+ if fn_node.op_type == "Constant" and fn_node.input == [] and fn_node.output[0] in function.output:
1121
+ extra_constants.append(fn_node)
1122
+ output_index = list(function.output).index(fn_node.output[0])
1123
+ extra_outputs.append(rot_emb_node.output[output_index])
1124
+
1125
+ # Set extra Constant node outputs as initializers
1126
+ extra_initializers = []
1127
+ for extra_constant in extra_constants:
1128
+ constant_tensorproto = extra_constant.attribute[0].t
1129
+ constant_tensorproto.name = self.model.create_node_name("Constant")
1130
+ self.model.add_initializer(constant_tensorproto)
1131
+ extra_initializers.append(constant_tensorproto.name)
1132
+
1133
+ # Update references of Constant node outputs to initializer references
1134
+ for extra_output, extra_initializer in zip(extra_outputs, extra_initializers):
1135
+ nodes_to_update = list(filter(lambda entry: extra_output in entry.input, self.model.model.graph.node))
1136
+ for node_to_update in nodes_to_update:
1137
+ OnnxModel.replace_node_input(node_to_update, extra_output, extra_initializer)
1138
+
1139
+ return extra_outputs
1140
+
1141
+ def create_rotary_embeddings_from_function(self, node: NodeProto):
1142
+ rotary_emb_node_name = self.model.create_node_name(self.base_name)
1143
+
1144
+ matmul_path = self.model.match_parent_path(
1145
+ node,
1146
+ ["Reshape", "MatMul"],
1147
+ [0, 0],
1148
+ )
1149
+ if matmul_path is not None:
1150
+ reshape_node, matmul_node = matmul_path
1151
+ else:
1152
+ logger.debug("fuse_rotary_embeddings: failed to match MatMul")
1153
+ return
1154
+
1155
+ rotary_emb_inputs = [
1156
+ matmul_node.output[0], # x is of shape (B,S,D) instead of (B,S,N,H)
1157
+ node.input[1], # position_ids
1158
+ ]
1159
+
1160
+ # Convert cos_cache and sin_cache from node attributes to model initializers
1161
+ cos_cache_node = list(filter(lambda constant: constant.output[0] == node.input[2], self.model.model.graph.node))
1162
+ sin_cache_node = list(filter(lambda constant: constant.output[0] == node.input[3], self.model.model.graph.node))
1163
+ cos_cache_name, sin_cache_name = "cos_cache", "sin_cache"
1164
+
1165
+ if (
1166
+ len(cos_cache_node) == 1
1167
+ and len(sin_cache_node) == 1
1168
+ and self.model.get_initializer(cos_cache_name) is None
1169
+ and self.model.get_initializer(sin_cache_name) is None
1170
+ ):
1171
+ cos_cache = numpy_helper.to_array(cos_cache_node[0].attribute[0].t).squeeze()
1172
+ sin_cache = numpy_helper.to_array(sin_cache_node[0].attribute[0].t).squeeze()
1173
+
1174
+ cos_cache_tensor = helper.make_tensor(
1175
+ name=cos_cache_name,
1176
+ data_type=TensorProto.FLOAT,
1177
+ dims=list(cos_cache.shape),
1178
+ vals=cos_cache.flatten().tolist(),
1179
+ )
1180
+ self.model.add_initializer(cos_cache_tensor, self.this_graph_name)
1181
+ sin_cache_tensor = helper.make_tensor(
1182
+ name=sin_cache_name,
1183
+ data_type=TensorProto.FLOAT,
1184
+ dims=list(sin_cache.shape),
1185
+ vals=sin_cache.flatten().tolist(),
1186
+ )
1187
+ self.model.add_initializer(sin_cache_tensor, self.this_graph_name)
1188
+
1189
+ self.nodes_to_remove.extend([cos_cache_node[0], sin_cache_node[0]])
1190
+
1191
+ rotary_emb_inputs.extend([cos_cache_name, sin_cache_name])
1192
+
1193
+ rotary_emb_outputs = node.output
1194
+ if len(rotary_emb_outputs) > 1:
1195
+ # Re-assign extraneous constant outputs in RotaryEmbedding functions as initializers
1196
+ func = list(filter(lambda fn: fn.name == node.op_type, self.model.model.functions))
1197
+ assert len(func) == 1
1198
+ extra_outputs = self.reassign_extra_outputs(node, func[0])
1199
+ rotary_emb_outputs = list(filter(lambda output_name: output_name not in extra_outputs, rotary_emb_outputs))
1200
+ assert len(rotary_emb_outputs) == 1
1201
+
1202
+ rotary_emb_node = helper.make_node(
1203
+ self.base_name,
1204
+ inputs=rotary_emb_inputs,
1205
+ outputs=rotary_emb_outputs,
1206
+ name=rotary_emb_node_name,
1207
+ interleaved=1,
1208
+ )
1209
+ rotary_emb_node.domain = "com.microsoft"
1210
+
1211
+ self.nodes_to_remove.append(reshape_node)
1212
+
1213
+ return rotary_emb_node
1214
+
1215
+ def create_rotary_embeddings_from_nodes(
1216
+ self,
1217
+ root_input: str,
1218
+ position_ids: str,
1219
+ cos_slice: str,
1220
+ sin_slice: str,
1221
+ output: str,
1222
+ ):
1223
+ rotary_emb_node_name = self.model.create_node_name(self.base_name)
1224
+
1225
+ # Convert cos_cache and sin_cache from node attributes to model initializers
1226
+ cos_cache_node = list(filter(lambda constant: constant.output[0] == cos_slice, self.model.model.graph.node))
1227
+ sin_cache_node = list(filter(lambda constant: constant.output[0] == sin_slice, self.model.model.graph.node))
1228
+ cos_cache_name, sin_cache_name = "cos_cache", "sin_cache"
1229
+
1230
+ if (
1231
+ len(cos_cache_node) == 1
1232
+ and len(sin_cache_node) == 1
1233
+ and self.model.get_initializer(cos_cache_name) is None
1234
+ and self.model.get_initializer(sin_cache_name) is None
1235
+ ):
1236
+ cos_cache = numpy_helper.to_array(cos_cache_node[0].attribute[0].t).squeeze()
1237
+ sin_cache = numpy_helper.to_array(sin_cache_node[0].attribute[0].t).squeeze()
1238
+
1239
+ # Reshape cos/sin cache from (M, H) to (M, H/2)
1240
+ head_size = cos_cache.shape[1]
1241
+ cos_cache = cos_cache[:, : (head_size // 2)]
1242
+ sin_cache = sin_cache[:, : (head_size // 2)]
1243
+
1244
+ cos_cache_tensor = helper.make_tensor(
1245
+ name=cos_cache_name,
1246
+ data_type=TensorProto.FLOAT,
1247
+ dims=list(cos_cache.shape),
1248
+ vals=cos_cache.flatten().tolist(),
1249
+ )
1250
+ self.model.add_initializer(cos_cache_tensor, self.this_graph_name)
1251
+ sin_cache_tensor = helper.make_tensor(
1252
+ name=sin_cache_name,
1253
+ data_type=TensorProto.FLOAT,
1254
+ dims=list(sin_cache.shape),
1255
+ vals=sin_cache.flatten().tolist(),
1256
+ )
1257
+ self.model.add_initializer(sin_cache_tensor, self.this_graph_name)
1258
+
1259
+ self.nodes_to_remove.extend([cos_cache_node[0], sin_cache_node[0]])
1260
+
1261
+ rotary_emb_node = helper.make_node(
1262
+ self.base_name,
1263
+ inputs=[root_input, position_ids, cos_cache_name, sin_cache_name],
1264
+ outputs=[output],
1265
+ name=rotary_emb_node_name,
1266
+ interleaved=0,
1267
+ )
1268
+ rotary_emb_node.domain = "com.microsoft"
1269
+ return rotary_emb_node
1270
+
1271
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
1272
+ # Node is either RotaryEmbedding function or Add
1273
+ if self.base_name not in node.op_type and node.op_type != "Add":
1274
+ return
1275
+
1276
+ # Check if node is "RotaryEmbedding nn.Module" exported as a function
1277
+ # (e.g. export_modules_as_functions={RotaryEmbedding} in torch.onnx.export)
1278
+ rotary_emb_node = None
1279
+ if node.op_type != "Add":
1280
+ # Verify that function has the correct inputs
1281
+ if len(node.input) not in {4, 5} or node.input[1] not in {
1282
+ "pos",
1283
+ "pos_id",
1284
+ "position_id",
1285
+ "pos_ids",
1286
+ "position_ids",
1287
+ }:
1288
+ logger.debug("fuse_rotary_embeddings: failed to verify inputs for RotaryEmbedding function")
1289
+ return
1290
+
1291
+ rotary_emb_node = self.create_rotary_embeddings_from_function(node)
1292
+ if rotary_emb_node is None:
1293
+ logger.debug("fuse_rotary_embeddings: failed to create RotaryEmbedding node")
1294
+ return
1295
+
1296
+ # Remove RotaryEmbedding function
1297
+ self.nodes_to_remove.append(node)
1298
+
1299
+ # Remove RotaryEmbedding function's shape inference stored in value_info
1300
+ # The new shape will be calculated during symbolic shape inference
1301
+ old_shape_infer = list(
1302
+ filter(lambda node: node.name == rotary_emb_node.output[0], self.model.model.graph.value_info)
1303
+ )
1304
+ assert len(old_shape_infer) == 1
1305
+ self.model.model.graph.value_info.remove(old_shape_infer[0])
1306
+
1307
+ else:
1308
+ # Rotary embeddings are defined using the below functions:
1309
+ #
1310
+ # def rotate_half(x):
1311
+ # """Rotates half the hidden dims of the input."""
1312
+ # x1 = x[..., : x.shape[-1] // 2]
1313
+ # x2 = x[..., x.shape[-1] // 2 :]
1314
+ # return torch.cat((-x2, x1), dim=-1)
1315
+ #
1316
+ # def apply_rope(x, cos, sin, position_ids):
1317
+ # cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
1318
+ # sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
1319
+ # cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
1320
+ # sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
1321
+ # x_embed = (x * cos) + (rotate_half(x) * sin)
1322
+ # return x_embed
1323
+
1324
+ # Check paths for rotate_half(x)
1325
+ rotate_half_x2_path_1_1 = self.model.match_parent_path(
1326
+ node,
1327
+ ["Mul", "Concat", "Neg", "Slice", "Transpose"],
1328
+ [1, 0, 0, 0, 0],
1329
+ )
1330
+
1331
+ rotate_half_x2_path_1_2 = self.model.match_parent_path(
1332
+ node,
1333
+ ["Mul", "Concat", "Neg", "Slice", "Slice"],
1334
+ [1, 0, 0, 0, 0],
1335
+ )
1336
+
1337
+ rotate_half_x2_path_1 = rotate_half_x2_path_1_1 or rotate_half_x2_path_1_2
1338
+
1339
+ rotate_half_x2_path_2_1 = self.model.match_parent_path(
1340
+ node,
1341
+ ["Mul", "Concat", "Neg", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"],
1342
+ [1, 0, 0, 0, 1, 0, 0, 0, 0],
1343
+ )
1344
+
1345
+ rotate_half_x2_path_2_2 = self.model.match_parent_path(
1346
+ node,
1347
+ ["Mul", "Concat", "Neg", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Slice"],
1348
+ [1, 0, 0, 0, 1, 0, 0, 0, 0],
1349
+ )
1350
+
1351
+ rotate_half_x2_path_2 = rotate_half_x2_path_2_1 or rotate_half_x2_path_2_2
1352
+
1353
+ if rotate_half_x2_path_1 is None or rotate_half_x2_path_2 is None:
1354
+ logger.debug("fuse_rotary_embeddings: failed to match x2 in rotate_half")
1355
+ return
1356
+
1357
+ rotate_half_x1_path_1_1 = self.model.match_parent_path(
1358
+ node,
1359
+ ["Mul", "Concat", "Slice", "Transpose"],
1360
+ [1, 0, 1, 0],
1361
+ )
1362
+
1363
+ rotate_half_x1_path_1_2 = self.model.match_parent_path(
1364
+ node,
1365
+ ["Mul", "Concat", "Slice", "Slice"],
1366
+ [1, 0, 1, 0],
1367
+ )
1368
+
1369
+ rotate_half_x1_path_1 = rotate_half_x1_path_1_1 or rotate_half_x1_path_1_2
1370
+
1371
+ rotate_half_x1_path_2_1 = self.model.match_parent_path(
1372
+ node,
1373
+ ["Mul", "Concat", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"],
1374
+ [1, 0, 1, 2, 0, 0, 0, 0],
1375
+ )
1376
+
1377
+ rotate_half_x1_path_2_2 = self.model.match_parent_path(
1378
+ node,
1379
+ ["Mul", "Concat", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Slice"],
1380
+ [1, 0, 1, 2, 0, 0, 0, 0],
1381
+ )
1382
+
1383
+ rotate_half_x1_path_2 = rotate_half_x1_path_2_1 or rotate_half_x1_path_2_2
1384
+
1385
+ if rotate_half_x1_path_1 is None or rotate_half_x1_path_2 is None:
1386
+ logger.debug("fuse_rotary_embeddings: failed to match x1 in rotate_half")
1387
+ return
1388
+
1389
+ if (
1390
+ rotate_half_x1_path_1[-1].name != rotate_half_x1_path_2[-1].name
1391
+ or rotate_half_x2_path_1[-1].name != rotate_half_x2_path_2[-1].name
1392
+ or rotate_half_x1_path_1[-1].name != rotate_half_x2_path_1[-1].name
1393
+ or rotate_half_x1_path_2[-1].name != rotate_half_x2_path_2[-1].name
1394
+ ):
1395
+ logger.debug("fuse_rotary_embeddings: failed to match common input in rotate_half")
1396
+ return
1397
+
1398
+ # Check path for x
1399
+ x_path_1 = self.model.match_parent_path(
1400
+ node,
1401
+ ["Mul", "Transpose"],
1402
+ [0, 0],
1403
+ )
1404
+
1405
+ x_path_2 = self.model.match_parent_path(
1406
+ node,
1407
+ ["Mul", "Slice"],
1408
+ [0, 0],
1409
+ )
1410
+
1411
+ x_path = x_path_1 or x_path_2
1412
+
1413
+ if x_path is None:
1414
+ logger.debug("fuse_rotary_embeddings: failed to match x in rotate_half")
1415
+ return
1416
+
1417
+ # Check path for sin
1418
+ sin_path, sin_cache, position_ids = None, "", ""
1419
+ sin_path_1 = self.model.match_parent_path(
1420
+ node,
1421
+ ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Gather", "Shape"],
1422
+ [1, 1, 0, 0, 0, 0, 2, 0, 0],
1423
+ )
1424
+ sin_path_2 = self.model.match_parent_path(
1425
+ node,
1426
+ ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Add"],
1427
+ [1, 1, 0, 0, 0, 0, 2, 0],
1428
+ )
1429
+ sin_path_3 = self.model.match_parent_path(
1430
+ node,
1431
+ ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Gather", "Shape"],
1432
+ [1, 1, 0, 0, 2, 0, 0],
1433
+ )
1434
+ sin_path_4 = self.model.match_parent_path(
1435
+ node,
1436
+ ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Add"],
1437
+ [1, 1, 0, 0, 2, 0],
1438
+ )
1439
+ if sin_path_1 is not None:
1440
+ sin_path = sin_path_1
1441
+ sin_cache = sin_path[-4].input[0]
1442
+ elif sin_path_2 is not None:
1443
+ sin_path = sin_path_2
1444
+ sin_cache = sin_path[-3].input[0]
1445
+ elif sin_path_3 is not None:
1446
+ sin_path = sin_path_3
1447
+ sin_cache = sin_path[-4].input[0]
1448
+ position_ids = sin_path[2].input[1]
1449
+ elif sin_path_4 is not None:
1450
+ sin_path = sin_path_4
1451
+ sin_cache = sin_path[-3].input[0]
1452
+ position_ids = sin_path[2].input[1]
1453
+ else:
1454
+ logger.debug("fuse_rotary_embeddings: failed to match sin path in apply_rope")
1455
+ return
1456
+
1457
+ # Check path for cos
1458
+ cos_path, cos_cache = None, ""
1459
+ cos_path_1 = self.model.match_parent_path(
1460
+ node,
1461
+ ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Gather", "Shape"],
1462
+ [0, 1, 0, 0, 0, 0, 2, 0, 0],
1463
+ )
1464
+ cos_path_2 = self.model.match_parent_path(
1465
+ node,
1466
+ ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Add"],
1467
+ [0, 1, 0, 0, 0, 0, 2, 0],
1468
+ )
1469
+ cos_path_3 = self.model.match_parent_path(
1470
+ node,
1471
+ ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Gather", "Shape"],
1472
+ [0, 1, 0, 0, 2, 0, 0],
1473
+ )
1474
+ cos_path_4 = self.model.match_parent_path(
1475
+ node,
1476
+ ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Add"],
1477
+ [0, 1, 0, 0, 2, 0],
1478
+ )
1479
+ if cos_path_1 is not None:
1480
+ cos_path = cos_path_1
1481
+ cos_cache = cos_path[-4].input[0]
1482
+ elif cos_path_2 is not None:
1483
+ cos_path = cos_path_2
1484
+ cos_cache = cos_path[-3].input[0]
1485
+ elif cos_path_3 is not None:
1486
+ cos_path = cos_path_3
1487
+ cos_cache = cos_path[-4].input[0]
1488
+ position_ids = cos_path[2].input[1]
1489
+ elif cos_path_4 is not None:
1490
+ cos_path = cos_path_4
1491
+ cos_cache = cos_path[-3].input[0]
1492
+ position_ids = cos_path[2].input[1]
1493
+ else:
1494
+ logger.debug("fuse_rotary_embeddings: failed to match sin path in apply_rope")
1495
+ return
1496
+
1497
+ # Check path for position ids
1498
+ if position_ids == "":
1499
+ position_ids_from_sin_path = self.model.match_parent_path(
1500
+ sin_path[2],
1501
+ ["Reshape"],
1502
+ [1],
1503
+ )
1504
+ position_ids_from_cos_path = self.model.match_parent_path(
1505
+ cos_path[2],
1506
+ ["Reshape"],
1507
+ [1],
1508
+ )
1509
+ if (
1510
+ position_ids_from_sin_path is None
1511
+ or position_ids_from_cos_path is None
1512
+ or position_ids_from_sin_path[0].name != position_ids_from_cos_path[0].name
1513
+ ):
1514
+ logger.debug("fuse_rotary_embeddings: failed to match position ids path in apply_rope")
1515
+ return
1516
+ position_ids = position_ids_from_cos_path[0].input[0]
1517
+ else:
1518
+ position_ids_from_sin_path = []
1519
+ position_ids_from_cos_path = []
1520
+
1521
+ past_seq_len_path, curr_seq_len_path = None, None
1522
+ if (sin_path == sin_path_1 and cos_path == cos_path_1) or (
1523
+ sin_path == sin_path_3 and cos_path == cos_path_3
1524
+ ):
1525
+ if sin_path[-2].name != cos_path[-2].name or sin_path[-1].name != cos_path[-1].name:
1526
+ logger.debug(
1527
+ "fuse_rotary_embeddings: failed to match common Gather node and Shape node in sin cache and cos cache"
1528
+ )
1529
+ return
1530
+ elif (sin_path == sin_path_2 and cos_path == cos_path_2) or (
1531
+ sin_path == sin_path_4 and cos_path == cos_path_4
1532
+ ):
1533
+ if sin_path[-1].name != cos_path[-1].name:
1534
+ logger.debug("fuse_rotary_embeddings: failed to match common Add node in sin cache and cos cache")
1535
+ return
1536
+ # Match past sequence length path: past_key --> Shape --> Gather --> Add
1537
+ past_seq_len_path = self.model.match_parent_path(
1538
+ sin_path[-1],
1539
+ ["Gather", "Shape"],
1540
+ [1, 0],
1541
+ )
1542
+ # Match current sequence length path: transpose_k --> Shape --> Gather --> Add
1543
+ curr_seq_len_path = self.model.match_parent_path(
1544
+ sin_path[-1],
1545
+ ["Gather", "Shape", "Transpose"],
1546
+ [0, 0, 0],
1547
+ )
1548
+ if (
1549
+ past_seq_len_path is None
1550
+ or curr_seq_len_path is None
1551
+ or self.model.find_graph_input(past_seq_len_path[-1].input[0]) is None
1552
+ or curr_seq_len_path[-1].op_type != "Transpose"
1553
+ ):
1554
+ logger.debug("fuse_rotary_embeddings: failed to match past_seq_len and curr_seq_len paths")
1555
+ return
1556
+ else:
1557
+ logger.debug("fuse_rotary_embeddings: failed to match common cache paths")
1558
+
1559
+ rotary_emb_node = self.create_rotary_embeddings_from_nodes(
1560
+ rotate_half_x1_path_1[-1].output[0],
1561
+ position_ids,
1562
+ cos_cache,
1563
+ sin_cache,
1564
+ node.output[0],
1565
+ )
1566
+ if rotary_emb_node is None:
1567
+ logger.debug("fuse_rotary_embeddings: failed to create RotaryEmbedding node")
1568
+ return
1569
+
1570
+ # Remove rotary embedding nodes
1571
+ self.add_nodes_to_remove([node])
1572
+ self.add_nodes_to_remove(rotate_half_x1_path_1[:-1])
1573
+ self.add_nodes_to_remove(rotate_half_x1_path_2[:-1])
1574
+ self.add_nodes_to_remove(rotate_half_x2_path_1[:-1])
1575
+ self.add_nodes_to_remove(rotate_half_x2_path_2[:-1])
1576
+ self.add_nodes_to_remove(x_path[:-1])
1577
+ self.add_nodes_to_remove(sin_path)
1578
+ self.add_nodes_to_remove(cos_path)
1579
+ self.add_nodes_to_remove(position_ids_from_sin_path[:-1])
1580
+ self.add_nodes_to_remove(position_ids_from_cos_path[:-1])
1581
+
1582
+ if past_seq_len_path is not None and len(self.model.get_children(past_seq_len_path[0])) == 1:
1583
+ # In merged HF model, output of Gather in past_seq_len_path is used twice
1584
+ # for past_key_values.0.key and once for other past_key_values
1585
+ self.add_nodes_to_remove(past_seq_len_path)
1586
+ if curr_seq_len_path is not None:
1587
+ self.add_nodes_to_remove(curr_seq_len_path[:-1])
1588
+
1589
+ self.increase_counter(self.base_name)
1590
+ self.node_name_to_graph_name[rotary_emb_node.name] = self.this_graph_name
1591
+ self.nodes_to_add.append(rotary_emb_node)
1592
+ self.prune_graph = True