onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1235 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
from logging import getLogger
|
|
6
|
+
from typing import List, Optional, Tuple, Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from fusion_base import Fusion
|
|
10
|
+
from fusion_options import AttentionMaskFormat
|
|
11
|
+
from fusion_utils import FusionUtils, NumpyHelper
|
|
12
|
+
from onnx import NodeProto, TensorProto, helper, numpy_helper
|
|
13
|
+
from onnx_model import OnnxModel
|
|
14
|
+
|
|
15
|
+
logger = getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AttentionMask:
|
|
19
|
+
"""
|
|
20
|
+
Fuse Attention subgraph into one Attention node.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, model: OnnxModel):
|
|
24
|
+
self.model = model
|
|
25
|
+
# A lookup table with mask input as key, and mask index output as value
|
|
26
|
+
self.mask_indice = {}
|
|
27
|
+
# A lookup table with mask input as key, and cast (to int32) output as value
|
|
28
|
+
self.mask_casted = {}
|
|
29
|
+
self.utils = FusionUtils(model)
|
|
30
|
+
self.mask_format = AttentionMaskFormat.MaskIndexEnd
|
|
31
|
+
self.opset_version = model.get_opset_version()
|
|
32
|
+
|
|
33
|
+
def set_mask_format(self, mask_format: AttentionMaskFormat):
|
|
34
|
+
self.mask_format = mask_format
|
|
35
|
+
|
|
36
|
+
def set_mask_indice(self, mask, mask_index):
|
|
37
|
+
if mask in self.mask_indice:
|
|
38
|
+
assert mask_index == self.mask_indice[mask]
|
|
39
|
+
self.mask_indice[mask] = mask_index
|
|
40
|
+
|
|
41
|
+
def get_first_mask(self):
|
|
42
|
+
assert len(self.mask_indice) > 0
|
|
43
|
+
return next(iter(self.mask_indice))
|
|
44
|
+
|
|
45
|
+
def process_mask(self, input: str) -> str:
|
|
46
|
+
if self.mask_format == AttentionMaskFormat.NoMask:
|
|
47
|
+
return None
|
|
48
|
+
|
|
49
|
+
if input in self.mask_indice:
|
|
50
|
+
return self.mask_indice[input]
|
|
51
|
+
|
|
52
|
+
# Add cast to convert int64 to int32
|
|
53
|
+
if self.model.find_graph_input(input):
|
|
54
|
+
casted, input_name = self.utils.cast_graph_input_to_int32(input)
|
|
55
|
+
else:
|
|
56
|
+
input_name, cast_node = self.utils.cast_input_to_int32(input)
|
|
57
|
+
casted = True
|
|
58
|
+
|
|
59
|
+
if casted:
|
|
60
|
+
self.mask_casted[input] = input_name
|
|
61
|
+
|
|
62
|
+
# Attention supports int32 attention mask (2D) since 1.4.0
|
|
63
|
+
if self.mask_format == AttentionMaskFormat.AttentionMask:
|
|
64
|
+
self.mask_indice[input] = input_name
|
|
65
|
+
return input_name
|
|
66
|
+
|
|
67
|
+
# Add a mask processing node to convert attention mask to mask index (1D)
|
|
68
|
+
output_name = self.model.create_node_name("mask_index")
|
|
69
|
+
if self.opset_version < 13:
|
|
70
|
+
mask_index_node = helper.make_node(
|
|
71
|
+
"ReduceSum",
|
|
72
|
+
inputs=[input_name],
|
|
73
|
+
outputs=[output_name],
|
|
74
|
+
name=self.model.create_node_name("ReduceSum", "MaskReduceSum"),
|
|
75
|
+
)
|
|
76
|
+
mask_index_node.attribute.extend([helper.make_attribute("axes", [1]), helper.make_attribute("keepdims", 0)])
|
|
77
|
+
else:
|
|
78
|
+
# ReduceSum-13: axes is moved from attribute to input
|
|
79
|
+
axes_name = "ort_const_1_reduce_sum_axes"
|
|
80
|
+
if self.model.get_initializer(axes_name) is None:
|
|
81
|
+
self.model.add_initializer(
|
|
82
|
+
helper.make_tensor(
|
|
83
|
+
name=axes_name,
|
|
84
|
+
data_type=TensorProto.INT64,
|
|
85
|
+
dims=[1],
|
|
86
|
+
vals=[1],
|
|
87
|
+
raw=False,
|
|
88
|
+
)
|
|
89
|
+
)
|
|
90
|
+
mask_index_node = helper.make_node(
|
|
91
|
+
"ReduceSum",
|
|
92
|
+
inputs=[input_name, axes_name],
|
|
93
|
+
outputs=[output_name],
|
|
94
|
+
name=self.model.create_node_name("ReduceSum", "MaskReduceSum"),
|
|
95
|
+
)
|
|
96
|
+
mask_index_node.attribute.extend([helper.make_attribute("keepdims", 0)])
|
|
97
|
+
|
|
98
|
+
self.model.add_node(mask_index_node)
|
|
99
|
+
|
|
100
|
+
self.mask_indice[input] = output_name
|
|
101
|
+
return output_name
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class FusionAttention(Fusion):
|
|
105
|
+
"""
|
|
106
|
+
Fuse Attention subgraph into one Attention node.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
model: OnnxModel,
|
|
112
|
+
hidden_size: int,
|
|
113
|
+
num_heads: int,
|
|
114
|
+
attention_mask: Optional[AttentionMask] = None,
|
|
115
|
+
use_multi_head_attention: bool = False,
|
|
116
|
+
disable_multi_head_attention_bias: bool = False,
|
|
117
|
+
search_op_types: List[str] = ["SkipLayerNormalization", "LayerNormalization"], # noqa: B006
|
|
118
|
+
):
|
|
119
|
+
attention_op_name = "MultiHeadAttention" if use_multi_head_attention else "Attention"
|
|
120
|
+
super().__init__(model, attention_op_name, search_op_types)
|
|
121
|
+
self.hidden_size = hidden_size
|
|
122
|
+
self.num_heads = num_heads
|
|
123
|
+
self.attention_mask = attention_mask if attention_mask else AttentionMask(model)
|
|
124
|
+
self.use_multi_head_attention = use_multi_head_attention
|
|
125
|
+
self.disable_multi_head_attention_bias = disable_multi_head_attention_bias
|
|
126
|
+
self.mask_filter_value = None
|
|
127
|
+
|
|
128
|
+
# Flags to show warning only once
|
|
129
|
+
self.num_heads_warning = True
|
|
130
|
+
self.hidden_size_warning = True
|
|
131
|
+
|
|
132
|
+
self.shape_infer = None
|
|
133
|
+
self.shape_infer_done = True
|
|
134
|
+
|
|
135
|
+
def get_num_heads_and_hidden_size_from_concat(self, concat: NodeProto) -> Tuple[int, int]:
|
|
136
|
+
"""
|
|
137
|
+
Detect num_heads and hidden_size from Concat node in the following subgraph:
|
|
138
|
+
|
|
139
|
+
SkipLayerNormalization or EmbedLayerNormalization
|
|
140
|
+
/ |
|
|
141
|
+
MatMul Shape
|
|
142
|
+
| |
|
|
143
|
+
Add Gather(indices=0)
|
|
144
|
+
| |
|
|
145
|
+
| Unsqueeze
|
|
146
|
+
| |
|
|
147
|
+
| Concat (*, -1, 12, 64)
|
|
148
|
+
| /
|
|
149
|
+
Reshape
|
|
150
|
+
|
|
|
151
|
+
Transpose
|
|
152
|
+
"""
|
|
153
|
+
if len(concat.input) == 4:
|
|
154
|
+
num_heads = self.model.get_constant_value(concat.input[2])
|
|
155
|
+
head_size = self.model.get_constant_value(concat.input[3])
|
|
156
|
+
if (
|
|
157
|
+
isinstance(num_heads, np.ndarray)
|
|
158
|
+
and num_heads.size == 1
|
|
159
|
+
and isinstance(head_size, np.ndarray)
|
|
160
|
+
and head_size.size == 1
|
|
161
|
+
):
|
|
162
|
+
return num_heads[0], num_heads[0] * head_size[0]
|
|
163
|
+
|
|
164
|
+
return self.num_heads, self.hidden_size
|
|
165
|
+
|
|
166
|
+
def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int]:
|
|
167
|
+
"""Detect num_heads and hidden_size from a reshape node.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
reshape_q (NodeProto): reshape node for Q
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
Tuple[int, int]: num_heads and hidden_size
|
|
174
|
+
"""
|
|
175
|
+
# we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
|
|
176
|
+
q_shape = self.model.get_initializer(reshape_q.input[1])
|
|
177
|
+
if q_shape is None:
|
|
178
|
+
concat = self.model.get_parent(reshape_q, 1)
|
|
179
|
+
if concat is not None and concat.op_type == "Concat":
|
|
180
|
+
return self.get_num_heads_and_hidden_size_from_concat(concat)
|
|
181
|
+
logger.debug(f"{reshape_q.input[1]} is not initializer.")
|
|
182
|
+
return self.num_heads, self.hidden_size # Fall back to user specified value
|
|
183
|
+
|
|
184
|
+
q_shape_value = NumpyHelper.to_array(q_shape)
|
|
185
|
+
if len(q_shape_value) != 4 or (q_shape_value[2] <= 0 or q_shape_value[3] <= 0):
|
|
186
|
+
logger.debug(f"q_shape_value={q_shape_value}. Expected value are like [0, 0, num_heads, head_size].")
|
|
187
|
+
return self.num_heads, self.hidden_size # Fall back to user specified value
|
|
188
|
+
|
|
189
|
+
num_heads = q_shape_value[2]
|
|
190
|
+
head_size = q_shape_value[3]
|
|
191
|
+
hidden_size = num_heads * head_size
|
|
192
|
+
|
|
193
|
+
if self.num_heads > 0 and num_heads != self.num_heads:
|
|
194
|
+
if self.num_heads_warning:
|
|
195
|
+
logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
|
|
196
|
+
self.num_heads_warning = False # Do not show the warning more than once
|
|
197
|
+
|
|
198
|
+
if self.hidden_size > 0 and hidden_size != self.hidden_size:
|
|
199
|
+
if self.hidden_size_warning:
|
|
200
|
+
logger.warning(
|
|
201
|
+
f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
|
|
202
|
+
)
|
|
203
|
+
self.hidden_size_warning = False # Do not show the warning more than once
|
|
204
|
+
|
|
205
|
+
return num_heads, hidden_size
|
|
206
|
+
|
|
207
|
+
def get_add_qk_str(self, add_qk: NodeProto):
|
|
208
|
+
if not self.shape_infer_done:
|
|
209
|
+
self.shape_infer = self.model.infer_runtime_shape(update=True)
|
|
210
|
+
self.shape_infer_done = True
|
|
211
|
+
|
|
212
|
+
if self.shape_infer is None:
|
|
213
|
+
return None
|
|
214
|
+
|
|
215
|
+
input_0_shape = self.shape_infer.get_edge_shape(add_qk.input[0])
|
|
216
|
+
input_1_shape = self.shape_infer.get_edge_shape(add_qk.input[1])
|
|
217
|
+
|
|
218
|
+
if input_0_shape is None or input_1_shape is None:
|
|
219
|
+
logger.debug(f"one of the inputs of {add_qk} is None")
|
|
220
|
+
return None
|
|
221
|
+
|
|
222
|
+
if input_0_shape != input_1_shape:
|
|
223
|
+
logger.debug(f"the shape of two inputs of {add_qk} is not same")
|
|
224
|
+
return None
|
|
225
|
+
|
|
226
|
+
return add_qk.input[1]
|
|
227
|
+
|
|
228
|
+
def reshape_add_qk(self, add_qk: str):
|
|
229
|
+
# Convert 4D mask from (B,1,S,T) to (B,N,S,T)
|
|
230
|
+
# B = batch size, N = num heads, S = source sequence length, T = target sequence length
|
|
231
|
+
mask_output_name = add_qk + "_mask"
|
|
232
|
+
|
|
233
|
+
# Check if concat node for (B,1,S,T) --> (B,N,S,T) already exists
|
|
234
|
+
concat_node = list(filter(lambda node: node.output[0] == mask_output_name, self.nodes_to_add))
|
|
235
|
+
if len(concat_node) == 1:
|
|
236
|
+
return mask_output_name
|
|
237
|
+
|
|
238
|
+
assert len(concat_node) == 0
|
|
239
|
+
concat_node_name = self.model.create_node_name("Concat")
|
|
240
|
+
concat_add_qk_fp32 = helper.make_node(
|
|
241
|
+
"Concat",
|
|
242
|
+
inputs=[add_qk for _ in range(self.num_heads)],
|
|
243
|
+
outputs=[mask_output_name],
|
|
244
|
+
name=concat_node_name,
|
|
245
|
+
axis=1,
|
|
246
|
+
)
|
|
247
|
+
# Add new node to graph
|
|
248
|
+
self.nodes_to_add.append(concat_add_qk_fp32)
|
|
249
|
+
self.node_name_to_graph_name[concat_node_name] = self.this_graph_name
|
|
250
|
+
|
|
251
|
+
return mask_output_name
|
|
252
|
+
|
|
253
|
+
def concat_kv(self, past_k: str, past_v: str) -> str:
|
|
254
|
+
"""Concatenate past_k and past_v inputs to create past_kv input.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
past_k (str): name of past K value
|
|
258
|
+
past_v (str): name of past V value
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
kv_output_name (str): name of past KV value
|
|
262
|
+
"""
|
|
263
|
+
# Unsqueeze K and V nodes from (B,N,P,H) to (1,B,N,P,H)
|
|
264
|
+
# B = batch size, N = num heads, P = past sequence length, H = head size
|
|
265
|
+
unsqueeze_k_name = self.model.create_node_name("Unsqueeze")
|
|
266
|
+
unsqueeze_v_name = self.model.create_node_name("Unsqueeze")
|
|
267
|
+
k_5d_name = (past_k + "_5d").replace(".", "_")
|
|
268
|
+
v_5d_name = (past_v + "_5d").replace(".", "_")
|
|
269
|
+
|
|
270
|
+
k_5d = helper.make_node(
|
|
271
|
+
"Unsqueeze",
|
|
272
|
+
inputs=[past_k],
|
|
273
|
+
outputs=[k_5d_name],
|
|
274
|
+
name=unsqueeze_k_name,
|
|
275
|
+
axes=[0],
|
|
276
|
+
)
|
|
277
|
+
v_5d = helper.make_node(
|
|
278
|
+
"Unsqueeze",
|
|
279
|
+
inputs=[past_v],
|
|
280
|
+
outputs=[v_5d_name],
|
|
281
|
+
name=unsqueeze_v_name,
|
|
282
|
+
axes=[0],
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
# Add unsqueeze nodes to graph
|
|
286
|
+
self.nodes_to_add.append(k_5d)
|
|
287
|
+
self.nodes_to_add.append(v_5d)
|
|
288
|
+
self.node_name_to_graph_name[unsqueeze_k_name] = self.this_graph_name
|
|
289
|
+
self.node_name_to_graph_name[unsqueeze_v_name] = self.this_graph_name
|
|
290
|
+
|
|
291
|
+
# Concat K and V to get one node of size (2,B,N,P,H)
|
|
292
|
+
concat_node_name = self.model.create_node_name("Concat")
|
|
293
|
+
kv_output_name = past_v.replace(".value", ".kv").replace(".", "_").replace("_value", "_kv")
|
|
294
|
+
concat_kv = helper.make_node(
|
|
295
|
+
"Concat",
|
|
296
|
+
inputs=[k_5d_name, v_5d_name],
|
|
297
|
+
outputs=[kv_output_name],
|
|
298
|
+
name=concat_node_name,
|
|
299
|
+
axis=0,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
# Add concat node to graph
|
|
303
|
+
self.nodes_to_add.append(concat_kv)
|
|
304
|
+
self.node_name_to_graph_name[concat_node_name] = self.this_graph_name
|
|
305
|
+
|
|
306
|
+
return kv_output_name
|
|
307
|
+
|
|
308
|
+
def reshape_kv(self, past_k: str, past_v: str) -> (str, str):
|
|
309
|
+
"""Reshape past_k and past_v from 4D to 3D to use as inputs for multihead attention node.
|
|
310
|
+
|
|
311
|
+
Args:
|
|
312
|
+
past_k (str): name of past K value of shape 4D
|
|
313
|
+
past_v (str): name of past V value of shape 4D
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
k_3d (str): name of past K value of shape 3D
|
|
317
|
+
v_3d (str): name of past V value of shape 3D
|
|
318
|
+
"""
|
|
319
|
+
# Reshape past_k and past_v from (B,N,P,H) to (B,P,N*H)
|
|
320
|
+
# B = batch size, N = num heads, P = past seq len, H = head size
|
|
321
|
+
|
|
322
|
+
# Create initializer for reshaping past_k and past_v
|
|
323
|
+
new_dims_name = "kv_4d_to_3d"
|
|
324
|
+
new_dims = self.model.get_initializer(new_dims_name)
|
|
325
|
+
if new_dims is None:
|
|
326
|
+
new_dims = numpy_helper.from_array(
|
|
327
|
+
np.array([0, -1, self.model.hidden_size], dtype="int64"), name=new_dims_name
|
|
328
|
+
)
|
|
329
|
+
self.model.add_initializer(new_dims, self.this_graph_name)
|
|
330
|
+
|
|
331
|
+
reshape_k_name = self.model.create_node_name("Reshape")
|
|
332
|
+
reshape_v_name = self.model.create_node_name("Reshape")
|
|
333
|
+
k_3d_name = (past_k + "_3d").replace(".", "_")
|
|
334
|
+
v_3d_name = (past_v + "_3d").replace(".", "_")
|
|
335
|
+
|
|
336
|
+
k_3d = helper.make_node(
|
|
337
|
+
"Reshape",
|
|
338
|
+
inputs=[past_k, new_dims_name],
|
|
339
|
+
outputs=[k_3d_name],
|
|
340
|
+
name=reshape_k_name,
|
|
341
|
+
)
|
|
342
|
+
v_3d = helper.make_node(
|
|
343
|
+
"Reshape",
|
|
344
|
+
inputs=[past_v, new_dims_name],
|
|
345
|
+
outputs=[v_3d_name],
|
|
346
|
+
name=reshape_v_name,
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
# Add reshape nodes to graph
|
|
350
|
+
self.nodes_to_add.append(k_3d)
|
|
351
|
+
self.nodes_to_add.append(v_3d)
|
|
352
|
+
self.node_name_to_graph_name[reshape_k_name] = self.this_graph_name
|
|
353
|
+
self.node_name_to_graph_name[reshape_v_name] = self.this_graph_name
|
|
354
|
+
|
|
355
|
+
return k_3d_name, v_3d_name
|
|
356
|
+
|
|
357
|
+
def split_kv(self, present_k_name: str, present_v_name: str, kv_node: str):
|
|
358
|
+
"""Split kv_node containing present KV values into separate present K and present V values.
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
present_k_name (str): name of output to store present K value in
|
|
362
|
+
present_v_name (str): name of output to store present V value in
|
|
363
|
+
kv_node (str): name of present KV values
|
|
364
|
+
"""
|
|
365
|
+
# Split kv_node into present_k and present_v nodes
|
|
366
|
+
|
|
367
|
+
# Create initializers for indexing kv_node, whose shape is (2,B,N,P,H)
|
|
368
|
+
k_index, v_index = "index_0", "index_1"
|
|
369
|
+
k_dim = self.model.get_initializer(k_index)
|
|
370
|
+
v_dim = self.model.get_initializer(v_index)
|
|
371
|
+
if k_dim is None:
|
|
372
|
+
k_dim = numpy_helper.from_array(np.array(0, dtype="int64"), name=k_index)
|
|
373
|
+
self.model.add_initializer(k_dim, self.this_graph_name)
|
|
374
|
+
if v_dim is None:
|
|
375
|
+
v_dim = numpy_helper.from_array(np.array(1, dtype="int64"), name=v_index)
|
|
376
|
+
self.model.add_initializer(v_dim, self.this_graph_name)
|
|
377
|
+
|
|
378
|
+
# Create nodes to index kv_node
|
|
379
|
+
gather_k_name = self.model.create_node_name("Gather")
|
|
380
|
+
gather_v_name = self.model.create_node_name("Gather")
|
|
381
|
+
present_k = helper.make_node(
|
|
382
|
+
"Gather",
|
|
383
|
+
inputs=[kv_node, k_index],
|
|
384
|
+
outputs=[present_k_name],
|
|
385
|
+
name=gather_k_name,
|
|
386
|
+
axis=0,
|
|
387
|
+
)
|
|
388
|
+
present_v = helper.make_node(
|
|
389
|
+
"Gather",
|
|
390
|
+
inputs=[kv_node, v_index],
|
|
391
|
+
outputs=[present_v_name],
|
|
392
|
+
name=gather_v_name,
|
|
393
|
+
axis=0,
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
# Add gather nodes to graph
|
|
397
|
+
self.nodes_to_add.append(present_k)
|
|
398
|
+
self.nodes_to_add.append(present_v)
|
|
399
|
+
self.node_name_to_graph_name[gather_k_name] = self.this_graph_name
|
|
400
|
+
self.node_name_to_graph_name[gather_v_name] = self.this_graph_name
|
|
401
|
+
|
|
402
|
+
def transpose_kv(self, past_k: str, past_v: str):
|
|
403
|
+
"""Transpose past_k and past_v from (B,N,P,H) to (B,P,N,H)
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
past_k (str): name of past K value of shape (B,N,P,H)
|
|
407
|
+
past_v (str): name of past V value of shape (B,N,P,H)
|
|
408
|
+
|
|
409
|
+
Returns:
|
|
410
|
+
past_k_transpose (str): name of past K value of shape (B,P,N,H)
|
|
411
|
+
past_v_transpose (str): name of past V value of shape (B,P,N,H)
|
|
412
|
+
"""
|
|
413
|
+
past_k_transpose = (past_k + "_transposed").replace(".", "_")
|
|
414
|
+
past_v_transpose = (past_v + "_transposed").replace(".", "_")
|
|
415
|
+
transpose_k_name = self.model.create_node_name("Transpose")
|
|
416
|
+
transpose_v_name = self.model.create_node_name("Transpose")
|
|
417
|
+
|
|
418
|
+
transpose_k = helper.make_node(
|
|
419
|
+
"Transpose",
|
|
420
|
+
inputs=[past_k],
|
|
421
|
+
outputs=[past_k_transpose],
|
|
422
|
+
name=transpose_k_name,
|
|
423
|
+
perm=[0, 2, 1, 3],
|
|
424
|
+
)
|
|
425
|
+
transpose_v = helper.make_node(
|
|
426
|
+
"Transpose",
|
|
427
|
+
inputs=[past_v],
|
|
428
|
+
outputs=[past_v_transpose],
|
|
429
|
+
name=transpose_v_name,
|
|
430
|
+
perm=[0, 2, 1, 3],
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
# Add reshape nodes to graph
|
|
434
|
+
self.nodes_to_add.append(transpose_k)
|
|
435
|
+
self.nodes_to_add.append(transpose_v)
|
|
436
|
+
self.node_name_to_graph_name[transpose_k_name] = self.this_graph_name
|
|
437
|
+
self.node_name_to_graph_name[transpose_v_name] = self.this_graph_name
|
|
438
|
+
|
|
439
|
+
return past_k_transpose, past_v_transpose
|
|
440
|
+
|
|
441
|
+
def create_combined_qkv_bias(
|
|
442
|
+
self,
|
|
443
|
+
q_add: NodeProto,
|
|
444
|
+
k_add: Union[NodeProto, None],
|
|
445
|
+
v_add: Union[NodeProto, None],
|
|
446
|
+
name_prefix: str,
|
|
447
|
+
) -> Union[NodeProto, None]:
|
|
448
|
+
q_bias = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0])
|
|
449
|
+
qb = NumpyHelper.to_array(q_bias)
|
|
450
|
+
kb = np.zeros_like(qb)
|
|
451
|
+
vb = np.zeros_like(qb)
|
|
452
|
+
if k_add is not None:
|
|
453
|
+
k_bias = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0])
|
|
454
|
+
kb = NumpyHelper.to_array(k_bias)
|
|
455
|
+
if v_add is not None:
|
|
456
|
+
v_bias = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])
|
|
457
|
+
vb = NumpyHelper.to_array(v_bias)
|
|
458
|
+
|
|
459
|
+
qkv_bias = np.stack((qb, kb, vb), axis=0)
|
|
460
|
+
qkv_bias_dim = 3 * np.prod(qb.shape)
|
|
461
|
+
|
|
462
|
+
bias_name = name_prefix + "_qkv_bias"
|
|
463
|
+
self.add_initializer(
|
|
464
|
+
name=bias_name,
|
|
465
|
+
data_type=q_bias.data_type,
|
|
466
|
+
dims=[qkv_bias_dim],
|
|
467
|
+
vals=qkv_bias,
|
|
468
|
+
)
|
|
469
|
+
return bias_name
|
|
470
|
+
|
|
471
|
+
def create_packed_qkv_matmul_node(
|
|
472
|
+
self,
|
|
473
|
+
q_matmul: NodeProto,
|
|
474
|
+
k_matmul: NodeProto,
|
|
475
|
+
v_matmul: NodeProto,
|
|
476
|
+
q_add: NodeProto,
|
|
477
|
+
k_add: Union[NodeProto, None],
|
|
478
|
+
v_add: Union[NodeProto, None],
|
|
479
|
+
num_heads: int,
|
|
480
|
+
) -> Union[NodeProto, None]:
|
|
481
|
+
"""Create packed QKV MatMul node before MultiHeadAttention node.
|
|
482
|
+
This is for the scenario where an Attention node should be created but cannot be created
|
|
483
|
+
because past_key and past_value are separate inputs and not one concatenated input.
|
|
484
|
+
|
|
485
|
+
Args:
|
|
486
|
+
q_matmul (NodeProto): name of MatMul from Q path - (batch_size, sequence_length, hidden_size)
|
|
487
|
+
k_matmul (NodeProto): name of MatMul from K path - (batch_size, sequence_length, hidden_size)
|
|
488
|
+
v_matmul (NodeProto): name of MatMul from V path - (batch_size, sequence_length, hidden_size)
|
|
489
|
+
q_add (NodeProto): name of Add from Q path
|
|
490
|
+
k_add (NodeProto): name of Add from K path
|
|
491
|
+
v_add (NodeProto): name of Add from V path
|
|
492
|
+
num_heads (int): number of heads
|
|
493
|
+
|
|
494
|
+
Returns:
|
|
495
|
+
Union[NodeProto, None]: the node created or None if failed.
|
|
496
|
+
"""
|
|
497
|
+
matmul_node_name = self.model.create_node_name("MatMul")
|
|
498
|
+
|
|
499
|
+
# Check that input for Q, K, V is the same
|
|
500
|
+
assert q_matmul.input[0] == k_matmul.input[0] and k_matmul.input[0] == v_matmul.input[0]
|
|
501
|
+
|
|
502
|
+
# Created packed QKV weight
|
|
503
|
+
q_weight = self.model.get_initializer(q_matmul.input[1])
|
|
504
|
+
k_weight = self.model.get_initializer(k_matmul.input[1])
|
|
505
|
+
v_weight = self.model.get_initializer(v_matmul.input[1])
|
|
506
|
+
|
|
507
|
+
qw = NumpyHelper.to_array(q_weight)
|
|
508
|
+
kw = NumpyHelper.to_array(k_weight)
|
|
509
|
+
vw = NumpyHelper.to_array(v_weight)
|
|
510
|
+
|
|
511
|
+
assert qw.shape == kw.shape and kw.shape == vw.shape
|
|
512
|
+
d = qw.shape[0]
|
|
513
|
+
|
|
514
|
+
qkv_weight = np.stack((qw, kw, vw), axis=1).reshape((d, 3 * d))
|
|
515
|
+
qkv_weight_name = matmul_node_name + "_qkv_weight"
|
|
516
|
+
|
|
517
|
+
self.add_initializer(
|
|
518
|
+
name=qkv_weight_name,
|
|
519
|
+
data_type=q_weight.data_type,
|
|
520
|
+
dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
|
|
521
|
+
vals=qkv_weight,
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
# Created packed QKV MatMul with output (B, S, 3*D)
|
|
525
|
+
# Output is of the form:
|
|
526
|
+
#
|
|
527
|
+
# [[[Q Q ... Q Q K K ... K K V V ... V V]]]
|
|
528
|
+
# [Q Q ... Q Q K K ... K K V V ... V V]
|
|
529
|
+
# .
|
|
530
|
+
# .
|
|
531
|
+
# .
|
|
532
|
+
# [[Q Q ... Q Q K K ... K K V V ... V V]
|
|
533
|
+
# [Q Q ... Q Q K K ... K K V V ... V V]]]
|
|
534
|
+
qkv_matmul_output = matmul_node_name + "_qkv_out"
|
|
535
|
+
qkv_matmul = helper.make_node(
|
|
536
|
+
"MatMul",
|
|
537
|
+
inputs=[q_matmul.input[0], qkv_weight_name],
|
|
538
|
+
outputs=[qkv_matmul_output],
|
|
539
|
+
name=matmul_node_name,
|
|
540
|
+
)
|
|
541
|
+
self.node_name_to_graph_name[matmul_node_name] = self.this_graph_name
|
|
542
|
+
|
|
543
|
+
qkv_nodes = [qkv_matmul]
|
|
544
|
+
|
|
545
|
+
# Create Slice nodes to access Q, K, V
|
|
546
|
+
q_slice_name = matmul_node_name + "_q_start_index"
|
|
547
|
+
self.add_initializer(name=q_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[0], raw=False)
|
|
548
|
+
k_slice_name = matmul_node_name + "_k_start_index"
|
|
549
|
+
self.add_initializer(name=k_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[d], raw=False)
|
|
550
|
+
v_slice_name = matmul_node_name + "_v_start_index"
|
|
551
|
+
self.add_initializer(name=v_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[2 * d], raw=False)
|
|
552
|
+
end_of_qkv_name = matmul_node_name + "_end_of_qkv_index"
|
|
553
|
+
self.add_initializer(name=end_of_qkv_name, data_type=TensorProto.INT64, dims=[1], vals=[3 * d], raw=False)
|
|
554
|
+
qkv_last_axis_name = matmul_node_name + "_qkv_last_axis"
|
|
555
|
+
self.add_initializer(name=qkv_last_axis_name, data_type=TensorProto.INT64, dims=[1], vals=[-1], raw=False)
|
|
556
|
+
|
|
557
|
+
q_slice_output = matmul_node_name + "_q_out"
|
|
558
|
+
q_slice = helper.make_node(
|
|
559
|
+
"Slice",
|
|
560
|
+
inputs=[qkv_matmul_output, q_slice_name, k_slice_name, qkv_last_axis_name],
|
|
561
|
+
outputs=[q_slice_output],
|
|
562
|
+
name=self.model.create_node_name("Slice"),
|
|
563
|
+
)
|
|
564
|
+
self.node_name_to_graph_name[q_slice.name] = self.this_graph_name
|
|
565
|
+
k_slice_output = matmul_node_name + "_k_out"
|
|
566
|
+
k_slice = helper.make_node(
|
|
567
|
+
"Slice",
|
|
568
|
+
inputs=[qkv_matmul_output, k_slice_name, v_slice_name, qkv_last_axis_name],
|
|
569
|
+
outputs=[k_slice_output],
|
|
570
|
+
name=self.model.create_node_name("Slice"),
|
|
571
|
+
)
|
|
572
|
+
self.node_name_to_graph_name[k_slice.name] = self.this_graph_name
|
|
573
|
+
v_slice_output = matmul_node_name + "_v_out"
|
|
574
|
+
v_slice = helper.make_node(
|
|
575
|
+
"Slice",
|
|
576
|
+
inputs=[qkv_matmul_output, v_slice_name, end_of_qkv_name, qkv_last_axis_name],
|
|
577
|
+
outputs=[v_slice_output],
|
|
578
|
+
name=self.model.create_node_name("Slice"),
|
|
579
|
+
)
|
|
580
|
+
self.node_name_to_graph_name[v_slice.name] = self.this_graph_name
|
|
581
|
+
|
|
582
|
+
q_output = q_slice
|
|
583
|
+
k_output = k_slice
|
|
584
|
+
v_output = v_slice
|
|
585
|
+
qkv_nodes.extend([q_slice, k_slice, v_slice])
|
|
586
|
+
|
|
587
|
+
if self.disable_multi_head_attention_bias:
|
|
588
|
+
if q_add is not None:
|
|
589
|
+
initializer_input = 1 if self.model.get_initializer(q_add.input[1]) else 0
|
|
590
|
+
if np.any(NumpyHelper.to_array(self.model.get_initializer(q_add.input[initializer_input]))):
|
|
591
|
+
q_add.input[1 - initializer_input] = q_slice_output
|
|
592
|
+
q_output = q_add
|
|
593
|
+
qkv_nodes.append(q_add)
|
|
594
|
+
self.node_name_to_graph_name[q_add.name] = self.this_graph_name
|
|
595
|
+
if k_add is not None:
|
|
596
|
+
initializer_input = 1 if self.model.get_initializer(k_add.input[1]) else 0
|
|
597
|
+
if np.any(NumpyHelper.to_array(self.model.get_initializer(k_add.input[initializer_input]))):
|
|
598
|
+
k_add.input[1 - initializer_input] = k_slice_output
|
|
599
|
+
k_output = k_add
|
|
600
|
+
qkv_nodes.append(k_add)
|
|
601
|
+
self.node_name_to_graph_name[k_add.name] = self.this_graph_name
|
|
602
|
+
if v_add is not None:
|
|
603
|
+
initializer_input = 1 if self.model.get_initializer(v_add.input[1]) else 0
|
|
604
|
+
if np.any(NumpyHelper.to_array(self.model.get_initializer(v_add.input[initializer_input]))):
|
|
605
|
+
v_add.input[1 - initializer_input] = v_slice_output
|
|
606
|
+
v_output = v_add
|
|
607
|
+
qkv_nodes.append(v_add)
|
|
608
|
+
self.node_name_to_graph_name[v_add.name] = self.this_graph_name
|
|
609
|
+
|
|
610
|
+
# Add nodes to graph
|
|
611
|
+
self.nodes_to_add.extend(qkv_nodes)
|
|
612
|
+
return q_output, k_output, v_output
|
|
613
|
+
|
|
614
|
+
def create_multihead_attention_node(
|
|
615
|
+
self,
|
|
616
|
+
q_matmul: NodeProto,
|
|
617
|
+
k_matmul: Union[NodeProto, str, None],
|
|
618
|
+
v_matmul: Union[NodeProto, str, None],
|
|
619
|
+
q_add: NodeProto,
|
|
620
|
+
k_add: Union[NodeProto, None],
|
|
621
|
+
v_add: Union[NodeProto, None],
|
|
622
|
+
num_heads: int,
|
|
623
|
+
hidden_size: int,
|
|
624
|
+
output: str,
|
|
625
|
+
key_padding_mask: str = "",
|
|
626
|
+
add_qk: str = "",
|
|
627
|
+
past_k: str = "",
|
|
628
|
+
past_v: str = "",
|
|
629
|
+
present_k: str = "",
|
|
630
|
+
present_v: str = "",
|
|
631
|
+
packed_qkv: bool = False,
|
|
632
|
+
) -> Union[NodeProto, None]:
|
|
633
|
+
"""Create a MultiHeadAttention node.
|
|
634
|
+
|
|
635
|
+
Args:
|
|
636
|
+
q_matmul (NodeProto): name of MatMul from Q path - (batch_size, sequence_length, hidden_size)
|
|
637
|
+
k_matmul (NodeProto): name of MatMul from K path - (batch_size, sequence_length, hidden_size) or (batch_size, num_heads, past_sequence_length, head_size)
|
|
638
|
+
v_matmul (NodeProto): name of MatMul from V path - (batch_size, sequence_length, hidden_size) or (batch_size, num_heads, past_sequence_length, head_size)
|
|
639
|
+
q_add (NodeProto): name of Add from Q path
|
|
640
|
+
k_add (NodeProto): name of Add from K path
|
|
641
|
+
v_add (NodeProto): name of Add from V path
|
|
642
|
+
num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
|
|
643
|
+
hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
|
|
644
|
+
output (str): output name of MHA
|
|
645
|
+
key_padding_mask (str): name of key padding mask
|
|
646
|
+
add_qk (str): name of add after Q x K'
|
|
647
|
+
past_k (str): name of past K value - (batch_size, num_heads, past_sequence_length, head_size)
|
|
648
|
+
past_v (str): name of past V value - (batch_size, num_heads, past_sequence_length, head_size)
|
|
649
|
+
present_k (str): name of present K value - (batch_size, num_heads, sequence_length, head_size)
|
|
650
|
+
present_v (str): name of present V value - (batch_size, num_heads, sequence_length, head_size)
|
|
651
|
+
packed_qkv (bool): whether to combine MatMuls from Q, K, V paths
|
|
652
|
+
Note: This is for the scenario where an Attention node should be created but cannot be created
|
|
653
|
+
because past_key and past_value are separate inputs and not one concatenated input.
|
|
654
|
+
|
|
655
|
+
Returns:
|
|
656
|
+
Union[NodeProto, None]: the node created or None if failed.
|
|
657
|
+
"""
|
|
658
|
+
# B = batch size, N = num heads, P = past seq len, H = head size
|
|
659
|
+
assert num_heads > 0
|
|
660
|
+
|
|
661
|
+
if hidden_size > 0 and (hidden_size % num_heads) != 0:
|
|
662
|
+
logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
|
|
663
|
+
return None
|
|
664
|
+
|
|
665
|
+
graph_input_names = set([node.name for node in self.model.graph().input])
|
|
666
|
+
mha_node_name = self.model.create_node_name("Attention")
|
|
667
|
+
|
|
668
|
+
# Add initial Q/K/V inputs for MHA
|
|
669
|
+
mha_inputs = []
|
|
670
|
+
if packed_qkv:
|
|
671
|
+
q_slice, k_slice, v_slice = self.create_packed_qkv_matmul_node(
|
|
672
|
+
q_matmul, k_matmul, v_matmul, q_add, k_add, v_add, num_heads
|
|
673
|
+
)
|
|
674
|
+
mha_inputs.extend([q_slice.output[0], k_slice.output[0], v_slice.output[0]])
|
|
675
|
+
elif type(k_matmul) is NodeProto and type(v_matmul) is NodeProto:
|
|
676
|
+
if self.disable_multi_head_attention_bias:
|
|
677
|
+
mha_inputs.extend([q_add.output[0], k_matmul.output[0], v_add.output[0]])
|
|
678
|
+
else:
|
|
679
|
+
mha_inputs.extend([q_matmul.output[0], k_matmul.output[0], v_matmul.output[0]])
|
|
680
|
+
elif (
|
|
681
|
+
type(k_matmul) == str # noqa: E721
|
|
682
|
+
and type(v_matmul) == str # noqa: E721
|
|
683
|
+
and k_matmul in graph_input_names
|
|
684
|
+
and v_matmul in graph_input_names
|
|
685
|
+
):
|
|
686
|
+
if self.disable_multi_head_attention_bias:
|
|
687
|
+
mha_inputs.extend([q_add.output[0], k_matmul, v_matmul])
|
|
688
|
+
else:
|
|
689
|
+
mha_inputs.extend([q_matmul.output[0], k_matmul, v_matmul])
|
|
690
|
+
else:
|
|
691
|
+
return None
|
|
692
|
+
|
|
693
|
+
# Add bias to inputs for MHA
|
|
694
|
+
# Bias for cross attention is not fully supported in DMMHA and cpu MHA kernels since they assume
|
|
695
|
+
# bias has been added to key and value when they are in BNSH format, so only bias for query is used.
|
|
696
|
+
# Need add checks if we found such assumption is not true.
|
|
697
|
+
if not self.disable_multi_head_attention_bias:
|
|
698
|
+
bias_name = self.create_combined_qkv_bias(q_add, k_add, v_add, mha_node_name)
|
|
699
|
+
mha_inputs.append(bias_name)
|
|
700
|
+
else:
|
|
701
|
+
mha_inputs.append("")
|
|
702
|
+
|
|
703
|
+
# Add optional inputs for MHA
|
|
704
|
+
|
|
705
|
+
if past_k and past_v:
|
|
706
|
+
mha_inputs.extend([key_padding_mask, add_qk, past_k, past_v])
|
|
707
|
+
elif key_padding_mask or add_qk:
|
|
708
|
+
mha_inputs.extend([key_padding_mask, add_qk])
|
|
709
|
+
|
|
710
|
+
# Add outputs for MHA
|
|
711
|
+
mha_outputs = [output]
|
|
712
|
+
if present_k and present_v:
|
|
713
|
+
mha_outputs.extend([present_k, present_v])
|
|
714
|
+
|
|
715
|
+
mha_node = helper.make_node(
|
|
716
|
+
"MultiHeadAttention",
|
|
717
|
+
inputs=mha_inputs,
|
|
718
|
+
outputs=mha_outputs,
|
|
719
|
+
name=mha_node_name,
|
|
720
|
+
)
|
|
721
|
+
mha_node.domain = "com.microsoft"
|
|
722
|
+
mha_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
|
|
723
|
+
return mha_node
|
|
724
|
+
|
|
725
|
+
def create_attention_node(
|
|
726
|
+
self,
|
|
727
|
+
mask_index: str,
|
|
728
|
+
q_matmul: NodeProto,
|
|
729
|
+
k_matmul: NodeProto,
|
|
730
|
+
v_matmul: NodeProto,
|
|
731
|
+
q_add: NodeProto,
|
|
732
|
+
k_add: NodeProto,
|
|
733
|
+
v_add: NodeProto,
|
|
734
|
+
num_heads: int,
|
|
735
|
+
hidden_size: int,
|
|
736
|
+
input: str,
|
|
737
|
+
output: str,
|
|
738
|
+
add_qk_str: str = "",
|
|
739
|
+
past_k: str = "",
|
|
740
|
+
past_v: str = "",
|
|
741
|
+
present_k: str = "",
|
|
742
|
+
present_v: str = "",
|
|
743
|
+
scale: Optional[float] = None,
|
|
744
|
+
causal: bool = False,
|
|
745
|
+
) -> Union[NodeProto, None]:
|
|
746
|
+
"""Create an Attention node.
|
|
747
|
+
|
|
748
|
+
Args:
|
|
749
|
+
mask_index (str): mask input
|
|
750
|
+
q_matmul (NodeProto): MatMul node in fully connection for Q
|
|
751
|
+
k_matmul (NodeProto): MatMul node in fully connection for K
|
|
752
|
+
v_matmul (NodeProto): MatMul node in fully connection for V
|
|
753
|
+
q_add (NodeProto): Add bias node in fully connection for Q
|
|
754
|
+
k_add (NodeProto): Add bias node in fully connection for K
|
|
755
|
+
v_add (NodeProto): Add bias node in fully connection for V
|
|
756
|
+
num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
|
|
757
|
+
hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
|
|
758
|
+
input (str): input name
|
|
759
|
+
output (str): output name
|
|
760
|
+
add_qk_str (str): name of Add node after Q x K'
|
|
761
|
+
past_k (str): name of input for past K value
|
|
762
|
+
past_v (str): name of input for past V value
|
|
763
|
+
present_k (str): name of output to store present K value
|
|
764
|
+
present_v (str): name of output to store present V value
|
|
765
|
+
scale: scale before softmax
|
|
766
|
+
causal: whether it is uni-directional mask.
|
|
767
|
+
|
|
768
|
+
Returns:
|
|
769
|
+
Union[NodeProto, None]: the node created or None if failed.
|
|
770
|
+
"""
|
|
771
|
+
assert num_heads > 0
|
|
772
|
+
|
|
773
|
+
if hidden_size > 0 and (hidden_size % num_heads) != 0:
|
|
774
|
+
logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
|
|
775
|
+
return None
|
|
776
|
+
|
|
777
|
+
has_bias = True
|
|
778
|
+
if q_add is None and k_add is None and v_add is None:
|
|
779
|
+
has_bias = False
|
|
780
|
+
|
|
781
|
+
q_weight = self.model.get_initializer(q_matmul.input[1])
|
|
782
|
+
k_weight = self.model.get_initializer(k_matmul.input[1])
|
|
783
|
+
v_weight = self.model.get_initializer(v_matmul.input[1])
|
|
784
|
+
|
|
785
|
+
q_bias, k_bias, v_bias = None, None, None
|
|
786
|
+
if has_bias:
|
|
787
|
+
q_bias = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0])
|
|
788
|
+
k_bias = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0])
|
|
789
|
+
v_bias = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])
|
|
790
|
+
|
|
791
|
+
if not (k_weight and v_weight and q_bias and k_bias):
|
|
792
|
+
return None
|
|
793
|
+
|
|
794
|
+
if q_weight is None:
|
|
795
|
+
print(
|
|
796
|
+
f"{q_matmul.input[1]} is not an initializer. "
|
|
797
|
+
"Please set do_constant_folding=True in torch.onnx.export to unblock attention fusion"
|
|
798
|
+
)
|
|
799
|
+
return None
|
|
800
|
+
|
|
801
|
+
qw = NumpyHelper.to_array(q_weight)
|
|
802
|
+
kw = NumpyHelper.to_array(k_weight)
|
|
803
|
+
vw = NumpyHelper.to_array(v_weight)
|
|
804
|
+
|
|
805
|
+
# assert q and k have same shape as expected
|
|
806
|
+
assert qw.shape == kw.shape
|
|
807
|
+
|
|
808
|
+
qw_in_size = qw.shape[0]
|
|
809
|
+
kw_in_size = kw.shape[0]
|
|
810
|
+
vw_in_size = vw.shape[0]
|
|
811
|
+
|
|
812
|
+
assert qw_in_size == kw_in_size == vw_in_size
|
|
813
|
+
|
|
814
|
+
if hidden_size > 0 and hidden_size != qw_in_size:
|
|
815
|
+
logger.warning(
|
|
816
|
+
f"Input hidden size ({hidden_size}) is not same as weight matrix dimension of q,k,v ({qw_in_size}). "
|
|
817
|
+
"Please provide a correct input hidden size or pass in 0"
|
|
818
|
+
)
|
|
819
|
+
|
|
820
|
+
is_qkv_diff_dims = False
|
|
821
|
+
if qw.shape != vw.shape:
|
|
822
|
+
is_qkv_diff_dims = True
|
|
823
|
+
|
|
824
|
+
# All the matrices can have the same shape or q, k matrices can have the same shape with v being different
|
|
825
|
+
# For 2d weights, the shapes would be [in_size, out_size].
|
|
826
|
+
# For 3d weights, shape would be [in_size, a, b] where a*b = out_size
|
|
827
|
+
qw_out_size = np.prod(qw.shape[1:])
|
|
828
|
+
kw_out_size = np.prod(kw.shape[1:])
|
|
829
|
+
vw_out_size = np.prod(vw.shape[1:])
|
|
830
|
+
|
|
831
|
+
qkv_weight_dim = 0
|
|
832
|
+
if is_qkv_diff_dims:
|
|
833
|
+
qkv_weight = np.concatenate((qw, kw, vw), axis=1)
|
|
834
|
+
qkv_weight_dim = qw_out_size + kw_out_size + vw_out_size
|
|
835
|
+
else:
|
|
836
|
+
qkv_weight = np.stack((qw, kw, vw), axis=1)
|
|
837
|
+
qkv_weight_dim = 3 * qw_out_size
|
|
838
|
+
|
|
839
|
+
if has_bias:
|
|
840
|
+
qb = NumpyHelper.to_array(q_bias)
|
|
841
|
+
kb = NumpyHelper.to_array(k_bias)
|
|
842
|
+
vb = NumpyHelper.to_array(v_bias)
|
|
843
|
+
|
|
844
|
+
q_bias_shape = np.prod(qb.shape)
|
|
845
|
+
k_bias_shape = np.prod(kb.shape)
|
|
846
|
+
v_bias_shape = np.prod(vb.shape)
|
|
847
|
+
|
|
848
|
+
assert q_bias_shape == k_bias_shape == qw_out_size
|
|
849
|
+
assert v_bias_shape == vw_out_size
|
|
850
|
+
|
|
851
|
+
if is_qkv_diff_dims:
|
|
852
|
+
qkv_bias = np.concatenate((qb, kb, vb), axis=0)
|
|
853
|
+
qkv_bias_dim = q_bias_shape + k_bias_shape + v_bias_shape
|
|
854
|
+
else:
|
|
855
|
+
qkv_bias = np.stack((qb, kb, vb), axis=0)
|
|
856
|
+
qkv_bias_dim = 3 * q_bias_shape
|
|
857
|
+
|
|
858
|
+
attention_node_name = self.model.create_node_name("Attention")
|
|
859
|
+
|
|
860
|
+
if not self.use_multi_head_attention:
|
|
861
|
+
self.add_initializer(
|
|
862
|
+
name=attention_node_name + "_qkv_weight",
|
|
863
|
+
data_type=q_weight.data_type,
|
|
864
|
+
dims=[qw_in_size, qkv_weight_dim],
|
|
865
|
+
vals=qkv_weight,
|
|
866
|
+
)
|
|
867
|
+
|
|
868
|
+
if has_bias:
|
|
869
|
+
self.add_initializer(
|
|
870
|
+
name=attention_node_name + "_qkv_bias",
|
|
871
|
+
data_type=q_bias.data_type,
|
|
872
|
+
dims=[qkv_bias_dim],
|
|
873
|
+
vals=qkv_bias,
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
# For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights.
|
|
877
|
+
if self.use_multi_head_attention:
|
|
878
|
+
if add_qk_str:
|
|
879
|
+
logger.debug("MultiHeadAttention does not support relative_position_bias: cannot fuse the attention.")
|
|
880
|
+
return None
|
|
881
|
+
|
|
882
|
+
attention_inputs = [
|
|
883
|
+
q_matmul.output[0],
|
|
884
|
+
k_matmul.output[0],
|
|
885
|
+
v_matmul.output[0],
|
|
886
|
+
attention_node_name + "_qkv_bias",
|
|
887
|
+
]
|
|
888
|
+
|
|
889
|
+
if mask_index is not None:
|
|
890
|
+
attention_inputs.append(mask_index)
|
|
891
|
+
|
|
892
|
+
attention_node = helper.make_node(
|
|
893
|
+
"MultiHeadAttention",
|
|
894
|
+
inputs=attention_inputs,
|
|
895
|
+
outputs=[output],
|
|
896
|
+
name=attention_node_name,
|
|
897
|
+
)
|
|
898
|
+
else:
|
|
899
|
+
attention_inputs = [
|
|
900
|
+
input,
|
|
901
|
+
attention_node_name + "_qkv_weight",
|
|
902
|
+
attention_node_name + "_qkv_bias" if has_bias else "",
|
|
903
|
+
]
|
|
904
|
+
if mask_index is not None:
|
|
905
|
+
attention_inputs.append(mask_index)
|
|
906
|
+
else:
|
|
907
|
+
attention_inputs.append("")
|
|
908
|
+
|
|
909
|
+
past_exists = past_k and past_v
|
|
910
|
+
if past_exists:
|
|
911
|
+
past_kv = self.concat_kv(past_k, past_v)
|
|
912
|
+
attention_inputs.append(past_kv)
|
|
913
|
+
|
|
914
|
+
if add_qk_str is not None:
|
|
915
|
+
mask_output_name = self.reshape_add_qk(add_qk_str)
|
|
916
|
+
|
|
917
|
+
# Add attention mask to attention node
|
|
918
|
+
if not past_exists:
|
|
919
|
+
attention_inputs.append("")
|
|
920
|
+
attention_inputs.append(mask_output_name)
|
|
921
|
+
|
|
922
|
+
attention_outputs = [output]
|
|
923
|
+
if present_k and present_v:
|
|
924
|
+
present_kv = present_k.replace(".key", "").replace("_key", "").replace(".", "_")
|
|
925
|
+
attention_outputs.append(present_kv)
|
|
926
|
+
self.split_kv(present_k, present_v, present_kv)
|
|
927
|
+
|
|
928
|
+
attention_node = helper.make_node(
|
|
929
|
+
"Attention",
|
|
930
|
+
inputs=attention_inputs,
|
|
931
|
+
outputs=attention_outputs,
|
|
932
|
+
name=attention_node_name,
|
|
933
|
+
)
|
|
934
|
+
|
|
935
|
+
attention_node.domain = "com.microsoft"
|
|
936
|
+
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
|
|
937
|
+
|
|
938
|
+
if causal:
|
|
939
|
+
attention_node.attribute.extend([helper.make_attribute("unidirectional", 1)])
|
|
940
|
+
|
|
941
|
+
if scale is not None:
|
|
942
|
+
attention_node.attribute.extend([helper.make_attribute("scale", scale)])
|
|
943
|
+
|
|
944
|
+
if is_qkv_diff_dims:
|
|
945
|
+
attention_node.attribute.extend(
|
|
946
|
+
[helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])]
|
|
947
|
+
)
|
|
948
|
+
|
|
949
|
+
if self.mask_filter_value is not None:
|
|
950
|
+
attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
|
|
951
|
+
|
|
952
|
+
return attention_node
|
|
953
|
+
|
|
954
|
+
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
|
955
|
+
# Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm
|
|
956
|
+
# Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern
|
|
957
|
+
start_node = normalize_node
|
|
958
|
+
if normalize_node.op_type == "LayerNormalization":
|
|
959
|
+
add_before_layernorm = self.model.match_parent(normalize_node, "Add", 0)
|
|
960
|
+
if add_before_layernorm is not None:
|
|
961
|
+
start_node = add_before_layernorm
|
|
962
|
+
else:
|
|
963
|
+
return
|
|
964
|
+
|
|
965
|
+
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
|
|
966
|
+
qkv_nodes = self.model.match_parent_path(
|
|
967
|
+
start_node,
|
|
968
|
+
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
|
|
969
|
+
[None, None, 0, 0, 0],
|
|
970
|
+
)
|
|
971
|
+
einsum_node = None
|
|
972
|
+
if qkv_nodes is not None:
|
|
973
|
+
(_, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
|
|
974
|
+
else:
|
|
975
|
+
# Match Albert
|
|
976
|
+
qkv_nodes = self.model.match_parent_path(
|
|
977
|
+
start_node, ["Add", "Einsum", "Transpose", "MatMul"], [1, None, 0, 0]
|
|
978
|
+
)
|
|
979
|
+
if qkv_nodes is not None:
|
|
980
|
+
(_, einsum_node, transpose_qkv, matmul_qkv) = qkv_nodes
|
|
981
|
+
else:
|
|
982
|
+
return
|
|
983
|
+
|
|
984
|
+
other_inputs = []
|
|
985
|
+
for _i, input in enumerate(start_node.input):
|
|
986
|
+
if input not in output_name_to_node:
|
|
987
|
+
continue
|
|
988
|
+
|
|
989
|
+
if input == qkv_nodes[0].output[0]:
|
|
990
|
+
continue
|
|
991
|
+
other_inputs.append(input)
|
|
992
|
+
if len(other_inputs) != 1:
|
|
993
|
+
return
|
|
994
|
+
|
|
995
|
+
root_input = other_inputs[0]
|
|
996
|
+
"""
|
|
997
|
+
Match flaubert Mask
|
|
998
|
+
|
|
|
999
|
+
Mul --> LayerNormalization --> Attention --> MatMul --> Add
|
|
1000
|
+
| |
|
|
1001
|
+
| |
|
|
1002
|
+
+---------------------------------------------------------
|
|
1003
|
+
"""
|
|
1004
|
+
mul_before_layernorm = self.model.match_parent(start_node, "Mul", 0)
|
|
1005
|
+
if mul_before_layernorm is not None:
|
|
1006
|
+
mul_children = input_name_to_nodes[mul_before_layernorm.output[0]]
|
|
1007
|
+
if mul_children is not None and len(mul_children) == 2:
|
|
1008
|
+
layernorm_node = mul_children[1]
|
|
1009
|
+
if layernorm_node.op_type == "LayerNormalization":
|
|
1010
|
+
root_input = layernorm_node.output[0]
|
|
1011
|
+
else:
|
|
1012
|
+
return
|
|
1013
|
+
elif mul_children is not None and len(mul_children) == 5:
|
|
1014
|
+
root_input = mul_before_layernorm.output[0]
|
|
1015
|
+
else:
|
|
1016
|
+
return
|
|
1017
|
+
elif normalize_node.op_type == "LayerNormalization":
|
|
1018
|
+
children = input_name_to_nodes[root_input]
|
|
1019
|
+
for child in children:
|
|
1020
|
+
if child.op_type == "LayerNormalization":
|
|
1021
|
+
root_input = child.output[0]
|
|
1022
|
+
|
|
1023
|
+
"""
|
|
1024
|
+
When Add before the LayerNormalization produces an output
|
|
1025
|
+
that is consumed by some other nodes other than the LayerNormalization itself,
|
|
1026
|
+
fused SkipLayerNormalization will have several outputs.
|
|
1027
|
+
In this case we need to pick the one used in Attention
|
|
1028
|
+
|
|
1029
|
+
For example, this is the case for ViT
|
|
1030
|
+
|
|
1031
|
+
SkipLayerNormalization --> Attention --> MatMul --> Add --> SkipLayerNormalization
|
|
1032
|
+
| |
|
|
1033
|
+
| |
|
|
1034
|
+
+---------------------------------------------------------------------+
|
|
1035
|
+
"""
|
|
1036
|
+
parent_node = output_name_to_node[root_input]
|
|
1037
|
+
if parent_node.op_type == "SkipLayerNormalization" and len(parent_node.output) == 4:
|
|
1038
|
+
root_input = parent_node.output[0]
|
|
1039
|
+
|
|
1040
|
+
children = input_name_to_nodes[root_input]
|
|
1041
|
+
children_types = [child.op_type for child in children]
|
|
1042
|
+
if children_types.count("MatMul") != 3:
|
|
1043
|
+
return
|
|
1044
|
+
|
|
1045
|
+
v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None])
|
|
1046
|
+
if v_nodes is None:
|
|
1047
|
+
logger.debug("fuse_attention: failed to match v path")
|
|
1048
|
+
return
|
|
1049
|
+
(_, _, add_v, matmul_v) = v_nodes
|
|
1050
|
+
|
|
1051
|
+
is_distill = False
|
|
1052
|
+
is_distill_add = False
|
|
1053
|
+
is_no_mask_attention = False
|
|
1054
|
+
qk_paths = {
|
|
1055
|
+
"path1": (["Softmax", "Add", "Div", "MatMul"], [0, 0, None, 0]),
|
|
1056
|
+
"path2": (["Softmax", "Add", "Mul", "MatMul"], [0, 0, None, 0]),
|
|
1057
|
+
"path3": (["Softmax", "Where", "MatMul", "Div"], [0, 0, 2, 0]),
|
|
1058
|
+
"path4": (["Softmax", "Add", "Where", "MatMul"], [0, 0, 0, 2]),
|
|
1059
|
+
"path5": (["Softmax", "Div", "MatMul"], [0, 0, 0]),
|
|
1060
|
+
}
|
|
1061
|
+
|
|
1062
|
+
qk_nodes = None
|
|
1063
|
+
for k, v in qk_paths.items():
|
|
1064
|
+
qk_nodes = self.model.match_parent_path(matmul_qkv, v[0], v[1])
|
|
1065
|
+
if qk_nodes is None:
|
|
1066
|
+
continue
|
|
1067
|
+
if k == "path3":
|
|
1068
|
+
is_distill = True
|
|
1069
|
+
if k == "path4":
|
|
1070
|
+
is_distill_add = True
|
|
1071
|
+
if k == "path5":
|
|
1072
|
+
is_no_mask_attention = True
|
|
1073
|
+
break
|
|
1074
|
+
|
|
1075
|
+
if qk_nodes is None:
|
|
1076
|
+
logger.debug("fuse_attention: failed to match qk path")
|
|
1077
|
+
return
|
|
1078
|
+
|
|
1079
|
+
add_qk = None
|
|
1080
|
+
matmul_qk = None
|
|
1081
|
+
where_qk = None
|
|
1082
|
+
if is_distill:
|
|
1083
|
+
(_, where_qk, matmul_qk, _) = qk_nodes
|
|
1084
|
+
elif is_distill_add:
|
|
1085
|
+
(_, add_qk, where_qk, matmul_qk) = qk_nodes
|
|
1086
|
+
elif is_no_mask_attention:
|
|
1087
|
+
(_, _, matmul_qk) = qk_nodes
|
|
1088
|
+
else:
|
|
1089
|
+
(_, add_qk, _, matmul_qk) = qk_nodes
|
|
1090
|
+
|
|
1091
|
+
q_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, None])
|
|
1092
|
+
if q_nodes is None:
|
|
1093
|
+
q_nodes = self.model.match_parent_path(
|
|
1094
|
+
matmul_qk,
|
|
1095
|
+
["Div", "Transpose", "Reshape", "Add", "MatMul"],
|
|
1096
|
+
[0, 0, 0, 0, None],
|
|
1097
|
+
)
|
|
1098
|
+
if q_nodes is None:
|
|
1099
|
+
logger.debug("fuse_attention: failed to match q path")
|
|
1100
|
+
return
|
|
1101
|
+
reshape_q = q_nodes[-3]
|
|
1102
|
+
add_q = q_nodes[-2]
|
|
1103
|
+
matmul_q = q_nodes[-1]
|
|
1104
|
+
|
|
1105
|
+
k_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None])
|
|
1106
|
+
if k_nodes is None:
|
|
1107
|
+
k_nodes = self.model.match_parent_path(
|
|
1108
|
+
matmul_qk,
|
|
1109
|
+
["Transpose", "Transpose", "Reshape", "Add", "MatMul"],
|
|
1110
|
+
[1, 0, 0, 0, None],
|
|
1111
|
+
)
|
|
1112
|
+
if k_nodes is None:
|
|
1113
|
+
logger.debug("fuse_attention: failed to match k path")
|
|
1114
|
+
return
|
|
1115
|
+
add_k = k_nodes[-2]
|
|
1116
|
+
matmul_k = k_nodes[-1]
|
|
1117
|
+
|
|
1118
|
+
# Note that Cast might be removed by OnnxRuntime so we match two patterns here.
|
|
1119
|
+
mask_nodes = None
|
|
1120
|
+
add_qk_str = None
|
|
1121
|
+
if is_distill:
|
|
1122
|
+
_, mask_nodes, _ = self.model.match_parent_paths(
|
|
1123
|
+
where_qk,
|
|
1124
|
+
[
|
|
1125
|
+
(["Expand", "Reshape", "Equal"], [0, 0, 0]),
|
|
1126
|
+
(["Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0]),
|
|
1127
|
+
(["Cast", "Expand", "Reshape", "Equal"], [0, 0, 0, 0]),
|
|
1128
|
+
],
|
|
1129
|
+
output_name_to_node,
|
|
1130
|
+
)
|
|
1131
|
+
elif is_distill_add:
|
|
1132
|
+
_, mask_nodes, _ = self.model.match_parent_paths(
|
|
1133
|
+
where_qk,
|
|
1134
|
+
[
|
|
1135
|
+
(["Cast", "Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0, 0]),
|
|
1136
|
+
(["Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0]),
|
|
1137
|
+
],
|
|
1138
|
+
output_name_to_node,
|
|
1139
|
+
)
|
|
1140
|
+
if add_qk is not None:
|
|
1141
|
+
add_qk_str = self.get_add_qk_str(add_qk)
|
|
1142
|
+
if add_qk_str is None:
|
|
1143
|
+
logger.debug(f"fuse_attention: failed to verify shape inference of {add_qk}")
|
|
1144
|
+
return
|
|
1145
|
+
elif is_no_mask_attention:
|
|
1146
|
+
pass
|
|
1147
|
+
else:
|
|
1148
|
+
_, mask_nodes, _ = self.model.match_parent_paths(
|
|
1149
|
+
add_qk,
|
|
1150
|
+
[
|
|
1151
|
+
(
|
|
1152
|
+
["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
|
|
1153
|
+
[None, 0, 1, 0, 0],
|
|
1154
|
+
),
|
|
1155
|
+
(["Mul", "Sub", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0]),
|
|
1156
|
+
],
|
|
1157
|
+
output_name_to_node,
|
|
1158
|
+
)
|
|
1159
|
+
if not is_no_mask_attention and mask_nodes is None:
|
|
1160
|
+
logger.debug("fuse_attention: failed to match mask path")
|
|
1161
|
+
return
|
|
1162
|
+
|
|
1163
|
+
if not is_no_mask_attention and len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul":
|
|
1164
|
+
_, mul_val = self.model.get_constant_input(mask_nodes[0])
|
|
1165
|
+
if mul_val != -10000:
|
|
1166
|
+
self.mask_filter_value = mul_val
|
|
1167
|
+
|
|
1168
|
+
if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input:
|
|
1169
|
+
mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) if not is_no_mask_attention else None
|
|
1170
|
+
|
|
1171
|
+
attention_last_node = reshape_qkv if einsum_node is None else transpose_qkv
|
|
1172
|
+
|
|
1173
|
+
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
|
|
1174
|
+
if q_num_heads <= 0 or q_hidden_size <= 0:
|
|
1175
|
+
logger.warning(
|
|
1176
|
+
"Failed to detect num_heads and hidden_size for Attention fusion. "
|
|
1177
|
+
"Please specify those parameters in argument."
|
|
1178
|
+
)
|
|
1179
|
+
return
|
|
1180
|
+
|
|
1181
|
+
# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
|
|
1182
|
+
# the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately
|
|
1183
|
+
new_node = self.create_attention_node(
|
|
1184
|
+
mask_index,
|
|
1185
|
+
matmul_q,
|
|
1186
|
+
matmul_k,
|
|
1187
|
+
matmul_v,
|
|
1188
|
+
add_q,
|
|
1189
|
+
add_k,
|
|
1190
|
+
add_v,
|
|
1191
|
+
q_num_heads,
|
|
1192
|
+
q_hidden_size,
|
|
1193
|
+
root_input,
|
|
1194
|
+
attention_last_node.output[0],
|
|
1195
|
+
add_qk_str,
|
|
1196
|
+
)
|
|
1197
|
+
if new_node is None:
|
|
1198
|
+
return
|
|
1199
|
+
|
|
1200
|
+
self.nodes_to_add.append(new_node)
|
|
1201
|
+
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
|
1202
|
+
|
|
1203
|
+
if einsum_node is not None:
|
|
1204
|
+
unique_index = einsum_node.input[0]
|
|
1205
|
+
new_edge = "edge_modified_" + unique_index
|
|
1206
|
+
|
|
1207
|
+
shape_tensor = self.add_initializer(
|
|
1208
|
+
name="shape_modified_tensor" + unique_index,
|
|
1209
|
+
data_type=TensorProto.INT64,
|
|
1210
|
+
dims=[4],
|
|
1211
|
+
vals=np.int64([0, 0, q_num_heads, int(q_hidden_size / q_num_heads)]),
|
|
1212
|
+
raw=False,
|
|
1213
|
+
)
|
|
1214
|
+
|
|
1215
|
+
self.model.add_node(
|
|
1216
|
+
helper.make_node(
|
|
1217
|
+
"Reshape",
|
|
1218
|
+
[attention_last_node.output[0], shape_tensor.name],
|
|
1219
|
+
[new_edge],
|
|
1220
|
+
"reshape_modified_" + unique_index,
|
|
1221
|
+
),
|
|
1222
|
+
self.this_graph_name,
|
|
1223
|
+
)
|
|
1224
|
+
einsum_node.input[0] = new_edge
|
|
1225
|
+
|
|
1226
|
+
self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
|
|
1227
|
+
self.nodes_to_remove.extend(qk_nodes)
|
|
1228
|
+
|
|
1229
|
+
# For MultiHeadAttention operator, MatMul nodes for Q/K/V projection shall not be fused.
|
|
1230
|
+
self.nodes_to_remove.extend(q_nodes if not self.use_multi_head_attention else q_nodes[:-1])
|
|
1231
|
+
self.nodes_to_remove.extend(k_nodes if not self.use_multi_head_attention else k_nodes[:-1])
|
|
1232
|
+
self.nodes_to_remove.extend(v_nodes if not self.use_multi_head_attention else v_nodes[:-1])
|
|
1233
|
+
|
|
1234
|
+
# Use prune graph to remove mask nodes since they are shared by all attention nodes.
|
|
1235
|
+
self.prune_graph = True
|