onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,534 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
from logging import getLogger
|
|
6
|
+
from typing import Tuple, Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from fusion_base import Fusion
|
|
10
|
+
from fusion_utils import NumpyHelper
|
|
11
|
+
from onnx import NodeProto, helper, numpy_helper
|
|
12
|
+
from onnx_model import OnnxModel
|
|
13
|
+
|
|
14
|
+
logger = getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class FusionMultiHeadAttentionSam2(Fusion):
|
|
18
|
+
"""
|
|
19
|
+
Fuse MultiHeadAttention subgraph of Segment Anything v2 (SAM2).
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
model: OnnxModel,
|
|
25
|
+
hidden_size: int,
|
|
26
|
+
num_heads: int,
|
|
27
|
+
):
|
|
28
|
+
super().__init__(model, "MultiHeadAttention", ["LayerNormalization"])
|
|
29
|
+
self.hidden_size = hidden_size
|
|
30
|
+
self.num_heads = num_heads
|
|
31
|
+
|
|
32
|
+
# Flags to show warning only once
|
|
33
|
+
self.num_heads_warning = True
|
|
34
|
+
self.hidden_size_warning = True
|
|
35
|
+
|
|
36
|
+
def get_decoder_num_heads(self, reshape_q: NodeProto) -> int:
|
|
37
|
+
"""Detect num_heads from a reshape node.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
reshape_q (NodeProto): reshape node for Q
|
|
41
|
+
Returns:
|
|
42
|
+
int: num_heads, or 0 if not found
|
|
43
|
+
"""
|
|
44
|
+
num_heads = 0
|
|
45
|
+
|
|
46
|
+
# we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
|
|
47
|
+
shape_value = self.model.get_constant_value(reshape_q.input[1])
|
|
48
|
+
if shape_value is not None:
|
|
49
|
+
if isinstance(shape_value, np.ndarray) and list(shape_value.shape) == [4]:
|
|
50
|
+
num_heads = int(shape_value[2])
|
|
51
|
+
|
|
52
|
+
if isinstance(num_heads, int) and num_heads > 0:
|
|
53
|
+
return num_heads
|
|
54
|
+
|
|
55
|
+
return 0
|
|
56
|
+
|
|
57
|
+
def get_encoder_num_heads(self, reshape_in: NodeProto) -> int:
|
|
58
|
+
"""Detect num_heads from a reshape node.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
reshape_q (NodeProto): reshape node for Q
|
|
62
|
+
Returns:
|
|
63
|
+
int: num_heads, or 0 if not found
|
|
64
|
+
"""
|
|
65
|
+
num_heads = 0
|
|
66
|
+
|
|
67
|
+
shape_value = self.model.get_constant_value(reshape_in.input[1])
|
|
68
|
+
if shape_value is not None:
|
|
69
|
+
if isinstance(shape_value, np.ndarray) and list(shape_value.shape) == [5]:
|
|
70
|
+
num_heads = int(shape_value[3])
|
|
71
|
+
else:
|
|
72
|
+
concat_shape = self.model.match_parent(reshape_in, "Concat", 1)
|
|
73
|
+
if concat_shape is not None and len(concat_shape.input) == 5:
|
|
74
|
+
# we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
|
|
75
|
+
shape_value = self.model.get_constant_value(concat_shape.input[3])
|
|
76
|
+
if shape_value is not None:
|
|
77
|
+
if isinstance(shape_value, np.ndarray) and list(shape_value.shape) == [1]:
|
|
78
|
+
num_heads = int(shape_value[0])
|
|
79
|
+
|
|
80
|
+
if isinstance(num_heads, int) and num_heads > 0:
|
|
81
|
+
return num_heads
|
|
82
|
+
|
|
83
|
+
return 0
|
|
84
|
+
|
|
85
|
+
def get_hidden_size(self, layernorm_node):
|
|
86
|
+
"""Detect hidden_size from LayerNormalization node.
|
|
87
|
+
Args:
|
|
88
|
+
layernorm_node (NodeProto): LayerNormalization node before Q, K and V
|
|
89
|
+
Returns:
|
|
90
|
+
int: hidden_size, or 0 if not found
|
|
91
|
+
"""
|
|
92
|
+
layernorm_bias = self.model.get_initializer(layernorm_node.input[2])
|
|
93
|
+
if layernorm_bias:
|
|
94
|
+
return NumpyHelper.to_array(layernorm_bias).shape[0]
|
|
95
|
+
|
|
96
|
+
return 0
|
|
97
|
+
|
|
98
|
+
def get_num_heads_and_hidden_size(
|
|
99
|
+
self, reshape_q: NodeProto, layernorm_node: NodeProto, is_encoder: bool = False
|
|
100
|
+
) -> Tuple[int, int]:
|
|
101
|
+
"""Detect num_heads and hidden_size.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
reshape_q (NodeProto): reshape node for Q
|
|
105
|
+
layernorm_node (NodeProto): LayerNormalization node before Q, K, V
|
|
106
|
+
Returns:
|
|
107
|
+
Tuple[int, int]: num_heads and hidden_size
|
|
108
|
+
"""
|
|
109
|
+
if is_encoder:
|
|
110
|
+
num_heads = self.get_encoder_num_heads(reshape_q)
|
|
111
|
+
else:
|
|
112
|
+
num_heads = self.get_decoder_num_heads(reshape_q)
|
|
113
|
+
if num_heads <= 0:
|
|
114
|
+
num_heads = self.num_heads # Fall back to user specified value
|
|
115
|
+
|
|
116
|
+
if self.num_heads > 0 and num_heads != self.num_heads:
|
|
117
|
+
if self.num_heads_warning:
|
|
118
|
+
logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
|
|
119
|
+
self.num_heads_warning = False # Do not show the warning more than once
|
|
120
|
+
|
|
121
|
+
hidden_size = self.get_hidden_size(layernorm_node)
|
|
122
|
+
if hidden_size <= 0:
|
|
123
|
+
hidden_size = self.hidden_size # Fall back to user specified value
|
|
124
|
+
|
|
125
|
+
if self.hidden_size > 0 and hidden_size != self.hidden_size:
|
|
126
|
+
if self.hidden_size_warning:
|
|
127
|
+
logger.warning(
|
|
128
|
+
f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
|
|
129
|
+
)
|
|
130
|
+
self.hidden_size_warning = False # Do not show the warning more than once
|
|
131
|
+
|
|
132
|
+
return num_heads, hidden_size
|
|
133
|
+
|
|
134
|
+
def create_attention_node(
|
|
135
|
+
self,
|
|
136
|
+
q_matmul: NodeProto,
|
|
137
|
+
q_add: NodeProto,
|
|
138
|
+
k_matmul: NodeProto,
|
|
139
|
+
k_add: NodeProto,
|
|
140
|
+
v_matmul: NodeProto,
|
|
141
|
+
v_add: NodeProto,
|
|
142
|
+
num_heads: int,
|
|
143
|
+
hidden_size: int,
|
|
144
|
+
output: str,
|
|
145
|
+
) -> Union[NodeProto, None]:
|
|
146
|
+
"""Create an Attention node.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
q_matmul (NodeProto): MatMul node in fully connection for Q
|
|
150
|
+
q_add (NodeProto): Add bias node in fully connection for Q
|
|
151
|
+
k_matmul (NodeProto): MatMul node in fully connection for K
|
|
152
|
+
k_add (NodeProto): Add bias node in fully connection for K
|
|
153
|
+
v_matmul (NodeProto): MatMul node in fully connection for V
|
|
154
|
+
v_add (NodeProto): Add bias node in fully connection for V
|
|
155
|
+
num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
|
|
156
|
+
hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
|
|
157
|
+
output (str): output name
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
Union[NodeProto, None]: the node created or None if failed.
|
|
161
|
+
"""
|
|
162
|
+
if hidden_size > 0 and (hidden_size % num_heads) != 0:
|
|
163
|
+
logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
|
|
164
|
+
return None
|
|
165
|
+
|
|
166
|
+
q_weight = self.model.get_initializer(q_matmul.input[1])
|
|
167
|
+
k_weight = self.model.get_initializer(k_matmul.input[1])
|
|
168
|
+
v_weight = self.model.get_initializer(v_matmul.input[1])
|
|
169
|
+
if not (q_weight and k_weight and v_weight):
|
|
170
|
+
return None
|
|
171
|
+
|
|
172
|
+
qw = NumpyHelper.to_array(q_weight)
|
|
173
|
+
kw = NumpyHelper.to_array(k_weight)
|
|
174
|
+
vw = NumpyHelper.to_array(v_weight)
|
|
175
|
+
logger.debug(f"qw={qw.shape} kw={kw.shape} vw={vw.shape} hidden_size={hidden_size}")
|
|
176
|
+
|
|
177
|
+
attention_node_name = self.model.create_node_name("MultiHeadAttention")
|
|
178
|
+
|
|
179
|
+
attention_inputs = [
|
|
180
|
+
q_add.output[0],
|
|
181
|
+
k_add.output[0],
|
|
182
|
+
v_add.output[0],
|
|
183
|
+
]
|
|
184
|
+
|
|
185
|
+
attention_node = helper.make_node(
|
|
186
|
+
"MultiHeadAttention",
|
|
187
|
+
inputs=attention_inputs,
|
|
188
|
+
outputs=[output],
|
|
189
|
+
name=attention_node_name,
|
|
190
|
+
)
|
|
191
|
+
attention_node.domain = "com.microsoft"
|
|
192
|
+
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
|
|
193
|
+
|
|
194
|
+
counter_name = "MultiHeadAttention ({})".format("cross attention")
|
|
195
|
+
self.increase_counter(counter_name)
|
|
196
|
+
return attention_node
|
|
197
|
+
|
|
198
|
+
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
|
199
|
+
if self.fuse_sam_encoder_pattern(normalize_node, input_name_to_nodes, output_name_to_node):
|
|
200
|
+
return
|
|
201
|
+
|
|
202
|
+
match_qkv = self.match_attention_subgraph(normalize_node)
|
|
203
|
+
if match_qkv is None:
|
|
204
|
+
if normalize_node.input[0] not in output_name_to_node:
|
|
205
|
+
return
|
|
206
|
+
|
|
207
|
+
skip_add = output_name_to_node[normalize_node.input[0]]
|
|
208
|
+
if skip_add.op_type != "Add":
|
|
209
|
+
return
|
|
210
|
+
|
|
211
|
+
match_qkv = self.match_attention_subgraph(skip_add)
|
|
212
|
+
|
|
213
|
+
if match_qkv is None:
|
|
214
|
+
return
|
|
215
|
+
|
|
216
|
+
reshape_qkv, transpose_qkv, reshape_q, matmul_q, add_q, matmul_k, add_k, matmul_v, add_v = match_qkv
|
|
217
|
+
|
|
218
|
+
attention_last_node = reshape_qkv
|
|
219
|
+
|
|
220
|
+
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, False)
|
|
221
|
+
if q_num_heads <= 0:
|
|
222
|
+
logger.debug("fuse_attention: failed to detect num_heads")
|
|
223
|
+
return
|
|
224
|
+
|
|
225
|
+
# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
|
|
226
|
+
new_node = self.create_attention_node(
|
|
227
|
+
matmul_q,
|
|
228
|
+
add_q,
|
|
229
|
+
matmul_k,
|
|
230
|
+
add_k,
|
|
231
|
+
matmul_v,
|
|
232
|
+
add_v,
|
|
233
|
+
q_num_heads,
|
|
234
|
+
q_hidden_size,
|
|
235
|
+
output=attention_last_node.output[0],
|
|
236
|
+
)
|
|
237
|
+
if new_node is None:
|
|
238
|
+
return
|
|
239
|
+
|
|
240
|
+
self.nodes_to_add.append(new_node)
|
|
241
|
+
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
|
242
|
+
|
|
243
|
+
self.nodes_to_remove.extend([attention_last_node, transpose_qkv])
|
|
244
|
+
|
|
245
|
+
# Use prune graph to remove nodes since they are shared by all attention nodes.
|
|
246
|
+
self.prune_graph = True
|
|
247
|
+
|
|
248
|
+
def match_attention_subgraph(self, node_after_output_projection):
|
|
249
|
+
"""Match Q, K and V paths exported by PyTorch 2.*"""
|
|
250
|
+
qkv_nodes = self.model.match_parent_path(
|
|
251
|
+
node_after_output_projection,
|
|
252
|
+
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
|
|
253
|
+
[None, None, None, 0, 0],
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
if qkv_nodes is None:
|
|
257
|
+
return None
|
|
258
|
+
|
|
259
|
+
(_, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes
|
|
260
|
+
|
|
261
|
+
v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None])
|
|
262
|
+
if v_nodes is None:
|
|
263
|
+
logger.debug("fuse_attention: failed to match v path")
|
|
264
|
+
return None
|
|
265
|
+
(_, _, add_v, matmul_v) = v_nodes
|
|
266
|
+
|
|
267
|
+
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
|
|
268
|
+
if qk_nodes is not None:
|
|
269
|
+
(_softmax_qk, matmul_qk) = qk_nodes
|
|
270
|
+
else:
|
|
271
|
+
logger.debug("fuse_attention: failed to match qk path")
|
|
272
|
+
return None
|
|
273
|
+
|
|
274
|
+
q_nodes = self.model.match_parent_path(
|
|
275
|
+
matmul_qk, ["Mul", "Transpose", "Reshape", "Add", "MatMul"], [0, None, 0, 0, None]
|
|
276
|
+
)
|
|
277
|
+
if q_nodes is None:
|
|
278
|
+
logger.debug("fuse_attention: failed to match q path")
|
|
279
|
+
return None
|
|
280
|
+
(mul_q, _transpose_q, reshape_q, add_q, matmul_q) = q_nodes
|
|
281
|
+
|
|
282
|
+
k_nodes = self.model.match_parent_path(
|
|
283
|
+
matmul_qk, ["Mul", "Transpose", "Reshape", "Add", "MatMul"], [1, None, 0, 0, None]
|
|
284
|
+
)
|
|
285
|
+
if k_nodes is None:
|
|
286
|
+
logger.debug("fuse_attention: failed to match k path")
|
|
287
|
+
return None
|
|
288
|
+
|
|
289
|
+
(_mul_k, _, _, add_k, matmul_k) = k_nodes
|
|
290
|
+
|
|
291
|
+
# The scalar for Q and K is sqrt(1.0/sqrt(head_size)).
|
|
292
|
+
mul_q_nodes = self.model.match_parent_path(
|
|
293
|
+
mul_q,
|
|
294
|
+
["Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape", "Transpose", "Reshape"],
|
|
295
|
+
[None, 0, 1, 0, 0, 0, 0, 0],
|
|
296
|
+
)
|
|
297
|
+
if mul_q_nodes is None or mul_q_nodes[-1] != reshape_q:
|
|
298
|
+
logger.debug("fuse_attention: failed to match mul_q path")
|
|
299
|
+
return None
|
|
300
|
+
|
|
301
|
+
return reshape_qkv, transpose_qkv, reshape_q, matmul_q, add_q, matmul_k, add_k, matmul_v, add_v
|
|
302
|
+
|
|
303
|
+
# --------------------------------------------------------
|
|
304
|
+
# The following are for SAM encoder
|
|
305
|
+
# --------------------------------------------------------
|
|
306
|
+
def fuse_sam_encoder_pattern(self, normalize_node, input_name_to_nodes, output_name_to_node) -> bool:
|
|
307
|
+
# SAM encoder attention layer pattern:
|
|
308
|
+
# Add -----------+
|
|
309
|
+
# | |
|
|
310
|
+
# LayerNorm |
|
|
311
|
+
# | |
|
|
312
|
+
# Reshape |
|
|
313
|
+
# | |
|
|
314
|
+
# Transpose |
|
|
315
|
+
# | |
|
|
316
|
+
# MatMul |
|
|
317
|
+
# | |
|
|
318
|
+
# Add |
|
|
319
|
+
# | |
|
|
320
|
+
# Reshape |
|
|
321
|
+
# | |
|
|
322
|
+
# Split |
|
|
323
|
+
# | |
|
|
324
|
+
# Self Attention subgraph |
|
|
325
|
+
# | |
|
|
326
|
+
# Reshape |
|
|
327
|
+
# | |
|
|
328
|
+
# Transpose |
|
|
329
|
+
# | |
|
|
330
|
+
# Reshape |
|
|
331
|
+
# | |
|
|
332
|
+
# Add ----------+
|
|
333
|
+
# |
|
|
334
|
+
# LayerNorm (starts from here)
|
|
335
|
+
|
|
336
|
+
nodes = self.model.match_parent_path(
|
|
337
|
+
normalize_node,
|
|
338
|
+
["Add", "Reshape", "Transpose", "Reshape"],
|
|
339
|
+
[0, None, 0, 0],
|
|
340
|
+
)
|
|
341
|
+
if nodes is None:
|
|
342
|
+
nodes = self.model.match_parent_path(
|
|
343
|
+
normalize_node,
|
|
344
|
+
["Add", "Slice", "Slice", "Reshape", "Transpose", "Reshape"],
|
|
345
|
+
[0, None, 0, 0, 0, 0],
|
|
346
|
+
)
|
|
347
|
+
if nodes is None:
|
|
348
|
+
nodes = self.model.match_parent_path(
|
|
349
|
+
normalize_node,
|
|
350
|
+
["Add"],
|
|
351
|
+
[0],
|
|
352
|
+
)
|
|
353
|
+
if nodes is None:
|
|
354
|
+
return False
|
|
355
|
+
|
|
356
|
+
node_after_output_projection = nodes[-1]
|
|
357
|
+
matched_sdpa = self.match_sam_encoder_attention_subgraph(
|
|
358
|
+
node_after_output_projection, input_index=1 if len(nodes) == 1 else None
|
|
359
|
+
)
|
|
360
|
+
if matched_sdpa is None:
|
|
361
|
+
return False
|
|
362
|
+
|
|
363
|
+
reshape_out, transpose_out, split_qkv, transpose_q, transpose_k, transpose_v = matched_sdpa
|
|
364
|
+
|
|
365
|
+
# B, S, N, H => B, N, S, H
|
|
366
|
+
permutation_q = OnnxModel.get_node_attribute(transpose_q, "perm")
|
|
367
|
+
if (not isinstance(permutation_q, list)) or permutation_q != [0, 2, 1, 3]:
|
|
368
|
+
return False
|
|
369
|
+
|
|
370
|
+
# B, S, N, H => B, N, H, S
|
|
371
|
+
permutation_k = OnnxModel.get_node_attribute(transpose_k, "perm")
|
|
372
|
+
if (not isinstance(permutation_k, list)) or permutation_k != [0, 2, 3, 1]:
|
|
373
|
+
return False
|
|
374
|
+
|
|
375
|
+
# B, S, N, H => B, N, S, H
|
|
376
|
+
permutation_v = OnnxModel.get_node_attribute(transpose_v, "perm")
|
|
377
|
+
if (not isinstance(permutation_v, list)) or permutation_v != [0, 2, 1, 3]:
|
|
378
|
+
return False
|
|
379
|
+
|
|
380
|
+
input_projection_nodes = self.model.match_parent_path(
|
|
381
|
+
split_qkv,
|
|
382
|
+
["Reshape", "Add", "MatMul"],
|
|
383
|
+
[0, 0, None],
|
|
384
|
+
)
|
|
385
|
+
if input_projection_nodes is None:
|
|
386
|
+
return False
|
|
387
|
+
reshape_in, add_in, matmul_in = input_projection_nodes
|
|
388
|
+
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_in, normalize_node, True)
|
|
389
|
+
if q_num_heads <= 0:
|
|
390
|
+
logger.debug("fuse_attention: failed to detect num_heads")
|
|
391
|
+
return False
|
|
392
|
+
|
|
393
|
+
# Add a shape to convert 4D BxSxNxH to 3D BxSxD, which is required by MHA operator.
|
|
394
|
+
new_dims_name = "bsnh_to_bsd_reshape_dims"
|
|
395
|
+
new_dims = self.model.get_initializer(new_dims_name)
|
|
396
|
+
if new_dims is None:
|
|
397
|
+
new_dims = numpy_helper.from_array(np.array([0, 0, -1], dtype="int64"), name=new_dims_name)
|
|
398
|
+
self.model.add_initializer(new_dims, self.this_graph_name)
|
|
399
|
+
reshape_q_name = self.model.create_node_name("Reshape")
|
|
400
|
+
reshape_q = helper.make_node(
|
|
401
|
+
"Reshape",
|
|
402
|
+
inputs=[transpose_q.input[0], new_dims_name],
|
|
403
|
+
outputs=[transpose_q.input[0] + "_BSD"],
|
|
404
|
+
name=reshape_q_name,
|
|
405
|
+
)
|
|
406
|
+
self.nodes_to_add.append(reshape_q)
|
|
407
|
+
self.node_name_to_graph_name[reshape_q.name] = self.this_graph_name
|
|
408
|
+
|
|
409
|
+
# Reuse the transpose_q node to transpose K from BSNH to BNSH. Here we update the input and output of the node.
|
|
410
|
+
transpose_k_bnsh = transpose_q
|
|
411
|
+
transpose_k_bnsh.input[0] = transpose_k.input[0]
|
|
412
|
+
transpose_k_bnsh.output[0] = transpose_k.input[0] + "_BNSH"
|
|
413
|
+
|
|
414
|
+
logger.debug(f"Found MHA: {q_num_heads=} {q_hidden_size=}")
|
|
415
|
+
|
|
416
|
+
# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
|
|
417
|
+
new_node = self.create_mha_node(
|
|
418
|
+
reshape_q,
|
|
419
|
+
transpose_k_bnsh,
|
|
420
|
+
transpose_v,
|
|
421
|
+
q_num_heads,
|
|
422
|
+
)
|
|
423
|
+
if new_node is None:
|
|
424
|
+
return False
|
|
425
|
+
|
|
426
|
+
# Update the input of the next node that consumes the output of the MHA.
|
|
427
|
+
assert len(self.model.get_children(transpose_out, input_name_to_nodes)) == 1
|
|
428
|
+
reshape_out.input[0] = new_node.output[0]
|
|
429
|
+
|
|
430
|
+
self.nodes_to_add.append(new_node)
|
|
431
|
+
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
|
432
|
+
self.nodes_to_remove.extend([transpose_out])
|
|
433
|
+
|
|
434
|
+
# Use prune graph to remove nodes since they are shared by all attention nodes.
|
|
435
|
+
self.prune_graph = True
|
|
436
|
+
return True
|
|
437
|
+
|
|
438
|
+
def match_sam_encoder_attention_subgraph(self, node_after_output_projection, input_index=None):
|
|
439
|
+
"""Match SDPA pattern in SAM2 enconder.*"""
|
|
440
|
+
|
|
441
|
+
# nodes of output projection and the second MatMul in SDPA.
|
|
442
|
+
out_nodes = self.model.match_parent_path(
|
|
443
|
+
node_after_output_projection,
|
|
444
|
+
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
|
|
445
|
+
[input_index, None, None, 0, 0],
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
if out_nodes is None:
|
|
449
|
+
return None
|
|
450
|
+
|
|
451
|
+
(_, _, reshape_out, transpose_out, matmul_qk_v) = out_nodes
|
|
452
|
+
|
|
453
|
+
# Split and Reshape is for packed QKV
|
|
454
|
+
v_nodes = self.model.match_parent_path(matmul_qk_v, ["Transpose", "Squeeze", "Split", "Reshape"], [1, 0, 0, 0])
|
|
455
|
+
if v_nodes is None:
|
|
456
|
+
logger.debug("failed to match v path")
|
|
457
|
+
return None
|
|
458
|
+
(transpose_v, _, split_qkv, reshape_qkv) = v_nodes
|
|
459
|
+
|
|
460
|
+
qk_nodes = self.model.match_parent_path(matmul_qk_v, ["Softmax", "MatMul"], [0, 0])
|
|
461
|
+
if qk_nodes is not None:
|
|
462
|
+
(_softmax_qk, matmul_qk) = qk_nodes
|
|
463
|
+
else:
|
|
464
|
+
logger.debug("failed to match qk path")
|
|
465
|
+
return None
|
|
466
|
+
|
|
467
|
+
q_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Squeeze", "Split"], [0, None, 0, 0])
|
|
468
|
+
if q_nodes is None:
|
|
469
|
+
q_nodes = self.model.match_parent_path(
|
|
470
|
+
matmul_qk,
|
|
471
|
+
["Mul", "Transpose", "Reshape", "Transpose", "MaxPool", "Transpose", "Reshape", "Squeeze", "Split"],
|
|
472
|
+
[0, None, 0, 0, 0, 0, 0, 0, 0],
|
|
473
|
+
)
|
|
474
|
+
if q_nodes is None:
|
|
475
|
+
logger.debug("failed to match q path")
|
|
476
|
+
return None
|
|
477
|
+
|
|
478
|
+
if q_nodes[-1] != split_qkv:
|
|
479
|
+
return None
|
|
480
|
+
transpose_q = q_nodes[1]
|
|
481
|
+
|
|
482
|
+
k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Squeeze", "Split"], [1, None, 0, 0])
|
|
483
|
+
if k_nodes is None:
|
|
484
|
+
logger.debug("failed to match k path")
|
|
485
|
+
return None
|
|
486
|
+
|
|
487
|
+
if k_nodes[-1] != split_qkv:
|
|
488
|
+
return None
|
|
489
|
+
(mul_k, transpose_k, _squeeze_k, _) = k_nodes
|
|
490
|
+
|
|
491
|
+
return reshape_out, transpose_out, split_qkv, transpose_q, transpose_k, transpose_v
|
|
492
|
+
|
|
493
|
+
def create_mha_node(
|
|
494
|
+
self,
|
|
495
|
+
reshape_q: NodeProto,
|
|
496
|
+
transpose_k: NodeProto,
|
|
497
|
+
transpose_v: NodeProto,
|
|
498
|
+
num_heads: int,
|
|
499
|
+
) -> NodeProto:
|
|
500
|
+
"""Create a MultiHeadAttention node for SAM2 encoder.
|
|
501
|
+
|
|
502
|
+
Args:
|
|
503
|
+
reshape_q (NodeProto): Reshape node for Q, output is 3D BxSxNH format
|
|
504
|
+
transpose_k (NodeProto): Transpose node for K, output is BNSH format
|
|
505
|
+
transpose_v (NodeProto): Transpose node for V, output is BNSH format
|
|
506
|
+
num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
|
|
507
|
+
|
|
508
|
+
Returns:
|
|
509
|
+
NodeProto: the MultiHeadAttention node created.
|
|
510
|
+
"""
|
|
511
|
+
|
|
512
|
+
attention_node_name = self.model.create_node_name("MultiHeadAttention")
|
|
513
|
+
|
|
514
|
+
inputs = [
|
|
515
|
+
reshape_q.output[0],
|
|
516
|
+
transpose_k.output[0],
|
|
517
|
+
transpose_v.output[0],
|
|
518
|
+
]
|
|
519
|
+
|
|
520
|
+
# Create a new output name since the shape is 3D, which is different from the original output shape (4D).
|
|
521
|
+
output = attention_node_name + "_out"
|
|
522
|
+
|
|
523
|
+
attention_node = helper.make_node(
|
|
524
|
+
"MultiHeadAttention",
|
|
525
|
+
inputs=inputs,
|
|
526
|
+
outputs=[output],
|
|
527
|
+
name=attention_node_name,
|
|
528
|
+
)
|
|
529
|
+
attention_node.domain = "com.microsoft"
|
|
530
|
+
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
|
|
531
|
+
|
|
532
|
+
counter_name = "MultiHeadAttention ({})".format("self attention")
|
|
533
|
+
self.increase_counter(counter_name)
|
|
534
|
+
return attention_node
|