onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1251 @@
|
|
|
1
|
+
#
|
|
2
|
+
# The implementation of this file is based on:
|
|
3
|
+
# https://github.com/intel/neural-compressor/tree/master/neural_compressor
|
|
4
|
+
#
|
|
5
|
+
# Copyright (c) 2023 Intel Corporation
|
|
6
|
+
#
|
|
7
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
8
|
+
# you may not use this file except in compliance with the License.
|
|
9
|
+
# You may obtain a copy of the License at
|
|
10
|
+
#
|
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
12
|
+
#
|
|
13
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
14
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
15
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
16
|
+
# See the License for the specific language governing permissions and
|
|
17
|
+
# limitations under the License.
|
|
18
|
+
|
|
19
|
+
"""Class for ONNX model."""
|
|
20
|
+
|
|
21
|
+
import copy
|
|
22
|
+
import logging
|
|
23
|
+
import os
|
|
24
|
+
import sys
|
|
25
|
+
from collections import deque
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
|
|
28
|
+
import onnx
|
|
29
|
+
import onnx.external_data_helper
|
|
30
|
+
|
|
31
|
+
from .util import MAXIMUM_PROTOBUF, find_by_name
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger("neural_compressor")
|
|
34
|
+
|
|
35
|
+
# TODO: Check https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/onnx_model.py to see if we can integrate with it.
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ONNXModel:
|
|
39
|
+
"""Build ONNX model."""
|
|
40
|
+
|
|
41
|
+
def __init__(self, model, **kwargs):
|
|
42
|
+
"""Initialize an ONNX model.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
model (str or ModelProto): path to onnx model or loaded ModelProto model object.
|
|
46
|
+
ignore_warning (bool): ignore large model warning. Default is False.
|
|
47
|
+
load_external_data (bool): load external data for large model. Default is True.
|
|
48
|
+
"""
|
|
49
|
+
self._model = model if not isinstance(model, str) else onnx.load(model, load_external_data=False)
|
|
50
|
+
self._model_path = None if not isinstance(model, str) else model
|
|
51
|
+
|
|
52
|
+
self.check_is_large_model()
|
|
53
|
+
if self._is_large_model and self._model_path is None and not kwargs.get("ignore_warning", False):
|
|
54
|
+
logger.warning("Model size > 2GB. Please use model path instead of onnx model object to quantize")
|
|
55
|
+
|
|
56
|
+
if self._is_large_model and isinstance(model, str) and kwargs.get("load_external_data", True):
|
|
57
|
+
onnx.external_data_helper.load_external_data_for_model(self._model, os.path.dirname(self._model_path))
|
|
58
|
+
|
|
59
|
+
self._config = None
|
|
60
|
+
if isinstance(model, str) and os.path.exists(Path(model).parent.joinpath("config.json").as_posix()):
|
|
61
|
+
from transformers import AutoConfig # noqa: PLC0415
|
|
62
|
+
|
|
63
|
+
self._config = AutoConfig.from_pretrained(Path(model).parent.as_posix())
|
|
64
|
+
|
|
65
|
+
self.node_name_counter = {}
|
|
66
|
+
self._output_name_to_node = {}
|
|
67
|
+
self._input_name_to_nodes = {}
|
|
68
|
+
self._get_input_name_to_nodes(self._model.graph.node)
|
|
69
|
+
self._get_output_name_to_node(self._model.graph.node)
|
|
70
|
+
self._graph_info = {}
|
|
71
|
+
self._get_graph_info()
|
|
72
|
+
self._q_config = None
|
|
73
|
+
|
|
74
|
+
def check_is_large_model(self):
|
|
75
|
+
"""Check model > 2GB."""
|
|
76
|
+
init_size = 0
|
|
77
|
+
for init in self._model.graph.initializer:
|
|
78
|
+
# if initializer has external data location, return True
|
|
79
|
+
if init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL:
|
|
80
|
+
self._is_large_model = True
|
|
81
|
+
return
|
|
82
|
+
# if raise error of initializer size > 2GB, return True
|
|
83
|
+
try:
|
|
84
|
+
init_bytes = init.SerializeToString()
|
|
85
|
+
init_size += sys.getsizeof(init_bytes)
|
|
86
|
+
except Exception as e:
|
|
87
|
+
if "exceeds maximum protobuf size of 2GB" in str(e):
|
|
88
|
+
self._is_large_model = True
|
|
89
|
+
return
|
|
90
|
+
else: # pragma: no cover
|
|
91
|
+
raise e
|
|
92
|
+
if init_size > MAXIMUM_PROTOBUF:
|
|
93
|
+
self._is_large_model = True
|
|
94
|
+
return
|
|
95
|
+
self._is_large_model = False
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def is_large_model(self):
|
|
99
|
+
"""Check the onnx model is over 2GB."""
|
|
100
|
+
return self._is_large_model
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def model_path(self):
|
|
104
|
+
"""Return model path."""
|
|
105
|
+
return self._model_path
|
|
106
|
+
|
|
107
|
+
@model_path.setter
|
|
108
|
+
def model_path(self, path):
|
|
109
|
+
"""Set model path."""
|
|
110
|
+
self._model_path = path
|
|
111
|
+
|
|
112
|
+
def framework(self):
|
|
113
|
+
"""Return framework."""
|
|
114
|
+
return "onnxruntime"
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def q_config(self):
|
|
118
|
+
"""Return q_config."""
|
|
119
|
+
return self._q_config
|
|
120
|
+
|
|
121
|
+
@q_config.setter
|
|
122
|
+
def q_config(self, q_config):
|
|
123
|
+
"""Set q_config."""
|
|
124
|
+
self._q_config = q_config
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def hf_config(self):
|
|
128
|
+
"""Return huggingface config if model is Transformer-based."""
|
|
129
|
+
return self._config
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def model(self):
|
|
133
|
+
"""Return model itself."""
|
|
134
|
+
return self._model
|
|
135
|
+
|
|
136
|
+
@model.setter
|
|
137
|
+
def model(self, model):
|
|
138
|
+
"""Set model itself."""
|
|
139
|
+
self._model = model
|
|
140
|
+
self._graph_info = {}
|
|
141
|
+
self._get_graph_info()
|
|
142
|
+
self._output_name_to_node = {}
|
|
143
|
+
self._input_name_to_nodes = {}
|
|
144
|
+
self._get_input_name_to_nodes(self._model.graph.node)
|
|
145
|
+
self._get_output_name_to_node(self._model.graph.node)
|
|
146
|
+
|
|
147
|
+
def input(self):
|
|
148
|
+
"""Return input of model."""
|
|
149
|
+
return [i.name for i in self._model.graph.input]
|
|
150
|
+
|
|
151
|
+
def output(self):
|
|
152
|
+
"""Return output of model."""
|
|
153
|
+
return [i.name for i in self._model.graph.output]
|
|
154
|
+
|
|
155
|
+
def update(self):
|
|
156
|
+
"""Update model info."""
|
|
157
|
+
self._graph_info = {}
|
|
158
|
+
self._get_graph_info()
|
|
159
|
+
self._output_name_to_node = {}
|
|
160
|
+
self._input_name_to_nodes = {}
|
|
161
|
+
self._get_input_name_to_nodes(self._model.graph.node)
|
|
162
|
+
self._get_output_name_to_node(self._model.graph.node)
|
|
163
|
+
|
|
164
|
+
@property
|
|
165
|
+
def graph_info(self):
|
|
166
|
+
"""Return ORT Graph Info object holding information about backend graph."""
|
|
167
|
+
return self._graph_info
|
|
168
|
+
|
|
169
|
+
def _get_graph_info(self):
|
|
170
|
+
"""Update graph info."""
|
|
171
|
+
for node in self._model.graph.node:
|
|
172
|
+
self.graph_info.update({node.name: node.op_type})
|
|
173
|
+
|
|
174
|
+
def save(self, root):
|
|
175
|
+
"""Save ONNX model."""
|
|
176
|
+
if os.path.split(root)[0] != "" and not os.path.exists(os.path.split(root)[0]):
|
|
177
|
+
raise ValueError('"root" directory does not exists.')
|
|
178
|
+
if self.is_large_model:
|
|
179
|
+
onnx.external_data_helper.load_external_data_for_model(self._model, os.path.split(self._model_path)[0])
|
|
180
|
+
onnx.save_model(
|
|
181
|
+
self._model,
|
|
182
|
+
root,
|
|
183
|
+
save_as_external_data=True,
|
|
184
|
+
all_tensors_to_one_file=True,
|
|
185
|
+
location=root.split("/")[-1] + "_data",
|
|
186
|
+
size_threshold=1024,
|
|
187
|
+
convert_attribute=False,
|
|
188
|
+
)
|
|
189
|
+
else:
|
|
190
|
+
onnx.save(self._model, root)
|
|
191
|
+
|
|
192
|
+
if self._config is not None:
|
|
193
|
+
model_type = "" if not hasattr(self._config, "model_type") else self._config.model_type
|
|
194
|
+
self._config.__class__.model_type = model_type
|
|
195
|
+
output_config_file = Path(root).parent.joinpath("config.json").as_posix()
|
|
196
|
+
self._config.to_json_file(output_config_file, use_diff=False)
|
|
197
|
+
|
|
198
|
+
def nodes(self):
|
|
199
|
+
"""Return model nodes."""
|
|
200
|
+
return self._model.graph.node
|
|
201
|
+
|
|
202
|
+
def initializer(self):
|
|
203
|
+
"""Return model initializer."""
|
|
204
|
+
return self._model.graph.initializer
|
|
205
|
+
|
|
206
|
+
def graph(self):
|
|
207
|
+
"""Return model graph."""
|
|
208
|
+
return self._model.graph
|
|
209
|
+
|
|
210
|
+
def ir_version(self):
|
|
211
|
+
"""Return model ir_version."""
|
|
212
|
+
return self._model.ir_version
|
|
213
|
+
|
|
214
|
+
def opset_import(self):
|
|
215
|
+
"""Return model opset_import."""
|
|
216
|
+
return self._model.opset_import
|
|
217
|
+
|
|
218
|
+
def remove_node(self, node):
|
|
219
|
+
"""Remove a node from model."""
|
|
220
|
+
if node in self._model.graph.node:
|
|
221
|
+
self._model.graph.node.remove(node)
|
|
222
|
+
|
|
223
|
+
def remove_nodes(self, nodes_to_remove):
|
|
224
|
+
"""Remove nodes from model."""
|
|
225
|
+
for node in nodes_to_remove:
|
|
226
|
+
self.remove_node(node)
|
|
227
|
+
|
|
228
|
+
def add_node(self, node):
|
|
229
|
+
"""Add a node to model."""
|
|
230
|
+
self._model.graph.node.extend([node])
|
|
231
|
+
|
|
232
|
+
def add_nodes(self, nodes_to_add):
|
|
233
|
+
"""Add nodes to model."""
|
|
234
|
+
self._model.graph.node.extend(nodes_to_add)
|
|
235
|
+
|
|
236
|
+
def add_initializer(self, tensor):
|
|
237
|
+
"""Add a initializer to model."""
|
|
238
|
+
if find_by_name(tensor.name, self._model.graph.initializer) is None:
|
|
239
|
+
self._model.graph.initializer.extend([tensor])
|
|
240
|
+
|
|
241
|
+
def add_initializers(self, tensors):
|
|
242
|
+
"""Add initializers to model."""
|
|
243
|
+
for tensor in tensors:
|
|
244
|
+
self.add_initializer(tensor)
|
|
245
|
+
|
|
246
|
+
def get_initializer(self, name):
|
|
247
|
+
"""Get an initializer by name."""
|
|
248
|
+
for tensor in self._model.graph.initializer:
|
|
249
|
+
if tensor.name == name:
|
|
250
|
+
return tensor
|
|
251
|
+
return None
|
|
252
|
+
|
|
253
|
+
def get_initializer_share_num(self, name):
|
|
254
|
+
"""Get the number of shares of initializer."""
|
|
255
|
+
num = 0
|
|
256
|
+
if self.get_initializer(name) is None:
|
|
257
|
+
return num
|
|
258
|
+
|
|
259
|
+
for node in self.nodes():
|
|
260
|
+
if name in node.input:
|
|
261
|
+
num += 1
|
|
262
|
+
return num
|
|
263
|
+
|
|
264
|
+
def get_node(self, name):
|
|
265
|
+
"""Get a node by name."""
|
|
266
|
+
for node in self._model.graph.node:
|
|
267
|
+
if node.name == name:
|
|
268
|
+
return node
|
|
269
|
+
return None
|
|
270
|
+
|
|
271
|
+
def remove_initializer(self, tensor):
|
|
272
|
+
"""Remove an initializer from model."""
|
|
273
|
+
if tensor in self._model.graph.initializer:
|
|
274
|
+
self._model.graph.initializer.remove(tensor)
|
|
275
|
+
|
|
276
|
+
def remove_initializers(self, init_to_remove):
|
|
277
|
+
"""Remove initializers from model."""
|
|
278
|
+
for initializer in init_to_remove:
|
|
279
|
+
self.remove_initializer(initializer)
|
|
280
|
+
|
|
281
|
+
def set_initializer(self, tensor, array, raw=False):
|
|
282
|
+
"""Update initializer."""
|
|
283
|
+
old_tensor = self.get_initializer(tensor)
|
|
284
|
+
self.remove_initializer(old_tensor)
|
|
285
|
+
dims = old_tensor.dims
|
|
286
|
+
data_type = old_tensor.data_type
|
|
287
|
+
new_tensor = (
|
|
288
|
+
onnx.helper.make_tensor(tensor, data_type, dims, array.flatten().tolist())
|
|
289
|
+
if not raw
|
|
290
|
+
else onnx.helper.make_tensor(tensor, data_type, dims, array.tostring(), raw=raw)
|
|
291
|
+
)
|
|
292
|
+
self.add_initializer(new_tensor)
|
|
293
|
+
|
|
294
|
+
@property
|
|
295
|
+
def input_name_to_nodes(self):
|
|
296
|
+
"""Return input names of nodes."""
|
|
297
|
+
return self._input_name_to_nodes
|
|
298
|
+
|
|
299
|
+
def _get_input_name_to_nodes(self, nodes):
|
|
300
|
+
"""Get input names of nodes."""
|
|
301
|
+
for node in nodes:
|
|
302
|
+
attrs = [
|
|
303
|
+
attr
|
|
304
|
+
for attr in node.attribute
|
|
305
|
+
if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS
|
|
306
|
+
]
|
|
307
|
+
if len(attrs) > 0:
|
|
308
|
+
for attr in attrs:
|
|
309
|
+
self._get_input_name_to_nodes(attr.g.node)
|
|
310
|
+
for input_name in node.input:
|
|
311
|
+
if len(input_name.strip()) != 0:
|
|
312
|
+
if input_name not in self._input_name_to_nodes:
|
|
313
|
+
self._input_name_to_nodes[input_name] = [node]
|
|
314
|
+
else:
|
|
315
|
+
self._input_name_to_nodes[input_name].append(node)
|
|
316
|
+
|
|
317
|
+
@property
|
|
318
|
+
def output_name_to_node(self):
|
|
319
|
+
"""Return output names of nodes."""
|
|
320
|
+
return self._output_name_to_node
|
|
321
|
+
|
|
322
|
+
def _get_output_name_to_node(self, nodes):
|
|
323
|
+
"""Get output names of nodes."""
|
|
324
|
+
for node in nodes:
|
|
325
|
+
attrs = [
|
|
326
|
+
attr
|
|
327
|
+
for attr in node.attribute
|
|
328
|
+
if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS
|
|
329
|
+
]
|
|
330
|
+
if len(attrs) > 0:
|
|
331
|
+
for attr in attrs:
|
|
332
|
+
self._get_output_name_to_node(attr.g.node)
|
|
333
|
+
for output_name in node.output:
|
|
334
|
+
if len(output_name.strip()) != 0:
|
|
335
|
+
self._output_name_to_node[output_name] = node
|
|
336
|
+
|
|
337
|
+
def get_siblings(self, node):
|
|
338
|
+
"""Get siblings nodes."""
|
|
339
|
+
siblings = []
|
|
340
|
+
for parent in self.get_parents(node):
|
|
341
|
+
for child in self.get_children(parent):
|
|
342
|
+
if child.name != node.name:
|
|
343
|
+
siblings.append(child)
|
|
344
|
+
return siblings
|
|
345
|
+
|
|
346
|
+
def get_children(self, node, input_name_to_nodes=None):
|
|
347
|
+
"""Get children nodes."""
|
|
348
|
+
if input_name_to_nodes is None:
|
|
349
|
+
input_name_to_nodes = self._input_name_to_nodes
|
|
350
|
+
|
|
351
|
+
children = []
|
|
352
|
+
for output in node.output:
|
|
353
|
+
if output in input_name_to_nodes:
|
|
354
|
+
for child in input_name_to_nodes[output]:
|
|
355
|
+
children.append(child) # noqa: PERF402
|
|
356
|
+
return children
|
|
357
|
+
|
|
358
|
+
def get_parents(self, node, output_name_to_node=None):
|
|
359
|
+
"""Get parents nodes."""
|
|
360
|
+
if output_name_to_node is None:
|
|
361
|
+
output_name_to_node = self._output_name_to_node
|
|
362
|
+
|
|
363
|
+
parents = []
|
|
364
|
+
for input in node.input:
|
|
365
|
+
if input in output_name_to_node:
|
|
366
|
+
parents.append(output_name_to_node[input])
|
|
367
|
+
return parents
|
|
368
|
+
|
|
369
|
+
def get_parent(self, node, idx, output_name_to_node=None):
|
|
370
|
+
"""Get parent node by idx."""
|
|
371
|
+
if output_name_to_node is None:
|
|
372
|
+
output_name_to_node = self._output_name_to_node
|
|
373
|
+
|
|
374
|
+
if len(node.input) <= idx:
|
|
375
|
+
return None
|
|
376
|
+
|
|
377
|
+
input = node.input[idx]
|
|
378
|
+
if input not in output_name_to_node:
|
|
379
|
+
return None
|
|
380
|
+
|
|
381
|
+
return output_name_to_node[input]
|
|
382
|
+
|
|
383
|
+
def find_node_by_name(self, node_name, new_nodes_list, graph):
|
|
384
|
+
"""Find out node by name."""
|
|
385
|
+
graph_nodes_list = list(graph.node) # deep copy
|
|
386
|
+
graph_nodes_list.extend(new_nodes_list)
|
|
387
|
+
node = find_by_name(node_name, graph_nodes_list)
|
|
388
|
+
return node
|
|
389
|
+
|
|
390
|
+
def find_nodes_by_initializer(self, graph, initializer):
|
|
391
|
+
"""Find all nodes with given initializer as an input."""
|
|
392
|
+
nodes = []
|
|
393
|
+
for node in graph.node:
|
|
394
|
+
for node_input in node.input:
|
|
395
|
+
if node_input == initializer.name:
|
|
396
|
+
nodes.append(node)
|
|
397
|
+
return nodes
|
|
398
|
+
|
|
399
|
+
def get_scale_zero(self, tensor):
|
|
400
|
+
"""Help function to get scale and zero_point."""
|
|
401
|
+
if not tensor.endswith("_quantized"):
|
|
402
|
+
logger.debug(f"Find {tensor} in the quantized graph is not quantized.")
|
|
403
|
+
return None, None
|
|
404
|
+
|
|
405
|
+
def _searcher(tensor_name):
|
|
406
|
+
"""Search scale and zero point tensor recursively."""
|
|
407
|
+
node = self._input_name_to_nodes[tensor_name][0]
|
|
408
|
+
parent = self._output_name_to_node.get(tensor_name, None)
|
|
409
|
+
direct_int8 = ["Reshape", "Transpose", "Squeeze", "Unsqueeze", "MaxPool", "Pad", "Split"]
|
|
410
|
+
if parent is not None and parent.op_type in direct_int8:
|
|
411
|
+
fp32_tensor_name = (
|
|
412
|
+
parent.input[0]
|
|
413
|
+
.replace("_quantized", "")
|
|
414
|
+
.replace("_QuantizeLinear", "")
|
|
415
|
+
.replace("_QuantizeInput", "")
|
|
416
|
+
)
|
|
417
|
+
elif node.op_type in ["Gather"]: # pragma: no cover
|
|
418
|
+
fp32_tensor_name = (
|
|
419
|
+
node.output[0]
|
|
420
|
+
.replace("_quantized", "")
|
|
421
|
+
.replace("_QuantizeLinear", "")
|
|
422
|
+
.replace("_QuantizeInput", "")
|
|
423
|
+
)
|
|
424
|
+
else:
|
|
425
|
+
fp32_tensor_name = (
|
|
426
|
+
tensor_name.replace("_quantized", "").replace("_QuantizeLinear", "").replace("_QuantizeInput", "")
|
|
427
|
+
)
|
|
428
|
+
scale = fp32_tensor_name + "_scale"
|
|
429
|
+
scale_tensor = self.get_initializer(scale)
|
|
430
|
+
zo = fp32_tensor_name + "_zero_point"
|
|
431
|
+
zo_tensor = self.get_initializer(zo)
|
|
432
|
+
|
|
433
|
+
if scale_tensor is None or zo_tensor is None:
|
|
434
|
+
if parent is not None:
|
|
435
|
+
scale_tensor, zo_tensor = _searcher(parent.input[0])
|
|
436
|
+
return scale_tensor, zo_tensor
|
|
437
|
+
|
|
438
|
+
node = self._input_name_to_nodes[tensor][0]
|
|
439
|
+
# TODO check if scale_tensor and zero_point is needed
|
|
440
|
+
# for bias of qlinearconv, scale and zero_point is not needed
|
|
441
|
+
if (node.op_type == "QLinearConv" and tensor == node.input[-1]) or (
|
|
442
|
+
node.op_type == "QGemm" and tensor == node.input[-3]
|
|
443
|
+
):
|
|
444
|
+
return None, None
|
|
445
|
+
else:
|
|
446
|
+
scale_tensor, zo_tensor = _searcher(tensor)
|
|
447
|
+
assert scale_tensor, f"missing scale for tensor {tensor}"
|
|
448
|
+
assert zo_tensor, f"missing zero point for tensor {tensor}"
|
|
449
|
+
return scale_tensor, zo_tensor
|
|
450
|
+
|
|
451
|
+
def save_model_to_file(self, output_path, use_external_data_format=False):
|
|
452
|
+
"""Save model to external data, which is needed for model size > 2GB."""
|
|
453
|
+
if use_external_data_format:
|
|
454
|
+
onnx.external_data_helper.convert_model_to_external_data(
|
|
455
|
+
self._model, all_tensors_to_one_file=True, location=Path(output_path).name + ".data"
|
|
456
|
+
)
|
|
457
|
+
onnx.save_model(self._model, output_path)
|
|
458
|
+
|
|
459
|
+
@staticmethod
|
|
460
|
+
def replace_node_input(node, old_input_name, new_input_name):
|
|
461
|
+
"""Replace input of a node."""
|
|
462
|
+
assert isinstance(old_input_name, str) and isinstance(new_input_name, str)
|
|
463
|
+
for j in range(len(node.input)):
|
|
464
|
+
if node.input[j] == old_input_name:
|
|
465
|
+
node.input[j] = new_input_name
|
|
466
|
+
|
|
467
|
+
def replace_input_of_all_nodes(self, old_input_name, new_input_name, white_optype=None, black_optype=None):
|
|
468
|
+
"""Replace inputs of all nodes."""
|
|
469
|
+
if white_optype is None:
|
|
470
|
+
white_optype = []
|
|
471
|
+
if black_optype is None:
|
|
472
|
+
black_optype = []
|
|
473
|
+
if len(white_optype) > 0:
|
|
474
|
+
for node in self.model.graph.node:
|
|
475
|
+
if node.op_type in white_optype:
|
|
476
|
+
ONNXModel.replace_node_input(node, old_input_name, new_input_name)
|
|
477
|
+
else:
|
|
478
|
+
for node in self.model.graph.node:
|
|
479
|
+
if node.op_type not in black_optype:
|
|
480
|
+
ONNXModel.replace_node_input(node, old_input_name, new_input_name)
|
|
481
|
+
|
|
482
|
+
@staticmethod
|
|
483
|
+
def replace_node_output(node, old_output_name, new_output_name):
|
|
484
|
+
"""Replace output of a node."""
|
|
485
|
+
assert isinstance(old_output_name, str) and isinstance(new_output_name, str)
|
|
486
|
+
for j in range(len(node.output)):
|
|
487
|
+
if node.output[j] == old_output_name:
|
|
488
|
+
node.output[j] = new_output_name
|
|
489
|
+
|
|
490
|
+
def replace_output_of_all_nodes(self, old_output_name, new_output_name, white_optype=None, black_optype=None):
|
|
491
|
+
"""Replace outputs of all nodes."""
|
|
492
|
+
if white_optype is None:
|
|
493
|
+
white_optype = []
|
|
494
|
+
if black_optype is None:
|
|
495
|
+
black_optype = []
|
|
496
|
+
if len(white_optype) > 0:
|
|
497
|
+
for node in self.model.graph.node:
|
|
498
|
+
if node.op_type in white_optype:
|
|
499
|
+
ONNXModel.replace_node_output(node, old_output_name, new_output_name)
|
|
500
|
+
else:
|
|
501
|
+
for node in self.model.graph.node:
|
|
502
|
+
if node.op_type not in black_optype:
|
|
503
|
+
ONNXModel.replace_node_output(node, old_output_name, new_output_name)
|
|
504
|
+
|
|
505
|
+
def remove_unused_nodes(self):
|
|
506
|
+
"""Remove unused nodes."""
|
|
507
|
+
unused_nodes = []
|
|
508
|
+
nodes = self.nodes()
|
|
509
|
+
for node in nodes:
|
|
510
|
+
if (
|
|
511
|
+
node.op_type == "Constant"
|
|
512
|
+
and node.output[0] not in self._model.graph.output
|
|
513
|
+
and node.output[0] not in self._input_name_to_nodes
|
|
514
|
+
):
|
|
515
|
+
unused_nodes.append(node)
|
|
516
|
+
elif (
|
|
517
|
+
node.op_type == "QuantizeLinear"
|
|
518
|
+
and len(self.get_children(node)) == 1
|
|
519
|
+
and self.get_children(node)[0].op_type == "DequantizeLinear"
|
|
520
|
+
and node.input[0] not in self._output_name_to_node
|
|
521
|
+
and self.get_children(node)[0].output[0] not in self._input_name_to_nodes
|
|
522
|
+
):
|
|
523
|
+
unused_nodes.append(node)
|
|
524
|
+
unused_nodes.extend(self.get_children(node))
|
|
525
|
+
else:
|
|
526
|
+
# remove the node if it does not serve as the input or output of any other nodes
|
|
527
|
+
unused = True
|
|
528
|
+
for output in node.output:
|
|
529
|
+
if output in self._input_name_to_nodes or output in self.output():
|
|
530
|
+
unused = False
|
|
531
|
+
break
|
|
532
|
+
for input in node.input:
|
|
533
|
+
if self.get_initializer(input) is not None:
|
|
534
|
+
continue
|
|
535
|
+
elif input in self._output_name_to_node or input in self.input():
|
|
536
|
+
unused = False
|
|
537
|
+
break
|
|
538
|
+
if unused:
|
|
539
|
+
unused_nodes.append(node)
|
|
540
|
+
self.remove_nodes(unused_nodes)
|
|
541
|
+
|
|
542
|
+
ununsed_weights = []
|
|
543
|
+
for w in self._model.graph.initializer:
|
|
544
|
+
if w.name not in self._input_name_to_nodes and w.name not in self._model.graph.output:
|
|
545
|
+
ununsed_weights.append(w)
|
|
546
|
+
# Remove from graph.input
|
|
547
|
+
for graph_input in self.graph().input:
|
|
548
|
+
if graph_input.name == w.name:
|
|
549
|
+
self.graph().input.remove(graph_input)
|
|
550
|
+
|
|
551
|
+
self.remove_initializers(ununsed_weights)
|
|
552
|
+
self.update()
|
|
553
|
+
|
|
554
|
+
def topological_sort(self, enable_subgraph=False):
|
|
555
|
+
"""Topological sort the model."""
|
|
556
|
+
|
|
557
|
+
if not enable_subgraph:
|
|
558
|
+
input_name_to_nodes = {}
|
|
559
|
+
output_name_to_node = {}
|
|
560
|
+
for node in self.model.graph.node:
|
|
561
|
+
for input_name in node.input:
|
|
562
|
+
if len(input_name.strip()) != 0:
|
|
563
|
+
if input_name not in input_name_to_nodes:
|
|
564
|
+
input_name_to_nodes[input_name] = [node]
|
|
565
|
+
else:
|
|
566
|
+
input_name_to_nodes[input_name].append(node)
|
|
567
|
+
for output_name in node.output:
|
|
568
|
+
if len(output_name.strip()) != 0:
|
|
569
|
+
output_name_to_node[output_name] = node
|
|
570
|
+
else: # pragma: no cover
|
|
571
|
+
input_name_to_nodes = self._input_name_to_nodes
|
|
572
|
+
output_name_to_node = self._output_name_to_node
|
|
573
|
+
|
|
574
|
+
all_nodes = {}
|
|
575
|
+
q = deque()
|
|
576
|
+
wait = deque()
|
|
577
|
+
for inp in self.model.graph.input:
|
|
578
|
+
q.extend(input_name_to_nodes[inp.name])
|
|
579
|
+
for n in self.model.graph.node:
|
|
580
|
+
if all(i not in output_name_to_node and i not in self.input() for i in n.input):
|
|
581
|
+
q.append(n)
|
|
582
|
+
|
|
583
|
+
while q:
|
|
584
|
+
n = q.popleft()
|
|
585
|
+
if not all(output_name_to_node[i].name in all_nodes for i in n.input if i in output_name_to_node):
|
|
586
|
+
if n not in wait:
|
|
587
|
+
wait.append(n)
|
|
588
|
+
continue
|
|
589
|
+
|
|
590
|
+
all_nodes[n.name] = n
|
|
591
|
+
for out in n.output:
|
|
592
|
+
if out in input_name_to_nodes:
|
|
593
|
+
q.extend([i for i in input_name_to_nodes[out] if i.name not in all_nodes and i not in q])
|
|
594
|
+
if len(q) == 0 and len(wait) != 0:
|
|
595
|
+
q = copy.deepcopy(wait)
|
|
596
|
+
wait.clear()
|
|
597
|
+
nodes = [i[1] for i in all_nodes.items()]
|
|
598
|
+
assert len(list({n.name for n in nodes})) == len(list({n.name for n in self.model.graph.node}))
|
|
599
|
+
self.model.graph.ClearField("node")
|
|
600
|
+
self.model.graph.node.extend(nodes)
|
|
601
|
+
|
|
602
|
+
def get_nodes_chain(self, start, stop, result_chain=None):
|
|
603
|
+
"""Get nodes chain with given start node and stop node."""
|
|
604
|
+
if result_chain is None:
|
|
605
|
+
result_chain = []
|
|
606
|
+
# process start node list
|
|
607
|
+
start_node = deque()
|
|
608
|
+
for node in start:
|
|
609
|
+
if isinstance(node, str):
|
|
610
|
+
start_node.append(node)
|
|
611
|
+
elif isinstance(node, onnx.NodeProto):
|
|
612
|
+
start_node.append(node.name)
|
|
613
|
+
else:
|
|
614
|
+
assert False, "'get_nodes_chain' function only support list[string]or list[NodeProto] params" # noqa: B011
|
|
615
|
+
|
|
616
|
+
# process stop node list
|
|
617
|
+
stop_node = []
|
|
618
|
+
for node in stop:
|
|
619
|
+
if isinstance(node, str):
|
|
620
|
+
stop_node.append(node)
|
|
621
|
+
elif isinstance(node, onnx.NodeProto):
|
|
622
|
+
stop_node.append(node.name)
|
|
623
|
+
else:
|
|
624
|
+
assert False, "'get_nodes_chain' function only support list[string]or list[NodeProto] params" # noqa: B011
|
|
625
|
+
|
|
626
|
+
while start_node:
|
|
627
|
+
node_name = start_node.popleft()
|
|
628
|
+
if node_name in stop_node:
|
|
629
|
+
continue
|
|
630
|
+
if node_name not in result_chain:
|
|
631
|
+
result_chain.append(node_name)
|
|
632
|
+
else:
|
|
633
|
+
continue
|
|
634
|
+
|
|
635
|
+
node = find_by_name(node_name, list(self.model.graph.node))
|
|
636
|
+
for parent in self.get_parents(node):
|
|
637
|
+
start_node.append(parent.name)
|
|
638
|
+
|
|
639
|
+
return result_chain
|
|
640
|
+
|
|
641
|
+
def find_split_node_for_layer_wise_quantization(self):
|
|
642
|
+
"""Find split node for layer wise quantization."""
|
|
643
|
+
# find split nodes of decoder blocks
|
|
644
|
+
# embed -> decoder.0 -(split_node)-> ... -(split_node)-> decoder.n -(split_node)-> norm -> head
|
|
645
|
+
# after split: embed -> decoder.0,
|
|
646
|
+
# decoder.1,
|
|
647
|
+
# decoder.2,
|
|
648
|
+
# ...,
|
|
649
|
+
# decoder.n,
|
|
650
|
+
# norm -> head
|
|
651
|
+
start_nodes = []
|
|
652
|
+
for node in self._model.graph.node:
|
|
653
|
+
start_node, qkv_nodes_list = None, None
|
|
654
|
+
if node.op_type == "SkipLayerNormalization":
|
|
655
|
+
start_node = node
|
|
656
|
+
qkv_nodes_list = [
|
|
657
|
+
self.match_parent_path(
|
|
658
|
+
start_node,
|
|
659
|
+
["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
|
|
660
|
+
[None, 0, 0, 0, 0],
|
|
661
|
+
),
|
|
662
|
+
self.match_parent_path(
|
|
663
|
+
start_node,
|
|
664
|
+
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
|
|
665
|
+
[1, 1, 0, 0, 0],
|
|
666
|
+
),
|
|
667
|
+
]
|
|
668
|
+
if node.op_type == "Add":
|
|
669
|
+
start_node = node
|
|
670
|
+
qkv_nodes_list = [
|
|
671
|
+
# match base attention structure
|
|
672
|
+
self.match_parent_path(
|
|
673
|
+
start_node,
|
|
674
|
+
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
|
|
675
|
+
[0, None, 0, 0, 0],
|
|
676
|
+
),
|
|
677
|
+
self.match_parent_path(
|
|
678
|
+
start_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [1, None, 0, 0, 0]
|
|
679
|
+
),
|
|
680
|
+
# match gpt attention no past structure
|
|
681
|
+
self.match_parent_path(
|
|
682
|
+
start_node,
|
|
683
|
+
["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"],
|
|
684
|
+
[None, 0, 0, 0, 0, 0],
|
|
685
|
+
output_name_to_node=self.output_name_to_node,
|
|
686
|
+
return_indice=[],
|
|
687
|
+
),
|
|
688
|
+
# match bart attention structure
|
|
689
|
+
self.match_parent_path(
|
|
690
|
+
start_node,
|
|
691
|
+
["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
|
|
692
|
+
[0, None, 0, 0, 0, 0],
|
|
693
|
+
),
|
|
694
|
+
self.match_parent_path(
|
|
695
|
+
start_node,
|
|
696
|
+
["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
|
|
697
|
+
[1, None, 0, 0, 0, 0],
|
|
698
|
+
),
|
|
699
|
+
self.match_parent_path(
|
|
700
|
+
start_node,
|
|
701
|
+
["MatMul", "Mul", "MatMul", "Mul", "Div", "Add"],
|
|
702
|
+
[None, 0, None, 0, None, 0],
|
|
703
|
+
),
|
|
704
|
+
self.match_parent_path(
|
|
705
|
+
start_node,
|
|
706
|
+
["MatMul", "Mul", "MatMul", "SimplifiedLayerNormalization", "Add"],
|
|
707
|
+
[None, 0, None, 0, 0],
|
|
708
|
+
),
|
|
709
|
+
]
|
|
710
|
+
if not start_node:
|
|
711
|
+
continue
|
|
712
|
+
if not any(qkv_nodes_list):
|
|
713
|
+
continue
|
|
714
|
+
start_nodes.append(start_node)
|
|
715
|
+
return start_nodes
|
|
716
|
+
|
|
717
|
+
def find_qkv_in_attention(self, find_all=False):
|
|
718
|
+
"""Find qkv MatMul in Attention.
|
|
719
|
+
|
|
720
|
+
Args:
|
|
721
|
+
find_all (bool, optional): find all qkv MatMul. Defaults to False
|
|
722
|
+
|
|
723
|
+
Returns:
|
|
724
|
+
qkv (list): qkv MatMul list
|
|
725
|
+
"""
|
|
726
|
+
qkv = []
|
|
727
|
+
for node in self._model.graph.node:
|
|
728
|
+
if node.op_type == "Attention":
|
|
729
|
+
qkv.append([node.name])
|
|
730
|
+
continue
|
|
731
|
+
start_node, qkv_nodes_list = None, None
|
|
732
|
+
if node.op_type == "SkipLayerNormalization":
|
|
733
|
+
start_node = node
|
|
734
|
+
qkv_nodes_list = [
|
|
735
|
+
self.match_parent_path(
|
|
736
|
+
start_node,
|
|
737
|
+
["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
|
|
738
|
+
[None, 0, 0, 0, 0],
|
|
739
|
+
),
|
|
740
|
+
self.match_parent_path(
|
|
741
|
+
start_node,
|
|
742
|
+
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
|
|
743
|
+
[1, 1, 0, 0, 0],
|
|
744
|
+
),
|
|
745
|
+
]
|
|
746
|
+
if node.op_type == "Add":
|
|
747
|
+
start_node = node
|
|
748
|
+
qkv_nodes_list = [
|
|
749
|
+
# match base attention structure
|
|
750
|
+
self.match_parent_path(
|
|
751
|
+
start_node,
|
|
752
|
+
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
|
|
753
|
+
[0, None, 0, 0, 0],
|
|
754
|
+
),
|
|
755
|
+
self.match_parent_path(
|
|
756
|
+
start_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [1, None, 0, 0, 0]
|
|
757
|
+
),
|
|
758
|
+
# match gpt attention no past structure
|
|
759
|
+
self.match_parent_path(
|
|
760
|
+
start_node,
|
|
761
|
+
["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"],
|
|
762
|
+
[None, 0, 0, 0, 0, 0],
|
|
763
|
+
output_name_to_node=self.output_name_to_node,
|
|
764
|
+
return_indice=[],
|
|
765
|
+
),
|
|
766
|
+
# match bart attention structure
|
|
767
|
+
self.match_parent_path(
|
|
768
|
+
start_node,
|
|
769
|
+
["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
|
|
770
|
+
[0, None, 0, 0, 0, 0],
|
|
771
|
+
),
|
|
772
|
+
self.match_parent_path(
|
|
773
|
+
start_node,
|
|
774
|
+
["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
|
|
775
|
+
[1, None, 0, 0, 0, 0],
|
|
776
|
+
),
|
|
777
|
+
]
|
|
778
|
+
if not start_node:
|
|
779
|
+
continue
|
|
780
|
+
if not any(qkv_nodes_list):
|
|
781
|
+
continue
|
|
782
|
+
qkv_nodes = [qkv for qkv in qkv_nodes_list if qkv is not None][-1]
|
|
783
|
+
other_inputs = []
|
|
784
|
+
for input in start_node.input:
|
|
785
|
+
if input not in self.output_name_to_node:
|
|
786
|
+
continue
|
|
787
|
+
if input == qkv_nodes[0].output[0]:
|
|
788
|
+
continue
|
|
789
|
+
other_inputs.append(input)
|
|
790
|
+
if len(other_inputs) != 1:
|
|
791
|
+
continue
|
|
792
|
+
root_input = other_inputs[0]
|
|
793
|
+
input_name_to_nodes = self.input_name_to_nodes
|
|
794
|
+
children = input_name_to_nodes[root_input]
|
|
795
|
+
children_types = [child.op_type for child in children]
|
|
796
|
+
if children_types.count("MatMul") == 3:
|
|
797
|
+
qkv.append([child.name for child in children if child.op_type == "MatMul"])
|
|
798
|
+
if not find_all:
|
|
799
|
+
break
|
|
800
|
+
return qkv
|
|
801
|
+
|
|
802
|
+
def find_ffn_matmul(self, attention_index, attention_matmul_list, block_len):
|
|
803
|
+
"""Find MatMul in FFN.
|
|
804
|
+
|
|
805
|
+
Args:
|
|
806
|
+
attention_index (list): index of Attention
|
|
807
|
+
attention_matmul_list (list): list of Attention and MatMul nodes
|
|
808
|
+
block_len (int): block length
|
|
809
|
+
|
|
810
|
+
Returns:
|
|
811
|
+
list: list of MatMul in FFN
|
|
812
|
+
"""
|
|
813
|
+
ffn_matmul = []
|
|
814
|
+
for idx in range(len(attention_index)):
|
|
815
|
+
if idx != len(attention_index) - 1:
|
|
816
|
+
index = attention_index[idx + 1]
|
|
817
|
+
if index - 2 >= 0:
|
|
818
|
+
ffn_matmul.append([attention_matmul_list[index - 2], attention_matmul_list[index - 1]])
|
|
819
|
+
else:
|
|
820
|
+
index = attention_index[idx]
|
|
821
|
+
if index + block_len - 1 < len(attention_matmul_list):
|
|
822
|
+
ffn_matmul.append(
|
|
823
|
+
[attention_matmul_list[index + block_len - 2], attention_matmul_list[index + block_len - 1]]
|
|
824
|
+
)
|
|
825
|
+
return ffn_matmul
|
|
826
|
+
|
|
827
|
+
def export(self, save_path, conf):
|
|
828
|
+
"""Export Qlinear to QDQ model."""
|
|
829
|
+
from neural_compressor.config import ONNXQlinear2QDQConfig # noqa: PLC0415
|
|
830
|
+
from neural_compressor.utils.export import onnx_qlinear_to_qdq # noqa: PLC0415
|
|
831
|
+
|
|
832
|
+
if isinstance(conf, ONNXQlinear2QDQConfig):
|
|
833
|
+
add_nodes, remove_nodes, inits = onnx_qlinear_to_qdq(self._model, self._input_name_to_nodes)
|
|
834
|
+
self.add_nodes(add_nodes)
|
|
835
|
+
self.remove_nodes(remove_nodes)
|
|
836
|
+
self.add_initializers(inits)
|
|
837
|
+
self.update()
|
|
838
|
+
self.remove_unused_nodes()
|
|
839
|
+
self.topological_sort()
|
|
840
|
+
self.save(save_path)
|
|
841
|
+
else:
|
|
842
|
+
logger.warning("Unsupported config for export, only ONNXQlinear2QDQConfig is supported!")
|
|
843
|
+
exit(0)
|
|
844
|
+
|
|
845
|
+
def add_tensors_to_outputs(self, tensor_names):
|
|
846
|
+
"""Add the tensors to the model outputs to gets their values.
|
|
847
|
+
|
|
848
|
+
Args:
|
|
849
|
+
tensor_names: The names of tensors to be dumped.
|
|
850
|
+
"""
|
|
851
|
+
added_outputs = []
|
|
852
|
+
for tensor in tensor_names:
|
|
853
|
+
if tensor not in self.output():
|
|
854
|
+
added_tensor = onnx.helper.ValueInfoProto()
|
|
855
|
+
added_tensor.name = tensor
|
|
856
|
+
added_outputs.append(added_tensor)
|
|
857
|
+
self._model.graph.output.extend(added_outputs) # pylint: disable=no-member
|
|
858
|
+
|
|
859
|
+
def remove_tensors_from_outputs(self, tensor_names):
|
|
860
|
+
"""Remove the tensors from the model outputs.
|
|
861
|
+
|
|
862
|
+
Args:
|
|
863
|
+
tensor_names: The names of tensors to be removed.
|
|
864
|
+
"""
|
|
865
|
+
removed_outputs = []
|
|
866
|
+
for tensor in tensor_names:
|
|
867
|
+
if tensor in self.output():
|
|
868
|
+
removed_outputs.append(self._model.graph.output[self.output().index(tensor)])
|
|
869
|
+
for output in removed_outputs:
|
|
870
|
+
self._model.graph.output.remove(output)
|
|
871
|
+
|
|
872
|
+
def match_first_parent(self, node, parent_op_type, output_name_to_node, exclude=None):
|
|
873
|
+
"""Find parent node based on constraints on op_type.
|
|
874
|
+
|
|
875
|
+
Args:
|
|
876
|
+
node (str): current node name.
|
|
877
|
+
parent_op_type (str): constraint of parent node op_type.
|
|
878
|
+
output_name_to_node (dict): dictionary with output name as key, and node as value.
|
|
879
|
+
exclude (list): list of nodes that are excluded (not allowed to match as parent).
|
|
880
|
+
|
|
881
|
+
Returns:
|
|
882
|
+
parent: The matched parent node. None if not found.
|
|
883
|
+
index: The input index of matched parent node. None if not found.
|
|
884
|
+
"""
|
|
885
|
+
if exclude is None:
|
|
886
|
+
exclude = []
|
|
887
|
+
for i, input in enumerate(node.input):
|
|
888
|
+
if input in output_name_to_node:
|
|
889
|
+
parent = output_name_to_node[input]
|
|
890
|
+
if parent.op_type == parent_op_type and parent not in exclude:
|
|
891
|
+
return parent, i
|
|
892
|
+
return None, None
|
|
893
|
+
|
|
894
|
+
def match_parent(
|
|
895
|
+
self,
|
|
896
|
+
node,
|
|
897
|
+
parent_op_type,
|
|
898
|
+
input_index=None,
|
|
899
|
+
output_name_to_node=None,
|
|
900
|
+
exclude=None,
|
|
901
|
+
return_indice=None,
|
|
902
|
+
):
|
|
903
|
+
"""Find parent node based on constraints on op_type and index.
|
|
904
|
+
|
|
905
|
+
Args:
|
|
906
|
+
node (str): current node name.
|
|
907
|
+
parent_op_type (str): constraint of parent node op_type.
|
|
908
|
+
input_index (int or None): only check the parent given input index of current node.
|
|
909
|
+
output_name_to_node (dict): dictionary with output name as key, and node as value.
|
|
910
|
+
exclude (list): list of nodes that are excluded (not allowed to match as parent).
|
|
911
|
+
return_indice (list): a list to append the input index when input_index is None.
|
|
912
|
+
|
|
913
|
+
Returns:
|
|
914
|
+
parent: The matched parent node.
|
|
915
|
+
"""
|
|
916
|
+
assert node is not None
|
|
917
|
+
assert input_index is None or input_index >= 0
|
|
918
|
+
if exclude is None:
|
|
919
|
+
exclude = []
|
|
920
|
+
if output_name_to_node is None:
|
|
921
|
+
output_name_to_node = self._output_name_to_node
|
|
922
|
+
|
|
923
|
+
if input_index is None:
|
|
924
|
+
parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node, exclude)
|
|
925
|
+
if return_indice is not None:
|
|
926
|
+
return_indice.append(index)
|
|
927
|
+
return parent
|
|
928
|
+
|
|
929
|
+
if input_index >= len(node.input):
|
|
930
|
+
return None
|
|
931
|
+
|
|
932
|
+
parent = self.get_parent(node, input_index, output_name_to_node)
|
|
933
|
+
if parent is not None and parent.op_type == parent_op_type and parent not in exclude:
|
|
934
|
+
return parent
|
|
935
|
+
|
|
936
|
+
return None
|
|
937
|
+
|
|
938
|
+
def match_parent_path(
|
|
939
|
+
self,
|
|
940
|
+
node,
|
|
941
|
+
parent_op_types,
|
|
942
|
+
parent_input_index,
|
|
943
|
+
output_name_to_node=None,
|
|
944
|
+
return_indice=None,
|
|
945
|
+
):
|
|
946
|
+
"""Find a sequence of input edges based on constraints on parent op_type and index.
|
|
947
|
+
|
|
948
|
+
Args:
|
|
949
|
+
node (str): current node name.
|
|
950
|
+
parent_op_types (str): constraint of parent node op_type of each input edge.
|
|
951
|
+
parent_input_index (list): constraint of input index of each input edge.
|
|
952
|
+
None means no constraint.
|
|
953
|
+
output_name_to_node (dict): dictionary with output name as key, and node as value.
|
|
954
|
+
return_indice (list): a list to append the input index when there is
|
|
955
|
+
no constraint on input index of an edge.
|
|
956
|
+
|
|
957
|
+
Returns:
|
|
958
|
+
parents: a list of matched parent node.
|
|
959
|
+
"""
|
|
960
|
+
assert len(parent_input_index) == len(parent_op_types)
|
|
961
|
+
|
|
962
|
+
if output_name_to_node is None:
|
|
963
|
+
output_name_to_node = self._output_name_to_node
|
|
964
|
+
|
|
965
|
+
current_node = node
|
|
966
|
+
matched_parents = []
|
|
967
|
+
for i, op_type in enumerate(parent_op_types):
|
|
968
|
+
matched_parent = self.match_parent(
|
|
969
|
+
current_node,
|
|
970
|
+
op_type,
|
|
971
|
+
parent_input_index[i],
|
|
972
|
+
output_name_to_node,
|
|
973
|
+
exclude=[],
|
|
974
|
+
return_indice=return_indice,
|
|
975
|
+
)
|
|
976
|
+
if matched_parent is None:
|
|
977
|
+
return None
|
|
978
|
+
|
|
979
|
+
matched_parents.append(matched_parent)
|
|
980
|
+
current_node = matched_parent
|
|
981
|
+
|
|
982
|
+
return matched_parents
|
|
983
|
+
|
|
984
|
+
def is_smoothquant_model(self):
|
|
985
|
+
"""Check the model is smooth quantized or not.
|
|
986
|
+
|
|
987
|
+
Returns:
|
|
988
|
+
bool: the model is smooth quantized or not.
|
|
989
|
+
"""
|
|
990
|
+
for init in self.model.graph.initializer: # noqa: SIM110
|
|
991
|
+
if "_smooth_scale" in init.name:
|
|
992
|
+
return True
|
|
993
|
+
return False
|
|
994
|
+
|
|
995
|
+
def find_split_nodes(self):
|
|
996
|
+
"""Find split nodes for layer-wise quantization."""
|
|
997
|
+
split_nodes = self.find_split_node_for_layer_wise_quantization()
|
|
998
|
+
return split_nodes
|
|
999
|
+
|
|
1000
|
+
def split_model_with_node(
|
|
1001
|
+
self, split_node_name, path_of_model_to_split, shape_infer=True, save_both_split_models=True
|
|
1002
|
+
):
|
|
1003
|
+
"""Split model into two parts at a given node.
|
|
1004
|
+
|
|
1005
|
+
Args:
|
|
1006
|
+
split_node_name (str): name of the node where the model is split at>
|
|
1007
|
+
path_of_model_to_split (str): path of model to be split.
|
|
1008
|
+
shape_infer (bool): do shape inference. Default is True.
|
|
1009
|
+
save_both_split_models (bool): whether to save the two split models.
|
|
1010
|
+
False means only save the first split model.
|
|
1011
|
+
True means save both the two split models.
|
|
1012
|
+
Default id True.
|
|
1013
|
+
|
|
1014
|
+
Returns:
|
|
1015
|
+
tuple: the first split model, the second split model
|
|
1016
|
+
"""
|
|
1017
|
+
# origin model : ... -> node_1 -> split_node -> node_2 -> ...
|
|
1018
|
+
# split model 1: ... -> node_1 -> split_node
|
|
1019
|
+
# split model 2: node_2 -> ...
|
|
1020
|
+
|
|
1021
|
+
split_model_part_1 = onnx.ModelProto()
|
|
1022
|
+
split_model_part_1.CopyFrom(self._model)
|
|
1023
|
+
split_model_part_1.graph.ClearField("node")
|
|
1024
|
+
|
|
1025
|
+
split_model_part_2 = onnx.ModelProto()
|
|
1026
|
+
split_model_part_2.CopyFrom(self._model)
|
|
1027
|
+
split_model_part_2.graph.ClearField("node")
|
|
1028
|
+
|
|
1029
|
+
split_node_output = None
|
|
1030
|
+
part_idx = 1
|
|
1031
|
+
for node in self._model.graph.node:
|
|
1032
|
+
if part_idx == 1:
|
|
1033
|
+
split_model_part_1.graph.node.append(node)
|
|
1034
|
+
elif part_idx == 2:
|
|
1035
|
+
split_model_part_2.graph.node.append(node)
|
|
1036
|
+
|
|
1037
|
+
if node.name == split_node_name:
|
|
1038
|
+
split_node_output = node.output
|
|
1039
|
+
part_idx = 2
|
|
1040
|
+
|
|
1041
|
+
assert len(split_node_output) == 1, (
|
|
1042
|
+
f"Only support split at node with 1 output tensor, while current split node {split_node_name} has {len(split_node_output)} output tensors"
|
|
1043
|
+
)
|
|
1044
|
+
split_tensor_name = split_node_output[0]
|
|
1045
|
+
|
|
1046
|
+
# infer shape of the model to be split
|
|
1047
|
+
if shape_infer:
|
|
1048
|
+
try:
|
|
1049
|
+
from neural_compressor.adaptor.ox_utils.util import infer_shapes # noqa: PLC0415
|
|
1050
|
+
|
|
1051
|
+
self._model = infer_shapes(self._model, auto_merge=True, base_dir=os.path.dirname(self._model_path))
|
|
1052
|
+
except Exception as e: # pragma: no cover
|
|
1053
|
+
logger.error(
|
|
1054
|
+
"Shape infer fails for layer-wise quantization. "
|
|
1055
|
+
"We would recommend checking the graph optimization level of your model "
|
|
1056
|
+
"and setting it to 'DISABLE_ALL' or 'ENABLE_BASIC', "
|
|
1057
|
+
"as this may help avoid this error."
|
|
1058
|
+
)
|
|
1059
|
+
raise e
|
|
1060
|
+
|
|
1061
|
+
split_tensor_type, split_tensor_shape = self._get_output_type_shape_by_tensor_name(split_tensor_name)
|
|
1062
|
+
split_tensor = onnx.helper.make_tensor_value_info(split_tensor_name, split_tensor_type, split_tensor_shape)
|
|
1063
|
+
|
|
1064
|
+
split_model_part_1 = ONNXModel(split_model_part_1, ignore_warning=True)
|
|
1065
|
+
split_model_part_2 = ONNXModel(split_model_part_2, ignore_warning=True)
|
|
1066
|
+
|
|
1067
|
+
# remove unused input & output
|
|
1068
|
+
split_model_part_1._remove_unused_input_output()
|
|
1069
|
+
split_model_part_2._remove_unused_input_output()
|
|
1070
|
+
|
|
1071
|
+
split_model_part_1.model.graph.output.append(split_tensor)
|
|
1072
|
+
split_model_part_2.model.graph.input.append(split_tensor)
|
|
1073
|
+
|
|
1074
|
+
insert_output_for_model_1 = []
|
|
1075
|
+
insert_input_for_model_2 = []
|
|
1076
|
+
for output in split_model_part_1.output_name_to_node:
|
|
1077
|
+
if output in split_model_part_2.input_name_to_nodes:
|
|
1078
|
+
output_type, output_shape = self._get_output_type_shape_by_tensor_name(output)
|
|
1079
|
+
output_tensor = onnx.helper.make_tensor_value_info(output, output_type, output_shape)
|
|
1080
|
+
if output_tensor not in split_model_part_1.model.graph.output:
|
|
1081
|
+
insert_output_for_model_1.append(output_tensor)
|
|
1082
|
+
if output_tensor not in split_model_part_2.model.graph.input:
|
|
1083
|
+
insert_input_for_model_2.append(output_tensor)
|
|
1084
|
+
|
|
1085
|
+
# insert model 1 output
|
|
1086
|
+
for output in insert_output_for_model_1:
|
|
1087
|
+
split_model_part_1.model.graph.output.append(output)
|
|
1088
|
+
|
|
1089
|
+
# insert model 2 input
|
|
1090
|
+
for input in insert_input_for_model_2:
|
|
1091
|
+
split_model_part_2.model.graph.input.append(input)
|
|
1092
|
+
|
|
1093
|
+
# remove unused init
|
|
1094
|
+
split_model_part_1.remove_unused_init()
|
|
1095
|
+
split_model_part_2.remove_unused_init()
|
|
1096
|
+
|
|
1097
|
+
split_model_part_1.update()
|
|
1098
|
+
split_model_part_2.update()
|
|
1099
|
+
|
|
1100
|
+
dir_of_model_to_split = os.path.dirname(path_of_model_to_split)
|
|
1101
|
+
|
|
1102
|
+
split_model_part_1.load_model_initializer_by_tensor(dir_of_model_to_split)
|
|
1103
|
+
split_model_part_1_path = os.path.join(dir_of_model_to_split, "split_model_part_1.onnx")
|
|
1104
|
+
split_model_part_1.model_path = split_model_part_1_path
|
|
1105
|
+
split_model_part_1._save_split_model(split_model_part_1_path)
|
|
1106
|
+
split_model_part_1.check_is_large_model()
|
|
1107
|
+
logger.debug(f"save split model part 1 to {split_model_part_1_path} for layer wise quantization")
|
|
1108
|
+
|
|
1109
|
+
if save_both_split_models:
|
|
1110
|
+
split_model_part_2.load_model_initializer_by_tensor(dir_of_model_to_split)
|
|
1111
|
+
split_model_part_2_path = os.path.join(dir_of_model_to_split, "split_model_part_2.onnx")
|
|
1112
|
+
split_model_part_2.model_path = split_model_part_2_path
|
|
1113
|
+
split_model_part_2._save_split_model(split_model_part_2_path)
|
|
1114
|
+
split_model_part_2.check_is_large_model()
|
|
1115
|
+
logger.debug(f"save split model part 2 to {split_model_part_2_path} for layer wise quantization")
|
|
1116
|
+
return split_model_part_1, split_model_part_2
|
|
1117
|
+
else:
|
|
1118
|
+
return split_model_part_1, split_model_part_2
|
|
1119
|
+
|
|
1120
|
+
def _save_split_model(self, save_path):
|
|
1121
|
+
"""Save split model as external data for layer wise quantization.
|
|
1122
|
+
|
|
1123
|
+
Args:
|
|
1124
|
+
save_path (str): the path to save the split model
|
|
1125
|
+
"""
|
|
1126
|
+
if os.path.exists(save_path + "_data"):
|
|
1127
|
+
os.remove(save_path + "_data")
|
|
1128
|
+
onnx.save_model(
|
|
1129
|
+
self._model,
|
|
1130
|
+
save_path,
|
|
1131
|
+
save_as_external_data=True,
|
|
1132
|
+
all_tensors_to_one_file=True,
|
|
1133
|
+
location=save_path.split("/")[-1] + "_data",
|
|
1134
|
+
size_threshold=1024,
|
|
1135
|
+
convert_attribute=False,
|
|
1136
|
+
)
|
|
1137
|
+
|
|
1138
|
+
def _get_output_type_shape_by_tensor_name(self, tensor_name):
|
|
1139
|
+
"""Get output type and shape with a tensor name.
|
|
1140
|
+
|
|
1141
|
+
Args:
|
|
1142
|
+
tensor_name (str): name of a tensor
|
|
1143
|
+
|
|
1144
|
+
Returns:
|
|
1145
|
+
tuple: output type and shape
|
|
1146
|
+
"""
|
|
1147
|
+
elem_type = onnx.TensorProto.FLOAT
|
|
1148
|
+
shape = None
|
|
1149
|
+
for output in self._model.graph.value_info:
|
|
1150
|
+
if output.name == tensor_name:
|
|
1151
|
+
elem_type = output.type.tensor_type.elem_type
|
|
1152
|
+
shape = [
|
|
1153
|
+
dim.dim_value if dim.HasField("dim_value") else -1 for dim in output.type.tensor_type.shape.dim
|
|
1154
|
+
]
|
|
1155
|
+
break
|
|
1156
|
+
return elem_type, shape
|
|
1157
|
+
|
|
1158
|
+
def _remove_unused_input_output(self):
|
|
1159
|
+
"""Remove unused input & output for split model."""
|
|
1160
|
+
remove_outputs = []
|
|
1161
|
+
remove_inputs = []
|
|
1162
|
+
for output in self._model.graph.output:
|
|
1163
|
+
if output.name not in self.output_name_to_node:
|
|
1164
|
+
remove_outputs.append(output)
|
|
1165
|
+
|
|
1166
|
+
for input in self._model.graph.input:
|
|
1167
|
+
if input.name not in self.input_name_to_nodes:
|
|
1168
|
+
remove_inputs.append(input)
|
|
1169
|
+
|
|
1170
|
+
for output in remove_outputs:
|
|
1171
|
+
self._model.graph.output.remove(output)
|
|
1172
|
+
for input in remove_inputs:
|
|
1173
|
+
self._model.graph.input.remove(input)
|
|
1174
|
+
|
|
1175
|
+
def remove_unused_init(self):
|
|
1176
|
+
"""Remove unused init."""
|
|
1177
|
+
remov_inits = []
|
|
1178
|
+
for init in self._model.graph.initializer:
|
|
1179
|
+
if init.name not in self.input_name_to_nodes:
|
|
1180
|
+
remov_inits.append(init)
|
|
1181
|
+
self.remove_initializers(remov_inits)
|
|
1182
|
+
|
|
1183
|
+
def load_model_initializer_by_tensor(self, data_path=None):
|
|
1184
|
+
"""Load model initializer by tensor.
|
|
1185
|
+
|
|
1186
|
+
Args:
|
|
1187
|
+
data_path (str, optional): the directory of saved initializer. Defaults to None.
|
|
1188
|
+
"""
|
|
1189
|
+
if data_path is None:
|
|
1190
|
+
data_path = os.path.dirname(self._model_path)
|
|
1191
|
+
for init in self._model.graph.initializer:
|
|
1192
|
+
if init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL:
|
|
1193
|
+
onnx.external_data_helper.load_external_data_for_tensor(init, data_path)
|
|
1194
|
+
|
|
1195
|
+
def write_external_data_to_new_location(self, external_data_location="external.data", overwrite=False):
|
|
1196
|
+
"""Write external data of merged quantized model to new location to save memory.
|
|
1197
|
+
|
|
1198
|
+
Args:
|
|
1199
|
+
external_data_location (str, optional): external data location of merged quantized model.
|
|
1200
|
+
Defaults to "external.data".
|
|
1201
|
+
overwrite (bool, optional): if True, remove existed externa data. Defaults to False.
|
|
1202
|
+
"""
|
|
1203
|
+
if overwrite and os.path.exists(os.path.join(os.path.dirname(self._model_path), external_data_location)):
|
|
1204
|
+
os.remove(os.path.join(os.path.dirname(self._model_path), external_data_location))
|
|
1205
|
+
self.load_model_initializer_by_tensor()
|
|
1206
|
+
onnx.external_data_helper.convert_model_to_external_data(self._model, location=external_data_location)
|
|
1207
|
+
# TODO : if init is already saved, skip write it
|
|
1208
|
+
onnx.external_data_helper.write_external_data_tensors(self._model, filepath=os.path.dirname(self._model_path))
|
|
1209
|
+
|
|
1210
|
+
def merge_split_models(self, to_merge_model):
|
|
1211
|
+
"""Merge two split model into final model."""
|
|
1212
|
+
to_merge_model.write_external_data_to_new_location()
|
|
1213
|
+
self.add_nodes(list(to_merge_model.nodes()))
|
|
1214
|
+
self.add_initializers(list(to_merge_model.initializer()))
|
|
1215
|
+
self.update()
|
|
1216
|
+
|
|
1217
|
+
# add new output
|
|
1218
|
+
for output in to_merge_model.graph().output:
|
|
1219
|
+
if output.name not in self.output():
|
|
1220
|
+
self._model.graph.output.append(output)
|
|
1221
|
+
|
|
1222
|
+
# remove unused output
|
|
1223
|
+
remove_output = []
|
|
1224
|
+
for output in self._model.graph.output:
|
|
1225
|
+
if output.name in to_merge_model.input():
|
|
1226
|
+
remove_output.append(output)
|
|
1227
|
+
for output in remove_output:
|
|
1228
|
+
self._model.graph.output.remove(output)
|
|
1229
|
+
|
|
1230
|
+
# add new input
|
|
1231
|
+
for input in to_merge_model.graph().input:
|
|
1232
|
+
if (
|
|
1233
|
+
input.name not in self.input()
|
|
1234
|
+
and input.name not in self.output()
|
|
1235
|
+
and input.name not in self.output_name_to_node
|
|
1236
|
+
):
|
|
1237
|
+
self._model.graph.input.append(input)
|
|
1238
|
+
|
|
1239
|
+
def re_org_output(self, origin_output):
|
|
1240
|
+
"""Re-org output of merged model for layer-wise quantization."""
|
|
1241
|
+
outputs = {}
|
|
1242
|
+
tmp_remove = []
|
|
1243
|
+
for output in self._model.graph.output:
|
|
1244
|
+
outputs[output.name] = output
|
|
1245
|
+
tmp_remove.append(output)
|
|
1246
|
+
|
|
1247
|
+
for output in tmp_remove:
|
|
1248
|
+
self._model.graph.output.remove(output)
|
|
1249
|
+
|
|
1250
|
+
for out_name in origin_output:
|
|
1251
|
+
self._model.graph.output.append(outputs[out_name])
|