onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import onnx
|
|
9
|
+
|
|
10
|
+
from ...fusions import Fusion
|
|
11
|
+
from ...onnx_model import ONNXModel
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class FusionLpNormalization(Fusion):
|
|
15
|
+
def __init__(self, model: ONNXModel, epsilon: float = 1e-12):
|
|
16
|
+
super().__init__(model, "LpNormalization", "ReduceL2")
|
|
17
|
+
self.epsilon = epsilon
|
|
18
|
+
|
|
19
|
+
def fuse(
|
|
20
|
+
self,
|
|
21
|
+
reduce_node: onnx.NodeProto,
|
|
22
|
+
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
|
23
|
+
output_name_to_node: dict[str, onnx.NodeProto],
|
|
24
|
+
):
|
|
25
|
+
"""
|
|
26
|
+
Interface function that tries to fuse a node sequence containing a ReduceL2 node into a single
|
|
27
|
+
LpNormalization node.
|
|
28
|
+
|
|
29
|
+
Pattern 1:
|
|
30
|
+
[root] --> ReduceL2 -----> Clip --> Expand ----> Div -->
|
|
31
|
+
| (axis=-1) (min=epsilon) (shape=root) ^
|
|
32
|
+
| (keepdims=True) |
|
|
33
|
+
| |
|
|
34
|
+
+-----------------------------------------------+
|
|
35
|
+
Notes:
|
|
36
|
+
- ReduceL2 must use the last axis, and keepdims == True
|
|
37
|
+
- Clip must only have a min attribute that is ~1e-12
|
|
38
|
+
- Expand must restore the shape to root.shape
|
|
39
|
+
- The output of Expand must be the second input to Div.
|
|
40
|
+
"""
|
|
41
|
+
if reduce_node.output[0] not in input_name_to_nodes:
|
|
42
|
+
return
|
|
43
|
+
|
|
44
|
+
# ReduceL2 must have one Clip child
|
|
45
|
+
children = input_name_to_nodes[reduce_node.output[0]]
|
|
46
|
+
if len(children) != 1 or children[0].op_type != "Clip":
|
|
47
|
+
return
|
|
48
|
+
|
|
49
|
+
# ReduceL2 must have keepdims == True
|
|
50
|
+
keepdims = self.get_node_attribute(reduce_node, "keepdims")
|
|
51
|
+
if not keepdims:
|
|
52
|
+
return
|
|
53
|
+
|
|
54
|
+
# ReduceL2 axes must refer only to the last dimension.
|
|
55
|
+
# Axes became an input in opset 18. Before then, axes was an attribute
|
|
56
|
+
reduce_input_ttype = self.model.get_tensor_type(reduce_node.input[0])
|
|
57
|
+
if not reduce_input_ttype:
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
reduce_input_shape = self.tensor_shape_to_list(reduce_input_ttype)
|
|
61
|
+
if not reduce_input_shape:
|
|
62
|
+
return
|
|
63
|
+
|
|
64
|
+
axes = self.get_node_attribute(reduce_node, "axes")
|
|
65
|
+
if not axes and len(reduce_node.input) > 1:
|
|
66
|
+
axes = self.model.get_constant_value(reduce_node.input[1])
|
|
67
|
+
|
|
68
|
+
if not axes or len(axes) != 1:
|
|
69
|
+
return
|
|
70
|
+
|
|
71
|
+
last_dim = len(reduce_input_shape) - 1
|
|
72
|
+
if axes[0] != -1 and axes[0] != last_dim:
|
|
73
|
+
return
|
|
74
|
+
|
|
75
|
+
# Clip node must have a min attribute approximately equal to 1e-12
|
|
76
|
+
clip_node = children[0]
|
|
77
|
+
clip_min = self.get_node_attribute(clip_node, "min")
|
|
78
|
+
if clip_min is None and len(clip_node.input) > 1:
|
|
79
|
+
clip_min = self.model.get_constant_value(clip_node.input[1])
|
|
80
|
+
|
|
81
|
+
clip_max = self.get_node_attribute(clip_node, "max") # TODO: clip_max could be FLOAT_MAX
|
|
82
|
+
if clip_max is None and len(clip_node.input) > 2:
|
|
83
|
+
clip_max = self.model.get_constant_value(clip_node.input[2])
|
|
84
|
+
|
|
85
|
+
if not (clip_max is None and clip_min is not None and clip_min > 0 and abs(clip_min - self.epsilon) < 1e-13):
|
|
86
|
+
return
|
|
87
|
+
|
|
88
|
+
if clip_node.output[0] not in input_name_to_nodes:
|
|
89
|
+
return
|
|
90
|
+
|
|
91
|
+
# Clip must have a single Expand child.
|
|
92
|
+
children = input_name_to_nodes[clip_node.output[0]]
|
|
93
|
+
if len(children) != 1 or children[0].op_type != "Expand":
|
|
94
|
+
return
|
|
95
|
+
|
|
96
|
+
expand_node = children[0]
|
|
97
|
+
if expand_node.output[0] not in input_name_to_nodes:
|
|
98
|
+
return
|
|
99
|
+
|
|
100
|
+
# Expand must have a single Div child
|
|
101
|
+
children = input_name_to_nodes[expand_node.output[0]]
|
|
102
|
+
if len(children) != 1 or children[0].op_type != "Div":
|
|
103
|
+
return
|
|
104
|
+
|
|
105
|
+
div_node = children[0]
|
|
106
|
+
|
|
107
|
+
# The first input to Div must be the root of the subgraph (i.e., reduce_node.input[0])
|
|
108
|
+
# The second input to Div must be the output of the Expand.
|
|
109
|
+
# As long as these two inputs go to the same Div node, then ONNX validation will ensure that
|
|
110
|
+
# their shapes match.
|
|
111
|
+
if div_node.input[0] != reduce_node.input[0]:
|
|
112
|
+
return
|
|
113
|
+
if div_node.input[1] != expand_node.output[0]:
|
|
114
|
+
return
|
|
115
|
+
|
|
116
|
+
subgraph_input = reduce_node.input[0]
|
|
117
|
+
subgraph_output = div_node.output[0]
|
|
118
|
+
|
|
119
|
+
subgraph_nodes = [reduce_node, clip_node, expand_node, div_node]
|
|
120
|
+
if not self.is_safe_to_fuse_nodes(subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node):
|
|
121
|
+
return
|
|
122
|
+
|
|
123
|
+
self.nodes_to_remove.extend(subgraph_nodes)
|
|
124
|
+
fused_node = onnx.helper.make_node(
|
|
125
|
+
self.fused_op_type,
|
|
126
|
+
name=self.create_unique_node_name(),
|
|
127
|
+
inputs=[subgraph_input],
|
|
128
|
+
outputs=[subgraph_output],
|
|
129
|
+
p=2,
|
|
130
|
+
axis=-1,
|
|
131
|
+
)
|
|
132
|
+
self.nodes_to_add.append(fused_node)
|
|
@@ -0,0 +1,413 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
|
|
11
|
+
import onnx
|
|
12
|
+
|
|
13
|
+
from ...quant_utils import QuantType
|
|
14
|
+
from ...tensor_quant_overrides import QuantTypeInfo, TensorQuantOverridesHelper
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class TensorTypeRequest:
|
|
19
|
+
"""
|
|
20
|
+
Bundles desired quantization type requests for a tensor. A distinction is made between the
|
|
21
|
+
produced type and the consumed type.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
# The tensor's quant type at the producer end. If None, assumed to be the default activation quant type.
|
|
25
|
+
producer: QuantTypeInfo | None
|
|
26
|
+
|
|
27
|
+
# The tensor's quant type received by a set of consumer nodes.
|
|
28
|
+
# If None, assumed to be the default activation quant type for all consumers.
|
|
29
|
+
# consumers[1] is a set of consumer node names.
|
|
30
|
+
consumers: tuple[QuantTypeInfo, set[str]] | None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class MixedPrecisionTensorQuantOverridesFixer:
|
|
34
|
+
"""
|
|
35
|
+
Helper that generates tensor quantization overrides for mixed-precision QDQ models.
|
|
36
|
+
|
|
37
|
+
Specifically, this helper fixes an initial set of quantization overrides that assign a non-default
|
|
38
|
+
activation quantization type to one or more tensors by doing the following:
|
|
39
|
+
- Inferring which other tensors need to be overridden to the non-default activation quantization type.
|
|
40
|
+
- Inserting quantization data type conversions.
|
|
41
|
+
|
|
42
|
+
Example:
|
|
43
|
+
--------
|
|
44
|
+
|
|
45
|
+
Float model:
|
|
46
|
+
|
|
47
|
+
input_0 --> Op1 --> Op3 --> Op5 --> Op6 --> output_0
|
|
48
|
+
^
|
|
49
|
+
|
|
|
50
|
+
input_1 --> Op2 -+-> Op4 ----+
|
|
51
|
+
|
|
|
52
|
+
+-> Op7 --> output_1
|
|
53
|
+
|
|
|
54
|
+
+-> Op8 --> output_2
|
|
55
|
+
|
|
56
|
+
If we'd like to quantize this model to uint8 precision, but would like to make sure tensor "Op4_out"
|
|
57
|
+
is quantized to 16-bit, then we would specify the following initial tensor quantization overrides:
|
|
58
|
+
|
|
59
|
+
```
|
|
60
|
+
init_overrides = {"Op4_out": [{"quant_type": QuantType.QUInt16}]}
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
These initial overrides may not create a valid model because Op4 and Op5 may require both the input and output
|
|
64
|
+
to be the same type (e.g., uint16). This helper fixes the overrides so that input/output data types
|
|
65
|
+
are valid:
|
|
66
|
+
|
|
67
|
+
```
|
|
68
|
+
overrides = TensorQuantOverridesHelper(init_overrides)
|
|
69
|
+
|
|
70
|
+
fixer = MixedPrecisionTensorQuantOverridesFixer.create_from_model(overrides, model, QuantType.QUInt8)
|
|
71
|
+
fixer.apply(
|
|
72
|
+
default_activation_qtype=QuantType.QUInt8,
|
|
73
|
+
default_activation_symmetric=False,
|
|
74
|
+
)
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
The above snippet generates the following "fixed" overrides (get via overrides.get_dict()):
|
|
78
|
+
|
|
79
|
+
{
|
|
80
|
+
"Op2_out": [{"quant_type": QUInt8, "convert": {"quant_type": QUInt16, "recv_nodes": {"Op4"}}}],
|
|
81
|
+
"Op3_out": [{"quant_type": QUInt8, "convert": {"quant_type": QUInt16, "recv_nodes": {"Op5"}}}],
|
|
82
|
+
"Op4_out": [{"quant_type": QUInt16}],
|
|
83
|
+
"Op5_out": [{"quant_type": QUInt16, "convert": {"quant_type": QUInt8, "recv_nodes": {"Op6"}}}]
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
How to interpret the fixed overrides:
|
|
87
|
+
- Op2's output is consumed by Op4, Op7, and Op8. Op4 consumes the converted u16 type,
|
|
88
|
+
but Op7 and Op8 consume the original u8 type.
|
|
89
|
+
- Op3's output is converted from u8 to u16. Op5 consumes the converted u16 type.
|
|
90
|
+
- Op4's output is just u16 (not converted). All consumers of Op4_out get the u16 type.
|
|
91
|
+
- Op5's output is converted from u16 to u8. Op6 consumes the u8 type.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
overrides: TensorQuantOverridesHelper,
|
|
97
|
+
producers: dict[str, onnx.NodeProto],
|
|
98
|
+
consumers: dict[str, list[onnx.NodeProto]],
|
|
99
|
+
value_infos: dict[str, onnx.ValueInfoProto],
|
|
100
|
+
initializers: dict[str, onnx.TensorProto],
|
|
101
|
+
):
|
|
102
|
+
"""
|
|
103
|
+
Params:
|
|
104
|
+
overrides: The initial tensor quantization overrides to fix.
|
|
105
|
+
producers: Dictionary that maps a tensor name to the producer node that generates the tensor.
|
|
106
|
+
consumers: Dictionary that maps a tensor name to the consumer nodes that take the tensor as input.
|
|
107
|
+
value_infos: Dictionary that maps a tensor name to its onnx.ValueInfoProto.
|
|
108
|
+
initializers: Dictionary that maps an initializer name to its onnx.TensorProto.
|
|
109
|
+
"""
|
|
110
|
+
self.overrides = overrides
|
|
111
|
+
self.consumers = consumers
|
|
112
|
+
self.producers = producers
|
|
113
|
+
self.value_infos = value_infos
|
|
114
|
+
self.initializers = initializers
|
|
115
|
+
|
|
116
|
+
@staticmethod
|
|
117
|
+
def create_from_model(
|
|
118
|
+
overrides: TensorQuantOverridesHelper, model: onnx.ModelProto, default_activation_qtype: QuantType
|
|
119
|
+
) -> MixedPrecisionTensorQuantOverridesFixer:
|
|
120
|
+
"""
|
|
121
|
+
Helper function that creates an instance of this class from a loaded ONNX model.
|
|
122
|
+
|
|
123
|
+
Params:
|
|
124
|
+
overrides: The initial tensor quantization overrides to fix.
|
|
125
|
+
model: Loaded ONNX model
|
|
126
|
+
default_activation_qtype: The intended default activation quantization type.
|
|
127
|
+
Used to validate the initial overrides.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Initialized MixedPrecisionTensorQuantOverridesFixer object
|
|
131
|
+
"""
|
|
132
|
+
model = onnx.shape_inference.infer_shapes(model) # Need to infer shapes to get value_infos
|
|
133
|
+
|
|
134
|
+
# Build dictionaries that enable convenient lookups of initializers and value_infos by name.
|
|
135
|
+
initializers = {initializer.name: initializer for initializer in model.graph.initializer}
|
|
136
|
+
value_infos = {vi.name: vi for vi in model.graph.value_info}
|
|
137
|
+
value_infos.update({ot.name: ot for ot in model.graph.output})
|
|
138
|
+
value_infos.update({it.name: it for it in model.graph.input})
|
|
139
|
+
|
|
140
|
+
# Ensure that the user-provided initial overrides are actually valid.
|
|
141
|
+
valid, err = overrides.is_valid(initializers, set(value_infos), default_activation_qtype)
|
|
142
|
+
if not valid:
|
|
143
|
+
pprint_overrides = overrides.pprint_str(indent=4)
|
|
144
|
+
logging.error(f"Provided invalid tensor quantization overrides:\n{pprint_overrides}")
|
|
145
|
+
raise ValueError(err)
|
|
146
|
+
|
|
147
|
+
consumers = {}
|
|
148
|
+
producers = {}
|
|
149
|
+
|
|
150
|
+
# Build dictionaries that map a tensor name to the consumer or producer nodes.
|
|
151
|
+
for node in model.graph.node:
|
|
152
|
+
for input_name in node.input:
|
|
153
|
+
if input_name:
|
|
154
|
+
if input_name not in consumers:
|
|
155
|
+
consumers[input_name] = []
|
|
156
|
+
|
|
157
|
+
consumers[input_name].append(node)
|
|
158
|
+
|
|
159
|
+
for output_name in node.output:
|
|
160
|
+
producers[output_name] = node
|
|
161
|
+
|
|
162
|
+
return MixedPrecisionTensorQuantOverridesFixer(overrides, producers, consumers, value_infos, initializers)
|
|
163
|
+
|
|
164
|
+
def apply(
|
|
165
|
+
self,
|
|
166
|
+
default_activation_qtype: QuantType,
|
|
167
|
+
default_activation_symmetric: bool,
|
|
168
|
+
):
|
|
169
|
+
"""
|
|
170
|
+
Fixes the initial tensor quantization overrides (in-place) for use in mixed-precision QDQ models.
|
|
171
|
+
|
|
172
|
+
Params:
|
|
173
|
+
default_activation_qtype: The intended default activation quantization type.
|
|
174
|
+
default_activation_symmetric: The intended default symmetry used to quantize activations.
|
|
175
|
+
"""
|
|
176
|
+
type_requests = self.get_desired_tensor_types(default_activation_qtype, default_activation_symmetric)
|
|
177
|
+
|
|
178
|
+
# Use type requests to "fix" tensor quantization overrides by adding
|
|
179
|
+
# quantization type conversions where necessary.
|
|
180
|
+
for tensor_name, type_req in type_requests.items():
|
|
181
|
+
all_consumers = set([node.name for node in self.consumers.get(tensor_name, [])])
|
|
182
|
+
has_producer_req = type_req.producer is not None
|
|
183
|
+
has_consumer_req = bool(type_req.consumers)
|
|
184
|
+
|
|
185
|
+
# Only producer type: Add conversion back to default activation type
|
|
186
|
+
if has_producer_req and not has_consumer_req:
|
|
187
|
+
self._update_converted_tensor(
|
|
188
|
+
tensor_name, type_req.producer, QuantTypeInfo(default_activation_qtype), all_consumers
|
|
189
|
+
)
|
|
190
|
+
# Only consumers
|
|
191
|
+
elif not has_producer_req and has_consumer_req:
|
|
192
|
+
prod_type_info = self.overrides.get_node_output_qtype_info(tensor_name, default_activation_qtype)
|
|
193
|
+
consumer_type_info = type_req.consumers[0]
|
|
194
|
+
|
|
195
|
+
if prod_type_info != consumer_type_info:
|
|
196
|
+
self._update_converted_tensor(
|
|
197
|
+
tensor_name, prod_type_info, consumer_type_info, type_req.consumers[1]
|
|
198
|
+
)
|
|
199
|
+
else:
|
|
200
|
+
if not self._check_nodes_are_not_convert_consumers(tensor_name, type_req.consumers[1]):
|
|
201
|
+
raise ValueError(
|
|
202
|
+
f"Tensor override for '{tensor_name}' converts the type for consumers that need the original type."
|
|
203
|
+
)
|
|
204
|
+
# Both producer and consumers
|
|
205
|
+
elif has_producer_req and has_consumer_req:
|
|
206
|
+
prod_type_info = type_req.producer
|
|
207
|
+
consumer_type_info = type_req.consumers[0]
|
|
208
|
+
|
|
209
|
+
if prod_type_info != consumer_type_info:
|
|
210
|
+
self._update_converted_tensor(
|
|
211
|
+
tensor_name, prod_type_info, consumer_type_info, type_req.consumers[1]
|
|
212
|
+
)
|
|
213
|
+
else:
|
|
214
|
+
consumers_for_original_type = all_consumers.difference(type_req.consumers[1])
|
|
215
|
+
|
|
216
|
+
if len(consumers_for_original_type) == 0:
|
|
217
|
+
# All consumers want the overridden type, so no need for convert nodes!
|
|
218
|
+
# Just add the override to the new new if not already present.
|
|
219
|
+
if tensor_name not in self.overrides:
|
|
220
|
+
self.overrides[tensor_name] = [{}]
|
|
221
|
+
prod_type_info.save_to_dict(self.overrides[tensor_name][0])
|
|
222
|
+
|
|
223
|
+
assert "convert" not in self.overrides[tensor_name][0]
|
|
224
|
+
else:
|
|
225
|
+
# Some consumers don't want the overridden type.
|
|
226
|
+
self._update_converted_tensor(
|
|
227
|
+
tensor_name,
|
|
228
|
+
prod_type_info,
|
|
229
|
+
QuantTypeInfo(default_activation_qtype),
|
|
230
|
+
consumers_for_original_type,
|
|
231
|
+
)
|
|
232
|
+
else:
|
|
233
|
+
raise ValueError(f"TypeRequest for tensor {tensor_name} has no producer or consumers.")
|
|
234
|
+
|
|
235
|
+
# Done. Check if the overrides are valid.
|
|
236
|
+
valid, err = self.overrides.is_valid(self.initializers, set(self.value_infos), default_activation_qtype)
|
|
237
|
+
if not valid:
|
|
238
|
+
pprint_overrides = self.overrides.pprint_str(indent=4)
|
|
239
|
+
logging.error(
|
|
240
|
+
f"Generated invalid tensor quantization overrides for mixed-precision QDQ model:\n{pprint_overrides}"
|
|
241
|
+
)
|
|
242
|
+
raise ValueError(err)
|
|
243
|
+
|
|
244
|
+
def get_desired_tensor_types(
|
|
245
|
+
self,
|
|
246
|
+
default_activation_qtype: QuantType,
|
|
247
|
+
default_activation_symmetric: bool,
|
|
248
|
+
) -> dict[str, TensorTypeRequest]:
|
|
249
|
+
"""
|
|
250
|
+
Iterates through the initial tensor quantization overrides and builds a set of TensorTypeRequests objects
|
|
251
|
+
that describe the quantization types required at each tensor. These TensorTypeRequests objects are ultimately
|
|
252
|
+
used to generated the "fixed" overrides.
|
|
253
|
+
|
|
254
|
+
Params:
|
|
255
|
+
default_activation_qtype: The intended default activation quantization type.
|
|
256
|
+
default_activation_symmetric: The intended default symmetry used to quantize activations.
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
TensorTypeRequest objects as a dict that maps a tensor name to its requested types.
|
|
260
|
+
"""
|
|
261
|
+
type_requests = {}
|
|
262
|
+
default_activation_type_info = QuantTypeInfo(default_activation_qtype, default_activation_symmetric)
|
|
263
|
+
|
|
264
|
+
# Scan tensor overrides for type conversion requests.
|
|
265
|
+
for tensor_name, override_list in self.overrides.items():
|
|
266
|
+
if not self.__is_tensor_quantizable(tensor_name):
|
|
267
|
+
continue # Skip non-quantizable tensors (e.g., not a float)
|
|
268
|
+
|
|
269
|
+
if tensor_name in self.initializers:
|
|
270
|
+
continue # Skip initializers
|
|
271
|
+
|
|
272
|
+
if not override_list or len(override_list) > 1:
|
|
273
|
+
continue # Skip per-channel stuff
|
|
274
|
+
|
|
275
|
+
override_dict = override_list[0]
|
|
276
|
+
quant_type_info = QuantTypeInfo.load_from_dict(override_dict, default_activation_type_info.quant_type)
|
|
277
|
+
producer_node = self.producers.get(tensor_name) # None if this is a model input
|
|
278
|
+
|
|
279
|
+
if quant_type_info != default_activation_type_info and "convert" not in override_dict:
|
|
280
|
+
if producer_node is not None:
|
|
281
|
+
self._add_type_requests_for_node(type_requests, quant_type_info, producer_node)
|
|
282
|
+
|
|
283
|
+
# Find all consumer nodes of `tensor_name` and update their inputs/outputs to the new type.
|
|
284
|
+
for consumer_node in self.consumers.get(tensor_name, []):
|
|
285
|
+
self._add_type_requests_for_node(type_requests, quant_type_info, consumer_node)
|
|
286
|
+
|
|
287
|
+
return type_requests
|
|
288
|
+
|
|
289
|
+
def _add_type_requests_for_node(
|
|
290
|
+
self,
|
|
291
|
+
type_requests: dict[str, TensorTypeRequest],
|
|
292
|
+
quant_type_info: QuantTypeInfo,
|
|
293
|
+
node: onnx.NodeProto,
|
|
294
|
+
):
|
|
295
|
+
"""
|
|
296
|
+
Adds TensorTypeRequest objects for a given node, assuming that we want all its inputs and outputs
|
|
297
|
+
to have the same quantization type (as specified by the `quant_type_info` parameter).
|
|
298
|
+
|
|
299
|
+
Params:
|
|
300
|
+
type_requests: Dictionary of type requests to append to for this node.
|
|
301
|
+
quant_type_info: The quantization type to use for inputs and outputs.
|
|
302
|
+
node: The node for which the TensorTypeRequest objects are created and added to type_requests.
|
|
303
|
+
"""
|
|
304
|
+
# Add output side
|
|
305
|
+
for output_name in node.output:
|
|
306
|
+
if not self.__is_tensor_quantizable(output_name):
|
|
307
|
+
continue
|
|
308
|
+
|
|
309
|
+
if output_name not in type_requests:
|
|
310
|
+
type_requests[output_name] = TensorTypeRequest(quant_type_info, None)
|
|
311
|
+
else:
|
|
312
|
+
if (
|
|
313
|
+
type_requests[output_name].producer is not None
|
|
314
|
+
and type_requests[output_name].producer != quant_type_info
|
|
315
|
+
):
|
|
316
|
+
raise ValueError(f"Tensor {output_name} has multiple types.")
|
|
317
|
+
|
|
318
|
+
type_requests[output_name].producer = quant_type_info
|
|
319
|
+
|
|
320
|
+
# Add the consumer side
|
|
321
|
+
for input_name in node.input:
|
|
322
|
+
if input_name and input_name not in self.initializers and self.__is_tensor_quantizable(input_name):
|
|
323
|
+
if input_name not in type_requests:
|
|
324
|
+
type_requests[input_name] = TensorTypeRequest(None, None)
|
|
325
|
+
|
|
326
|
+
if type_requests[input_name].consumers is None:
|
|
327
|
+
type_requests[input_name].consumers = (quant_type_info, set())
|
|
328
|
+
|
|
329
|
+
if type_requests[input_name].consumers[0] != quant_type_info:
|
|
330
|
+
raise ValueError(f"Tensor {input_name} has consumers requesting different types.")
|
|
331
|
+
|
|
332
|
+
if not node.name:
|
|
333
|
+
raise ValueError(
|
|
334
|
+
f"Node of type {node.op_type} with output 0 {node.output[0]} does not have a name!"
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
type_requests[input_name].consumers[1].add(node.name)
|
|
338
|
+
|
|
339
|
+
def _update_converted_tensor(
|
|
340
|
+
self,
|
|
341
|
+
tensor_name: str,
|
|
342
|
+
producer_type_info: QuantTypeInfo,
|
|
343
|
+
consumer_type_info: QuantTypeInfo,
|
|
344
|
+
consumer_names: set[str],
|
|
345
|
+
):
|
|
346
|
+
"""
|
|
347
|
+
Updates the tensor quantization overrides for a tensor that is converted from one type to another.
|
|
348
|
+
|
|
349
|
+
Params:
|
|
350
|
+
tensor_name: The name of the tensor for which to update overrides.
|
|
351
|
+
producer_type_info: Info for the tensor's produced type.
|
|
352
|
+
consumer_type_info: Info for the tensor's consumed (i.e., converted) type.
|
|
353
|
+
consumer_names: Nodes names of consumers that consume the converted type.
|
|
354
|
+
"""
|
|
355
|
+
if tensor_name not in self.overrides or not self.overrides[tensor_name]:
|
|
356
|
+
self.overrides[tensor_name] = [{}]
|
|
357
|
+
producer_type_info.save_to_dict(self.overrides[tensor_name][0])
|
|
358
|
+
|
|
359
|
+
overrides = self.overrides[tensor_name][0]
|
|
360
|
+
if producer_type_info != QuantTypeInfo.load_from_dict(overrides):
|
|
361
|
+
raise ValueError(f"Desired producer quant_type for {tensor_name} doesn't match existing type.")
|
|
362
|
+
|
|
363
|
+
if consumer_names:
|
|
364
|
+
if "convert" not in overrides:
|
|
365
|
+
overrides["convert"] = {}
|
|
366
|
+
consumer_type_info.save_to_dict(overrides["convert"])
|
|
367
|
+
|
|
368
|
+
convert_dict = overrides["convert"]
|
|
369
|
+
if consumer_type_info != QuantTypeInfo.load_from_dict(convert_dict):
|
|
370
|
+
raise ValueError(f"Desired consumer quant_type for {tensor_name} doesn't match existing type.")
|
|
371
|
+
|
|
372
|
+
if "recv_nodes" not in convert_dict:
|
|
373
|
+
convert_dict["recv_nodes"] = set()
|
|
374
|
+
|
|
375
|
+
convert_dict["recv_nodes"].update(consumer_names)
|
|
376
|
+
|
|
377
|
+
def _check_nodes_are_not_convert_consumers(self, tensor_name: str, node_names: set[str]):
|
|
378
|
+
"""
|
|
379
|
+
Returns true if the given nodes do not consume/receive a converted quantization type.
|
|
380
|
+
|
|
381
|
+
Params:
|
|
382
|
+
tensor_name: The name of the tensor to check.
|
|
383
|
+
node_names: Set of node names that should not be consumers of the converted type.
|
|
384
|
+
"""
|
|
385
|
+
if tensor_name not in self.overrides or not self.overrides[tensor_name]:
|
|
386
|
+
return True
|
|
387
|
+
|
|
388
|
+
overrides = self.overrides[tensor_name][0]
|
|
389
|
+
|
|
390
|
+
if "convert" not in overrides:
|
|
391
|
+
return True
|
|
392
|
+
|
|
393
|
+
convert_dict = overrides["convert"]
|
|
394
|
+
|
|
395
|
+
if "recv_nodes" not in convert_dict:
|
|
396
|
+
return False
|
|
397
|
+
|
|
398
|
+
return not convert_dict["recv_nodes"].intersection(node_names)
|
|
399
|
+
|
|
400
|
+
def __is_tensor_quantizable(self, tensor_name):
|
|
401
|
+
weight = self.initializers.get(tensor_name)
|
|
402
|
+
if weight is not None:
|
|
403
|
+
if weight.data_type in (onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT16):
|
|
404
|
+
return True
|
|
405
|
+
elif tensor_name in self.value_infos:
|
|
406
|
+
vi = self.value_infos[tensor_name]
|
|
407
|
+
if vi.type.HasField("tensor_type") and vi.type.tensor_type.elem_type in (
|
|
408
|
+
onnx.TensorProto.FLOAT,
|
|
409
|
+
onnx.TensorProto.FLOAT16,
|
|
410
|
+
):
|
|
411
|
+
return True
|
|
412
|
+
|
|
413
|
+
return False
|