onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,532 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Any, Dict
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import onnx
|
|
11
|
+
import onnx.numpy_helper
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from onnx.reference.op_run import to_array_extended
|
|
15
|
+
except ImportError:
|
|
16
|
+
# old version of onnx.
|
|
17
|
+
to_array_extended = None
|
|
18
|
+
|
|
19
|
+
from .calibrate import TensorData
|
|
20
|
+
from .onnx_model import ONNXModel
|
|
21
|
+
from .quant_utils import (
|
|
22
|
+
ONNX_TYPE_TO_NP_TYPE,
|
|
23
|
+
TENSOR_NAME_QUANT_SUFFIX,
|
|
24
|
+
QuantType,
|
|
25
|
+
find_by_name,
|
|
26
|
+
model_has_infer_metadata,
|
|
27
|
+
normalize_axis,
|
|
28
|
+
pack_bytes_to_4bit,
|
|
29
|
+
quantize_data,
|
|
30
|
+
quantize_nparray,
|
|
31
|
+
save_and_reload_model_with_shape_infer,
|
|
32
|
+
tensor_proto_to_array,
|
|
33
|
+
)
|
|
34
|
+
from .tensor_quant_overrides import TensorQuantOverridesHelper
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class QuantizationParams:
|
|
38
|
+
def __init__(self, **data: Dict[str, Any]):
|
|
39
|
+
self.data = {}
|
|
40
|
+
for k, v in data.items():
|
|
41
|
+
if not isinstance(k, str):
|
|
42
|
+
raise TypeError(f"Keys must be strings not {type(k)} for k={k!r}.")
|
|
43
|
+
if not isinstance(v, (int, str, np.ndarray)):
|
|
44
|
+
raise TypeError(f"Values must be numpy arrays, int, float, str not {type(v)} for k={k!r}.")
|
|
45
|
+
if k == "scale" and v.dtype not in (np.float32, np.float16):
|
|
46
|
+
raise ValueError(f"scale must a float32 or float16 numpy element but is {v.dtype} for k={k!r}")
|
|
47
|
+
self.data[k] = v
|
|
48
|
+
|
|
49
|
+
def __iter__(self):
|
|
50
|
+
yield from self.data
|
|
51
|
+
|
|
52
|
+
def __getitem__(self, key):
|
|
53
|
+
return self.data[key]
|
|
54
|
+
|
|
55
|
+
def __len__(self):
|
|
56
|
+
return len(self.data)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class BaseQuantizer:
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
model,
|
|
63
|
+
per_channel,
|
|
64
|
+
reduce_range,
|
|
65
|
+
weight_qType,
|
|
66
|
+
activation_qType,
|
|
67
|
+
tensors_range,
|
|
68
|
+
nodes_to_quantize,
|
|
69
|
+
nodes_to_exclude,
|
|
70
|
+
op_types_to_quantize,
|
|
71
|
+
extra_options=None,
|
|
72
|
+
):
|
|
73
|
+
if not model_has_infer_metadata(model):
|
|
74
|
+
model = save_and_reload_model_with_shape_infer(model)
|
|
75
|
+
self.value_infos = {vi.name: vi for vi in model.graph.value_info}
|
|
76
|
+
self.value_infos.update({ot.name: ot for ot in model.graph.output})
|
|
77
|
+
self.value_infos.update({it.name: it for it in model.graph.input})
|
|
78
|
+
|
|
79
|
+
self.model = ONNXModel(model)
|
|
80
|
+
self.per_channel = per_channel # weight-pack per channel
|
|
81
|
+
self.reduce_range = reduce_range
|
|
82
|
+
|
|
83
|
+
self.extra_options = extra_options if extra_options else {}
|
|
84
|
+
self.enable_subgraph_quantization = (
|
|
85
|
+
"EnableSubgraph" in self.extra_options and self.extra_options["EnableSubgraph"]
|
|
86
|
+
)
|
|
87
|
+
self.parent = None
|
|
88
|
+
self.force_quantize_no_input_check = (
|
|
89
|
+
"ForceQuantizeNoInputCheck" in self.extra_options and self.extra_options["ForceQuantizeNoInputCheck"]
|
|
90
|
+
)
|
|
91
|
+
self.is_weight_symmetric = self.extra_options.get(
|
|
92
|
+
"WeightSymmetric", weight_qType in (QuantType.QInt8, QuantType.QInt16, QuantType.QFLOAT8E4M3FN)
|
|
93
|
+
)
|
|
94
|
+
self.is_activation_symmetric = self.extra_options.get("ActivationSymmetric", False)
|
|
95
|
+
self.min_real_range = self.extra_options.get("MinimumRealRange")
|
|
96
|
+
|
|
97
|
+
self.activation_qType = getattr(activation_qType, "tensor_type", activation_qType)
|
|
98
|
+
self.weight_qType = getattr(weight_qType, "tensor_type", weight_qType)
|
|
99
|
+
|
|
100
|
+
"""
|
|
101
|
+
Dictionary specifying the min and max values for tensors. It has following format:
|
|
102
|
+
{
|
|
103
|
+
"param_name": [min, max]
|
|
104
|
+
}
|
|
105
|
+
example:
|
|
106
|
+
{
|
|
107
|
+
'Conv_3:0': [np.float32(0), np.float32(0.5)],
|
|
108
|
+
'Conv_4:0': [np.float32(1), np.float32(3.5)]
|
|
109
|
+
}
|
|
110
|
+
"""
|
|
111
|
+
if tensors_range is not None and any(map(lambda t: not isinstance(t, TensorData), tensors_range.values())):
|
|
112
|
+
raise TypeError(
|
|
113
|
+
f"tensors_range contains unexpected types {set(type(v) for v in tensors_range.values())}, not TensorData."
|
|
114
|
+
)
|
|
115
|
+
self.tensors_range = tensors_range
|
|
116
|
+
self.nodes_to_quantize = nodes_to_quantize # specific nodes to quantize
|
|
117
|
+
self.nodes_to_exclude = nodes_to_exclude # specific nodes to exclude
|
|
118
|
+
self.op_types_to_quantize = op_types_to_quantize
|
|
119
|
+
|
|
120
|
+
self.opset_version = self.check_opset_version()
|
|
121
|
+
|
|
122
|
+
# Get tensor-level quantization overrides and ensure they are valid.
|
|
123
|
+
self.tensor_quant_overrides = TensorQuantOverridesHelper(self.extra_options.get("TensorQuantOverrides", {}))
|
|
124
|
+
|
|
125
|
+
self.initializers = {initzer.name: initzer for initzer in self.model.initializer()}
|
|
126
|
+
overrides_valid, overrides_err = self.tensor_quant_overrides.is_valid(
|
|
127
|
+
self.initializers, self.value_infos.keys(), activation_qType
|
|
128
|
+
)
|
|
129
|
+
if not overrides_valid:
|
|
130
|
+
raise ValueError(overrides_err)
|
|
131
|
+
|
|
132
|
+
self.tensor_quant_override_qtypes = self.tensor_quant_overrides.get_quant_types()
|
|
133
|
+
|
|
134
|
+
def quantize_model(self):
|
|
135
|
+
raise NotImplementedError
|
|
136
|
+
|
|
137
|
+
def is_input_a_initializer(self, input_name):
|
|
138
|
+
initializer = find_by_name(input_name, self.model.initializer())
|
|
139
|
+
return initializer is not None
|
|
140
|
+
|
|
141
|
+
def is_per_channel(self):
|
|
142
|
+
return self.per_channel
|
|
143
|
+
|
|
144
|
+
def is_valid_quantize_weight(self, weight_name):
|
|
145
|
+
weight = find_by_name(weight_name, self.model.initializer())
|
|
146
|
+
if weight is not None:
|
|
147
|
+
return weight.data_type in (onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT16)
|
|
148
|
+
if (not self.enable_subgraph_quantization) or (self.parent is None):
|
|
149
|
+
return False
|
|
150
|
+
return self.parent.is_valid_quantize_weight(weight_name)
|
|
151
|
+
|
|
152
|
+
def should_quantize_node(self, node):
|
|
153
|
+
if (
|
|
154
|
+
self.nodes_to_quantize is not None
|
|
155
|
+
and len(self.nodes_to_quantize) != 0
|
|
156
|
+
and node.name not in self.nodes_to_quantize
|
|
157
|
+
):
|
|
158
|
+
return False
|
|
159
|
+
|
|
160
|
+
if node.op_type not in self.op_types_to_quantize:
|
|
161
|
+
return False
|
|
162
|
+
|
|
163
|
+
if self.nodes_to_exclude is not None and node.name in self.nodes_to_exclude:
|
|
164
|
+
return False
|
|
165
|
+
|
|
166
|
+
return True
|
|
167
|
+
|
|
168
|
+
def check_opset_version(self):
|
|
169
|
+
ai_onnx_domain = [
|
|
170
|
+
opset for opset in self.model.model.opset_import if not opset.domain or opset.domain == "ai.onnx"
|
|
171
|
+
]
|
|
172
|
+
if len(ai_onnx_domain) != 1:
|
|
173
|
+
raise ValueError("Failed to find proper ai.onnx domain")
|
|
174
|
+
opset_version = ai_onnx_domain[0].version
|
|
175
|
+
|
|
176
|
+
if opset_version == 10:
|
|
177
|
+
logging.warning(
|
|
178
|
+
f"The original model opset version is {opset_version}, which does not support node fusions. Please update the model to opset >= 11 for better performance."
|
|
179
|
+
)
|
|
180
|
+
return 10
|
|
181
|
+
|
|
182
|
+
if opset_version < 10:
|
|
183
|
+
logging.warning(
|
|
184
|
+
f"The original model opset version is {opset_version}, which does not support quantization. Please update the model to opset >= 11. Updating the model automatically to opset 11. Please verify the quantized model."
|
|
185
|
+
)
|
|
186
|
+
self.model.model.opset_import.remove(ai_onnx_domain[0])
|
|
187
|
+
self.model.model.opset_import.extend([onnx.helper.make_opsetid("", 11)])
|
|
188
|
+
opset_version = 11
|
|
189
|
+
|
|
190
|
+
if opset_version < 19 and self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN:
|
|
191
|
+
logging.warning(
|
|
192
|
+
f"The original model opset version is {opset_version}, which does not support quantization to float 8. "
|
|
193
|
+
"Please update the model to opset >= 19. Updating the model automatically to opset 19. "
|
|
194
|
+
"Please verify the quantized model."
|
|
195
|
+
)
|
|
196
|
+
self.model.model.opset_import.remove(ai_onnx_domain[0])
|
|
197
|
+
self.model.model.opset_import.extend([onnx.helper.make_opsetid("", 19)])
|
|
198
|
+
self.model.model.ir_version = 9
|
|
199
|
+
opset_version = 19
|
|
200
|
+
|
|
201
|
+
return opset_version
|
|
202
|
+
|
|
203
|
+
def quantize_bias_static_impl(self, bias_name, input_scale, weight_scale, beta=1.0):
|
|
204
|
+
"""
|
|
205
|
+
Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale
|
|
206
|
+
"""
|
|
207
|
+
|
|
208
|
+
# get bias
|
|
209
|
+
bias_initializer = find_by_name(bias_name, self.model.initializer())
|
|
210
|
+
bias_data = tensor_proto_to_array(bias_initializer)
|
|
211
|
+
quantized_bias_name = bias_name + TENSOR_NAME_QUANT_SUFFIX
|
|
212
|
+
|
|
213
|
+
# quantize bias
|
|
214
|
+
if self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN:
|
|
215
|
+
data = np.asarray(bias_data)
|
|
216
|
+
if data.dtype == np.float16:
|
|
217
|
+
node_qtype = onnx.TensorProto.FLOAT16
|
|
218
|
+
elif data.dtype == np.float32:
|
|
219
|
+
node_qtype = onnx.TensorProto.FLOAT
|
|
220
|
+
else:
|
|
221
|
+
raise TypeError(f"Only float16 or float32 are supported with float 8 but bias dtype is {data.dtype}.")
|
|
222
|
+
quantized_data = data.astype(np.float32)
|
|
223
|
+
bias_scale = np.array([1], dtype=quantized_data.dtype)
|
|
224
|
+
bias_scale_data = bias_scale.reshape(-1)
|
|
225
|
+
packed_bias_initializer = onnx.numpy_helper.from_array(quantized_data, quantized_bias_name)
|
|
226
|
+
self.model.initializer_extend([packed_bias_initializer])
|
|
227
|
+
node_type = "Cast"
|
|
228
|
+
else:
|
|
229
|
+
# calculate scale for bias
|
|
230
|
+
# TODO: This formula should be explained including why the scale is not estimated for the bias as well.
|
|
231
|
+
bias_scale = input_scale * weight_scale * beta
|
|
232
|
+
|
|
233
|
+
quantized_data = (np.asarray(bias_data) / bias_scale).round()
|
|
234
|
+
quantized_data = np.clip(quantized_data, np.iinfo(np.int32).min, np.iinfo(np.int32).max)
|
|
235
|
+
quantized_data = quantized_data.astype(np.int32)
|
|
236
|
+
|
|
237
|
+
# update bias initializer
|
|
238
|
+
bias_np_data = np.asarray(quantized_data, dtype=np.int32).reshape(bias_initializer.dims)
|
|
239
|
+
packed_bias_initializer = onnx.numpy_helper.from_array(bias_np_data, quantized_bias_name)
|
|
240
|
+
self.model.initializer_extend([packed_bias_initializer])
|
|
241
|
+
|
|
242
|
+
# Bias's scale dtype should match the original bias data's unquantized type (float32 or float16).
|
|
243
|
+
bias_scale_data = np.asarray(bias_scale, dtype=bias_data.dtype).reshape(-1)
|
|
244
|
+
node_type = "DequantizeLinear"
|
|
245
|
+
node_qtype = self.weight_qType
|
|
246
|
+
|
|
247
|
+
# update scale initializer
|
|
248
|
+
quantized_bias_scale_name = quantized_bias_name + "_scale"
|
|
249
|
+
packed_bias_scale_initializer = onnx.numpy_helper.from_array(bias_scale_data, quantized_bias_scale_name)
|
|
250
|
+
self.model.initializer_extend([packed_bias_scale_initializer])
|
|
251
|
+
|
|
252
|
+
# update zero initializer
|
|
253
|
+
if self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN:
|
|
254
|
+
tensor_type = self.weight_qType
|
|
255
|
+
else:
|
|
256
|
+
tensor_type = onnx.TensorProto.INT32
|
|
257
|
+
|
|
258
|
+
quantized_bias_zp_name = quantized_bias_name + "_zero_point"
|
|
259
|
+
if self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN:
|
|
260
|
+
packed_bias_zp_initializer = onnx.helper.make_tensor(quantized_bias_zp_name, self.weight_qType, [1], [0.0])
|
|
261
|
+
elif bias_scale.size > 1:
|
|
262
|
+
bias_zp_data = np.zeros(bias_scale.shape, dtype=np.int32).reshape(-1)
|
|
263
|
+
packed_bias_zp_initializer = onnx.numpy_helper.from_array(bias_zp_data, quantized_bias_zp_name)
|
|
264
|
+
else:
|
|
265
|
+
packed_bias_zp_initializer = onnx.helper.make_tensor(quantized_bias_zp_name, tensor_type, [], [0])
|
|
266
|
+
self.model.initializer_extend([packed_bias_zp_initializer])
|
|
267
|
+
|
|
268
|
+
return (
|
|
269
|
+
quantized_bias_name,
|
|
270
|
+
quantized_bias_scale_name,
|
|
271
|
+
quantized_bias_zp_name,
|
|
272
|
+
bias_scale_data,
|
|
273
|
+
node_type,
|
|
274
|
+
node_qtype,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
def quantize_initializer_impl(self, weight, qType, reduce_range=False, keep_float_weight=False):
|
|
278
|
+
"""
|
|
279
|
+
:param weight: TensorProto initializer
|
|
280
|
+
:param qType: type to quantize to
|
|
281
|
+
:param keep_float_weight: Whether to quantize the weight. In some cases, we only want to qunatize scale and zero point.
|
|
282
|
+
If keep_float_weight is False, quantize the weight, or don't quantize the weight.
|
|
283
|
+
:return: quantized weight name, zero point name, scale name
|
|
284
|
+
"""
|
|
285
|
+
q_weight_name = weight.name + TENSOR_NAME_QUANT_SUFFIX
|
|
286
|
+
zp_name = weight.name + "_zero_point"
|
|
287
|
+
scale_name = weight.name + "_scale"
|
|
288
|
+
|
|
289
|
+
# Quantize weight data. Use quantization overrides if provided by the user.
|
|
290
|
+
weight_data = tensor_proto_to_array(weight)
|
|
291
|
+
quant_overrides = self.tensor_quant_overrides.get_per_tensor_overrides(weight.name, default_val={})
|
|
292
|
+
if "quant_type" in quant_overrides:
|
|
293
|
+
qType = quant_overrides["quant_type"].tensor_type # noqa: N806
|
|
294
|
+
|
|
295
|
+
if "scale" in quant_overrides and "zero_point" in quant_overrides:
|
|
296
|
+
zero_point = np.array(quant_overrides["zero_point"], dtype=ONNX_TYPE_TO_NP_TYPE[qType])
|
|
297
|
+
scale = np.array(quant_overrides["scale"])
|
|
298
|
+
q_weight_data = quantize_nparray(qType, weight_data.flatten(), scale, zero_point)
|
|
299
|
+
assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
|
|
300
|
+
assert (
|
|
301
|
+
zero_point.dtype != np.float32 and zero_point.dtype != np.float16
|
|
302
|
+
), f"Unexpected dtype {zero_point.dtype}"
|
|
303
|
+
assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
|
|
304
|
+
|
|
305
|
+
else:
|
|
306
|
+
_, _, zero_point, scale, q_weight_data = quantize_data(
|
|
307
|
+
weight_data.flatten(),
|
|
308
|
+
qType,
|
|
309
|
+
quant_overrides.get("symmetric", self.is_weight_symmetric),
|
|
310
|
+
reduce_range=quant_overrides.get("reduce_range", self.reduce_range and reduce_range),
|
|
311
|
+
min_real_range=self.min_real_range,
|
|
312
|
+
rmin_override=quant_overrides.get("rmin"),
|
|
313
|
+
rmax_override=quant_overrides.get("rmax"),
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
|
|
317
|
+
assert (
|
|
318
|
+
zero_point.dtype != np.float32 and zero_point.dtype != np.float16
|
|
319
|
+
), f"Unexpected dtype {zero_point.dtype}"
|
|
320
|
+
assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
|
|
321
|
+
|
|
322
|
+
scale_dtype = weight.data_type
|
|
323
|
+
scale_initializer = onnx.helper.make_tensor(scale_name, scale_dtype, [], scale.reshape((-1,)).tolist())
|
|
324
|
+
zero_initializer = onnx.helper.make_tensor(zp_name, qType, [], zero_point.reshape((-1,)).tolist())
|
|
325
|
+
self.model.initializer_extend([scale_initializer, zero_initializer])
|
|
326
|
+
|
|
327
|
+
if not keep_float_weight:
|
|
328
|
+
if self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN:
|
|
329
|
+
q_weight_initializer = onnx.TensorProto()
|
|
330
|
+
q_weight_initializer.data_type = self.weight_qType
|
|
331
|
+
q_weight_initializer.dims.extend(weight.dims)
|
|
332
|
+
q_weight_initializer.name = q_weight_name
|
|
333
|
+
# Do not remove .flatten().copy() numpy is not clear about data persistence.
|
|
334
|
+
q_weight_initializer.raw_data = q_weight_data.flatten().copy().tobytes()
|
|
335
|
+
if to_array_extended is not None:
|
|
336
|
+
# This test should not be needed but it helped catch some issues
|
|
337
|
+
# with data persistence and tobytes.
|
|
338
|
+
check = to_array_extended(q_weight_initializer)
|
|
339
|
+
if check.shape != weight_data.shape or check.tobytes() != q_weight_data.tobytes():
|
|
340
|
+
raise RuntimeError(
|
|
341
|
+
f"The initializer of shape {weight_data.shape} could not be created, expecting "
|
|
342
|
+
f"{q_weight_data.tobytes()[:10]}, got {check.tobytes()[:10]} and shape={weight.shape}"
|
|
343
|
+
f"\nraw={str(q_weight_initializer)[:200]}."
|
|
344
|
+
)
|
|
345
|
+
elif qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4):
|
|
346
|
+
if q_weight_data.dtype not in (np.int8, np.uint8):
|
|
347
|
+
raise RuntimeError(
|
|
348
|
+
f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values."
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
# We do not use onnx.helper.pack_float32_to_4bit() due to performance.
|
|
352
|
+
# This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes.
|
|
353
|
+
packed_data = bytes(pack_bytes_to_4bit(q_weight_data.tobytes()))
|
|
354
|
+
|
|
355
|
+
# We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161
|
|
356
|
+
q_weight_initializer = onnx.helper.make_tensor(q_weight_name, qType, weight.dims, packed_data, raw=True)
|
|
357
|
+
else:
|
|
358
|
+
q_weight_data = np.asarray(q_weight_data, dtype=onnx.helper.tensor_dtype_to_np_dtype(qType)).reshape(
|
|
359
|
+
weight.dims
|
|
360
|
+
)
|
|
361
|
+
q_weight_initializer = onnx.numpy_helper.from_array(q_weight_data, q_weight_name)
|
|
362
|
+
self.model.initializer_extend([q_weight_initializer])
|
|
363
|
+
|
|
364
|
+
return q_weight_name, zp_name, scale_name
|
|
365
|
+
|
|
366
|
+
def quantize_weight_per_channel_impl(
|
|
367
|
+
self,
|
|
368
|
+
weight_name,
|
|
369
|
+
weight_qType,
|
|
370
|
+
channel_axis,
|
|
371
|
+
reduce_range=True,
|
|
372
|
+
keep_float_weight=False,
|
|
373
|
+
):
|
|
374
|
+
initializer = find_by_name(weight_name, self.model.initializer())
|
|
375
|
+
if initializer is None:
|
|
376
|
+
raise ValueError("{} is not an initializer", weight_name)
|
|
377
|
+
|
|
378
|
+
weights = tensor_proto_to_array(initializer)
|
|
379
|
+
weights_rank = len(weights.shape)
|
|
380
|
+
is_axis_valid, axis_norm = normalize_axis(channel_axis, weights_rank)
|
|
381
|
+
if not is_axis_valid:
|
|
382
|
+
raise ValueError(
|
|
383
|
+
f"Weight {weight_name} has a per-channel axis with value {channel_axis} that is "
|
|
384
|
+
f"out-of-bounds for rank {weights_rank}"
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
channel_axis = axis_norm
|
|
388
|
+
channel_count = weights.shape[channel_axis]
|
|
389
|
+
quant_overrides_for_channels = self.tensor_quant_overrides.get_per_channel_overrides(
|
|
390
|
+
weight_name, default_val=[{"axis": channel_axis}]
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
num_channel_overrides = len(quant_overrides_for_channels)
|
|
394
|
+
if num_channel_overrides != 1 and num_channel_overrides != channel_count:
|
|
395
|
+
raise ValueError(
|
|
396
|
+
f"Per-channel tensor quantization overrides for {weight_name} must have "
|
|
397
|
+
f"either 1 or {channel_count} elements in the list of dictionaries."
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
is_axis_override_valid, axis_override = normalize_axis(quant_overrides_for_channels[0]["axis"], weights_rank)
|
|
401
|
+
if not is_axis_override_valid or axis_override != channel_axis:
|
|
402
|
+
raise ValueError(
|
|
403
|
+
f"Tensor quantization overrides for {weight_name} specify an unexpected axis. "
|
|
404
|
+
f"Expected {channel_axis}, but got {quant_overrides_for_channels[0]['axis']}."
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
# If user provides per-channel quantization overrides, all channels must use the same quant_type,
|
|
408
|
+
# axis, symmetric, and reduce_range values. So, just use the first channel's values.
|
|
409
|
+
if "quant_type" in quant_overrides_for_channels[0]:
|
|
410
|
+
weight_qType = quant_overrides_for_channels[0]["quant_type"].tensor_type # noqa: N806
|
|
411
|
+
|
|
412
|
+
symmetric = quant_overrides_for_channels[0].get(
|
|
413
|
+
"symmetric",
|
|
414
|
+
(
|
|
415
|
+
self.is_weight_symmetric
|
|
416
|
+
or weight_qType in (onnx.TensorProto.INT8, onnx.TensorProto.FLOAT8E4M3FN, onnx.TensorProto.INT4)
|
|
417
|
+
),
|
|
418
|
+
)
|
|
419
|
+
reduce_range = quant_overrides_for_channels[0].get("reduce_range", self.reduce_range and reduce_range)
|
|
420
|
+
zero_point_list = []
|
|
421
|
+
scale_list = []
|
|
422
|
+
quantized_per_channel_data_list = []
|
|
423
|
+
weights_shape = list(weights.shape)
|
|
424
|
+
reshape_dims = list(weights_shape) # deep copy
|
|
425
|
+
reshape_dims[channel_axis] = 1 # only one per channel for reshape
|
|
426
|
+
for i in range(channel_count):
|
|
427
|
+
per_channel_data = weights.take(i, channel_axis)
|
|
428
|
+
channel_override_index = i if i < num_channel_overrides else 0
|
|
429
|
+
channel_quant_overrides = quant_overrides_for_channels[channel_override_index]
|
|
430
|
+
|
|
431
|
+
if "scale" in channel_quant_overrides and "zero_point" in channel_quant_overrides:
|
|
432
|
+
zero_point = np.array(channel_quant_overrides["zero_point"], dtype=ONNX_TYPE_TO_NP_TYPE[weight_qType])
|
|
433
|
+
scale = np.array(channel_quant_overrides["scale"])
|
|
434
|
+
quantized_per_channel_data = quantize_nparray(
|
|
435
|
+
weight_qType, per_channel_data.flatten(), scale, zero_point
|
|
436
|
+
)
|
|
437
|
+
assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
|
|
438
|
+
assert (
|
|
439
|
+
zero_point.dtype != np.float32 and zero_point.dtype != np.float16
|
|
440
|
+
), f"Unexpected dtype {zero_point.dtype}"
|
|
441
|
+
assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
|
|
442
|
+
assert isinstance(
|
|
443
|
+
quantized_per_channel_data, np.ndarray
|
|
444
|
+
), f"Unexpected type {type(quantized_per_channel_data)}"
|
|
445
|
+
|
|
446
|
+
else:
|
|
447
|
+
_, _, zero_point, scale, quantized_per_channel_data = quantize_data(
|
|
448
|
+
per_channel_data.flatten(),
|
|
449
|
+
weight_qType,
|
|
450
|
+
symmetric,
|
|
451
|
+
reduce_range=reduce_range,
|
|
452
|
+
min_real_range=self.min_real_range,
|
|
453
|
+
rmin_override=channel_quant_overrides.get("rmin"),
|
|
454
|
+
rmax_override=channel_quant_overrides.get("rmax"),
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
assert isinstance(zero_point, np.ndarray), f"Unexpected type {type(zero_point)}"
|
|
458
|
+
assert (
|
|
459
|
+
zero_point.dtype != np.float32 and zero_point.dtype != np.float16
|
|
460
|
+
), f"Unexpected dtype {zero_point.dtype}"
|
|
461
|
+
assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"
|
|
462
|
+
assert isinstance(
|
|
463
|
+
quantized_per_channel_data, np.ndarray
|
|
464
|
+
), f"Unexpected type {type(quantized_per_channel_data)}"
|
|
465
|
+
|
|
466
|
+
zero_point_list.append(zero_point)
|
|
467
|
+
scale_list.append(scale)
|
|
468
|
+
quantized_per_channel_data_list.append(np.asarray(quantized_per_channel_data).reshape(reshape_dims))
|
|
469
|
+
|
|
470
|
+
# combine per_channel_data into one
|
|
471
|
+
quantized_weights = np.concatenate(quantized_per_channel_data_list, channel_axis)
|
|
472
|
+
q_weight_name = weight_name + TENSOR_NAME_QUANT_SUFFIX
|
|
473
|
+
zp_name = weight_name + "_zero_point"
|
|
474
|
+
scale_name = weight_name + "_scale"
|
|
475
|
+
|
|
476
|
+
# Update packed weight, zero point, and scale initializers
|
|
477
|
+
zero_scale_shape = [initializer.dims[channel_axis]]
|
|
478
|
+
scale_initializer = onnx.helper.make_tensor(
|
|
479
|
+
scale_name, initializer.data_type, zero_scale_shape, np.hstack(scale_list).tolist()
|
|
480
|
+
)
|
|
481
|
+
zero_initializer = onnx.helper.make_tensor(
|
|
482
|
+
zp_name, weight_qType, zero_scale_shape, np.hstack(zero_point_list).tolist()
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
self.model.initializer_extend([scale_initializer, zero_initializer])
|
|
486
|
+
|
|
487
|
+
if not keep_float_weight:
|
|
488
|
+
if weight_qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4):
|
|
489
|
+
if quantized_weights.dtype not in (np.int8, np.uint8):
|
|
490
|
+
raise RuntimeError(
|
|
491
|
+
f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values."
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
# We do not use onnx.helper.pack_float32_to_4bit() due to performance.
|
|
495
|
+
# This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes.
|
|
496
|
+
packed_data = bytes(pack_bytes_to_4bit(quantized_weights.tobytes()))
|
|
497
|
+
|
|
498
|
+
# We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161
|
|
499
|
+
q_weight_initializer = onnx.helper.make_tensor(
|
|
500
|
+
q_weight_name, weight_qType, weights_shape, packed_data, raw=True
|
|
501
|
+
)
|
|
502
|
+
self.model.initializer_extend([q_weight_initializer])
|
|
503
|
+
else:
|
|
504
|
+
quantized_weights = np.asarray(
|
|
505
|
+
quantized_weights,
|
|
506
|
+
dtype=onnx.helper.tensor_dtype_to_np_dtype(weight_qType),
|
|
507
|
+
).reshape(initializer.dims)
|
|
508
|
+
q_weight_initializer = onnx.numpy_helper.from_array(quantized_weights, q_weight_name)
|
|
509
|
+
self.model.initializer_extend([q_weight_initializer])
|
|
510
|
+
|
|
511
|
+
return q_weight_name, zp_name, scale_name
|
|
512
|
+
|
|
513
|
+
def adjust_tensor_ranges(self):
|
|
514
|
+
if self.tensors_range is None:
|
|
515
|
+
return
|
|
516
|
+
|
|
517
|
+
for node in self.model.nodes():
|
|
518
|
+
# adjust tensor_ranges for input of Clip and Relu node
|
|
519
|
+
if node.op_type in ["Clip", "Relu"]:
|
|
520
|
+
if not self.should_quantize_node(node):
|
|
521
|
+
continue
|
|
522
|
+
if len(self.model.input_name_to_nodes()[node.input[0]]) != 1:
|
|
523
|
+
continue
|
|
524
|
+
if node.input[0] not in self.tensors_range or node.output[0] not in self.tensors_range:
|
|
525
|
+
continue
|
|
526
|
+
td = self.tensors_range[node.output[0]]
|
|
527
|
+
if not isinstance(td, TensorData):
|
|
528
|
+
raise TypeError(f"Unexpected type {type(td)} for {node.output[0]!r}.")
|
|
529
|
+
self.tensors_range[node.input[0]] = td
|
|
530
|
+
# Adjust Softmax to range from 0.0 to 1.0
|
|
531
|
+
elif node.op_type == "Softmax":
|
|
532
|
+
self.tensors_range[node.output[0]] = TensorData(lowest=np.float32(0.0), highest=np.float32(1.0))
|