onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1591 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
from fusion_attention import FusionAttention
|
|
8
|
+
from fusion_base import Fusion
|
|
9
|
+
from onnx import FunctionProto, NodeProto, TensorProto, helper, numpy_helper
|
|
10
|
+
from onnx_model import OnnxModel
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FusionRotaryAttention(FusionAttention):
|
|
16
|
+
"""
|
|
17
|
+
Fuse Attention subgraph with rotary positional embeddings into one MultiHeadAttention node.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
model: OnnxModel,
|
|
23
|
+
hidden_size: int,
|
|
24
|
+
num_heads: int,
|
|
25
|
+
):
|
|
26
|
+
super().__init__(
|
|
27
|
+
model,
|
|
28
|
+
hidden_size,
|
|
29
|
+
num_heads,
|
|
30
|
+
use_multi_head_attention=True,
|
|
31
|
+
search_op_types=[
|
|
32
|
+
"SimplifiedLayerNormalization",
|
|
33
|
+
"SkipSimplifiedLayerNormalization",
|
|
34
|
+
"LayerNormalization",
|
|
35
|
+
"SkipLayerNormalization",
|
|
36
|
+
"Add",
|
|
37
|
+
],
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
def create_mha_node(
|
|
41
|
+
self,
|
|
42
|
+
input: str,
|
|
43
|
+
output: str,
|
|
44
|
+
q_rotary: NodeProto,
|
|
45
|
+
k_rotary: NodeProto,
|
|
46
|
+
v_matmul: NodeProto,
|
|
47
|
+
attn_mask: str = "",
|
|
48
|
+
add_qk: str = "",
|
|
49
|
+
past_k: str = "",
|
|
50
|
+
past_v: str = "",
|
|
51
|
+
present_k: str = "",
|
|
52
|
+
present_v: str = "",
|
|
53
|
+
scale: float | None = None,
|
|
54
|
+
) -> NodeProto | None:
|
|
55
|
+
assert self.num_heads > 0
|
|
56
|
+
|
|
57
|
+
if self.hidden_size > 0 and (self.hidden_size % self.num_heads) != 0:
|
|
58
|
+
logger.debug(
|
|
59
|
+
f"fuse_rotary_attention: input hidden size {self.hidden_size} is not a multiple of num of heads {self.num_heads}"
|
|
60
|
+
)
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
mha_node_name = self.model.create_node_name("MultiHeadAttention")
|
|
64
|
+
mha_inputs = [
|
|
65
|
+
q_rotary.output[0],
|
|
66
|
+
k_rotary.output[0],
|
|
67
|
+
v_matmul.output[0],
|
|
68
|
+
"", # bias
|
|
69
|
+
attn_mask, # key_padding_mask
|
|
70
|
+
add_qk, # attention_bias
|
|
71
|
+
past_k,
|
|
72
|
+
past_v,
|
|
73
|
+
]
|
|
74
|
+
|
|
75
|
+
mha_outputs = [output]
|
|
76
|
+
if present_k and present_v:
|
|
77
|
+
mha_outputs.extend([present_k, present_v])
|
|
78
|
+
|
|
79
|
+
mha_node = helper.make_node(
|
|
80
|
+
"MultiHeadAttention",
|
|
81
|
+
inputs=mha_inputs,
|
|
82
|
+
outputs=mha_outputs,
|
|
83
|
+
name=mha_node_name,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
mha_node.domain = "com.microsoft"
|
|
87
|
+
mha_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)])
|
|
88
|
+
if scale is not None:
|
|
89
|
+
mha_node.attribute.extend([helper.make_attribute("scale", scale)])
|
|
90
|
+
if self.mask_filter_value is not None:
|
|
91
|
+
mha_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
|
|
92
|
+
|
|
93
|
+
self.increase_counter("MultiHeadAttention")
|
|
94
|
+
return mha_node
|
|
95
|
+
|
|
96
|
+
def check_runtime_shape_paths_for_function(
|
|
97
|
+
self,
|
|
98
|
+
reshape_qkv_2, # Reshape after Transpose
|
|
99
|
+
reshape_qkv_1, # Reshape before Transpose
|
|
100
|
+
reshape_q_2, # Reshape after RotaryEmbedding
|
|
101
|
+
reshape_k_2, # Reshape after RotaryEmbedding
|
|
102
|
+
reshape_v_2, # Reshape after Transpose
|
|
103
|
+
reshape_v_1, # Reshape before Transpose
|
|
104
|
+
add_qk, # Add before Softmax
|
|
105
|
+
root_input, # Root input to attention subgraph
|
|
106
|
+
):
|
|
107
|
+
# Check #1: check paths for qkv nodes
|
|
108
|
+
concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ["Concat"], [1])
|
|
109
|
+
concat_qkv_1_path = self.model.match_parent_path(reshape_qkv_1, ["Concat"], [1])
|
|
110
|
+
if concat_qkv_2_path is None or concat_qkv_1_path is None:
|
|
111
|
+
return False
|
|
112
|
+
concat_qkv_2, concat_qkv_1 = concat_qkv_2_path[0], concat_qkv_1_path[0]
|
|
113
|
+
|
|
114
|
+
reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
|
|
115
|
+
reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
|
|
116
|
+
reshape_qkv_1_path_1 = self.model.match_parent_path(concat_qkv_1, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
|
|
117
|
+
reshape_qkv_1_path_2 = self.model.match_parent_path(concat_qkv_1, ["Unsqueeze", "Gather", "Shape"], [2, 0, 0])
|
|
118
|
+
if (
|
|
119
|
+
reshape_qkv_2_path_1 is None
|
|
120
|
+
or reshape_qkv_2_path_2 is None
|
|
121
|
+
or reshape_qkv_1_path_1 is None
|
|
122
|
+
or reshape_qkv_1_path_2 is None
|
|
123
|
+
):
|
|
124
|
+
return False
|
|
125
|
+
|
|
126
|
+
_, gather_1, shape_1 = reshape_qkv_2_path_1
|
|
127
|
+
_, gather_2, shape_2 = reshape_qkv_2_path_2
|
|
128
|
+
|
|
129
|
+
# Check root_input --> Shape --> Gather connection
|
|
130
|
+
if shape_1.input[0] != root_input or shape_2.input[0] != root_input:
|
|
131
|
+
return False
|
|
132
|
+
|
|
133
|
+
# Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_qkv_1_path_1 and reshape_qkv_1_path_2
|
|
134
|
+
if reshape_qkv_1_path_1[1].name != gather_1.name or reshape_qkv_1_path_2[1].name != gather_2.name:
|
|
135
|
+
return False
|
|
136
|
+
|
|
137
|
+
# Check #2: check paths for v nodes
|
|
138
|
+
concat_v_2_path = self.model.match_parent_path(reshape_v_2, ["Concat"], [1])
|
|
139
|
+
concat_v_1_path = self.model.match_parent_path(reshape_v_1, ["Concat"], [1])
|
|
140
|
+
if concat_v_2_path is None or concat_v_1_path is None:
|
|
141
|
+
return False
|
|
142
|
+
concat_v_2, concat_v_1 = concat_v_2_path[0], concat_v_1_path[0]
|
|
143
|
+
|
|
144
|
+
reshape_v_2_path_1 = self.model.match_parent_path(
|
|
145
|
+
concat_v_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0]
|
|
146
|
+
)
|
|
147
|
+
reshape_v_2_path_2 = self.model.match_parent_path(
|
|
148
|
+
concat_v_2, ["Unsqueeze", "Add", "Gather", "Shape"], [1, 0, 0, 0]
|
|
149
|
+
)
|
|
150
|
+
reshape_v_1_path_1 = self.model.match_parent_path(concat_v_1, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
|
|
151
|
+
reshape_v_1_path_2 = self.model.match_parent_path(concat_v_1, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
|
|
152
|
+
if (
|
|
153
|
+
reshape_v_2_path_1 is None
|
|
154
|
+
or reshape_v_2_path_2 is None
|
|
155
|
+
or reshape_v_1_path_1 is None
|
|
156
|
+
or reshape_v_1_path_2 is None
|
|
157
|
+
):
|
|
158
|
+
return False
|
|
159
|
+
|
|
160
|
+
# Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_v_2_path_1
|
|
161
|
+
# Check Gather --> Add --> Unsqueeze --> Concat --> Reshape connection for reshape_v_2_path_2
|
|
162
|
+
# Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_v_1_path_1 and reshape_v_1_path_2
|
|
163
|
+
if (
|
|
164
|
+
reshape_v_2_path_1[2].name != gather_1.name
|
|
165
|
+
or reshape_v_2_path_2[2].name != gather_2.name
|
|
166
|
+
or reshape_v_1_path_1[1].name != gather_1.name
|
|
167
|
+
or reshape_v_1_path_2[1].name != gather_2.name
|
|
168
|
+
):
|
|
169
|
+
return False
|
|
170
|
+
|
|
171
|
+
# Check #3: check paths for k nodes
|
|
172
|
+
concat_k_2_path = self.model.match_parent_path(reshape_k_2, ["Concat"], [1])
|
|
173
|
+
if concat_k_2_path is None:
|
|
174
|
+
return False
|
|
175
|
+
concat_k_2 = concat_k_2_path[0]
|
|
176
|
+
|
|
177
|
+
reshape_k_2_path_1 = self.model.match_parent_path(
|
|
178
|
+
concat_k_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0]
|
|
179
|
+
)
|
|
180
|
+
reshape_k_2_path_2 = self.model.match_parent_path(
|
|
181
|
+
concat_k_2, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 0, 0]
|
|
182
|
+
)
|
|
183
|
+
if reshape_k_2_path_1 is None or reshape_k_2_path_2 is None:
|
|
184
|
+
return False
|
|
185
|
+
|
|
186
|
+
# Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_k_2_path_1
|
|
187
|
+
# Check Gather --> Add --> Unsqueeze --> Concat --> Reshape connection for reshape_k_2_path_2
|
|
188
|
+
if reshape_k_2_path_1[2].name != gather_1.name or reshape_k_2_path_2[2].name != gather_2.name:
|
|
189
|
+
return False
|
|
190
|
+
|
|
191
|
+
# Check #4: check paths for q nodes
|
|
192
|
+
concat_q_2_path = self.model.match_parent_path(reshape_q_2, ["Concat"], [1])
|
|
193
|
+
if concat_q_2_path is None:
|
|
194
|
+
return False
|
|
195
|
+
concat_q_2 = concat_q_2_path[0]
|
|
196
|
+
|
|
197
|
+
reshape_q_2_path_1 = self.model.match_parent_path(
|
|
198
|
+
concat_q_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0]
|
|
199
|
+
)
|
|
200
|
+
reshape_q_2_path_2 = self.model.match_parent_path(concat_q_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
|
|
201
|
+
if reshape_q_2_path_1 is None or reshape_q_2_path_2 is None:
|
|
202
|
+
return False
|
|
203
|
+
|
|
204
|
+
# Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_q_2_path_1
|
|
205
|
+
# Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_q_2_path_2
|
|
206
|
+
if reshape_q_2_path_1[2].name != gather_1.name or reshape_q_2_path_2[1].name != gather_2.name:
|
|
207
|
+
return False
|
|
208
|
+
|
|
209
|
+
# Check #5: check Mul nodes are the same for q, k, v
|
|
210
|
+
mul_q = reshape_q_2_path_1[1]
|
|
211
|
+
mul_k = reshape_k_2_path_1[1]
|
|
212
|
+
mul_v = reshape_v_2_path_1[1]
|
|
213
|
+
gather_1_out = gather_1.output[0]
|
|
214
|
+
if mul_q.input[0] != gather_1_out or mul_k.input[0] != gather_1_out or mul_v.input[0] != gather_1_out:
|
|
215
|
+
return False
|
|
216
|
+
|
|
217
|
+
# Check #6: check paths for attention mask nodes
|
|
218
|
+
attn_mask_path_1 = self.model.match_parent_path(add_qk, ["Concat", "Slice", "Slice"], [1, 0, 0])
|
|
219
|
+
attn_mask_path_2 = self.model.match_parent_path(add_qk, ["Cast", "Concat", "Slice", "Slice"], [1, 0, 0, 0])
|
|
220
|
+
if attn_mask_path_1 is not None:
|
|
221
|
+
_, slice_qk_2, slice_qk_1 = attn_mask_path_1
|
|
222
|
+
elif attn_mask_path_2 is not None:
|
|
223
|
+
_, _, slice_qk_2, slice_qk_1 = attn_mask_path_2
|
|
224
|
+
else:
|
|
225
|
+
return False
|
|
226
|
+
# Check first input to Slice #1 is 3D attention mask of shape (B,S,T)
|
|
227
|
+
if slice_qk_1.input[0] not in {"attn_mask", "attention_mask"}:
|
|
228
|
+
return False
|
|
229
|
+
|
|
230
|
+
slice_qk_2_path = self.model.match_parent_path(
|
|
231
|
+
slice_qk_2, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 1, 0]
|
|
232
|
+
)
|
|
233
|
+
slice_qk_1_path_1 = self.model.match_parent_path(
|
|
234
|
+
slice_qk_1, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 1, 0]
|
|
235
|
+
)
|
|
236
|
+
slice_qk_1_path_2 = self.model.match_parent_path(slice_qk_1, ["Unsqueeze"], [1])
|
|
237
|
+
if slice_qk_2_path is None or slice_qk_1_path_1 is None or slice_qk_1_path_2 is None:
|
|
238
|
+
return False
|
|
239
|
+
|
|
240
|
+
# Check Gather --> Add --> Unsqueeze #3 --> Slice #2 connection for slice_qk_2_path
|
|
241
|
+
# Check Gather --> Add --> Unsqueeze #2 --> Slice #1 connection for slice_qk_1_path_1
|
|
242
|
+
if slice_qk_2_path[1].name != slice_qk_1_path_1[1].name or slice_qk_2_path[2].name != slice_qk_1_path_1[2].name:
|
|
243
|
+
return False
|
|
244
|
+
|
|
245
|
+
# Check Unsqueeze #1 --> Slice #1 connection for slice_qk_1_path_2
|
|
246
|
+
# Check if first input to Add and Unsqueeze #1 is position ids
|
|
247
|
+
if slice_qk_1_path_1[1].input[0] != slice_qk_1_path_2[0].input[0]:
|
|
248
|
+
return False
|
|
249
|
+
|
|
250
|
+
return True
|
|
251
|
+
|
|
252
|
+
def check_runtime_shape_paths_for_nodes(
|
|
253
|
+
self,
|
|
254
|
+
reshape_qkv, # Final reshape before o_proj MatMul
|
|
255
|
+
reshape_q, # Reshape before q_proj MatMul
|
|
256
|
+
reshape_k, # Reshape before k_proj MatMul
|
|
257
|
+
reshape_v, # Reshape before v_proj MatMul
|
|
258
|
+
root_input, # Root input to attention subgraph
|
|
259
|
+
):
|
|
260
|
+
# Check #1: check paths for qkv nodes
|
|
261
|
+
concat_qkv_path = self.model.match_parent_path(reshape_qkv, ["Concat"], [1])
|
|
262
|
+
if concat_qkv_path is None:
|
|
263
|
+
return False
|
|
264
|
+
concat_qkv = concat_qkv_path[0]
|
|
265
|
+
|
|
266
|
+
reshape_qkv_path_1 = self.model.match_parent_path(concat_qkv, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
|
|
267
|
+
reshape_qkv_path_2 = self.model.match_parent_path(concat_qkv, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
|
|
268
|
+
if reshape_qkv_path_1 is None or reshape_qkv_path_2 is None:
|
|
269
|
+
return False
|
|
270
|
+
|
|
271
|
+
_, gather_1, shape_1 = reshape_qkv_path_1
|
|
272
|
+
_, gather_2, shape_2 = reshape_qkv_path_2
|
|
273
|
+
|
|
274
|
+
# Check root_input --> Shape --> Gather connection
|
|
275
|
+
if shape_1.input[0] != root_input or shape_2.input[0] != root_input:
|
|
276
|
+
return False
|
|
277
|
+
|
|
278
|
+
# Check #2: check paths for v nodes
|
|
279
|
+
concat_v_path = self.model.match_parent_path(reshape_v, ["Concat"], [1])
|
|
280
|
+
if concat_v_path is None:
|
|
281
|
+
return False
|
|
282
|
+
concat_v = concat_v_path[0]
|
|
283
|
+
|
|
284
|
+
reshape_v_path_1 = self.model.match_parent_path(concat_v, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
|
|
285
|
+
reshape_v_path_2 = self.model.match_parent_path(concat_v, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
|
|
286
|
+
if reshape_v_path_1 is None or reshape_v_path_2 is None:
|
|
287
|
+
return False
|
|
288
|
+
|
|
289
|
+
# Check Gather --> Unsqueeze --> Concat --> Reshape connection
|
|
290
|
+
if reshape_v_path_1[1].name != gather_1.name or reshape_v_path_2[1].name != gather_2.name:
|
|
291
|
+
return False
|
|
292
|
+
|
|
293
|
+
# Check #3: check paths for k nodes
|
|
294
|
+
concat_k_path = self.model.match_parent_path(reshape_k, ["Concat"], [1])
|
|
295
|
+
if concat_k_path is None:
|
|
296
|
+
return False
|
|
297
|
+
concat_k = concat_k_path[0]
|
|
298
|
+
|
|
299
|
+
reshape_k_path_1 = self.model.match_parent_path(concat_k, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
|
|
300
|
+
reshape_k_path_2 = self.model.match_parent_path(concat_k, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
|
|
301
|
+
if reshape_k_path_1 is None or reshape_k_path_2 is None:
|
|
302
|
+
return False
|
|
303
|
+
|
|
304
|
+
# Check Gather --> Unsqueeze --> Concat --> Reshape connection
|
|
305
|
+
if reshape_k_path_1[1].name != gather_1.name or reshape_k_path_2[1].name != gather_2.name:
|
|
306
|
+
return False
|
|
307
|
+
|
|
308
|
+
# Check #4: check paths for q nodes
|
|
309
|
+
concat_q_path = self.model.match_parent_path(reshape_q, ["Concat"], [1])
|
|
310
|
+
if concat_q_path is None:
|
|
311
|
+
return False
|
|
312
|
+
concat_q = concat_q_path[0]
|
|
313
|
+
|
|
314
|
+
reshape_q_path_1 = self.model.match_parent_path(concat_q, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
|
|
315
|
+
reshape_q_path_2 = self.model.match_parent_path(concat_q, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
|
|
316
|
+
if reshape_q_path_1 is None or reshape_q_path_2 is None:
|
|
317
|
+
return False
|
|
318
|
+
|
|
319
|
+
# Check Gather --> Unsqueeze --> Concat --> Reshape connection
|
|
320
|
+
if reshape_q_path_1[1].name != gather_1.name or reshape_q_path_2[1].name != gather_2.name:
|
|
321
|
+
return False
|
|
322
|
+
|
|
323
|
+
return True
|
|
324
|
+
|
|
325
|
+
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
|
326
|
+
if normalize_node.op_type not in {"SkipSimplifiedLayerNormalization", "SkipLayerNormalization", "Add"}:
|
|
327
|
+
return
|
|
328
|
+
|
|
329
|
+
# qkv_nodes_1 is for LLaMA-2 Microsoft
|
|
330
|
+
# qkv_nodes_2 is for LLaMA-2 Hugging Face
|
|
331
|
+
# qkv_nodes_3 is for LLaMA-2 distribute Hugging Face model
|
|
332
|
+
qkv_nodes = None
|
|
333
|
+
qkv_nodes_1 = self.model.match_parent_path(
|
|
334
|
+
normalize_node,
|
|
335
|
+
["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
|
|
336
|
+
[1, 0, 0, 0, 0],
|
|
337
|
+
)
|
|
338
|
+
qkv_nodes_2 = self.model.match_parent_path(
|
|
339
|
+
normalize_node,
|
|
340
|
+
["MatMul", "Reshape", "Transpose", "MatMul"],
|
|
341
|
+
[1, 0, 0, 0],
|
|
342
|
+
)
|
|
343
|
+
qkv_nodes_3 = self.model.match_parent_path(
|
|
344
|
+
normalize_node,
|
|
345
|
+
["AllReduce", "MatMul", "Reshape", "Transpose", "MatMul"],
|
|
346
|
+
[1, 0, 0, 0, 0],
|
|
347
|
+
)
|
|
348
|
+
if qkv_nodes_1 is not None:
|
|
349
|
+
_, reshape_qkv_2, _, reshape_qkv_1, matmul_qkv = qkv_nodes_1
|
|
350
|
+
qkv_nodes = qkv_nodes_1
|
|
351
|
+
elif qkv_nodes_2 is not None:
|
|
352
|
+
_, reshape_qkv, _, matmul_qkv = qkv_nodes_2
|
|
353
|
+
qkv_nodes = qkv_nodes_2
|
|
354
|
+
elif qkv_nodes_3 is not None:
|
|
355
|
+
_, _, reshape_qkv, _, matmul_qkv = qkv_nodes_3
|
|
356
|
+
qkv_nodes = qkv_nodes_3
|
|
357
|
+
else:
|
|
358
|
+
logger.debug("fuse_rotary_attention: failed to match qkv nodes")
|
|
359
|
+
return
|
|
360
|
+
|
|
361
|
+
# v_nodes_1 is for LLaMA-2 Microsoft
|
|
362
|
+
# v_nodes_3 is for LLaMA-2 Hugging Face
|
|
363
|
+
# v_nodes_4 is for LLaMA-2 70B model
|
|
364
|
+
# v_nodes_5 is for Phi-2 DirectML
|
|
365
|
+
past_v, present_v, past_seq_len = "", "", ""
|
|
366
|
+
v_nodes = None
|
|
367
|
+
add_v = None
|
|
368
|
+
v_nodes_1 = self.model.match_parent_path(
|
|
369
|
+
matmul_qkv,
|
|
370
|
+
["Reshape", "Transpose", "Concat", "Transpose", "Reshape", "MatMul"],
|
|
371
|
+
[1, 0, 0, 1, 0, 0],
|
|
372
|
+
)
|
|
373
|
+
v_nodes_2 = self.model.match_parent_path(
|
|
374
|
+
matmul_qkv,
|
|
375
|
+
["Concat", "Transpose", "Reshape", "MatMul"],
|
|
376
|
+
[1, 1, 0, 0],
|
|
377
|
+
)
|
|
378
|
+
v_nodes_3 = self.model.match_parent_path(
|
|
379
|
+
matmul_qkv,
|
|
380
|
+
["Transpose", "Reshape", "MatMul"],
|
|
381
|
+
[1, 0, 0],
|
|
382
|
+
)
|
|
383
|
+
_, v_nodes_4, _ = self.model.match_parent_paths_all(
|
|
384
|
+
matmul_qkv,
|
|
385
|
+
[
|
|
386
|
+
(
|
|
387
|
+
["Reshape", "Expand", "Unsqueeze", "Concat", "Transpose", "Reshape", "MatMul"],
|
|
388
|
+
[1, 0, 0, 0, 1, 0, 0],
|
|
389
|
+
),
|
|
390
|
+
(
|
|
391
|
+
[
|
|
392
|
+
"Reshape",
|
|
393
|
+
"Expand",
|
|
394
|
+
"Where",
|
|
395
|
+
"Equal",
|
|
396
|
+
"Reshape",
|
|
397
|
+
"Concat",
|
|
398
|
+
"Unsqueeze",
|
|
399
|
+
"Gather",
|
|
400
|
+
"Shape",
|
|
401
|
+
"Concat",
|
|
402
|
+
"Transpose",
|
|
403
|
+
"Reshape",
|
|
404
|
+
"MatMul",
|
|
405
|
+
],
|
|
406
|
+
[1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
|
|
407
|
+
),
|
|
408
|
+
(
|
|
409
|
+
[
|
|
410
|
+
"Reshape",
|
|
411
|
+
"Expand",
|
|
412
|
+
"Where",
|
|
413
|
+
"Equal",
|
|
414
|
+
"Mul",
|
|
415
|
+
"ConstantOfShape",
|
|
416
|
+
"Shape",
|
|
417
|
+
"Reshape",
|
|
418
|
+
"Concat",
|
|
419
|
+
"Unsqueeze",
|
|
420
|
+
"Gather",
|
|
421
|
+
"Shape",
|
|
422
|
+
"Concat",
|
|
423
|
+
"Transpose",
|
|
424
|
+
"Reshape",
|
|
425
|
+
"MatMul",
|
|
426
|
+
],
|
|
427
|
+
[1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0],
|
|
428
|
+
),
|
|
429
|
+
(
|
|
430
|
+
[
|
|
431
|
+
"Reshape",
|
|
432
|
+
"Expand",
|
|
433
|
+
"Where",
|
|
434
|
+
"ConstantOfShape",
|
|
435
|
+
"Shape",
|
|
436
|
+
"Reshape",
|
|
437
|
+
"Concat",
|
|
438
|
+
"Unsqueeze",
|
|
439
|
+
"Gather",
|
|
440
|
+
"Shape",
|
|
441
|
+
"Concat",
|
|
442
|
+
"Transpose",
|
|
443
|
+
"Reshape",
|
|
444
|
+
"MatMul",
|
|
445
|
+
],
|
|
446
|
+
[1, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0],
|
|
447
|
+
),
|
|
448
|
+
(
|
|
449
|
+
[
|
|
450
|
+
"Reshape",
|
|
451
|
+
"Expand",
|
|
452
|
+
"Where",
|
|
453
|
+
"Reshape",
|
|
454
|
+
"Concat",
|
|
455
|
+
"Unsqueeze",
|
|
456
|
+
"Gather",
|
|
457
|
+
"Shape",
|
|
458
|
+
"Concat",
|
|
459
|
+
"Transpose",
|
|
460
|
+
"Reshape",
|
|
461
|
+
"MatMul",
|
|
462
|
+
],
|
|
463
|
+
[1, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0],
|
|
464
|
+
),
|
|
465
|
+
(
|
|
466
|
+
["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
|
|
467
|
+
[1, 1, 0, 0, 0, 0, 1, 0, 0],
|
|
468
|
+
),
|
|
469
|
+
(
|
|
470
|
+
[
|
|
471
|
+
"Reshape",
|
|
472
|
+
"Concat",
|
|
473
|
+
"Unsqueeze",
|
|
474
|
+
"Mul",
|
|
475
|
+
"Gather",
|
|
476
|
+
"Shape",
|
|
477
|
+
"Concat",
|
|
478
|
+
"Transpose",
|
|
479
|
+
"Reshape",
|
|
480
|
+
"MatMul",
|
|
481
|
+
],
|
|
482
|
+
[1, 1, 1, 0, 0, 0, 0, 1, 0, 0],
|
|
483
|
+
),
|
|
484
|
+
(
|
|
485
|
+
["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
|
|
486
|
+
[1, 1, 2, 0, 0, 0, 1, 0, 0],
|
|
487
|
+
),
|
|
488
|
+
(
|
|
489
|
+
["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
|
|
490
|
+
[1, 1, 3, 0, 0, 0, 1, 0, 0],
|
|
491
|
+
),
|
|
492
|
+
],
|
|
493
|
+
output_name_to_node=None,
|
|
494
|
+
)
|
|
495
|
+
v_nodes_5 = self.model.match_parent_path(
|
|
496
|
+
matmul_qkv,
|
|
497
|
+
["Concat", "Transpose", "Reshape", "Add", "MatMul"],
|
|
498
|
+
[1, 1, 0, 0, 1],
|
|
499
|
+
)
|
|
500
|
+
if v_nodes_1 is not None:
|
|
501
|
+
reshape_v_2, _, concat_v, _, reshape_v_1, matmul_v = v_nodes_1
|
|
502
|
+
v_nodes = v_nodes_1
|
|
503
|
+
|
|
504
|
+
concat_v_path = self.model.match_parent_path(
|
|
505
|
+
concat_v,
|
|
506
|
+
["Slice", "Unsqueeze"],
|
|
507
|
+
[0, 2],
|
|
508
|
+
)
|
|
509
|
+
if concat_v_path is None:
|
|
510
|
+
logger.debug("fuse_rotary_attention: failed to match past/present concat in v path")
|
|
511
|
+
return
|
|
512
|
+
|
|
513
|
+
past_v = concat_v_path[0].input[0]
|
|
514
|
+
past_seq_len = concat_v_path[-1].input[0]
|
|
515
|
+
present_v = concat_v.output[0]
|
|
516
|
+
elif v_nodes_2 is not None:
|
|
517
|
+
concat_v, transpose_v, reshape_v, matmul_v = v_nodes_2
|
|
518
|
+
v_nodes = v_nodes_2
|
|
519
|
+
past_v = concat_v.input[0]
|
|
520
|
+
present_v = concat_v.output[0]
|
|
521
|
+
elif v_nodes_3 is not None:
|
|
522
|
+
transpose_v, reshape_v, matmul_v = v_nodes_3
|
|
523
|
+
v_nodes = v_nodes_3
|
|
524
|
+
present_v = transpose_v.output[0]
|
|
525
|
+
elif v_nodes_4 is not None and len(v_nodes_4) == 9:
|
|
526
|
+
concat_v, transpose_v, reshape_v, matmul_v = v_nodes_4[0][-4:]
|
|
527
|
+
v_nodes = v_nodes_4
|
|
528
|
+
past_v = concat_v.input[0]
|
|
529
|
+
present_v = concat_v.output[0]
|
|
530
|
+
elif v_nodes_5 is not None:
|
|
531
|
+
concat_v, transpose_v, reshape_v, add_v, matmul_v = v_nodes_5
|
|
532
|
+
matmul_v = add_v
|
|
533
|
+
v_nodes = v_nodes_5
|
|
534
|
+
past_v = concat_v.input[0]
|
|
535
|
+
present_v = concat_v.output[0]
|
|
536
|
+
else:
|
|
537
|
+
logger.debug("fuse_rotary_attention: failed to match v path")
|
|
538
|
+
return
|
|
539
|
+
|
|
540
|
+
qk_nodes = self.model.match_parent_path(
|
|
541
|
+
matmul_qkv,
|
|
542
|
+
["Softmax", "Add", "Div", "MatMul"],
|
|
543
|
+
[0, 0, 0, 0],
|
|
544
|
+
)
|
|
545
|
+
add_qk, matmul_qk = None, None
|
|
546
|
+
if qk_nodes is not None:
|
|
547
|
+
_, add_qk, _, matmul_qk = qk_nodes
|
|
548
|
+
else:
|
|
549
|
+
logger.debug("fuse_rotary_attention: failed to match qk nodes")
|
|
550
|
+
return
|
|
551
|
+
|
|
552
|
+
# attn_mask_nodes_1, attn_mask_nodes_2 are for LLaMA-2 Microsoft's 3D attention mask
|
|
553
|
+
# attn_mask_nodes_3, attn_mask_nodes_4 are for LLaMA-2 Hugging Face's 2D attention mask
|
|
554
|
+
# attn_mask_nodes_5, attn_mask_nodes_6 are for LLaMA-2 Microsoft's model for the DML EP
|
|
555
|
+
# attn_mask_nodes_7 is for LLaMA-2 Hugging Face's changes to the attention mask
|
|
556
|
+
attn_mask, add_qk_str = "", ""
|
|
557
|
+
attn_mask_nodes_1 = self.model.match_parent_path(
|
|
558
|
+
add_qk,
|
|
559
|
+
["Concat", "Slice", "Slice"],
|
|
560
|
+
[1, 0, 0],
|
|
561
|
+
)
|
|
562
|
+
attn_mask_nodes_2 = self.model.match_parent_path(
|
|
563
|
+
add_qk,
|
|
564
|
+
["Cast", "Concat", "Slice", "Slice"],
|
|
565
|
+
[1, 0, 0, 0],
|
|
566
|
+
)
|
|
567
|
+
attn_mask_nodes_3 = self.model.match_parent_path(
|
|
568
|
+
add_qk,
|
|
569
|
+
["Add", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
|
|
570
|
+
[1, 0, 2, 1, 0, 0, 0],
|
|
571
|
+
)
|
|
572
|
+
attn_mask_nodes_4 = self.model.match_parent_path(
|
|
573
|
+
add_qk,
|
|
574
|
+
["Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
|
|
575
|
+
[1, 2, 1, 0, 0, 0],
|
|
576
|
+
)
|
|
577
|
+
attn_mask_nodes_5 = self.model.match_parent_path(
|
|
578
|
+
add_qk,
|
|
579
|
+
["Expand", "Add", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
|
|
580
|
+
[1, 0, 0, 2, 1, 0, 0, 0],
|
|
581
|
+
)
|
|
582
|
+
attn_mask_nodes_6 = self.model.match_parent_path(
|
|
583
|
+
add_qk,
|
|
584
|
+
["Expand", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
|
|
585
|
+
[1, 0, 2, 1, 0, 0, 0],
|
|
586
|
+
)
|
|
587
|
+
attn_mask_nodes_7 = self.model.match_parent_path(
|
|
588
|
+
add_qk,
|
|
589
|
+
["Where", "Cast", "Where", "Cast", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
|
|
590
|
+
[1, 0, 0, 0, 0, 1, 0, 0, 0],
|
|
591
|
+
)
|
|
592
|
+
if attn_mask_nodes_1 is not None:
|
|
593
|
+
_, slice_mask_1, slice_mask_2 = attn_mask_nodes_1
|
|
594
|
+
attn_mask = slice_mask_1.output[0]
|
|
595
|
+
elif attn_mask_nodes_2 is not None:
|
|
596
|
+
_, _, slice_mask_1, slice_mask_2 = attn_mask_nodes_2
|
|
597
|
+
attn_mask = slice_mask_1.output[0]
|
|
598
|
+
elif attn_mask_nodes_3 is not None:
|
|
599
|
+
# Reshape from (B,1,S,T) to (B,N,S,T)
|
|
600
|
+
add_qk_str = self.reshape_add_qk(attn_mask_nodes_3[0].output[0])
|
|
601
|
+
elif attn_mask_nodes_4 is not None:
|
|
602
|
+
# Reshape from (B,1,S,T) to (B,N,S,T)
|
|
603
|
+
add_qk_str = self.reshape_add_qk(attn_mask_nodes_4[0].output[0])
|
|
604
|
+
elif attn_mask_nodes_5 is not None:
|
|
605
|
+
# The mask has already been reshaped to (B,N,S,T)
|
|
606
|
+
add_qk_str = attn_mask_nodes_5[0].output[0]
|
|
607
|
+
elif attn_mask_nodes_6 is not None:
|
|
608
|
+
# The mask has already been reshaped to (B,N,S,T)
|
|
609
|
+
add_qk_str = attn_mask_nodes_6[0].output[0]
|
|
610
|
+
elif attn_mask_nodes_7 is not None:
|
|
611
|
+
# Reshape from (B,1,S,T) to (B,N,S,T)
|
|
612
|
+
add_qk_str = self.reshape_add_qk(attn_mask_nodes_7[0].output[0])
|
|
613
|
+
else:
|
|
614
|
+
logger.debug("fuse_rotary_attention: failed to match attention mask nodes")
|
|
615
|
+
return
|
|
616
|
+
|
|
617
|
+
# k_nodes_1 is for LLaMA-2 Microsoft
|
|
618
|
+
# k_nodes_2 is for LLaMA-2 Hugging Face
|
|
619
|
+
# k_nodes_4 is for LLaMA-2 70B Hugging Face
|
|
620
|
+
past_k, present_k = "", ""
|
|
621
|
+
k_nodes = None
|
|
622
|
+
slice_k = None
|
|
623
|
+
concat_k_half = None
|
|
624
|
+
k_nodes_1 = self.model.match_parent_path(
|
|
625
|
+
matmul_qk,
|
|
626
|
+
["Reshape", "Transpose", "Concat", "Transpose", "RotaryEmbedding", "MatMul"],
|
|
627
|
+
[1, 0, 0, 1, 0, 0],
|
|
628
|
+
)
|
|
629
|
+
k_nodes_2 = self.model.match_parent_path(
|
|
630
|
+
matmul_qk,
|
|
631
|
+
["Transpose", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
|
|
632
|
+
[1, 0, 0, 0, 0],
|
|
633
|
+
)
|
|
634
|
+
k_nodes_3 = self.model.match_parent_path(
|
|
635
|
+
matmul_qk,
|
|
636
|
+
["Transpose", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
|
|
637
|
+
[1, 0, 1, 0, 0, 0],
|
|
638
|
+
)
|
|
639
|
+
_, k_nodes_4, _ = self.model.match_parent_paths_all(
|
|
640
|
+
matmul_qk,
|
|
641
|
+
[
|
|
642
|
+
(
|
|
643
|
+
[
|
|
644
|
+
"Transpose",
|
|
645
|
+
"Reshape",
|
|
646
|
+
"Expand",
|
|
647
|
+
"Unsqueeze",
|
|
648
|
+
"Concat",
|
|
649
|
+
"RotaryEmbedding",
|
|
650
|
+
"Transpose",
|
|
651
|
+
"Reshape",
|
|
652
|
+
"MatMul",
|
|
653
|
+
],
|
|
654
|
+
[1, 0, 0, 0, 0, 1, 0, 0, 0],
|
|
655
|
+
),
|
|
656
|
+
(
|
|
657
|
+
[
|
|
658
|
+
"Transpose",
|
|
659
|
+
"Reshape",
|
|
660
|
+
"Expand",
|
|
661
|
+
"Where",
|
|
662
|
+
"Equal",
|
|
663
|
+
"Reshape",
|
|
664
|
+
"Concat",
|
|
665
|
+
"Unsqueeze",
|
|
666
|
+
"Gather",
|
|
667
|
+
"Shape",
|
|
668
|
+
"Concat",
|
|
669
|
+
"RotaryEmbedding",
|
|
670
|
+
"Transpose",
|
|
671
|
+
"Reshape",
|
|
672
|
+
"MatMul",
|
|
673
|
+
],
|
|
674
|
+
[1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
|
|
675
|
+
),
|
|
676
|
+
(
|
|
677
|
+
[
|
|
678
|
+
"Transpose",
|
|
679
|
+
"Reshape",
|
|
680
|
+
"Expand",
|
|
681
|
+
"Where",
|
|
682
|
+
"Equal",
|
|
683
|
+
"Mul",
|
|
684
|
+
"ConstantOfShape",
|
|
685
|
+
"Shape",
|
|
686
|
+
"Reshape",
|
|
687
|
+
"Concat",
|
|
688
|
+
"Unsqueeze",
|
|
689
|
+
"Gather",
|
|
690
|
+
"Shape",
|
|
691
|
+
"Concat",
|
|
692
|
+
"RotaryEmbedding",
|
|
693
|
+
"Transpose",
|
|
694
|
+
"Reshape",
|
|
695
|
+
"MatMul",
|
|
696
|
+
],
|
|
697
|
+
[1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0],
|
|
698
|
+
),
|
|
699
|
+
(
|
|
700
|
+
[
|
|
701
|
+
"Transpose",
|
|
702
|
+
"Reshape",
|
|
703
|
+
"Expand",
|
|
704
|
+
"Where",
|
|
705
|
+
"ConstantOfShape",
|
|
706
|
+
"Shape",
|
|
707
|
+
"Reshape",
|
|
708
|
+
"Concat",
|
|
709
|
+
"Unsqueeze",
|
|
710
|
+
"Gather",
|
|
711
|
+
"Shape",
|
|
712
|
+
"Concat",
|
|
713
|
+
"RotaryEmbedding",
|
|
714
|
+
"Transpose",
|
|
715
|
+
"Reshape",
|
|
716
|
+
"MatMul",
|
|
717
|
+
],
|
|
718
|
+
[1, 0, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0],
|
|
719
|
+
),
|
|
720
|
+
(
|
|
721
|
+
[
|
|
722
|
+
"Transpose",
|
|
723
|
+
"Reshape",
|
|
724
|
+
"Expand",
|
|
725
|
+
"Where",
|
|
726
|
+
"Reshape",
|
|
727
|
+
"Concat",
|
|
728
|
+
"Unsqueeze",
|
|
729
|
+
"Gather",
|
|
730
|
+
"Shape",
|
|
731
|
+
"Concat",
|
|
732
|
+
"RotaryEmbedding",
|
|
733
|
+
"Transpose",
|
|
734
|
+
"Reshape",
|
|
735
|
+
"MatMul",
|
|
736
|
+
],
|
|
737
|
+
[1, 0, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0, 0],
|
|
738
|
+
),
|
|
739
|
+
(
|
|
740
|
+
[
|
|
741
|
+
"Transpose",
|
|
742
|
+
"Reshape",
|
|
743
|
+
"Concat",
|
|
744
|
+
"Unsqueeze",
|
|
745
|
+
"Gather",
|
|
746
|
+
"Shape",
|
|
747
|
+
"Concat",
|
|
748
|
+
"RotaryEmbedding",
|
|
749
|
+
"Transpose",
|
|
750
|
+
"Reshape",
|
|
751
|
+
"MatMul",
|
|
752
|
+
],
|
|
753
|
+
[1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
|
|
754
|
+
),
|
|
755
|
+
(
|
|
756
|
+
[
|
|
757
|
+
"Transpose",
|
|
758
|
+
"Reshape",
|
|
759
|
+
"Concat",
|
|
760
|
+
"Unsqueeze",
|
|
761
|
+
"Mul",
|
|
762
|
+
"Gather",
|
|
763
|
+
"Shape",
|
|
764
|
+
"Concat",
|
|
765
|
+
"RotaryEmbedding",
|
|
766
|
+
"Transpose",
|
|
767
|
+
"Reshape",
|
|
768
|
+
"MatMul",
|
|
769
|
+
],
|
|
770
|
+
[1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0],
|
|
771
|
+
),
|
|
772
|
+
(
|
|
773
|
+
[
|
|
774
|
+
"Transpose",
|
|
775
|
+
"Reshape",
|
|
776
|
+
"Concat",
|
|
777
|
+
"Unsqueeze",
|
|
778
|
+
"Gather",
|
|
779
|
+
"Shape",
|
|
780
|
+
"Concat",
|
|
781
|
+
"RotaryEmbedding",
|
|
782
|
+
"Transpose",
|
|
783
|
+
"Reshape",
|
|
784
|
+
"MatMul",
|
|
785
|
+
],
|
|
786
|
+
[1, 0, 1, 2, 0, 0, 0, 1, 0, 0, 0],
|
|
787
|
+
),
|
|
788
|
+
(
|
|
789
|
+
[
|
|
790
|
+
"Transpose",
|
|
791
|
+
"Reshape",
|
|
792
|
+
"Concat",
|
|
793
|
+
"Unsqueeze",
|
|
794
|
+
"Gather",
|
|
795
|
+
"Shape",
|
|
796
|
+
"Concat",
|
|
797
|
+
"RotaryEmbedding",
|
|
798
|
+
"Transpose",
|
|
799
|
+
"Reshape",
|
|
800
|
+
"MatMul",
|
|
801
|
+
],
|
|
802
|
+
[1, 0, 1, 3, 0, 0, 0, 1, 0, 0, 0],
|
|
803
|
+
),
|
|
804
|
+
],
|
|
805
|
+
output_name_to_node=None,
|
|
806
|
+
)
|
|
807
|
+
k_nodes_5 = self.model.match_parent_path(
|
|
808
|
+
matmul_qk,
|
|
809
|
+
["Transpose", "Concat", "Concat", "RotaryEmbedding", "Slice", "Transpose", "Reshape", "Add", "MatMul"],
|
|
810
|
+
[1, 0, 1, 0, 0, 0, 0, 0, 1],
|
|
811
|
+
)
|
|
812
|
+
if k_nodes_1 is not None:
|
|
813
|
+
reshape_k_2, _, concat_k, _, rotary_k, matmul_k = k_nodes_1
|
|
814
|
+
k_nodes = k_nodes_1
|
|
815
|
+
|
|
816
|
+
concat_k_path = self.model.match_parent_path(
|
|
817
|
+
concat_k,
|
|
818
|
+
["Slice", "Unsqueeze"],
|
|
819
|
+
[0, 2],
|
|
820
|
+
)
|
|
821
|
+
if concat_k_path is None:
|
|
822
|
+
logger.debug("fuse_rotary_attention: failed to match past/present concat in k path")
|
|
823
|
+
return
|
|
824
|
+
|
|
825
|
+
past_k = concat_k_path[0].input[0]
|
|
826
|
+
shared_past_seq_len = concat_k_path[-1].input[0]
|
|
827
|
+
present_k = concat_k.output[0]
|
|
828
|
+
|
|
829
|
+
assert past_seq_len == shared_past_seq_len
|
|
830
|
+
elif k_nodes_2 is not None:
|
|
831
|
+
_, rotary_k, _, reshape_k, matmul_k = k_nodes_2
|
|
832
|
+
k_nodes = k_nodes_2
|
|
833
|
+
present_k = rotary_k.output[0]
|
|
834
|
+
elif k_nodes_3 is not None:
|
|
835
|
+
_, concat_k, rotary_k, _, reshape_k, matmul_k = k_nodes_3
|
|
836
|
+
k_nodes = k_nodes_3
|
|
837
|
+
past_k = concat_k.input[0]
|
|
838
|
+
present_k = concat_k.output[0]
|
|
839
|
+
elif k_nodes_4 is not None and len(k_nodes_4) == 9:
|
|
840
|
+
reshape_k, matmul_k = k_nodes_4[0][-2:]
|
|
841
|
+
concat_k, rotary_k = k_nodes_4[0][-5:-3]
|
|
842
|
+
k_nodes = k_nodes_4
|
|
843
|
+
past_k = concat_k.input[0]
|
|
844
|
+
present_k = concat_k.output[0]
|
|
845
|
+
elif k_nodes_5 is not None:
|
|
846
|
+
_, concat_k, concat_k_half, rotary_k, slice_k, _, reshape_k, _, matmul_k = k_nodes_5
|
|
847
|
+
k_nodes = k_nodes_5
|
|
848
|
+
past_k = concat_k.input[0]
|
|
849
|
+
present_k = concat_k.output[0]
|
|
850
|
+
else:
|
|
851
|
+
logger.debug("fuse_rotary_attention: failed to match k nodes")
|
|
852
|
+
return
|
|
853
|
+
|
|
854
|
+
# q_nodes_1 is for LLaMA-2 Microsoft
|
|
855
|
+
# q_nodes_2 is for LLaMA-2 Hugging Face
|
|
856
|
+
# q_nodes_3 is for Phi-2 DirectML
|
|
857
|
+
q_nodes = None
|
|
858
|
+
slice_q = None
|
|
859
|
+
concat_q_half = None
|
|
860
|
+
q_nodes_1 = self.model.match_parent_path(
|
|
861
|
+
matmul_qk,
|
|
862
|
+
["Reshape", "Transpose", "RotaryEmbedding", "MatMul"],
|
|
863
|
+
[0, 0, 0, 0],
|
|
864
|
+
)
|
|
865
|
+
q_nodes_2 = self.model.match_parent_path(
|
|
866
|
+
matmul_qk,
|
|
867
|
+
["RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
|
|
868
|
+
[0, 0, 0, 0],
|
|
869
|
+
)
|
|
870
|
+
q_nodes_3 = self.model.match_parent_path(
|
|
871
|
+
matmul_qk,
|
|
872
|
+
["Concat", "RotaryEmbedding", "Slice", "Transpose", "Reshape", "Add", "MatMul"],
|
|
873
|
+
[0, 0, 0, 0, 0, 0, 1],
|
|
874
|
+
)
|
|
875
|
+
if q_nodes_1 is not None:
|
|
876
|
+
reshape_q_2, _, rotary_q, matmul_q = q_nodes_1
|
|
877
|
+
q_nodes = q_nodes_1
|
|
878
|
+
elif q_nodes_2 is not None:
|
|
879
|
+
rotary_q, _, reshape_q, matmul_q = q_nodes_2
|
|
880
|
+
q_nodes = q_nodes_2
|
|
881
|
+
elif q_nodes_3 is not None:
|
|
882
|
+
concat_q_half, rotary_q, slice_q, _, reshape_q, _, matmul_q = q_nodes_3
|
|
883
|
+
q_nodes = q_nodes_3
|
|
884
|
+
else:
|
|
885
|
+
logger.debug("fuse_rotary_attention: failed to match q nodes")
|
|
886
|
+
return
|
|
887
|
+
|
|
888
|
+
if matmul_q.input[0] != matmul_k.input[0] and matmul_k.input[0] != matmul_v.input[0]:
|
|
889
|
+
logger.debug("fuse_rotary_attention: failed to find the same root_input for q, k, v paths")
|
|
890
|
+
return
|
|
891
|
+
|
|
892
|
+
root_output = ""
|
|
893
|
+
if qkv_nodes == qkv_nodes_1:
|
|
894
|
+
if not self.check_runtime_shape_paths_for_function(
|
|
895
|
+
reshape_qkv_2,
|
|
896
|
+
reshape_qkv_1,
|
|
897
|
+
reshape_q_2,
|
|
898
|
+
reshape_k_2,
|
|
899
|
+
reshape_v_2,
|
|
900
|
+
reshape_v_1,
|
|
901
|
+
add_qk,
|
|
902
|
+
matmul_q.input[0],
|
|
903
|
+
):
|
|
904
|
+
logger.debug("fuse_rotary_attention: failed to verify runtime shape paths")
|
|
905
|
+
return
|
|
906
|
+
root_output = reshape_qkv_2.output[0]
|
|
907
|
+
|
|
908
|
+
elif qkv_nodes in (qkv_nodes_2, qkv_nodes_3):
|
|
909
|
+
if not self.check_runtime_shape_paths_for_nodes(
|
|
910
|
+
reshape_qkv,
|
|
911
|
+
reshape_q,
|
|
912
|
+
reshape_k,
|
|
913
|
+
reshape_v,
|
|
914
|
+
matmul_q.input[0],
|
|
915
|
+
):
|
|
916
|
+
logger.debug("fuse_rotary_attention: failed to verify runtime shape paths")
|
|
917
|
+
return
|
|
918
|
+
root_output = reshape_qkv.output[0]
|
|
919
|
+
|
|
920
|
+
# Rename inputs of rotary_q/k so it connects with output of matmul_q/k
|
|
921
|
+
# Before: MatMul --> Reshape --> Transpose --> RotaryEmbedding
|
|
922
|
+
# After: MatMul --> RotaryEmbedding
|
|
923
|
+
rotary_q.input[0] = slice_q.output[0] if slice_q else matmul_q.output[0]
|
|
924
|
+
rotary_k.input[0] = slice_k.output[0] if slice_k else matmul_k.output[0]
|
|
925
|
+
|
|
926
|
+
# Rename current output of rotary_k (present_key) so it doesn't match output of MHA (present_key)
|
|
927
|
+
if concat_q_half is None:
|
|
928
|
+
rotary_k.output[0] = rotary_k.name + "_output_0"
|
|
929
|
+
|
|
930
|
+
if qkv_nodes == qkv_nodes_3:
|
|
931
|
+
qkv_nodes = qkv_nodes[1:]
|
|
932
|
+
|
|
933
|
+
def create_hidden_size_concat_node(reshape_q):
|
|
934
|
+
"""Detect num_heads and hidden_size for ONNX model from phi-2
|
|
935
|
+
Args:
|
|
936
|
+
reshape_q (NodeProto): reshape node for q
|
|
937
|
+
Returns:
|
|
938
|
+
hidden_size_concat_node(NodeProto): Concat node to be used by reshape
|
|
939
|
+
"""
|
|
940
|
+
concat = self.model.match_parent(reshape_q, "Concat", 1)
|
|
941
|
+
|
|
942
|
+
if concat is None:
|
|
943
|
+
logger.debug("fuse_rotary_attention: failed to trace the concat node from reshape_q")
|
|
944
|
+
return None
|
|
945
|
+
|
|
946
|
+
# The shape is a tensor like [?, ?, num_heads, head_size]
|
|
947
|
+
num_head_constant_node = self.model.get_constant_value(concat.input[2])
|
|
948
|
+
head_size_constant_node = self.model.get_constant_value(concat.input[3])
|
|
949
|
+
|
|
950
|
+
if num_head_constant_node is None or head_size_constant_node is None:
|
|
951
|
+
logger.debug("fuse_rotary_attention: failed to get constant nodes of num_heads or head_size")
|
|
952
|
+
return None
|
|
953
|
+
|
|
954
|
+
num_head_value = num_head_constant_node[0]
|
|
955
|
+
head_size_value = head_size_constant_node[0]
|
|
956
|
+
|
|
957
|
+
hidden_size = num_head_value * head_size_value
|
|
958
|
+
|
|
959
|
+
hidden_size_initilizer = self.model.create_node_name("Initializer", name_prefix="hidden_size")
|
|
960
|
+
if self.model.get_initializer(hidden_size_initilizer) is None:
|
|
961
|
+
self.add_initializer(
|
|
962
|
+
name=hidden_size_initilizer,
|
|
963
|
+
data_type=TensorProto.INT64,
|
|
964
|
+
dims=[1],
|
|
965
|
+
vals=[hidden_size],
|
|
966
|
+
raw=False,
|
|
967
|
+
)
|
|
968
|
+
|
|
969
|
+
hidden_size_reshape_node_name = self.model.create_node_name("Concat", name_prefix="hidden_size_concat")
|
|
970
|
+
|
|
971
|
+
hidden_size_concat_node = helper.make_node(
|
|
972
|
+
"Concat",
|
|
973
|
+
inputs=[
|
|
974
|
+
concat.input[0],
|
|
975
|
+
concat.input[1],
|
|
976
|
+
hidden_size_initilizer,
|
|
977
|
+
],
|
|
978
|
+
outputs=[hidden_size_reshape_node_name + "output_0"],
|
|
979
|
+
name=hidden_size_reshape_node_name,
|
|
980
|
+
)
|
|
981
|
+
hidden_size_concat_node.attribute.extend([helper.make_attribute("axis", 0)])
|
|
982
|
+
|
|
983
|
+
return hidden_size_concat_node
|
|
984
|
+
|
|
985
|
+
# Add Tranpose and Reshape nodes for patial rotary embedding applied in phi-2 before passing into MHA
|
|
986
|
+
if concat_q_half and concat_k_half:
|
|
987
|
+
# Transpose the key output of rotary Embedding
|
|
988
|
+
k_transpose_node_name = self.model.create_node_name("Transpose")
|
|
989
|
+
k_tranpose_output_name = k_transpose_node_name + "_output_0"
|
|
990
|
+
k_transpose_node = helper.make_node(
|
|
991
|
+
"Transpose",
|
|
992
|
+
inputs=[concat_k_half.output[0]],
|
|
993
|
+
outputs=[k_tranpose_output_name],
|
|
994
|
+
name=k_transpose_node_name,
|
|
995
|
+
)
|
|
996
|
+
|
|
997
|
+
k_transpose_node.attribute.extend([helper.make_attribute("perm", [0, 2, 1, 3])])
|
|
998
|
+
|
|
999
|
+
# Transpose the query output of rotary Embedding
|
|
1000
|
+
q_transpose_node_name = self.model.create_node_name("Transpose")
|
|
1001
|
+
q_tranpose_output_name = q_transpose_node_name + "_output_0"
|
|
1002
|
+
q_transpose_node = helper.make_node(
|
|
1003
|
+
"Transpose",
|
|
1004
|
+
inputs=[concat_q_half.output[0]],
|
|
1005
|
+
outputs=[q_tranpose_output_name],
|
|
1006
|
+
name=q_transpose_node_name,
|
|
1007
|
+
)
|
|
1008
|
+
|
|
1009
|
+
q_transpose_node.attribute.extend([helper.make_attribute("perm", [0, 2, 1, 3])])
|
|
1010
|
+
|
|
1011
|
+
hidden_size_concat_node = create_hidden_size_concat_node(reshape_k)
|
|
1012
|
+
if hidden_size_concat_node is None:
|
|
1013
|
+
logger.debug("fuse_rotary_attention: failed to create hidden_size_concat_node")
|
|
1014
|
+
return
|
|
1015
|
+
|
|
1016
|
+
# Reshape the Rotary Embedding output for key for 4D to 3D
|
|
1017
|
+
concat_k_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="concat_k_half")
|
|
1018
|
+
concat_k_reshape_node = helper.make_node(
|
|
1019
|
+
"Reshape",
|
|
1020
|
+
inputs=[k_transpose_node.output[0], hidden_size_concat_node.output[0]],
|
|
1021
|
+
outputs=[concat_k_reshape_node_name + "_output_0"],
|
|
1022
|
+
name=concat_k_reshape_node_name,
|
|
1023
|
+
)
|
|
1024
|
+
|
|
1025
|
+
# Reshape the Rotary Embedding output for query from 4D to 3D
|
|
1026
|
+
concat_q_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="concat_q_half")
|
|
1027
|
+
concat_q_reshape_node = helper.make_node(
|
|
1028
|
+
"Reshape",
|
|
1029
|
+
inputs=[q_transpose_node.output[0], hidden_size_concat_node.output[0]],
|
|
1030
|
+
outputs=[concat_q_reshape_node_name + "_output_0"],
|
|
1031
|
+
name=concat_q_reshape_node_name,
|
|
1032
|
+
)
|
|
1033
|
+
|
|
1034
|
+
rotary_k = concat_k_reshape_node
|
|
1035
|
+
rotary_q = concat_q_reshape_node
|
|
1036
|
+
|
|
1037
|
+
self.nodes_to_add.append(hidden_size_concat_node)
|
|
1038
|
+
self.nodes_to_add.append(k_transpose_node)
|
|
1039
|
+
self.nodes_to_add.append(q_transpose_node)
|
|
1040
|
+
self.nodes_to_add.append(concat_k_reshape_node)
|
|
1041
|
+
self.nodes_to_add.append(concat_q_reshape_node)
|
|
1042
|
+
|
|
1043
|
+
self.node_name_to_graph_name[hidden_size_concat_node.name] = self.this_graph_name
|
|
1044
|
+
self.node_name_to_graph_name[k_transpose_node.name] = self.this_graph_name
|
|
1045
|
+
self.node_name_to_graph_name[q_transpose_node.name] = self.this_graph_name
|
|
1046
|
+
self.node_name_to_graph_name[concat_k_reshape_node.name] = self.this_graph_name
|
|
1047
|
+
self.node_name_to_graph_name[concat_q_reshape_node.name] = self.this_graph_name
|
|
1048
|
+
|
|
1049
|
+
new_node = self.create_mha_node(
|
|
1050
|
+
matmul_q.input[0],
|
|
1051
|
+
root_output,
|
|
1052
|
+
rotary_q,
|
|
1053
|
+
rotary_k,
|
|
1054
|
+
matmul_v,
|
|
1055
|
+
attn_mask,
|
|
1056
|
+
add_qk_str,
|
|
1057
|
+
past_k,
|
|
1058
|
+
past_v,
|
|
1059
|
+
present_k,
|
|
1060
|
+
present_v,
|
|
1061
|
+
)
|
|
1062
|
+
if new_node is None:
|
|
1063
|
+
logger.debug("fuse_rotary_attention: failed to create multi-head attention with rotary embeddings")
|
|
1064
|
+
return
|
|
1065
|
+
|
|
1066
|
+
self.nodes_to_add.append(new_node)
|
|
1067
|
+
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
|
1068
|
+
|
|
1069
|
+
self.nodes_to_remove.extend(qkv_nodes[1:])
|
|
1070
|
+
|
|
1071
|
+
if v_nodes != v_nodes_4:
|
|
1072
|
+
self.nodes_to_remove.extend(v_nodes[:-1] if add_v is None else v_nodes[:-2])
|
|
1073
|
+
else:
|
|
1074
|
+
nodes_to_keep = [v_nodes[0][-1]]
|
|
1075
|
+
for temp_path in v_nodes:
|
|
1076
|
+
self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep)
|
|
1077
|
+
|
|
1078
|
+
self.nodes_to_remove.extend(qk_nodes)
|
|
1079
|
+
|
|
1080
|
+
if k_nodes == k_nodes_1:
|
|
1081
|
+
self.nodes_to_remove.extend(k_nodes[:-2])
|
|
1082
|
+
elif k_nodes == k_nodes_2:
|
|
1083
|
+
self.nodes_to_remove.append(k_nodes[0])
|
|
1084
|
+
self.nodes_to_remove.append(k_nodes[2])
|
|
1085
|
+
self.nodes_to_remove.append(k_nodes[3])
|
|
1086
|
+
elif k_nodes == k_nodes_3:
|
|
1087
|
+
self.nodes_to_remove.append(k_nodes[0])
|
|
1088
|
+
self.nodes_to_remove.append(k_nodes[1])
|
|
1089
|
+
self.nodes_to_remove.append(k_nodes[3])
|
|
1090
|
+
self.nodes_to_remove.append(k_nodes[4])
|
|
1091
|
+
elif k_nodes == k_nodes_5:
|
|
1092
|
+
self.nodes_to_remove.append(k_nodes[0])
|
|
1093
|
+
self.nodes_to_remove.append(k_nodes[1])
|
|
1094
|
+
elif k_nodes == k_nodes_4:
|
|
1095
|
+
nodes_to_keep = [k_nodes[0][-1], k_nodes[0][-4]]
|
|
1096
|
+
for temp_path in k_nodes:
|
|
1097
|
+
self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep)
|
|
1098
|
+
|
|
1099
|
+
if q_nodes == q_nodes_1:
|
|
1100
|
+
self.nodes_to_remove.extend(q_nodes[:-2])
|
|
1101
|
+
elif q_nodes == q_nodes_2:
|
|
1102
|
+
self.nodes_to_remove.append(q_nodes[1])
|
|
1103
|
+
self.nodes_to_remove.append(q_nodes[2])
|
|
1104
|
+
self.prune_graph = True
|
|
1105
|
+
|
|
1106
|
+
|
|
1107
|
+
class FusionRotaryEmbeddings(Fusion):
|
|
1108
|
+
def __init__(self, model: OnnxModel):
|
|
1109
|
+
self.base_name = "RotaryEmbedding"
|
|
1110
|
+
super().__init__(model, self.base_name, [self.base_name, self.base_name + ".1", "Add"])
|
|
1111
|
+
|
|
1112
|
+
# The RotaryEmbedding function can have multiple extraneous constant outputs even though the function is supposed to produce only one output.
|
|
1113
|
+
# This is a byproduct of a potential CSE bug when using `export_modules_as_functions` in the TorchScript exporter.
|
|
1114
|
+
# To work around this issue, we set the extraneous constant values from the RotaryEmbedding function as initializers in the locations where they are actually used.
|
|
1115
|
+
def reassign_extra_outputs(self, rot_emb_node: NodeProto, function: FunctionProto):
|
|
1116
|
+
# Find extra outputs and Constant nodes attached to those outputs
|
|
1117
|
+
extra_constants, extra_outputs = [], []
|
|
1118
|
+
for fn_node in function.node:
|
|
1119
|
+
if fn_node.op_type == "Constant" and fn_node.input == [] and fn_node.output[0] in function.output:
|
|
1120
|
+
extra_constants.append(fn_node)
|
|
1121
|
+
output_index = list(function.output).index(fn_node.output[0])
|
|
1122
|
+
extra_outputs.append(rot_emb_node.output[output_index])
|
|
1123
|
+
|
|
1124
|
+
# Set extra Constant node outputs as initializers
|
|
1125
|
+
extra_initializers = []
|
|
1126
|
+
for extra_constant in extra_constants:
|
|
1127
|
+
constant_tensorproto = extra_constant.attribute[0].t
|
|
1128
|
+
constant_tensorproto.name = self.model.create_node_name("Constant")
|
|
1129
|
+
self.model.add_initializer(constant_tensorproto)
|
|
1130
|
+
extra_initializers.append(constant_tensorproto.name)
|
|
1131
|
+
|
|
1132
|
+
# Update references of Constant node outputs to initializer references
|
|
1133
|
+
for extra_output, extra_initializer in zip(extra_outputs, extra_initializers, strict=False):
|
|
1134
|
+
nodes_to_update = list(filter(lambda entry: extra_output in entry.input, self.model.model.graph.node))
|
|
1135
|
+
for node_to_update in nodes_to_update:
|
|
1136
|
+
OnnxModel.replace_node_input(node_to_update, extra_output, extra_initializer)
|
|
1137
|
+
|
|
1138
|
+
return extra_outputs
|
|
1139
|
+
|
|
1140
|
+
def create_rotary_embeddings_from_function(self, node: NodeProto):
|
|
1141
|
+
rotary_emb_node_name = self.model.create_node_name(self.base_name)
|
|
1142
|
+
|
|
1143
|
+
matmul_path = self.model.match_parent_path(
|
|
1144
|
+
node,
|
|
1145
|
+
["Reshape", "MatMul"],
|
|
1146
|
+
[0, 0],
|
|
1147
|
+
)
|
|
1148
|
+
if matmul_path is not None:
|
|
1149
|
+
reshape_node, matmul_node = matmul_path
|
|
1150
|
+
else:
|
|
1151
|
+
logger.debug("fuse_rotary_embeddings: failed to match MatMul")
|
|
1152
|
+
return
|
|
1153
|
+
|
|
1154
|
+
rotary_emb_inputs = [
|
|
1155
|
+
matmul_node.output[0], # x is of shape (B,S,D) instead of (B,S,N,H)
|
|
1156
|
+
node.input[1], # position_ids
|
|
1157
|
+
]
|
|
1158
|
+
|
|
1159
|
+
# Convert cos_cache and sin_cache from node attributes to model initializers
|
|
1160
|
+
cos_cache_node = list(filter(lambda constant: constant.output[0] == node.input[2], self.model.model.graph.node))
|
|
1161
|
+
sin_cache_node = list(filter(lambda constant: constant.output[0] == node.input[3], self.model.model.graph.node))
|
|
1162
|
+
cos_cache_name, sin_cache_name = "cos_cache", "sin_cache"
|
|
1163
|
+
|
|
1164
|
+
if (
|
|
1165
|
+
len(cos_cache_node) == 1
|
|
1166
|
+
and len(sin_cache_node) == 1
|
|
1167
|
+
and self.model.get_initializer(cos_cache_name) is None
|
|
1168
|
+
and self.model.get_initializer(sin_cache_name) is None
|
|
1169
|
+
):
|
|
1170
|
+
cos_cache = numpy_helper.to_array(cos_cache_node[0].attribute[0].t).squeeze()
|
|
1171
|
+
sin_cache = numpy_helper.to_array(sin_cache_node[0].attribute[0].t).squeeze()
|
|
1172
|
+
|
|
1173
|
+
cos_cache_tensor = helper.make_tensor(
|
|
1174
|
+
name=cos_cache_name,
|
|
1175
|
+
data_type=TensorProto.FLOAT,
|
|
1176
|
+
dims=list(cos_cache.shape),
|
|
1177
|
+
vals=cos_cache.flatten().tolist(),
|
|
1178
|
+
)
|
|
1179
|
+
self.model.add_initializer(cos_cache_tensor, self.this_graph_name)
|
|
1180
|
+
sin_cache_tensor = helper.make_tensor(
|
|
1181
|
+
name=sin_cache_name,
|
|
1182
|
+
data_type=TensorProto.FLOAT,
|
|
1183
|
+
dims=list(sin_cache.shape),
|
|
1184
|
+
vals=sin_cache.flatten().tolist(),
|
|
1185
|
+
)
|
|
1186
|
+
self.model.add_initializer(sin_cache_tensor, self.this_graph_name)
|
|
1187
|
+
|
|
1188
|
+
self.nodes_to_remove.extend([cos_cache_node[0], sin_cache_node[0]])
|
|
1189
|
+
|
|
1190
|
+
rotary_emb_inputs.extend([cos_cache_name, sin_cache_name])
|
|
1191
|
+
|
|
1192
|
+
rotary_emb_outputs = node.output
|
|
1193
|
+
if len(rotary_emb_outputs) > 1:
|
|
1194
|
+
# Re-assign extraneous constant outputs in RotaryEmbedding functions as initializers
|
|
1195
|
+
func = list(filter(lambda fn: fn.name == node.op_type, self.model.model.functions))
|
|
1196
|
+
assert len(func) == 1
|
|
1197
|
+
extra_outputs = self.reassign_extra_outputs(node, func[0])
|
|
1198
|
+
rotary_emb_outputs = list(filter(lambda output_name: output_name not in extra_outputs, rotary_emb_outputs))
|
|
1199
|
+
assert len(rotary_emb_outputs) == 1
|
|
1200
|
+
|
|
1201
|
+
rotary_emb_node = helper.make_node(
|
|
1202
|
+
self.base_name,
|
|
1203
|
+
inputs=rotary_emb_inputs,
|
|
1204
|
+
outputs=rotary_emb_outputs,
|
|
1205
|
+
name=rotary_emb_node_name,
|
|
1206
|
+
interleaved=1,
|
|
1207
|
+
)
|
|
1208
|
+
rotary_emb_node.domain = "com.microsoft"
|
|
1209
|
+
|
|
1210
|
+
self.nodes_to_remove.append(reshape_node)
|
|
1211
|
+
|
|
1212
|
+
return rotary_emb_node
|
|
1213
|
+
|
|
1214
|
+
def create_rotary_embeddings_from_nodes(
|
|
1215
|
+
self,
|
|
1216
|
+
root_input: str,
|
|
1217
|
+
position_ids: str,
|
|
1218
|
+
cos_slice: str,
|
|
1219
|
+
sin_slice: str,
|
|
1220
|
+
output: str,
|
|
1221
|
+
):
|
|
1222
|
+
rotary_emb_node_name = self.model.create_node_name(self.base_name)
|
|
1223
|
+
|
|
1224
|
+
# Convert cos_cache and sin_cache from node attributes to model initializers
|
|
1225
|
+
cos_cache_node = list(filter(lambda constant: constant.output[0] == cos_slice, self.model.model.graph.node))
|
|
1226
|
+
sin_cache_node = list(filter(lambda constant: constant.output[0] == sin_slice, self.model.model.graph.node))
|
|
1227
|
+
cos_cache_name, sin_cache_name = "cos_cache", "sin_cache"
|
|
1228
|
+
|
|
1229
|
+
if (
|
|
1230
|
+
len(cos_cache_node) == 1
|
|
1231
|
+
and len(sin_cache_node) == 1
|
|
1232
|
+
and self.model.get_initializer(cos_cache_name) is None
|
|
1233
|
+
and self.model.get_initializer(sin_cache_name) is None
|
|
1234
|
+
):
|
|
1235
|
+
cos_cache = numpy_helper.to_array(cos_cache_node[0].attribute[0].t).squeeze()
|
|
1236
|
+
sin_cache = numpy_helper.to_array(sin_cache_node[0].attribute[0].t).squeeze()
|
|
1237
|
+
|
|
1238
|
+
# Reshape cos/sin cache from (M, H) to (M, H/2)
|
|
1239
|
+
head_size = cos_cache.shape[1]
|
|
1240
|
+
cos_cache = cos_cache[:, : (head_size // 2)]
|
|
1241
|
+
sin_cache = sin_cache[:, : (head_size // 2)]
|
|
1242
|
+
|
|
1243
|
+
cos_cache_tensor = helper.make_tensor(
|
|
1244
|
+
name=cos_cache_name,
|
|
1245
|
+
data_type=TensorProto.FLOAT,
|
|
1246
|
+
dims=list(cos_cache.shape),
|
|
1247
|
+
vals=cos_cache.flatten().tolist(),
|
|
1248
|
+
)
|
|
1249
|
+
self.model.add_initializer(cos_cache_tensor, self.this_graph_name)
|
|
1250
|
+
sin_cache_tensor = helper.make_tensor(
|
|
1251
|
+
name=sin_cache_name,
|
|
1252
|
+
data_type=TensorProto.FLOAT,
|
|
1253
|
+
dims=list(sin_cache.shape),
|
|
1254
|
+
vals=sin_cache.flatten().tolist(),
|
|
1255
|
+
)
|
|
1256
|
+
self.model.add_initializer(sin_cache_tensor, self.this_graph_name)
|
|
1257
|
+
|
|
1258
|
+
self.nodes_to_remove.extend([cos_cache_node[0], sin_cache_node[0]])
|
|
1259
|
+
|
|
1260
|
+
rotary_emb_node = helper.make_node(
|
|
1261
|
+
self.base_name,
|
|
1262
|
+
inputs=[root_input, position_ids, cos_cache_name, sin_cache_name],
|
|
1263
|
+
outputs=[output],
|
|
1264
|
+
name=rotary_emb_node_name,
|
|
1265
|
+
interleaved=0,
|
|
1266
|
+
)
|
|
1267
|
+
rotary_emb_node.domain = "com.microsoft"
|
|
1268
|
+
return rotary_emb_node
|
|
1269
|
+
|
|
1270
|
+
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
|
1271
|
+
# Node is either RotaryEmbedding function or Add
|
|
1272
|
+
if self.base_name not in node.op_type and node.op_type != "Add":
|
|
1273
|
+
return
|
|
1274
|
+
|
|
1275
|
+
# Check if node is "RotaryEmbedding nn.Module" exported as a function
|
|
1276
|
+
# (e.g. export_modules_as_functions={RotaryEmbedding} in torch.onnx.export)
|
|
1277
|
+
rotary_emb_node = None
|
|
1278
|
+
if node.op_type != "Add":
|
|
1279
|
+
# Verify that function has the correct inputs
|
|
1280
|
+
if len(node.input) not in {4, 5} or node.input[1] not in {
|
|
1281
|
+
"pos",
|
|
1282
|
+
"pos_id",
|
|
1283
|
+
"position_id",
|
|
1284
|
+
"pos_ids",
|
|
1285
|
+
"position_ids",
|
|
1286
|
+
}:
|
|
1287
|
+
logger.debug("fuse_rotary_embeddings: failed to verify inputs for RotaryEmbedding function")
|
|
1288
|
+
return
|
|
1289
|
+
|
|
1290
|
+
rotary_emb_node = self.create_rotary_embeddings_from_function(node)
|
|
1291
|
+
if rotary_emb_node is None:
|
|
1292
|
+
logger.debug("fuse_rotary_embeddings: failed to create RotaryEmbedding node")
|
|
1293
|
+
return
|
|
1294
|
+
|
|
1295
|
+
# Remove RotaryEmbedding function
|
|
1296
|
+
self.nodes_to_remove.append(node)
|
|
1297
|
+
|
|
1298
|
+
# Remove RotaryEmbedding function's shape inference stored in value_info
|
|
1299
|
+
# The new shape will be calculated during symbolic shape inference
|
|
1300
|
+
old_shape_infer = list(
|
|
1301
|
+
filter(lambda node: node.name == rotary_emb_node.output[0], self.model.model.graph.value_info)
|
|
1302
|
+
)
|
|
1303
|
+
assert len(old_shape_infer) == 1
|
|
1304
|
+
self.model.model.graph.value_info.remove(old_shape_infer[0])
|
|
1305
|
+
|
|
1306
|
+
else:
|
|
1307
|
+
# Rotary embeddings are defined using the below functions:
|
|
1308
|
+
#
|
|
1309
|
+
# def rotate_half(x):
|
|
1310
|
+
# """Rotates half the hidden dims of the input."""
|
|
1311
|
+
# x1 = x[..., : x.shape[-1] // 2]
|
|
1312
|
+
# x2 = x[..., x.shape[-1] // 2 :]
|
|
1313
|
+
# return torch.cat((-x2, x1), dim=-1)
|
|
1314
|
+
#
|
|
1315
|
+
# def apply_rope(x, cos, sin, position_ids):
|
|
1316
|
+
# cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
|
1317
|
+
# sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
|
1318
|
+
# cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
|
1319
|
+
# sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
|
1320
|
+
# x_embed = (x * cos) + (rotate_half(x) * sin)
|
|
1321
|
+
# return x_embed
|
|
1322
|
+
|
|
1323
|
+
# Check paths for rotate_half(x)
|
|
1324
|
+
rotate_half_x2_path_1_1 = self.model.match_parent_path(
|
|
1325
|
+
node,
|
|
1326
|
+
["Mul", "Concat", "Neg", "Slice", "Transpose"],
|
|
1327
|
+
[1, 0, 0, 0, 0],
|
|
1328
|
+
)
|
|
1329
|
+
|
|
1330
|
+
rotate_half_x2_path_1_2 = self.model.match_parent_path(
|
|
1331
|
+
node,
|
|
1332
|
+
["Mul", "Concat", "Neg", "Slice", "Slice"],
|
|
1333
|
+
[1, 0, 0, 0, 0],
|
|
1334
|
+
)
|
|
1335
|
+
|
|
1336
|
+
rotate_half_x2_path_1 = rotate_half_x2_path_1_1 or rotate_half_x2_path_1_2
|
|
1337
|
+
|
|
1338
|
+
rotate_half_x2_path_2_1 = self.model.match_parent_path(
|
|
1339
|
+
node,
|
|
1340
|
+
["Mul", "Concat", "Neg", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"],
|
|
1341
|
+
[1, 0, 0, 0, 1, 0, 0, 0, 0],
|
|
1342
|
+
)
|
|
1343
|
+
|
|
1344
|
+
rotate_half_x2_path_2_2 = self.model.match_parent_path(
|
|
1345
|
+
node,
|
|
1346
|
+
["Mul", "Concat", "Neg", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Slice"],
|
|
1347
|
+
[1, 0, 0, 0, 1, 0, 0, 0, 0],
|
|
1348
|
+
)
|
|
1349
|
+
|
|
1350
|
+
rotate_half_x2_path_2 = rotate_half_x2_path_2_1 or rotate_half_x2_path_2_2
|
|
1351
|
+
|
|
1352
|
+
if rotate_half_x2_path_1 is None or rotate_half_x2_path_2 is None:
|
|
1353
|
+
logger.debug("fuse_rotary_embeddings: failed to match x2 in rotate_half")
|
|
1354
|
+
return
|
|
1355
|
+
|
|
1356
|
+
rotate_half_x1_path_1_1 = self.model.match_parent_path(
|
|
1357
|
+
node,
|
|
1358
|
+
["Mul", "Concat", "Slice", "Transpose"],
|
|
1359
|
+
[1, 0, 1, 0],
|
|
1360
|
+
)
|
|
1361
|
+
|
|
1362
|
+
rotate_half_x1_path_1_2 = self.model.match_parent_path(
|
|
1363
|
+
node,
|
|
1364
|
+
["Mul", "Concat", "Slice", "Slice"],
|
|
1365
|
+
[1, 0, 1, 0],
|
|
1366
|
+
)
|
|
1367
|
+
|
|
1368
|
+
rotate_half_x1_path_1 = rotate_half_x1_path_1_1 or rotate_half_x1_path_1_2
|
|
1369
|
+
|
|
1370
|
+
rotate_half_x1_path_2_1 = self.model.match_parent_path(
|
|
1371
|
+
node,
|
|
1372
|
+
["Mul", "Concat", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"],
|
|
1373
|
+
[1, 0, 1, 2, 0, 0, 0, 0],
|
|
1374
|
+
)
|
|
1375
|
+
|
|
1376
|
+
rotate_half_x1_path_2_2 = self.model.match_parent_path(
|
|
1377
|
+
node,
|
|
1378
|
+
["Mul", "Concat", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Slice"],
|
|
1379
|
+
[1, 0, 1, 2, 0, 0, 0, 0],
|
|
1380
|
+
)
|
|
1381
|
+
|
|
1382
|
+
rotate_half_x1_path_2 = rotate_half_x1_path_2_1 or rotate_half_x1_path_2_2
|
|
1383
|
+
|
|
1384
|
+
if rotate_half_x1_path_1 is None or rotate_half_x1_path_2 is None:
|
|
1385
|
+
logger.debug("fuse_rotary_embeddings: failed to match x1 in rotate_half")
|
|
1386
|
+
return
|
|
1387
|
+
|
|
1388
|
+
if (
|
|
1389
|
+
rotate_half_x1_path_1[-1].name != rotate_half_x1_path_2[-1].name
|
|
1390
|
+
or rotate_half_x2_path_1[-1].name != rotate_half_x2_path_2[-1].name
|
|
1391
|
+
or rotate_half_x1_path_1[-1].name != rotate_half_x2_path_1[-1].name
|
|
1392
|
+
or rotate_half_x1_path_2[-1].name != rotate_half_x2_path_2[-1].name
|
|
1393
|
+
):
|
|
1394
|
+
logger.debug("fuse_rotary_embeddings: failed to match common input in rotate_half")
|
|
1395
|
+
return
|
|
1396
|
+
|
|
1397
|
+
# Check path for x
|
|
1398
|
+
x_path_1 = self.model.match_parent_path(
|
|
1399
|
+
node,
|
|
1400
|
+
["Mul", "Transpose"],
|
|
1401
|
+
[0, 0],
|
|
1402
|
+
)
|
|
1403
|
+
|
|
1404
|
+
x_path_2 = self.model.match_parent_path(
|
|
1405
|
+
node,
|
|
1406
|
+
["Mul", "Slice"],
|
|
1407
|
+
[0, 0],
|
|
1408
|
+
)
|
|
1409
|
+
|
|
1410
|
+
x_path = x_path_1 or x_path_2
|
|
1411
|
+
|
|
1412
|
+
if x_path is None:
|
|
1413
|
+
logger.debug("fuse_rotary_embeddings: failed to match x in rotate_half")
|
|
1414
|
+
return
|
|
1415
|
+
|
|
1416
|
+
# Check path for sin
|
|
1417
|
+
sin_path, sin_cache, position_ids = None, "", ""
|
|
1418
|
+
sin_path_1 = self.model.match_parent_path(
|
|
1419
|
+
node,
|
|
1420
|
+
["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Gather", "Shape"],
|
|
1421
|
+
[1, 1, 0, 0, 0, 0, 2, 0, 0],
|
|
1422
|
+
)
|
|
1423
|
+
sin_path_2 = self.model.match_parent_path(
|
|
1424
|
+
node,
|
|
1425
|
+
["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Add"],
|
|
1426
|
+
[1, 1, 0, 0, 0, 0, 2, 0],
|
|
1427
|
+
)
|
|
1428
|
+
sin_path_3 = self.model.match_parent_path(
|
|
1429
|
+
node,
|
|
1430
|
+
["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Gather", "Shape"],
|
|
1431
|
+
[1, 1, 0, 0, 2, 0, 0],
|
|
1432
|
+
)
|
|
1433
|
+
sin_path_4 = self.model.match_parent_path(
|
|
1434
|
+
node,
|
|
1435
|
+
["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Add"],
|
|
1436
|
+
[1, 1, 0, 0, 2, 0],
|
|
1437
|
+
)
|
|
1438
|
+
if sin_path_1 is not None:
|
|
1439
|
+
sin_path = sin_path_1
|
|
1440
|
+
sin_cache = sin_path[-4].input[0]
|
|
1441
|
+
elif sin_path_2 is not None:
|
|
1442
|
+
sin_path = sin_path_2
|
|
1443
|
+
sin_cache = sin_path[-3].input[0]
|
|
1444
|
+
elif sin_path_3 is not None:
|
|
1445
|
+
sin_path = sin_path_3
|
|
1446
|
+
sin_cache = sin_path[-4].input[0]
|
|
1447
|
+
position_ids = sin_path[2].input[1]
|
|
1448
|
+
elif sin_path_4 is not None:
|
|
1449
|
+
sin_path = sin_path_4
|
|
1450
|
+
sin_cache = sin_path[-3].input[0]
|
|
1451
|
+
position_ids = sin_path[2].input[1]
|
|
1452
|
+
else:
|
|
1453
|
+
logger.debug("fuse_rotary_embeddings: failed to match sin path in apply_rope")
|
|
1454
|
+
return
|
|
1455
|
+
|
|
1456
|
+
# Check path for cos
|
|
1457
|
+
cos_path, cos_cache = None, ""
|
|
1458
|
+
cos_path_1 = self.model.match_parent_path(
|
|
1459
|
+
node,
|
|
1460
|
+
["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Gather", "Shape"],
|
|
1461
|
+
[0, 1, 0, 0, 0, 0, 2, 0, 0],
|
|
1462
|
+
)
|
|
1463
|
+
cos_path_2 = self.model.match_parent_path(
|
|
1464
|
+
node,
|
|
1465
|
+
["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Add"],
|
|
1466
|
+
[0, 1, 0, 0, 0, 0, 2, 0],
|
|
1467
|
+
)
|
|
1468
|
+
cos_path_3 = self.model.match_parent_path(
|
|
1469
|
+
node,
|
|
1470
|
+
["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Gather", "Shape"],
|
|
1471
|
+
[0, 1, 0, 0, 2, 0, 0],
|
|
1472
|
+
)
|
|
1473
|
+
cos_path_4 = self.model.match_parent_path(
|
|
1474
|
+
node,
|
|
1475
|
+
["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Add"],
|
|
1476
|
+
[0, 1, 0, 0, 2, 0],
|
|
1477
|
+
)
|
|
1478
|
+
if cos_path_1 is not None:
|
|
1479
|
+
cos_path = cos_path_1
|
|
1480
|
+
cos_cache = cos_path[-4].input[0]
|
|
1481
|
+
elif cos_path_2 is not None:
|
|
1482
|
+
cos_path = cos_path_2
|
|
1483
|
+
cos_cache = cos_path[-3].input[0]
|
|
1484
|
+
elif cos_path_3 is not None:
|
|
1485
|
+
cos_path = cos_path_3
|
|
1486
|
+
cos_cache = cos_path[-4].input[0]
|
|
1487
|
+
position_ids = cos_path[2].input[1]
|
|
1488
|
+
elif cos_path_4 is not None:
|
|
1489
|
+
cos_path = cos_path_4
|
|
1490
|
+
cos_cache = cos_path[-3].input[0]
|
|
1491
|
+
position_ids = cos_path[2].input[1]
|
|
1492
|
+
else:
|
|
1493
|
+
logger.debug("fuse_rotary_embeddings: failed to match sin path in apply_rope")
|
|
1494
|
+
return
|
|
1495
|
+
|
|
1496
|
+
# Check path for position ids
|
|
1497
|
+
if position_ids == "":
|
|
1498
|
+
position_ids_from_sin_path = self.model.match_parent_path(
|
|
1499
|
+
sin_path[2],
|
|
1500
|
+
["Reshape"],
|
|
1501
|
+
[1],
|
|
1502
|
+
)
|
|
1503
|
+
position_ids_from_cos_path = self.model.match_parent_path(
|
|
1504
|
+
cos_path[2],
|
|
1505
|
+
["Reshape"],
|
|
1506
|
+
[1],
|
|
1507
|
+
)
|
|
1508
|
+
if (
|
|
1509
|
+
position_ids_from_sin_path is None
|
|
1510
|
+
or position_ids_from_cos_path is None
|
|
1511
|
+
or position_ids_from_sin_path[0].name != position_ids_from_cos_path[0].name
|
|
1512
|
+
):
|
|
1513
|
+
logger.debug("fuse_rotary_embeddings: failed to match position ids path in apply_rope")
|
|
1514
|
+
return
|
|
1515
|
+
position_ids = position_ids_from_cos_path[0].input[0]
|
|
1516
|
+
else:
|
|
1517
|
+
position_ids_from_sin_path = []
|
|
1518
|
+
position_ids_from_cos_path = []
|
|
1519
|
+
|
|
1520
|
+
past_seq_len_path, curr_seq_len_path = None, None
|
|
1521
|
+
if (sin_path == sin_path_1 and cos_path == cos_path_1) or (
|
|
1522
|
+
sin_path == sin_path_3 and cos_path == cos_path_3
|
|
1523
|
+
):
|
|
1524
|
+
if sin_path[-2].name != cos_path[-2].name or sin_path[-1].name != cos_path[-1].name:
|
|
1525
|
+
logger.debug(
|
|
1526
|
+
"fuse_rotary_embeddings: failed to match common Gather node and Shape node in sin cache and cos cache"
|
|
1527
|
+
)
|
|
1528
|
+
return
|
|
1529
|
+
elif (sin_path == sin_path_2 and cos_path == cos_path_2) or (
|
|
1530
|
+
sin_path == sin_path_4 and cos_path == cos_path_4
|
|
1531
|
+
):
|
|
1532
|
+
if sin_path[-1].name != cos_path[-1].name:
|
|
1533
|
+
logger.debug("fuse_rotary_embeddings: failed to match common Add node in sin cache and cos cache")
|
|
1534
|
+
return
|
|
1535
|
+
# Match past sequence length path: past_key --> Shape --> Gather --> Add
|
|
1536
|
+
past_seq_len_path = self.model.match_parent_path(
|
|
1537
|
+
sin_path[-1],
|
|
1538
|
+
["Gather", "Shape"],
|
|
1539
|
+
[1, 0],
|
|
1540
|
+
)
|
|
1541
|
+
# Match current sequence length path: transpose_k --> Shape --> Gather --> Add
|
|
1542
|
+
curr_seq_len_path = self.model.match_parent_path(
|
|
1543
|
+
sin_path[-1],
|
|
1544
|
+
["Gather", "Shape", "Transpose"],
|
|
1545
|
+
[0, 0, 0],
|
|
1546
|
+
)
|
|
1547
|
+
if (
|
|
1548
|
+
past_seq_len_path is None
|
|
1549
|
+
or curr_seq_len_path is None
|
|
1550
|
+
or self.model.find_graph_input(past_seq_len_path[-1].input[0]) is None
|
|
1551
|
+
or curr_seq_len_path[-1].op_type != "Transpose"
|
|
1552
|
+
):
|
|
1553
|
+
logger.debug("fuse_rotary_embeddings: failed to match past_seq_len and curr_seq_len paths")
|
|
1554
|
+
return
|
|
1555
|
+
else:
|
|
1556
|
+
logger.debug("fuse_rotary_embeddings: failed to match common cache paths")
|
|
1557
|
+
|
|
1558
|
+
rotary_emb_node = self.create_rotary_embeddings_from_nodes(
|
|
1559
|
+
rotate_half_x1_path_1[-1].output[0],
|
|
1560
|
+
position_ids,
|
|
1561
|
+
cos_cache,
|
|
1562
|
+
sin_cache,
|
|
1563
|
+
node.output[0],
|
|
1564
|
+
)
|
|
1565
|
+
if rotary_emb_node is None:
|
|
1566
|
+
logger.debug("fuse_rotary_embeddings: failed to create RotaryEmbedding node")
|
|
1567
|
+
return
|
|
1568
|
+
|
|
1569
|
+
# Remove rotary embedding nodes
|
|
1570
|
+
self.add_nodes_to_remove([node])
|
|
1571
|
+
self.add_nodes_to_remove(rotate_half_x1_path_1[:-1])
|
|
1572
|
+
self.add_nodes_to_remove(rotate_half_x1_path_2[:-1])
|
|
1573
|
+
self.add_nodes_to_remove(rotate_half_x2_path_1[:-1])
|
|
1574
|
+
self.add_nodes_to_remove(rotate_half_x2_path_2[:-1])
|
|
1575
|
+
self.add_nodes_to_remove(x_path[:-1])
|
|
1576
|
+
self.add_nodes_to_remove(sin_path)
|
|
1577
|
+
self.add_nodes_to_remove(cos_path)
|
|
1578
|
+
self.add_nodes_to_remove(position_ids_from_sin_path[:-1])
|
|
1579
|
+
self.add_nodes_to_remove(position_ids_from_cos_path[:-1])
|
|
1580
|
+
|
|
1581
|
+
if past_seq_len_path is not None and len(self.model.get_children(past_seq_len_path[0])) == 1:
|
|
1582
|
+
# In merged HF model, output of Gather in past_seq_len_path is used twice
|
|
1583
|
+
# for past_key_values.0.key and once for other past_key_values
|
|
1584
|
+
self.add_nodes_to_remove(past_seq_len_path)
|
|
1585
|
+
if curr_seq_len_path is not None:
|
|
1586
|
+
self.add_nodes_to_remove(curr_seq_len_path[:-1])
|
|
1587
|
+
|
|
1588
|
+
self.increase_counter(self.base_name)
|
|
1589
|
+
self.node_name_to_graph_name[rotary_emb_node.name] = self.this_graph_name
|
|
1590
|
+
self.nodes_to_add.append(rotary_emb_node)
|
|
1591
|
+
self.prune_graph = True
|