onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
from logging import getLogger
|
|
7
|
+
from typing import Dict, List, Union
|
|
8
|
+
|
|
9
|
+
from fusion_base import Fusion
|
|
10
|
+
from fusion_utils import FusionUtils
|
|
11
|
+
from numpy import ndarray
|
|
12
|
+
from onnx import NodeProto, TensorProto
|
|
13
|
+
from onnx_model import OnnxModel
|
|
14
|
+
|
|
15
|
+
logger = getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class FusionShape(Fusion):
|
|
19
|
+
def __init__(self, model: OnnxModel):
|
|
20
|
+
super().__init__(model, "Shape", "Concat")
|
|
21
|
+
self.utils = FusionUtils(model)
|
|
22
|
+
self.shape_infer = None
|
|
23
|
+
self.shape_infer_done = False
|
|
24
|
+
|
|
25
|
+
def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> Union[int, None]:
|
|
26
|
+
if tensor_proto.type.tensor_type.HasField("shape"):
|
|
27
|
+
return len(tensor_proto.type.tensor_type.shape.dim)
|
|
28
|
+
else:
|
|
29
|
+
return None
|
|
30
|
+
|
|
31
|
+
def get_dimensions(self, input_name: str) -> Union[int, None]:
|
|
32
|
+
shape = self.model.get_shape(input_name)
|
|
33
|
+
if shape is not None:
|
|
34
|
+
return len(shape)
|
|
35
|
+
|
|
36
|
+
if not self.shape_infer_done:
|
|
37
|
+
self.shape_infer = self.model.infer_runtime_shape(update=True)
|
|
38
|
+
self.shape_infer_done = True
|
|
39
|
+
|
|
40
|
+
if self.shape_infer is not None:
|
|
41
|
+
return self.get_dimensions_from_tensor_proto(self.shape_infer.known_vi_[input_name])
|
|
42
|
+
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
def fuse(
|
|
46
|
+
self,
|
|
47
|
+
concat_node: NodeProto,
|
|
48
|
+
input_name_to_nodes: Dict[str, List[NodeProto]],
|
|
49
|
+
output_name_to_node: Dict[str, NodeProto],
|
|
50
|
+
):
|
|
51
|
+
#
|
|
52
|
+
# Simplify subgraph like
|
|
53
|
+
#
|
|
54
|
+
# (2d_input)
|
|
55
|
+
# / \
|
|
56
|
+
# Shape shape
|
|
57
|
+
# / \
|
|
58
|
+
# Gather(indices=0) Gather(indices=1)
|
|
59
|
+
# | |
|
|
60
|
+
# Unsqueeze(axes=0) Unsqueeze(axes=0)
|
|
61
|
+
# \ /
|
|
62
|
+
# Concat
|
|
63
|
+
# |
|
|
64
|
+
#
|
|
65
|
+
# into (2d_input) --> Shape -->
|
|
66
|
+
#
|
|
67
|
+
opset_version = self.model.get_opset_version()
|
|
68
|
+
|
|
69
|
+
inputs = len(concat_node.input)
|
|
70
|
+
root = None
|
|
71
|
+
shape_output = None
|
|
72
|
+
for i in range(inputs):
|
|
73
|
+
path = self.model.match_parent_path(
|
|
74
|
+
concat_node,
|
|
75
|
+
["Unsqueeze", "Gather", "Shape"],
|
|
76
|
+
[i, 0, 0],
|
|
77
|
+
output_name_to_node,
|
|
78
|
+
)
|
|
79
|
+
if path is None:
|
|
80
|
+
return
|
|
81
|
+
|
|
82
|
+
unsqueeze, gather, shape = path
|
|
83
|
+
if i == 0:
|
|
84
|
+
shape_output = shape.output[0]
|
|
85
|
+
if root is None:
|
|
86
|
+
root = shape.input[0]
|
|
87
|
+
if self.get_dimensions(root) != inputs:
|
|
88
|
+
return
|
|
89
|
+
elif shape.input[0] != root:
|
|
90
|
+
return
|
|
91
|
+
|
|
92
|
+
if not FusionUtils.check_node_attribute(unsqueeze, "axis", 0, default_value=0):
|
|
93
|
+
return
|
|
94
|
+
|
|
95
|
+
if opset_version < 13:
|
|
96
|
+
if not FusionUtils.check_node_attribute(unsqueeze, "axes", [0]):
|
|
97
|
+
return
|
|
98
|
+
else:
|
|
99
|
+
if not self.utils.check_node_input_value(unsqueeze, 1, [0]):
|
|
100
|
+
return
|
|
101
|
+
|
|
102
|
+
value = self.model.get_constant_value(gather.input[1])
|
|
103
|
+
|
|
104
|
+
if not (isinstance(value, ndarray) and value.size == 1 and value.item() == i):
|
|
105
|
+
return
|
|
106
|
+
|
|
107
|
+
if self.model.find_graph_output(concat_node.output[0]) is None:
|
|
108
|
+
self.model.replace_input_of_all_nodes(concat_node.output[0], shape_output)
|
|
109
|
+
self.increase_counter("Reshape")
|
|
110
|
+
self.prune_graph = True
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Dict
|
|
3
|
+
|
|
4
|
+
from fusion_base import Fusion
|
|
5
|
+
from fusion_skiplayernorm import FusionSkipLayerNormalization
|
|
6
|
+
from onnx import helper
|
|
7
|
+
from onnx_model import OnnxModel
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class FusionSimplifiedLayerNormalization(Fusion):
|
|
13
|
+
def __init__(self, model: OnnxModel):
|
|
14
|
+
super().__init__(model, "SimplifiedLayerNormalization", "Mul")
|
|
15
|
+
|
|
16
|
+
def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
|
|
17
|
+
if node.op_type != "Mul":
|
|
18
|
+
return
|
|
19
|
+
|
|
20
|
+
sim_ln_nodes = None
|
|
21
|
+
# SimplifiedLayerNorm calculation (notation from https://onnx.ai/onnx/operators/onnx__LayerNormalization.html#summary):
|
|
22
|
+
# DD = Pow(D, 2)
|
|
23
|
+
# Var = ReduceMean(DD)
|
|
24
|
+
# VarEps = Add(Var, epsilon)
|
|
25
|
+
# StdDev = Sqrt(VarEps)
|
|
26
|
+
# InvStdDev = Div(1, StdDev)
|
|
27
|
+
# Normalized = Mul(D, InvStdDev)
|
|
28
|
+
# NormalizedScaled = Mul(Normalized, Scale)
|
|
29
|
+
|
|
30
|
+
# SimplifiedLayerNorm
|
|
31
|
+
# +-------------------------------------------------------+
|
|
32
|
+
# | |
|
|
33
|
+
# Add --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul
|
|
34
|
+
# |
|
|
35
|
+
# node
|
|
36
|
+
sim_ln_nodes_1 = self.model.match_parent_path(
|
|
37
|
+
node,
|
|
38
|
+
["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"],
|
|
39
|
+
[1, 1, 1, 0, 0, 0, 0],
|
|
40
|
+
)
|
|
41
|
+
# SimplifiedLayerNorm
|
|
42
|
+
# +-------------------------------------------------------+
|
|
43
|
+
# | |
|
|
44
|
+
# Gather --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul
|
|
45
|
+
# |
|
|
46
|
+
# node
|
|
47
|
+
sim_ln_nodes_2 = self.model.match_parent_path(
|
|
48
|
+
node,
|
|
49
|
+
["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Gather"],
|
|
50
|
+
[1, 1, 1, 0, 0, 0, 0],
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# For LLaMA from Microsoft custom export:
|
|
54
|
+
# sim_ln_nodes_3 uses a different start parent index than sim_ln_nodes_1
|
|
55
|
+
#
|
|
56
|
+
# SimplifiedLayerNorm
|
|
57
|
+
# +-------------------------------------------------------+
|
|
58
|
+
# | |
|
|
59
|
+
# Add --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul
|
|
60
|
+
# |
|
|
61
|
+
# node
|
|
62
|
+
sim_ln_nodes_3 = self.model.match_parent_path(
|
|
63
|
+
node,
|
|
64
|
+
["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"],
|
|
65
|
+
[0, 1, 1, 0, 0, 0, 0],
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# sim_ln_nodes_4 starts with a graph input instead of an Add node like sim_ln_nodes_3
|
|
69
|
+
#
|
|
70
|
+
# SimplifiedLayerNorm
|
|
71
|
+
# +-----------------------------------------------+
|
|
72
|
+
# | |
|
|
73
|
+
# graph_input --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul
|
|
74
|
+
# |
|
|
75
|
+
# node
|
|
76
|
+
sim_ln_nodes_4 = self.model.match_parent_path(
|
|
77
|
+
node,
|
|
78
|
+
["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow"],
|
|
79
|
+
[0, 1, 1, 0, 0, 0],
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# For Gemma from Microsoft custom export, which has a Multiply after the Gather:
|
|
83
|
+
#
|
|
84
|
+
# SimplifiedLayerNorm
|
|
85
|
+
# +-------------------------------------------------------+
|
|
86
|
+
# | |
|
|
87
|
+
# Mul --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul
|
|
88
|
+
# |
|
|
89
|
+
# node
|
|
90
|
+
sim_ln_nodes_5 = self.model.match_parent_path(
|
|
91
|
+
node,
|
|
92
|
+
["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Mul"],
|
|
93
|
+
[1, 1, 1, 0, 0, 0, 0],
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
add_node, pow_node = None, None
|
|
97
|
+
if sim_ln_nodes_1 is not None:
|
|
98
|
+
sim_ln_nodes = sim_ln_nodes_1
|
|
99
|
+
add_node = sim_ln_nodes[3]
|
|
100
|
+
pow_node = sim_ln_nodes[-2]
|
|
101
|
+
elif sim_ln_nodes_2 is not None:
|
|
102
|
+
sim_ln_nodes = sim_ln_nodes_2
|
|
103
|
+
add_node = sim_ln_nodes[3]
|
|
104
|
+
pow_node = sim_ln_nodes[-2]
|
|
105
|
+
elif sim_ln_nodes_3 is not None:
|
|
106
|
+
sim_ln_nodes = sim_ln_nodes_3
|
|
107
|
+
add_node = sim_ln_nodes[3]
|
|
108
|
+
pow_node = sim_ln_nodes[-2]
|
|
109
|
+
elif sim_ln_nodes_4 is not None:
|
|
110
|
+
sim_ln_nodes = sim_ln_nodes_4
|
|
111
|
+
add_node = sim_ln_nodes[3]
|
|
112
|
+
pow_node = sim_ln_nodes[-1]
|
|
113
|
+
# Verify that parent input to Pow node is graph_input
|
|
114
|
+
if pow_node.input[0] not in self.model.get_graphs_input_names():
|
|
115
|
+
return
|
|
116
|
+
elif sim_ln_nodes_5 is not None:
|
|
117
|
+
sim_ln_nodes = sim_ln_nodes_5
|
|
118
|
+
add_node = sim_ln_nodes[3]
|
|
119
|
+
pow_node = sim_ln_nodes[-2]
|
|
120
|
+
else:
|
|
121
|
+
return
|
|
122
|
+
|
|
123
|
+
layernorm_weight_index = 1 if sim_ln_nodes in (sim_ln_nodes_3, sim_ln_nodes_4) else 0
|
|
124
|
+
starts_with_graph_input = sim_ln_nodes == sim_ln_nodes_4
|
|
125
|
+
|
|
126
|
+
if self.model.find_constant_input(pow_node, 2.0) != 1:
|
|
127
|
+
return
|
|
128
|
+
|
|
129
|
+
root_input = pow_node.input[0]
|
|
130
|
+
if root_input != sim_ln_nodes[0].input[0]:
|
|
131
|
+
return
|
|
132
|
+
|
|
133
|
+
i, add_weight = self.model.get_constant_input(add_node)
|
|
134
|
+
if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4:
|
|
135
|
+
logger.warning(f"epsilon value is not expected: {add_weight}")
|
|
136
|
+
return
|
|
137
|
+
|
|
138
|
+
self.nodes_to_remove.extend(sim_ln_nodes[:-1] if not starts_with_graph_input else sim_ln_nodes)
|
|
139
|
+
self.nodes_to_remove.append(node)
|
|
140
|
+
|
|
141
|
+
normalize_node = helper.make_node(
|
|
142
|
+
"SimplifiedLayerNormalization",
|
|
143
|
+
inputs=[root_input, node.input[layernorm_weight_index]],
|
|
144
|
+
outputs=[node.output[0]],
|
|
145
|
+
name=self.model.create_node_name("SimplifiedLayerNormalization", name_prefix="LayerNorm"),
|
|
146
|
+
)
|
|
147
|
+
normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))])
|
|
148
|
+
normalize_node.attribute.extend([helper.make_attribute("axis", -1)])
|
|
149
|
+
normalize_node.attribute.extend([helper.make_attribute("stash_type", 1)])
|
|
150
|
+
self.nodes_to_add.append(normalize_node)
|
|
151
|
+
self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class FusionSkipSimplifiedLayerNormalization(FusionSkipLayerNormalization):
|
|
155
|
+
def __init__(self, model: OnnxModel):
|
|
156
|
+
super().__init__(model, "SkipSimplifiedLayerNormalization", "SimplifiedLayerNormalization")
|
|
157
|
+
|
|
158
|
+
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
|
159
|
+
super().fuse(node, input_name_to_nodes, output_name_to_node)
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
from logging import getLogger
|
|
6
|
+
from typing import List
|
|
7
|
+
|
|
8
|
+
from fusion_base import Fusion
|
|
9
|
+
from fusion_utils import NumpyHelper
|
|
10
|
+
from onnx import helper
|
|
11
|
+
from onnx_model import OnnxModel
|
|
12
|
+
|
|
13
|
+
logger = getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class FusionSkipGroupNorm(Fusion):
|
|
17
|
+
"""
|
|
18
|
+
Fuse Add + GroupNorm into one node: SkipGroupNorm.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, model: OnnxModel):
|
|
22
|
+
super().__init__(model, "SkipGroupNorm", "GroupNorm")
|
|
23
|
+
# Update shape inference is needed since other fusions might add new edge which does not have shape info yet.
|
|
24
|
+
self.shape_infer_helper = self.model.infer_runtime_shape(update=True)
|
|
25
|
+
|
|
26
|
+
if self.shape_infer_helper is None:
|
|
27
|
+
logger.warning("SkipGroupNorm fusion will be skipped since symbolic shape inference disabled or failed.")
|
|
28
|
+
|
|
29
|
+
def create_transpose_node(self, input_name: str, perm: List[int], output_name=None):
|
|
30
|
+
"""Append a Transpose node after an input"""
|
|
31
|
+
node_name = self.model.create_node_name("Transpose")
|
|
32
|
+
if output_name is None:
|
|
33
|
+
output_name = node_name + "_out" + "-" + input_name
|
|
34
|
+
transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
|
|
35
|
+
transpose_node.attribute.extend([helper.make_attribute("perm", perm)])
|
|
36
|
+
return transpose_node
|
|
37
|
+
|
|
38
|
+
def get_skip_index(self, add, is_channel_last: bool):
|
|
39
|
+
"""Add has two inputs. This classifies which input is skip based on shape info (skip allows broadcast)."""
|
|
40
|
+
skip = -1
|
|
41
|
+
broadcast = False
|
|
42
|
+
|
|
43
|
+
assert self.shape_infer_helper is not None
|
|
44
|
+
shape_a = self.shape_infer_helper.get_edge_shape(add.input[0])
|
|
45
|
+
shape_b = self.shape_infer_helper.get_edge_shape(add.input[1])
|
|
46
|
+
assert shape_a is not None and shape_b is not None
|
|
47
|
+
|
|
48
|
+
if len(shape_a) == 4 and len(shape_b) == 4:
|
|
49
|
+
if shape_a == shape_b:
|
|
50
|
+
skip = 1
|
|
51
|
+
else:
|
|
52
|
+
c = 3 if is_channel_last else 1
|
|
53
|
+
h = 1 if is_channel_last else 2
|
|
54
|
+
w = 2 if is_channel_last else 3
|
|
55
|
+
if shape_a[0] == shape_b[0] and shape_a[c] == shape_b[c]:
|
|
56
|
+
if shape_b[h] == 1 and shape_b[w] == 1:
|
|
57
|
+
skip = 1
|
|
58
|
+
broadcast = True
|
|
59
|
+
elif shape_a[h] == 1 and shape_a[w] == 1:
|
|
60
|
+
skip = 0
|
|
61
|
+
broadcast = True
|
|
62
|
+
|
|
63
|
+
if skip < 0:
|
|
64
|
+
logger.debug(
|
|
65
|
+
"skip SkipGroupNorm fusion since shape of Add inputs (%s, %s) are not expected",
|
|
66
|
+
add.input[0],
|
|
67
|
+
add.input[1],
|
|
68
|
+
)
|
|
69
|
+
return skip, broadcast
|
|
70
|
+
|
|
71
|
+
def has_multiple_consumers(self, output_name, input_name_to_nodes):
|
|
72
|
+
"""Whether an output has multiple consumers (like graph output or more than one children nodes)"""
|
|
73
|
+
return self.model.find_graph_output(output_name) is not None or (
|
|
74
|
+
output_name in input_name_to_nodes and len(input_name_to_nodes[output_name]) > 1
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def remove_if_safe(self, node, input_name_to_nodes):
|
|
78
|
+
"""Remove a node if it is safe (only one children, and not graph output)"""
|
|
79
|
+
if not self.has_multiple_consumers(node.output[0], input_name_to_nodes):
|
|
80
|
+
self.nodes_to_remove.extend([node])
|
|
81
|
+
|
|
82
|
+
def is_bias_1d(self, bias_name: str):
|
|
83
|
+
"""Whether bias is an initializer of one dimension"""
|
|
84
|
+
initializer = self.model.get_initializer(bias_name)
|
|
85
|
+
if initializer is None:
|
|
86
|
+
return False
|
|
87
|
+
|
|
88
|
+
bias_weight = NumpyHelper.to_array(initializer)
|
|
89
|
+
if bias_weight is None:
|
|
90
|
+
logger.debug("Bias weight not found")
|
|
91
|
+
return False
|
|
92
|
+
|
|
93
|
+
if len(bias_weight.shape) != 1:
|
|
94
|
+
logger.debug("Bias weight is not 1D")
|
|
95
|
+
return False
|
|
96
|
+
return True
|
|
97
|
+
|
|
98
|
+
def match_bias_path(self, node, input_name_to_nodes, output_name_to_node):
|
|
99
|
+
"""
|
|
100
|
+
Match the bias graph pattern from an Transpose node after Reshape node like in below example.
|
|
101
|
+
It checks whether the bias is 1D initializer. If so, remove Add and redirect MatMul output to Reshape.
|
|
102
|
+
"""
|
|
103
|
+
# Before Fusion:
|
|
104
|
+
# MatMul (bias)
|
|
105
|
+
# \ / (shape)
|
|
106
|
+
# Add /
|
|
107
|
+
# \ /
|
|
108
|
+
# (a) Reshape
|
|
109
|
+
# \ |
|
|
110
|
+
# Transpose([0, 3, 1, 2]) Transpose([0, 3, 1, 2]) --- the start node, this func only handles the above nodes.
|
|
111
|
+
# \ /
|
|
112
|
+
# Add
|
|
113
|
+
# / \
|
|
114
|
+
# (c) Transpose([0,2,3,1])
|
|
115
|
+
# |
|
|
116
|
+
# GroupNorm
|
|
117
|
+
# |
|
|
118
|
+
# (d)
|
|
119
|
+
#
|
|
120
|
+
# After Fusion (the nodes below Reshape is handled in the fuse function):
|
|
121
|
+
# MatMul (shape)
|
|
122
|
+
# \ /
|
|
123
|
+
# (a) Reshape
|
|
124
|
+
# \ /
|
|
125
|
+
# SkipGroupNorm
|
|
126
|
+
# / \
|
|
127
|
+
# (d) Transpose([0, 3, 1, 2])
|
|
128
|
+
# \
|
|
129
|
+
# (c)
|
|
130
|
+
|
|
131
|
+
add_input_index = []
|
|
132
|
+
bias_nodes = self.model.match_parent_path(
|
|
133
|
+
node, ["Reshape", "Add", "MatMul"], [0, 0, None], output_name_to_node, add_input_index
|
|
134
|
+
)
|
|
135
|
+
if bias_nodes is None:
|
|
136
|
+
return None
|
|
137
|
+
|
|
138
|
+
(reshape, add_bias, matmul) = bias_nodes
|
|
139
|
+
bias = bias_nodes[1].input[1 - add_input_index[0]]
|
|
140
|
+
if not self.is_bias_1d(bias):
|
|
141
|
+
return None
|
|
142
|
+
|
|
143
|
+
reshape.input[0] = matmul.output[0]
|
|
144
|
+
self.remove_if_safe(add_bias, input_name_to_nodes)
|
|
145
|
+
|
|
146
|
+
return bias
|
|
147
|
+
|
|
148
|
+
def match_transpose_from_nhwc(self, output_name, input_name_to_nodes, output_name_to_node):
|
|
149
|
+
"""Match whether an output is from a Transpose(perm=[0,3,1,2]) node."""
|
|
150
|
+
parent = output_name_to_node.get(output_name, None)
|
|
151
|
+
if parent is not None and parent.op_type == "Transpose":
|
|
152
|
+
permutation = OnnxModel.get_node_attribute(parent, "perm")
|
|
153
|
+
if permutation == [0, 3, 1, 2]:
|
|
154
|
+
self.remove_if_safe(parent, input_name_to_nodes)
|
|
155
|
+
return parent
|
|
156
|
+
return None
|
|
157
|
+
|
|
158
|
+
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
|
159
|
+
# This fusion requires shape information, so skip it if shape is not available.
|
|
160
|
+
if self.shape_infer_helper is None:
|
|
161
|
+
return
|
|
162
|
+
|
|
163
|
+
# Before Fusion:
|
|
164
|
+
# (a) (b)
|
|
165
|
+
# \ /
|
|
166
|
+
# Add
|
|
167
|
+
# /\
|
|
168
|
+
# (c) Transpose([0,2,3,1])
|
|
169
|
+
# \
|
|
170
|
+
# GroupNorm
|
|
171
|
+
# |
|
|
172
|
+
# (d)
|
|
173
|
+
#
|
|
174
|
+
# After Fusion:
|
|
175
|
+
# (a) (b)
|
|
176
|
+
# \ /
|
|
177
|
+
# Transpose([0,2,3,1]) Transpose([0,2,3,1])
|
|
178
|
+
# \ /
|
|
179
|
+
# SkipGroupNorm
|
|
180
|
+
# / \
|
|
181
|
+
# / Transpose([0, 3, 1, 2])
|
|
182
|
+
# / \
|
|
183
|
+
# (d) (c)
|
|
184
|
+
nodes = self.model.match_parent_path(node, ["Transpose", "Add"], [0, 0], output_name_to_node)
|
|
185
|
+
if nodes is None:
|
|
186
|
+
return
|
|
187
|
+
|
|
188
|
+
(transpose, add) = nodes
|
|
189
|
+
if transpose in self.nodes_to_remove or add in self.nodes_to_remove:
|
|
190
|
+
return
|
|
191
|
+
|
|
192
|
+
if self.has_multiple_consumers(transpose.output[0], input_name_to_nodes):
|
|
193
|
+
return
|
|
194
|
+
|
|
195
|
+
permutation = OnnxModel.get_node_attribute(transpose, "perm")
|
|
196
|
+
if permutation != [0, 2, 3, 1]:
|
|
197
|
+
return
|
|
198
|
+
|
|
199
|
+
inputs = []
|
|
200
|
+
bias = None
|
|
201
|
+
for i in range(2):
|
|
202
|
+
matched_transpose = self.match_transpose_from_nhwc(add.input[i], input_name_to_nodes, output_name_to_node)
|
|
203
|
+
if matched_transpose:
|
|
204
|
+
# When there is an Transpose node before Add (see examples in match_bias_path), we do not need to
|
|
205
|
+
# insert another Transpose node. The existing Transpose node will be removed in prune_graph if it
|
|
206
|
+
# has only one consumer.
|
|
207
|
+
inputs.append(matched_transpose.input[0])
|
|
208
|
+
# See whether it match bias pattern.
|
|
209
|
+
if bias is None:
|
|
210
|
+
bias = self.match_bias_path(matched_transpose, input_name_to_nodes, output_name_to_node)
|
|
211
|
+
else:
|
|
212
|
+
# Otherwise, insert a Transpose node before Add.
|
|
213
|
+
new_transpose = self.create_transpose_node(add.input[i], [0, 2, 3, 1])
|
|
214
|
+
self.model.add_node(new_transpose, self.this_graph_name)
|
|
215
|
+
inputs.append(new_transpose.output[0])
|
|
216
|
+
|
|
217
|
+
skip, broadcast = self.get_skip_index(add, is_channel_last=False)
|
|
218
|
+
if skip < 0:
|
|
219
|
+
return
|
|
220
|
+
|
|
221
|
+
inputs = [inputs[1 - skip], node.input[1], node.input[2], inputs[skip]]
|
|
222
|
+
if bias:
|
|
223
|
+
inputs = [*inputs, bias]
|
|
224
|
+
|
|
225
|
+
outputs = node.output
|
|
226
|
+
|
|
227
|
+
new_node_name = self.model.create_node_name(self.fused_op_type, name_prefix="SkipGroupNorm")
|
|
228
|
+
if self.has_multiple_consumers(add.output[0], input_name_to_nodes):
|
|
229
|
+
add_out_name = new_node_name + "_add_out"
|
|
230
|
+
outputs.append(add_out_name)
|
|
231
|
+
|
|
232
|
+
# Insert a Transpose node after add output.
|
|
233
|
+
add_out_transpose = self.create_transpose_node(add_out_name, [0, 3, 1, 2], add.output[0])
|
|
234
|
+
self.model.add_node(add_out_transpose, self.this_graph_name)
|
|
235
|
+
|
|
236
|
+
skip_group_norm = helper.make_node(
|
|
237
|
+
self.fused_op_type,
|
|
238
|
+
inputs=inputs,
|
|
239
|
+
outputs=outputs,
|
|
240
|
+
name=new_node_name,
|
|
241
|
+
)
|
|
242
|
+
skip_group_norm.domain = "com.microsoft"
|
|
243
|
+
|
|
244
|
+
self.increase_counter(
|
|
245
|
+
f"SkipGroupNorm(add_out={int(len(outputs) > 1)} bias={int(bias is not None)} broadcast={int(broadcast)})"
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# Pass attributes from GroupNorm node to SkipGroupNorm
|
|
249
|
+
for att in node.attribute:
|
|
250
|
+
skip_group_norm.attribute.extend([att])
|
|
251
|
+
|
|
252
|
+
self.nodes_to_remove.extend([add, transpose, node])
|
|
253
|
+
self.nodes_to_add.append(skip_group_norm)
|
|
254
|
+
self.node_name_to_graph_name[skip_group_norm.name] = self.this_graph_name
|
|
255
|
+
self.prune_graph = True
|