onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,985 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from fusion_attention import AttentionMask, FusionAttention
|
|
9
|
+
from fusion_base import Fusion
|
|
10
|
+
from fusion_simplified_layernorm import FusionSimplifiedLayerNormalization, FusionSkipSimplifiedLayerNormalization
|
|
11
|
+
from fusion_utils import NumpyHelper
|
|
12
|
+
from onnx import NodeProto, TensorProto, helper
|
|
13
|
+
from onnx_model import OnnxModel
|
|
14
|
+
from onnx_model_bert import BertOnnxModel
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class FusionT5Attention(FusionAttention):
|
|
20
|
+
"""
|
|
21
|
+
Fuse T5 Attention subgraph into one Attention node.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
model: OnnxModel,
|
|
27
|
+
hidden_size: int,
|
|
28
|
+
num_heads: int,
|
|
29
|
+
attention_mask: AttentionMask,
|
|
30
|
+
):
|
|
31
|
+
super().__init__(
|
|
32
|
+
model,
|
|
33
|
+
hidden_size,
|
|
34
|
+
num_heads,
|
|
35
|
+
attention_mask,
|
|
36
|
+
use_multi_head_attention=False,
|
|
37
|
+
search_op_types=["Softmax"],
|
|
38
|
+
)
|
|
39
|
+
self.static_kv = 1
|
|
40
|
+
|
|
41
|
+
def make_attention_node(
|
|
42
|
+
self,
|
|
43
|
+
mask_index: str | None,
|
|
44
|
+
q_matmul: NodeProto,
|
|
45
|
+
k_matmul: NodeProto,
|
|
46
|
+
v_matmul: NodeProto,
|
|
47
|
+
num_heads: int,
|
|
48
|
+
hidden_size: int,
|
|
49
|
+
input: str,
|
|
50
|
+
output: str,
|
|
51
|
+
attn_bias: str | None,
|
|
52
|
+
scale: float,
|
|
53
|
+
) -> NodeProto | None:
|
|
54
|
+
"""Create an Attention node.
|
|
55
|
+
Args:
|
|
56
|
+
mask_index (str): mask input
|
|
57
|
+
q_matmul (NodeProto): MatMul node in fully connection for Q
|
|
58
|
+
k_matmul (NodeProto): MatMul node in fully connection for K
|
|
59
|
+
v_matmul (NodeProto): MatMul node in fully connection for V
|
|
60
|
+
num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
|
|
61
|
+
hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
|
|
62
|
+
input (str): input name
|
|
63
|
+
output (str): output name
|
|
64
|
+
Returns:
|
|
65
|
+
Union[NodeProto, None]: the node created or None if failed.
|
|
66
|
+
"""
|
|
67
|
+
assert num_heads > 0
|
|
68
|
+
|
|
69
|
+
if hidden_size > 0 and (hidden_size % num_heads) != 0:
|
|
70
|
+
logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
q_weight = self.model.get_initializer(q_matmul.input[1])
|
|
74
|
+
k_weight = self.model.get_initializer(k_matmul.input[1])
|
|
75
|
+
v_weight = self.model.get_initializer(v_matmul.input[1])
|
|
76
|
+
|
|
77
|
+
if q_weight is None or k_weight is None or v_weight is None:
|
|
78
|
+
matmul = q_matmul if q_weight is None else k_matmul if k_weight is None else v_matmul
|
|
79
|
+
print(
|
|
80
|
+
f"{matmul.input[1]} is not an initializer. "
|
|
81
|
+
"Please set do_constant_folding=True in torch.onnx.export to unblock attention fusion"
|
|
82
|
+
)
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
qw = NumpyHelper.to_array(q_weight)
|
|
86
|
+
kw = NumpyHelper.to_array(k_weight)
|
|
87
|
+
vw = NumpyHelper.to_array(v_weight)
|
|
88
|
+
|
|
89
|
+
# assert q and k have same shape as expected
|
|
90
|
+
assert qw.shape == kw.shape
|
|
91
|
+
|
|
92
|
+
qw_in_size = qw.shape[0]
|
|
93
|
+
kw_in_size = kw.shape[0]
|
|
94
|
+
vw_in_size = vw.shape[0]
|
|
95
|
+
|
|
96
|
+
assert qw_in_size == kw_in_size == vw_in_size
|
|
97
|
+
|
|
98
|
+
if hidden_size > 0 and hidden_size != qw_in_size:
|
|
99
|
+
logger.warning(
|
|
100
|
+
f"Input hidden size ({hidden_size}) is not same as weight matrix dimension of q,k,v ({qw_in_size}). "
|
|
101
|
+
"Please provide a correct input hidden size or pass in 0"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
qw_out_size = np.prod(qw.shape[1:])
|
|
105
|
+
qkv_weight = np.stack((qw, kw, vw), axis=1)
|
|
106
|
+
qkv_weight_dim = 3 * qw_out_size
|
|
107
|
+
|
|
108
|
+
attention_node_name = self.model.create_node_name("Attention")
|
|
109
|
+
|
|
110
|
+
weight = helper.make_tensor(
|
|
111
|
+
name=attention_node_name + "_qkv_weight",
|
|
112
|
+
data_type=TensorProto.FLOAT,
|
|
113
|
+
dims=[qw_in_size, qkv_weight_dim],
|
|
114
|
+
vals=qkv_weight.tobytes(),
|
|
115
|
+
raw=True,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
self.model.add_initializer(weight, self.this_graph_name)
|
|
119
|
+
|
|
120
|
+
attention_inputs = [
|
|
121
|
+
input,
|
|
122
|
+
attention_node_name + "_qkv_weight",
|
|
123
|
+
"",
|
|
124
|
+
]
|
|
125
|
+
if mask_index:
|
|
126
|
+
attention_inputs.append(mask_index)
|
|
127
|
+
else:
|
|
128
|
+
attention_inputs.append("")
|
|
129
|
+
|
|
130
|
+
if attn_bias:
|
|
131
|
+
attention_inputs.append("") # no past
|
|
132
|
+
attention_inputs.append(attn_bias)
|
|
133
|
+
|
|
134
|
+
while attention_inputs and attention_inputs[-1] == "":
|
|
135
|
+
attention_inputs.pop()
|
|
136
|
+
|
|
137
|
+
attention_node = helper.make_node(
|
|
138
|
+
"Attention",
|
|
139
|
+
inputs=attention_inputs,
|
|
140
|
+
outputs=[output],
|
|
141
|
+
name=attention_node_name,
|
|
142
|
+
)
|
|
143
|
+
attention_node.domain = "com.microsoft"
|
|
144
|
+
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
|
|
145
|
+
|
|
146
|
+
if scale is not None:
|
|
147
|
+
attention_node.attribute.extend([helper.make_attribute("scale", scale)])
|
|
148
|
+
|
|
149
|
+
if self.mask_filter_value is not None:
|
|
150
|
+
attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
|
|
151
|
+
|
|
152
|
+
return attention_node
|
|
153
|
+
|
|
154
|
+
def create_mha_node(
|
|
155
|
+
self,
|
|
156
|
+
query: str,
|
|
157
|
+
key: str,
|
|
158
|
+
value: str,
|
|
159
|
+
mask_index: str | None,
|
|
160
|
+
attn_bias: str | None,
|
|
161
|
+
past_key: str | None,
|
|
162
|
+
past_value: str | None,
|
|
163
|
+
output: str,
|
|
164
|
+
present_key: str | None,
|
|
165
|
+
present_value: str | None,
|
|
166
|
+
num_heads: int,
|
|
167
|
+
hidden_size: int,
|
|
168
|
+
) -> NodeProto | None:
|
|
169
|
+
assert num_heads > 0 and hidden_size > 0 and query and key and value
|
|
170
|
+
|
|
171
|
+
if (hidden_size % num_heads) != 0:
|
|
172
|
+
logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
|
|
173
|
+
return None
|
|
174
|
+
|
|
175
|
+
attention_node_name = self.model.create_node_name("MultiHeadAttention")
|
|
176
|
+
attention_inputs = [
|
|
177
|
+
query,
|
|
178
|
+
key,
|
|
179
|
+
value,
|
|
180
|
+
"", # bias
|
|
181
|
+
]
|
|
182
|
+
|
|
183
|
+
if mask_index:
|
|
184
|
+
attention_inputs.append(mask_index)
|
|
185
|
+
else:
|
|
186
|
+
attention_inputs.append("")
|
|
187
|
+
|
|
188
|
+
if attn_bias:
|
|
189
|
+
attention_inputs.append(attn_bias)
|
|
190
|
+
else:
|
|
191
|
+
attention_inputs.append("")
|
|
192
|
+
|
|
193
|
+
if past_key:
|
|
194
|
+
assert past_value
|
|
195
|
+
attention_inputs.append(past_key)
|
|
196
|
+
attention_inputs.append(past_value)
|
|
197
|
+
|
|
198
|
+
while attention_inputs and attention_inputs[-1] == "":
|
|
199
|
+
attention_inputs.pop()
|
|
200
|
+
|
|
201
|
+
attention_outputs = [output]
|
|
202
|
+
if present_key:
|
|
203
|
+
assert present_value
|
|
204
|
+
attention_outputs.append(present_key)
|
|
205
|
+
attention_outputs.append(present_value)
|
|
206
|
+
|
|
207
|
+
print(f"{attention_inputs=}, {attention_outputs=}, {attention_node_name=}")
|
|
208
|
+
attention_node = helper.make_node(
|
|
209
|
+
"MultiHeadAttention",
|
|
210
|
+
inputs=attention_inputs,
|
|
211
|
+
outputs=attention_outputs,
|
|
212
|
+
name=attention_node_name,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
attention_node.domain = "com.microsoft"
|
|
216
|
+
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
|
|
217
|
+
attention_node.attribute.extend([helper.make_attribute("scale", 1.0)])
|
|
218
|
+
if self.mask_filter_value is not None:
|
|
219
|
+
attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
|
|
220
|
+
|
|
221
|
+
self.increase_counter("MultiHeadAttention")
|
|
222
|
+
return attention_node
|
|
223
|
+
|
|
224
|
+
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
|
225
|
+
if self.fuse_t5_encoder(node, input_name_to_nodes, output_name_to_node):
|
|
226
|
+
return
|
|
227
|
+
|
|
228
|
+
self.fuse_t5_decoder(node, input_name_to_nodes, output_name_to_node)
|
|
229
|
+
|
|
230
|
+
def fuse_t5_encoder(self, softmax_node, input_name_to_nodes, output_name_to_node):
|
|
231
|
+
assert softmax_node.op_type == "Softmax"
|
|
232
|
+
qkv_nodes = self.model.match_child_path(
|
|
233
|
+
softmax_node,
|
|
234
|
+
["MatMul", "Transpose", "Reshape"],
|
|
235
|
+
edges=[(0, 0), (0, 0), (0, 0)],
|
|
236
|
+
input_name_to_nodes=input_name_to_nodes,
|
|
237
|
+
)
|
|
238
|
+
if qkv_nodes is None:
|
|
239
|
+
return False
|
|
240
|
+
matmul_qkv, _, reshape_qkv = qkv_nodes
|
|
241
|
+
|
|
242
|
+
qkv_shape_nodes = self.model.match_parent_path(
|
|
243
|
+
reshape_qkv,
|
|
244
|
+
["Concat", "Unsqueeze", "Gather", "Shape"],
|
|
245
|
+
[1, 0, 0, 0],
|
|
246
|
+
output_name_to_node,
|
|
247
|
+
)
|
|
248
|
+
if qkv_shape_nodes is None:
|
|
249
|
+
return False
|
|
250
|
+
input_shape_node = qkv_shape_nodes[-1]
|
|
251
|
+
|
|
252
|
+
v_nodes = self.model.match_parent_path(
|
|
253
|
+
matmul_qkv,
|
|
254
|
+
["Transpose", "Reshape", "MatMul"],
|
|
255
|
+
[1, 0, 0],
|
|
256
|
+
output_name_to_node,
|
|
257
|
+
)
|
|
258
|
+
if v_nodes is None:
|
|
259
|
+
return False
|
|
260
|
+
_, reshape_v, matmul_v = v_nodes
|
|
261
|
+
# todo: check reshape_v parent nodes
|
|
262
|
+
|
|
263
|
+
qk_nodes = self.model.match_parent_path(
|
|
264
|
+
matmul_qkv,
|
|
265
|
+
["Softmax", "Add", "MatMul"],
|
|
266
|
+
[0, 0, 0],
|
|
267
|
+
output_name_to_node,
|
|
268
|
+
)
|
|
269
|
+
if qk_nodes is None:
|
|
270
|
+
return False
|
|
271
|
+
_, add_qk, matmul_qk = qk_nodes
|
|
272
|
+
|
|
273
|
+
mask_nodes = self.model.match_parent_path(
|
|
274
|
+
add_qk,
|
|
275
|
+
["Add", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
|
|
276
|
+
[1, 1, 0, 1, 0, 0],
|
|
277
|
+
output_name_to_node,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
is_pattern_for_one_graph_input = mask_nodes is None
|
|
281
|
+
if mask_nodes is not None:
|
|
282
|
+
mul_node = mask_nodes[1]
|
|
283
|
+
else:
|
|
284
|
+
# Pattern for SD3 and Flux.
|
|
285
|
+
mask_nodes = self.model.match_parent_path(
|
|
286
|
+
add_qk,
|
|
287
|
+
["Add", "Slice", "Mul", "Sub", "Unsqueeze", "Unsqueeze"],
|
|
288
|
+
[1, 1, 0, 0, 1, 0],
|
|
289
|
+
output_name_to_node,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
# If the model is not optimized by ORT, there might be an additional Cast node.
|
|
293
|
+
if mask_nodes is None:
|
|
294
|
+
mask_nodes = self.model.match_parent_path(
|
|
295
|
+
add_qk,
|
|
296
|
+
["Add", "Slice", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
|
|
297
|
+
[1, 1, 0, 0, 1, 0, 0],
|
|
298
|
+
output_name_to_node,
|
|
299
|
+
)
|
|
300
|
+
if mask_nodes is None:
|
|
301
|
+
return False
|
|
302
|
+
mul_node = mask_nodes[2]
|
|
303
|
+
|
|
304
|
+
_, mul_val = self.model.get_constant_input(mul_node)
|
|
305
|
+
if mul_val is None:
|
|
306
|
+
return False
|
|
307
|
+
|
|
308
|
+
if mul_val != -10000:
|
|
309
|
+
self.mask_filter_value = float(mul_val)
|
|
310
|
+
|
|
311
|
+
# If the mask is derived from shape of input_ids, it means there is no padding mask.
|
|
312
|
+
mask_nodes_2 = self.model.match_parent_path(
|
|
313
|
+
mask_nodes[-1],
|
|
314
|
+
["ConstantOfShape", "Concat", "Unsqueeze", "Gather", "Shape"],
|
|
315
|
+
[0, 0, 0, 0, 0],
|
|
316
|
+
output_name_to_node,
|
|
317
|
+
)
|
|
318
|
+
mask_nodes_3 = self.model.match_parent_path(
|
|
319
|
+
mask_nodes[-1],
|
|
320
|
+
["ConstantOfShape", "Concat", "Unsqueeze", "Gather", "Shape"],
|
|
321
|
+
[0, 0, 1, 0, 0],
|
|
322
|
+
output_name_to_node,
|
|
323
|
+
)
|
|
324
|
+
if (
|
|
325
|
+
mask_nodes_2 is not None
|
|
326
|
+
and any(input.name == mask_nodes_2[-1].input[0] for input in self.model.graph().input)
|
|
327
|
+
and mask_nodes_3 is not None
|
|
328
|
+
and mask_nodes_2[-1].input[0] == mask_nodes_3[-1].input[0]
|
|
329
|
+
and len(mask_nodes_2[1].input) == 2
|
|
330
|
+
):
|
|
331
|
+
mask_index = ""
|
|
332
|
+
else:
|
|
333
|
+
mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
|
|
334
|
+
|
|
335
|
+
res_pos_bias = None
|
|
336
|
+
rpb_nodes = self.model.match_parent_path(
|
|
337
|
+
add_qk,
|
|
338
|
+
["Add", "RelativePositionBias"],
|
|
339
|
+
[1, 0],
|
|
340
|
+
)
|
|
341
|
+
if rpb_nodes is None and is_pattern_for_one_graph_input:
|
|
342
|
+
# Pattern for SD3 and Flux.
|
|
343
|
+
rpb_nodes = self.model.match_parent_path(
|
|
344
|
+
add_qk,
|
|
345
|
+
["Add", "Slice", "RelativePositionBias"],
|
|
346
|
+
[1, 0, 0],
|
|
347
|
+
)
|
|
348
|
+
if rpb_nodes is None:
|
|
349
|
+
return False
|
|
350
|
+
|
|
351
|
+
res_pos_bias = rpb_nodes[-1].output[0]
|
|
352
|
+
|
|
353
|
+
k_nodes = self.model.match_parent_path(
|
|
354
|
+
matmul_qk,
|
|
355
|
+
["Transpose", "Reshape", "MatMul"],
|
|
356
|
+
[1, 0, 0],
|
|
357
|
+
)
|
|
358
|
+
if k_nodes is None:
|
|
359
|
+
return False
|
|
360
|
+
_, _, matmul_k = k_nodes
|
|
361
|
+
# todo: check reshape_k parent nodes
|
|
362
|
+
|
|
363
|
+
q_nodes = self.model.match_parent_path(
|
|
364
|
+
matmul_qk,
|
|
365
|
+
["Transpose", "Reshape", "MatMul"],
|
|
366
|
+
[0, 0, 0],
|
|
367
|
+
)
|
|
368
|
+
if q_nodes is None:
|
|
369
|
+
return False
|
|
370
|
+
|
|
371
|
+
_, reshape_q, matmul_q = q_nodes
|
|
372
|
+
# todo: check reshape_q parent nodes
|
|
373
|
+
|
|
374
|
+
if matmul_q.input[0] != input_shape_node.input[0]:
|
|
375
|
+
return False
|
|
376
|
+
|
|
377
|
+
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
|
|
378
|
+
|
|
379
|
+
new_node = self.make_attention_node(
|
|
380
|
+
mask_index,
|
|
381
|
+
matmul_q,
|
|
382
|
+
matmul_k,
|
|
383
|
+
matmul_v,
|
|
384
|
+
num_heads=q_num_heads,
|
|
385
|
+
hidden_size=q_hidden_size,
|
|
386
|
+
input=input_shape_node.input[0],
|
|
387
|
+
output=reshape_qkv.output[0],
|
|
388
|
+
attn_bias=res_pos_bias,
|
|
389
|
+
scale=1.0,
|
|
390
|
+
)
|
|
391
|
+
if new_node is None:
|
|
392
|
+
return False
|
|
393
|
+
|
|
394
|
+
self.nodes_to_add.append(new_node)
|
|
395
|
+
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
|
396
|
+
|
|
397
|
+
self.nodes_to_remove.append(reshape_qkv)
|
|
398
|
+
self.prune_graph = True
|
|
399
|
+
return True
|
|
400
|
+
|
|
401
|
+
def fuse_t5_decoder(self, softmax_node, input_name_to_nodes, output_name_to_node):
|
|
402
|
+
assert softmax_node.op_type == "Softmax"
|
|
403
|
+
|
|
404
|
+
qkv_nodes = self.model.match_child_path(
|
|
405
|
+
softmax_node,
|
|
406
|
+
["MatMul", "Transpose", "Reshape"],
|
|
407
|
+
edges=[(0, 0), (0, 0), (0, 0)],
|
|
408
|
+
input_name_to_nodes=input_name_to_nodes,
|
|
409
|
+
)
|
|
410
|
+
if qkv_nodes is None:
|
|
411
|
+
return
|
|
412
|
+
matmul_qkv, _transpose_qkv, reshape_qkv = qkv_nodes
|
|
413
|
+
|
|
414
|
+
qkv_shape_nodes = self.model.match_parent_path(
|
|
415
|
+
reshape_qkv,
|
|
416
|
+
["Concat", "Unsqueeze", "Gather", "Shape"],
|
|
417
|
+
[1, 0, 0, 0],
|
|
418
|
+
)
|
|
419
|
+
if qkv_shape_nodes is None:
|
|
420
|
+
return
|
|
421
|
+
input_shape_node = qkv_shape_nodes[-1]
|
|
422
|
+
|
|
423
|
+
value = None
|
|
424
|
+
past_value = None
|
|
425
|
+
present_value = None
|
|
426
|
+
v_nodes = self.model.match_parent_path(
|
|
427
|
+
matmul_qkv,
|
|
428
|
+
["Concat", "Transpose", "Reshape", "MatMul"],
|
|
429
|
+
[1, 1, 0, 0],
|
|
430
|
+
)
|
|
431
|
+
if v_nodes is None:
|
|
432
|
+
v_nodes = self.model.match_parent_path(
|
|
433
|
+
matmul_qkv,
|
|
434
|
+
["Transpose", "Reshape", "MatMul"],
|
|
435
|
+
[1, 0, 0],
|
|
436
|
+
)
|
|
437
|
+
if v_nodes is not None:
|
|
438
|
+
transpose_v, reshape_v, matmul_v = v_nodes
|
|
439
|
+
value = reshape_v.input[0]
|
|
440
|
+
present_value = transpose_v.output[0]
|
|
441
|
+
if "present_value" not in present_value:
|
|
442
|
+
return
|
|
443
|
+
if matmul_v.input[0] != input_shape_node.input[0]:
|
|
444
|
+
self.static_kv = 1
|
|
445
|
+
else:
|
|
446
|
+
self.static_kv = 0
|
|
447
|
+
else:
|
|
448
|
+
past_value = matmul_qkv.input[1]
|
|
449
|
+
if past_value in output_name_to_node:
|
|
450
|
+
return
|
|
451
|
+
if "past_value_cross" not in past_value:
|
|
452
|
+
return
|
|
453
|
+
self.static_kv = 1
|
|
454
|
+
else:
|
|
455
|
+
concat_v, _, reshape_v, _ = v_nodes
|
|
456
|
+
past_value = concat_v.input[0]
|
|
457
|
+
if past_value in output_name_to_node:
|
|
458
|
+
return
|
|
459
|
+
if "past_value_self" not in past_value:
|
|
460
|
+
return
|
|
461
|
+
present_value = concat_v.output[0]
|
|
462
|
+
if "present_value_self" not in present_value:
|
|
463
|
+
return
|
|
464
|
+
value = reshape_v.input[0]
|
|
465
|
+
self.static_kv = 0
|
|
466
|
+
|
|
467
|
+
qk_nodes = self.model.match_parent_path(
|
|
468
|
+
matmul_qkv,
|
|
469
|
+
["Softmax", "Add", "MatMul"],
|
|
470
|
+
[0, 0, 0],
|
|
471
|
+
)
|
|
472
|
+
if qk_nodes is None:
|
|
473
|
+
return
|
|
474
|
+
_, add_qk, matmul_qk = qk_nodes
|
|
475
|
+
|
|
476
|
+
mask_index = None
|
|
477
|
+
res_pos_bias = None
|
|
478
|
+
if self.static_kv == 1:
|
|
479
|
+
mask_nodes = self.model.match_parent_path(
|
|
480
|
+
add_qk,
|
|
481
|
+
["Add", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
|
|
482
|
+
[1, 1, 0, 1, 0, 0],
|
|
483
|
+
)
|
|
484
|
+
if mask_nodes is not None:
|
|
485
|
+
mul_node = mask_nodes[1]
|
|
486
|
+
else:
|
|
487
|
+
mask_nodes = self.model.match_parent_path(
|
|
488
|
+
add_qk,
|
|
489
|
+
["Add", "Slice", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
|
|
490
|
+
[1, 1, 0, 0, 1, 0, 0],
|
|
491
|
+
)
|
|
492
|
+
if mask_nodes is None:
|
|
493
|
+
return
|
|
494
|
+
mul_node = mask_nodes[2]
|
|
495
|
+
|
|
496
|
+
_, mul_val = self.model.get_constant_input(mul_node)
|
|
497
|
+
if mul_val != -10000:
|
|
498
|
+
self.mask_filter_value = mul_val
|
|
499
|
+
|
|
500
|
+
mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
|
|
501
|
+
else:
|
|
502
|
+
matched_path_index, _, _ = self.model.match_parent_paths(
|
|
503
|
+
add_qk,
|
|
504
|
+
[
|
|
505
|
+
(["Add", "Slice"], [1, 0]),
|
|
506
|
+
(["Add", "RelativePositionBias"], [1, 0]),
|
|
507
|
+
],
|
|
508
|
+
output_name_to_node,
|
|
509
|
+
)
|
|
510
|
+
if matched_path_index < 0:
|
|
511
|
+
logger.debug("Skip MultiHeadAttention fusion since attention bias pattern not matched")
|
|
512
|
+
return
|
|
513
|
+
|
|
514
|
+
res_pos_bias = add_qk.input[1]
|
|
515
|
+
|
|
516
|
+
key = None
|
|
517
|
+
past_key = None
|
|
518
|
+
present_key = None
|
|
519
|
+
if self.static_kv == 1:
|
|
520
|
+
k_nodes = self.model.match_parent_path(
|
|
521
|
+
matmul_qk,
|
|
522
|
+
["Transpose", "Reshape", "MatMul"],
|
|
523
|
+
[1, 0, 0],
|
|
524
|
+
)
|
|
525
|
+
if k_nodes is not None:
|
|
526
|
+
transpose_k, reshape_k, _ = k_nodes
|
|
527
|
+
key = reshape_k.input[0]
|
|
528
|
+
present_key_transpose_nodes = input_name_to_nodes[reshape_k.output[0]]
|
|
529
|
+
for present_key_transpose_node in present_key_transpose_nodes:
|
|
530
|
+
present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0])
|
|
531
|
+
if present_key_candidate is not None:
|
|
532
|
+
present_key = present_key_candidate.name
|
|
533
|
+
break
|
|
534
|
+
if present_key is None:
|
|
535
|
+
return
|
|
536
|
+
if "present_key_cross" not in present_key:
|
|
537
|
+
return
|
|
538
|
+
else:
|
|
539
|
+
k_nodes = self.model.match_parent_path(
|
|
540
|
+
matmul_qk,
|
|
541
|
+
["Transpose"],
|
|
542
|
+
[1],
|
|
543
|
+
)
|
|
544
|
+
if k_nodes is None:
|
|
545
|
+
return
|
|
546
|
+
transpose_k = k_nodes[0]
|
|
547
|
+
|
|
548
|
+
past_key = transpose_k.input[0]
|
|
549
|
+
if past_key in output_name_to_node:
|
|
550
|
+
return
|
|
551
|
+
if "past_key_cross" not in past_key:
|
|
552
|
+
return
|
|
553
|
+
else:
|
|
554
|
+
idx, k_nodes, _ = self.model.match_parent_paths(
|
|
555
|
+
matmul_qk,
|
|
556
|
+
[
|
|
557
|
+
(["Transpose", "Concat", "Reshape", "MatMul"], [1, 0, 1, 0]),
|
|
558
|
+
(["Transpose", "Concat", "Transpose", "Reshape", "MatMul"], [1, 0, 1, 0, 0]),
|
|
559
|
+
],
|
|
560
|
+
output_name_to_node,
|
|
561
|
+
)
|
|
562
|
+
past_key_transpose_node = None
|
|
563
|
+
present_key_transpose_nodes = None
|
|
564
|
+
if k_nodes is not None:
|
|
565
|
+
concat_k, reshape_k = k_nodes[1], k_nodes[-2]
|
|
566
|
+
key = reshape_k.input[0]
|
|
567
|
+
|
|
568
|
+
if idx == 0:
|
|
569
|
+
past_key_transpose_node = output_name_to_node[concat_k.input[0]]
|
|
570
|
+
past_key = past_key_transpose_node.input[0]
|
|
571
|
+
else:
|
|
572
|
+
past_key = concat_k.input[0]
|
|
573
|
+
if past_key in output_name_to_node:
|
|
574
|
+
return
|
|
575
|
+
if "past_key_self" not in past_key:
|
|
576
|
+
return
|
|
577
|
+
|
|
578
|
+
if idx == 0:
|
|
579
|
+
present_key_transpose_nodes = input_name_to_nodes[concat_k.output[0]]
|
|
580
|
+
for present_key_transpose_node in present_key_transpose_nodes:
|
|
581
|
+
present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0])
|
|
582
|
+
if present_key_candidate is not None:
|
|
583
|
+
present_key = present_key_candidate.name
|
|
584
|
+
break
|
|
585
|
+
else:
|
|
586
|
+
present_key = concat_k.output[0]
|
|
587
|
+
if present_key is None:
|
|
588
|
+
return
|
|
589
|
+
if "present_key_self" not in present_key:
|
|
590
|
+
return
|
|
591
|
+
else:
|
|
592
|
+
k_nodes = self.model.match_parent_path(
|
|
593
|
+
matmul_qk,
|
|
594
|
+
["Transpose", "Reshape", "MatMul"],
|
|
595
|
+
[1, 0, 0],
|
|
596
|
+
)
|
|
597
|
+
if k_nodes is None:
|
|
598
|
+
return
|
|
599
|
+
_, reshape_k, _ = k_nodes
|
|
600
|
+
key = reshape_k.input[0]
|
|
601
|
+
present_key_transpose_nodes = input_name_to_nodes[reshape_k.output[0]]
|
|
602
|
+
for present_key_transpose_node in present_key_transpose_nodes:
|
|
603
|
+
present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0])
|
|
604
|
+
if present_key_candidate is not None:
|
|
605
|
+
present_key = present_key_candidate.name
|
|
606
|
+
break
|
|
607
|
+
if present_key is None:
|
|
608
|
+
return
|
|
609
|
+
if "present_key_self" not in present_key:
|
|
610
|
+
return
|
|
611
|
+
|
|
612
|
+
q_nodes = self.model.match_parent_path(
|
|
613
|
+
matmul_qk,
|
|
614
|
+
["Transpose", "Reshape", "MatMul"],
|
|
615
|
+
[0, 0, 0],
|
|
616
|
+
)
|
|
617
|
+
if q_nodes is None:
|
|
618
|
+
return
|
|
619
|
+
|
|
620
|
+
transpose_q, reshape_q, matmul_q = q_nodes
|
|
621
|
+
|
|
622
|
+
if matmul_q.input[0] != input_shape_node.input[0]:
|
|
623
|
+
return
|
|
624
|
+
|
|
625
|
+
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
|
|
626
|
+
|
|
627
|
+
if self.static_kv == 1 and past_key is not None:
|
|
628
|
+
key = past_key
|
|
629
|
+
value = past_value
|
|
630
|
+
past_key = None
|
|
631
|
+
past_value = None
|
|
632
|
+
|
|
633
|
+
if not (key and value and q_num_heads > 0 and q_hidden_size > 0):
|
|
634
|
+
return
|
|
635
|
+
|
|
636
|
+
new_node = self.create_mha_node(
|
|
637
|
+
query=matmul_q.output[0],
|
|
638
|
+
key=key,
|
|
639
|
+
value=value,
|
|
640
|
+
mask_index=mask_index,
|
|
641
|
+
attn_bias=res_pos_bias,
|
|
642
|
+
past_key=past_key,
|
|
643
|
+
past_value=past_value,
|
|
644
|
+
output=reshape_qkv.output[0],
|
|
645
|
+
present_key=present_key,
|
|
646
|
+
present_value=present_value,
|
|
647
|
+
num_heads=q_num_heads,
|
|
648
|
+
hidden_size=q_hidden_size,
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
if new_node:
|
|
652
|
+
self.nodes_to_add.append(new_node)
|
|
653
|
+
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
|
654
|
+
|
|
655
|
+
# Since present_* is graph output, we need update the graph to avoid circular.
|
|
656
|
+
if present_key or present_value:
|
|
657
|
+
for graph_output in [present_key, present_value]:
|
|
658
|
+
if not (graph_output and self.model.find_graph_output(graph_output)):
|
|
659
|
+
print(f"{graph_output=} does not exist in graph output")
|
|
660
|
+
return
|
|
661
|
+
assert graph_output in output_name_to_node
|
|
662
|
+
output_name_to_node[graph_output].output[0] = graph_output + "_copy"
|
|
663
|
+
self.model.replace_input_of_all_nodes(graph_output, graph_output + "_copy")
|
|
664
|
+
|
|
665
|
+
self.nodes_to_remove.append(reshape_qkv)
|
|
666
|
+
self.prune_graph = False
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
class FusionRelativePositionBiasBlock(Fusion):
|
|
670
|
+
def __init__(self, model: OnnxModel):
|
|
671
|
+
super().__init__(model, "RelativePositionBias", ["Softmax"])
|
|
672
|
+
|
|
673
|
+
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
|
674
|
+
compute_bias_nodes = self.model.match_parent_path(
|
|
675
|
+
node,
|
|
676
|
+
["Add", "Add", "Slice", "Unsqueeze", "Transpose", "Gather", "Where"],
|
|
677
|
+
[0, 1, 0, 0, 0, 0, 1],
|
|
678
|
+
output_name_to_node,
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
if compute_bias_nodes is None:
|
|
682
|
+
compute_bias_nodes = self.model.match_parent_path(
|
|
683
|
+
node,
|
|
684
|
+
["Add", "Add", "Slice", "Unsqueeze", "Transpose", "Gather", "Add", "Where"],
|
|
685
|
+
[0, 1, 0, 0, 0, 0, 1, 1],
|
|
686
|
+
output_name_to_node,
|
|
687
|
+
)
|
|
688
|
+
if compute_bias_nodes is None:
|
|
689
|
+
return
|
|
690
|
+
|
|
691
|
+
gather = compute_bias_nodes[5]
|
|
692
|
+
where = compute_bias_nodes[-1]
|
|
693
|
+
slice = compute_bias_nodes[2]
|
|
694
|
+
unsqueeze = compute_bias_nodes[3]
|
|
695
|
+
|
|
696
|
+
# Current fusion will not remove the node until the graph is processed.
|
|
697
|
+
# This avoids to fuse it again when it is shared by multiple layers.
|
|
698
|
+
if unsqueeze in self.nodes_to_remove:
|
|
699
|
+
return
|
|
700
|
+
|
|
701
|
+
compute_buckets_nodes = self.model.match_parent_path(
|
|
702
|
+
where,
|
|
703
|
+
["Min", "ConstantOfShape", "Shape", "Add", "Cast", "Mul", "Div", "Log", "Div"],
|
|
704
|
+
[2, 1, 0, 0, 0, 0, 0, 0, 0],
|
|
705
|
+
output_name_to_node,
|
|
706
|
+
)
|
|
707
|
+
if compute_buckets_nodes is None:
|
|
708
|
+
return
|
|
709
|
+
|
|
710
|
+
# This value is to used to compute max_distance later.
|
|
711
|
+
log_max = self.model.get_constant_value(compute_buckets_nodes[-3].input[1])
|
|
712
|
+
|
|
713
|
+
div = compute_buckets_nodes[-1]
|
|
714
|
+
|
|
715
|
+
range_nodes = self.model.match_parent_path(
|
|
716
|
+
div,
|
|
717
|
+
["Cast", "Neg", "Min", "ConstantOfShape", "Shape", "Sub", "Unsqueeze", "Range"],
|
|
718
|
+
[0, 0, 0, 1, 0, 0, 0, 0],
|
|
719
|
+
output_name_to_node,
|
|
720
|
+
)
|
|
721
|
+
|
|
722
|
+
is_bidirectional = False
|
|
723
|
+
if range_nodes is None:
|
|
724
|
+
range_nodes = self.model.match_parent_path(
|
|
725
|
+
div, ["Cast", "Abs", "Sub", "Unsqueeze", "Range"], [0, 0, 0, 0, 0], output_name_to_node
|
|
726
|
+
)
|
|
727
|
+
is_bidirectional = True
|
|
728
|
+
if range_nodes is None:
|
|
729
|
+
return
|
|
730
|
+
range_node = range_nodes[-1]
|
|
731
|
+
|
|
732
|
+
# Double check that the constant relative to max_distance and relative_attention_num_buckets.
|
|
733
|
+
# Most t5 models use max_distance=128, so we hardcode it unitl we see a model with different value.
|
|
734
|
+
|
|
735
|
+
# The log_max is the value of the following formula:
|
|
736
|
+
# math.log(max_distance / (relative_attention_num_buckets // (4 if is_bidirectional else 2)))
|
|
737
|
+
# See https://github.com/huggingface/transformers/blob/608e163b527eaee41e650ffb9eb4c422d2679902/src/transformers/models/t5/modeling_t5.py#L397.
|
|
738
|
+
# Here is the value based on max_distance=128 and relative_attention_num_buckets=32:
|
|
739
|
+
max_distance = int(np.round(np.exp(log_max) * (32 // (4 if is_bidirectional else 2))))
|
|
740
|
+
if max_distance != 128:
|
|
741
|
+
logger.warning(
|
|
742
|
+
f"max_distance is {max_distance}, which is different from the default value 128. "
|
|
743
|
+
"Please double check the model configuration."
|
|
744
|
+
)
|
|
745
|
+
|
|
746
|
+
node_name = self.model.create_node_name(
|
|
747
|
+
"RelativePositionBias", name_prefix="RelPosBias_" + ("encoder" if is_bidirectional else "decoder")
|
|
748
|
+
)
|
|
749
|
+
|
|
750
|
+
table_weight_i = self.model.get_initializer(gather.input[0])
|
|
751
|
+
if table_weight_i is None:
|
|
752
|
+
return
|
|
753
|
+
table_weight = NumpyHelper.to_array(table_weight_i)
|
|
754
|
+
table_weight_t = np.transpose(table_weight)
|
|
755
|
+
bias_table = helper.make_tensor(
|
|
756
|
+
name=node_name + "_bias_table_weight",
|
|
757
|
+
data_type=TensorProto.FLOAT,
|
|
758
|
+
dims=[np.shape(table_weight)[0], np.shape(table_weight)[1]],
|
|
759
|
+
vals=table_weight_t.tobytes(),
|
|
760
|
+
raw=True,
|
|
761
|
+
)
|
|
762
|
+
self.model.add_initializer(bias_table, self.this_graph_name)
|
|
763
|
+
|
|
764
|
+
# Relative position is like the following in encoder:
|
|
765
|
+
# seq_len
|
|
766
|
+
# |
|
|
767
|
+
# Range(0, *)
|
|
768
|
+
# / \
|
|
769
|
+
# Unsqueeze(axes=0) Unsqueeze(axes=1)
|
|
770
|
+
# \ /
|
|
771
|
+
# Sub
|
|
772
|
+
# |
|
|
773
|
+
# Abs
|
|
774
|
+
#
|
|
775
|
+
# Relative position is like the following in decoder:
|
|
776
|
+
# past_seq_len seq_len
|
|
777
|
+
# \ /
|
|
778
|
+
# Add
|
|
779
|
+
# / \
|
|
780
|
+
# Range(0, *) Range(0, *)
|
|
781
|
+
# \ /
|
|
782
|
+
# Sub
|
|
783
|
+
# Note that the graph will slice the attention bias to get last seq_len rows.
|
|
784
|
+
#
|
|
785
|
+
# In new version of transformers, the pattern of decoder is changed like the following
|
|
786
|
+
#
|
|
787
|
+
# total_seq_len Range(start=past_seq_len, end=total_seq_len)
|
|
788
|
+
# | |
|
|
789
|
+
# Range(0, *) Unsqueeze(axes=1)
|
|
790
|
+
# | |
|
|
791
|
+
# Unsqueeze(axes=0) Cast(to=int64)
|
|
792
|
+
# \ /
|
|
793
|
+
# Sub
|
|
794
|
+
# Currently, there is still Slice to get last seq_len rows so end result is same.
|
|
795
|
+
# But need to be careful that the shape of bias tensor is changed before Slice.
|
|
796
|
+
#
|
|
797
|
+
# RelativePositionBias operator requires query_length == key_length so we shall pass in total_seq_len.
|
|
798
|
+
# Here we get the end value of the Range node as length to pass to the RelativePositionBias node.
|
|
799
|
+
|
|
800
|
+
# TODO: Optimization opportunity: change RelativePositionBias op to support query_length != key_length.
|
|
801
|
+
# only compute seq_len rows, then we can remove the Slice after the RelativePositionBias node.
|
|
802
|
+
inputs = [bias_table.name, range_node.input[1], range_node.input[1]]
|
|
803
|
+
|
|
804
|
+
# Use a new tensor name since the shape might be different as mentioned above.
|
|
805
|
+
bias_output = node_name + "_rel_pos_bias"
|
|
806
|
+
slice.input[0] = bias_output
|
|
807
|
+
|
|
808
|
+
rpb_node = helper.make_node(
|
|
809
|
+
"RelativePositionBias",
|
|
810
|
+
inputs=inputs,
|
|
811
|
+
outputs=[bias_output],
|
|
812
|
+
name=node_name,
|
|
813
|
+
)
|
|
814
|
+
rpb_node.domain = "com.microsoft"
|
|
815
|
+
rpb_node.attribute.extend([helper.make_attribute("max_distance", max_distance)])
|
|
816
|
+
rpb_node.attribute.extend([helper.make_attribute("is_bidirectional", is_bidirectional)])
|
|
817
|
+
self.node_name_to_graph_name[rpb_node.name] = self.this_graph_name
|
|
818
|
+
self.nodes_to_add.append(rpb_node)
|
|
819
|
+
self.prune_graph = True
|
|
820
|
+
|
|
821
|
+
|
|
822
|
+
class T5OnnxModel(BertOnnxModel):
|
|
823
|
+
def __init__(self, model, num_heads: int = 0, hidden_size: int = 0):
|
|
824
|
+
super().__init__(model, num_heads, hidden_size)
|
|
825
|
+
self.attention_mask = AttentionMask(self)
|
|
826
|
+
|
|
827
|
+
# When the model has only one input (input_ids), there is no padding mask.
|
|
828
|
+
if len(self.model.graph.input) == 1:
|
|
829
|
+
from fusion_options import AttentionMaskFormat # noqa: PLC0415
|
|
830
|
+
|
|
831
|
+
self.attention_mask.mask_format = AttentionMaskFormat.NoMask
|
|
832
|
+
|
|
833
|
+
self.attention_fusion = FusionT5Attention(self, self.hidden_size, self.num_heads, self.attention_mask)
|
|
834
|
+
self.layer_norm_fusion = FusionSimplifiedLayerNormalization(self)
|
|
835
|
+
self.skip_layer_norm_fusion = FusionSkipSimplifiedLayerNormalization(self)
|
|
836
|
+
self.rpb_fusion = FusionRelativePositionBiasBlock(self)
|
|
837
|
+
|
|
838
|
+
def fuse_attention(self):
|
|
839
|
+
self.attention_fusion.apply()
|
|
840
|
+
|
|
841
|
+
def fuse_layer_norm(self):
|
|
842
|
+
self.layer_norm_fusion.apply()
|
|
843
|
+
|
|
844
|
+
def fuse_skip_layer_norm(self, shape_infer=True):
|
|
845
|
+
self.skip_layer_norm_fusion.apply()
|
|
846
|
+
|
|
847
|
+
def adjust_rel_pos_bis_length_input(self):
|
|
848
|
+
# For T5 encoder, it uses complex logic to compute the query and key length when there is only one graph input (input_ids)
|
|
849
|
+
# We can directly get the length from shape (the 2nd dimension) of input_ids.
|
|
850
|
+
for node in self.nodes():
|
|
851
|
+
if node.op_type == "RelativePositionBias":
|
|
852
|
+
nodes = self.match_parent_path(
|
|
853
|
+
node,
|
|
854
|
+
[
|
|
855
|
+
"Gather",
|
|
856
|
+
"Shape",
|
|
857
|
+
"Transpose",
|
|
858
|
+
"Reshape",
|
|
859
|
+
"Concat",
|
|
860
|
+
"Unsqueeze",
|
|
861
|
+
"Gather",
|
|
862
|
+
"Shape",
|
|
863
|
+
"SimplifiedLayerNormalization",
|
|
864
|
+
"Gather",
|
|
865
|
+
],
|
|
866
|
+
[1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
|
|
867
|
+
)
|
|
868
|
+
# TODO: more validation on node attributes
|
|
869
|
+
if nodes is not None:
|
|
870
|
+
graph_input_names = [input.name for input in self.model.graph.input]
|
|
871
|
+
if nodes[-1].input[1] in graph_input_names:
|
|
872
|
+
node_name = self.create_node_name("Shape", name_prefix="Added_Shape_")
|
|
873
|
+
shape_node = helper.make_node(
|
|
874
|
+
"Shape",
|
|
875
|
+
inputs=[nodes[-1].input[1]],
|
|
876
|
+
outputs=[node_name + "_Output"],
|
|
877
|
+
name=node_name,
|
|
878
|
+
)
|
|
879
|
+
|
|
880
|
+
indices_1 = helper.make_tensor(
|
|
881
|
+
name="Constant_Index_1",
|
|
882
|
+
data_type=TensorProto.INT64,
|
|
883
|
+
dims=[1], # Shape of the tensor
|
|
884
|
+
vals=[1], # Tensor values
|
|
885
|
+
)
|
|
886
|
+
self.add_initializer(indices_1)
|
|
887
|
+
|
|
888
|
+
gather = helper.make_node(
|
|
889
|
+
"Gather",
|
|
890
|
+
inputs=[node_name + "_Output", "Constant_Index_1"],
|
|
891
|
+
outputs=[node_name + "_Output_Gather_1"],
|
|
892
|
+
name=self.create_node_name("Gather", name_prefix="Added_Gather_"),
|
|
893
|
+
axis=0,
|
|
894
|
+
)
|
|
895
|
+
|
|
896
|
+
self.add_node(shape_node)
|
|
897
|
+
self.add_node(gather)
|
|
898
|
+
node.input[1] = node_name + "_Output_Gather_1"
|
|
899
|
+
node.input[2] = node_name + "_Output_Gather_1"
|
|
900
|
+
|
|
901
|
+
break
|
|
902
|
+
|
|
903
|
+
# Remove get_extended_attention_mask() since it generates all zeros.
|
|
904
|
+
def remove_extended_mask_decoder_init(self):
|
|
905
|
+
nodes_to_remove = []
|
|
906
|
+
for node in self.nodes():
|
|
907
|
+
if node.op_type == "Add":
|
|
908
|
+
extended_mask_nodes = self.match_parent_path(
|
|
909
|
+
node,
|
|
910
|
+
[
|
|
911
|
+
"Mul",
|
|
912
|
+
"Sub",
|
|
913
|
+
"Mul",
|
|
914
|
+
"Unsqueeze",
|
|
915
|
+
"Cast",
|
|
916
|
+
"LessOrEqual",
|
|
917
|
+
"Tile",
|
|
918
|
+
"Concat",
|
|
919
|
+
"Unsqueeze",
|
|
920
|
+
"Gather",
|
|
921
|
+
"Shape",
|
|
922
|
+
],
|
|
923
|
+
[1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
|
|
924
|
+
)
|
|
925
|
+
if extended_mask_nodes is None:
|
|
926
|
+
continue
|
|
927
|
+
|
|
928
|
+
rpb_nodes = self.match_parent_path(node, ["RelativePositionBias"], [0])
|
|
929
|
+
if rpb_nodes is None:
|
|
930
|
+
continue
|
|
931
|
+
|
|
932
|
+
rpb_node = rpb_nodes[0]
|
|
933
|
+
rpb_node.output[0] = node.output[0]
|
|
934
|
+
|
|
935
|
+
nodes_to_remove.extend(extended_mask_nodes)
|
|
936
|
+
nodes_to_remove.append(node)
|
|
937
|
+
self.remove_nodes(nodes_to_remove)
|
|
938
|
+
|
|
939
|
+
def remove_extended_mask_decoder(self):
|
|
940
|
+
nodes_to_remove = []
|
|
941
|
+
for node in self.nodes():
|
|
942
|
+
if node.op_type == "Add":
|
|
943
|
+
extended_mask_nodes = self.match_parent_path(
|
|
944
|
+
node,
|
|
945
|
+
[
|
|
946
|
+
"Mul",
|
|
947
|
+
"Sub",
|
|
948
|
+
"Mul",
|
|
949
|
+
"Unsqueeze",
|
|
950
|
+
"Concat",
|
|
951
|
+
"Cast",
|
|
952
|
+
"LessOrEqual",
|
|
953
|
+
"Tile",
|
|
954
|
+
"Concat",
|
|
955
|
+
"Unsqueeze",
|
|
956
|
+
"Gather",
|
|
957
|
+
"Shape",
|
|
958
|
+
],
|
|
959
|
+
[1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0],
|
|
960
|
+
)
|
|
961
|
+
if extended_mask_nodes is None:
|
|
962
|
+
continue
|
|
963
|
+
|
|
964
|
+
rpb_nodes = self.match_parent_path(node, ["Slice", "RelativePositionBias"], [0, 0])
|
|
965
|
+
if rpb_nodes is None:
|
|
966
|
+
continue
|
|
967
|
+
|
|
968
|
+
rpb_node = rpb_nodes[0]
|
|
969
|
+
rpb_node.output[0] = node.output[0]
|
|
970
|
+
|
|
971
|
+
nodes_to_remove.extend(extended_mask_nodes)
|
|
972
|
+
nodes_to_remove.append(node)
|
|
973
|
+
self.remove_nodes(nodes_to_remove)
|
|
974
|
+
|
|
975
|
+
def preprocess(self):
|
|
976
|
+
self.adjust_reshape_and_expand()
|
|
977
|
+
self.rpb_fusion.apply()
|
|
978
|
+
|
|
979
|
+
def postprocess(self):
|
|
980
|
+
# remove get_extended_attention_mask() since it generates all zeros.
|
|
981
|
+
self.remove_extended_mask_decoder_init()
|
|
982
|
+
self.remove_extended_mask_decoder()
|
|
983
|
+
self.adjust_rel_pos_bis_length_input()
|
|
984
|
+
|
|
985
|
+
self.prune_graph()
|