onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1304 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
from logging import getLogger
|
|
6
|
+
from typing import Tuple, Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from fusion_base import Fusion
|
|
10
|
+
from fusion_utils import NumpyHelper
|
|
11
|
+
from onnx import NodeProto, TensorProto, helper
|
|
12
|
+
from onnx_model import OnnxModel
|
|
13
|
+
|
|
14
|
+
logger = getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class FusionAttentionUnet(Fusion):
|
|
18
|
+
"""
|
|
19
|
+
Fuse Attention subgraph of UNet into one Attention node.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
model: OnnxModel,
|
|
25
|
+
hidden_size: int,
|
|
26
|
+
num_heads: int,
|
|
27
|
+
is_cross_attention: bool,
|
|
28
|
+
enable_packed_qkv: bool,
|
|
29
|
+
enable_packed_kv: bool,
|
|
30
|
+
):
|
|
31
|
+
super().__init__(
|
|
32
|
+
model,
|
|
33
|
+
"Attention" if is_cross_attention and enable_packed_qkv else "MultiHeadAttention",
|
|
34
|
+
["LayerNormalization"],
|
|
35
|
+
)
|
|
36
|
+
self.hidden_size = hidden_size
|
|
37
|
+
self.num_heads = num_heads
|
|
38
|
+
self.is_cross_attention = is_cross_attention
|
|
39
|
+
|
|
40
|
+
# Note: pack Q/K/V or K/V weights into one tensor make it harder for updating initializers for LoRA.
|
|
41
|
+
# To support LoRA, it is better to use separated Q, K and V inputs in offline optimization,
|
|
42
|
+
# and CUDA operator pre-packs those tensors to preferred format based on available kernels.
|
|
43
|
+
# In this way, we can support LoRA and get optimal performance at same time.
|
|
44
|
+
self.enable_packed_qkv = enable_packed_qkv
|
|
45
|
+
self.enable_packed_kv = enable_packed_kv
|
|
46
|
+
|
|
47
|
+
# Flags to show warning only once
|
|
48
|
+
self.num_heads_warning = True
|
|
49
|
+
self.hidden_size_warning = True
|
|
50
|
+
|
|
51
|
+
def get_num_heads(self, reshape_q: NodeProto, is_torch2: bool = False) -> int:
|
|
52
|
+
"""Detect num_heads from a reshape node.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
reshape_q (NodeProto): reshape node for Q
|
|
56
|
+
is_torch2 (bool): graph pattern is from PyTorch 2.*
|
|
57
|
+
Returns:
|
|
58
|
+
int: num_heads, or 0 if not found
|
|
59
|
+
"""
|
|
60
|
+
num_heads = 0
|
|
61
|
+
if is_torch2:
|
|
62
|
+
# we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
|
|
63
|
+
reshape_parent = self.model.get_parent(reshape_q, 1)
|
|
64
|
+
if reshape_parent and reshape_parent.op_type == "Concat" and len(reshape_parent.input) == 4:
|
|
65
|
+
num_heads = self.model.get_constant_value(reshape_parent.input[2])
|
|
66
|
+
if isinstance(num_heads, np.ndarray) and list(num_heads.shape) == [1]:
|
|
67
|
+
num_heads = int(num_heads)
|
|
68
|
+
else:
|
|
69
|
+
# we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
|
|
70
|
+
q_shape_value = self.model.get_constant_value(reshape_q.input[1])
|
|
71
|
+
if isinstance(q_shape_value, np.ndarray) and list(q_shape_value.shape) == [4]:
|
|
72
|
+
num_heads = int(q_shape_value[2])
|
|
73
|
+
|
|
74
|
+
if isinstance(num_heads, int) and num_heads > 0:
|
|
75
|
+
return num_heads
|
|
76
|
+
|
|
77
|
+
return 0
|
|
78
|
+
|
|
79
|
+
def get_hidden_size(self, layernorm_node):
|
|
80
|
+
"""Detect hidden_size from LayerNormalization node.
|
|
81
|
+
Args:
|
|
82
|
+
layernorm_node (NodeProto): LayerNormalization node before Q, K and V
|
|
83
|
+
Returns:
|
|
84
|
+
int: hidden_size, or 0 if not found
|
|
85
|
+
"""
|
|
86
|
+
layernorm_bias = self.model.get_initializer(layernorm_node.input[2])
|
|
87
|
+
if layernorm_bias:
|
|
88
|
+
return NumpyHelper.to_array(layernorm_bias).shape[0]
|
|
89
|
+
|
|
90
|
+
return 0
|
|
91
|
+
|
|
92
|
+
def get_num_heads_and_hidden_size(
|
|
93
|
+
self, reshape_q: NodeProto, layernorm_node: NodeProto, is_torch2: bool = False
|
|
94
|
+
) -> Tuple[int, int]:
|
|
95
|
+
"""Detect num_heads and hidden_size.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
reshape_q (NodeProto): reshape node for Q
|
|
99
|
+
is_torch2 (bool): graph pattern is from PyTorch 2.*
|
|
100
|
+
layernorm_node (NodeProto): LayerNormalization node before Q, K, V
|
|
101
|
+
Returns:
|
|
102
|
+
Tuple[int, int]: num_heads and hidden_size
|
|
103
|
+
"""
|
|
104
|
+
num_heads = self.get_num_heads(reshape_q, is_torch2)
|
|
105
|
+
if num_heads <= 0:
|
|
106
|
+
num_heads = self.num_heads # Fall back to user specified value
|
|
107
|
+
|
|
108
|
+
if self.num_heads > 0 and num_heads != self.num_heads:
|
|
109
|
+
if self.num_heads_warning:
|
|
110
|
+
logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
|
|
111
|
+
self.num_heads_warning = False # Do not show the warning more than once
|
|
112
|
+
|
|
113
|
+
hidden_size = self.get_hidden_size(layernorm_node)
|
|
114
|
+
if hidden_size <= 0:
|
|
115
|
+
hidden_size = self.hidden_size # Fall back to user specified value
|
|
116
|
+
|
|
117
|
+
if self.hidden_size > 0 and hidden_size != self.hidden_size:
|
|
118
|
+
if self.hidden_size_warning:
|
|
119
|
+
logger.warning(
|
|
120
|
+
f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
|
|
121
|
+
)
|
|
122
|
+
self.hidden_size_warning = False # Do not show the warning more than once
|
|
123
|
+
|
|
124
|
+
return num_heads, hidden_size
|
|
125
|
+
|
|
126
|
+
def create_attention_node(
|
|
127
|
+
self,
|
|
128
|
+
q_matmul: NodeProto,
|
|
129
|
+
k_matmul: NodeProto,
|
|
130
|
+
v_matmul: NodeProto,
|
|
131
|
+
num_heads: int,
|
|
132
|
+
hidden_size: int,
|
|
133
|
+
input: str,
|
|
134
|
+
output: str,
|
|
135
|
+
) -> Union[NodeProto, None]:
|
|
136
|
+
"""Create an Attention node.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
q_matmul (NodeProto): MatMul node in fully connection for Q
|
|
140
|
+
k_matmul (NodeProto): MatMul node in fully connection for K
|
|
141
|
+
v_matmul (NodeProto): MatMul node in fully connection for V
|
|
142
|
+
num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
|
|
143
|
+
hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
|
|
144
|
+
input (str): input name
|
|
145
|
+
output (str): output name
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Union[NodeProto, None]: the node created or None if failed.
|
|
149
|
+
"""
|
|
150
|
+
is_self_attention = not self.is_cross_attention
|
|
151
|
+
|
|
152
|
+
if is_self_attention:
|
|
153
|
+
if q_matmul.input[0] != input or k_matmul.input[0] != input or v_matmul.input[0] != input:
|
|
154
|
+
logger.debug(
|
|
155
|
+
"For self attention, input hidden state for q and k/v shall be same. Got %s, %s, %s",
|
|
156
|
+
q_matmul.input[0],
|
|
157
|
+
k_matmul.input[0],
|
|
158
|
+
v_matmul.input[0],
|
|
159
|
+
)
|
|
160
|
+
return None
|
|
161
|
+
else:
|
|
162
|
+
if q_matmul.input[0] != input or (k_matmul.input[0] != v_matmul.input[0]) or (k_matmul.input[0] == input):
|
|
163
|
+
logger.debug(
|
|
164
|
+
"For cross attention, input hidden state for q and k/v shall be different. Got %s, %s, %s",
|
|
165
|
+
q_matmul.input[0],
|
|
166
|
+
k_matmul.input[0],
|
|
167
|
+
v_matmul.input[0],
|
|
168
|
+
)
|
|
169
|
+
return None
|
|
170
|
+
|
|
171
|
+
if hidden_size > 0 and (hidden_size % num_heads) != 0:
|
|
172
|
+
logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
|
|
173
|
+
return None
|
|
174
|
+
|
|
175
|
+
q_weight = self.model.get_initializer(q_matmul.input[1])
|
|
176
|
+
k_weight = self.model.get_initializer(k_matmul.input[1])
|
|
177
|
+
v_weight = self.model.get_initializer(v_matmul.input[1])
|
|
178
|
+
if not (q_weight and k_weight and v_weight):
|
|
179
|
+
return None
|
|
180
|
+
|
|
181
|
+
# Sometimes weights are stored in fp16
|
|
182
|
+
float_type = q_weight.data_type
|
|
183
|
+
|
|
184
|
+
qw = NumpyHelper.to_array(q_weight)
|
|
185
|
+
kw = NumpyHelper.to_array(k_weight)
|
|
186
|
+
vw = NumpyHelper.to_array(v_weight)
|
|
187
|
+
logger.debug(f"qw={qw.shape} kw={kw.shape} vw={vw.shape} hidden_size={hidden_size}")
|
|
188
|
+
|
|
189
|
+
# assert q and k have same shape as expected
|
|
190
|
+
if is_self_attention:
|
|
191
|
+
if qw.shape != kw.shape or qw.shape != vw.shape:
|
|
192
|
+
return None
|
|
193
|
+
|
|
194
|
+
qw_in_size = qw.shape[0]
|
|
195
|
+
|
|
196
|
+
if hidden_size > 0 and hidden_size != qw_in_size:
|
|
197
|
+
raise ValueError(
|
|
198
|
+
f"Input hidden size ({hidden_size}) is not same as weight dimension of q,k,v ({qw_in_size}). "
|
|
199
|
+
"Please provide a correct input hidden size or pass in 0"
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# All the matrices can have the same shape or q, k matrics can have the same shape with v being different
|
|
203
|
+
# For 2d weights, the shapes would be [in_size, out_size].
|
|
204
|
+
# For 3d weights, shape would be [in_size, a, b] where a*b = out_size
|
|
205
|
+
qw_out_size = int(np.prod(qw.shape[1:]))
|
|
206
|
+
|
|
207
|
+
if self.enable_packed_qkv:
|
|
208
|
+
attention_node_name = self.model.create_node_name("MultiHeadAttention")
|
|
209
|
+
|
|
210
|
+
c = qw_in_size
|
|
211
|
+
n = num_heads
|
|
212
|
+
h = qw_out_size // num_heads
|
|
213
|
+
|
|
214
|
+
# Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 3, H] shape
|
|
215
|
+
qkv_weight = np.dstack([qw.reshape(c, n, h), kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(
|
|
216
|
+
c, n * 3 * h
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_QKV")
|
|
220
|
+
self.add_initializer(
|
|
221
|
+
name=matmul_node_name + "_weight",
|
|
222
|
+
data_type=float_type,
|
|
223
|
+
dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
|
|
224
|
+
vals=qkv_weight,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
matmul_node = helper.make_node(
|
|
228
|
+
"MatMul",
|
|
229
|
+
inputs=[k_matmul.input[0], matmul_node_name + "_weight"],
|
|
230
|
+
outputs=[matmul_node_name + "_out"],
|
|
231
|
+
name=matmul_node_name,
|
|
232
|
+
)
|
|
233
|
+
self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
|
|
234
|
+
|
|
235
|
+
self.add_initializer(
|
|
236
|
+
name=matmul_node_name + "_reshape_shape",
|
|
237
|
+
data_type=TensorProto.INT64,
|
|
238
|
+
dims=[5],
|
|
239
|
+
vals=[0, 0, n, 3, h],
|
|
240
|
+
raw=False,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
reshape_node = helper.make_node(
|
|
244
|
+
"Reshape",
|
|
245
|
+
inputs=[
|
|
246
|
+
matmul_node_name + "_out",
|
|
247
|
+
matmul_node_name + "_reshape_shape",
|
|
248
|
+
],
|
|
249
|
+
outputs=[attention_node_name + "_qkv_input"],
|
|
250
|
+
name=matmul_node_name + "_reshape",
|
|
251
|
+
)
|
|
252
|
+
self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
|
|
253
|
+
self.nodes_to_add.extend([matmul_node, reshape_node])
|
|
254
|
+
self.nodes_to_remove.extend([q_matmul, k_matmul, v_matmul])
|
|
255
|
+
|
|
256
|
+
else:
|
|
257
|
+
qkv_weight = np.stack((qw, kw, vw), axis=1)
|
|
258
|
+
qkv_weight_dim = 3 * qw_out_size
|
|
259
|
+
|
|
260
|
+
attention_node_name = self.model.create_node_name("Attention")
|
|
261
|
+
|
|
262
|
+
self.add_initializer(
|
|
263
|
+
name=attention_node_name + "_qkv_weight",
|
|
264
|
+
data_type=float_type,
|
|
265
|
+
dims=[qw_in_size, qkv_weight_dim],
|
|
266
|
+
vals=qkv_weight,
|
|
267
|
+
)
|
|
268
|
+
else: # cross attention
|
|
269
|
+
attention_node_name = self.model.create_node_name("MultiHeadAttention")
|
|
270
|
+
if self.enable_packed_kv:
|
|
271
|
+
if kw.shape != vw.shape:
|
|
272
|
+
return None
|
|
273
|
+
|
|
274
|
+
kw_in_size = kw.shape[0]
|
|
275
|
+
vw_in_size = vw.shape[0]
|
|
276
|
+
assert kw_in_size == vw_in_size
|
|
277
|
+
|
|
278
|
+
qw_out_size = qw.shape[1]
|
|
279
|
+
kw_out_size = kw.shape[1]
|
|
280
|
+
vw_out_size = vw.shape[1]
|
|
281
|
+
assert qw_out_size == vw_out_size and kw_out_size == vw_out_size
|
|
282
|
+
|
|
283
|
+
c = kw_in_size
|
|
284
|
+
n = num_heads
|
|
285
|
+
h = kw_out_size // num_heads
|
|
286
|
+
|
|
287
|
+
# Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 2, H] shape
|
|
288
|
+
kv_weight = np.dstack([kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(c, n * 2 * h)
|
|
289
|
+
|
|
290
|
+
matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV")
|
|
291
|
+
self.add_initializer(
|
|
292
|
+
name=matmul_node_name + "_weight",
|
|
293
|
+
data_type=float_type,
|
|
294
|
+
dims=[kv_weight.shape[0], kv_weight.shape[1]],
|
|
295
|
+
vals=kv_weight,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
matmul_node = helper.make_node(
|
|
299
|
+
"MatMul",
|
|
300
|
+
inputs=[k_matmul.input[0], matmul_node_name + "_weight"],
|
|
301
|
+
outputs=[matmul_node_name + "_out"],
|
|
302
|
+
name=matmul_node_name,
|
|
303
|
+
)
|
|
304
|
+
self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
|
|
305
|
+
|
|
306
|
+
self.add_initializer(
|
|
307
|
+
name=matmul_node_name + "_reshape_shape",
|
|
308
|
+
data_type=TensorProto.INT64,
|
|
309
|
+
dims=[5],
|
|
310
|
+
vals=[0, 0, n, 2, h],
|
|
311
|
+
raw=False,
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
reshape_node = helper.make_node(
|
|
315
|
+
"Reshape",
|
|
316
|
+
inputs=[
|
|
317
|
+
matmul_node_name + "_out",
|
|
318
|
+
matmul_node_name + "_reshape_shape",
|
|
319
|
+
],
|
|
320
|
+
outputs=[attention_node_name + "_kv_input"],
|
|
321
|
+
name=matmul_node_name + "_reshape",
|
|
322
|
+
)
|
|
323
|
+
self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
|
|
324
|
+
self.nodes_to_add.extend([matmul_node, reshape_node])
|
|
325
|
+
self.nodes_to_remove.extend([k_matmul, v_matmul])
|
|
326
|
+
|
|
327
|
+
# No bias, use zeros
|
|
328
|
+
qkv_bias = np.zeros([3, hidden_size], dtype=np.float32)
|
|
329
|
+
qkv_bias_dim = 3 * hidden_size
|
|
330
|
+
|
|
331
|
+
self.add_initializer(
|
|
332
|
+
name=attention_node_name + "_qkv_bias",
|
|
333
|
+
data_type=float_type,
|
|
334
|
+
dims=[qkv_bias_dim],
|
|
335
|
+
vals=qkv_bias,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
if is_self_attention:
|
|
339
|
+
if not self.enable_packed_qkv:
|
|
340
|
+
attention_inputs = [
|
|
341
|
+
input,
|
|
342
|
+
attention_node_name + "_qkv_weight",
|
|
343
|
+
attention_node_name + "_qkv_bias",
|
|
344
|
+
]
|
|
345
|
+
else:
|
|
346
|
+
attention_inputs = [attention_node_name + "_qkv_input"]
|
|
347
|
+
else:
|
|
348
|
+
if not self.enable_packed_kv:
|
|
349
|
+
attention_inputs = [
|
|
350
|
+
q_matmul.output[0],
|
|
351
|
+
k_matmul.output[0],
|
|
352
|
+
v_matmul.output[0],
|
|
353
|
+
attention_node_name + "_qkv_bias",
|
|
354
|
+
]
|
|
355
|
+
else:
|
|
356
|
+
attention_inputs = [
|
|
357
|
+
q_matmul.output[0],
|
|
358
|
+
attention_node_name + "_kv_input",
|
|
359
|
+
]
|
|
360
|
+
|
|
361
|
+
attention_node = helper.make_node(
|
|
362
|
+
"Attention" if (is_self_attention and not self.enable_packed_qkv) else "MultiHeadAttention",
|
|
363
|
+
inputs=attention_inputs,
|
|
364
|
+
outputs=[output],
|
|
365
|
+
name=attention_node_name,
|
|
366
|
+
)
|
|
367
|
+
attention_node.domain = "com.microsoft"
|
|
368
|
+
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
|
|
369
|
+
|
|
370
|
+
counter_name = (
|
|
371
|
+
"Attention (self attention)"
|
|
372
|
+
if is_self_attention and not self.enable_packed_qkv
|
|
373
|
+
else "MultiHeadAttention ({})".format(
|
|
374
|
+
"self attention with packed qkv"
|
|
375
|
+
if self.enable_packed_qkv
|
|
376
|
+
else "cross attention with packed kv" if self.enable_packed_kv else "cross attention"
|
|
377
|
+
)
|
|
378
|
+
)
|
|
379
|
+
self.increase_counter(counter_name)
|
|
380
|
+
return attention_node
|
|
381
|
+
|
|
382
|
+
def create_attention_node_lora(
|
|
383
|
+
self,
|
|
384
|
+
q_matmul_add: NodeProto,
|
|
385
|
+
k_matmul_add: NodeProto,
|
|
386
|
+
v_matmul_add: NodeProto,
|
|
387
|
+
num_heads: int,
|
|
388
|
+
hidden_size: int,
|
|
389
|
+
input: str,
|
|
390
|
+
output: str,
|
|
391
|
+
) -> Union[NodeProto, None]:
|
|
392
|
+
"""Create an Attention node.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
q_matmul (NodeProto): MatMul node in fully connection for Q
|
|
396
|
+
k_matmul (NodeProto): MatMul node in fully connection for K
|
|
397
|
+
v_matmul (NodeProto): MatMul node in fully connection for V
|
|
398
|
+
num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
|
|
399
|
+
hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
|
|
400
|
+
input (str): input name
|
|
401
|
+
output (str): output name
|
|
402
|
+
|
|
403
|
+
Returns:
|
|
404
|
+
Union[NodeProto, None]: the node created or None if failed.
|
|
405
|
+
"""
|
|
406
|
+
is_self_attention = not self.is_cross_attention
|
|
407
|
+
|
|
408
|
+
q_matmul = self.model.match_parent(q_matmul_add, "MatMul", 0)
|
|
409
|
+
k_matmul = self.model.match_parent(k_matmul_add, "MatMul", 0)
|
|
410
|
+
v_matmul = self.model.match_parent(v_matmul_add, "MatMul", 0)
|
|
411
|
+
|
|
412
|
+
q_lora_nodes = self.match_lora_path(q_matmul_add)
|
|
413
|
+
if q_lora_nodes is None:
|
|
414
|
+
return None
|
|
415
|
+
(q_lora_last_node, q_lora_matmul_1) = q_lora_nodes
|
|
416
|
+
|
|
417
|
+
k_lora_nodes = self.match_lora_path(k_matmul_add)
|
|
418
|
+
if k_lora_nodes is None:
|
|
419
|
+
return None
|
|
420
|
+
(k_lora_last_node, k_lora_matmul_1) = k_lora_nodes
|
|
421
|
+
|
|
422
|
+
v_lora_nodes = self.match_lora_path(v_matmul_add)
|
|
423
|
+
if v_lora_nodes is None:
|
|
424
|
+
return None
|
|
425
|
+
(v_lora_last_node, v_lora_matmul_1) = v_lora_nodes
|
|
426
|
+
|
|
427
|
+
if is_self_attention:
|
|
428
|
+
if q_matmul.input[0] != input or k_matmul.input[0] != input or v_matmul.input[0] != input:
|
|
429
|
+
logger.debug(
|
|
430
|
+
"For self attention, input hidden state for q and k/v shall be same. Got %s, %s, %s",
|
|
431
|
+
q_matmul.input[0],
|
|
432
|
+
k_matmul.input[0],
|
|
433
|
+
v_matmul.input[0],
|
|
434
|
+
)
|
|
435
|
+
return None
|
|
436
|
+
|
|
437
|
+
if (
|
|
438
|
+
q_lora_matmul_1.input[0] != input
|
|
439
|
+
or k_lora_matmul_1.input[0] != input
|
|
440
|
+
or v_lora_matmul_1.input[0] != input
|
|
441
|
+
):
|
|
442
|
+
logger.debug(
|
|
443
|
+
"For self attention, input hidden state for LoRA q and k/v weights shall be same. Got %s, %s, %s",
|
|
444
|
+
q_lora_matmul_1.input[0],
|
|
445
|
+
k_lora_matmul_1.input[0],
|
|
446
|
+
v_lora_matmul_1.input[0],
|
|
447
|
+
)
|
|
448
|
+
return None
|
|
449
|
+
else:
|
|
450
|
+
if q_matmul.input[0] != input or (k_matmul.input[0] != v_matmul.input[0]) or (k_matmul.input[0] == input):
|
|
451
|
+
logger.debug(
|
|
452
|
+
"For cross attention, input hidden state for q and k/v shall be different. Got %s, %s, %s",
|
|
453
|
+
q_matmul.input[0],
|
|
454
|
+
k_matmul.input[0],
|
|
455
|
+
v_matmul.input[0],
|
|
456
|
+
)
|
|
457
|
+
return None
|
|
458
|
+
|
|
459
|
+
if (
|
|
460
|
+
q_lora_matmul_1.input[0] != input
|
|
461
|
+
or (k_lora_matmul_1.input[0] != v_lora_matmul_1.input[0])
|
|
462
|
+
or (k_matmul.input[0] == input)
|
|
463
|
+
):
|
|
464
|
+
logger.debug(
|
|
465
|
+
(
|
|
466
|
+
"For cross attention, input hidden state for LoRA q and k/v weights shall be different. "
|
|
467
|
+
"Got %s, %s, %s"
|
|
468
|
+
),
|
|
469
|
+
q_lora_matmul_1.input[0],
|
|
470
|
+
k_lora_matmul_1.input[0],
|
|
471
|
+
v_lora_matmul_1.input[0],
|
|
472
|
+
)
|
|
473
|
+
return None
|
|
474
|
+
|
|
475
|
+
if hidden_size > 0 and (hidden_size % num_heads) != 0:
|
|
476
|
+
logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
|
|
477
|
+
return None
|
|
478
|
+
|
|
479
|
+
q_weight = self.model.get_initializer(q_matmul.input[1])
|
|
480
|
+
k_weight = self.model.get_initializer(k_matmul.input[1])
|
|
481
|
+
v_weight = self.model.get_initializer(v_matmul.input[1])
|
|
482
|
+
if not (q_weight and k_weight and v_weight):
|
|
483
|
+
return None
|
|
484
|
+
|
|
485
|
+
# Sometimes weights are stored in fp16
|
|
486
|
+
if q_weight.data_type == 10:
|
|
487
|
+
logger.debug("weights are in fp16. Please run fp16 conversion after optimization")
|
|
488
|
+
return None
|
|
489
|
+
|
|
490
|
+
qw = NumpyHelper.to_array(q_weight)
|
|
491
|
+
kw = NumpyHelper.to_array(k_weight)
|
|
492
|
+
vw = NumpyHelper.to_array(v_weight)
|
|
493
|
+
logger.debug(f"qw={qw.shape} kw={kw.shape} vw={vw.shape} hidden_size={hidden_size}")
|
|
494
|
+
|
|
495
|
+
# assert q and k have same shape as expected
|
|
496
|
+
if is_self_attention:
|
|
497
|
+
if qw.shape != kw.shape or qw.shape != vw.shape:
|
|
498
|
+
return None
|
|
499
|
+
|
|
500
|
+
qw_in_size = qw.shape[0]
|
|
501
|
+
|
|
502
|
+
if hidden_size > 0 and hidden_size != qw_in_size:
|
|
503
|
+
raise ValueError(
|
|
504
|
+
f"Input hidden size ({hidden_size}) is not same as weight dimension of q,k,v ({qw_in_size}). "
|
|
505
|
+
"Please provide a correct input hidden size or pass in 0"
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
# All the matrices can have the same shape or q, k matrics can have the same shape with v being different
|
|
509
|
+
# For 2d weights, the shapes would be [in_size, out_size].
|
|
510
|
+
# For 3d weights, shape would be [in_size, a, b] where a*b = out_size
|
|
511
|
+
qw_out_size = int(np.prod(qw.shape[1:]))
|
|
512
|
+
|
|
513
|
+
if self.enable_packed_qkv:
|
|
514
|
+
attention_node_name = self.model.create_node_name("MultiHeadAttention")
|
|
515
|
+
|
|
516
|
+
c = qw_in_size
|
|
517
|
+
n = num_heads
|
|
518
|
+
h = qw_out_size // num_heads
|
|
519
|
+
|
|
520
|
+
# Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 3, H] shape
|
|
521
|
+
qkv_weight = np.dstack([qw.reshape(c, n, h), kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(
|
|
522
|
+
c, n * 3 * h
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_QKV")
|
|
526
|
+
self.add_initializer(
|
|
527
|
+
name=matmul_node_name + "_weight",
|
|
528
|
+
data_type=TensorProto.FLOAT,
|
|
529
|
+
dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
|
|
530
|
+
vals=qkv_weight,
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
matmul_node = helper.make_node(
|
|
534
|
+
"MatMul",
|
|
535
|
+
inputs=[k_matmul.input[0], matmul_node_name + "_weight"],
|
|
536
|
+
outputs=[matmul_node_name + "_out"],
|
|
537
|
+
name=matmul_node_name,
|
|
538
|
+
)
|
|
539
|
+
self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
|
|
540
|
+
|
|
541
|
+
# Do the same thing with the LoRA weights, but don't constant fold the result. The goal is to allow
|
|
542
|
+
# the Q/K/V weights to be changed without having to re-run the optimizer.
|
|
543
|
+
lora_weight_shape_tensor_name = q_lora_last_node.name + "_reshape_shape"
|
|
544
|
+
|
|
545
|
+
self.add_initializer(
|
|
546
|
+
name=lora_weight_shape_tensor_name,
|
|
547
|
+
data_type=TensorProto.INT64,
|
|
548
|
+
dims=[4],
|
|
549
|
+
vals=[0, 0, n, h],
|
|
550
|
+
raw=False,
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
# Reshape the LoRA Q weights
|
|
554
|
+
q_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_Q")
|
|
555
|
+
q_lora_reshape_node = helper.make_node(
|
|
556
|
+
"Reshape",
|
|
557
|
+
inputs=[q_lora_last_node.output[0], lora_weight_shape_tensor_name],
|
|
558
|
+
outputs=[q_lora_reshape_node_name + "_out"],
|
|
559
|
+
name=q_lora_reshape_node_name,
|
|
560
|
+
)
|
|
561
|
+
self.node_name_to_graph_name[q_lora_reshape_node.name] = self.this_graph_name
|
|
562
|
+
|
|
563
|
+
# Reshape the LoRA K weights
|
|
564
|
+
k_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_K")
|
|
565
|
+
k_lora_reshape_node = helper.make_node(
|
|
566
|
+
"Reshape",
|
|
567
|
+
inputs=[k_lora_last_node.output[0], lora_weight_shape_tensor_name],
|
|
568
|
+
outputs=[k_lora_reshape_node_name + "_out"],
|
|
569
|
+
name=k_lora_reshape_node_name,
|
|
570
|
+
)
|
|
571
|
+
self.node_name_to_graph_name[k_lora_reshape_node.name] = self.this_graph_name
|
|
572
|
+
|
|
573
|
+
# Reshape the LoRA V weights
|
|
574
|
+
v_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_V")
|
|
575
|
+
v_lora_reshape_node = helper.make_node(
|
|
576
|
+
"Reshape",
|
|
577
|
+
inputs=[v_lora_last_node.output[0], lora_weight_shape_tensor_name],
|
|
578
|
+
outputs=[v_lora_reshape_node_name + "_out"],
|
|
579
|
+
name=v_lora_reshape_node_name,
|
|
580
|
+
)
|
|
581
|
+
self.node_name_to_graph_name[v_lora_reshape_node.name] = self.this_graph_name
|
|
582
|
+
|
|
583
|
+
# Concat the reshaped LoRA Q/K/V weights together on the third axis
|
|
584
|
+
qkv_lora_concat_node_name = self.model.create_node_name("Concat", name_prefix="Concat_LoRA_QKV")
|
|
585
|
+
qkv_lora_concat_node = helper.make_node(
|
|
586
|
+
"Concat",
|
|
587
|
+
inputs=[
|
|
588
|
+
q_lora_reshape_node.output[0],
|
|
589
|
+
k_lora_reshape_node.output[0],
|
|
590
|
+
v_lora_reshape_node.output[0],
|
|
591
|
+
],
|
|
592
|
+
outputs=[qkv_lora_concat_node_name + "_out"],
|
|
593
|
+
name=qkv_lora_concat_node_name,
|
|
594
|
+
)
|
|
595
|
+
qkv_lora_concat_node.attribute.extend([helper.make_attribute("axis", 3)])
|
|
596
|
+
self.node_name_to_graph_name[qkv_lora_concat_node.name] = self.this_graph_name
|
|
597
|
+
|
|
598
|
+
# Reshape the LoRA concatenated weights to [..., n * 3 * h]
|
|
599
|
+
reshaped_lora_weights_shape_tensor_name = qkv_lora_concat_node.name + "_reshape_shape"
|
|
600
|
+
self.add_initializer(
|
|
601
|
+
name=reshaped_lora_weights_shape_tensor_name,
|
|
602
|
+
data_type=TensorProto.INT64,
|
|
603
|
+
dims=[3],
|
|
604
|
+
vals=[0, 0, n * 3 * h],
|
|
605
|
+
raw=False,
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
qkv_lora_reshaped_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_QKV")
|
|
609
|
+
qkv_lora_reshaped_node = helper.make_node(
|
|
610
|
+
"Reshape",
|
|
611
|
+
inputs=[qkv_lora_concat_node.output[0], reshaped_lora_weights_shape_tensor_name],
|
|
612
|
+
outputs=[qkv_lora_reshaped_node_name + "_out"],
|
|
613
|
+
name=qkv_lora_reshaped_node_name,
|
|
614
|
+
)
|
|
615
|
+
self.node_name_to_graph_name[qkv_lora_reshaped_node.name] = self.this_graph_name
|
|
616
|
+
|
|
617
|
+
# Add the LoRA Q/K/V weights to the base Q/K/V weights
|
|
618
|
+
add_weights_node_name = self.model.create_node_name("Add", name_prefix="Add_Weights_QKV")
|
|
619
|
+
add_weights_node = helper.make_node(
|
|
620
|
+
"Add",
|
|
621
|
+
inputs=[qkv_lora_reshaped_node.output[0], matmul_node.output[0]],
|
|
622
|
+
outputs=[add_weights_node_name + "_out"],
|
|
623
|
+
name=add_weights_node_name,
|
|
624
|
+
)
|
|
625
|
+
self.node_name_to_graph_name[add_weights_node.name] = self.this_graph_name
|
|
626
|
+
|
|
627
|
+
# Finally, reshape the concatenated Q/K/V result to 5D
|
|
628
|
+
shape_tensor_name = add_weights_node_name + "_reshape_shape"
|
|
629
|
+
self.add_initializer(
|
|
630
|
+
name=shape_tensor_name,
|
|
631
|
+
data_type=TensorProto.INT64,
|
|
632
|
+
dims=[5],
|
|
633
|
+
vals=[0, 0, n, 3, h],
|
|
634
|
+
raw=False,
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
reshape_node = helper.make_node(
|
|
638
|
+
"Reshape",
|
|
639
|
+
inputs=[add_weights_node.output[0], shape_tensor_name],
|
|
640
|
+
outputs=[attention_node_name + "_qkv_input"],
|
|
641
|
+
name=add_weights_node_name + "_reshape",
|
|
642
|
+
)
|
|
643
|
+
self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
|
|
644
|
+
|
|
645
|
+
self.nodes_to_add.extend(
|
|
646
|
+
[
|
|
647
|
+
matmul_node,
|
|
648
|
+
q_lora_reshape_node,
|
|
649
|
+
k_lora_reshape_node,
|
|
650
|
+
v_lora_reshape_node,
|
|
651
|
+
qkv_lora_concat_node,
|
|
652
|
+
qkv_lora_reshaped_node,
|
|
653
|
+
add_weights_node,
|
|
654
|
+
reshape_node,
|
|
655
|
+
]
|
|
656
|
+
)
|
|
657
|
+
self.nodes_to_remove.extend([q_matmul, k_matmul, v_matmul, q_matmul_add, k_matmul_add, v_matmul_add])
|
|
658
|
+
else:
|
|
659
|
+
# TODO: Support non-packed QKV
|
|
660
|
+
return None
|
|
661
|
+
else: # cross attention
|
|
662
|
+
attention_node_name = self.model.create_node_name("MultiHeadAttention")
|
|
663
|
+
if self.enable_packed_kv:
|
|
664
|
+
if kw.shape != vw.shape:
|
|
665
|
+
return None
|
|
666
|
+
|
|
667
|
+
kw_in_size = kw.shape[0]
|
|
668
|
+
vw_in_size = vw.shape[0]
|
|
669
|
+
assert kw_in_size == vw_in_size
|
|
670
|
+
|
|
671
|
+
qw_out_size = qw.shape[1]
|
|
672
|
+
kw_out_size = kw.shape[1]
|
|
673
|
+
vw_out_size = vw.shape[1]
|
|
674
|
+
assert qw_out_size == vw_out_size and kw_out_size == vw_out_size
|
|
675
|
+
|
|
676
|
+
c = kw_in_size
|
|
677
|
+
n = num_heads
|
|
678
|
+
h = kw_out_size // num_heads
|
|
679
|
+
|
|
680
|
+
# Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 2, H] shape
|
|
681
|
+
kv_weight = np.dstack([kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(c, n * 2 * h)
|
|
682
|
+
|
|
683
|
+
matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV")
|
|
684
|
+
self.add_initializer(
|
|
685
|
+
name=matmul_node_name + "_weight",
|
|
686
|
+
data_type=TensorProto.FLOAT,
|
|
687
|
+
dims=[kv_weight.shape[0], kv_weight.shape[1]],
|
|
688
|
+
vals=kv_weight,
|
|
689
|
+
)
|
|
690
|
+
|
|
691
|
+
matmul_node = helper.make_node(
|
|
692
|
+
"MatMul",
|
|
693
|
+
inputs=[k_matmul.input[0], matmul_node_name + "_weight"],
|
|
694
|
+
outputs=[matmul_node_name + "_out"],
|
|
695
|
+
name=matmul_node_name,
|
|
696
|
+
)
|
|
697
|
+
self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
|
|
698
|
+
|
|
699
|
+
# Do the same thing with the LoRA weights, but don't constant fold the result. The goal is to allow
|
|
700
|
+
# the Q/K/V weights to be changed without having to re-run the optimizer.
|
|
701
|
+
kv_lora_weight_shape_tensor_name = q_lora_last_node.name + "_reshape_shape"
|
|
702
|
+
self.add_initializer(
|
|
703
|
+
name=kv_lora_weight_shape_tensor_name,
|
|
704
|
+
data_type=TensorProto.INT64,
|
|
705
|
+
dims=[4],
|
|
706
|
+
vals=[0, 0, n, h],
|
|
707
|
+
raw=False,
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
# Reshape the LoRA K weights
|
|
711
|
+
k_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_K")
|
|
712
|
+
k_lora_reshape_node = helper.make_node(
|
|
713
|
+
"Reshape",
|
|
714
|
+
inputs=[k_lora_last_node.output[0], kv_lora_weight_shape_tensor_name],
|
|
715
|
+
outputs=[k_lora_reshape_node_name + "_out"],
|
|
716
|
+
name=k_lora_reshape_node_name,
|
|
717
|
+
)
|
|
718
|
+
self.node_name_to_graph_name[k_lora_reshape_node.name] = self.this_graph_name
|
|
719
|
+
|
|
720
|
+
# Reshape the LoRA V weights
|
|
721
|
+
v_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_V")
|
|
722
|
+
v_lora_reshape_node = helper.make_node(
|
|
723
|
+
"Reshape",
|
|
724
|
+
inputs=[v_lora_last_node.output[0], kv_lora_weight_shape_tensor_name],
|
|
725
|
+
outputs=[v_lora_reshape_node_name + "_out"],
|
|
726
|
+
name=v_lora_reshape_node_name,
|
|
727
|
+
)
|
|
728
|
+
self.node_name_to_graph_name[v_lora_reshape_node.name] = self.this_graph_name
|
|
729
|
+
|
|
730
|
+
# Concat the reshaped LoRA K/V weights together on the third axis
|
|
731
|
+
kv_lora_concat_node_name = self.model.create_node_name("Concat", name_prefix="Concat_LoRA_KV")
|
|
732
|
+
kv_lora_concat_node = helper.make_node(
|
|
733
|
+
"Concat",
|
|
734
|
+
inputs=[k_lora_reshape_node.output[0], v_lora_reshape_node.output[0]],
|
|
735
|
+
outputs=[kv_lora_concat_node_name + "_out"],
|
|
736
|
+
name=kv_lora_concat_node_name,
|
|
737
|
+
)
|
|
738
|
+
kv_lora_concat_node.attribute.extend([helper.make_attribute("axis", 3)])
|
|
739
|
+
self.node_name_to_graph_name[kv_lora_concat_node.name] = self.this_graph_name
|
|
740
|
+
|
|
741
|
+
# Reshape the LoRA concatenated weights to [..., n * 2 * h]
|
|
742
|
+
reshaped_kv_lora_weights_shape_tensor_name = kv_lora_concat_node.name + "_reshape_shape"
|
|
743
|
+
self.add_initializer(
|
|
744
|
+
name=reshaped_kv_lora_weights_shape_tensor_name,
|
|
745
|
+
data_type=TensorProto.INT64,
|
|
746
|
+
dims=[3],
|
|
747
|
+
vals=[0, 0, n * 2 * h],
|
|
748
|
+
raw=False,
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
kv_lora_reshaped_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_KV")
|
|
752
|
+
kv_lora_reshaped_node = helper.make_node(
|
|
753
|
+
"Reshape",
|
|
754
|
+
inputs=[kv_lora_concat_node.output[0], reshaped_kv_lora_weights_shape_tensor_name],
|
|
755
|
+
outputs=[kv_lora_reshaped_node_name + "_out"],
|
|
756
|
+
name=kv_lora_reshaped_node_name,
|
|
757
|
+
)
|
|
758
|
+
self.node_name_to_graph_name[kv_lora_reshaped_node.name] = self.this_graph_name
|
|
759
|
+
|
|
760
|
+
# Add the LoRA K/V weights to the base K/V weights
|
|
761
|
+
add_kv_weights_node_name = self.model.create_node_name("Add", name_prefix="Add_Weights_KV")
|
|
762
|
+
add_kv_weights_node = helper.make_node(
|
|
763
|
+
"Add",
|
|
764
|
+
inputs=[kv_lora_reshaped_node.output[0], matmul_node.output[0]],
|
|
765
|
+
outputs=[add_kv_weights_node_name + "_out"],
|
|
766
|
+
name=add_kv_weights_node_name,
|
|
767
|
+
)
|
|
768
|
+
self.node_name_to_graph_name[add_kv_weights_node.name] = self.this_graph_name
|
|
769
|
+
|
|
770
|
+
# Finally, reshape the concatenated K/V result to 5D
|
|
771
|
+
shape_tensor_name = add_kv_weights_node_name + "_reshape_shape"
|
|
772
|
+
self.add_initializer(
|
|
773
|
+
name=shape_tensor_name,
|
|
774
|
+
data_type=TensorProto.INT64,
|
|
775
|
+
dims=[5],
|
|
776
|
+
vals=[0, 0, n, 2, h],
|
|
777
|
+
raw=False,
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
reshape_node = helper.make_node(
|
|
781
|
+
"Reshape",
|
|
782
|
+
inputs=[add_kv_weights_node.output[0], shape_tensor_name],
|
|
783
|
+
outputs=[attention_node_name + "_kv_input"],
|
|
784
|
+
name=add_kv_weights_node_name + "_reshape",
|
|
785
|
+
)
|
|
786
|
+
self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
|
|
787
|
+
self.nodes_to_add.extend(
|
|
788
|
+
[
|
|
789
|
+
matmul_node,
|
|
790
|
+
k_lora_reshape_node,
|
|
791
|
+
v_lora_reshape_node,
|
|
792
|
+
kv_lora_concat_node,
|
|
793
|
+
kv_lora_reshaped_node,
|
|
794
|
+
add_kv_weights_node,
|
|
795
|
+
reshape_node,
|
|
796
|
+
]
|
|
797
|
+
)
|
|
798
|
+
self.nodes_to_remove.extend([k_matmul, v_matmul, k_matmul_add, v_matmul_add])
|
|
799
|
+
else:
|
|
800
|
+
# TODO: Support non-packed KV
|
|
801
|
+
return None
|
|
802
|
+
|
|
803
|
+
# No bias, use zeros
|
|
804
|
+
qkv_bias = np.zeros([3, hidden_size], dtype=np.float32)
|
|
805
|
+
qkv_bias_dim = 3 * hidden_size
|
|
806
|
+
self.add_initializer(
|
|
807
|
+
name=attention_node_name + "_qkv_bias",
|
|
808
|
+
data_type=TensorProto.FLOAT,
|
|
809
|
+
dims=[qkv_bias_dim],
|
|
810
|
+
vals=qkv_bias,
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
if is_self_attention:
|
|
814
|
+
if not self.enable_packed_qkv:
|
|
815
|
+
# TODO: Support non-packed QKV
|
|
816
|
+
return None
|
|
817
|
+
else:
|
|
818
|
+
attention_inputs = [attention_node_name + "_qkv_input"]
|
|
819
|
+
else:
|
|
820
|
+
if not self.enable_packed_kv:
|
|
821
|
+
# TODO: Support non-packed QKV
|
|
822
|
+
return None
|
|
823
|
+
else:
|
|
824
|
+
attention_inputs = [
|
|
825
|
+
q_matmul_add.output[0],
|
|
826
|
+
attention_node_name + "_kv_input",
|
|
827
|
+
]
|
|
828
|
+
|
|
829
|
+
attention_node = helper.make_node(
|
|
830
|
+
"Attention" if (is_self_attention and not self.enable_packed_qkv) else "MultiHeadAttention",
|
|
831
|
+
inputs=attention_inputs,
|
|
832
|
+
outputs=[output],
|
|
833
|
+
name=attention_node_name,
|
|
834
|
+
)
|
|
835
|
+
attention_node.domain = "com.microsoft"
|
|
836
|
+
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
|
|
837
|
+
|
|
838
|
+
counter_name = (
|
|
839
|
+
"Attention (self attention)"
|
|
840
|
+
if is_self_attention and not self.enable_packed_qkv
|
|
841
|
+
else "MultiHeadAttention ({})".format(
|
|
842
|
+
"self attention with packed qkv"
|
|
843
|
+
if self.enable_packed_qkv
|
|
844
|
+
else "cross attention with packed kv" if self.enable_packed_kv else "cross attention"
|
|
845
|
+
)
|
|
846
|
+
)
|
|
847
|
+
self.increase_counter(counter_name)
|
|
848
|
+
return attention_node
|
|
849
|
+
|
|
850
|
+
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
|
851
|
+
if self.fuse_a1111_fp16(normalize_node, input_name_to_nodes, output_name_to_node):
|
|
852
|
+
return
|
|
853
|
+
|
|
854
|
+
node_before_layernorm = self.model.match_parent(normalize_node, "Add", 0)
|
|
855
|
+
|
|
856
|
+
# In SD 1.5, for self attention, LayerNorm has parent Reshape
|
|
857
|
+
if node_before_layernorm is None and not self.is_cross_attention:
|
|
858
|
+
node_before_layernorm = self.model.match_parent(normalize_node, "Reshape", 0)
|
|
859
|
+
|
|
860
|
+
if node_before_layernorm is None:
|
|
861
|
+
return
|
|
862
|
+
|
|
863
|
+
root_input = node_before_layernorm.output[0]
|
|
864
|
+
|
|
865
|
+
children_nodes = input_name_to_nodes[root_input]
|
|
866
|
+
skip_add = None
|
|
867
|
+
for node in children_nodes:
|
|
868
|
+
if node.op_type == "Add": # SkipLayerNormalization fusion is not applied yet
|
|
869
|
+
skip_add = node
|
|
870
|
+
break
|
|
871
|
+
if skip_add is None:
|
|
872
|
+
return
|
|
873
|
+
|
|
874
|
+
match_qkv = self.match_qkv_torch1(root_input, skip_add) or self.match_qkv_torch2(root_input, skip_add)
|
|
875
|
+
if match_qkv is not None:
|
|
876
|
+
is_torch2, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v = match_qkv
|
|
877
|
+
|
|
878
|
+
attention_last_node = reshape_qkv
|
|
879
|
+
|
|
880
|
+
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2)
|
|
881
|
+
if q_num_heads <= 0:
|
|
882
|
+
logger.debug("fuse_attention: failed to detect num_heads")
|
|
883
|
+
return
|
|
884
|
+
|
|
885
|
+
# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
|
|
886
|
+
new_node = self.create_attention_node(
|
|
887
|
+
matmul_q,
|
|
888
|
+
matmul_k,
|
|
889
|
+
matmul_v,
|
|
890
|
+
q_num_heads,
|
|
891
|
+
q_hidden_size,
|
|
892
|
+
input=normalize_node.output[0],
|
|
893
|
+
output=attention_last_node.output[0],
|
|
894
|
+
)
|
|
895
|
+
if new_node is None:
|
|
896
|
+
return
|
|
897
|
+
else:
|
|
898
|
+
# Check if we have a LoRA pattern
|
|
899
|
+
match_qkv = self.match_qkv_torch1_lora(root_input, skip_add) or self.match_qkv_torch2_lora(
|
|
900
|
+
root_input, skip_add
|
|
901
|
+
)
|
|
902
|
+
if match_qkv is None:
|
|
903
|
+
return
|
|
904
|
+
|
|
905
|
+
is_torch2, reshape_qkv, transpose_qkv, reshape_q, matmul_add_q, matmul_add_k, matmul_add_v = match_qkv
|
|
906
|
+
|
|
907
|
+
attention_last_node = reshape_qkv
|
|
908
|
+
|
|
909
|
+
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2)
|
|
910
|
+
if q_num_heads <= 0:
|
|
911
|
+
logger.debug("fuse_attention: failed to detect num_heads")
|
|
912
|
+
return
|
|
913
|
+
|
|
914
|
+
# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
|
|
915
|
+
new_node = self.create_attention_node_lora(
|
|
916
|
+
matmul_add_q,
|
|
917
|
+
matmul_add_k,
|
|
918
|
+
matmul_add_v,
|
|
919
|
+
q_num_heads,
|
|
920
|
+
q_hidden_size,
|
|
921
|
+
input=normalize_node.output[0],
|
|
922
|
+
output=attention_last_node.output[0],
|
|
923
|
+
)
|
|
924
|
+
if new_node is None:
|
|
925
|
+
return
|
|
926
|
+
|
|
927
|
+
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2)
|
|
928
|
+
if q_num_heads <= 0:
|
|
929
|
+
logger.debug("fuse_attention: failed to detect num_heads")
|
|
930
|
+
return
|
|
931
|
+
|
|
932
|
+
self.nodes_to_add.append(new_node)
|
|
933
|
+
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
|
934
|
+
|
|
935
|
+
self.nodes_to_remove.extend([attention_last_node, transpose_qkv])
|
|
936
|
+
|
|
937
|
+
# Use prune graph to remove nodes since they are shared by all attention nodes.
|
|
938
|
+
self.prune_graph = True
|
|
939
|
+
|
|
940
|
+
def match_qkv_torch1(self, root_input, skip_add):
|
|
941
|
+
"""Match Q, K and V paths exported by PyTorch 1.*"""
|
|
942
|
+
another_input = 1 if skip_add.input[0] == root_input else 0
|
|
943
|
+
qkv_nodes = self.model.match_parent_path(
|
|
944
|
+
skip_add,
|
|
945
|
+
["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
|
|
946
|
+
[another_input, None, None, 0, 0, 0],
|
|
947
|
+
)
|
|
948
|
+
|
|
949
|
+
if qkv_nodes is None:
|
|
950
|
+
return None
|
|
951
|
+
|
|
952
|
+
(_, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes
|
|
953
|
+
|
|
954
|
+
# No bias. For cross-attention, the input of the MatMul is encoder_hidden_states graph input.
|
|
955
|
+
v_nodes = self.model.match_parent_path(matmul_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
|
|
956
|
+
if v_nodes is None:
|
|
957
|
+
logger.debug("fuse_attention: failed to match v path")
|
|
958
|
+
return None
|
|
959
|
+
(_, _, _, matmul_v) = v_nodes
|
|
960
|
+
|
|
961
|
+
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0])
|
|
962
|
+
if qk_nodes is not None:
|
|
963
|
+
(_softmax_qk, _mul_qk, matmul_qk) = qk_nodes
|
|
964
|
+
else:
|
|
965
|
+
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
|
|
966
|
+
if qk_nodes is not None:
|
|
967
|
+
(_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes
|
|
968
|
+
else:
|
|
969
|
+
logger.debug("fuse_attention: failed to match qk path")
|
|
970
|
+
return None
|
|
971
|
+
|
|
972
|
+
q_nodes = self.model.match_parent_path(matmul_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0])
|
|
973
|
+
if q_nodes is None:
|
|
974
|
+
logger.debug("fuse_attention: failed to match q path")
|
|
975
|
+
return None
|
|
976
|
+
(_, _transpose_q, reshape_q, matmul_q) = q_nodes
|
|
977
|
+
|
|
978
|
+
k_nodes = self.model.match_parent_path(
|
|
979
|
+
matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0, 0]
|
|
980
|
+
)
|
|
981
|
+
if k_nodes is None:
|
|
982
|
+
logger.debug("fuse_attention: failed to match k path")
|
|
983
|
+
return None
|
|
984
|
+
|
|
985
|
+
(_, _, _, _, matmul_k) = k_nodes
|
|
986
|
+
|
|
987
|
+
return False, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v
|
|
988
|
+
|
|
989
|
+
def match_qkv_torch2(self, root_input, skip_add):
|
|
990
|
+
"""Match Q, K and V paths exported by PyTorch 2.*"""
|
|
991
|
+
another_input = 1 if skip_add.input[0] == root_input else 0
|
|
992
|
+
qkv_nodes = self.model.match_parent_path(
|
|
993
|
+
skip_add,
|
|
994
|
+
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
|
|
995
|
+
[another_input, None, None, 0, 0],
|
|
996
|
+
)
|
|
997
|
+
|
|
998
|
+
if qkv_nodes is None:
|
|
999
|
+
return None
|
|
1000
|
+
|
|
1001
|
+
(_, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
|
|
1002
|
+
|
|
1003
|
+
v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "MatMul"], [1, 0, 0])
|
|
1004
|
+
if v_nodes is None:
|
|
1005
|
+
logger.debug("fuse_attention: failed to match v path")
|
|
1006
|
+
return None
|
|
1007
|
+
(_, _, matmul_v) = v_nodes
|
|
1008
|
+
|
|
1009
|
+
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
|
|
1010
|
+
if qk_nodes is not None:
|
|
1011
|
+
(_softmax_qk, matmul_qk) = qk_nodes
|
|
1012
|
+
else:
|
|
1013
|
+
logger.debug("fuse_attention: failed to match qk path")
|
|
1014
|
+
return None
|
|
1015
|
+
|
|
1016
|
+
q_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "MatMul"], [0, None, 0, 0])
|
|
1017
|
+
if q_nodes is None:
|
|
1018
|
+
logger.debug("fuse_attention: failed to match q path")
|
|
1019
|
+
return None
|
|
1020
|
+
(mul_q, _transpose_q, reshape_q, matmul_q) = q_nodes
|
|
1021
|
+
|
|
1022
|
+
k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "MatMul"], [1, None, 0, 0])
|
|
1023
|
+
if k_nodes is None:
|
|
1024
|
+
logger.debug("fuse_attention: failed to match k path")
|
|
1025
|
+
return None
|
|
1026
|
+
|
|
1027
|
+
(_mul_k, _, _, matmul_k) = k_nodes
|
|
1028
|
+
|
|
1029
|
+
# The scalar for Q and K is sqrt(1.0/sqrt(head_size)).
|
|
1030
|
+
mul_q_nodes = self.model.match_parent_path(
|
|
1031
|
+
mul_q,
|
|
1032
|
+
["Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape", "Transpose", "Reshape"],
|
|
1033
|
+
[None, 0, 1, 0, 0, 0, 0, 0],
|
|
1034
|
+
)
|
|
1035
|
+
if mul_q_nodes is None or mul_q_nodes[-1] != reshape_q:
|
|
1036
|
+
logger.debug("fuse_attention: failed to match mul_q path")
|
|
1037
|
+
return None
|
|
1038
|
+
|
|
1039
|
+
return True, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v
|
|
1040
|
+
|
|
1041
|
+
def match_qkv_torch1_lora(self, root_input, skip_add):
|
|
1042
|
+
"""Match Q, K and V paths exported by PyTorch 1 that contains LoRA patterns.*"""
|
|
1043
|
+
another_input = 1 if skip_add.input[0] == root_input else 0
|
|
1044
|
+
qkv_nodes = self.model.match_parent_path(
|
|
1045
|
+
skip_add,
|
|
1046
|
+
["Add", "Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
|
|
1047
|
+
[another_input, 0, None, None, 0, 0, 0],
|
|
1048
|
+
)
|
|
1049
|
+
if qkv_nodes is None:
|
|
1050
|
+
return None
|
|
1051
|
+
|
|
1052
|
+
(_, _, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes
|
|
1053
|
+
|
|
1054
|
+
# No bias. For cross-attention, the input of the MatMul is encoder_hidden_states graph input.
|
|
1055
|
+
v_nodes = self.model.match_parent_path(matmul_qkv, ["Reshape", "Transpose", "Reshape", "Add"], [1, 0, 0, 0])
|
|
1056
|
+
if v_nodes is None:
|
|
1057
|
+
logger.debug("fuse_attention: failed to match LoRA v path")
|
|
1058
|
+
return None
|
|
1059
|
+
(_, _, _, matmul_add_v) = v_nodes
|
|
1060
|
+
|
|
1061
|
+
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0])
|
|
1062
|
+
if qk_nodes is not None:
|
|
1063
|
+
(_softmax_qk, _mul_qk, matmul_qk) = qk_nodes
|
|
1064
|
+
else:
|
|
1065
|
+
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
|
|
1066
|
+
if qk_nodes is not None:
|
|
1067
|
+
(_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes
|
|
1068
|
+
else:
|
|
1069
|
+
logger.debug("fuse_attention: failed to match LoRA qk path")
|
|
1070
|
+
return None
|
|
1071
|
+
|
|
1072
|
+
q_nodes = self.model.match_parent_path(matmul_qk, ["Reshape", "Transpose", "Reshape", "Add"], [0, 0, 0, 0])
|
|
1073
|
+
if q_nodes is None:
|
|
1074
|
+
logger.debug("fuse_attention: failed to match LoRA q path")
|
|
1075
|
+
return None
|
|
1076
|
+
(_, _transpose_q, reshape_q, matmul_add_q) = q_nodes
|
|
1077
|
+
|
|
1078
|
+
k_nodes = self.model.match_parent_path(
|
|
1079
|
+
matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "Add"], [1, 0, 0, 0, 0]
|
|
1080
|
+
)
|
|
1081
|
+
if k_nodes is None:
|
|
1082
|
+
logger.debug("fuse_attention: failed to match LoRA k path")
|
|
1083
|
+
return None
|
|
1084
|
+
|
|
1085
|
+
(_, _, _, _, matmul_add_k) = k_nodes
|
|
1086
|
+
|
|
1087
|
+
return False, reshape_qkv, transpose_qkv, reshape_q, matmul_add_q, matmul_add_k, matmul_add_v
|
|
1088
|
+
|
|
1089
|
+
def match_qkv_torch2_lora(self, root_input, skip_add):
|
|
1090
|
+
"""Match Q, K and V paths exported by PyTorch 2 that contains LoRA patterns.*"""
|
|
1091
|
+
another_input = 1 if skip_add.input[0] == root_input else 0
|
|
1092
|
+
qkv_nodes = self.model.match_parent_path(
|
|
1093
|
+
skip_add,
|
|
1094
|
+
["Add", "Add", "MatMul", "Reshape", "Transpose", "MatMul"],
|
|
1095
|
+
[another_input, 0, None, None, 0, 0],
|
|
1096
|
+
)
|
|
1097
|
+
if qkv_nodes is None:
|
|
1098
|
+
return None
|
|
1099
|
+
|
|
1100
|
+
(_, _, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
|
|
1101
|
+
|
|
1102
|
+
v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add"], [1, 0, 0])
|
|
1103
|
+
if v_nodes is None:
|
|
1104
|
+
logger.debug("fuse_attention: failed to match LoRA v path")
|
|
1105
|
+
return None
|
|
1106
|
+
(_, _, matmul_add_v) = v_nodes
|
|
1107
|
+
|
|
1108
|
+
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
|
|
1109
|
+
if qk_nodes is not None:
|
|
1110
|
+
(_softmax_qk, matmul_qk) = qk_nodes
|
|
1111
|
+
else:
|
|
1112
|
+
logger.debug("fuse_attention: failed to match LoRA qk path")
|
|
1113
|
+
return None
|
|
1114
|
+
|
|
1115
|
+
q_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "Add"], [0, None, 0, 0])
|
|
1116
|
+
if q_nodes is None:
|
|
1117
|
+
logger.debug("fuse_attention: failed to match LoRA q path")
|
|
1118
|
+
return None
|
|
1119
|
+
(mul_q, _transpose_q, reshape_q, matmul_add_q) = q_nodes
|
|
1120
|
+
|
|
1121
|
+
k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "Add"], [1, None, 0, 0])
|
|
1122
|
+
if k_nodes is None:
|
|
1123
|
+
logger.debug("fuse_attention: failed to match LoRA k path")
|
|
1124
|
+
return None
|
|
1125
|
+
|
|
1126
|
+
(_mul_k, _, _, matmul_add_k) = k_nodes
|
|
1127
|
+
|
|
1128
|
+
# The scalar for Q and K is sqrt(1.0/sqrt(head_size)).
|
|
1129
|
+
mul_q_nodes = self.model.match_parent_path(
|
|
1130
|
+
mul_q,
|
|
1131
|
+
["Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape", "Transpose", "Reshape"],
|
|
1132
|
+
[None, 0, 1, 0, 0, 0, 0, 0],
|
|
1133
|
+
)
|
|
1134
|
+
if mul_q_nodes is None or mul_q_nodes[-1] != reshape_q:
|
|
1135
|
+
logger.debug("fuse_attention: failed to match LoRA mul_q path")
|
|
1136
|
+
return None
|
|
1137
|
+
|
|
1138
|
+
return True, reshape_qkv, transpose_qkv, reshape_q, matmul_add_q, matmul_add_k, matmul_add_v
|
|
1139
|
+
|
|
1140
|
+
def match_lora_path(
|
|
1141
|
+
self,
|
|
1142
|
+
add_node: NodeProto,
|
|
1143
|
+
):
|
|
1144
|
+
# Lora paths can look like one of the following options:
|
|
1145
|
+
# MatMul -> MatMul -> Add
|
|
1146
|
+
# MatMul -> MatMul -> Mul -> Add
|
|
1147
|
+
# MatMul -> MatMul -> Mul -> Mul -> Add
|
|
1148
|
+
|
|
1149
|
+
# Try matching MatMul -> MatMul -> Add
|
|
1150
|
+
lora_nodes = self.model.match_parent_path(
|
|
1151
|
+
add_node,
|
|
1152
|
+
["MatMul", "MatMul"],
|
|
1153
|
+
[1, 0],
|
|
1154
|
+
)
|
|
1155
|
+
|
|
1156
|
+
if lora_nodes is not None:
|
|
1157
|
+
(lora_matmul_2_node, lora_matmul_1_node) = lora_nodes
|
|
1158
|
+
return (lora_matmul_2_node, lora_matmul_1_node)
|
|
1159
|
+
|
|
1160
|
+
# Try matching MatMul -> MatMul -> Mul -> Add
|
|
1161
|
+
lora_nodes = self.model.match_parent_path(
|
|
1162
|
+
add_node,
|
|
1163
|
+
["Mul", "MatMul", "MatMul"],
|
|
1164
|
+
[1, 0, 0],
|
|
1165
|
+
)
|
|
1166
|
+
|
|
1167
|
+
if lora_nodes is not None:
|
|
1168
|
+
(lora_mul_node, _, lora_matmul_1_node) = lora_nodes
|
|
1169
|
+
return (lora_mul_node, lora_matmul_1_node)
|
|
1170
|
+
|
|
1171
|
+
# Try matching MatMul -> MatMul -> Mul -> Mul -> Add
|
|
1172
|
+
lora_nodes = self.model.match_parent_path(
|
|
1173
|
+
add_node,
|
|
1174
|
+
["Mul", "Mul", "MatMul", "MatMul"],
|
|
1175
|
+
[1, 0, 0, 0],
|
|
1176
|
+
)
|
|
1177
|
+
|
|
1178
|
+
if lora_nodes is not None:
|
|
1179
|
+
(lora_mul_node, _, _, lora_matmul_1_node) = lora_nodes
|
|
1180
|
+
return (lora_mul_node, lora_matmul_1_node)
|
|
1181
|
+
|
|
1182
|
+
return None
|
|
1183
|
+
|
|
1184
|
+
def fuse_a1111_fp16(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
|
1185
|
+
"""Fuse attention of fp16 UNet exported in A1111 (stable diffusion webui) extension"""
|
|
1186
|
+
entry_path = self.model.match_parent_path(normalize_node, ["Cast", "Add"], [0, 0])
|
|
1187
|
+
if entry_path is None:
|
|
1188
|
+
entry_path = self.model.match_parent_path(normalize_node, ["Cast", "Reshape"], [0, 0])
|
|
1189
|
+
if entry_path is None:
|
|
1190
|
+
return False
|
|
1191
|
+
_cast, node_before_layernorm = entry_path
|
|
1192
|
+
|
|
1193
|
+
root_input = node_before_layernorm.output[0]
|
|
1194
|
+
|
|
1195
|
+
children_nodes = input_name_to_nodes[root_input]
|
|
1196
|
+
skip_add = None
|
|
1197
|
+
for node in children_nodes:
|
|
1198
|
+
if node.op_type == "Add": # SkipLayerNormalization fusion is not applied yet
|
|
1199
|
+
skip_add = node
|
|
1200
|
+
break
|
|
1201
|
+
if skip_add is None:
|
|
1202
|
+
return False
|
|
1203
|
+
|
|
1204
|
+
match_qkv = self.match_qkv_a1111(root_input, skip_add)
|
|
1205
|
+
if match_qkv is None:
|
|
1206
|
+
return False
|
|
1207
|
+
|
|
1208
|
+
(
|
|
1209
|
+
reshape_qkv,
|
|
1210
|
+
transpose_qkv,
|
|
1211
|
+
reshape_q,
|
|
1212
|
+
matmul_q,
|
|
1213
|
+
matmul_k,
|
|
1214
|
+
matmul_v,
|
|
1215
|
+
) = match_qkv
|
|
1216
|
+
|
|
1217
|
+
cast_q = self.model.match_parent(matmul_q, "Cast", 0)
|
|
1218
|
+
cast_k = self.model.match_parent(matmul_k, "Cast", 0)
|
|
1219
|
+
cast_v = self.model.match_parent(matmul_v, "Cast", 0)
|
|
1220
|
+
if not (
|
|
1221
|
+
cast_q is not None
|
|
1222
|
+
and cast_k is not None
|
|
1223
|
+
and (cast_q == cast_k if not self.is_cross_attention else cast_q != cast_k)
|
|
1224
|
+
and cast_k == cast_v
|
|
1225
|
+
):
|
|
1226
|
+
return False
|
|
1227
|
+
|
|
1228
|
+
if cast_q.input[0] != normalize_node.output[0]:
|
|
1229
|
+
return False
|
|
1230
|
+
|
|
1231
|
+
attention_last_node = reshape_qkv
|
|
1232
|
+
|
|
1233
|
+
q_num_heads = self.get_num_heads(reshape_q, True) or self.get_num_heads(reshape_q, False)
|
|
1234
|
+
if q_num_heads <= 0:
|
|
1235
|
+
logger.debug("fuse_attention: failed to detect num_heads")
|
|
1236
|
+
return False
|
|
1237
|
+
|
|
1238
|
+
q_hidden_size = self.get_hidden_size(normalize_node)
|
|
1239
|
+
|
|
1240
|
+
# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
|
|
1241
|
+
new_node = self.create_attention_node(
|
|
1242
|
+
matmul_q,
|
|
1243
|
+
matmul_k,
|
|
1244
|
+
matmul_v,
|
|
1245
|
+
q_num_heads,
|
|
1246
|
+
q_hidden_size,
|
|
1247
|
+
input=matmul_q.input[0],
|
|
1248
|
+
output=attention_last_node.output[0],
|
|
1249
|
+
)
|
|
1250
|
+
if new_node is None:
|
|
1251
|
+
return False
|
|
1252
|
+
|
|
1253
|
+
self.nodes_to_add.append(new_node)
|
|
1254
|
+
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
|
1255
|
+
|
|
1256
|
+
self.nodes_to_remove.extend([attention_last_node, transpose_qkv])
|
|
1257
|
+
|
|
1258
|
+
# Use prune graph to remove nodes since they are shared by all attention nodes.
|
|
1259
|
+
self.prune_graph = True
|
|
1260
|
+
return True
|
|
1261
|
+
|
|
1262
|
+
def match_qkv_a1111(self, root_input, skip_add):
|
|
1263
|
+
"""Match Q, K and V paths exported by A1111 (stable diffusion webui) extension"""
|
|
1264
|
+
another_input = 1 if skip_add.input[0] == root_input else 0
|
|
1265
|
+
qkv_nodes = self.model.match_parent_path(
|
|
1266
|
+
skip_add,
|
|
1267
|
+
["Add", "MatMul", "Reshape", "Transpose", "Reshape", "Einsum"],
|
|
1268
|
+
[another_input, None, None, 0, 0, 0],
|
|
1269
|
+
)
|
|
1270
|
+
|
|
1271
|
+
if qkv_nodes is None:
|
|
1272
|
+
return None
|
|
1273
|
+
|
|
1274
|
+
(_, _, reshape_qkv, transpose_qkv, reshape_einsum, einsum_qkv) = qkv_nodes
|
|
1275
|
+
|
|
1276
|
+
v_nodes = self.model.match_parent_path(einsum_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
|
|
1277
|
+
if v_nodes is None:
|
|
1278
|
+
logger.debug("fuse_attention: failed to match v path")
|
|
1279
|
+
return None
|
|
1280
|
+
(_, _, _, matmul_v) = v_nodes
|
|
1281
|
+
|
|
1282
|
+
qk_nodes = self.model.match_parent_path(
|
|
1283
|
+
einsum_qkv, ["Cast", "Cast", "Softmax", "Mul", "Einsum"], [0, 0, 0, 0, None]
|
|
1284
|
+
)
|
|
1285
|
+
if qk_nodes is not None:
|
|
1286
|
+
(_, _, _softmax_qk, _, einsum_qk) = qk_nodes
|
|
1287
|
+
else:
|
|
1288
|
+
logger.debug("fuse_attention: failed to match qk path")
|
|
1289
|
+
return None
|
|
1290
|
+
|
|
1291
|
+
q_nodes = self.model.match_parent_path(einsum_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0])
|
|
1292
|
+
if q_nodes is None:
|
|
1293
|
+
logger.debug("fuse_attention: failed to match q path")
|
|
1294
|
+
return None
|
|
1295
|
+
(_, _transpose_q, reshape_q, matmul_q) = q_nodes
|
|
1296
|
+
|
|
1297
|
+
k_nodes = self.model.match_parent_path(einsum_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
|
|
1298
|
+
if k_nodes is None:
|
|
1299
|
+
logger.debug("fuse_attention: failed to match k path")
|
|
1300
|
+
return None
|
|
1301
|
+
|
|
1302
|
+
(_, _, _, matmul_k) = k_nodes
|
|
1303
|
+
|
|
1304
|
+
return reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v
|