onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1591 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import logging
6
+
7
+ from fusion_attention import FusionAttention
8
+ from fusion_base import Fusion
9
+ from onnx import FunctionProto, NodeProto, TensorProto, helper, numpy_helper
10
+ from onnx_model import OnnxModel
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class FusionRotaryAttention(FusionAttention):
16
+ """
17
+ Fuse Attention subgraph with rotary positional embeddings into one MultiHeadAttention node.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ model: OnnxModel,
23
+ hidden_size: int,
24
+ num_heads: int,
25
+ ):
26
+ super().__init__(
27
+ model,
28
+ hidden_size,
29
+ num_heads,
30
+ use_multi_head_attention=True,
31
+ search_op_types=[
32
+ "SimplifiedLayerNormalization",
33
+ "SkipSimplifiedLayerNormalization",
34
+ "LayerNormalization",
35
+ "SkipLayerNormalization",
36
+ "Add",
37
+ ],
38
+ )
39
+
40
+ def create_mha_node(
41
+ self,
42
+ input: str,
43
+ output: str,
44
+ q_rotary: NodeProto,
45
+ k_rotary: NodeProto,
46
+ v_matmul: NodeProto,
47
+ attn_mask: str = "",
48
+ add_qk: str = "",
49
+ past_k: str = "",
50
+ past_v: str = "",
51
+ present_k: str = "",
52
+ present_v: str = "",
53
+ scale: float | None = None,
54
+ ) -> NodeProto | None:
55
+ assert self.num_heads > 0
56
+
57
+ if self.hidden_size > 0 and (self.hidden_size % self.num_heads) != 0:
58
+ logger.debug(
59
+ f"fuse_rotary_attention: input hidden size {self.hidden_size} is not a multiple of num of heads {self.num_heads}"
60
+ )
61
+ return None
62
+
63
+ mha_node_name = self.model.create_node_name("MultiHeadAttention")
64
+ mha_inputs = [
65
+ q_rotary.output[0],
66
+ k_rotary.output[0],
67
+ v_matmul.output[0],
68
+ "", # bias
69
+ attn_mask, # key_padding_mask
70
+ add_qk, # attention_bias
71
+ past_k,
72
+ past_v,
73
+ ]
74
+
75
+ mha_outputs = [output]
76
+ if present_k and present_v:
77
+ mha_outputs.extend([present_k, present_v])
78
+
79
+ mha_node = helper.make_node(
80
+ "MultiHeadAttention",
81
+ inputs=mha_inputs,
82
+ outputs=mha_outputs,
83
+ name=mha_node_name,
84
+ )
85
+
86
+ mha_node.domain = "com.microsoft"
87
+ mha_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)])
88
+ if scale is not None:
89
+ mha_node.attribute.extend([helper.make_attribute("scale", scale)])
90
+ if self.mask_filter_value is not None:
91
+ mha_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
92
+
93
+ self.increase_counter("MultiHeadAttention")
94
+ return mha_node
95
+
96
+ def check_runtime_shape_paths_for_function(
97
+ self,
98
+ reshape_qkv_2, # Reshape after Transpose
99
+ reshape_qkv_1, # Reshape before Transpose
100
+ reshape_q_2, # Reshape after RotaryEmbedding
101
+ reshape_k_2, # Reshape after RotaryEmbedding
102
+ reshape_v_2, # Reshape after Transpose
103
+ reshape_v_1, # Reshape before Transpose
104
+ add_qk, # Add before Softmax
105
+ root_input, # Root input to attention subgraph
106
+ ):
107
+ # Check #1: check paths for qkv nodes
108
+ concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ["Concat"], [1])
109
+ concat_qkv_1_path = self.model.match_parent_path(reshape_qkv_1, ["Concat"], [1])
110
+ if concat_qkv_2_path is None or concat_qkv_1_path is None:
111
+ return False
112
+ concat_qkv_2, concat_qkv_1 = concat_qkv_2_path[0], concat_qkv_1_path[0]
113
+
114
+ reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
115
+ reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
116
+ reshape_qkv_1_path_1 = self.model.match_parent_path(concat_qkv_1, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
117
+ reshape_qkv_1_path_2 = self.model.match_parent_path(concat_qkv_1, ["Unsqueeze", "Gather", "Shape"], [2, 0, 0])
118
+ if (
119
+ reshape_qkv_2_path_1 is None
120
+ or reshape_qkv_2_path_2 is None
121
+ or reshape_qkv_1_path_1 is None
122
+ or reshape_qkv_1_path_2 is None
123
+ ):
124
+ return False
125
+
126
+ _, gather_1, shape_1 = reshape_qkv_2_path_1
127
+ _, gather_2, shape_2 = reshape_qkv_2_path_2
128
+
129
+ # Check root_input --> Shape --> Gather connection
130
+ if shape_1.input[0] != root_input or shape_2.input[0] != root_input:
131
+ return False
132
+
133
+ # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_qkv_1_path_1 and reshape_qkv_1_path_2
134
+ if reshape_qkv_1_path_1[1].name != gather_1.name or reshape_qkv_1_path_2[1].name != gather_2.name:
135
+ return False
136
+
137
+ # Check #2: check paths for v nodes
138
+ concat_v_2_path = self.model.match_parent_path(reshape_v_2, ["Concat"], [1])
139
+ concat_v_1_path = self.model.match_parent_path(reshape_v_1, ["Concat"], [1])
140
+ if concat_v_2_path is None or concat_v_1_path is None:
141
+ return False
142
+ concat_v_2, concat_v_1 = concat_v_2_path[0], concat_v_1_path[0]
143
+
144
+ reshape_v_2_path_1 = self.model.match_parent_path(
145
+ concat_v_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0]
146
+ )
147
+ reshape_v_2_path_2 = self.model.match_parent_path(
148
+ concat_v_2, ["Unsqueeze", "Add", "Gather", "Shape"], [1, 0, 0, 0]
149
+ )
150
+ reshape_v_1_path_1 = self.model.match_parent_path(concat_v_1, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
151
+ reshape_v_1_path_2 = self.model.match_parent_path(concat_v_1, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
152
+ if (
153
+ reshape_v_2_path_1 is None
154
+ or reshape_v_2_path_2 is None
155
+ or reshape_v_1_path_1 is None
156
+ or reshape_v_1_path_2 is None
157
+ ):
158
+ return False
159
+
160
+ # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_v_2_path_1
161
+ # Check Gather --> Add --> Unsqueeze --> Concat --> Reshape connection for reshape_v_2_path_2
162
+ # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_v_1_path_1 and reshape_v_1_path_2
163
+ if (
164
+ reshape_v_2_path_1[2].name != gather_1.name
165
+ or reshape_v_2_path_2[2].name != gather_2.name
166
+ or reshape_v_1_path_1[1].name != gather_1.name
167
+ or reshape_v_1_path_2[1].name != gather_2.name
168
+ ):
169
+ return False
170
+
171
+ # Check #3: check paths for k nodes
172
+ concat_k_2_path = self.model.match_parent_path(reshape_k_2, ["Concat"], [1])
173
+ if concat_k_2_path is None:
174
+ return False
175
+ concat_k_2 = concat_k_2_path[0]
176
+
177
+ reshape_k_2_path_1 = self.model.match_parent_path(
178
+ concat_k_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0]
179
+ )
180
+ reshape_k_2_path_2 = self.model.match_parent_path(
181
+ concat_k_2, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 0, 0]
182
+ )
183
+ if reshape_k_2_path_1 is None or reshape_k_2_path_2 is None:
184
+ return False
185
+
186
+ # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_k_2_path_1
187
+ # Check Gather --> Add --> Unsqueeze --> Concat --> Reshape connection for reshape_k_2_path_2
188
+ if reshape_k_2_path_1[2].name != gather_1.name or reshape_k_2_path_2[2].name != gather_2.name:
189
+ return False
190
+
191
+ # Check #4: check paths for q nodes
192
+ concat_q_2_path = self.model.match_parent_path(reshape_q_2, ["Concat"], [1])
193
+ if concat_q_2_path is None:
194
+ return False
195
+ concat_q_2 = concat_q_2_path[0]
196
+
197
+ reshape_q_2_path_1 = self.model.match_parent_path(
198
+ concat_q_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0]
199
+ )
200
+ reshape_q_2_path_2 = self.model.match_parent_path(concat_q_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
201
+ if reshape_q_2_path_1 is None or reshape_q_2_path_2 is None:
202
+ return False
203
+
204
+ # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_q_2_path_1
205
+ # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_q_2_path_2
206
+ if reshape_q_2_path_1[2].name != gather_1.name or reshape_q_2_path_2[1].name != gather_2.name:
207
+ return False
208
+
209
+ # Check #5: check Mul nodes are the same for q, k, v
210
+ mul_q = reshape_q_2_path_1[1]
211
+ mul_k = reshape_k_2_path_1[1]
212
+ mul_v = reshape_v_2_path_1[1]
213
+ gather_1_out = gather_1.output[0]
214
+ if mul_q.input[0] != gather_1_out or mul_k.input[0] != gather_1_out or mul_v.input[0] != gather_1_out:
215
+ return False
216
+
217
+ # Check #6: check paths for attention mask nodes
218
+ attn_mask_path_1 = self.model.match_parent_path(add_qk, ["Concat", "Slice", "Slice"], [1, 0, 0])
219
+ attn_mask_path_2 = self.model.match_parent_path(add_qk, ["Cast", "Concat", "Slice", "Slice"], [1, 0, 0, 0])
220
+ if attn_mask_path_1 is not None:
221
+ _, slice_qk_2, slice_qk_1 = attn_mask_path_1
222
+ elif attn_mask_path_2 is not None:
223
+ _, _, slice_qk_2, slice_qk_1 = attn_mask_path_2
224
+ else:
225
+ return False
226
+ # Check first input to Slice #1 is 3D attention mask of shape (B,S,T)
227
+ if slice_qk_1.input[0] not in {"attn_mask", "attention_mask"}:
228
+ return False
229
+
230
+ slice_qk_2_path = self.model.match_parent_path(
231
+ slice_qk_2, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 1, 0]
232
+ )
233
+ slice_qk_1_path_1 = self.model.match_parent_path(
234
+ slice_qk_1, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 1, 0]
235
+ )
236
+ slice_qk_1_path_2 = self.model.match_parent_path(slice_qk_1, ["Unsqueeze"], [1])
237
+ if slice_qk_2_path is None or slice_qk_1_path_1 is None or slice_qk_1_path_2 is None:
238
+ return False
239
+
240
+ # Check Gather --> Add --> Unsqueeze #3 --> Slice #2 connection for slice_qk_2_path
241
+ # Check Gather --> Add --> Unsqueeze #2 --> Slice #1 connection for slice_qk_1_path_1
242
+ if slice_qk_2_path[1].name != slice_qk_1_path_1[1].name or slice_qk_2_path[2].name != slice_qk_1_path_1[2].name:
243
+ return False
244
+
245
+ # Check Unsqueeze #1 --> Slice #1 connection for slice_qk_1_path_2
246
+ # Check if first input to Add and Unsqueeze #1 is position ids
247
+ if slice_qk_1_path_1[1].input[0] != slice_qk_1_path_2[0].input[0]:
248
+ return False
249
+
250
+ return True
251
+
252
+ def check_runtime_shape_paths_for_nodes(
253
+ self,
254
+ reshape_qkv, # Final reshape before o_proj MatMul
255
+ reshape_q, # Reshape before q_proj MatMul
256
+ reshape_k, # Reshape before k_proj MatMul
257
+ reshape_v, # Reshape before v_proj MatMul
258
+ root_input, # Root input to attention subgraph
259
+ ):
260
+ # Check #1: check paths for qkv nodes
261
+ concat_qkv_path = self.model.match_parent_path(reshape_qkv, ["Concat"], [1])
262
+ if concat_qkv_path is None:
263
+ return False
264
+ concat_qkv = concat_qkv_path[0]
265
+
266
+ reshape_qkv_path_1 = self.model.match_parent_path(concat_qkv, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
267
+ reshape_qkv_path_2 = self.model.match_parent_path(concat_qkv, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
268
+ if reshape_qkv_path_1 is None or reshape_qkv_path_2 is None:
269
+ return False
270
+
271
+ _, gather_1, shape_1 = reshape_qkv_path_1
272
+ _, gather_2, shape_2 = reshape_qkv_path_2
273
+
274
+ # Check root_input --> Shape --> Gather connection
275
+ if shape_1.input[0] != root_input or shape_2.input[0] != root_input:
276
+ return False
277
+
278
+ # Check #2: check paths for v nodes
279
+ concat_v_path = self.model.match_parent_path(reshape_v, ["Concat"], [1])
280
+ if concat_v_path is None:
281
+ return False
282
+ concat_v = concat_v_path[0]
283
+
284
+ reshape_v_path_1 = self.model.match_parent_path(concat_v, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
285
+ reshape_v_path_2 = self.model.match_parent_path(concat_v, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
286
+ if reshape_v_path_1 is None or reshape_v_path_2 is None:
287
+ return False
288
+
289
+ # Check Gather --> Unsqueeze --> Concat --> Reshape connection
290
+ if reshape_v_path_1[1].name != gather_1.name or reshape_v_path_2[1].name != gather_2.name:
291
+ return False
292
+
293
+ # Check #3: check paths for k nodes
294
+ concat_k_path = self.model.match_parent_path(reshape_k, ["Concat"], [1])
295
+ if concat_k_path is None:
296
+ return False
297
+ concat_k = concat_k_path[0]
298
+
299
+ reshape_k_path_1 = self.model.match_parent_path(concat_k, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
300
+ reshape_k_path_2 = self.model.match_parent_path(concat_k, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
301
+ if reshape_k_path_1 is None or reshape_k_path_2 is None:
302
+ return False
303
+
304
+ # Check Gather --> Unsqueeze --> Concat --> Reshape connection
305
+ if reshape_k_path_1[1].name != gather_1.name or reshape_k_path_2[1].name != gather_2.name:
306
+ return False
307
+
308
+ # Check #4: check paths for q nodes
309
+ concat_q_path = self.model.match_parent_path(reshape_q, ["Concat"], [1])
310
+ if concat_q_path is None:
311
+ return False
312
+ concat_q = concat_q_path[0]
313
+
314
+ reshape_q_path_1 = self.model.match_parent_path(concat_q, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
315
+ reshape_q_path_2 = self.model.match_parent_path(concat_q, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
316
+ if reshape_q_path_1 is None or reshape_q_path_2 is None:
317
+ return False
318
+
319
+ # Check Gather --> Unsqueeze --> Concat --> Reshape connection
320
+ if reshape_q_path_1[1].name != gather_1.name or reshape_q_path_2[1].name != gather_2.name:
321
+ return False
322
+
323
+ return True
324
+
325
+ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
326
+ if normalize_node.op_type not in {"SkipSimplifiedLayerNormalization", "SkipLayerNormalization", "Add"}:
327
+ return
328
+
329
+ # qkv_nodes_1 is for LLaMA-2 Microsoft
330
+ # qkv_nodes_2 is for LLaMA-2 Hugging Face
331
+ # qkv_nodes_3 is for LLaMA-2 distribute Hugging Face model
332
+ qkv_nodes = None
333
+ qkv_nodes_1 = self.model.match_parent_path(
334
+ normalize_node,
335
+ ["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
336
+ [1, 0, 0, 0, 0],
337
+ )
338
+ qkv_nodes_2 = self.model.match_parent_path(
339
+ normalize_node,
340
+ ["MatMul", "Reshape", "Transpose", "MatMul"],
341
+ [1, 0, 0, 0],
342
+ )
343
+ qkv_nodes_3 = self.model.match_parent_path(
344
+ normalize_node,
345
+ ["AllReduce", "MatMul", "Reshape", "Transpose", "MatMul"],
346
+ [1, 0, 0, 0, 0],
347
+ )
348
+ if qkv_nodes_1 is not None:
349
+ _, reshape_qkv_2, _, reshape_qkv_1, matmul_qkv = qkv_nodes_1
350
+ qkv_nodes = qkv_nodes_1
351
+ elif qkv_nodes_2 is not None:
352
+ _, reshape_qkv, _, matmul_qkv = qkv_nodes_2
353
+ qkv_nodes = qkv_nodes_2
354
+ elif qkv_nodes_3 is not None:
355
+ _, _, reshape_qkv, _, matmul_qkv = qkv_nodes_3
356
+ qkv_nodes = qkv_nodes_3
357
+ else:
358
+ logger.debug("fuse_rotary_attention: failed to match qkv nodes")
359
+ return
360
+
361
+ # v_nodes_1 is for LLaMA-2 Microsoft
362
+ # v_nodes_3 is for LLaMA-2 Hugging Face
363
+ # v_nodes_4 is for LLaMA-2 70B model
364
+ # v_nodes_5 is for Phi-2 DirectML
365
+ past_v, present_v, past_seq_len = "", "", ""
366
+ v_nodes = None
367
+ add_v = None
368
+ v_nodes_1 = self.model.match_parent_path(
369
+ matmul_qkv,
370
+ ["Reshape", "Transpose", "Concat", "Transpose", "Reshape", "MatMul"],
371
+ [1, 0, 0, 1, 0, 0],
372
+ )
373
+ v_nodes_2 = self.model.match_parent_path(
374
+ matmul_qkv,
375
+ ["Concat", "Transpose", "Reshape", "MatMul"],
376
+ [1, 1, 0, 0],
377
+ )
378
+ v_nodes_3 = self.model.match_parent_path(
379
+ matmul_qkv,
380
+ ["Transpose", "Reshape", "MatMul"],
381
+ [1, 0, 0],
382
+ )
383
+ _, v_nodes_4, _ = self.model.match_parent_paths_all(
384
+ matmul_qkv,
385
+ [
386
+ (
387
+ ["Reshape", "Expand", "Unsqueeze", "Concat", "Transpose", "Reshape", "MatMul"],
388
+ [1, 0, 0, 0, 1, 0, 0],
389
+ ),
390
+ (
391
+ [
392
+ "Reshape",
393
+ "Expand",
394
+ "Where",
395
+ "Equal",
396
+ "Reshape",
397
+ "Concat",
398
+ "Unsqueeze",
399
+ "Gather",
400
+ "Shape",
401
+ "Concat",
402
+ "Transpose",
403
+ "Reshape",
404
+ "MatMul",
405
+ ],
406
+ [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
407
+ ),
408
+ (
409
+ [
410
+ "Reshape",
411
+ "Expand",
412
+ "Where",
413
+ "Equal",
414
+ "Mul",
415
+ "ConstantOfShape",
416
+ "Shape",
417
+ "Reshape",
418
+ "Concat",
419
+ "Unsqueeze",
420
+ "Gather",
421
+ "Shape",
422
+ "Concat",
423
+ "Transpose",
424
+ "Reshape",
425
+ "MatMul",
426
+ ],
427
+ [1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0],
428
+ ),
429
+ (
430
+ [
431
+ "Reshape",
432
+ "Expand",
433
+ "Where",
434
+ "ConstantOfShape",
435
+ "Shape",
436
+ "Reshape",
437
+ "Concat",
438
+ "Unsqueeze",
439
+ "Gather",
440
+ "Shape",
441
+ "Concat",
442
+ "Transpose",
443
+ "Reshape",
444
+ "MatMul",
445
+ ],
446
+ [1, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0],
447
+ ),
448
+ (
449
+ [
450
+ "Reshape",
451
+ "Expand",
452
+ "Where",
453
+ "Reshape",
454
+ "Concat",
455
+ "Unsqueeze",
456
+ "Gather",
457
+ "Shape",
458
+ "Concat",
459
+ "Transpose",
460
+ "Reshape",
461
+ "MatMul",
462
+ ],
463
+ [1, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0],
464
+ ),
465
+ (
466
+ ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
467
+ [1, 1, 0, 0, 0, 0, 1, 0, 0],
468
+ ),
469
+ (
470
+ [
471
+ "Reshape",
472
+ "Concat",
473
+ "Unsqueeze",
474
+ "Mul",
475
+ "Gather",
476
+ "Shape",
477
+ "Concat",
478
+ "Transpose",
479
+ "Reshape",
480
+ "MatMul",
481
+ ],
482
+ [1, 1, 1, 0, 0, 0, 0, 1, 0, 0],
483
+ ),
484
+ (
485
+ ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
486
+ [1, 1, 2, 0, 0, 0, 1, 0, 0],
487
+ ),
488
+ (
489
+ ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
490
+ [1, 1, 3, 0, 0, 0, 1, 0, 0],
491
+ ),
492
+ ],
493
+ output_name_to_node=None,
494
+ )
495
+ v_nodes_5 = self.model.match_parent_path(
496
+ matmul_qkv,
497
+ ["Concat", "Transpose", "Reshape", "Add", "MatMul"],
498
+ [1, 1, 0, 0, 1],
499
+ )
500
+ if v_nodes_1 is not None:
501
+ reshape_v_2, _, concat_v, _, reshape_v_1, matmul_v = v_nodes_1
502
+ v_nodes = v_nodes_1
503
+
504
+ concat_v_path = self.model.match_parent_path(
505
+ concat_v,
506
+ ["Slice", "Unsqueeze"],
507
+ [0, 2],
508
+ )
509
+ if concat_v_path is None:
510
+ logger.debug("fuse_rotary_attention: failed to match past/present concat in v path")
511
+ return
512
+
513
+ past_v = concat_v_path[0].input[0]
514
+ past_seq_len = concat_v_path[-1].input[0]
515
+ present_v = concat_v.output[0]
516
+ elif v_nodes_2 is not None:
517
+ concat_v, transpose_v, reshape_v, matmul_v = v_nodes_2
518
+ v_nodes = v_nodes_2
519
+ past_v = concat_v.input[0]
520
+ present_v = concat_v.output[0]
521
+ elif v_nodes_3 is not None:
522
+ transpose_v, reshape_v, matmul_v = v_nodes_3
523
+ v_nodes = v_nodes_3
524
+ present_v = transpose_v.output[0]
525
+ elif v_nodes_4 is not None and len(v_nodes_4) == 9:
526
+ concat_v, transpose_v, reshape_v, matmul_v = v_nodes_4[0][-4:]
527
+ v_nodes = v_nodes_4
528
+ past_v = concat_v.input[0]
529
+ present_v = concat_v.output[0]
530
+ elif v_nodes_5 is not None:
531
+ concat_v, transpose_v, reshape_v, add_v, matmul_v = v_nodes_5
532
+ matmul_v = add_v
533
+ v_nodes = v_nodes_5
534
+ past_v = concat_v.input[0]
535
+ present_v = concat_v.output[0]
536
+ else:
537
+ logger.debug("fuse_rotary_attention: failed to match v path")
538
+ return
539
+
540
+ qk_nodes = self.model.match_parent_path(
541
+ matmul_qkv,
542
+ ["Softmax", "Add", "Div", "MatMul"],
543
+ [0, 0, 0, 0],
544
+ )
545
+ add_qk, matmul_qk = None, None
546
+ if qk_nodes is not None:
547
+ _, add_qk, _, matmul_qk = qk_nodes
548
+ else:
549
+ logger.debug("fuse_rotary_attention: failed to match qk nodes")
550
+ return
551
+
552
+ # attn_mask_nodes_1, attn_mask_nodes_2 are for LLaMA-2 Microsoft's 3D attention mask
553
+ # attn_mask_nodes_3, attn_mask_nodes_4 are for LLaMA-2 Hugging Face's 2D attention mask
554
+ # attn_mask_nodes_5, attn_mask_nodes_6 are for LLaMA-2 Microsoft's model for the DML EP
555
+ # attn_mask_nodes_7 is for LLaMA-2 Hugging Face's changes to the attention mask
556
+ attn_mask, add_qk_str = "", ""
557
+ attn_mask_nodes_1 = self.model.match_parent_path(
558
+ add_qk,
559
+ ["Concat", "Slice", "Slice"],
560
+ [1, 0, 0],
561
+ )
562
+ attn_mask_nodes_2 = self.model.match_parent_path(
563
+ add_qk,
564
+ ["Cast", "Concat", "Slice", "Slice"],
565
+ [1, 0, 0, 0],
566
+ )
567
+ attn_mask_nodes_3 = self.model.match_parent_path(
568
+ add_qk,
569
+ ["Add", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
570
+ [1, 0, 2, 1, 0, 0, 0],
571
+ )
572
+ attn_mask_nodes_4 = self.model.match_parent_path(
573
+ add_qk,
574
+ ["Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
575
+ [1, 2, 1, 0, 0, 0],
576
+ )
577
+ attn_mask_nodes_5 = self.model.match_parent_path(
578
+ add_qk,
579
+ ["Expand", "Add", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
580
+ [1, 0, 0, 2, 1, 0, 0, 0],
581
+ )
582
+ attn_mask_nodes_6 = self.model.match_parent_path(
583
+ add_qk,
584
+ ["Expand", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
585
+ [1, 0, 2, 1, 0, 0, 0],
586
+ )
587
+ attn_mask_nodes_7 = self.model.match_parent_path(
588
+ add_qk,
589
+ ["Where", "Cast", "Where", "Cast", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
590
+ [1, 0, 0, 0, 0, 1, 0, 0, 0],
591
+ )
592
+ if attn_mask_nodes_1 is not None:
593
+ _, slice_mask_1, slice_mask_2 = attn_mask_nodes_1
594
+ attn_mask = slice_mask_1.output[0]
595
+ elif attn_mask_nodes_2 is not None:
596
+ _, _, slice_mask_1, slice_mask_2 = attn_mask_nodes_2
597
+ attn_mask = slice_mask_1.output[0]
598
+ elif attn_mask_nodes_3 is not None:
599
+ # Reshape from (B,1,S,T) to (B,N,S,T)
600
+ add_qk_str = self.reshape_add_qk(attn_mask_nodes_3[0].output[0])
601
+ elif attn_mask_nodes_4 is not None:
602
+ # Reshape from (B,1,S,T) to (B,N,S,T)
603
+ add_qk_str = self.reshape_add_qk(attn_mask_nodes_4[0].output[0])
604
+ elif attn_mask_nodes_5 is not None:
605
+ # The mask has already been reshaped to (B,N,S,T)
606
+ add_qk_str = attn_mask_nodes_5[0].output[0]
607
+ elif attn_mask_nodes_6 is not None:
608
+ # The mask has already been reshaped to (B,N,S,T)
609
+ add_qk_str = attn_mask_nodes_6[0].output[0]
610
+ elif attn_mask_nodes_7 is not None:
611
+ # Reshape from (B,1,S,T) to (B,N,S,T)
612
+ add_qk_str = self.reshape_add_qk(attn_mask_nodes_7[0].output[0])
613
+ else:
614
+ logger.debug("fuse_rotary_attention: failed to match attention mask nodes")
615
+ return
616
+
617
+ # k_nodes_1 is for LLaMA-2 Microsoft
618
+ # k_nodes_2 is for LLaMA-2 Hugging Face
619
+ # k_nodes_4 is for LLaMA-2 70B Hugging Face
620
+ past_k, present_k = "", ""
621
+ k_nodes = None
622
+ slice_k = None
623
+ concat_k_half = None
624
+ k_nodes_1 = self.model.match_parent_path(
625
+ matmul_qk,
626
+ ["Reshape", "Transpose", "Concat", "Transpose", "RotaryEmbedding", "MatMul"],
627
+ [1, 0, 0, 1, 0, 0],
628
+ )
629
+ k_nodes_2 = self.model.match_parent_path(
630
+ matmul_qk,
631
+ ["Transpose", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
632
+ [1, 0, 0, 0, 0],
633
+ )
634
+ k_nodes_3 = self.model.match_parent_path(
635
+ matmul_qk,
636
+ ["Transpose", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
637
+ [1, 0, 1, 0, 0, 0],
638
+ )
639
+ _, k_nodes_4, _ = self.model.match_parent_paths_all(
640
+ matmul_qk,
641
+ [
642
+ (
643
+ [
644
+ "Transpose",
645
+ "Reshape",
646
+ "Expand",
647
+ "Unsqueeze",
648
+ "Concat",
649
+ "RotaryEmbedding",
650
+ "Transpose",
651
+ "Reshape",
652
+ "MatMul",
653
+ ],
654
+ [1, 0, 0, 0, 0, 1, 0, 0, 0],
655
+ ),
656
+ (
657
+ [
658
+ "Transpose",
659
+ "Reshape",
660
+ "Expand",
661
+ "Where",
662
+ "Equal",
663
+ "Reshape",
664
+ "Concat",
665
+ "Unsqueeze",
666
+ "Gather",
667
+ "Shape",
668
+ "Concat",
669
+ "RotaryEmbedding",
670
+ "Transpose",
671
+ "Reshape",
672
+ "MatMul",
673
+ ],
674
+ [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
675
+ ),
676
+ (
677
+ [
678
+ "Transpose",
679
+ "Reshape",
680
+ "Expand",
681
+ "Where",
682
+ "Equal",
683
+ "Mul",
684
+ "ConstantOfShape",
685
+ "Shape",
686
+ "Reshape",
687
+ "Concat",
688
+ "Unsqueeze",
689
+ "Gather",
690
+ "Shape",
691
+ "Concat",
692
+ "RotaryEmbedding",
693
+ "Transpose",
694
+ "Reshape",
695
+ "MatMul",
696
+ ],
697
+ [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0],
698
+ ),
699
+ (
700
+ [
701
+ "Transpose",
702
+ "Reshape",
703
+ "Expand",
704
+ "Where",
705
+ "ConstantOfShape",
706
+ "Shape",
707
+ "Reshape",
708
+ "Concat",
709
+ "Unsqueeze",
710
+ "Gather",
711
+ "Shape",
712
+ "Concat",
713
+ "RotaryEmbedding",
714
+ "Transpose",
715
+ "Reshape",
716
+ "MatMul",
717
+ ],
718
+ [1, 0, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0],
719
+ ),
720
+ (
721
+ [
722
+ "Transpose",
723
+ "Reshape",
724
+ "Expand",
725
+ "Where",
726
+ "Reshape",
727
+ "Concat",
728
+ "Unsqueeze",
729
+ "Gather",
730
+ "Shape",
731
+ "Concat",
732
+ "RotaryEmbedding",
733
+ "Transpose",
734
+ "Reshape",
735
+ "MatMul",
736
+ ],
737
+ [1, 0, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0, 0],
738
+ ),
739
+ (
740
+ [
741
+ "Transpose",
742
+ "Reshape",
743
+ "Concat",
744
+ "Unsqueeze",
745
+ "Gather",
746
+ "Shape",
747
+ "Concat",
748
+ "RotaryEmbedding",
749
+ "Transpose",
750
+ "Reshape",
751
+ "MatMul",
752
+ ],
753
+ [1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
754
+ ),
755
+ (
756
+ [
757
+ "Transpose",
758
+ "Reshape",
759
+ "Concat",
760
+ "Unsqueeze",
761
+ "Mul",
762
+ "Gather",
763
+ "Shape",
764
+ "Concat",
765
+ "RotaryEmbedding",
766
+ "Transpose",
767
+ "Reshape",
768
+ "MatMul",
769
+ ],
770
+ [1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0],
771
+ ),
772
+ (
773
+ [
774
+ "Transpose",
775
+ "Reshape",
776
+ "Concat",
777
+ "Unsqueeze",
778
+ "Gather",
779
+ "Shape",
780
+ "Concat",
781
+ "RotaryEmbedding",
782
+ "Transpose",
783
+ "Reshape",
784
+ "MatMul",
785
+ ],
786
+ [1, 0, 1, 2, 0, 0, 0, 1, 0, 0, 0],
787
+ ),
788
+ (
789
+ [
790
+ "Transpose",
791
+ "Reshape",
792
+ "Concat",
793
+ "Unsqueeze",
794
+ "Gather",
795
+ "Shape",
796
+ "Concat",
797
+ "RotaryEmbedding",
798
+ "Transpose",
799
+ "Reshape",
800
+ "MatMul",
801
+ ],
802
+ [1, 0, 1, 3, 0, 0, 0, 1, 0, 0, 0],
803
+ ),
804
+ ],
805
+ output_name_to_node=None,
806
+ )
807
+ k_nodes_5 = self.model.match_parent_path(
808
+ matmul_qk,
809
+ ["Transpose", "Concat", "Concat", "RotaryEmbedding", "Slice", "Transpose", "Reshape", "Add", "MatMul"],
810
+ [1, 0, 1, 0, 0, 0, 0, 0, 1],
811
+ )
812
+ if k_nodes_1 is not None:
813
+ reshape_k_2, _, concat_k, _, rotary_k, matmul_k = k_nodes_1
814
+ k_nodes = k_nodes_1
815
+
816
+ concat_k_path = self.model.match_parent_path(
817
+ concat_k,
818
+ ["Slice", "Unsqueeze"],
819
+ [0, 2],
820
+ )
821
+ if concat_k_path is None:
822
+ logger.debug("fuse_rotary_attention: failed to match past/present concat in k path")
823
+ return
824
+
825
+ past_k = concat_k_path[0].input[0]
826
+ shared_past_seq_len = concat_k_path[-1].input[0]
827
+ present_k = concat_k.output[0]
828
+
829
+ assert past_seq_len == shared_past_seq_len
830
+ elif k_nodes_2 is not None:
831
+ _, rotary_k, _, reshape_k, matmul_k = k_nodes_2
832
+ k_nodes = k_nodes_2
833
+ present_k = rotary_k.output[0]
834
+ elif k_nodes_3 is not None:
835
+ _, concat_k, rotary_k, _, reshape_k, matmul_k = k_nodes_3
836
+ k_nodes = k_nodes_3
837
+ past_k = concat_k.input[0]
838
+ present_k = concat_k.output[0]
839
+ elif k_nodes_4 is not None and len(k_nodes_4) == 9:
840
+ reshape_k, matmul_k = k_nodes_4[0][-2:]
841
+ concat_k, rotary_k = k_nodes_4[0][-5:-3]
842
+ k_nodes = k_nodes_4
843
+ past_k = concat_k.input[0]
844
+ present_k = concat_k.output[0]
845
+ elif k_nodes_5 is not None:
846
+ _, concat_k, concat_k_half, rotary_k, slice_k, _, reshape_k, _, matmul_k = k_nodes_5
847
+ k_nodes = k_nodes_5
848
+ past_k = concat_k.input[0]
849
+ present_k = concat_k.output[0]
850
+ else:
851
+ logger.debug("fuse_rotary_attention: failed to match k nodes")
852
+ return
853
+
854
+ # q_nodes_1 is for LLaMA-2 Microsoft
855
+ # q_nodes_2 is for LLaMA-2 Hugging Face
856
+ # q_nodes_3 is for Phi-2 DirectML
857
+ q_nodes = None
858
+ slice_q = None
859
+ concat_q_half = None
860
+ q_nodes_1 = self.model.match_parent_path(
861
+ matmul_qk,
862
+ ["Reshape", "Transpose", "RotaryEmbedding", "MatMul"],
863
+ [0, 0, 0, 0],
864
+ )
865
+ q_nodes_2 = self.model.match_parent_path(
866
+ matmul_qk,
867
+ ["RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
868
+ [0, 0, 0, 0],
869
+ )
870
+ q_nodes_3 = self.model.match_parent_path(
871
+ matmul_qk,
872
+ ["Concat", "RotaryEmbedding", "Slice", "Transpose", "Reshape", "Add", "MatMul"],
873
+ [0, 0, 0, 0, 0, 0, 1],
874
+ )
875
+ if q_nodes_1 is not None:
876
+ reshape_q_2, _, rotary_q, matmul_q = q_nodes_1
877
+ q_nodes = q_nodes_1
878
+ elif q_nodes_2 is not None:
879
+ rotary_q, _, reshape_q, matmul_q = q_nodes_2
880
+ q_nodes = q_nodes_2
881
+ elif q_nodes_3 is not None:
882
+ concat_q_half, rotary_q, slice_q, _, reshape_q, _, matmul_q = q_nodes_3
883
+ q_nodes = q_nodes_3
884
+ else:
885
+ logger.debug("fuse_rotary_attention: failed to match q nodes")
886
+ return
887
+
888
+ if matmul_q.input[0] != matmul_k.input[0] and matmul_k.input[0] != matmul_v.input[0]:
889
+ logger.debug("fuse_rotary_attention: failed to find the same root_input for q, k, v paths")
890
+ return
891
+
892
+ root_output = ""
893
+ if qkv_nodes == qkv_nodes_1:
894
+ if not self.check_runtime_shape_paths_for_function(
895
+ reshape_qkv_2,
896
+ reshape_qkv_1,
897
+ reshape_q_2,
898
+ reshape_k_2,
899
+ reshape_v_2,
900
+ reshape_v_1,
901
+ add_qk,
902
+ matmul_q.input[0],
903
+ ):
904
+ logger.debug("fuse_rotary_attention: failed to verify runtime shape paths")
905
+ return
906
+ root_output = reshape_qkv_2.output[0]
907
+
908
+ elif qkv_nodes in (qkv_nodes_2, qkv_nodes_3):
909
+ if not self.check_runtime_shape_paths_for_nodes(
910
+ reshape_qkv,
911
+ reshape_q,
912
+ reshape_k,
913
+ reshape_v,
914
+ matmul_q.input[0],
915
+ ):
916
+ logger.debug("fuse_rotary_attention: failed to verify runtime shape paths")
917
+ return
918
+ root_output = reshape_qkv.output[0]
919
+
920
+ # Rename inputs of rotary_q/k so it connects with output of matmul_q/k
921
+ # Before: MatMul --> Reshape --> Transpose --> RotaryEmbedding
922
+ # After: MatMul --> RotaryEmbedding
923
+ rotary_q.input[0] = slice_q.output[0] if slice_q else matmul_q.output[0]
924
+ rotary_k.input[0] = slice_k.output[0] if slice_k else matmul_k.output[0]
925
+
926
+ # Rename current output of rotary_k (present_key) so it doesn't match output of MHA (present_key)
927
+ if concat_q_half is None:
928
+ rotary_k.output[0] = rotary_k.name + "_output_0"
929
+
930
+ if qkv_nodes == qkv_nodes_3:
931
+ qkv_nodes = qkv_nodes[1:]
932
+
933
+ def create_hidden_size_concat_node(reshape_q):
934
+ """Detect num_heads and hidden_size for ONNX model from phi-2
935
+ Args:
936
+ reshape_q (NodeProto): reshape node for q
937
+ Returns:
938
+ hidden_size_concat_node(NodeProto): Concat node to be used by reshape
939
+ """
940
+ concat = self.model.match_parent(reshape_q, "Concat", 1)
941
+
942
+ if concat is None:
943
+ logger.debug("fuse_rotary_attention: failed to trace the concat node from reshape_q")
944
+ return None
945
+
946
+ # The shape is a tensor like [?, ?, num_heads, head_size]
947
+ num_head_constant_node = self.model.get_constant_value(concat.input[2])
948
+ head_size_constant_node = self.model.get_constant_value(concat.input[3])
949
+
950
+ if num_head_constant_node is None or head_size_constant_node is None:
951
+ logger.debug("fuse_rotary_attention: failed to get constant nodes of num_heads or head_size")
952
+ return None
953
+
954
+ num_head_value = num_head_constant_node[0]
955
+ head_size_value = head_size_constant_node[0]
956
+
957
+ hidden_size = num_head_value * head_size_value
958
+
959
+ hidden_size_initilizer = self.model.create_node_name("Initializer", name_prefix="hidden_size")
960
+ if self.model.get_initializer(hidden_size_initilizer) is None:
961
+ self.add_initializer(
962
+ name=hidden_size_initilizer,
963
+ data_type=TensorProto.INT64,
964
+ dims=[1],
965
+ vals=[hidden_size],
966
+ raw=False,
967
+ )
968
+
969
+ hidden_size_reshape_node_name = self.model.create_node_name("Concat", name_prefix="hidden_size_concat")
970
+
971
+ hidden_size_concat_node = helper.make_node(
972
+ "Concat",
973
+ inputs=[
974
+ concat.input[0],
975
+ concat.input[1],
976
+ hidden_size_initilizer,
977
+ ],
978
+ outputs=[hidden_size_reshape_node_name + "output_0"],
979
+ name=hidden_size_reshape_node_name,
980
+ )
981
+ hidden_size_concat_node.attribute.extend([helper.make_attribute("axis", 0)])
982
+
983
+ return hidden_size_concat_node
984
+
985
+ # Add Tranpose and Reshape nodes for patial rotary embedding applied in phi-2 before passing into MHA
986
+ if concat_q_half and concat_k_half:
987
+ # Transpose the key output of rotary Embedding
988
+ k_transpose_node_name = self.model.create_node_name("Transpose")
989
+ k_tranpose_output_name = k_transpose_node_name + "_output_0"
990
+ k_transpose_node = helper.make_node(
991
+ "Transpose",
992
+ inputs=[concat_k_half.output[0]],
993
+ outputs=[k_tranpose_output_name],
994
+ name=k_transpose_node_name,
995
+ )
996
+
997
+ k_transpose_node.attribute.extend([helper.make_attribute("perm", [0, 2, 1, 3])])
998
+
999
+ # Transpose the query output of rotary Embedding
1000
+ q_transpose_node_name = self.model.create_node_name("Transpose")
1001
+ q_tranpose_output_name = q_transpose_node_name + "_output_0"
1002
+ q_transpose_node = helper.make_node(
1003
+ "Transpose",
1004
+ inputs=[concat_q_half.output[0]],
1005
+ outputs=[q_tranpose_output_name],
1006
+ name=q_transpose_node_name,
1007
+ )
1008
+
1009
+ q_transpose_node.attribute.extend([helper.make_attribute("perm", [0, 2, 1, 3])])
1010
+
1011
+ hidden_size_concat_node = create_hidden_size_concat_node(reshape_k)
1012
+ if hidden_size_concat_node is None:
1013
+ logger.debug("fuse_rotary_attention: failed to create hidden_size_concat_node")
1014
+ return
1015
+
1016
+ # Reshape the Rotary Embedding output for key for 4D to 3D
1017
+ concat_k_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="concat_k_half")
1018
+ concat_k_reshape_node = helper.make_node(
1019
+ "Reshape",
1020
+ inputs=[k_transpose_node.output[0], hidden_size_concat_node.output[0]],
1021
+ outputs=[concat_k_reshape_node_name + "_output_0"],
1022
+ name=concat_k_reshape_node_name,
1023
+ )
1024
+
1025
+ # Reshape the Rotary Embedding output for query from 4D to 3D
1026
+ concat_q_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="concat_q_half")
1027
+ concat_q_reshape_node = helper.make_node(
1028
+ "Reshape",
1029
+ inputs=[q_transpose_node.output[0], hidden_size_concat_node.output[0]],
1030
+ outputs=[concat_q_reshape_node_name + "_output_0"],
1031
+ name=concat_q_reshape_node_name,
1032
+ )
1033
+
1034
+ rotary_k = concat_k_reshape_node
1035
+ rotary_q = concat_q_reshape_node
1036
+
1037
+ self.nodes_to_add.append(hidden_size_concat_node)
1038
+ self.nodes_to_add.append(k_transpose_node)
1039
+ self.nodes_to_add.append(q_transpose_node)
1040
+ self.nodes_to_add.append(concat_k_reshape_node)
1041
+ self.nodes_to_add.append(concat_q_reshape_node)
1042
+
1043
+ self.node_name_to_graph_name[hidden_size_concat_node.name] = self.this_graph_name
1044
+ self.node_name_to_graph_name[k_transpose_node.name] = self.this_graph_name
1045
+ self.node_name_to_graph_name[q_transpose_node.name] = self.this_graph_name
1046
+ self.node_name_to_graph_name[concat_k_reshape_node.name] = self.this_graph_name
1047
+ self.node_name_to_graph_name[concat_q_reshape_node.name] = self.this_graph_name
1048
+
1049
+ new_node = self.create_mha_node(
1050
+ matmul_q.input[0],
1051
+ root_output,
1052
+ rotary_q,
1053
+ rotary_k,
1054
+ matmul_v,
1055
+ attn_mask,
1056
+ add_qk_str,
1057
+ past_k,
1058
+ past_v,
1059
+ present_k,
1060
+ present_v,
1061
+ )
1062
+ if new_node is None:
1063
+ logger.debug("fuse_rotary_attention: failed to create multi-head attention with rotary embeddings")
1064
+ return
1065
+
1066
+ self.nodes_to_add.append(new_node)
1067
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
1068
+
1069
+ self.nodes_to_remove.extend(qkv_nodes[1:])
1070
+
1071
+ if v_nodes != v_nodes_4:
1072
+ self.nodes_to_remove.extend(v_nodes[:-1] if add_v is None else v_nodes[:-2])
1073
+ else:
1074
+ nodes_to_keep = [v_nodes[0][-1]]
1075
+ for temp_path in v_nodes:
1076
+ self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep)
1077
+
1078
+ self.nodes_to_remove.extend(qk_nodes)
1079
+
1080
+ if k_nodes == k_nodes_1:
1081
+ self.nodes_to_remove.extend(k_nodes[:-2])
1082
+ elif k_nodes == k_nodes_2:
1083
+ self.nodes_to_remove.append(k_nodes[0])
1084
+ self.nodes_to_remove.append(k_nodes[2])
1085
+ self.nodes_to_remove.append(k_nodes[3])
1086
+ elif k_nodes == k_nodes_3:
1087
+ self.nodes_to_remove.append(k_nodes[0])
1088
+ self.nodes_to_remove.append(k_nodes[1])
1089
+ self.nodes_to_remove.append(k_nodes[3])
1090
+ self.nodes_to_remove.append(k_nodes[4])
1091
+ elif k_nodes == k_nodes_5:
1092
+ self.nodes_to_remove.append(k_nodes[0])
1093
+ self.nodes_to_remove.append(k_nodes[1])
1094
+ elif k_nodes == k_nodes_4:
1095
+ nodes_to_keep = [k_nodes[0][-1], k_nodes[0][-4]]
1096
+ for temp_path in k_nodes:
1097
+ self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep)
1098
+
1099
+ if q_nodes == q_nodes_1:
1100
+ self.nodes_to_remove.extend(q_nodes[:-2])
1101
+ elif q_nodes == q_nodes_2:
1102
+ self.nodes_to_remove.append(q_nodes[1])
1103
+ self.nodes_to_remove.append(q_nodes[2])
1104
+ self.prune_graph = True
1105
+
1106
+
1107
+ class FusionRotaryEmbeddings(Fusion):
1108
+ def __init__(self, model: OnnxModel):
1109
+ self.base_name = "RotaryEmbedding"
1110
+ super().__init__(model, self.base_name, [self.base_name, self.base_name + ".1", "Add"])
1111
+
1112
+ # The RotaryEmbedding function can have multiple extraneous constant outputs even though the function is supposed to produce only one output.
1113
+ # This is a byproduct of a potential CSE bug when using `export_modules_as_functions` in the TorchScript exporter.
1114
+ # To work around this issue, we set the extraneous constant values from the RotaryEmbedding function as initializers in the locations where they are actually used.
1115
+ def reassign_extra_outputs(self, rot_emb_node: NodeProto, function: FunctionProto):
1116
+ # Find extra outputs and Constant nodes attached to those outputs
1117
+ extra_constants, extra_outputs = [], []
1118
+ for fn_node in function.node:
1119
+ if fn_node.op_type == "Constant" and fn_node.input == [] and fn_node.output[0] in function.output:
1120
+ extra_constants.append(fn_node)
1121
+ output_index = list(function.output).index(fn_node.output[0])
1122
+ extra_outputs.append(rot_emb_node.output[output_index])
1123
+
1124
+ # Set extra Constant node outputs as initializers
1125
+ extra_initializers = []
1126
+ for extra_constant in extra_constants:
1127
+ constant_tensorproto = extra_constant.attribute[0].t
1128
+ constant_tensorproto.name = self.model.create_node_name("Constant")
1129
+ self.model.add_initializer(constant_tensorproto)
1130
+ extra_initializers.append(constant_tensorproto.name)
1131
+
1132
+ # Update references of Constant node outputs to initializer references
1133
+ for extra_output, extra_initializer in zip(extra_outputs, extra_initializers, strict=False):
1134
+ nodes_to_update = list(filter(lambda entry: extra_output in entry.input, self.model.model.graph.node))
1135
+ for node_to_update in nodes_to_update:
1136
+ OnnxModel.replace_node_input(node_to_update, extra_output, extra_initializer)
1137
+
1138
+ return extra_outputs
1139
+
1140
+ def create_rotary_embeddings_from_function(self, node: NodeProto):
1141
+ rotary_emb_node_name = self.model.create_node_name(self.base_name)
1142
+
1143
+ matmul_path = self.model.match_parent_path(
1144
+ node,
1145
+ ["Reshape", "MatMul"],
1146
+ [0, 0],
1147
+ )
1148
+ if matmul_path is not None:
1149
+ reshape_node, matmul_node = matmul_path
1150
+ else:
1151
+ logger.debug("fuse_rotary_embeddings: failed to match MatMul")
1152
+ return
1153
+
1154
+ rotary_emb_inputs = [
1155
+ matmul_node.output[0], # x is of shape (B,S,D) instead of (B,S,N,H)
1156
+ node.input[1], # position_ids
1157
+ ]
1158
+
1159
+ # Convert cos_cache and sin_cache from node attributes to model initializers
1160
+ cos_cache_node = list(filter(lambda constant: constant.output[0] == node.input[2], self.model.model.graph.node))
1161
+ sin_cache_node = list(filter(lambda constant: constant.output[0] == node.input[3], self.model.model.graph.node))
1162
+ cos_cache_name, sin_cache_name = "cos_cache", "sin_cache"
1163
+
1164
+ if (
1165
+ len(cos_cache_node) == 1
1166
+ and len(sin_cache_node) == 1
1167
+ and self.model.get_initializer(cos_cache_name) is None
1168
+ and self.model.get_initializer(sin_cache_name) is None
1169
+ ):
1170
+ cos_cache = numpy_helper.to_array(cos_cache_node[0].attribute[0].t).squeeze()
1171
+ sin_cache = numpy_helper.to_array(sin_cache_node[0].attribute[0].t).squeeze()
1172
+
1173
+ cos_cache_tensor = helper.make_tensor(
1174
+ name=cos_cache_name,
1175
+ data_type=TensorProto.FLOAT,
1176
+ dims=list(cos_cache.shape),
1177
+ vals=cos_cache.flatten().tolist(),
1178
+ )
1179
+ self.model.add_initializer(cos_cache_tensor, self.this_graph_name)
1180
+ sin_cache_tensor = helper.make_tensor(
1181
+ name=sin_cache_name,
1182
+ data_type=TensorProto.FLOAT,
1183
+ dims=list(sin_cache.shape),
1184
+ vals=sin_cache.flatten().tolist(),
1185
+ )
1186
+ self.model.add_initializer(sin_cache_tensor, self.this_graph_name)
1187
+
1188
+ self.nodes_to_remove.extend([cos_cache_node[0], sin_cache_node[0]])
1189
+
1190
+ rotary_emb_inputs.extend([cos_cache_name, sin_cache_name])
1191
+
1192
+ rotary_emb_outputs = node.output
1193
+ if len(rotary_emb_outputs) > 1:
1194
+ # Re-assign extraneous constant outputs in RotaryEmbedding functions as initializers
1195
+ func = list(filter(lambda fn: fn.name == node.op_type, self.model.model.functions))
1196
+ assert len(func) == 1
1197
+ extra_outputs = self.reassign_extra_outputs(node, func[0])
1198
+ rotary_emb_outputs = list(filter(lambda output_name: output_name not in extra_outputs, rotary_emb_outputs))
1199
+ assert len(rotary_emb_outputs) == 1
1200
+
1201
+ rotary_emb_node = helper.make_node(
1202
+ self.base_name,
1203
+ inputs=rotary_emb_inputs,
1204
+ outputs=rotary_emb_outputs,
1205
+ name=rotary_emb_node_name,
1206
+ interleaved=1,
1207
+ )
1208
+ rotary_emb_node.domain = "com.microsoft"
1209
+
1210
+ self.nodes_to_remove.append(reshape_node)
1211
+
1212
+ return rotary_emb_node
1213
+
1214
+ def create_rotary_embeddings_from_nodes(
1215
+ self,
1216
+ root_input: str,
1217
+ position_ids: str,
1218
+ cos_slice: str,
1219
+ sin_slice: str,
1220
+ output: str,
1221
+ ):
1222
+ rotary_emb_node_name = self.model.create_node_name(self.base_name)
1223
+
1224
+ # Convert cos_cache and sin_cache from node attributes to model initializers
1225
+ cos_cache_node = list(filter(lambda constant: constant.output[0] == cos_slice, self.model.model.graph.node))
1226
+ sin_cache_node = list(filter(lambda constant: constant.output[0] == sin_slice, self.model.model.graph.node))
1227
+ cos_cache_name, sin_cache_name = "cos_cache", "sin_cache"
1228
+
1229
+ if (
1230
+ len(cos_cache_node) == 1
1231
+ and len(sin_cache_node) == 1
1232
+ and self.model.get_initializer(cos_cache_name) is None
1233
+ and self.model.get_initializer(sin_cache_name) is None
1234
+ ):
1235
+ cos_cache = numpy_helper.to_array(cos_cache_node[0].attribute[0].t).squeeze()
1236
+ sin_cache = numpy_helper.to_array(sin_cache_node[0].attribute[0].t).squeeze()
1237
+
1238
+ # Reshape cos/sin cache from (M, H) to (M, H/2)
1239
+ head_size = cos_cache.shape[1]
1240
+ cos_cache = cos_cache[:, : (head_size // 2)]
1241
+ sin_cache = sin_cache[:, : (head_size // 2)]
1242
+
1243
+ cos_cache_tensor = helper.make_tensor(
1244
+ name=cos_cache_name,
1245
+ data_type=TensorProto.FLOAT,
1246
+ dims=list(cos_cache.shape),
1247
+ vals=cos_cache.flatten().tolist(),
1248
+ )
1249
+ self.model.add_initializer(cos_cache_tensor, self.this_graph_name)
1250
+ sin_cache_tensor = helper.make_tensor(
1251
+ name=sin_cache_name,
1252
+ data_type=TensorProto.FLOAT,
1253
+ dims=list(sin_cache.shape),
1254
+ vals=sin_cache.flatten().tolist(),
1255
+ )
1256
+ self.model.add_initializer(sin_cache_tensor, self.this_graph_name)
1257
+
1258
+ self.nodes_to_remove.extend([cos_cache_node[0], sin_cache_node[0]])
1259
+
1260
+ rotary_emb_node = helper.make_node(
1261
+ self.base_name,
1262
+ inputs=[root_input, position_ids, cos_cache_name, sin_cache_name],
1263
+ outputs=[output],
1264
+ name=rotary_emb_node_name,
1265
+ interleaved=0,
1266
+ )
1267
+ rotary_emb_node.domain = "com.microsoft"
1268
+ return rotary_emb_node
1269
+
1270
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
1271
+ # Node is either RotaryEmbedding function or Add
1272
+ if self.base_name not in node.op_type and node.op_type != "Add":
1273
+ return
1274
+
1275
+ # Check if node is "RotaryEmbedding nn.Module" exported as a function
1276
+ # (e.g. export_modules_as_functions={RotaryEmbedding} in torch.onnx.export)
1277
+ rotary_emb_node = None
1278
+ if node.op_type != "Add":
1279
+ # Verify that function has the correct inputs
1280
+ if len(node.input) not in {4, 5} or node.input[1] not in {
1281
+ "pos",
1282
+ "pos_id",
1283
+ "position_id",
1284
+ "pos_ids",
1285
+ "position_ids",
1286
+ }:
1287
+ logger.debug("fuse_rotary_embeddings: failed to verify inputs for RotaryEmbedding function")
1288
+ return
1289
+
1290
+ rotary_emb_node = self.create_rotary_embeddings_from_function(node)
1291
+ if rotary_emb_node is None:
1292
+ logger.debug("fuse_rotary_embeddings: failed to create RotaryEmbedding node")
1293
+ return
1294
+
1295
+ # Remove RotaryEmbedding function
1296
+ self.nodes_to_remove.append(node)
1297
+
1298
+ # Remove RotaryEmbedding function's shape inference stored in value_info
1299
+ # The new shape will be calculated during symbolic shape inference
1300
+ old_shape_infer = list(
1301
+ filter(lambda node: node.name == rotary_emb_node.output[0], self.model.model.graph.value_info)
1302
+ )
1303
+ assert len(old_shape_infer) == 1
1304
+ self.model.model.graph.value_info.remove(old_shape_infer[0])
1305
+
1306
+ else:
1307
+ # Rotary embeddings are defined using the below functions:
1308
+ #
1309
+ # def rotate_half(x):
1310
+ # """Rotates half the hidden dims of the input."""
1311
+ # x1 = x[..., : x.shape[-1] // 2]
1312
+ # x2 = x[..., x.shape[-1] // 2 :]
1313
+ # return torch.cat((-x2, x1), dim=-1)
1314
+ #
1315
+ # def apply_rope(x, cos, sin, position_ids):
1316
+ # cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
1317
+ # sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
1318
+ # cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
1319
+ # sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
1320
+ # x_embed = (x * cos) + (rotate_half(x) * sin)
1321
+ # return x_embed
1322
+
1323
+ # Check paths for rotate_half(x)
1324
+ rotate_half_x2_path_1_1 = self.model.match_parent_path(
1325
+ node,
1326
+ ["Mul", "Concat", "Neg", "Slice", "Transpose"],
1327
+ [1, 0, 0, 0, 0],
1328
+ )
1329
+
1330
+ rotate_half_x2_path_1_2 = self.model.match_parent_path(
1331
+ node,
1332
+ ["Mul", "Concat", "Neg", "Slice", "Slice"],
1333
+ [1, 0, 0, 0, 0],
1334
+ )
1335
+
1336
+ rotate_half_x2_path_1 = rotate_half_x2_path_1_1 or rotate_half_x2_path_1_2
1337
+
1338
+ rotate_half_x2_path_2_1 = self.model.match_parent_path(
1339
+ node,
1340
+ ["Mul", "Concat", "Neg", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"],
1341
+ [1, 0, 0, 0, 1, 0, 0, 0, 0],
1342
+ )
1343
+
1344
+ rotate_half_x2_path_2_2 = self.model.match_parent_path(
1345
+ node,
1346
+ ["Mul", "Concat", "Neg", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Slice"],
1347
+ [1, 0, 0, 0, 1, 0, 0, 0, 0],
1348
+ )
1349
+
1350
+ rotate_half_x2_path_2 = rotate_half_x2_path_2_1 or rotate_half_x2_path_2_2
1351
+
1352
+ if rotate_half_x2_path_1 is None or rotate_half_x2_path_2 is None:
1353
+ logger.debug("fuse_rotary_embeddings: failed to match x2 in rotate_half")
1354
+ return
1355
+
1356
+ rotate_half_x1_path_1_1 = self.model.match_parent_path(
1357
+ node,
1358
+ ["Mul", "Concat", "Slice", "Transpose"],
1359
+ [1, 0, 1, 0],
1360
+ )
1361
+
1362
+ rotate_half_x1_path_1_2 = self.model.match_parent_path(
1363
+ node,
1364
+ ["Mul", "Concat", "Slice", "Slice"],
1365
+ [1, 0, 1, 0],
1366
+ )
1367
+
1368
+ rotate_half_x1_path_1 = rotate_half_x1_path_1_1 or rotate_half_x1_path_1_2
1369
+
1370
+ rotate_half_x1_path_2_1 = self.model.match_parent_path(
1371
+ node,
1372
+ ["Mul", "Concat", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"],
1373
+ [1, 0, 1, 2, 0, 0, 0, 0],
1374
+ )
1375
+
1376
+ rotate_half_x1_path_2_2 = self.model.match_parent_path(
1377
+ node,
1378
+ ["Mul", "Concat", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Slice"],
1379
+ [1, 0, 1, 2, 0, 0, 0, 0],
1380
+ )
1381
+
1382
+ rotate_half_x1_path_2 = rotate_half_x1_path_2_1 or rotate_half_x1_path_2_2
1383
+
1384
+ if rotate_half_x1_path_1 is None or rotate_half_x1_path_2 is None:
1385
+ logger.debug("fuse_rotary_embeddings: failed to match x1 in rotate_half")
1386
+ return
1387
+
1388
+ if (
1389
+ rotate_half_x1_path_1[-1].name != rotate_half_x1_path_2[-1].name
1390
+ or rotate_half_x2_path_1[-1].name != rotate_half_x2_path_2[-1].name
1391
+ or rotate_half_x1_path_1[-1].name != rotate_half_x2_path_1[-1].name
1392
+ or rotate_half_x1_path_2[-1].name != rotate_half_x2_path_2[-1].name
1393
+ ):
1394
+ logger.debug("fuse_rotary_embeddings: failed to match common input in rotate_half")
1395
+ return
1396
+
1397
+ # Check path for x
1398
+ x_path_1 = self.model.match_parent_path(
1399
+ node,
1400
+ ["Mul", "Transpose"],
1401
+ [0, 0],
1402
+ )
1403
+
1404
+ x_path_2 = self.model.match_parent_path(
1405
+ node,
1406
+ ["Mul", "Slice"],
1407
+ [0, 0],
1408
+ )
1409
+
1410
+ x_path = x_path_1 or x_path_2
1411
+
1412
+ if x_path is None:
1413
+ logger.debug("fuse_rotary_embeddings: failed to match x in rotate_half")
1414
+ return
1415
+
1416
+ # Check path for sin
1417
+ sin_path, sin_cache, position_ids = None, "", ""
1418
+ sin_path_1 = self.model.match_parent_path(
1419
+ node,
1420
+ ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Gather", "Shape"],
1421
+ [1, 1, 0, 0, 0, 0, 2, 0, 0],
1422
+ )
1423
+ sin_path_2 = self.model.match_parent_path(
1424
+ node,
1425
+ ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Add"],
1426
+ [1, 1, 0, 0, 0, 0, 2, 0],
1427
+ )
1428
+ sin_path_3 = self.model.match_parent_path(
1429
+ node,
1430
+ ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Gather", "Shape"],
1431
+ [1, 1, 0, 0, 2, 0, 0],
1432
+ )
1433
+ sin_path_4 = self.model.match_parent_path(
1434
+ node,
1435
+ ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Add"],
1436
+ [1, 1, 0, 0, 2, 0],
1437
+ )
1438
+ if sin_path_1 is not None:
1439
+ sin_path = sin_path_1
1440
+ sin_cache = sin_path[-4].input[0]
1441
+ elif sin_path_2 is not None:
1442
+ sin_path = sin_path_2
1443
+ sin_cache = sin_path[-3].input[0]
1444
+ elif sin_path_3 is not None:
1445
+ sin_path = sin_path_3
1446
+ sin_cache = sin_path[-4].input[0]
1447
+ position_ids = sin_path[2].input[1]
1448
+ elif sin_path_4 is not None:
1449
+ sin_path = sin_path_4
1450
+ sin_cache = sin_path[-3].input[0]
1451
+ position_ids = sin_path[2].input[1]
1452
+ else:
1453
+ logger.debug("fuse_rotary_embeddings: failed to match sin path in apply_rope")
1454
+ return
1455
+
1456
+ # Check path for cos
1457
+ cos_path, cos_cache = None, ""
1458
+ cos_path_1 = self.model.match_parent_path(
1459
+ node,
1460
+ ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Gather", "Shape"],
1461
+ [0, 1, 0, 0, 0, 0, 2, 0, 0],
1462
+ )
1463
+ cos_path_2 = self.model.match_parent_path(
1464
+ node,
1465
+ ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Add"],
1466
+ [0, 1, 0, 0, 0, 0, 2, 0],
1467
+ )
1468
+ cos_path_3 = self.model.match_parent_path(
1469
+ node,
1470
+ ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Gather", "Shape"],
1471
+ [0, 1, 0, 0, 2, 0, 0],
1472
+ )
1473
+ cos_path_4 = self.model.match_parent_path(
1474
+ node,
1475
+ ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Add"],
1476
+ [0, 1, 0, 0, 2, 0],
1477
+ )
1478
+ if cos_path_1 is not None:
1479
+ cos_path = cos_path_1
1480
+ cos_cache = cos_path[-4].input[0]
1481
+ elif cos_path_2 is not None:
1482
+ cos_path = cos_path_2
1483
+ cos_cache = cos_path[-3].input[0]
1484
+ elif cos_path_3 is not None:
1485
+ cos_path = cos_path_3
1486
+ cos_cache = cos_path[-4].input[0]
1487
+ position_ids = cos_path[2].input[1]
1488
+ elif cos_path_4 is not None:
1489
+ cos_path = cos_path_4
1490
+ cos_cache = cos_path[-3].input[0]
1491
+ position_ids = cos_path[2].input[1]
1492
+ else:
1493
+ logger.debug("fuse_rotary_embeddings: failed to match sin path in apply_rope")
1494
+ return
1495
+
1496
+ # Check path for position ids
1497
+ if position_ids == "":
1498
+ position_ids_from_sin_path = self.model.match_parent_path(
1499
+ sin_path[2],
1500
+ ["Reshape"],
1501
+ [1],
1502
+ )
1503
+ position_ids_from_cos_path = self.model.match_parent_path(
1504
+ cos_path[2],
1505
+ ["Reshape"],
1506
+ [1],
1507
+ )
1508
+ if (
1509
+ position_ids_from_sin_path is None
1510
+ or position_ids_from_cos_path is None
1511
+ or position_ids_from_sin_path[0].name != position_ids_from_cos_path[0].name
1512
+ ):
1513
+ logger.debug("fuse_rotary_embeddings: failed to match position ids path in apply_rope")
1514
+ return
1515
+ position_ids = position_ids_from_cos_path[0].input[0]
1516
+ else:
1517
+ position_ids_from_sin_path = []
1518
+ position_ids_from_cos_path = []
1519
+
1520
+ past_seq_len_path, curr_seq_len_path = None, None
1521
+ if (sin_path == sin_path_1 and cos_path == cos_path_1) or (
1522
+ sin_path == sin_path_3 and cos_path == cos_path_3
1523
+ ):
1524
+ if sin_path[-2].name != cos_path[-2].name or sin_path[-1].name != cos_path[-1].name:
1525
+ logger.debug(
1526
+ "fuse_rotary_embeddings: failed to match common Gather node and Shape node in sin cache and cos cache"
1527
+ )
1528
+ return
1529
+ elif (sin_path == sin_path_2 and cos_path == cos_path_2) or (
1530
+ sin_path == sin_path_4 and cos_path == cos_path_4
1531
+ ):
1532
+ if sin_path[-1].name != cos_path[-1].name:
1533
+ logger.debug("fuse_rotary_embeddings: failed to match common Add node in sin cache and cos cache")
1534
+ return
1535
+ # Match past sequence length path: past_key --> Shape --> Gather --> Add
1536
+ past_seq_len_path = self.model.match_parent_path(
1537
+ sin_path[-1],
1538
+ ["Gather", "Shape"],
1539
+ [1, 0],
1540
+ )
1541
+ # Match current sequence length path: transpose_k --> Shape --> Gather --> Add
1542
+ curr_seq_len_path = self.model.match_parent_path(
1543
+ sin_path[-1],
1544
+ ["Gather", "Shape", "Transpose"],
1545
+ [0, 0, 0],
1546
+ )
1547
+ if (
1548
+ past_seq_len_path is None
1549
+ or curr_seq_len_path is None
1550
+ or self.model.find_graph_input(past_seq_len_path[-1].input[0]) is None
1551
+ or curr_seq_len_path[-1].op_type != "Transpose"
1552
+ ):
1553
+ logger.debug("fuse_rotary_embeddings: failed to match past_seq_len and curr_seq_len paths")
1554
+ return
1555
+ else:
1556
+ logger.debug("fuse_rotary_embeddings: failed to match common cache paths")
1557
+
1558
+ rotary_emb_node = self.create_rotary_embeddings_from_nodes(
1559
+ rotate_half_x1_path_1[-1].output[0],
1560
+ position_ids,
1561
+ cos_cache,
1562
+ sin_cache,
1563
+ node.output[0],
1564
+ )
1565
+ if rotary_emb_node is None:
1566
+ logger.debug("fuse_rotary_embeddings: failed to create RotaryEmbedding node")
1567
+ return
1568
+
1569
+ # Remove rotary embedding nodes
1570
+ self.add_nodes_to_remove([node])
1571
+ self.add_nodes_to_remove(rotate_half_x1_path_1[:-1])
1572
+ self.add_nodes_to_remove(rotate_half_x1_path_2[:-1])
1573
+ self.add_nodes_to_remove(rotate_half_x2_path_1[:-1])
1574
+ self.add_nodes_to_remove(rotate_half_x2_path_2[:-1])
1575
+ self.add_nodes_to_remove(x_path[:-1])
1576
+ self.add_nodes_to_remove(sin_path)
1577
+ self.add_nodes_to_remove(cos_path)
1578
+ self.add_nodes_to_remove(position_ids_from_sin_path[:-1])
1579
+ self.add_nodes_to_remove(position_ids_from_cos_path[:-1])
1580
+
1581
+ if past_seq_len_path is not None and len(self.model.get_children(past_seq_len_path[0])) == 1:
1582
+ # In merged HF model, output of Gather in past_seq_len_path is used twice
1583
+ # for past_key_values.0.key and once for other past_key_values
1584
+ self.add_nodes_to_remove(past_seq_len_path)
1585
+ if curr_seq_len_path is not None:
1586
+ self.add_nodes_to_remove(curr_seq_len_path[:-1])
1587
+
1588
+ self.increase_counter(self.base_name)
1589
+ self.node_name_to_graph_name[rotary_emb_node.name] = self.this_graph_name
1590
+ self.nodes_to_add.append(rotary_emb_node)
1591
+ self.prune_graph = True