onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,640 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from fusion_attention import AttentionMask, FusionAttention
|
|
9
|
+
from onnx import TensorProto, helper
|
|
10
|
+
from onnx_model import OnnxModel
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FusionBartAttention(FusionAttention):
|
|
16
|
+
"""
|
|
17
|
+
Fuse Bart Attention subgraph into one Attention node.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
model: OnnxModel,
|
|
23
|
+
hidden_size: int,
|
|
24
|
+
num_heads: int,
|
|
25
|
+
attention_mask: AttentionMask,
|
|
26
|
+
):
|
|
27
|
+
super().__init__(model, hidden_size, num_heads, attention_mask)
|
|
28
|
+
|
|
29
|
+
def check_runtime_shape_path(
|
|
30
|
+
self,
|
|
31
|
+
reshape_qkv_2,
|
|
32
|
+
reshape_qkv_1,
|
|
33
|
+
reshape_q_2,
|
|
34
|
+
reshape_k_2,
|
|
35
|
+
reshape_v_2,
|
|
36
|
+
root_input,
|
|
37
|
+
):
|
|
38
|
+
concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ["Concat"], [1])
|
|
39
|
+
if concat_qkv_2_path is None:
|
|
40
|
+
return False
|
|
41
|
+
concat_qkv_2 = concat_qkv_2_path[0]
|
|
42
|
+
|
|
43
|
+
reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
|
|
44
|
+
reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
|
|
45
|
+
if reshape_qkv_2_path_1 is None or reshape_qkv_2_path_2 is None:
|
|
46
|
+
return False
|
|
47
|
+
|
|
48
|
+
_, gather_1, shape_1 = reshape_qkv_2_path_1
|
|
49
|
+
_, gather_2, shape_2 = reshape_qkv_2_path_2
|
|
50
|
+
|
|
51
|
+
if shape_1.input[0] != root_input or shape_2.input[0] != root_input:
|
|
52
|
+
return False
|
|
53
|
+
|
|
54
|
+
reshape_qkv_1_path_1 = self.model.match_parent_path(reshape_qkv_1, ["Concat", "Unsqueeze", "Gather"], [1, 0, 0])
|
|
55
|
+
reshape_qkv_1_path_2 = self.model.match_parent_path(reshape_qkv_1, ["Concat", "Unsqueeze", "Gather"], [1, 2, 0])
|
|
56
|
+
if reshape_qkv_1_path_1 is None or reshape_qkv_1_path_2 is None:
|
|
57
|
+
return False
|
|
58
|
+
if reshape_qkv_1_path_1[-1].name != gather_1.name or reshape_qkv_1_path_2[-1].name != gather_2.name:
|
|
59
|
+
return False
|
|
60
|
+
|
|
61
|
+
reshape_q_2_path = self.model.match_parent_path(reshape_q_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0])
|
|
62
|
+
reshape_k_2_path = self.model.match_parent_path(reshape_k_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0])
|
|
63
|
+
reshape_v_2_path = self.model.match_parent_path(reshape_v_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0])
|
|
64
|
+
if reshape_q_2_path is None or reshape_k_2_path is None or reshape_v_2_path is None:
|
|
65
|
+
return False
|
|
66
|
+
|
|
67
|
+
mul_q = reshape_q_2_path[-1]
|
|
68
|
+
mul_k = reshape_k_2_path[-1]
|
|
69
|
+
mul_v = reshape_v_2_path[-1]
|
|
70
|
+
|
|
71
|
+
gather_1_out = gather_1.output[0]
|
|
72
|
+
if mul_q.input[0] != gather_1_out or mul_k.input[0] != gather_1_out or mul_v.input[0] != gather_1_out:
|
|
73
|
+
return False
|
|
74
|
+
|
|
75
|
+
return True
|
|
76
|
+
|
|
77
|
+
def check_runtime_shape_path_openai(
|
|
78
|
+
self,
|
|
79
|
+
reshape_qkv_2,
|
|
80
|
+
matmul_qkv,
|
|
81
|
+
add_qk,
|
|
82
|
+
matmul_qk,
|
|
83
|
+
add_q,
|
|
84
|
+
):
|
|
85
|
+
reshape_qkv_2_path = self.model.match_parent_path(
|
|
86
|
+
reshape_qkv_2, ["Concat", "Slice", "Gather", "Shape"], [1, 0, 0, 0]
|
|
87
|
+
)
|
|
88
|
+
if reshape_qkv_2_path is None:
|
|
89
|
+
return False
|
|
90
|
+
else:
|
|
91
|
+
if reshape_qkv_2_path[-1].input[0] != matmul_qkv.output[0]:
|
|
92
|
+
return False
|
|
93
|
+
|
|
94
|
+
matmul_qk_path_1 = self.model.match_parent_path(
|
|
95
|
+
matmul_qk, ["Mul", "Pow", "Cast", "Div", "Gather", "Shape"], [0, 1, 0, 0, 0, 0]
|
|
96
|
+
)
|
|
97
|
+
matmul_qk_path_2 = self.model.match_parent_path(
|
|
98
|
+
matmul_qk, ["Mul", "Pow", "Cast", "Div", "Gather", "Shape"], [1, 1, 0, 0, 0, 0]
|
|
99
|
+
)
|
|
100
|
+
if matmul_qk_path_1 is None or matmul_qk_path_2 is None:
|
|
101
|
+
return False
|
|
102
|
+
|
|
103
|
+
mul_1 = matmul_qk_path_1[0]
|
|
104
|
+
mul_2 = matmul_qk_path_2[0]
|
|
105
|
+
if mul_1.input[1] != mul_2.input[1]:
|
|
106
|
+
return False
|
|
107
|
+
if matmul_qk_path_1[-1].input[0] != add_q.output[0] and matmul_qk_path_2[-1].input[0] != add_q.output[0]:
|
|
108
|
+
return False
|
|
109
|
+
|
|
110
|
+
# For decoder attentions only
|
|
111
|
+
if add_qk is not None:
|
|
112
|
+
add_qk_path = self.model.match_parent_path(add_qk, ["Slice"], [1])
|
|
113
|
+
if add_qk_path is None:
|
|
114
|
+
return False
|
|
115
|
+
slice_q_path_1 = self.model.match_parent_path(
|
|
116
|
+
add_qk_path[0], ["Slice", "Unsqueeze", "Gather", "Shape"], [0, 2, 0, 0]
|
|
117
|
+
)
|
|
118
|
+
slice_q_path_2 = self.model.match_parent_path(add_qk_path[0], ["Unsqueeze", "Gather", "Shape"], [2, 0, 0])
|
|
119
|
+
if slice_q_path_1 is None and slice_q_path_2 is None:
|
|
120
|
+
return False
|
|
121
|
+
_, unsqueeze_1, _, _ = slice_q_path_1
|
|
122
|
+
unsqueeze_2, _, _ = slice_q_path_2
|
|
123
|
+
if unsqueeze_1.input[0] != unsqueeze_2.input[0]:
|
|
124
|
+
return False
|
|
125
|
+
if slice_q_path_1[-1].input[0] != add_q.output[0] and slice_q_path_2[-1].input[0] != add_q.output[0]:
|
|
126
|
+
return False
|
|
127
|
+
|
|
128
|
+
return True
|
|
129
|
+
|
|
130
|
+
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
|
131
|
+
# Track if fusion is occurring for OpenAI implementation of Whisper
|
|
132
|
+
model_impl_openai = False
|
|
133
|
+
|
|
134
|
+
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
|
|
135
|
+
qkv_nodes = self.model.match_parent_path(
|
|
136
|
+
normalize_node,
|
|
137
|
+
["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
|
|
138
|
+
[1, 1, 0, 0, 0, 0],
|
|
139
|
+
)
|
|
140
|
+
qkv_nodes_openai = self.model.match_parent_path(
|
|
141
|
+
normalize_node,
|
|
142
|
+
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
|
|
143
|
+
[1, 1, 0, 0, 0],
|
|
144
|
+
)
|
|
145
|
+
if qkv_nodes is not None:
|
|
146
|
+
(
|
|
147
|
+
add_out,
|
|
148
|
+
matmul_out,
|
|
149
|
+
reshape_qkv_2,
|
|
150
|
+
transpose_qkv,
|
|
151
|
+
reshape_qkv_1,
|
|
152
|
+
matmul_qkv,
|
|
153
|
+
) = qkv_nodes
|
|
154
|
+
elif qkv_nodes_openai is not None:
|
|
155
|
+
qkv_nodes = qkv_nodes_openai
|
|
156
|
+
(
|
|
157
|
+
add_out,
|
|
158
|
+
matmul_out,
|
|
159
|
+
reshape_qkv_2,
|
|
160
|
+
transpose_qkv,
|
|
161
|
+
matmul_qkv,
|
|
162
|
+
) = qkv_nodes
|
|
163
|
+
# Set model implementation to openai
|
|
164
|
+
model_impl_openai = True
|
|
165
|
+
else:
|
|
166
|
+
return
|
|
167
|
+
|
|
168
|
+
other_inputs = []
|
|
169
|
+
for input in normalize_node.input:
|
|
170
|
+
if input not in output_name_to_node:
|
|
171
|
+
continue
|
|
172
|
+
if input == qkv_nodes[0].output[0]:
|
|
173
|
+
continue
|
|
174
|
+
other_inputs.append(input)
|
|
175
|
+
if len(other_inputs) != 1:
|
|
176
|
+
return
|
|
177
|
+
root_input = other_inputs[0]
|
|
178
|
+
|
|
179
|
+
# Sometimes the input name to the attention MatMul nodes does not match the input name to the end
|
|
180
|
+
# SkipLayerNormalization node (name saved in root_input). We find the true input name to the MatMul
|
|
181
|
+
# nodes by getting the initial SkipLayerNormalization node and checking how many MatMul nodes are
|
|
182
|
+
# children nodes for each of its output names.
|
|
183
|
+
"""
|
|
184
|
+
root_input
|
|
185
|
+
+---------------------------------------------------+
|
|
186
|
+
| |
|
|
187
|
+
| |
|
|
188
|
+
SkipLayerNormalization --> Attention --> MatMul --> SkipLayerNormalization
|
|
189
|
+
"""
|
|
190
|
+
skip_layernorm = output_name_to_node[root_input]
|
|
191
|
+
# For some attention blocks, the end SkipLayerNormalization node may point to an Add node whose
|
|
192
|
+
# child is the LayerNormalization node.
|
|
193
|
+
if skip_layernorm.op_type == "Add":
|
|
194
|
+
skip_layernorm = self.model.get_children(skip_layernorm)[0]
|
|
195
|
+
for output in skip_layernorm.output:
|
|
196
|
+
if not output:
|
|
197
|
+
continue
|
|
198
|
+
children = input_name_to_nodes[output]
|
|
199
|
+
children_types = [child.op_type for child in children]
|
|
200
|
+
if children_types.count("MatMul") >= 1:
|
|
201
|
+
root_input = output
|
|
202
|
+
break
|
|
203
|
+
|
|
204
|
+
graph_input_names = set([node.name for node in self.model.graph().input])
|
|
205
|
+
graph_output_names = set([node.name for node in self.model.graph().output])
|
|
206
|
+
|
|
207
|
+
v_nodes = self.model.match_parent_path(
|
|
208
|
+
matmul_qkv,
|
|
209
|
+
["Reshape", "Transpose", "Reshape", "Add", "MatMul"],
|
|
210
|
+
[1, 0, 0, 0, None],
|
|
211
|
+
)
|
|
212
|
+
v_nodes_openai = self.model.match_parent_path(
|
|
213
|
+
matmul_qkv,
|
|
214
|
+
["Transpose", "Reshape", "Add", "MatMul"],
|
|
215
|
+
[1, 0, 0, None],
|
|
216
|
+
)
|
|
217
|
+
v_nodes_with_past_self_attn = self.model.match_parent_path(
|
|
218
|
+
# Decoder attention with past value concatenated before MatMul
|
|
219
|
+
matmul_qkv,
|
|
220
|
+
["Reshape", "Concat", "Transpose", "Reshape", "Add", "MatMul"],
|
|
221
|
+
[1, 0, 1, 0, 0, None],
|
|
222
|
+
)
|
|
223
|
+
v_nodes_with_past_cross_attn = self.model.match_parent_path(
|
|
224
|
+
# Decoder attention with past value directly used in MatMul
|
|
225
|
+
matmul_qkv,
|
|
226
|
+
["Reshape"],
|
|
227
|
+
[1],
|
|
228
|
+
)
|
|
229
|
+
v_nodes_with_past_cross_attn_openai = self.model.match_parent_path(
|
|
230
|
+
matmul_qkv,
|
|
231
|
+
["Transpose", "Reshape", "Reshape", "Transpose"],
|
|
232
|
+
[1, 0, 0, 0],
|
|
233
|
+
)
|
|
234
|
+
past_v, present_v = "", ""
|
|
235
|
+
reshape_v_2, add_v = None, None
|
|
236
|
+
if v_nodes is not None:
|
|
237
|
+
(reshape_v_2, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes
|
|
238
|
+
# For initial pass through encoder-decoder_with_past to get starting past values (beam search)
|
|
239
|
+
present_v = transpose_v.output[0]
|
|
240
|
+
elif v_nodes_openai is not None:
|
|
241
|
+
v_nodes = v_nodes_openai
|
|
242
|
+
(transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes
|
|
243
|
+
# For initial pass through encoder-decoder_with_past to get starting past values (beam search)
|
|
244
|
+
|
|
245
|
+
# Find the child path to access the correct present_v values
|
|
246
|
+
# Openai impl provides present/past v values in 3D format
|
|
247
|
+
# whereas ort MultiHeadAttention expects v values in 4D, hence the
|
|
248
|
+
# additional Reshape and Transpose nodes are added
|
|
249
|
+
# For encoder attention types
|
|
250
|
+
# Add -> Reshape -> Transpose -> Present_V
|
|
251
|
+
reshape_path = self.model.match_child_path(
|
|
252
|
+
add_v,
|
|
253
|
+
["Reshape", "Transpose"],
|
|
254
|
+
exclude=[reshape_v_1],
|
|
255
|
+
)
|
|
256
|
+
# For decoder attention types
|
|
257
|
+
# add_v_node Reshape <- Transpose <-Past_V
|
|
258
|
+
# \ /
|
|
259
|
+
# \ /
|
|
260
|
+
# -> Concat <-
|
|
261
|
+
# |
|
|
262
|
+
# |--> Reshape -> Transpose -> Present_V
|
|
263
|
+
concat_path = self.model.match_child_path(add_v, ["Concat", "Reshape", "Transpose"])
|
|
264
|
+
if reshape_path is not None:
|
|
265
|
+
(_, transpose_add_v) = reshape_path
|
|
266
|
+
if transpose_add_v.output[0] in graph_output_names:
|
|
267
|
+
present_v = transpose_add_v.output[0]
|
|
268
|
+
if concat_path is not None:
|
|
269
|
+
(concat_v, _, transpose_concat_v) = concat_path
|
|
270
|
+
if transpose_concat_v.output[0] in graph_output_names:
|
|
271
|
+
present_v = transpose_concat_v.output[0]
|
|
272
|
+
concat_nodes = self.model.match_parent_path(concat_v, ["Reshape", "Transpose"], [0, 0])
|
|
273
|
+
_, transpose_concat_v_in = concat_nodes
|
|
274
|
+
past_v = transpose_concat_v_in.input[0]
|
|
275
|
+
elif v_nodes_with_past_self_attn is not None:
|
|
276
|
+
(reshape_v_2, concat_v, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes_with_past_self_attn
|
|
277
|
+
v_nodes = v_nodes_with_past_self_attn
|
|
278
|
+
past_v = concat_v.input[0]
|
|
279
|
+
present_v = concat_v.output[0]
|
|
280
|
+
elif (
|
|
281
|
+
v_nodes_with_past_cross_attn is not None and v_nodes_with_past_cross_attn[-1].input[0] in graph_input_names
|
|
282
|
+
):
|
|
283
|
+
v_nodes = v_nodes_with_past_cross_attn
|
|
284
|
+
past_v = v_nodes[-1].input[0]
|
|
285
|
+
present_v = v_nodes[-1].output[0]
|
|
286
|
+
if present_v not in graph_output_names:
|
|
287
|
+
identity_node_v = list(
|
|
288
|
+
filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v])
|
|
289
|
+
)
|
|
290
|
+
present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else ""
|
|
291
|
+
elif (
|
|
292
|
+
v_nodes_with_past_cross_attn_openai is not None
|
|
293
|
+
and v_nodes_with_past_cross_attn_openai[-1].input[0] in graph_input_names
|
|
294
|
+
):
|
|
295
|
+
v_nodes = v_nodes_with_past_cross_attn_openai
|
|
296
|
+
past_v = v_nodes[-1].input[0]
|
|
297
|
+
present_v = v_nodes[-1].output[0]
|
|
298
|
+
if present_v not in graph_output_names:
|
|
299
|
+
identity_node_v = list(
|
|
300
|
+
filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v])
|
|
301
|
+
)
|
|
302
|
+
present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else ""
|
|
303
|
+
else:
|
|
304
|
+
logger.debug("fuse_attention: failed to match v path")
|
|
305
|
+
return
|
|
306
|
+
past_v = past_v if past_v in graph_input_names else ""
|
|
307
|
+
present_v = present_v if present_v in graph_output_names else ""
|
|
308
|
+
|
|
309
|
+
qk_nodes_1 = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
|
|
310
|
+
qk_nodes_2 = self.model.match_parent_path(
|
|
311
|
+
matmul_qkv, ["Softmax", "Reshape", "Add", "Reshape", "MatMul"], [0, 0, 0, 0, 0]
|
|
312
|
+
)
|
|
313
|
+
qk_nodes_2_openai = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
|
|
314
|
+
add_qk = None
|
|
315
|
+
if qk_nodes_1 is not None:
|
|
316
|
+
_, matmul_qk = qk_nodes_1
|
|
317
|
+
qk_nodes = qk_nodes_1
|
|
318
|
+
elif qk_nodes_2 is not None:
|
|
319
|
+
_, _, add_qk, _, matmul_qk = qk_nodes_2
|
|
320
|
+
qk_nodes = qk_nodes_2
|
|
321
|
+
elif qk_nodes_2_openai is not None:
|
|
322
|
+
_, add_qk, matmul_qk = qk_nodes_2_openai
|
|
323
|
+
qk_nodes = qk_nodes_2_openai
|
|
324
|
+
else:
|
|
325
|
+
return
|
|
326
|
+
|
|
327
|
+
q_nodes = self.model.match_parent_path(
|
|
328
|
+
matmul_qk,
|
|
329
|
+
["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"],
|
|
330
|
+
[0, 0, 0, 0, 0, 1],
|
|
331
|
+
)
|
|
332
|
+
q_nodes_openai = self.model.match_parent_path(
|
|
333
|
+
matmul_qk,
|
|
334
|
+
["Mul", "Transpose", "Reshape", "Add", "MatMul"],
|
|
335
|
+
[0, 0, 0, 0, 1],
|
|
336
|
+
)
|
|
337
|
+
reshape_q_2 = None
|
|
338
|
+
if q_nodes is not None:
|
|
339
|
+
reshape_q_2, transpose_q, reshape_q_1, mul_q, add_q, matmul_q = q_nodes
|
|
340
|
+
elif q_nodes_openai is not None:
|
|
341
|
+
q_nodes = q_nodes_openai
|
|
342
|
+
mul_q, transpose_q, reshape_q_1, add_q, matmul_q = q_nodes
|
|
343
|
+
else:
|
|
344
|
+
return
|
|
345
|
+
|
|
346
|
+
k_nodes_with_bias = self.model.match_parent_path(
|
|
347
|
+
matmul_qk,
|
|
348
|
+
["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"],
|
|
349
|
+
[1, 0, 0, 0, 0, 1],
|
|
350
|
+
)
|
|
351
|
+
k_nodes_with_bias_openai = self.model.match_parent_path(
|
|
352
|
+
matmul_qk,
|
|
353
|
+
["Mul", "Transpose", "Reshape", "MatMul"],
|
|
354
|
+
[1, 0, 0, 0],
|
|
355
|
+
)
|
|
356
|
+
k_nodes_no_bias = self.model.match_parent_path(
|
|
357
|
+
matmul_qk,
|
|
358
|
+
["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"],
|
|
359
|
+
[1, 0, 0, 0, 0],
|
|
360
|
+
)
|
|
361
|
+
k_nodes_no_bias_with_past_self_attn = self.model.match_parent_path(
|
|
362
|
+
# Decoder attention with past key concatenated before MatMul
|
|
363
|
+
matmul_qk,
|
|
364
|
+
["Transpose", "Reshape", "Concat", "Transpose", "Reshape", "MatMul"],
|
|
365
|
+
[1, 0, 0, 1, 0, 0],
|
|
366
|
+
)
|
|
367
|
+
k_nodes_no_bias_with_past_cross_attn = self.model.match_parent_path(
|
|
368
|
+
# Decoder attention with past key directly used in MatMul
|
|
369
|
+
matmul_qk,
|
|
370
|
+
["Transpose", "Reshape"],
|
|
371
|
+
[1, 0],
|
|
372
|
+
)
|
|
373
|
+
k_nodes_no_bias_with_past_cross_attn_openai = self.model.match_parent_path(
|
|
374
|
+
# Decoder attention with past key directly used in MatMul
|
|
375
|
+
matmul_qk,
|
|
376
|
+
["Mul", "Transpose", "Reshape", "Reshape", "Transpose"],
|
|
377
|
+
[1, 0, 0, 0, 0],
|
|
378
|
+
)
|
|
379
|
+
past_k, present_k = "", ""
|
|
380
|
+
reshape_k_2, reshape_k_1, matmul_k = None, None, None
|
|
381
|
+
if k_nodes_with_bias is not None:
|
|
382
|
+
_, reshape_k_2, transpose_k_1, reshape_k_1, add_k, matmul_k = k_nodes_with_bias
|
|
383
|
+
k_nodes = k_nodes_with_bias
|
|
384
|
+
elif k_nodes_with_bias_openai is not None:
|
|
385
|
+
mul_k, transpose_k_1, reshape_k_1, matmul_k = k_nodes_with_bias_openai
|
|
386
|
+
k_nodes = k_nodes_with_bias_openai
|
|
387
|
+
present_k = matmul_k.output[0]
|
|
388
|
+
|
|
389
|
+
# Find the child path to access the correct present_k values
|
|
390
|
+
# Openai impl provides present/past k values in 3D format
|
|
391
|
+
# whereas ort MultiHeadAttention expects k values in 4D, hence the
|
|
392
|
+
# additional Reshape and Transpose nodes are added
|
|
393
|
+
# For encoder attention types
|
|
394
|
+
# Matmul -> Reshape -> Transpose -> Present_K
|
|
395
|
+
reshape_path = self.model.match_child_path(
|
|
396
|
+
matmul_k,
|
|
397
|
+
["Reshape", "Transpose"],
|
|
398
|
+
exclude=[reshape_k_1],
|
|
399
|
+
)
|
|
400
|
+
# For decoder attention types
|
|
401
|
+
# matmul_k_node Reshape <- Transpose <- Past_K
|
|
402
|
+
# \ /
|
|
403
|
+
# \ /
|
|
404
|
+
# -> Concat <-
|
|
405
|
+
# |
|
|
406
|
+
# |--> Reshape -> Transpose -> Present_K
|
|
407
|
+
concat_path = self.model.match_child_path(matmul_k, ["Concat", "Reshape", "Transpose"])
|
|
408
|
+
if reshape_path is not None:
|
|
409
|
+
(_, transpose_matmul_k) = reshape_path
|
|
410
|
+
if transpose_matmul_k.output[0] in graph_output_names:
|
|
411
|
+
present_k = transpose_matmul_k.output[0]
|
|
412
|
+
if concat_path is not None:
|
|
413
|
+
(concat_k, _, transpose_concat_k) = concat_path
|
|
414
|
+
if transpose_concat_k.output[0] in graph_output_names:
|
|
415
|
+
present_k = transpose_concat_k.output[0]
|
|
416
|
+
concat_nodes = self.model.match_parent_path(concat_k, ["Reshape", "Transpose"], [0, 0])
|
|
417
|
+
_, transpose_concat_k_in = concat_nodes
|
|
418
|
+
past_k = transpose_concat_k_in.input[0]
|
|
419
|
+
elif k_nodes_no_bias is not None:
|
|
420
|
+
_, reshape_k_2, transpose_k_1, reshape_k_1, matmul_k = k_nodes_no_bias
|
|
421
|
+
k_nodes = k_nodes_no_bias
|
|
422
|
+
# For initial pass through encoder-decoder_with_past to get starting past values (beam search)
|
|
423
|
+
present_k = transpose_k_1.output[0]
|
|
424
|
+
elif k_nodes_no_bias_with_past_self_attn is not None:
|
|
425
|
+
_, reshape_k_2, concat_k, _, reshape_k_1, matmul_k = k_nodes_no_bias_with_past_self_attn
|
|
426
|
+
k_nodes = k_nodes_no_bias_with_past_self_attn
|
|
427
|
+
past_k = concat_k.input[0]
|
|
428
|
+
present_k = concat_k.output[0]
|
|
429
|
+
elif (
|
|
430
|
+
k_nodes_no_bias_with_past_cross_attn is not None
|
|
431
|
+
and k_nodes_no_bias_with_past_cross_attn[-1].input[0] in graph_input_names
|
|
432
|
+
):
|
|
433
|
+
k_nodes = k_nodes_no_bias_with_past_cross_attn
|
|
434
|
+
past_k = k_nodes[-1].input[0]
|
|
435
|
+
present_k = k_nodes[-1].output[0]
|
|
436
|
+
if present_k not in graph_output_names:
|
|
437
|
+
identity_node_k = list(
|
|
438
|
+
filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k])
|
|
439
|
+
)
|
|
440
|
+
present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else ""
|
|
441
|
+
elif (
|
|
442
|
+
k_nodes_no_bias_with_past_cross_attn_openai is not None
|
|
443
|
+
and k_nodes_no_bias_with_past_cross_attn_openai[-1].input[0] in graph_input_names
|
|
444
|
+
):
|
|
445
|
+
k_nodes = k_nodes_no_bias_with_past_cross_attn_openai
|
|
446
|
+
past_k = k_nodes[-1].input[0]
|
|
447
|
+
present_k = k_nodes[-1].output[0]
|
|
448
|
+
if present_k not in graph_output_names:
|
|
449
|
+
identity_node_k = list(
|
|
450
|
+
filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k])
|
|
451
|
+
)
|
|
452
|
+
present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else ""
|
|
453
|
+
else:
|
|
454
|
+
return
|
|
455
|
+
past_k = past_k if past_k in graph_input_names else ""
|
|
456
|
+
present_k = present_k if present_k in graph_output_names else ""
|
|
457
|
+
|
|
458
|
+
if k_nodes in (k_nodes_with_bias_openai, k_nodes_no_bias, k_nodes_no_bias_with_past_self_attn):
|
|
459
|
+
# Create empty Add node for attention graph
|
|
460
|
+
bias_dim = self.model.get_initializer(add_v.input[0]).dims[0]
|
|
461
|
+
empty_bias_name = "empty_bias"
|
|
462
|
+
empty_tensor = self.model.get_initializer(empty_bias_name)
|
|
463
|
+
if empty_tensor is None:
|
|
464
|
+
self.add_initializer(
|
|
465
|
+
empty_bias_name,
|
|
466
|
+
TensorProto.FLOAT,
|
|
467
|
+
dims=[bias_dim],
|
|
468
|
+
vals=np.array([0.0] * bias_dim, dtype=np.float32),
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
add_name = self.model.create_node_name("Add")
|
|
472
|
+
add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k_1.name], add_name)
|
|
473
|
+
|
|
474
|
+
if (
|
|
475
|
+
model_impl_openai
|
|
476
|
+
and not past_k
|
|
477
|
+
and not self.check_runtime_shape_path_openai(
|
|
478
|
+
reshape_qkv_2,
|
|
479
|
+
matmul_qkv,
|
|
480
|
+
add_qk,
|
|
481
|
+
matmul_qk,
|
|
482
|
+
add_q,
|
|
483
|
+
)
|
|
484
|
+
):
|
|
485
|
+
return
|
|
486
|
+
elif (
|
|
487
|
+
not model_impl_openai
|
|
488
|
+
and not past_k
|
|
489
|
+
and not self.check_runtime_shape_path(
|
|
490
|
+
reshape_qkv_2,
|
|
491
|
+
reshape_qkv_1,
|
|
492
|
+
reshape_q_2,
|
|
493
|
+
reshape_k_2,
|
|
494
|
+
reshape_v_2,
|
|
495
|
+
root_input,
|
|
496
|
+
)
|
|
497
|
+
):
|
|
498
|
+
return
|
|
499
|
+
|
|
500
|
+
three_root_inputs = past_k and past_v and matmul_k is None and "matmul_v" not in locals()
|
|
501
|
+
one_root_input = (
|
|
502
|
+
not three_root_inputs
|
|
503
|
+
and matmul_k.input[0] == root_input
|
|
504
|
+
and matmul_q.input[0] == root_input
|
|
505
|
+
and matmul_v.input[0] == root_input
|
|
506
|
+
)
|
|
507
|
+
two_root_inputs = (
|
|
508
|
+
not three_root_inputs
|
|
509
|
+
and matmul_q.input[0] == root_input
|
|
510
|
+
and matmul_k.input[0] == matmul_v.input[0]
|
|
511
|
+
and matmul_k.input[0] != matmul_q.input[0]
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
# There are 5 types of attention:
|
|
515
|
+
# 1) Encoder attention with one_root_input=True and qk_nodes=qk_nodes_1
|
|
516
|
+
# 2) Decoder attention with one_root_input=True and qk_nodes=qk_nodes_2
|
|
517
|
+
# 3) Decoder attention with past with one_root_input=True and qk_nodes=qk_nodes_1 and past_k=past_decoder_key and past_v=past_decoder_value
|
|
518
|
+
# 4) Decoder cross attention with two_root_inputs=True and qk_nodes=qk_nodes_1
|
|
519
|
+
# 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_1
|
|
520
|
+
encoder_attention = one_root_input and qk_nodes == qk_nodes_1
|
|
521
|
+
decoder_attention = one_root_input and qk_nodes in (qk_nodes_2, qk_nodes_2_openai)
|
|
522
|
+
decoder_attention_with_past = (
|
|
523
|
+
(encoder_attention if not model_impl_openai else decoder_attention) and past_k and past_v
|
|
524
|
+
)
|
|
525
|
+
decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_1
|
|
526
|
+
decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_1
|
|
527
|
+
|
|
528
|
+
# For decoder_attention, the attention mask needs to be included in the attention node
|
|
529
|
+
mask_index = None
|
|
530
|
+
if decoder_attention:
|
|
531
|
+
mask_nodes_bart = self.model.match_parent_path(
|
|
532
|
+
add_qk,
|
|
533
|
+
["Where"],
|
|
534
|
+
[1],
|
|
535
|
+
)
|
|
536
|
+
mask_nodes_whisper = self.model.match_parent_path(
|
|
537
|
+
add_qk,
|
|
538
|
+
["Expand", "Unsqueeze", "Unsqueeze", "Where"],
|
|
539
|
+
[1, 0, 0, 0],
|
|
540
|
+
)
|
|
541
|
+
if mask_nodes_whisper is not None:
|
|
542
|
+
mask_index = mask_nodes_whisper[0].output[-1]
|
|
543
|
+
elif mask_nodes_bart is not None:
|
|
544
|
+
mask_index = mask_nodes_bart[0].output[-1]
|
|
545
|
+
|
|
546
|
+
if (
|
|
547
|
+
encoder_attention
|
|
548
|
+
or decoder_attention
|
|
549
|
+
or decoder_attention_with_past
|
|
550
|
+
or decoder_cross_attention
|
|
551
|
+
or decoder_cross_attention_with_past
|
|
552
|
+
):
|
|
553
|
+
attention_last_node = reshape_qkv_2
|
|
554
|
+
num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q_1)
|
|
555
|
+
|
|
556
|
+
if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
|
|
557
|
+
logger.debug("fuse_attention: failed to detect num_heads or hidden_size")
|
|
558
|
+
return
|
|
559
|
+
|
|
560
|
+
new_node = None
|
|
561
|
+
if decoder_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past:
|
|
562
|
+
# Note: Decoder attention with past key and past value is fused as multihead attention
|
|
563
|
+
# rather than attention because multihead attention supports separate past key and past
|
|
564
|
+
# value whereas attention supports concatenated past key and past value.
|
|
565
|
+
new_node = (
|
|
566
|
+
self.create_multihead_attention_node(
|
|
567
|
+
matmul_q,
|
|
568
|
+
matmul_k if decoder_cross_attention or decoder_attention_with_past else past_k,
|
|
569
|
+
matmul_v if decoder_cross_attention or decoder_attention_with_past else past_v,
|
|
570
|
+
add_q,
|
|
571
|
+
add_k if decoder_cross_attention or decoder_attention_with_past else None,
|
|
572
|
+
add_v if decoder_cross_attention or decoder_attention_with_past else None,
|
|
573
|
+
num_heads,
|
|
574
|
+
hidden_size,
|
|
575
|
+
attention_last_node.output[0],
|
|
576
|
+
past_k=past_k if decoder_attention_with_past else "",
|
|
577
|
+
past_v=past_v if decoder_attention_with_past else "",
|
|
578
|
+
present_k=present_k,
|
|
579
|
+
present_v=present_v,
|
|
580
|
+
packed_qkv=decoder_attention_with_past,
|
|
581
|
+
)
|
|
582
|
+
if self.use_multi_head_attention
|
|
583
|
+
else None
|
|
584
|
+
)
|
|
585
|
+
else:
|
|
586
|
+
# Temporarily set multihead attention flag to false
|
|
587
|
+
use_multi_head_attention_ground_truth = self.use_multi_head_attention
|
|
588
|
+
self.use_multi_head_attention = False
|
|
589
|
+
new_node = self.create_attention_node(
|
|
590
|
+
None,
|
|
591
|
+
matmul_q,
|
|
592
|
+
matmul_k,
|
|
593
|
+
matmul_v,
|
|
594
|
+
add_q,
|
|
595
|
+
add_k,
|
|
596
|
+
add_v,
|
|
597
|
+
num_heads,
|
|
598
|
+
hidden_size,
|
|
599
|
+
root_input,
|
|
600
|
+
attention_last_node.output[0],
|
|
601
|
+
add_qk_str=mask_index if decoder_attention else None,
|
|
602
|
+
past_k=past_k,
|
|
603
|
+
past_v=past_v,
|
|
604
|
+
present_k=present_k,
|
|
605
|
+
present_v=present_v,
|
|
606
|
+
)
|
|
607
|
+
self.use_multi_head_attention = use_multi_head_attention_ground_truth
|
|
608
|
+
if new_node is None:
|
|
609
|
+
return
|
|
610
|
+
|
|
611
|
+
self.nodes_to_add.append(new_node)
|
|
612
|
+
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
|
613
|
+
|
|
614
|
+
self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
|
|
615
|
+
self.nodes_to_remove.extend(qk_nodes)
|
|
616
|
+
|
|
617
|
+
# When using multihead attention, keep MatMul nodes in original graph
|
|
618
|
+
if decoder_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past:
|
|
619
|
+
if q_nodes[-1].op_type == "MatMul":
|
|
620
|
+
q_nodes.pop()
|
|
621
|
+
if k_nodes[-1].op_type == "MatMul":
|
|
622
|
+
k_nodes.pop()
|
|
623
|
+
if v_nodes[-1].op_type == "MatMul":
|
|
624
|
+
v_nodes.pop()
|
|
625
|
+
if self.disable_multi_head_attention_bias and (
|
|
626
|
+
decoder_cross_attention or decoder_cross_attention_with_past
|
|
627
|
+
):
|
|
628
|
+
if q_nodes[-1].op_type == "Add":
|
|
629
|
+
q_nodes.pop()
|
|
630
|
+
if k_nodes[-1].op_type == "Add":
|
|
631
|
+
k_nodes.pop()
|
|
632
|
+
if v_nodes[-1].op_type == "Add":
|
|
633
|
+
v_nodes.pop()
|
|
634
|
+
|
|
635
|
+
self.nodes_to_remove.extend(q_nodes)
|
|
636
|
+
self.nodes_to_remove.extend(k_nodes)
|
|
637
|
+
self.nodes_to_remove.extend(v_nodes)
|
|
638
|
+
|
|
639
|
+
# Use prune graph to remove mask nodes since they are shared by all attention nodes.
|
|
640
|
+
self.prune_graph = True
|