onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
from logging import getLogger
|
|
7
|
+
|
|
8
|
+
from fusion_base import Fusion
|
|
9
|
+
from fusion_utils import NumpyHelper
|
|
10
|
+
from onnx import helper
|
|
11
|
+
from onnx_model import OnnxModel
|
|
12
|
+
|
|
13
|
+
logger = getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class FusionSkipLayerNormalization(Fusion):
|
|
17
|
+
"""
|
|
18
|
+
Fuse Add + LayerNormalization into one node: SkipLayerNormalization
|
|
19
|
+
Note: This fusion does not check the input shape of Add and LayerNormalization.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
model: OnnxModel,
|
|
25
|
+
fused_op_type: str = "SkipLayerNormalization",
|
|
26
|
+
search_op_types: str = "LayerNormalization",
|
|
27
|
+
shape_infer: bool = True,
|
|
28
|
+
):
|
|
29
|
+
super().__init__(model, fused_op_type, search_op_types)
|
|
30
|
+
if shape_infer:
|
|
31
|
+
# Update shape inference is needed since other fusions might add new edge which does not have shape info yet.
|
|
32
|
+
self.shape_infer_helper = self.model.infer_runtime_shape({"batch_size": 4, "seq_len": 7}, update=True)
|
|
33
|
+
if self.shape_infer_helper is None:
|
|
34
|
+
# TODO(tianleiwu): support subgraph in shape inference or add broadcasting in SkipLayerNormalization op.
|
|
35
|
+
logger.warning("symbolic shape inference disabled or failed.")
|
|
36
|
+
|
|
37
|
+
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
|
38
|
+
add = self.model.get_parent(node, 0, output_name_to_node)
|
|
39
|
+
|
|
40
|
+
# In some models there is input_ids->gather->add->LayerNorm and one of input of the
|
|
41
|
+
# add node is initializer with fixed shape which should not be fused into SkipLayerNorm
|
|
42
|
+
if add is None or add.op_type != "Add":
|
|
43
|
+
return
|
|
44
|
+
|
|
45
|
+
# The number of inputs of add should be 2
|
|
46
|
+
if len(add.input) != 2:
|
|
47
|
+
return
|
|
48
|
+
|
|
49
|
+
for add_input in add.input:
|
|
50
|
+
if self.model.get_initializer(add_input) is not None:
|
|
51
|
+
return
|
|
52
|
+
|
|
53
|
+
# To avoid an Add node have two children of LayerNormalization, we shall only fuse one SkipLayerNormalization
|
|
54
|
+
if add in self.nodes_to_remove:
|
|
55
|
+
return
|
|
56
|
+
|
|
57
|
+
# Root Mean Square Layer Normalization
|
|
58
|
+
simplified = node.op_type == "SimplifiedLayerNormalization"
|
|
59
|
+
|
|
60
|
+
if hasattr(self, "shape_infer_helper"):
|
|
61
|
+
if self.shape_infer_helper is not None:
|
|
62
|
+
if (
|
|
63
|
+
self.shape_infer_helper.get_edge_shape(add.input[0])
|
|
64
|
+
and len(self.shape_infer_helper.get_edge_shape(add.input[0])) != 3
|
|
65
|
+
):
|
|
66
|
+
logger.debug("skip SkipLayerNormalization fusion since shape of input %s is not 3D", add.input[0])
|
|
67
|
+
return
|
|
68
|
+
|
|
69
|
+
# TODO(tianleiwu): support broadcasting Skip shape (1, sequence_length, hidden_size) or (sequence_length, hidden_size)
|
|
70
|
+
if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]):
|
|
71
|
+
logger.debug(
|
|
72
|
+
"skip SkipLayerNormalization fusion since shape of inputs (%s, %s) are not same",
|
|
73
|
+
add.input[0],
|
|
74
|
+
add.input[1],
|
|
75
|
+
)
|
|
76
|
+
return
|
|
77
|
+
else:
|
|
78
|
+
logger.debug("skip SkipLayerNormalization fusion since symbolic shape inference failed")
|
|
79
|
+
return
|
|
80
|
+
|
|
81
|
+
gather_path = self.model.match_parent_path(add, ["Gather"], [None])
|
|
82
|
+
if gather_path is not None and self.model.find_graph_input(gather_path[0].input[1]) is None:
|
|
83
|
+
if self.model.match_parent_path(gather_path[0], ["ConstantOfShape"], [1]) is None:
|
|
84
|
+
return
|
|
85
|
+
|
|
86
|
+
# This means that the residual Add before the LayerNormalization produces an output
|
|
87
|
+
# that is consumed by some other nodes or graph output other than the LayerNormalization itself
|
|
88
|
+
# We can still go ahead with the SkipLayerNormalization fusion but we need to
|
|
89
|
+
# preserve the output of Add and that needs to be produced by SkipLayerNormalization.
|
|
90
|
+
add_has_graph_output = self.model.find_graph_output(add.output[0]) is not None
|
|
91
|
+
residual_add_has_multiple_consumers = (
|
|
92
|
+
add_has_graph_output or len(self.model.get_children(add, input_name_to_nodes)) > 1
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
outputs_to_keep = node.output
|
|
96
|
+
|
|
97
|
+
if residual_add_has_multiple_consumers:
|
|
98
|
+
outputs_to_keep.extend([add.output[0]])
|
|
99
|
+
|
|
100
|
+
outputs = [node.output[0]]
|
|
101
|
+
|
|
102
|
+
# Skip the other optional outputs of SkipLayerNormalization before adding the Add's output
|
|
103
|
+
if residual_add_has_multiple_consumers:
|
|
104
|
+
outputs.extend(["", "", add.output[0]])
|
|
105
|
+
|
|
106
|
+
if self.model.is_safe_to_fuse_nodes([add, node], outputs_to_keep, input_name_to_nodes, output_name_to_node):
|
|
107
|
+
self.nodes_to_remove.extend([add, node])
|
|
108
|
+
|
|
109
|
+
inputs = (
|
|
110
|
+
[add.input[0], add.input[1], node.input[1], node.input[2]]
|
|
111
|
+
if not simplified
|
|
112
|
+
else [add.input[0], add.input[1], node.input[1]]
|
|
113
|
+
)
|
|
114
|
+
normalize_node = helper.make_node(
|
|
115
|
+
self.fused_op_type,
|
|
116
|
+
inputs=inputs,
|
|
117
|
+
outputs=outputs,
|
|
118
|
+
name=self.model.create_node_name(self.fused_op_type, name_prefix="SkipLayerNorm"),
|
|
119
|
+
)
|
|
120
|
+
normalize_node.domain = "com.microsoft"
|
|
121
|
+
|
|
122
|
+
# Pass attribute "epsilon" from layernorm node to SkipLayerNormalization
|
|
123
|
+
for att in node.attribute:
|
|
124
|
+
if att.name == "epsilon":
|
|
125
|
+
normalize_node.attribute.extend([att])
|
|
126
|
+
|
|
127
|
+
# Set default epsilon if no epsilon exists from layernorm
|
|
128
|
+
if len(normalize_node.attribute) == 0:
|
|
129
|
+
normalize_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)])
|
|
130
|
+
|
|
131
|
+
self.nodes_to_add.append(normalize_node)
|
|
132
|
+
self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class FusionBiasSkipLayerNormalization(Fusion):
|
|
136
|
+
def __init__(self, model: OnnxModel):
|
|
137
|
+
super().__init__(model, "SkipLayerNormalization", "SkipLayerNormalization", "add bias")
|
|
138
|
+
|
|
139
|
+
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
|
140
|
+
if len(node.input) != 4:
|
|
141
|
+
return
|
|
142
|
+
|
|
143
|
+
return_indice = []
|
|
144
|
+
nodes = self.model.match_parent_path(node, ["Add", "MatMul"], [None, None], output_name_to_node, return_indice)
|
|
145
|
+
if nodes is not None:
|
|
146
|
+
(add, _matmul) = nodes
|
|
147
|
+
else:
|
|
148
|
+
# In case of fp16, we could have a Cast between the MatMul and the bias Add
|
|
149
|
+
return_indice = []
|
|
150
|
+
nodes = self.model.match_parent_path(
|
|
151
|
+
node, ["Add", "Cast", "MatMul"], [None, None, None], output_name_to_node, return_indice
|
|
152
|
+
)
|
|
153
|
+
if nodes is not None:
|
|
154
|
+
(add, _cast, _matmul) = nodes
|
|
155
|
+
else:
|
|
156
|
+
return
|
|
157
|
+
|
|
158
|
+
assert len(return_indice) == 2 or len(return_indice) == 3
|
|
159
|
+
add_input_index = return_indice[0]
|
|
160
|
+
if add_input_index >= 2:
|
|
161
|
+
return
|
|
162
|
+
sln_input = add.input[return_indice[1]]
|
|
163
|
+
bias_input = add.input[1 - return_indice[1]]
|
|
164
|
+
skip_input = node.input[1 - add_input_index]
|
|
165
|
+
|
|
166
|
+
# bias should be one dimension
|
|
167
|
+
initializer = self.model.get_initializer(bias_input)
|
|
168
|
+
if initializer is None:
|
|
169
|
+
return
|
|
170
|
+
bias_weight = NumpyHelper.to_array(initializer)
|
|
171
|
+
if bias_weight is None:
|
|
172
|
+
logger.debug("Bias weight not found")
|
|
173
|
+
return
|
|
174
|
+
if len(bias_weight.shape) != 1:
|
|
175
|
+
logger.debug("Bias weight is not 1D")
|
|
176
|
+
return
|
|
177
|
+
|
|
178
|
+
subgraph_nodes = [node, add]
|
|
179
|
+
if not self.model.is_safe_to_fuse_nodes(subgraph_nodes, node.output, input_name_to_nodes, output_name_to_node):
|
|
180
|
+
logger.debug("Skip fusing SkipLayerNormalization with Bias since it is not safe")
|
|
181
|
+
return
|
|
182
|
+
|
|
183
|
+
self.nodes_to_remove.extend(subgraph_nodes)
|
|
184
|
+
inputs = [
|
|
185
|
+
sln_input,
|
|
186
|
+
skip_input,
|
|
187
|
+
node.input[2],
|
|
188
|
+
node.input[3],
|
|
189
|
+
bias_input,
|
|
190
|
+
]
|
|
191
|
+
new_node = helper.make_node(
|
|
192
|
+
"SkipLayerNormalization",
|
|
193
|
+
inputs=inputs,
|
|
194
|
+
outputs=node.output,
|
|
195
|
+
name=self.model.create_node_name("SkipLayerNormalization", "SkipLayerNorm_AddBias_"),
|
|
196
|
+
)
|
|
197
|
+
new_node.domain = "com.microsoft"
|
|
198
|
+
|
|
199
|
+
# Pass attribute "epsilon" from skiplayernorm node to skiplayernorm(add bias)
|
|
200
|
+
for att in node.attribute:
|
|
201
|
+
if att.name == "epsilon":
|
|
202
|
+
new_node.attribute.extend([att])
|
|
203
|
+
|
|
204
|
+
# Set default epsilon if no epsilon exists from skiplayernorm
|
|
205
|
+
if len(new_node.attribute) == 0:
|
|
206
|
+
new_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)])
|
|
207
|
+
|
|
208
|
+
self.nodes_to_add.append(new_node)
|
|
209
|
+
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
from logging import getLogger
|
|
7
|
+
from typing import Dict, List
|
|
8
|
+
|
|
9
|
+
from fusion_base import Fusion
|
|
10
|
+
from fusion_utils import FusionUtils
|
|
11
|
+
from onnx import NodeProto, TensorProto, helper
|
|
12
|
+
from onnx_model import OnnxModel
|
|
13
|
+
|
|
14
|
+
logger = getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class FusionTranspose(Fusion):
|
|
18
|
+
def __init__(self, model: OnnxModel):
|
|
19
|
+
super().__init__(model, "Transpose", "Transpose")
|
|
20
|
+
|
|
21
|
+
def fuse(
|
|
22
|
+
self,
|
|
23
|
+
transpose_node: NodeProto,
|
|
24
|
+
input_name_to_nodes: Dict[str, List[NodeProto]],
|
|
25
|
+
output_name_to_node: Dict[str, NodeProto],
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Note that onnxruntime will do comprehensive transpose optimization after loading model.
|
|
29
|
+
The purpose of this fusion is to make graph clean before running onnxruntime.
|
|
30
|
+
|
|
31
|
+
Case 1:
|
|
32
|
+
(input)-->Transpose(perm=a)-->Transpose(perm=b)-->
|
|
33
|
+
After:
|
|
34
|
+
(input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore)
|
|
35
|
+
|
|
|
36
|
+
+----->Transpose(perm=a*b)-->
|
|
37
|
+
|
|
38
|
+
Case 2 (Cast has only one child):
|
|
39
|
+
(input)-->Transpose(perm=a)--> Cast -->Transpose(perm=b)-->
|
|
40
|
+
After:
|
|
41
|
+
(input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore)
|
|
42
|
+
|
|
|
43
|
+
+----->Cast --> Transpose(perm=a*b)-->
|
|
44
|
+
"""
|
|
45
|
+
transpose_b = transpose_node
|
|
46
|
+
if transpose_b.input[0] not in output_name_to_node:
|
|
47
|
+
return
|
|
48
|
+
|
|
49
|
+
transpose_a = output_name_to_node[transpose_b.input[0]]
|
|
50
|
+
if transpose_a.op_type != "Cast":
|
|
51
|
+
cast_node = None
|
|
52
|
+
else:
|
|
53
|
+
cast_node = transpose_a
|
|
54
|
+
|
|
55
|
+
cast_children = self.model.get_children(cast_node, input_name_to_nodes)
|
|
56
|
+
if cast_children and len(cast_children) > 1:
|
|
57
|
+
return
|
|
58
|
+
|
|
59
|
+
if cast_node.input[0] not in output_name_to_node:
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
transpose_a = output_name_to_node[cast_node.input[0]]
|
|
63
|
+
|
|
64
|
+
if transpose_a.op_type != "Transpose":
|
|
65
|
+
return
|
|
66
|
+
|
|
67
|
+
permutation = OnnxModel.get_node_attribute(transpose_b, "perm")
|
|
68
|
+
assert isinstance(permutation, list)
|
|
69
|
+
|
|
70
|
+
parent_permutation = OnnxModel.get_node_attribute(transpose_a, "perm")
|
|
71
|
+
assert isinstance(parent_permutation, list)
|
|
72
|
+
|
|
73
|
+
assert len(parent_permutation) == len(permutation)
|
|
74
|
+
|
|
75
|
+
output_permutation = []
|
|
76
|
+
for _j, index in enumerate(permutation):
|
|
77
|
+
output_permutation.append(parent_permutation[index])
|
|
78
|
+
|
|
79
|
+
if cast_node is None:
|
|
80
|
+
if FusionUtils.skip_parent(self.model, transpose_b, transpose_a, input_name_to_nodes):
|
|
81
|
+
self.nodes_to_remove.append(transpose_a)
|
|
82
|
+
else:
|
|
83
|
+
if FusionUtils.skip_parent(self.model, cast_node, transpose_a, input_name_to_nodes):
|
|
84
|
+
self.nodes_to_remove.append(transpose_a)
|
|
85
|
+
transpose_b.ClearField("attribute")
|
|
86
|
+
transpose_b.attribute.extend([helper.make_attribute("perm", output_permutation)])
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class FusionInsertTranspose(Fusion):
|
|
90
|
+
def __init__(self, model: OnnxModel):
|
|
91
|
+
super().__init__(model, "", "GroupNorm")
|
|
92
|
+
|
|
93
|
+
def create_transpose_node(self, input_name: str, perm: List[int], output_name=None):
|
|
94
|
+
"""Append a Transpose node after an input"""
|
|
95
|
+
node_name = self.model.create_node_name("Transpose")
|
|
96
|
+
if output_name is None:
|
|
97
|
+
output_name = node_name + "_out" + "-" + input_name
|
|
98
|
+
transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
|
|
99
|
+
transpose_node.attribute.extend([helper.make_attribute("perm", perm)])
|
|
100
|
+
return transpose_node
|
|
101
|
+
|
|
102
|
+
def fuse(
|
|
103
|
+
self,
|
|
104
|
+
group_norm_node: NodeProto,
|
|
105
|
+
input_name_to_nodes: Dict[str, List[NodeProto]],
|
|
106
|
+
output_name_to_node: Dict[str, NodeProto],
|
|
107
|
+
):
|
|
108
|
+
"""
|
|
109
|
+
This optimization will insert an Transpose, and onnxruntime transpose optimizer will remove it together with
|
|
110
|
+
another Transpose so that we can get effect of reducing one Transpose after onnxruntime optimization.
|
|
111
|
+
Before:
|
|
112
|
+
--> Gemm --> Unsqueeze(axes=[2]) --> Unsqueeze(axes=[3]) --> Add --> Transpose([0,2,3,1]) --> GroupNorm
|
|
113
|
+
After:
|
|
114
|
+
--> Gemm --> Unsqueeze(axes=[1]) --> Unsqueeze(axes=[2]) -->Transpose([0,3,1,2]) --> Add --> Transpose([0,2,3,1]) --> GroupNorm
|
|
115
|
+
"""
|
|
116
|
+
gemm_path = self.model.match_parent_path(
|
|
117
|
+
group_norm_node, ["Transpose", "Add", "Unsqueeze", "Unsqueeze", "Gemm"], [0, 0, None, 0, 0]
|
|
118
|
+
)
|
|
119
|
+
if gemm_path is None:
|
|
120
|
+
return
|
|
121
|
+
transpose, add, unsqueeze_3, unsqueeze_2, gemm = gemm_path
|
|
122
|
+
if self.model.find_graph_output(unsqueeze_3.output[0]):
|
|
123
|
+
return
|
|
124
|
+
|
|
125
|
+
permutation = OnnxModel.get_node_attribute(transpose, "perm")
|
|
126
|
+
assert isinstance(permutation, list)
|
|
127
|
+
if permutation != [0, 2, 3, 1]:
|
|
128
|
+
return
|
|
129
|
+
|
|
130
|
+
if not (
|
|
131
|
+
len(unsqueeze_3.input) == 2
|
|
132
|
+
and self.model.get_constant_value(unsqueeze_3.input[1]) == 3
|
|
133
|
+
and len(unsqueeze_2.input) == 2
|
|
134
|
+
and self.model.get_constant_value(unsqueeze_2.input[1]) == 2
|
|
135
|
+
and len(self.model.get_children(gemm, input_name_to_nodes)) == 1
|
|
136
|
+
and len(self.model.get_children(unsqueeze_3, input_name_to_nodes)) == 1
|
|
137
|
+
and len(self.model.get_children(unsqueeze_2, input_name_to_nodes)) == 1
|
|
138
|
+
):
|
|
139
|
+
return
|
|
140
|
+
|
|
141
|
+
# Here we use hard-coded name so that it could be shared for the whole model.
|
|
142
|
+
axes_1 = "ort_const_unsqueeze_axes_1"
|
|
143
|
+
if self.model.get_initializer(axes_1) is None:
|
|
144
|
+
self.add_initializer(
|
|
145
|
+
name=axes_1,
|
|
146
|
+
data_type=TensorProto.INT64,
|
|
147
|
+
dims=[1],
|
|
148
|
+
vals=[1],
|
|
149
|
+
raw=False,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
axes_2 = "ort_const_unsqueeze_axes_2"
|
|
153
|
+
if self.model.get_initializer(axes_2) is None:
|
|
154
|
+
self.add_initializer(
|
|
155
|
+
name=axes_2,
|
|
156
|
+
data_type=TensorProto.INT64,
|
|
157
|
+
dims=[1],
|
|
158
|
+
vals=[2],
|
|
159
|
+
raw=False,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
unsqueeze_3.input[1] = "ort_const_unsqueeze_axes_2"
|
|
163
|
+
unsqueeze_2.input[1] = "ort_const_unsqueeze_axes_1"
|
|
164
|
+
transpose_output_name = self.model.create_node_name("Transpose") + "_NCHW"
|
|
165
|
+
self.model.replace_input_of_all_nodes(unsqueeze_3.output[0], transpose_output_name)
|
|
166
|
+
new_transpose = self.create_transpose_node(unsqueeze_3.output[0], [0, 3, 1, 2], transpose_output_name)
|
|
167
|
+
self.model.add_node(new_transpose, self.this_graph_name)
|
|
168
|
+
self.increase_counter("Insert Transpose")
|
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
from logging import getLogger
|
|
6
|
+
from typing import Optional, Tuple
|
|
7
|
+
|
|
8
|
+
import numpy
|
|
9
|
+
from numpy import array_equal, ndarray
|
|
10
|
+
from onnx import NodeProto, TensorProto, helper, numpy_helper
|
|
11
|
+
from onnx import onnx_pb as onnx_proto
|
|
12
|
+
from onnx_model import OnnxModel
|
|
13
|
+
|
|
14
|
+
logger = getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class FusionUtils:
|
|
18
|
+
def __init__(self, model: OnnxModel):
|
|
19
|
+
self.model: OnnxModel = model
|
|
20
|
+
|
|
21
|
+
def cast_graph_input_to_int32(self, input_name: str) -> Tuple[bool, str]:
|
|
22
|
+
graph_input = self.model.find_graph_input(input_name)
|
|
23
|
+
if graph_input is not None and graph_input.type.tensor_type.elem_type != TensorProto.INT32:
|
|
24
|
+
cast_output, cast_node = self.cast_input_to_int32(input_name)
|
|
25
|
+
logger.debug(f"Casted graph input {input_name} to int32")
|
|
26
|
+
return True, cast_output
|
|
27
|
+
|
|
28
|
+
logger.debug(f"Did not cast graph input {input_name} to int32: found {graph_input is not None}")
|
|
29
|
+
return False, input_name
|
|
30
|
+
|
|
31
|
+
def cast_input(self, input_name: str, target_type="int32"):
|
|
32
|
+
output_name = input_name + "_" + target_type
|
|
33
|
+
|
|
34
|
+
if target_type == "int32":
|
|
35
|
+
to_type = int(TensorProto.INT32)
|
|
36
|
+
elif target_type == "float32":
|
|
37
|
+
to_type = int(TensorProto.FLOAT)
|
|
38
|
+
elif target_type == "float16":
|
|
39
|
+
to_type = int(TensorProto.FLOAT16)
|
|
40
|
+
else:
|
|
41
|
+
raise ValueError("Invalid target_type: {target_type}")
|
|
42
|
+
|
|
43
|
+
cast_node = self.add_cast_node(input_name, to_type, output_name)
|
|
44
|
+
|
|
45
|
+
return output_name, cast_node
|
|
46
|
+
|
|
47
|
+
def add_cast_node(
|
|
48
|
+
self,
|
|
49
|
+
input_name: str,
|
|
50
|
+
to_type: int,
|
|
51
|
+
output_name: Optional[str] = None,
|
|
52
|
+
output_name_to_node=None,
|
|
53
|
+
graph_name: Optional[str] = None,
|
|
54
|
+
):
|
|
55
|
+
if output_name is None:
|
|
56
|
+
output_name = input_name + f"_cast_to_{to_type}"
|
|
57
|
+
|
|
58
|
+
# Avoid consequent Cast nodes.
|
|
59
|
+
inputs = [input_name]
|
|
60
|
+
if output_name_to_node is None:
|
|
61
|
+
output_name_to_node = self.model.output_name_to_node()
|
|
62
|
+
if input_name in output_name_to_node:
|
|
63
|
+
parent_node = output_name_to_node[input_name]
|
|
64
|
+
if parent_node and parent_node.op_type == "Cast":
|
|
65
|
+
inputs = [parent_node.input[0]]
|
|
66
|
+
|
|
67
|
+
cast_node = helper.make_node("Cast", inputs=inputs, outputs=[output_name])
|
|
68
|
+
|
|
69
|
+
cast_node.attribute.extend([helper.make_attribute("to", to_type)])
|
|
70
|
+
self.model.add_node(cast_node, graph_name=graph_name)
|
|
71
|
+
|
|
72
|
+
return cast_node
|
|
73
|
+
|
|
74
|
+
def cast_input_to_int32(self, input_name: str):
|
|
75
|
+
return self.cast_input(input_name, "int32")
|
|
76
|
+
|
|
77
|
+
def remove_cast_int32(self, input_name: str):
|
|
78
|
+
input_name_to_nodes = self.model.input_name_to_nodes()
|
|
79
|
+
nodes = input_name_to_nodes[input_name]
|
|
80
|
+
for node in nodes:
|
|
81
|
+
if node.op_type == "Cast":
|
|
82
|
+
is_int32 = False
|
|
83
|
+
for att in node.attribute:
|
|
84
|
+
if att.name == "to" and att.i == int(TensorProto.INT32):
|
|
85
|
+
is_int32 = True
|
|
86
|
+
break
|
|
87
|
+
if is_int32:
|
|
88
|
+
output_name = node.output[0]
|
|
89
|
+
self.model.remove_node(node)
|
|
90
|
+
self.model.replace_input_of_all_nodes(output_name, input_name)
|
|
91
|
+
|
|
92
|
+
@staticmethod
|
|
93
|
+
def update_node_input(node, i, new_input_name, input_name_to_nodes):
|
|
94
|
+
old_input_reference = 0
|
|
95
|
+
if (node.input[i] in input_name_to_nodes) and node in input_name_to_nodes[node.input[i]]:
|
|
96
|
+
input_name_to_nodes[node.input[i]].remove(node)
|
|
97
|
+
old_input_reference = len(input_name_to_nodes[node.input[i]])
|
|
98
|
+
|
|
99
|
+
node.input[i] = new_input_name
|
|
100
|
+
|
|
101
|
+
if new_input_name in input_name_to_nodes:
|
|
102
|
+
input_name_to_nodes[new_input_name].append(node)
|
|
103
|
+
else:
|
|
104
|
+
input_name_to_nodes[new_input_name] = [node]
|
|
105
|
+
|
|
106
|
+
return old_input_reference
|
|
107
|
+
|
|
108
|
+
@staticmethod
|
|
109
|
+
def skip_parent(model: OnnxModel, node, parent_node, input_name_to_nodes, node_input_index=0, parent_input_index=0):
|
|
110
|
+
"""
|
|
111
|
+
Before:
|
|
112
|
+
(input)-->parent-->node-->(output)
|
|
113
|
+
After:
|
|
114
|
+
(input)-->parent-->
|
|
115
|
+
|
|
|
116
|
+
+----->node-->(output)
|
|
117
|
+
|
|
118
|
+
This function returns a flag whether the parent node can be removed.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
old_input_name = node.input[node_input_index]
|
|
122
|
+
new_input_name = parent_node.input[parent_input_index]
|
|
123
|
+
old_input_reference = FusionUtils.update_node_input(node, node_input_index, new_input_name, input_name_to_nodes)
|
|
124
|
+
|
|
125
|
+
# We can remove the first Transpose if its output is not used (linked to graph output or other nodes) anymore.
|
|
126
|
+
parent_can_be_removed = (old_input_reference == 0) and not model.find_graph_output(old_input_name)
|
|
127
|
+
|
|
128
|
+
return parent_can_be_removed
|
|
129
|
+
|
|
130
|
+
@staticmethod
|
|
131
|
+
def check_node_attribute(node, attribute_name: str, expected_value, default_value=None):
|
|
132
|
+
"""Verify that a node has expected value for an attribute.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
node (NodeProto): a node to check
|
|
136
|
+
attribute_name (str): name of attribute
|
|
137
|
+
expected_value (Any): expected value of the attribute
|
|
138
|
+
default_value (Any, optional): default value if the attribute does not exist. Defaults to None.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
bool: whether the check is passed or not
|
|
142
|
+
"""
|
|
143
|
+
value = default_value
|
|
144
|
+
for attr in node.attribute:
|
|
145
|
+
if attr.name == attribute_name:
|
|
146
|
+
value = helper.get_attribute_value(attr)
|
|
147
|
+
|
|
148
|
+
if isinstance(expected_value, list):
|
|
149
|
+
return (isinstance(value, (ndarray, list))) and array_equal(expected_value, value, equal_nan=False)
|
|
150
|
+
else:
|
|
151
|
+
return value == expected_value
|
|
152
|
+
|
|
153
|
+
@staticmethod
|
|
154
|
+
def transpose_2d_int8_tensor(tensor: onnx_proto.TensorProto):
|
|
155
|
+
"""Transpose a 2-D INT8 TensorProto
|
|
156
|
+
Args:
|
|
157
|
+
tensor (TensorProto): tensor to be transposed
|
|
158
|
+
Returns:
|
|
159
|
+
tensor (TensorProto): transposed tensor
|
|
160
|
+
"""
|
|
161
|
+
if not isinstance(tensor, onnx_proto.TensorProto):
|
|
162
|
+
raise ValueError(f"Expected input type is an ONNX TensorProto but got {type(tensor)}")
|
|
163
|
+
|
|
164
|
+
if len(tensor.dims) != 2 or tensor.data_type != onnx_proto.TensorProto.INT8:
|
|
165
|
+
raise ValueError("Only INT8 2-D tensors can be transposed")
|
|
166
|
+
|
|
167
|
+
if tensor.raw_data:
|
|
168
|
+
int32_data = numpy.reshape(numpy.frombuffer(tensor.raw_data, dtype="int8"), tensor.dims)
|
|
169
|
+
int32_transposed_data = numpy.transpose(int32_data, [1, 0])
|
|
170
|
+
tensor.raw_data = int32_transposed_data.tobytes()
|
|
171
|
+
|
|
172
|
+
else:
|
|
173
|
+
raise ValueError("only raw buffer supported")
|
|
174
|
+
|
|
175
|
+
return tensor
|
|
176
|
+
|
|
177
|
+
@staticmethod
|
|
178
|
+
def check_qdq_node_for_fusion(node: NodeProto, model: OnnxModel, allow_per_tensor_quantization_only=True):
|
|
179
|
+
"""Verify if a provided QuantizeLinear (Q) / DequantizeLinear (DQ) node is a good candidate for fusion.
|
|
180
|
+
It is a good candidate for fusion if:
|
|
181
|
+
(1) The Q/DQ node is for per-tensor quantization if allow_per_tensor_quantization_only is `True`
|
|
182
|
+
(2) The Q/DQ node should have constant scale
|
|
183
|
+
(3) The Q/DQ node should have a zero point of 0
|
|
184
|
+
Args:
|
|
185
|
+
node (NodeProto): a Q/DQ node to check
|
|
186
|
+
Returns:
|
|
187
|
+
bool: whether the check is passed or not
|
|
188
|
+
"""
|
|
189
|
+
if node.op_type not in {"QuantizeLinear", "DequantizeLinear"}:
|
|
190
|
+
logger.debug(f"Provided node is not a Q/DQ node. Op Type: {node.op_type}")
|
|
191
|
+
|
|
192
|
+
scale = model.get_constant_value(node.input[1])
|
|
193
|
+
|
|
194
|
+
# Scale is not constant
|
|
195
|
+
if scale is None:
|
|
196
|
+
return False
|
|
197
|
+
|
|
198
|
+
# Not per-tensor quantization
|
|
199
|
+
scale_has_single_element = scale.ndim == 0 or (scale.ndim == 1 and scale.shape[0] == 1)
|
|
200
|
+
if allow_per_tensor_quantization_only and not scale_has_single_element:
|
|
201
|
+
return False
|
|
202
|
+
|
|
203
|
+
# If the Q/DQ node has no zero point input, it is assumed to be 0 (per ONNX spec)
|
|
204
|
+
if len(node.input) == 2:
|
|
205
|
+
return True
|
|
206
|
+
|
|
207
|
+
# Zero point should be constant and should have a value of 0
|
|
208
|
+
zero_point = model.get_constant_value(node.input[2])
|
|
209
|
+
|
|
210
|
+
# Zero point and scale should have same number of dims
|
|
211
|
+
if scale.ndim != zero_point.ndim:
|
|
212
|
+
return False
|
|
213
|
+
|
|
214
|
+
# Zero point is not constant or zero point is not zero
|
|
215
|
+
if zero_point is None:
|
|
216
|
+
return False
|
|
217
|
+
|
|
218
|
+
return numpy.all(zero_point == 0)
|
|
219
|
+
|
|
220
|
+
def check_node_input_value(self, node, input_index: int, expected_value):
|
|
221
|
+
"""Verify that a node has expected input value
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
node (NodeProto): a node to check
|
|
225
|
+
input_index (int): index of its input to be verified
|
|
226
|
+
expected_value (Any): expected value of the input
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
bool: whether the check is passed or not
|
|
230
|
+
"""
|
|
231
|
+
assert len(node.input) > input_index
|
|
232
|
+
|
|
233
|
+
value = self.model.get_constant_value(node.input[input_index])
|
|
234
|
+
|
|
235
|
+
if isinstance(expected_value, list):
|
|
236
|
+
return (isinstance(value, (ndarray, list))) and array_equal(expected_value, value, equal_nan=False)
|
|
237
|
+
else:
|
|
238
|
+
return value == expected_value
|
|
239
|
+
|
|
240
|
+
def remove_identity_nodes(self):
|
|
241
|
+
"""Remove Identity nodes, except those right before graph output."""
|
|
242
|
+
nodes_to_remove = []
|
|
243
|
+
graph_output_names = self.model.get_graphs_output_names()
|
|
244
|
+
for node in self.model.nodes():
|
|
245
|
+
if node.op_type == "Identity":
|
|
246
|
+
if node.output[0] not in graph_output_names:
|
|
247
|
+
self.model.replace_input_of_all_nodes(node.output[0], node.input[0])
|
|
248
|
+
nodes_to_remove.append(node)
|
|
249
|
+
|
|
250
|
+
if nodes_to_remove:
|
|
251
|
+
self.model.remove_nodes(nodes_to_remove)
|
|
252
|
+
logger.info(f"Removed {len(nodes_to_remove)} Identity nodes")
|
|
253
|
+
|
|
254
|
+
def remove_cascaded_cast_nodes(self):
|
|
255
|
+
self.model.remove_cascaded_cast_nodes()
|
|
256
|
+
|
|
257
|
+
def remove_useless_cast_nodes(self):
|
|
258
|
+
self.model.remove_useless_cast_nodes()
|
|
259
|
+
|
|
260
|
+
def remove_useless_reshape_nodes(self):
|
|
261
|
+
"""Remove reshape node that is not needed based on symbolic shape inference: input and output has same shape"""
|
|
262
|
+
shape_infer = self.model.infer_runtime_shape(update=True)
|
|
263
|
+
if shape_infer is None:
|
|
264
|
+
return
|
|
265
|
+
|
|
266
|
+
nodes_to_remove = []
|
|
267
|
+
for node in self.model.nodes():
|
|
268
|
+
if node.op_type == "Reshape":
|
|
269
|
+
input_shape = shape_infer.get_edge_shape(node.input[0])
|
|
270
|
+
output_shape = shape_infer.get_edge_shape(node.output[0])
|
|
271
|
+
if input_shape and output_shape and input_shape == output_shape:
|
|
272
|
+
logger.info(
|
|
273
|
+
f"Remove reshape node {node.name} since its input shape is same as output: {input_shape}"
|
|
274
|
+
)
|
|
275
|
+
nodes_to_remove.append(node)
|
|
276
|
+
|
|
277
|
+
if nodes_to_remove:
|
|
278
|
+
graph_input_names = set(self.model.get_graphs_input_names())
|
|
279
|
+
graph_output_names = set(self.model.get_graphs_output_names())
|
|
280
|
+
for node in nodes_to_remove:
|
|
281
|
+
if bool(set(node.output) & graph_output_names):
|
|
282
|
+
if (
|
|
283
|
+
not bool(set(node.input) & graph_input_names)
|
|
284
|
+
and len(self.model.input_name_to_nodes()[node.input[0]]) == 1 # parent has only one child
|
|
285
|
+
):
|
|
286
|
+
self.model.replace_output_of_all_nodes(node.input[0], node.output[0])
|
|
287
|
+
else:
|
|
288
|
+
continue
|
|
289
|
+
else:
|
|
290
|
+
self.model.replace_input_of_all_nodes(node.output[0], node.input[0])
|
|
291
|
+
self.model.remove_node(node)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class NumpyHelper:
|
|
295
|
+
@staticmethod
|
|
296
|
+
def to_array(tensor: TensorProto, fill_zeros: bool = False) -> ndarray:
|
|
297
|
+
# When weights are in external data format but not presented, we can still test the optimizer with two changes:
|
|
298
|
+
# (1) set fill_zeros = True (2) change load_external_data=False in optimizer.py
|
|
299
|
+
if fill_zeros:
|
|
300
|
+
from onnx import mapping
|
|
301
|
+
|
|
302
|
+
return ndarray(
|
|
303
|
+
shape=tensor.dims,
|
|
304
|
+
dtype=mapping.TENSOR_TYPE_TO_NP_TYPE[tensor.data_type],
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
return numpy_helper.to_array(tensor)
|