onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,580 @@
|
|
|
1
|
+
# --------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import onnx
|
|
8
|
+
import onnx.helper as onnx_helper
|
|
9
|
+
import onnx.numpy_helper as onnx_numpy_helper
|
|
10
|
+
from onnx.onnx_pb import ModelProto
|
|
11
|
+
|
|
12
|
+
from .quant_utils import attribute_to_kwarg, find_by_name
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _clean_initializers_helper(graph, model):
|
|
16
|
+
"""Clean unused initializers from graph.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
A cleaned graph without unused initializers
|
|
20
|
+
A list of tensor names, which are not produced by this graph and its subgraphes
|
|
21
|
+
"""
|
|
22
|
+
requesting_tensor_names = set()
|
|
23
|
+
requesting_tensor_names.update(input_name for node in graph.node for input_name in node.input if input_name)
|
|
24
|
+
requesting_tensor_names.update(g_out.name for g_out in graph.output if g_out.name)
|
|
25
|
+
|
|
26
|
+
new_nodes = []
|
|
27
|
+
for node in graph.node:
|
|
28
|
+
new_node = node
|
|
29
|
+
graph_attrs = [
|
|
30
|
+
attr
|
|
31
|
+
for attr in node.attribute
|
|
32
|
+
if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS
|
|
33
|
+
]
|
|
34
|
+
if graph_attrs:
|
|
35
|
+
kwargs = {}
|
|
36
|
+
for attr in node.attribute:
|
|
37
|
+
new_attribute = {}
|
|
38
|
+
if attr.type == onnx.AttributeProto.GRAPH:
|
|
39
|
+
(
|
|
40
|
+
cleaned_sub_graph,
|
|
41
|
+
sub_requesting_tensor_names,
|
|
42
|
+
) = _clean_initializers_helper(attr.g, model)
|
|
43
|
+
new_attribute = {attr.name: cleaned_sub_graph}
|
|
44
|
+
requesting_tensor_names.update(sub_requesting_tensor_names)
|
|
45
|
+
elif attr.type == onnx.AttributeProto.GRAPHS:
|
|
46
|
+
cleaned_graphes = []
|
|
47
|
+
for subgraph in attr.graphs:
|
|
48
|
+
(
|
|
49
|
+
cleaned_sub_graph,
|
|
50
|
+
sub_requesting_tensor_names,
|
|
51
|
+
) = _clean_initializers_helper(subgraph, model)
|
|
52
|
+
cleaned_graphes.append(cleaned_sub_graph)
|
|
53
|
+
requesting_tensor_names.update(sub_requesting_tensor_names)
|
|
54
|
+
new_attribute = {attr.name: cleaned_graphes}
|
|
55
|
+
else:
|
|
56
|
+
new_attribute = attribute_to_kwarg(attr)
|
|
57
|
+
kwargs.update(new_attribute)
|
|
58
|
+
new_node = onnx_helper.make_node(node.op_type, node.input, node.output, name=node.name, **kwargs)
|
|
59
|
+
new_nodes.append(new_node)
|
|
60
|
+
|
|
61
|
+
graph.ClearField("node")
|
|
62
|
+
graph.node.extend(new_nodes)
|
|
63
|
+
|
|
64
|
+
requesting_tensor_names.difference_update(output for node in graph.node for output in node.output)
|
|
65
|
+
|
|
66
|
+
unused_initializer = []
|
|
67
|
+
for initializer in graph.initializer:
|
|
68
|
+
if initializer.name in requesting_tensor_names:
|
|
69
|
+
requesting_tensor_names.remove(initializer.name)
|
|
70
|
+
else:
|
|
71
|
+
# mark it to remove, remove here directly will cause mis-behavier
|
|
72
|
+
unused_initializer.append(initializer)
|
|
73
|
+
|
|
74
|
+
name_to_input = {input.name: input for input in graph.input}
|
|
75
|
+
for initializer in unused_initializer:
|
|
76
|
+
graph.initializer.remove(initializer)
|
|
77
|
+
if initializer.name in name_to_input:
|
|
78
|
+
try:
|
|
79
|
+
graph.input.remove(name_to_input[initializer.name])
|
|
80
|
+
except StopIteration:
|
|
81
|
+
if model.ir_version < 4:
|
|
82
|
+
print(f"Warning: invalid weight name {initializer.name} found in the graph (not a graph input)")
|
|
83
|
+
|
|
84
|
+
requesting_tensor_names.difference_update(input.name for input in graph.input)
|
|
85
|
+
|
|
86
|
+
return graph, requesting_tensor_names
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class ONNXModel:
|
|
90
|
+
def __init__(self, model: ModelProto):
|
|
91
|
+
self.model = model
|
|
92
|
+
|
|
93
|
+
def nodes(self):
|
|
94
|
+
return self.model.graph.node
|
|
95
|
+
|
|
96
|
+
def initializer(self):
|
|
97
|
+
return self.model.graph.initializer
|
|
98
|
+
|
|
99
|
+
def initializer_extend(self, inits):
|
|
100
|
+
if len(inits) == 0:
|
|
101
|
+
raise ValueError("Can add an empty list.")
|
|
102
|
+
for init in self.initializer():
|
|
103
|
+
self._check_init(init, "gain")
|
|
104
|
+
for init in inits:
|
|
105
|
+
self._check_init(init)
|
|
106
|
+
self.model.graph.initializer.append(init)
|
|
107
|
+
|
|
108
|
+
def graph(self):
|
|
109
|
+
return self.model.graph
|
|
110
|
+
|
|
111
|
+
def ir_version(self):
|
|
112
|
+
return self.model.ir_version
|
|
113
|
+
|
|
114
|
+
def opset_import(self):
|
|
115
|
+
return self.model.opset_import
|
|
116
|
+
|
|
117
|
+
def set_opset_import(self, domain, version):
|
|
118
|
+
for opset in self.model.opset_import:
|
|
119
|
+
if opset.domain == domain:
|
|
120
|
+
opset.version = version
|
|
121
|
+
return
|
|
122
|
+
|
|
123
|
+
self.model.opset_import.extend([onnx_helper.make_opsetid(domain, version)])
|
|
124
|
+
|
|
125
|
+
def remove_node(self, node):
|
|
126
|
+
if node in self.model.graph.node:
|
|
127
|
+
self.model.graph.node.remove(node)
|
|
128
|
+
|
|
129
|
+
def remove_nodes(self, nodes_to_remove):
|
|
130
|
+
for node in nodes_to_remove:
|
|
131
|
+
self.remove_node(node)
|
|
132
|
+
|
|
133
|
+
def add_node(self, node):
|
|
134
|
+
self.model.graph.node.extend([self._check_node(node)])
|
|
135
|
+
|
|
136
|
+
def add_nodes(self, nodes_to_add):
|
|
137
|
+
for node in nodes_to_add:
|
|
138
|
+
self.add_node(node)
|
|
139
|
+
|
|
140
|
+
def add_initializer(self, tensor):
|
|
141
|
+
if find_by_name(tensor.name, self.model.graph.initializer) is None:
|
|
142
|
+
self._check_init(tensor)
|
|
143
|
+
self.model.graph.initializer.extend([tensor])
|
|
144
|
+
|
|
145
|
+
def get_initializer(self, name):
|
|
146
|
+
for tensor in self.model.graph.initializer:
|
|
147
|
+
if tensor.name == name:
|
|
148
|
+
return tensor
|
|
149
|
+
return None
|
|
150
|
+
|
|
151
|
+
def find_graph_input(self, input_name):
|
|
152
|
+
for input in self.model.graph.input:
|
|
153
|
+
if input.name == input_name:
|
|
154
|
+
return input
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
def find_graph_output(self, output_name):
|
|
158
|
+
for output in self.model.graph.output:
|
|
159
|
+
if output.name == output_name:
|
|
160
|
+
return output
|
|
161
|
+
return None
|
|
162
|
+
|
|
163
|
+
def get_tensor_type(self, tensor_name: str):
|
|
164
|
+
tensor_type_map = {obj.name: obj.type for obj in self.model.graph.value_info}
|
|
165
|
+
|
|
166
|
+
if tensor_name in tensor_type_map:
|
|
167
|
+
return tensor_type_map[tensor_name].tensor_type
|
|
168
|
+
|
|
169
|
+
g_input = self.find_graph_input(tensor_name)
|
|
170
|
+
if g_input:
|
|
171
|
+
return g_input.type.tensor_type
|
|
172
|
+
|
|
173
|
+
g_output = self.find_graph_output(tensor_name)
|
|
174
|
+
if g_output:
|
|
175
|
+
return g_output.type.tensor_type
|
|
176
|
+
|
|
177
|
+
return None
|
|
178
|
+
|
|
179
|
+
def get_constant_value(self, output_name):
|
|
180
|
+
for node in self.model.graph.node:
|
|
181
|
+
if node.op_type == "Constant":
|
|
182
|
+
if node.output[0] == output_name:
|
|
183
|
+
for attr in node.attribute:
|
|
184
|
+
if attr.name == "value":
|
|
185
|
+
return onnx_numpy_helper.to_array(attr.t)
|
|
186
|
+
|
|
187
|
+
# Fallback to initializer since constant folding may have been applied.
|
|
188
|
+
initializer = self.get_initializer(output_name)
|
|
189
|
+
if initializer is not None:
|
|
190
|
+
return onnx_numpy_helper.to_array(initializer)
|
|
191
|
+
|
|
192
|
+
return None
|
|
193
|
+
|
|
194
|
+
def get_initializer_name_set(self):
|
|
195
|
+
return {initializer.name for initializer in self.model.graph.initializer}
|
|
196
|
+
|
|
197
|
+
def remove_initializer(self, tensor):
|
|
198
|
+
if tensor in self.model.graph.initializer:
|
|
199
|
+
self.model.graph.initializer.remove(tensor)
|
|
200
|
+
for input in self.model.graph.input:
|
|
201
|
+
if input.name == tensor.name:
|
|
202
|
+
self.model.graph.input.remove(input)
|
|
203
|
+
break
|
|
204
|
+
|
|
205
|
+
def remove_initializers(self, init_to_remove):
|
|
206
|
+
for initializer in init_to_remove:
|
|
207
|
+
self.remove_initializer(initializer)
|
|
208
|
+
|
|
209
|
+
def get_non_initializer_inputs(self):
|
|
210
|
+
initializer_names = self.get_initializer_name_set()
|
|
211
|
+
non_initializer_inputs = set()
|
|
212
|
+
for input in self.model.graph.input:
|
|
213
|
+
if input.name not in initializer_names:
|
|
214
|
+
non_initializer_inputs.add(input.name)
|
|
215
|
+
return non_initializer_inputs
|
|
216
|
+
|
|
217
|
+
def input_name_to_nodes(self):
|
|
218
|
+
input_name_to_nodes = {}
|
|
219
|
+
for node in self.model.graph.node:
|
|
220
|
+
for input_name in node.input:
|
|
221
|
+
if input_name: # Could be empty when it is optional
|
|
222
|
+
if input_name not in input_name_to_nodes:
|
|
223
|
+
input_name_to_nodes[input_name] = [node]
|
|
224
|
+
else:
|
|
225
|
+
input_name_to_nodes[input_name].append(node)
|
|
226
|
+
return input_name_to_nodes
|
|
227
|
+
|
|
228
|
+
def output_name_to_node(self):
|
|
229
|
+
output_name_to_node = {}
|
|
230
|
+
for node in self.model.graph.node:
|
|
231
|
+
for output_name in node.output:
|
|
232
|
+
if output_name: # Could be empty when it is optional
|
|
233
|
+
output_name_to_node[output_name] = node
|
|
234
|
+
return output_name_to_node
|
|
235
|
+
|
|
236
|
+
def get_children(self, node, input_name_to_nodes=None):
|
|
237
|
+
if input_name_to_nodes is None:
|
|
238
|
+
input_name_to_nodes = self.input_name_to_nodes()
|
|
239
|
+
|
|
240
|
+
children = []
|
|
241
|
+
for output in node.output:
|
|
242
|
+
if output in input_name_to_nodes:
|
|
243
|
+
for node in input_name_to_nodes[output]:
|
|
244
|
+
children.append(node) # noqa: PERF402
|
|
245
|
+
return children
|
|
246
|
+
|
|
247
|
+
def get_parents(self, node, output_name_to_node=None):
|
|
248
|
+
if output_name_to_node is None:
|
|
249
|
+
output_name_to_node = self.output_name_to_node()
|
|
250
|
+
|
|
251
|
+
parents = []
|
|
252
|
+
for input in node.input:
|
|
253
|
+
if input in output_name_to_node:
|
|
254
|
+
parents.append(output_name_to_node[input])
|
|
255
|
+
return parents
|
|
256
|
+
|
|
257
|
+
def get_parent(self, node, idx, output_name_to_node=None):
|
|
258
|
+
if output_name_to_node is None:
|
|
259
|
+
output_name_to_node = self.output_name_to_node()
|
|
260
|
+
|
|
261
|
+
if len(node.input) <= idx:
|
|
262
|
+
return None
|
|
263
|
+
|
|
264
|
+
input = node.input[idx]
|
|
265
|
+
if input not in output_name_to_node:
|
|
266
|
+
return None
|
|
267
|
+
|
|
268
|
+
return output_name_to_node[input]
|
|
269
|
+
|
|
270
|
+
def find_node_by_name(self, node_name, new_nodes_list, graph):
|
|
271
|
+
"""Find out if a node exists in a graph or a node is in the
|
|
272
|
+
new set of nodes created during quantization.
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
The node found or None.
|
|
276
|
+
"""
|
|
277
|
+
graph_nodes_list = list(graph.node) # deep copy
|
|
278
|
+
graph_nodes_list.extend(new_nodes_list)
|
|
279
|
+
node = find_by_name(node_name, graph_nodes_list)
|
|
280
|
+
return node
|
|
281
|
+
|
|
282
|
+
def get_largest_node_name_suffix(self, node_name_prefix):
|
|
283
|
+
"""
|
|
284
|
+
Gets the largest node name (int) suffix for all node names that begin with `node_name_prefix`.
|
|
285
|
+
Example: for nodes my_prefix_0 and my_prefix_3, this method returns 3.
|
|
286
|
+
"""
|
|
287
|
+
suffix = -1
|
|
288
|
+
|
|
289
|
+
for node in self.model.graph.node:
|
|
290
|
+
if node.name and node.name.startswith(node_name_prefix):
|
|
291
|
+
try:
|
|
292
|
+
index = int(node.name[len(node_name_prefix) :])
|
|
293
|
+
suffix = max(index, suffix)
|
|
294
|
+
except ValueError:
|
|
295
|
+
continue
|
|
296
|
+
|
|
297
|
+
return suffix
|
|
298
|
+
|
|
299
|
+
def find_nodes_by_initializer(self, graph, initializer):
|
|
300
|
+
"""
|
|
301
|
+
Find all nodes with given initializer as an input.
|
|
302
|
+
"""
|
|
303
|
+
nodes = []
|
|
304
|
+
for node in graph.node:
|
|
305
|
+
for node_input in node.input:
|
|
306
|
+
if node_input == initializer.name:
|
|
307
|
+
nodes.append(node)
|
|
308
|
+
return nodes
|
|
309
|
+
|
|
310
|
+
@staticmethod
|
|
311
|
+
def __get_initializer(name, graph_path):
|
|
312
|
+
for gid in range(len(graph_path) - 1, -1, -1):
|
|
313
|
+
graph = graph_path[gid]
|
|
314
|
+
for tensor in graph.initializer:
|
|
315
|
+
if tensor.name == name:
|
|
316
|
+
return tensor, graph
|
|
317
|
+
return None, None
|
|
318
|
+
|
|
319
|
+
@staticmethod
|
|
320
|
+
def __replace_gemm_with_matmul(graph_path):
|
|
321
|
+
new_nodes = []
|
|
322
|
+
graph = graph_path[-1]
|
|
323
|
+
for node in graph.node:
|
|
324
|
+
graph_attrs = [attr for attr in node.attribute if attr.type == 5 or attr.type == 10]
|
|
325
|
+
if len(graph_attrs):
|
|
326
|
+
kwargs = {}
|
|
327
|
+
for attr in node.attribute:
|
|
328
|
+
if attr.type == 5:
|
|
329
|
+
graph_path.append(attr.g)
|
|
330
|
+
kv = {attr.name: ONNXModel.__replace_gemm_with_matmul(graph_path)}
|
|
331
|
+
elif attr.type == 10:
|
|
332
|
+
value = []
|
|
333
|
+
for subgraph in attr.graphs:
|
|
334
|
+
graph_path.append(subgraph)
|
|
335
|
+
value.extend([ONNXModel.__replace_gemm_with_matmul(graph_path)])
|
|
336
|
+
kv = {attr.name: value}
|
|
337
|
+
else:
|
|
338
|
+
kv = attribute_to_kwarg(attr)
|
|
339
|
+
kwargs.update(kv)
|
|
340
|
+
node = onnx_helper.make_node( # noqa: PLW2901
|
|
341
|
+
node.op_type, node.input, node.output, name=node.name, **kwargs
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
if node.op_type == "Gemm":
|
|
345
|
+
alpha = 1.0
|
|
346
|
+
beta = 1.0
|
|
347
|
+
transA = 0 # noqa: N806
|
|
348
|
+
transB = 0 # noqa: N806
|
|
349
|
+
for attr in node.attribute:
|
|
350
|
+
if attr.name == "alpha":
|
|
351
|
+
alpha = onnx_helper.get_attribute_value(attr)
|
|
352
|
+
elif attr.name == "beta":
|
|
353
|
+
beta = onnx_helper.get_attribute_value(attr)
|
|
354
|
+
elif attr.name == "transA":
|
|
355
|
+
transA = onnx_helper.get_attribute_value(attr) # noqa: N806
|
|
356
|
+
elif attr.name == "transB":
|
|
357
|
+
transB = onnx_helper.get_attribute_value(attr) # noqa: N806
|
|
358
|
+
if alpha == 1.0 and beta == 1.0 and transA == 0:
|
|
359
|
+
inputB = node.input[1] # noqa: N806
|
|
360
|
+
if transB == 1:
|
|
361
|
+
B, Bs_graph = ONNXModel.__get_initializer(node.input[1], graph_path) # noqa: N806
|
|
362
|
+
if B:
|
|
363
|
+
# assume B is not used by any other node
|
|
364
|
+
B_array = onnx_numpy_helper.to_array(B) # noqa: N806
|
|
365
|
+
B_trans = onnx_numpy_helper.from_array(B_array.T) # noqa: N806
|
|
366
|
+
B_trans.name = B.name
|
|
367
|
+
Bs_graph.initializer.remove(B)
|
|
368
|
+
for input in Bs_graph.input:
|
|
369
|
+
if input.name == inputB:
|
|
370
|
+
Bs_graph.input.remove(input)
|
|
371
|
+
break
|
|
372
|
+
Bs_graph.initializer.extend([B_trans])
|
|
373
|
+
else:
|
|
374
|
+
inputB += "_Transposed" # noqa: N806
|
|
375
|
+
transpose_node = onnx_helper.make_node(
|
|
376
|
+
"Transpose",
|
|
377
|
+
inputs=[node.input[1]],
|
|
378
|
+
outputs=[inputB],
|
|
379
|
+
name=node.name + "_Transpose" if node.name else "",
|
|
380
|
+
)
|
|
381
|
+
new_nodes.append(transpose_node)
|
|
382
|
+
|
|
383
|
+
matmul_node = onnx_helper.make_node(
|
|
384
|
+
"MatMul",
|
|
385
|
+
inputs=[node.input[0], inputB],
|
|
386
|
+
outputs=[node.output[0] + ("_MatMul" if len(node.input) > 2 else "")],
|
|
387
|
+
name=node.name + "_MatMul" if node.name else "",
|
|
388
|
+
)
|
|
389
|
+
new_nodes.append(matmul_node)
|
|
390
|
+
|
|
391
|
+
if len(node.input) > 2:
|
|
392
|
+
add_node = onnx_helper.make_node(
|
|
393
|
+
"Add",
|
|
394
|
+
inputs=[node.output[0] + "_MatMul", node.input[2]],
|
|
395
|
+
outputs=node.output,
|
|
396
|
+
name=node.name + "_Add" if node.name else "",
|
|
397
|
+
)
|
|
398
|
+
new_nodes.append(add_node)
|
|
399
|
+
|
|
400
|
+
# unsupported
|
|
401
|
+
else:
|
|
402
|
+
new_nodes.append(node)
|
|
403
|
+
|
|
404
|
+
# not GEMM
|
|
405
|
+
else:
|
|
406
|
+
new_nodes.append(node)
|
|
407
|
+
|
|
408
|
+
graph.ClearField("node")
|
|
409
|
+
graph.node.extend(new_nodes)
|
|
410
|
+
graph_path.pop()
|
|
411
|
+
return graph
|
|
412
|
+
|
|
413
|
+
def replace_gemm_with_matmul(self):
|
|
414
|
+
graph_path = [self.graph()]
|
|
415
|
+
ONNXModel.__replace_gemm_with_matmul(graph_path)
|
|
416
|
+
|
|
417
|
+
def save_model_to_file(self, output_path, use_external_data_format=False):
|
|
418
|
+
"""
|
|
419
|
+
Save model to external data, which is needed for model size > 2GB
|
|
420
|
+
"""
|
|
421
|
+
self.topological_sort()
|
|
422
|
+
if use_external_data_format:
|
|
423
|
+
onnx.external_data_helper.convert_model_to_external_data(
|
|
424
|
+
self.model,
|
|
425
|
+
all_tensors_to_one_file=True,
|
|
426
|
+
location=Path(output_path).name + ".data",
|
|
427
|
+
convert_attribute=True,
|
|
428
|
+
)
|
|
429
|
+
for init in self.model.graph.initializer:
|
|
430
|
+
self._check_init(init, "end")
|
|
431
|
+
onnx.save_model(self.model, output_path)
|
|
432
|
+
|
|
433
|
+
@staticmethod
|
|
434
|
+
def replace_node_input(node, old_input_name, new_input_name):
|
|
435
|
+
assert isinstance(old_input_name, str) and isinstance(new_input_name, str)
|
|
436
|
+
for j in range(len(node.input)):
|
|
437
|
+
if node.input[j] == old_input_name:
|
|
438
|
+
node.input[j] = new_input_name
|
|
439
|
+
|
|
440
|
+
def replace_input_of_all_nodes(self, old_input_name, new_input_name):
|
|
441
|
+
for node in self.model.graph.node:
|
|
442
|
+
ONNXModel.replace_node_input(node, old_input_name, new_input_name)
|
|
443
|
+
|
|
444
|
+
def replace_input_of_nodes(self, old_input_name, new_input_name, node_names_set):
|
|
445
|
+
for node in self.model.graph.node:
|
|
446
|
+
if node.name in node_names_set:
|
|
447
|
+
ONNXModel.replace_node_input(node, old_input_name, new_input_name)
|
|
448
|
+
|
|
449
|
+
@staticmethod
|
|
450
|
+
def replace_node_output(node, old_output_name, new_output_name):
|
|
451
|
+
assert isinstance(old_output_name, str) and isinstance(new_output_name, str)
|
|
452
|
+
for j in range(len(node.output)):
|
|
453
|
+
if node.output[j] == old_output_name:
|
|
454
|
+
node.output[j] = new_output_name
|
|
455
|
+
|
|
456
|
+
def replace_output_of_all_nodes(self, old_output_name, new_output_name):
|
|
457
|
+
for node in self.model.graph.node:
|
|
458
|
+
ONNXModel.replace_node_output(node, old_output_name, new_output_name)
|
|
459
|
+
|
|
460
|
+
def replace_output_of_nodes(self, old_output_name, new_output_name, node_names_set):
|
|
461
|
+
for node in self.model.graph.node:
|
|
462
|
+
if node.name in node_names_set:
|
|
463
|
+
ONNXModel.replace_node_output(node, old_output_name, new_output_name)
|
|
464
|
+
|
|
465
|
+
def remove_unused_constant(self):
|
|
466
|
+
input_name_to_nodes = self.input_name_to_nodes()
|
|
467
|
+
|
|
468
|
+
# remove unused constant
|
|
469
|
+
unused_nodes = []
|
|
470
|
+
nodes = self.nodes()
|
|
471
|
+
for node in nodes:
|
|
472
|
+
if (
|
|
473
|
+
node.op_type == "Constant"
|
|
474
|
+
and not self.is_graph_output(node.output[0])
|
|
475
|
+
and node.output[0] not in input_name_to_nodes
|
|
476
|
+
):
|
|
477
|
+
unused_nodes.append(node)
|
|
478
|
+
|
|
479
|
+
self.remove_nodes(unused_nodes)
|
|
480
|
+
|
|
481
|
+
ununsed_weights = []
|
|
482
|
+
for w in self.initializer():
|
|
483
|
+
if w.name not in input_name_to_nodes and not self.is_graph_output(w.name):
|
|
484
|
+
ununsed_weights.append(w)
|
|
485
|
+
# Remove from graph.input
|
|
486
|
+
for graph_input in self.graph().input:
|
|
487
|
+
if graph_input.name == w.name:
|
|
488
|
+
self.graph().input.remove(graph_input)
|
|
489
|
+
|
|
490
|
+
self.remove_initializers(ununsed_weights)
|
|
491
|
+
|
|
492
|
+
def is_graph_output(self, output_name):
|
|
493
|
+
return any(output.name == output_name for output in self.model.graph.output)
|
|
494
|
+
|
|
495
|
+
def is_graph_input(self, tensor_name: str) -> bool:
|
|
496
|
+
return any(input.name == tensor_name for input in self.model.graph.input)
|
|
497
|
+
|
|
498
|
+
# TODO:use OnnxModel.graph_topological_sort(self.model.graph) from transformers.onnx_model
|
|
499
|
+
# Currently it breaks Openvino/Linux training gpu pipeline so hold off for 1.8 release
|
|
500
|
+
def topological_sort(self):
|
|
501
|
+
deps_count = [0] * len(self.nodes()) # dependency count of each node
|
|
502
|
+
deps_to_nodes = {} # input to node indice
|
|
503
|
+
sorted_nodes = [] # initialize sorted_nodes
|
|
504
|
+
for node_idx, node in enumerate(self.nodes()):
|
|
505
|
+
# CANNOT use len(node.input) directly because input can be optional
|
|
506
|
+
deps_count[node_idx] = sum(1 for _ in node.input if _)
|
|
507
|
+
if deps_count[node_idx] == 0: # Constant doesn't depend on any inputs
|
|
508
|
+
sorted_nodes.append(self.nodes()[node_idx])
|
|
509
|
+
continue
|
|
510
|
+
|
|
511
|
+
for input_name in node.input:
|
|
512
|
+
if not input_name:
|
|
513
|
+
continue
|
|
514
|
+
if input_name not in deps_to_nodes:
|
|
515
|
+
deps_to_nodes[input_name] = [node_idx]
|
|
516
|
+
else:
|
|
517
|
+
deps_to_nodes[input_name].append(node_idx)
|
|
518
|
+
|
|
519
|
+
initializer_names = [init.name for init in self.initializer()]
|
|
520
|
+
graph_input_names = [input.name for input in self.model.graph.input]
|
|
521
|
+
input_names = initializer_names + graph_input_names
|
|
522
|
+
input_names.sort()
|
|
523
|
+
prev_input_name = None
|
|
524
|
+
for input_name in input_names:
|
|
525
|
+
if prev_input_name == input_name:
|
|
526
|
+
continue
|
|
527
|
+
|
|
528
|
+
prev_input_name = input_name
|
|
529
|
+
if input_name in deps_to_nodes:
|
|
530
|
+
for node_idx in deps_to_nodes[input_name]:
|
|
531
|
+
deps_count[node_idx] = deps_count[node_idx] - 1
|
|
532
|
+
if deps_count[node_idx] == 0:
|
|
533
|
+
sorted_nodes.append(self.nodes()[node_idx])
|
|
534
|
+
|
|
535
|
+
start = 0
|
|
536
|
+
end = len(sorted_nodes)
|
|
537
|
+
|
|
538
|
+
while start < end:
|
|
539
|
+
for output in sorted_nodes[start].output:
|
|
540
|
+
if output in deps_to_nodes:
|
|
541
|
+
for node_idx in deps_to_nodes[output]:
|
|
542
|
+
deps_count[node_idx] = deps_count[node_idx] - 1
|
|
543
|
+
if deps_count[node_idx] == 0:
|
|
544
|
+
sorted_nodes.append(self.nodes()[node_idx])
|
|
545
|
+
end = end + 1
|
|
546
|
+
start = start + 1
|
|
547
|
+
|
|
548
|
+
assert end == len(self.graph().node), "Graph is not a DAG"
|
|
549
|
+
self.graph().ClearField("node")
|
|
550
|
+
self.graph().node.extend(sorted_nodes)
|
|
551
|
+
|
|
552
|
+
def clean_initializers(self):
|
|
553
|
+
return _clean_initializers_helper(self.graph(), self.model)
|
|
554
|
+
|
|
555
|
+
def _check_init(self, init, test=None):
|
|
556
|
+
if init.data_type == onnx.TensorProto.FLOAT8E4M3FN:
|
|
557
|
+
if init.HasField("raw_data"):
|
|
558
|
+
b = list(init.raw_data)
|
|
559
|
+
if any(map(lambda i: (i & 127) == 127, b)):
|
|
560
|
+
raise ValueError(f"Initializer {init.name!r} has nan.")
|
|
561
|
+
return init
|
|
562
|
+
|
|
563
|
+
def _check_node(self, node):
|
|
564
|
+
"""
|
|
565
|
+
A quantization to float 8 does not use quantized bias but float 16 bias.
|
|
566
|
+
This function checks that DequantizeLinear is not used to
|
|
567
|
+
dequantize from float 16.
|
|
568
|
+
"""
|
|
569
|
+
if node.op_type == "DequantizeLinear":
|
|
570
|
+
zero_point = node.input[2]
|
|
571
|
+
init = self.get_initializer(zero_point)
|
|
572
|
+
dtype = init.data_type
|
|
573
|
+
if dtype in {
|
|
574
|
+
onnx.TensorProto.FLOAT16,
|
|
575
|
+
onnx.TensorProto.FLOAT,
|
|
576
|
+
onnx.TensorProto.DOUBLE,
|
|
577
|
+
onnx.TensorProto.BFLOAT16,
|
|
578
|
+
}:
|
|
579
|
+
raise RuntimeError(f"Unsupported DequantizeLinear operator, dequantization from {dtype}.")
|
|
580
|
+
return node
|