onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
from logging import getLogger
|
|
7
|
+
|
|
8
|
+
from fusion_base import Fusion
|
|
9
|
+
from fusion_utils import FusionUtils
|
|
10
|
+
from numpy import ndarray
|
|
11
|
+
from onnx import NodeProto, TensorProto
|
|
12
|
+
from onnx_model import OnnxModel
|
|
13
|
+
|
|
14
|
+
logger = getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class FusionShape(Fusion):
|
|
18
|
+
def __init__(self, model: OnnxModel):
|
|
19
|
+
super().__init__(model, "Shape", "Concat")
|
|
20
|
+
self.utils = FusionUtils(model)
|
|
21
|
+
self.shape_infer = None
|
|
22
|
+
self.shape_infer_done = False
|
|
23
|
+
|
|
24
|
+
def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> int | None:
|
|
25
|
+
if tensor_proto.type.tensor_type.HasField("shape"):
|
|
26
|
+
return len(tensor_proto.type.tensor_type.shape.dim)
|
|
27
|
+
else:
|
|
28
|
+
return None
|
|
29
|
+
|
|
30
|
+
def get_dimensions(self, input_name: str) -> int | None:
|
|
31
|
+
shape = self.model.get_shape(input_name)
|
|
32
|
+
if shape is not None:
|
|
33
|
+
return len(shape)
|
|
34
|
+
|
|
35
|
+
if not self.shape_infer_done:
|
|
36
|
+
self.shape_infer = self.model.infer_runtime_shape(update=True)
|
|
37
|
+
self.shape_infer_done = True
|
|
38
|
+
|
|
39
|
+
if self.shape_infer is not None:
|
|
40
|
+
return self.get_dimensions_from_tensor_proto(self.shape_infer.known_vi_[input_name])
|
|
41
|
+
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
def fuse(
|
|
45
|
+
self,
|
|
46
|
+
concat_node: NodeProto,
|
|
47
|
+
input_name_to_nodes: dict[str, list[NodeProto]],
|
|
48
|
+
output_name_to_node: dict[str, NodeProto],
|
|
49
|
+
):
|
|
50
|
+
#
|
|
51
|
+
# Simplify subgraph like
|
|
52
|
+
#
|
|
53
|
+
# (2d_input)
|
|
54
|
+
# / \
|
|
55
|
+
# Shape shape
|
|
56
|
+
# / \
|
|
57
|
+
# Gather(indices=0) Gather(indices=1)
|
|
58
|
+
# | |
|
|
59
|
+
# Unsqueeze(axes=0) Unsqueeze(axes=0)
|
|
60
|
+
# \ /
|
|
61
|
+
# Concat
|
|
62
|
+
# |
|
|
63
|
+
#
|
|
64
|
+
# into (2d_input) --> Shape -->
|
|
65
|
+
#
|
|
66
|
+
opset_version = self.model.get_opset_version()
|
|
67
|
+
|
|
68
|
+
inputs = len(concat_node.input)
|
|
69
|
+
root = None
|
|
70
|
+
shape_output = None
|
|
71
|
+
for i in range(inputs):
|
|
72
|
+
path = self.model.match_parent_path(
|
|
73
|
+
concat_node,
|
|
74
|
+
["Unsqueeze", "Gather", "Shape"],
|
|
75
|
+
[i, 0, 0],
|
|
76
|
+
output_name_to_node,
|
|
77
|
+
)
|
|
78
|
+
if path is None:
|
|
79
|
+
return
|
|
80
|
+
|
|
81
|
+
unsqueeze, gather, shape = path
|
|
82
|
+
if i == 0:
|
|
83
|
+
shape_output = shape.output[0]
|
|
84
|
+
if root is None:
|
|
85
|
+
root = shape.input[0]
|
|
86
|
+
if self.get_dimensions(root) != inputs:
|
|
87
|
+
return
|
|
88
|
+
elif shape.input[0] != root:
|
|
89
|
+
return
|
|
90
|
+
|
|
91
|
+
if not FusionUtils.check_node_attribute(unsqueeze, "axis", 0, default_value=0):
|
|
92
|
+
return
|
|
93
|
+
|
|
94
|
+
if opset_version < 13:
|
|
95
|
+
if not FusionUtils.check_node_attribute(unsqueeze, "axes", [0]):
|
|
96
|
+
return
|
|
97
|
+
else:
|
|
98
|
+
if not self.utils.check_node_input_value(unsqueeze, 1, [0]):
|
|
99
|
+
return
|
|
100
|
+
|
|
101
|
+
value = self.model.get_constant_value(gather.input[1])
|
|
102
|
+
|
|
103
|
+
if not (isinstance(value, ndarray) and value.size == 1 and value.item() == i):
|
|
104
|
+
return
|
|
105
|
+
|
|
106
|
+
if self.model.find_graph_output(concat_node.output[0]) is None:
|
|
107
|
+
self.model.replace_input_of_all_nodes(concat_node.output[0], shape_output)
|
|
108
|
+
self.increase_counter("Reshape")
|
|
109
|
+
self.prune_graph = True
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from fusion_base import Fusion
|
|
4
|
+
from fusion_skiplayernorm import FusionSkipLayerNormalization
|
|
5
|
+
from onnx import helper
|
|
6
|
+
from onnx_model import OnnxModel
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class FusionSimplifiedLayerNormalization(Fusion):
|
|
12
|
+
def __init__(self, model: OnnxModel):
|
|
13
|
+
super().__init__(model, "SimplifiedLayerNormalization", "Mul")
|
|
14
|
+
|
|
15
|
+
def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict):
|
|
16
|
+
if node.op_type != "Mul":
|
|
17
|
+
return
|
|
18
|
+
|
|
19
|
+
sim_ln_nodes = None
|
|
20
|
+
# RMSNorm formula:
|
|
21
|
+
# S = Pow(X, 2) or S = Mul(X, X)
|
|
22
|
+
# MS = ReduceMean(S)
|
|
23
|
+
# MSEps = Add(MS, epsilon)
|
|
24
|
+
# RMS = Sqrt(MSEps)
|
|
25
|
+
# InvRMS = Div(1, RMS) or InvRMS = Reciprocal(RMS)
|
|
26
|
+
# Normalized = Mul(D, InvRMS)
|
|
27
|
+
# Y = Mul(Normalized, Scale)
|
|
28
|
+
#
|
|
29
|
+
# (root_input) ----------------------------------------+
|
|
30
|
+
# | |
|
|
31
|
+
# v v
|
|
32
|
+
# Pow --> ReduceMean --> Add ---> Sqrt --> Div --> Mul --> Mul (node)
|
|
33
|
+
# (B=2) (A/B=eps) (A=1) (A/B=scale)
|
|
34
|
+
#
|
|
35
|
+
# (root_input) ----------------------------------------+
|
|
36
|
+
# | | |
|
|
37
|
+
# v v v
|
|
38
|
+
# Mul --> ReduceMean --> Add ---> Sqrt --> Div --> Mul --> Mul (node)
|
|
39
|
+
# (B=2) (A/B=eps) (A=1) (A/B=scale)
|
|
40
|
+
#
|
|
41
|
+
return_indice = []
|
|
42
|
+
sim_ln_nodes = self.model.match_parent_path(
|
|
43
|
+
node,
|
|
44
|
+
["Mul", "Div", "Sqrt", "Add", "ReduceMean"],
|
|
45
|
+
[None, 1, 1, 0, None],
|
|
46
|
+
output_name_to_node=output_name_to_node,
|
|
47
|
+
return_indice=return_indice,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
if sim_ln_nodes:
|
|
51
|
+
mul_node, div_node, _sqrt_node, add_node, reduce_mean_node = sim_ln_nodes
|
|
52
|
+
if not self.model.has_constant_input(div_node, 1.0):
|
|
53
|
+
return
|
|
54
|
+
node_parent = mul_node
|
|
55
|
+
else:
|
|
56
|
+
# Div(1, RMS) can also be represented as Reciprocal(RMS) like
|
|
57
|
+
#
|
|
58
|
+
# (root_input) -----------------------------------------------+
|
|
59
|
+
# | |
|
|
60
|
+
# v v
|
|
61
|
+
# Pow --> ReduceMean --> Add ---> Sqrt --> Reciprocal --> Mul --> Mul (node)
|
|
62
|
+
# (B=2) (A/B=eps) (A/B=scale)
|
|
63
|
+
#
|
|
64
|
+
# (root_input) -----------------------------------------------+
|
|
65
|
+
# | | |
|
|
66
|
+
# v v v
|
|
67
|
+
# Mul --> ReduceMean --> Add ---> Sqrt --> Reciprocal --> Mul --> Mul (node)
|
|
68
|
+
# (B=2) (A/B=eps) (A/B=scale)
|
|
69
|
+
#
|
|
70
|
+
return_indice = []
|
|
71
|
+
sim_ln_nodes = self.model.match_parent_path(
|
|
72
|
+
node,
|
|
73
|
+
["Mul", "Reciprocal", "Sqrt", "Add", "ReduceMean"],
|
|
74
|
+
[None, 1, 0, 0, None],
|
|
75
|
+
output_name_to_node=output_name_to_node,
|
|
76
|
+
return_indice=return_indice,
|
|
77
|
+
)
|
|
78
|
+
if sim_ln_nodes is not None:
|
|
79
|
+
mul_node, _reciprocal_node, _sqrt_node, add_node, reduce_mean_node = sim_ln_nodes
|
|
80
|
+
node_parent = mul_node
|
|
81
|
+
else:
|
|
82
|
+
# (root_input) --------------------------------+
|
|
83
|
+
# | |
|
|
84
|
+
# v v
|
|
85
|
+
# Pow --> ReduceMean --> Add ---> Sqrt --> Div --> Mul (node)
|
|
86
|
+
# (B=2) (A/B=eps) (A/B=scale)
|
|
87
|
+
#
|
|
88
|
+
# (root_input) --------------------------------+
|
|
89
|
+
# | | |
|
|
90
|
+
# v v v
|
|
91
|
+
# Mul --> ReduceMean --> Add ---> Sqrt --> Div --> Mul (node)
|
|
92
|
+
# (B=2) (A/B=eps) (A/B=scale)
|
|
93
|
+
#
|
|
94
|
+
return_indice = []
|
|
95
|
+
sim_ln_nodes = self.model.match_parent_path(
|
|
96
|
+
node,
|
|
97
|
+
["Div", "Sqrt", "Add", "ReduceMean"],
|
|
98
|
+
[None, 1, 0, None],
|
|
99
|
+
output_name_to_node=output_name_to_node,
|
|
100
|
+
return_indice=return_indice,
|
|
101
|
+
)
|
|
102
|
+
if sim_ln_nodes is not None:
|
|
103
|
+
div_node, _sqrt_node, add_node, reduce_mean_node = sim_ln_nodes
|
|
104
|
+
node_parent = div_node
|
|
105
|
+
else:
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
reduce_mean_parent = self.model.get_parent(reduce_mean_node, 0, output_name_to_node)
|
|
109
|
+
if reduce_mean_parent is None or reduce_mean_parent.op_type not in ["Pow", "Mul"]:
|
|
110
|
+
return
|
|
111
|
+
|
|
112
|
+
if reduce_mean_parent.op_type == "Pow":
|
|
113
|
+
if self.model.find_constant_input(reduce_mean_parent, 2.0) != 1:
|
|
114
|
+
return
|
|
115
|
+
else:
|
|
116
|
+
assert reduce_mean_parent.op_type == "Mul"
|
|
117
|
+
if reduce_mean_parent[0] != reduce_mean_parent[1]:
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
root_input = reduce_mean_parent.input[0]
|
|
121
|
+
if root_input not in node_parent.input:
|
|
122
|
+
return
|
|
123
|
+
|
|
124
|
+
_i, epsilon = self.model.get_constant_input(add_node)
|
|
125
|
+
if epsilon is None or epsilon <= 0 or epsilon > 1.0e-4:
|
|
126
|
+
logger.warning(f"epsilon value is not expected: {epsilon}")
|
|
127
|
+
return
|
|
128
|
+
|
|
129
|
+
# ReduceMean must have keepdims == 1
|
|
130
|
+
keepdims = self.model.get_node_attribute(reduce_mean_node, "keepdims")
|
|
131
|
+
if not keepdims:
|
|
132
|
+
return
|
|
133
|
+
|
|
134
|
+
# ReduceMean axes must refer only to the last dimension.
|
|
135
|
+
# Axes became an input in opset 18. Before then, axes was an attribute.
|
|
136
|
+
axes = self.model.get_node_attribute(reduce_mean_node, "axes")
|
|
137
|
+
if (not axes) and len(reduce_mean_node.input) > 1:
|
|
138
|
+
axes = self.model.get_constant_value(reduce_mean_node.input[1])
|
|
139
|
+
# Make sure only one axis as required by SimplifiedLayerNormalization spec.
|
|
140
|
+
if not axes or len(axes) != 1:
|
|
141
|
+
return
|
|
142
|
+
|
|
143
|
+
self.nodes_to_remove.extend(sim_ln_nodes)
|
|
144
|
+
self.nodes_to_remove.append(reduce_mean_parent)
|
|
145
|
+
self.nodes_to_remove.append(node)
|
|
146
|
+
|
|
147
|
+
normalize_node = helper.make_node(
|
|
148
|
+
"SimplifiedLayerNormalization",
|
|
149
|
+
inputs=[root_input, node.input[1 - return_indice[0]]],
|
|
150
|
+
outputs=[node.output[0]],
|
|
151
|
+
name=self.model.create_node_name("SimplifiedLayerNormalization", name_prefix="RMSNorm"),
|
|
152
|
+
)
|
|
153
|
+
normalize_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))])
|
|
154
|
+
normalize_node.attribute.extend([helper.make_attribute("axis", axes[0])])
|
|
155
|
+
normalize_node.attribute.extend([helper.make_attribute("stash_type", 1)])
|
|
156
|
+
self.nodes_to_add.append(normalize_node)
|
|
157
|
+
self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class FusionSkipSimplifiedLayerNormalization(FusionSkipLayerNormalization):
|
|
161
|
+
def __init__(self, model: OnnxModel):
|
|
162
|
+
super().__init__(model, "SkipSimplifiedLayerNormalization", "SimplifiedLayerNormalization")
|
|
163
|
+
|
|
164
|
+
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
|
165
|
+
super().fuse(node, input_name_to_nodes, output_name_to_node)
|
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
from logging import getLogger
|
|
6
|
+
|
|
7
|
+
from fusion_base import Fusion
|
|
8
|
+
from fusion_utils import NumpyHelper
|
|
9
|
+
from onnx import helper
|
|
10
|
+
from onnx_model import OnnxModel
|
|
11
|
+
|
|
12
|
+
logger = getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FusionSkipGroupNorm(Fusion):
|
|
16
|
+
"""
|
|
17
|
+
Fuse Add + GroupNorm into one node: SkipGroupNorm.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, model: OnnxModel):
|
|
21
|
+
super().__init__(model, "SkipGroupNorm", "GroupNorm")
|
|
22
|
+
# Update shape inference is needed since other fusions might add new edge which does not have shape info yet.
|
|
23
|
+
self.shape_infer_helper = self.model.infer_runtime_shape(update=True)
|
|
24
|
+
|
|
25
|
+
if self.shape_infer_helper is None:
|
|
26
|
+
logger.warning("SkipGroupNorm fusion will be skipped since symbolic shape inference disabled or failed.")
|
|
27
|
+
|
|
28
|
+
def create_transpose_node(self, input_name: str, perm: list[int], output_name=None):
|
|
29
|
+
"""Append a Transpose node after an input"""
|
|
30
|
+
node_name = self.model.create_node_name("Transpose")
|
|
31
|
+
if output_name is None:
|
|
32
|
+
output_name = node_name + "_out" + "-" + input_name
|
|
33
|
+
transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
|
|
34
|
+
transpose_node.attribute.extend([helper.make_attribute("perm", perm)])
|
|
35
|
+
return transpose_node
|
|
36
|
+
|
|
37
|
+
def get_skip_index(self, add, is_channel_last: bool):
|
|
38
|
+
"""Add has two inputs. This classifies which input is skip based on shape info (skip allows broadcast)."""
|
|
39
|
+
skip = -1
|
|
40
|
+
broadcast = False
|
|
41
|
+
|
|
42
|
+
assert self.shape_infer_helper is not None
|
|
43
|
+
shape_a = self.shape_infer_helper.get_edge_shape(add.input[0])
|
|
44
|
+
shape_b = self.shape_infer_helper.get_edge_shape(add.input[1])
|
|
45
|
+
assert shape_a is not None and shape_b is not None
|
|
46
|
+
|
|
47
|
+
if len(shape_a) == 4 and len(shape_b) == 4:
|
|
48
|
+
if shape_a == shape_b:
|
|
49
|
+
skip = 1
|
|
50
|
+
else:
|
|
51
|
+
c = 3 if is_channel_last else 1
|
|
52
|
+
h = 1 if is_channel_last else 2
|
|
53
|
+
w = 2 if is_channel_last else 3
|
|
54
|
+
if shape_a[0] == shape_b[0] and shape_a[c] == shape_b[c]:
|
|
55
|
+
if shape_b[h] == 1 and shape_b[w] == 1:
|
|
56
|
+
skip = 1
|
|
57
|
+
broadcast = True
|
|
58
|
+
elif shape_a[h] == 1 and shape_a[w] == 1:
|
|
59
|
+
skip = 0
|
|
60
|
+
broadcast = True
|
|
61
|
+
|
|
62
|
+
if skip < 0:
|
|
63
|
+
logger.debug(
|
|
64
|
+
"skip SkipGroupNorm fusion since shape of Add inputs (%s, %s) are not expected",
|
|
65
|
+
add.input[0],
|
|
66
|
+
add.input[1],
|
|
67
|
+
)
|
|
68
|
+
return skip, broadcast
|
|
69
|
+
|
|
70
|
+
def has_multiple_consumers(self, output_name, input_name_to_nodes):
|
|
71
|
+
"""Whether an output has multiple consumers (like graph output or more than one children nodes)"""
|
|
72
|
+
return self.model.find_graph_output(output_name) is not None or (
|
|
73
|
+
output_name in input_name_to_nodes and len(input_name_to_nodes[output_name]) > 1
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def remove_if_safe(self, node, input_name_to_nodes):
|
|
77
|
+
"""Remove a node if it is safe (only one children, and not graph output)"""
|
|
78
|
+
if not self.has_multiple_consumers(node.output[0], input_name_to_nodes):
|
|
79
|
+
self.nodes_to_remove.extend([node])
|
|
80
|
+
|
|
81
|
+
def is_bias_1d(self, bias_name: str):
|
|
82
|
+
"""Whether bias is an initializer of one dimension"""
|
|
83
|
+
initializer = self.model.get_initializer(bias_name)
|
|
84
|
+
if initializer is None:
|
|
85
|
+
return False
|
|
86
|
+
|
|
87
|
+
bias_weight = NumpyHelper.to_array(initializer)
|
|
88
|
+
if bias_weight is None:
|
|
89
|
+
logger.debug("Bias weight not found")
|
|
90
|
+
return False
|
|
91
|
+
|
|
92
|
+
if len(bias_weight.shape) != 1:
|
|
93
|
+
logger.debug("Bias weight is not 1D")
|
|
94
|
+
return False
|
|
95
|
+
return True
|
|
96
|
+
|
|
97
|
+
def match_bias_path(self, node, input_name_to_nodes, output_name_to_node):
|
|
98
|
+
"""
|
|
99
|
+
Match the bias graph pattern from an Transpose node after Reshape node like in below example.
|
|
100
|
+
It checks whether the bias is 1D initializer. If so, remove Add and redirect MatMul output to Reshape.
|
|
101
|
+
"""
|
|
102
|
+
# Before Fusion:
|
|
103
|
+
# MatMul (bias)
|
|
104
|
+
# \ / (shape)
|
|
105
|
+
# Add /
|
|
106
|
+
# \ /
|
|
107
|
+
# (a) Reshape
|
|
108
|
+
# \ |
|
|
109
|
+
# Transpose([0, 3, 1, 2]) Transpose([0, 3, 1, 2]) --- the start node, this func only handles the above nodes.
|
|
110
|
+
# \ /
|
|
111
|
+
# Add
|
|
112
|
+
# / \
|
|
113
|
+
# (c) Transpose([0,2,3,1])
|
|
114
|
+
# |
|
|
115
|
+
# GroupNorm
|
|
116
|
+
# |
|
|
117
|
+
# (d)
|
|
118
|
+
#
|
|
119
|
+
# After Fusion (the nodes below Reshape is handled in the fuse function):
|
|
120
|
+
# MatMul (shape)
|
|
121
|
+
# \ /
|
|
122
|
+
# (a) Reshape
|
|
123
|
+
# \ /
|
|
124
|
+
# SkipGroupNorm
|
|
125
|
+
# / \
|
|
126
|
+
# (d) Transpose([0, 3, 1, 2])
|
|
127
|
+
# \
|
|
128
|
+
# (c)
|
|
129
|
+
|
|
130
|
+
add_input_index = []
|
|
131
|
+
bias_nodes = self.model.match_parent_path(
|
|
132
|
+
node, ["Reshape", "Add", "MatMul"], [0, 0, None], output_name_to_node, add_input_index
|
|
133
|
+
)
|
|
134
|
+
if bias_nodes is None:
|
|
135
|
+
return None
|
|
136
|
+
|
|
137
|
+
(reshape, add_bias, matmul) = bias_nodes
|
|
138
|
+
bias = bias_nodes[1].input[1 - add_input_index[0]]
|
|
139
|
+
if not self.is_bias_1d(bias):
|
|
140
|
+
return None
|
|
141
|
+
|
|
142
|
+
reshape.input[0] = matmul.output[0]
|
|
143
|
+
self.remove_if_safe(add_bias, input_name_to_nodes)
|
|
144
|
+
|
|
145
|
+
return bias
|
|
146
|
+
|
|
147
|
+
def match_transpose_from_nhwc(self, output_name, input_name_to_nodes, output_name_to_node):
|
|
148
|
+
"""Match whether an output is from a Transpose(perm=[0,3,1,2]) node."""
|
|
149
|
+
parent = output_name_to_node.get(output_name, None)
|
|
150
|
+
if parent is not None and parent.op_type == "Transpose":
|
|
151
|
+
permutation = OnnxModel.get_node_attribute(parent, "perm")
|
|
152
|
+
if permutation == [0, 3, 1, 2]:
|
|
153
|
+
self.remove_if_safe(parent, input_name_to_nodes)
|
|
154
|
+
return parent
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
|
158
|
+
# This fusion requires shape information, so skip it if shape is not available.
|
|
159
|
+
if self.shape_infer_helper is None:
|
|
160
|
+
return
|
|
161
|
+
|
|
162
|
+
# Before Fusion:
|
|
163
|
+
# (a) (b)
|
|
164
|
+
# \ /
|
|
165
|
+
# Add
|
|
166
|
+
# /\
|
|
167
|
+
# (c) Transpose([0,2,3,1])
|
|
168
|
+
# \
|
|
169
|
+
# GroupNorm
|
|
170
|
+
# |
|
|
171
|
+
# (d)
|
|
172
|
+
#
|
|
173
|
+
# After Fusion:
|
|
174
|
+
# (a) (b)
|
|
175
|
+
# \ /
|
|
176
|
+
# Transpose([0,2,3,1]) Transpose([0,2,3,1])
|
|
177
|
+
# \ /
|
|
178
|
+
# SkipGroupNorm
|
|
179
|
+
# / \
|
|
180
|
+
# / Transpose([0, 3, 1, 2])
|
|
181
|
+
# / \
|
|
182
|
+
# (d) (c)
|
|
183
|
+
nodes = self.model.match_parent_path(node, ["Transpose", "Add"], [0, 0], output_name_to_node)
|
|
184
|
+
if nodes is None:
|
|
185
|
+
return
|
|
186
|
+
|
|
187
|
+
(transpose, add) = nodes
|
|
188
|
+
if transpose in self.nodes_to_remove or add in self.nodes_to_remove:
|
|
189
|
+
return
|
|
190
|
+
|
|
191
|
+
if self.has_multiple_consumers(transpose.output[0], input_name_to_nodes):
|
|
192
|
+
return
|
|
193
|
+
|
|
194
|
+
permutation = OnnxModel.get_node_attribute(transpose, "perm")
|
|
195
|
+
if permutation != [0, 2, 3, 1]:
|
|
196
|
+
return
|
|
197
|
+
|
|
198
|
+
inputs = []
|
|
199
|
+
bias = None
|
|
200
|
+
for i in range(2):
|
|
201
|
+
matched_transpose = self.match_transpose_from_nhwc(add.input[i], input_name_to_nodes, output_name_to_node)
|
|
202
|
+
if matched_transpose:
|
|
203
|
+
# When there is an Transpose node before Add (see examples in match_bias_path), we do not need to
|
|
204
|
+
# insert another Transpose node. The existing Transpose node will be removed in prune_graph if it
|
|
205
|
+
# has only one consumer.
|
|
206
|
+
inputs.append(matched_transpose.input[0])
|
|
207
|
+
# See whether it match bias pattern.
|
|
208
|
+
if bias is None:
|
|
209
|
+
bias = self.match_bias_path(matched_transpose, input_name_to_nodes, output_name_to_node)
|
|
210
|
+
else:
|
|
211
|
+
# Otherwise, insert a Transpose node before Add.
|
|
212
|
+
new_transpose = self.create_transpose_node(add.input[i], [0, 2, 3, 1])
|
|
213
|
+
self.model.add_node(new_transpose, self.this_graph_name)
|
|
214
|
+
inputs.append(new_transpose.output[0])
|
|
215
|
+
|
|
216
|
+
skip, broadcast = self.get_skip_index(add, is_channel_last=False)
|
|
217
|
+
if skip < 0:
|
|
218
|
+
return
|
|
219
|
+
|
|
220
|
+
inputs = [inputs[1 - skip], node.input[1], node.input[2], inputs[skip]]
|
|
221
|
+
if bias:
|
|
222
|
+
inputs = [*inputs, bias]
|
|
223
|
+
|
|
224
|
+
outputs = node.output
|
|
225
|
+
|
|
226
|
+
new_node_name = self.model.create_node_name(self.fused_op_type, name_prefix="SkipGroupNorm")
|
|
227
|
+
if self.has_multiple_consumers(add.output[0], input_name_to_nodes):
|
|
228
|
+
add_out_name = new_node_name + "_add_out"
|
|
229
|
+
outputs.append(add_out_name)
|
|
230
|
+
|
|
231
|
+
# Insert a Transpose node after add output.
|
|
232
|
+
add_out_transpose = self.create_transpose_node(add_out_name, [0, 3, 1, 2], add.output[0])
|
|
233
|
+
self.model.add_node(add_out_transpose, self.this_graph_name)
|
|
234
|
+
|
|
235
|
+
skip_group_norm = helper.make_node(
|
|
236
|
+
self.fused_op_type,
|
|
237
|
+
inputs=inputs,
|
|
238
|
+
outputs=outputs,
|
|
239
|
+
name=new_node_name,
|
|
240
|
+
)
|
|
241
|
+
skip_group_norm.domain = "com.microsoft"
|
|
242
|
+
|
|
243
|
+
self.increase_counter(
|
|
244
|
+
f"SkipGroupNorm(add_out={int(len(outputs) > 1)} bias={int(bias is not None)} broadcast={int(broadcast)})"
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# Pass attributes from GroupNorm node to SkipGroupNorm
|
|
248
|
+
for att in node.attribute:
|
|
249
|
+
skip_group_norm.attribute.extend([att])
|
|
250
|
+
|
|
251
|
+
self.nodes_to_remove.extend([add, transpose, node])
|
|
252
|
+
self.nodes_to_add.append(skip_group_norm)
|
|
253
|
+
self.node_name_to_graph_name[skip_group_norm.name] = self.this_graph_name
|
|
254
|
+
self.prune_graph = True
|