onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from logging import getLogger
|
|
7
|
+
from typing import Any, Dict, List, Optional, Sequence, Union
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from onnx import NodeProto, helper
|
|
11
|
+
from onnx_model import OnnxModel
|
|
12
|
+
|
|
13
|
+
logger = getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Fusion:
|
|
17
|
+
"""
|
|
18
|
+
Base class for Graph Fusion
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
model: OnnxModel,
|
|
24
|
+
fused_op_type: str,
|
|
25
|
+
search_op_types: Union[str, List[str]],
|
|
26
|
+
description: str = "",
|
|
27
|
+
):
|
|
28
|
+
self.search_op_types: List[str] = [search_op_types] if isinstance(search_op_types, str) else search_op_types
|
|
29
|
+
self.fused_op_type: str = fused_op_type
|
|
30
|
+
self.description: str = f"{fused_op_type}({description})" if description else fused_op_type
|
|
31
|
+
self.model: OnnxModel = model
|
|
32
|
+
self.nodes_to_remove: List = []
|
|
33
|
+
self.nodes_to_add: List = []
|
|
34
|
+
self.prune_graph: bool = False
|
|
35
|
+
self.node_name_to_graph_name: dict = {}
|
|
36
|
+
self.this_graph_name: Optional[str] = None
|
|
37
|
+
# It is optional that subclass updates fused_count since we will also check nodes_to_add to get counter.
|
|
38
|
+
self.fused_count: defaultdict = defaultdict(int)
|
|
39
|
+
|
|
40
|
+
def increase_counter(self, fused_op_name: str):
|
|
41
|
+
"""
|
|
42
|
+
Increase counter of a fused operator.
|
|
43
|
+
"""
|
|
44
|
+
self.fused_count[fused_op_name] += 1
|
|
45
|
+
|
|
46
|
+
def fuse(
|
|
47
|
+
self,
|
|
48
|
+
node: NodeProto,
|
|
49
|
+
input_name_to_nodes: Dict[str, List[NodeProto]],
|
|
50
|
+
output_name_to_node: Dict[str, NodeProto],
|
|
51
|
+
):
|
|
52
|
+
"""Interface for fusion that starts from a node"""
|
|
53
|
+
raise NotImplementedError
|
|
54
|
+
|
|
55
|
+
def apply(self):
|
|
56
|
+
"""
|
|
57
|
+
Apply graph fusion on the whole model graph.
|
|
58
|
+
It searched nodes of given operators, and start fusion on each of those nodes.
|
|
59
|
+
"""
|
|
60
|
+
logger.debug(f"start {self.description} fusion...")
|
|
61
|
+
input_name_to_nodes = self.model.input_name_to_nodes()
|
|
62
|
+
output_name_to_node = self.model.output_name_to_node()
|
|
63
|
+
|
|
64
|
+
# This assumes that two search ops will not be fused at same time!
|
|
65
|
+
for search_op_type in self.search_op_types:
|
|
66
|
+
for node in self.model.get_nodes_by_op_type(search_op_type):
|
|
67
|
+
graph = self.model.get_graph_by_node(node)
|
|
68
|
+
if graph is None:
|
|
69
|
+
raise Exception("Can not find node in any graph")
|
|
70
|
+
self.this_graph_name = graph.name
|
|
71
|
+
self.fuse(node, input_name_to_nodes, output_name_to_node)
|
|
72
|
+
|
|
73
|
+
op_list = [node.op_type for node in self.nodes_to_add]
|
|
74
|
+
if self.fused_count:
|
|
75
|
+
for key, value in self.fused_count.items():
|
|
76
|
+
if value:
|
|
77
|
+
logger.info(f"Fused {key}: {value}")
|
|
78
|
+
else:
|
|
79
|
+
count = op_list.count(self.fused_op_type)
|
|
80
|
+
if count > 0:
|
|
81
|
+
logger.info(f"Fused {self.description}: {count}")
|
|
82
|
+
|
|
83
|
+
self.model.remove_nodes(self.nodes_to_remove)
|
|
84
|
+
self.model.add_nodes(self.nodes_to_add, self.node_name_to_graph_name)
|
|
85
|
+
|
|
86
|
+
if self.prune_graph:
|
|
87
|
+
self.model.prune_graph()
|
|
88
|
+
elif self.nodes_to_remove or self.nodes_to_add:
|
|
89
|
+
self.model.update_graph()
|
|
90
|
+
|
|
91
|
+
def add_initializer(self, name: str, data_type: int, dims: Sequence[int], vals: Any, raw: bool = True):
|
|
92
|
+
if raw:
|
|
93
|
+
np_type = helper.tensor_dtype_to_np_dtype(data_type)
|
|
94
|
+
if not isinstance(vals, np.ndarray):
|
|
95
|
+
bytes = np.array(vals, dtype=np_type).tobytes()
|
|
96
|
+
else:
|
|
97
|
+
bytes = vals.astype(np_type).tobytes()
|
|
98
|
+
tensor = helper.make_tensor(
|
|
99
|
+
name=name,
|
|
100
|
+
data_type=data_type,
|
|
101
|
+
dims=dims,
|
|
102
|
+
vals=bytes,
|
|
103
|
+
raw=True,
|
|
104
|
+
)
|
|
105
|
+
else:
|
|
106
|
+
tensor = helper.make_tensor(
|
|
107
|
+
name=name,
|
|
108
|
+
data_type=data_type,
|
|
109
|
+
dims=dims,
|
|
110
|
+
vals=vals,
|
|
111
|
+
raw=False,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
self.model.add_initializer(tensor, self.this_graph_name)
|
|
115
|
+
return tensor
|
|
116
|
+
|
|
117
|
+
def add_nodes_to_remove(self, nodes: List[NodeProto]):
|
|
118
|
+
# Some nodes are shared between paths (e.g. rotary embedding nodes in the Q and K paths).
|
|
119
|
+
# When path A is fused, its shared nodes are added to `self.nodes_to_remove`. But when path B
|
|
120
|
+
# is fused, its shared nodes are also added to `self.nodes_to_remove`. When the nodes are
|
|
121
|
+
# iteratively removed from `self.nodes_to_remove`, path A's shared nodes are removed first.
|
|
122
|
+
# Since path A's shared nodes are removed, path B's shared nodes are not removed because they
|
|
123
|
+
# were previously removed for path A. This causes an error to print in remove_node that a node
|
|
124
|
+
# has failed to be removed.
|
|
125
|
+
#
|
|
126
|
+
# To avoid this error, we pre-emptively check if the shared nodes are already in `self.nodes_to_remove`.
|
|
127
|
+
# We could alternatively convert `self.nodes_to_remove` to a set to avoid this issue, but there could
|
|
128
|
+
# be scenarios where the nodes need to be removed in a specific order and converting to a set would
|
|
129
|
+
# lose this order.
|
|
130
|
+
for node in nodes:
|
|
131
|
+
if node not in self.nodes_to_remove:
|
|
132
|
+
self.nodes_to_remove.append(node)
|
|
133
|
+
|
|
134
|
+
def add_nodes_to_remove_with_nodes_to_keep(self, nodes: List[NodeProto], nodes_to_keep: List[NodeProto]):
|
|
135
|
+
for node in nodes:
|
|
136
|
+
if node not in self.nodes_to_remove and node not in nodes_to_keep:
|
|
137
|
+
self.nodes_to_remove.append(node)
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
from logging import getLogger
|
|
6
|
+
from typing import Dict
|
|
7
|
+
|
|
8
|
+
from fusion_base import Fusion
|
|
9
|
+
from numpy import ndarray
|
|
10
|
+
from onnx import helper
|
|
11
|
+
from onnx_model import OnnxModel
|
|
12
|
+
|
|
13
|
+
logger = getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class FusionBiasAdd(Fusion):
|
|
17
|
+
def __init__(self, model: OnnxModel):
|
|
18
|
+
super().__init__(model, "BiasAdd", "Add")
|
|
19
|
+
|
|
20
|
+
def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
|
|
21
|
+
"""
|
|
22
|
+
Fuse Add bias and Add skip connection into BiasAdd
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
nodes = self.model.match_parent_path(
|
|
26
|
+
add_node,
|
|
27
|
+
["Add", "MatMul", "BiasSplitGelu", "MatMul", "SkipLayerNormalization"],
|
|
28
|
+
[0, None, 0, 0, 0],
|
|
29
|
+
output_name_to_node,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
if nodes is None:
|
|
33
|
+
return
|
|
34
|
+
|
|
35
|
+
bias_node = nodes[0]
|
|
36
|
+
skip_layer_norm = nodes[-1]
|
|
37
|
+
|
|
38
|
+
# Check skip connection is from SkipLayerNormalization output
|
|
39
|
+
if add_node.input[1] not in skip_layer_norm.output:
|
|
40
|
+
return
|
|
41
|
+
|
|
42
|
+
bias_index, bias_value = self.model.get_constant_input(bias_node)
|
|
43
|
+
if not (isinstance(bias_index, int) and (bias_value is not None) and isinstance(bias_value, ndarray)):
|
|
44
|
+
return
|
|
45
|
+
if bias_value.ndim != 1:
|
|
46
|
+
return
|
|
47
|
+
|
|
48
|
+
self.nodes_to_remove.extend([add_node, bias_node])
|
|
49
|
+
node_name = self.model.create_node_name("BiasAdd")
|
|
50
|
+
fused_node = helper.make_node(
|
|
51
|
+
"BiasAdd",
|
|
52
|
+
inputs=[bias_node.input[1 - bias_index], bias_node.input[bias_index], add_node.input[1]],
|
|
53
|
+
outputs=[add_node.output[0]],
|
|
54
|
+
name=node_name,
|
|
55
|
+
)
|
|
56
|
+
fused_node.domain = "com.microsoft"
|
|
57
|
+
self.nodes_to_add.append(fused_node)
|
|
58
|
+
self.node_name_to_graph_name[node_name] = self.this_graph_name
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
from logging import getLogger
|
|
7
|
+
|
|
8
|
+
from fusion_base import Fusion
|
|
9
|
+
from fusion_utils import NumpyHelper
|
|
10
|
+
from onnx import helper
|
|
11
|
+
from onnx_model import OnnxModel
|
|
12
|
+
|
|
13
|
+
logger = getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class FusionBiasGelu(Fusion):
|
|
17
|
+
def __init__(self, model: OnnxModel, is_fastgelu):
|
|
18
|
+
if is_fastgelu:
|
|
19
|
+
super().__init__(model, "FastGelu", "FastGelu", "add bias")
|
|
20
|
+
else:
|
|
21
|
+
super().__init__(model, "BiasGelu", "Gelu")
|
|
22
|
+
|
|
23
|
+
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
|
24
|
+
gelu_op_type = node.op_type
|
|
25
|
+
fuse_op_type = "BiasGelu" if gelu_op_type == "Gelu" else "FastGelu"
|
|
26
|
+
|
|
27
|
+
if len(node.input) != 1:
|
|
28
|
+
return
|
|
29
|
+
|
|
30
|
+
nodes = self.model.match_parent_path(node, ["Add", "MatMul"], [0, None])
|
|
31
|
+
if nodes is None:
|
|
32
|
+
return
|
|
33
|
+
(add, matmul) = nodes
|
|
34
|
+
|
|
35
|
+
bias_weight = None
|
|
36
|
+
# bias should be one dimension
|
|
37
|
+
bias_index = -1
|
|
38
|
+
for i, input in enumerate(add.input):
|
|
39
|
+
initializer = self.model.get_initializer(input)
|
|
40
|
+
if initializer is None:
|
|
41
|
+
continue
|
|
42
|
+
bias_index = i
|
|
43
|
+
bias_weight = NumpyHelper.to_array(initializer)
|
|
44
|
+
break
|
|
45
|
+
if bias_weight is None:
|
|
46
|
+
return
|
|
47
|
+
if len(bias_weight.shape) != 1:
|
|
48
|
+
return
|
|
49
|
+
|
|
50
|
+
subgraph_nodes = [node, add]
|
|
51
|
+
if not self.model.is_safe_to_fuse_nodes(
|
|
52
|
+
subgraph_nodes, [node.output[0]], input_name_to_nodes, output_name_to_node
|
|
53
|
+
):
|
|
54
|
+
return
|
|
55
|
+
|
|
56
|
+
self.nodes_to_remove.extend(subgraph_nodes)
|
|
57
|
+
|
|
58
|
+
fused_node = helper.make_node(
|
|
59
|
+
fuse_op_type,
|
|
60
|
+
inputs=[matmul.output[0], add.input[bias_index]],
|
|
61
|
+
outputs=node.output,
|
|
62
|
+
name=self.model.create_node_name(fuse_op_type, gelu_op_type + "_AddBias_"),
|
|
63
|
+
)
|
|
64
|
+
fused_node.domain = "com.microsoft"
|
|
65
|
+
self.nodes_to_add.append(fused_node)
|
|
66
|
+
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
from logging import getLogger
|
|
6
|
+
from typing import Dict
|
|
7
|
+
|
|
8
|
+
from fusion_base import Fusion
|
|
9
|
+
from onnx import helper
|
|
10
|
+
from onnx_model import OnnxModel
|
|
11
|
+
|
|
12
|
+
logger = getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FusionBiasSplitGelu(Fusion):
|
|
16
|
+
def __init__(self, model: OnnxModel):
|
|
17
|
+
super().__init__(model, "BiasSplitGelu", "Gelu")
|
|
18
|
+
|
|
19
|
+
def fuse(self, gelu_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
|
|
20
|
+
"""
|
|
21
|
+
[root] --->Add --------------------> Slice ---------------> Mul -->
|
|
22
|
+
| ^ ^
|
|
23
|
+
| | |
|
|
24
|
+
+----------------------------+---Slice --> Gelu---+
|
|
25
|
+
| | ^
|
|
26
|
+
| |-----|
|
|
27
|
+
| | |
|
|
28
|
+
| Mul Mul
|
|
29
|
+
| ^ ^
|
|
30
|
+
v | |
|
|
31
|
+
Shape ---> Gather --> Add --> Div --+
|
|
32
|
+
"""
|
|
33
|
+
if gelu_node.output[0] not in input_name_to_nodes:
|
|
34
|
+
return
|
|
35
|
+
children = input_name_to_nodes[gelu_node.output[0]]
|
|
36
|
+
if len(children) != 1 or children[0].op_type != "Mul":
|
|
37
|
+
return
|
|
38
|
+
mul_after_gelu = children[0]
|
|
39
|
+
|
|
40
|
+
slice_before_gelu = self.model.match_parent(gelu_node, "Slice", 0, output_name_to_node)
|
|
41
|
+
if slice_before_gelu is None:
|
|
42
|
+
return
|
|
43
|
+
|
|
44
|
+
if self.model.find_constant_input(slice_before_gelu, -1, delta=0.001) != 3:
|
|
45
|
+
return
|
|
46
|
+
|
|
47
|
+
add_output = slice_before_gelu.input[0]
|
|
48
|
+
|
|
49
|
+
start_index_nodes = self.model.match_parent_path(
|
|
50
|
+
slice_before_gelu,
|
|
51
|
+
["Div", "Add", "Gather", "Shape", "Add"],
|
|
52
|
+
[1, 0, 0, 0, 0],
|
|
53
|
+
output_name_to_node, # Mul(1) is optional
|
|
54
|
+
)
|
|
55
|
+
if start_index_nodes is None:
|
|
56
|
+
start_index_nodes = self.model.match_parent_path(
|
|
57
|
+
slice_before_gelu,
|
|
58
|
+
["Mul", "Div", "Add", "Gather", "Shape", "Add"],
|
|
59
|
+
[1, 0, 0, 0, 0, 0],
|
|
60
|
+
output_name_to_node,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
if start_index_nodes is None or start_index_nodes[-2].input[0] != add_output:
|
|
64
|
+
return
|
|
65
|
+
|
|
66
|
+
end_index_nodes = self.model.match_parent_path(slice_before_gelu, ["Mul", "Div"], [2, 0], output_name_to_node)
|
|
67
|
+
|
|
68
|
+
if (
|
|
69
|
+
end_index_nodes is None or end_index_nodes[1] not in start_index_nodes
|
|
70
|
+
): # the Div is parent of both two Mul nodes
|
|
71
|
+
return
|
|
72
|
+
|
|
73
|
+
slice_before_mul = self.model.match_parent(mul_after_gelu, "Slice", 0, output_name_to_node)
|
|
74
|
+
if slice_before_mul is None:
|
|
75
|
+
return
|
|
76
|
+
|
|
77
|
+
if (
|
|
78
|
+
slice_before_mul.input[2] != slice_before_gelu.input[1]
|
|
79
|
+
): # end index of slice_before_mul is start index of slice_before_gelu
|
|
80
|
+
return
|
|
81
|
+
|
|
82
|
+
subgraph_nodes = [
|
|
83
|
+
*start_index_nodes,
|
|
84
|
+
end_index_nodes[0],
|
|
85
|
+
mul_after_gelu,
|
|
86
|
+
gelu_node,
|
|
87
|
+
slice_before_mul,
|
|
88
|
+
slice_before_gelu,
|
|
89
|
+
]
|
|
90
|
+
subgraph_output = mul_after_gelu.output[0]
|
|
91
|
+
if not self.model.is_safe_to_fuse_nodes(
|
|
92
|
+
subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node
|
|
93
|
+
):
|
|
94
|
+
logger.info("Skip fuse BiasSplitGelu since it is not safe to fuse the subgraph.")
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
add_node = start_index_nodes[-1]
|
|
98
|
+
bias_index, _value = self.model.get_constant_input(add_node)
|
|
99
|
+
if not isinstance(bias_index, int):
|
|
100
|
+
return
|
|
101
|
+
self.nodes_to_remove.extend(subgraph_nodes)
|
|
102
|
+
node_name = self.model.create_node_name("BiasSplitGelu", name_prefix="BiasSplitGelu")
|
|
103
|
+
fused_node = helper.make_node(
|
|
104
|
+
"BiasSplitGelu",
|
|
105
|
+
inputs=[add_node.input[1 - bias_index], add_node.input[bias_index]],
|
|
106
|
+
outputs=[subgraph_output],
|
|
107
|
+
name=node_name,
|
|
108
|
+
)
|
|
109
|
+
fused_node.domain = "com.microsoft"
|
|
110
|
+
self.nodes_to_add.append(fused_node)
|
|
111
|
+
self.node_name_to_graph_name[node_name] = self.this_graph_name
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
from fusion_attention import AttentionMask, FusionAttention
|
|
8
|
+
from onnx_model import OnnxModel
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class FusionConformerAttention(FusionAttention):
|
|
14
|
+
"""
|
|
15
|
+
Fuse Conformer Attention subgraph into one MultiHeadAttention node.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
model: OnnxModel,
|
|
21
|
+
hidden_size: int,
|
|
22
|
+
num_heads: int,
|
|
23
|
+
attention_mask: AttentionMask,
|
|
24
|
+
):
|
|
25
|
+
super().__init__(model, hidden_size, num_heads, attention_mask)
|
|
26
|
+
|
|
27
|
+
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
|
28
|
+
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
|
|
29
|
+
qkv_nodes = self.model.match_parent_path(
|
|
30
|
+
normalize_node,
|
|
31
|
+
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
|
|
32
|
+
[1, 1, 0, 0, 0],
|
|
33
|
+
)
|
|
34
|
+
if qkv_nodes is not None:
|
|
35
|
+
(
|
|
36
|
+
_,
|
|
37
|
+
_,
|
|
38
|
+
reshape_qkv,
|
|
39
|
+
transpose_qkv,
|
|
40
|
+
matmul_qkv,
|
|
41
|
+
) = qkv_nodes
|
|
42
|
+
else:
|
|
43
|
+
logger.debug("fuse_conformer_attention: failed to match qkv path")
|
|
44
|
+
return
|
|
45
|
+
|
|
46
|
+
v_nodes = self.model.match_parent_path(
|
|
47
|
+
matmul_qkv,
|
|
48
|
+
["Concat", "Transpose", "Reshape", "Add", "MatMul"],
|
|
49
|
+
[1, 1, 0, 0, 1],
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
add_v = None
|
|
53
|
+
if v_nodes is not None:
|
|
54
|
+
(concat_v, _, _, add_v, matmul_v) = v_nodes
|
|
55
|
+
concat_parent = self.model.get_parent(concat_v, 0, None)
|
|
56
|
+
present_v = concat_v.output[0]
|
|
57
|
+
past_v = concat_parent.output[0]
|
|
58
|
+
else:
|
|
59
|
+
logger.debug("fuse_conformer_attention: failed to match v path")
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
|
|
63
|
+
|
|
64
|
+
if qk_nodes is not None:
|
|
65
|
+
_, add_qk, matmul_qk = qk_nodes
|
|
66
|
+
else:
|
|
67
|
+
logger.debug("fuse_conformer_attention: failed to match qk path")
|
|
68
|
+
return
|
|
69
|
+
|
|
70
|
+
q_nodes = self.model.match_parent_path(
|
|
71
|
+
matmul_qk,
|
|
72
|
+
["Div", "Transpose", "Reshape", "Add", "MatMul"],
|
|
73
|
+
[0, 0, 0, 0, 1],
|
|
74
|
+
)
|
|
75
|
+
if q_nodes is not None:
|
|
76
|
+
_, _, reshape_q, add_q, matmul_q = q_nodes
|
|
77
|
+
else:
|
|
78
|
+
logger.debug("fuse_conformer_attention: failed to match q path")
|
|
79
|
+
return
|
|
80
|
+
|
|
81
|
+
k_nodes = self.model.match_parent_path(
|
|
82
|
+
matmul_qk,
|
|
83
|
+
["Transpose", "Concat", "Transpose", "Reshape", "Add", "MatMul"],
|
|
84
|
+
[1, 0, 1, 0, 0, 1],
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
matmul_k = None
|
|
88
|
+
if k_nodes is not None:
|
|
89
|
+
_, concat_k, _, _, add_k, matmul_k = k_nodes
|
|
90
|
+
concat_parent = self.model.get_parent(concat_k, 0, None)
|
|
91
|
+
past_k = concat_parent.output[0]
|
|
92
|
+
present_k = concat_k.output[0]
|
|
93
|
+
else:
|
|
94
|
+
logger.debug("fuse_conformer_attention: failed to match k path")
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
attention_last_node = reshape_qkv
|
|
98
|
+
num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
|
|
99
|
+
|
|
100
|
+
if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
|
|
101
|
+
logger.debug("fuse_conformer_attention: failed to detect num_heads or hidden_size")
|
|
102
|
+
return
|
|
103
|
+
|
|
104
|
+
new_node = self.create_multihead_attention_node(
|
|
105
|
+
matmul_q,
|
|
106
|
+
matmul_k,
|
|
107
|
+
matmul_v,
|
|
108
|
+
add_q,
|
|
109
|
+
add_k,
|
|
110
|
+
add_v,
|
|
111
|
+
num_heads,
|
|
112
|
+
hidden_size,
|
|
113
|
+
attention_last_node.output[0],
|
|
114
|
+
add_qk=add_qk.input[1],
|
|
115
|
+
past_k=past_k,
|
|
116
|
+
past_v=past_v,
|
|
117
|
+
present_k=present_k,
|
|
118
|
+
present_v=present_v,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
if new_node is None:
|
|
122
|
+
logger.debug("fuse_conformer_attention: MultiHeadAttention node creation failed")
|
|
123
|
+
return
|
|
124
|
+
|
|
125
|
+
self.nodes_to_add.append(new_node)
|
|
126
|
+
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
|
127
|
+
|
|
128
|
+
self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
|
|
129
|
+
self.nodes_to_remove.extend(qk_nodes)
|
|
130
|
+
|
|
131
|
+
# When using multihead attention, keep MatMul nodes in original graph
|
|
132
|
+
if q_nodes[-1].op_type == "MatMul":
|
|
133
|
+
q_nodes.pop()
|
|
134
|
+
if k_nodes[-1].op_type == "MatMul":
|
|
135
|
+
k_nodes.pop()
|
|
136
|
+
if v_nodes[-1].op_type == "MatMul":
|
|
137
|
+
v_nodes.pop()
|
|
138
|
+
|
|
139
|
+
self.nodes_to_remove.extend(k_nodes)
|
|
140
|
+
self.nodes_to_remove.extend(v_nodes)
|
|
141
|
+
|
|
142
|
+
# Use prune graph to remove mask nodes since they are shared by all attention nodes.
|
|
143
|
+
self.prune_graph = True
|