onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1051 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import copy
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
import tempfile
|
|
12
|
+
from enum import Enum
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
|
|
15
|
+
import numpy
|
|
16
|
+
import onnx
|
|
17
|
+
from ml_dtypes import float8_e4m3fn, int4, uint4
|
|
18
|
+
from onnx import ModelProto, TensorProto, external_data_helper
|
|
19
|
+
from onnx import onnx_pb as onnx_proto
|
|
20
|
+
from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info
|
|
21
|
+
from onnx.reference import ReferenceEvaluator
|
|
22
|
+
|
|
23
|
+
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
from onnx.reference.op_run import to_array_extended
|
|
27
|
+
except ImportError:
|
|
28
|
+
# old version of onnx.
|
|
29
|
+
to_array_extended = None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
__producer__ = "onnx.quantize"
|
|
33
|
+
__version__ = "0.1.0"
|
|
34
|
+
onnx_domain = "ai.onnx"
|
|
35
|
+
ms_domain = "com.microsoft"
|
|
36
|
+
QUANT_OP_NAME = "QuantizeLinear"
|
|
37
|
+
QUANT_INPUT_SUFFIX = "_QuantizeLinear_Input"
|
|
38
|
+
DEQUANT_OP_NAME = "DequantizeLinear"
|
|
39
|
+
DEQUANT_OUTPUT_SUFFIX = "_DequantizeLinear_Output"
|
|
40
|
+
TENSOR_NAME_QUANT_SUFFIX = "_quantized"
|
|
41
|
+
MODEL_SIZE_THRESHOLD = 2147483648 # Quant model should use external data if >= 2GB
|
|
42
|
+
|
|
43
|
+
FLOAT8_DISTRIBUTIONS = {}
|
|
44
|
+
|
|
45
|
+
type_to_name = {getattr(TensorProto, k): k for k in dir(TensorProto) if isinstance(getattr(TensorProto, k), int)}
|
|
46
|
+
|
|
47
|
+
# Quantization mode
|
|
48
|
+
# IntegerOps: Use IntegerOps in quantized model. Only ConvInteger and MatMulInteger ops are supported now.
|
|
49
|
+
# QLinearOps: Use QLinearOps in quantized model. Only QLinearConv and QLinearMatMul ops are supported now.
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class QuantizationMode(Enum):
|
|
53
|
+
IntegerOps = 0
|
|
54
|
+
QLinearOps = 1
|
|
55
|
+
|
|
56
|
+
def __str__(self):
|
|
57
|
+
return self.name
|
|
58
|
+
|
|
59
|
+
@staticmethod
|
|
60
|
+
def from_string(mode):
|
|
61
|
+
try:
|
|
62
|
+
return QuantizationMode[mode]
|
|
63
|
+
except KeyError:
|
|
64
|
+
raise ValueError() # noqa: B904
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class QuantizedValueType(Enum):
|
|
68
|
+
Input = 0
|
|
69
|
+
Initializer = 1
|
|
70
|
+
|
|
71
|
+
def __str__(self):
|
|
72
|
+
return self.name
|
|
73
|
+
|
|
74
|
+
@staticmethod
|
|
75
|
+
def from_string(v):
|
|
76
|
+
try:
|
|
77
|
+
return QuantizedValueType[v]
|
|
78
|
+
except KeyError:
|
|
79
|
+
raise ValueError() # noqa: B904
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class QuantType(Enum):
|
|
83
|
+
QInt8 = 0
|
|
84
|
+
QUInt8 = 1
|
|
85
|
+
QFLOAT8E4M3FN = 2
|
|
86
|
+
QInt16 = 3
|
|
87
|
+
QUInt16 = 4
|
|
88
|
+
QInt4 = 5
|
|
89
|
+
QUInt4 = 6
|
|
90
|
+
|
|
91
|
+
def __str__(self):
|
|
92
|
+
return self.name
|
|
93
|
+
|
|
94
|
+
@staticmethod
|
|
95
|
+
def from_string(t):
|
|
96
|
+
try:
|
|
97
|
+
return QuantType[t]
|
|
98
|
+
except KeyError:
|
|
99
|
+
raise ValueError() # noqa: B904
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def tensor_type(self):
|
|
103
|
+
if self == QuantType.QInt8:
|
|
104
|
+
return TensorProto.INT8
|
|
105
|
+
if self == QuantType.QUInt8:
|
|
106
|
+
return TensorProto.UINT8
|
|
107
|
+
if self == QuantType.QUInt16:
|
|
108
|
+
return TensorProto.UINT16
|
|
109
|
+
if self == QuantType.QInt16:
|
|
110
|
+
return TensorProto.INT16
|
|
111
|
+
if self == QuantType.QFLOAT8E4M3FN:
|
|
112
|
+
return TensorProto.FLOAT8E4M3FN
|
|
113
|
+
if self == QuantType.QUInt4:
|
|
114
|
+
return TensorProto.UINT4
|
|
115
|
+
if self == QuantType.QInt4:
|
|
116
|
+
return TensorProto.INT4
|
|
117
|
+
raise ValueError(f"Unexpected value qtype={self!r}.")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class QuantFormat(Enum):
|
|
121
|
+
QOperator = 0
|
|
122
|
+
QDQ = 1
|
|
123
|
+
|
|
124
|
+
def __str__(self):
|
|
125
|
+
return self.name
|
|
126
|
+
|
|
127
|
+
@staticmethod
|
|
128
|
+
def from_string(format):
|
|
129
|
+
try:
|
|
130
|
+
return QuantFormat[format]
|
|
131
|
+
except KeyError:
|
|
132
|
+
raise ValueError() # noqa: B904
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
ONNX_TYPE_TO_NP_TYPE = {
|
|
136
|
+
onnx_proto.TensorProto.INT8: numpy.dtype("int8"),
|
|
137
|
+
onnx_proto.TensorProto.UINT8: numpy.dtype("uint8"),
|
|
138
|
+
onnx_proto.TensorProto.INT16: numpy.dtype("int16"),
|
|
139
|
+
onnx_proto.TensorProto.UINT16: numpy.dtype("uint16"),
|
|
140
|
+
onnx_proto.TensorProto.FLOAT8E4M3FN: float8_e4m3fn,
|
|
141
|
+
onnx_proto.TensorProto.INT4: int4,
|
|
142
|
+
onnx_proto.TensorProto.UINT4: uint4,
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
ONNX_INT_TYPE_RANGE = {
|
|
146
|
+
onnx_proto.TensorProto.UINT8: (numpy.array(0, dtype=numpy.uint8), numpy.array(255, dtype=numpy.uint8)),
|
|
147
|
+
onnx_proto.TensorProto.INT8: (numpy.array(-128, dtype=numpy.int8), numpy.array(127, dtype=numpy.int8)),
|
|
148
|
+
onnx_proto.TensorProto.UINT16: (numpy.array(0, dtype=numpy.uint16), numpy.array(65535, dtype=numpy.uint16)),
|
|
149
|
+
onnx_proto.TensorProto.INT16: (numpy.array(-32768, dtype=numpy.int16), numpy.array(32767, dtype=numpy.int16)),
|
|
150
|
+
onnx_proto.TensorProto.UINT4: (numpy.array(0, dtype=uint4), numpy.array(15, dtype=uint4)),
|
|
151
|
+
onnx_proto.TensorProto.INT4: (numpy.array(-8, dtype=int4), numpy.array(7, dtype=int4)),
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
ONNX_INT_TYPE_SYMMETRIC_RANGE = {
|
|
155
|
+
onnx_proto.TensorProto.INT8: (numpy.array(-127, dtype=numpy.int8), numpy.array(127, dtype=numpy.int8)),
|
|
156
|
+
onnx_proto.TensorProto.INT16: (numpy.array(-32767, dtype=numpy.int16), numpy.array(32767, dtype=numpy.int16)),
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
ONNX_INT_TYPE_REDUCED_RANGE = {
|
|
160
|
+
onnx_proto.TensorProto.UINT8: (numpy.array(0, dtype=numpy.uint8), numpy.array(127, dtype=numpy.uint8)),
|
|
161
|
+
onnx_proto.TensorProto.INT8: (numpy.array(-64, dtype=numpy.int8), numpy.array(64, dtype=numpy.int8)),
|
|
162
|
+
onnx_proto.TensorProto.UINT16: (numpy.array(0, dtype=numpy.uint16), numpy.array(32767, dtype=numpy.uint16)),
|
|
163
|
+
onnx_proto.TensorProto.INT16: (numpy.array(-16384, dtype=numpy.int16), numpy.array(16384, dtype=numpy.int16)),
|
|
164
|
+
onnx_proto.TensorProto.UINT4: (numpy.array(0, dtype=uint4), numpy.array(7, dtype=uint4)),
|
|
165
|
+
onnx_proto.TensorProto.INT4: (numpy.array(-4, dtype=int4), numpy.array(3, dtype=int4)),
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _check_type(*args, zero_point_index=-1):
|
|
170
|
+
new_args = []
|
|
171
|
+
for i, a in enumerate(args):
|
|
172
|
+
if numpy.issubdtype(type(a), numpy.number):
|
|
173
|
+
new_args.append(numpy.array(a))
|
|
174
|
+
elif isinstance(a, numpy.ndarray):
|
|
175
|
+
new_args.append(a)
|
|
176
|
+
else:
|
|
177
|
+
raise TypeError(f"arg {i} is not an array: {a}")
|
|
178
|
+
if i == zero_point_index:
|
|
179
|
+
v = new_args[-1]
|
|
180
|
+
if v.dtype == numpy.float32 or v.dtype == numpy.float16:
|
|
181
|
+
raise TypeError(f"zero_point cannot be {v.dtype}")
|
|
182
|
+
return tuple(new_args) if len(new_args) > 1 else new_args[0]
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None):
|
|
186
|
+
assert qType in ONNX_TYPE_TO_NP_TYPE, (
|
|
187
|
+
f"Unexpected data type {qType} requested. Only INT8, UINT8, INT16, and UINT16 are supported."
|
|
188
|
+
)
|
|
189
|
+
if qType in (
|
|
190
|
+
onnx_proto.TensorProto.FLOAT8E4M3FN,
|
|
191
|
+
onnx_proto.TensorProto.FLOAT8E4M3FNUZ,
|
|
192
|
+
onnx_proto.TensorProto.FLOAT8E5M2,
|
|
193
|
+
onnx_proto.TensorProto.FLOAT8E5M2FNUZ,
|
|
194
|
+
):
|
|
195
|
+
if zero_point != 0:
|
|
196
|
+
raise NotImplementedError(f"zero_point is expected to be null for float 8 not {zero_point!r}.")
|
|
197
|
+
if arr.dtype == numpy.float32:
|
|
198
|
+
onnx_type = TensorProto.FLOAT
|
|
199
|
+
elif arr.dtype == numpy.float16:
|
|
200
|
+
onnx_type = TensorProto.FLOAT16
|
|
201
|
+
else:
|
|
202
|
+
raise ValueError(f"Unexpected dtype {arr.dtype}.")
|
|
203
|
+
onnx_model = make_model(
|
|
204
|
+
make_graph(
|
|
205
|
+
[
|
|
206
|
+
make_node(
|
|
207
|
+
"Constant", [], ["zero_point"], value=onnx.helper.make_tensor("zero_point", qType, [], [0])
|
|
208
|
+
),
|
|
209
|
+
make_node("QuantizeLinear", ["X", "scale", "zero_point"], ["Y"]),
|
|
210
|
+
],
|
|
211
|
+
"qu",
|
|
212
|
+
[
|
|
213
|
+
make_tensor_value_info("X", onnx_type, None),
|
|
214
|
+
make_tensor_value_info("scale", onnx_type, None),
|
|
215
|
+
],
|
|
216
|
+
[make_tensor_value_info("Y", qType, None)],
|
|
217
|
+
)
|
|
218
|
+
)
|
|
219
|
+
ref = ReferenceEvaluator(onnx_model)
|
|
220
|
+
return _check_type(ref.run(None, {"X": arr, "scale": scale})[0])
|
|
221
|
+
else:
|
|
222
|
+
# Quantizes data for all integer types.
|
|
223
|
+
#
|
|
224
|
+
# For int4 types, the quantized data is returned as either np.int8 or np.uint8,
|
|
225
|
+
# which matches the python reference ONNX implementation of QuantizeLinear.
|
|
226
|
+
# This data can be packed into 4-bit elements by using pack_bytes_to_4bit().
|
|
227
|
+
dtype = ONNX_TYPE_TO_NP_TYPE[qType]
|
|
228
|
+
qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=False)
|
|
229
|
+
|
|
230
|
+
cliplow = max(qmin, low) if low is not None else qmin
|
|
231
|
+
cliphigh = min(qmax, high) if high is not None else qmax
|
|
232
|
+
arr_fp32 = numpy.asarray((arr.astype(numpy.float32) / scale).round() + zero_point)
|
|
233
|
+
numpy.clip(arr_fp32, cliplow, cliphigh, out=arr_fp32)
|
|
234
|
+
return _check_type(arr_fp32.astype(dtype))
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def compute_scale_zp(rmin, rmax, qmin, qmax, symmetric=False, min_real_range=None):
|
|
238
|
+
"""Calculate the scale s and zero point z for the quantization relation
|
|
239
|
+
r = s(q-z), where r are the original values and q are the corresponding
|
|
240
|
+
quantized values.
|
|
241
|
+
|
|
242
|
+
r and z are calculated such that every value within [rmin,rmax] has an
|
|
243
|
+
approximate representation within [qmin,qmax]. In addition, qmin <= z <=
|
|
244
|
+
qmax is enforced. If the symmetric flag is set to True, the interval
|
|
245
|
+
[rmin,rmax] is symmetrized to [-absmax, +absmax], where
|
|
246
|
+
absmax = max(abs(rmin), abs(rmax)).
|
|
247
|
+
|
|
248
|
+
:parameter rmin: minimum value of r
|
|
249
|
+
:parameter rmax: maximum value of r
|
|
250
|
+
:parameter qmin: minimum value representable by the target quantization data type
|
|
251
|
+
:parameter qmax: maximum value representable by the target quantization data type
|
|
252
|
+
:parameter symmetric: True if the floating-point range should be made symmetric. Defaults to False.
|
|
253
|
+
:parameter min_real_range: Minimum floating-point range (i.e., rmax - rmin) to enforce. Defaults to None.
|
|
254
|
+
:return: zero and scale [z, s]
|
|
255
|
+
|
|
256
|
+
"""
|
|
257
|
+
if qmin > 0 or qmax < 0:
|
|
258
|
+
raise ValueError(f"qmin and qmax must meet requirement: qmin <= 0 <= qmax while qmin:{qmin}, qmmax:{qmax}")
|
|
259
|
+
|
|
260
|
+
# Adjust rmin and rmax such that 0 is included in the range. This is
|
|
261
|
+
# required to make sure zero can be represented by the quantization data
|
|
262
|
+
# type (i.e. to make sure qmin <= zero_point <= qmax)
|
|
263
|
+
rmin = numpy.minimum(rmin, numpy.array(0, dtype=rmin.dtype))
|
|
264
|
+
rmax = numpy.maximum(rmax, numpy.array(0, dtype=rmax.dtype))
|
|
265
|
+
|
|
266
|
+
# Ensure a minimum float-point range if specified.
|
|
267
|
+
if min_real_range is not None:
|
|
268
|
+
rmax = max(rmax, rmin + numpy.asarray(min_real_range, dtype=rmin.dtype))
|
|
269
|
+
|
|
270
|
+
if symmetric:
|
|
271
|
+
absmax = numpy.maximum(numpy.abs(rmin), numpy.abs(rmax))
|
|
272
|
+
rmin = -absmax
|
|
273
|
+
rmax = +absmax
|
|
274
|
+
|
|
275
|
+
assert qmin <= qmax, f"qmin={rmin} > qmax={rmax}"
|
|
276
|
+
dr = numpy.array(rmax - rmin, dtype=numpy.float64)
|
|
277
|
+
dq = numpy.array(qmax, dtype=numpy.float64) - numpy.array(qmin, dtype=numpy.float64)
|
|
278
|
+
scale = numpy.array(dr / dq)
|
|
279
|
+
assert scale >= 0, "scale issue"
|
|
280
|
+
if scale < numpy.finfo(rmax.dtype).tiny:
|
|
281
|
+
scale = numpy.array(1.0, dtype=rmax.dtype)
|
|
282
|
+
zero_point = numpy.array(0, dtype=qmin.dtype)
|
|
283
|
+
else:
|
|
284
|
+
if symmetric:
|
|
285
|
+
# When symmetric (i.e., rmax == -rmin), the zero_point formula reduces to round((qmax + qmin) / 2.0).
|
|
286
|
+
# This simpler formula doesn't depend on scale and guarantees that the zero point values
|
|
287
|
+
# for int8, uint8, int16, and uint16 are always 0, 128, 0, and 32768, respectively.
|
|
288
|
+
# This is important for per-channel/symmetric QLinearConv on CPU EP, which requires all channels to have
|
|
289
|
+
# the exact same zero_point values.
|
|
290
|
+
zero_point = numpy.array(
|
|
291
|
+
numpy.round((qmin + qmax) / numpy.array(2.0, dtype=numpy.float64)), dtype=qmin.dtype
|
|
292
|
+
)
|
|
293
|
+
else:
|
|
294
|
+
zero_point = numpy.array(numpy.round(qmin - rmin / scale), dtype=qmin.dtype)
|
|
295
|
+
scale = scale.astype(rmax.dtype)
|
|
296
|
+
|
|
297
|
+
return [zero_point, scale]
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def compute_scale_zp_float8(element_type, std):
|
|
301
|
+
"""Calculate the scale s for a float8 type (E4M3FN).
|
|
302
|
+
The function assumes the coefficient distribution and the float 8
|
|
303
|
+
distribution are similar to two gaussian laws.
|
|
304
|
+
|
|
305
|
+
:return: zero and scale [z, s]
|
|
306
|
+
|
|
307
|
+
More details in notebook `quantization_fp8.ipynb
|
|
308
|
+
<https://github.com/microsoft/onnxruntime/blob/main/docs/python/notebooks/quantization_fp8.ipynb>`_.
|
|
309
|
+
"""
|
|
310
|
+
zp_dtype = None
|
|
311
|
+
if element_type not in FLOAT8_DISTRIBUTIONS:
|
|
312
|
+
if element_type == TensorProto.FLOAT8E4M3FN:
|
|
313
|
+
from ml_dtypes import float8_e4m3fn # noqa: PLC0415
|
|
314
|
+
|
|
315
|
+
zp_dtype = float8_e4m3fn
|
|
316
|
+
all_values = [float(i) for i in range(256)]
|
|
317
|
+
values = numpy.array(
|
|
318
|
+
[f for f in all_values if not numpy.isnan(f) and not numpy.isinf(f)], dtype=numpy.float32
|
|
319
|
+
)
|
|
320
|
+
else:
|
|
321
|
+
raise ValueError(f"Quantization to element_type={element_type} not implemented.")
|
|
322
|
+
FLOAT8_DISTRIBUTIONS[element_type] = values
|
|
323
|
+
elif element_type == TensorProto.FLOAT8E4M3FN:
|
|
324
|
+
from ml_dtypes import float8_e4m3fn # noqa: PLC0415
|
|
325
|
+
|
|
326
|
+
zp_dtype = float8_e4m3fn
|
|
327
|
+
|
|
328
|
+
if zp_dtype is None:
|
|
329
|
+
raise TypeError(f"Unexpected element_type {element_type}.")
|
|
330
|
+
std_f8 = numpy.std(FLOAT8_DISTRIBUTIONS[element_type])
|
|
331
|
+
zero = numpy.array(0, dtype=zp_dtype)
|
|
332
|
+
scale = numpy.array(std / std_f8, dtype=std.dtype)
|
|
333
|
+
return [zero, scale]
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def compute_data_quant_params(
|
|
337
|
+
data: numpy.ndarray,
|
|
338
|
+
quant_type: onnx.TensorProto.DataType,
|
|
339
|
+
symmetric: bool,
|
|
340
|
+
reduce_range: bool = False,
|
|
341
|
+
min_real_range: float | None = None,
|
|
342
|
+
rmin_override: float | None = None,
|
|
343
|
+
rmax_override: float | None = None,
|
|
344
|
+
) -> tuple[numpy.ndarray, numpy.ndarray]:
|
|
345
|
+
"""
|
|
346
|
+
Returns the zero_point and scale for the given data.
|
|
347
|
+
|
|
348
|
+
:param data: The data for which to compute quantization parameters.
|
|
349
|
+
:param quant_type: The quantization data type.
|
|
350
|
+
:param symmetric: whether symmetric quantization is used or not.
|
|
351
|
+
:parameter reduce_range: True if the quantization range should be reduced. Defaults to False.
|
|
352
|
+
:parameter min_real_range: Minimum floating-point range (i.e., rmax - rmin) to enforce. Defaults to None.
|
|
353
|
+
:parameter rmin_override: The value of rmin to use if not None. Otherwise, uses min(data).
|
|
354
|
+
:parameter rmax_override: The value of rmax to use if not None. Otherwise, uses max(data).
|
|
355
|
+
:return: zero point and scale
|
|
356
|
+
"""
|
|
357
|
+
if not isinstance(data, numpy.ndarray):
|
|
358
|
+
raise TypeError(f"Weight must be given as an array not {type(data)}.")
|
|
359
|
+
if rmin_override is not None:
|
|
360
|
+
rmin = rmin_override
|
|
361
|
+
else:
|
|
362
|
+
rmin = data.min() if len(data) else 0.0
|
|
363
|
+
|
|
364
|
+
if rmax_override is not None:
|
|
365
|
+
rmax = rmax_override
|
|
366
|
+
else:
|
|
367
|
+
rmax = data.max() if len(data) else 0.0
|
|
368
|
+
|
|
369
|
+
rmin = numpy.array(rmin, dtype=data.dtype)
|
|
370
|
+
rmax = numpy.array(rmax, dtype=data.dtype)
|
|
371
|
+
scale = numpy.array(1.0, dtype=data.dtype)
|
|
372
|
+
|
|
373
|
+
if quant_type == TensorProto.FLOAT8E4M3FN:
|
|
374
|
+
if reduce_range:
|
|
375
|
+
raise RuntimeError("Unsupported option reduce_range=True for float 8.")
|
|
376
|
+
std = numpy.std(data)
|
|
377
|
+
zero_point, scale = compute_scale_zp_float8(quant_type, std)
|
|
378
|
+
return _check_type(zero_point, scale, zero_point_index=0)
|
|
379
|
+
|
|
380
|
+
if quant_type in (
|
|
381
|
+
TensorProto.INT8,
|
|
382
|
+
TensorProto.UINT8,
|
|
383
|
+
TensorProto.INT16,
|
|
384
|
+
TensorProto.UINT16,
|
|
385
|
+
TensorProto.INT4,
|
|
386
|
+
TensorProto.UINT4,
|
|
387
|
+
):
|
|
388
|
+
qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range, symmetric=symmetric)
|
|
389
|
+
if len(data):
|
|
390
|
+
zero_point, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, min_real_range)
|
|
391
|
+
else:
|
|
392
|
+
zero_point = numpy.array(0, dtype=qmin.dtype)
|
|
393
|
+
return _check_type(zero_point, scale, zero_point_index=0)
|
|
394
|
+
|
|
395
|
+
raise ValueError(f"Unexpected value for quant_type={quant_type}.")
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def quantize_data(
|
|
399
|
+
data, qType, symmetric, reduce_range=False, min_real_range=None, rmin_override=None, rmax_override=None
|
|
400
|
+
) -> tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]:
|
|
401
|
+
"""
|
|
402
|
+
:param data: data to quantize
|
|
403
|
+
:param qType: data type to quantize to.
|
|
404
|
+
:param symmetric: whether symmetric quantization is used or not.
|
|
405
|
+
:parameter reduce_range: True if the quantization range should be reduced. Defaults to False.
|
|
406
|
+
:parameter min_real_range: Minimum floating-point range (i.e., rmax - rmin) to enforce. Defaults to None.
|
|
407
|
+
:parameter rmin_override: The value of rmin to use if not None. Otherwise, uses min(data).
|
|
408
|
+
:parameter rmax_override: The value of rmax to use if not None. Otherwise, uses max(data).
|
|
409
|
+
:return: minimum, maximum, zero point, scale, and quantized weights
|
|
410
|
+
|
|
411
|
+
To pack weights, we compute a linear transformation
|
|
412
|
+
|
|
413
|
+
- when data `type == uint8` mode, from `[rmin, rmax]` -> :math:`[0, 2^{b-1}]` and
|
|
414
|
+
- when data `type == int8`, from `[-m , m]` -> :math:`[-(2^{b-1}-1), 2^{b-1}-1]` where
|
|
415
|
+
`m = max(abs(rmin), abs(rmax))`
|
|
416
|
+
|
|
417
|
+
and add necessary intermediate nodes to transform quantized weight to full weight using the equation
|
|
418
|
+
|
|
419
|
+
:math:`r = S(q-z)`, where
|
|
420
|
+
|
|
421
|
+
- *r*: real original value
|
|
422
|
+
- *q*: quantized value
|
|
423
|
+
- *S*: scale
|
|
424
|
+
- *z*: zero point
|
|
425
|
+
"""
|
|
426
|
+
zero_point, scale = compute_data_quant_params(
|
|
427
|
+
data,
|
|
428
|
+
qType,
|
|
429
|
+
symmetric,
|
|
430
|
+
reduce_range,
|
|
431
|
+
min_real_range,
|
|
432
|
+
rmin_override,
|
|
433
|
+
rmax_override,
|
|
434
|
+
)
|
|
435
|
+
if qType == TensorProto.FLOAT8E4M3FN:
|
|
436
|
+
quantized_data = quantize_nparray(qType, data, scale, zero_point)
|
|
437
|
+
if any((quantized_data.view(numpy.uint8).ravel() & 127) == 127):
|
|
438
|
+
np_data = numpy.asarray(data)
|
|
439
|
+
raise RuntimeError(
|
|
440
|
+
f"One of the quantized value is NaN data in [{np_data.min()}, {np_data.max()}], "
|
|
441
|
+
f"quantized_data in [{quantized_data.min()}, {quantized_data.max()}]."
|
|
442
|
+
)
|
|
443
|
+
return zero_point, scale, quantized_data
|
|
444
|
+
|
|
445
|
+
if qType in (
|
|
446
|
+
TensorProto.INT8,
|
|
447
|
+
TensorProto.UINT8,
|
|
448
|
+
TensorProto.INT16,
|
|
449
|
+
TensorProto.UINT16,
|
|
450
|
+
TensorProto.INT4,
|
|
451
|
+
TensorProto.UINT4,
|
|
452
|
+
):
|
|
453
|
+
quantized_data = quantize_nparray(qType, data, scale, zero_point)
|
|
454
|
+
return zero_point, scale, quantized_data
|
|
455
|
+
|
|
456
|
+
raise ValueError(f"Unexpected value for qType={qType}.")
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def quantize_onnx_initializer(
|
|
460
|
+
weight: onnx.TensorProto,
|
|
461
|
+
quant_type: onnx.TensorProto.DataType,
|
|
462
|
+
zero_point: numpy.ndarray,
|
|
463
|
+
scale: numpy.ndarray,
|
|
464
|
+
axis: int | None = None,
|
|
465
|
+
quant_weight_name: str | None = None,
|
|
466
|
+
) -> onnx.TensorProto:
|
|
467
|
+
"""
|
|
468
|
+
Returns a quantized version of the given ONNX initializer.
|
|
469
|
+
|
|
470
|
+
:param weight: The ONNX initializer to quantize.
|
|
471
|
+
:param quant_type: The final quantized data type.
|
|
472
|
+
:param zero_point: The zero-point value to use for quantization.
|
|
473
|
+
:param scale: The scale value to use for quantization.
|
|
474
|
+
:param axis: The quantization axis if quantizing per-channel. Defaults to None.
|
|
475
|
+
:param quant_weight_name: The name of the quantized initializer.
|
|
476
|
+
If not specified, the quantized name is generated.
|
|
477
|
+
:return: The quantized ONNX initializer.
|
|
478
|
+
"""
|
|
479
|
+
weight_data = tensor_proto_to_array(weight)
|
|
480
|
+
q_weight_data: numpy.ndarray | None = None
|
|
481
|
+
|
|
482
|
+
if axis is None: # Per-tensor quantization
|
|
483
|
+
q_weight_data = quantize_nparray(quant_type, weight_data.ravel(), scale, zero_point)
|
|
484
|
+
else: # Per-channel quantization
|
|
485
|
+
channel_count = weight_data.shape[axis]
|
|
486
|
+
channel_dims = list(weight_data.shape) # deep copy
|
|
487
|
+
channel_dims[axis] = 1 # only one per channel for reshape
|
|
488
|
+
quantized_channel_data_list = []
|
|
489
|
+
|
|
490
|
+
for i in range(channel_count):
|
|
491
|
+
channel_data = weight_data.take(i, axis)
|
|
492
|
+
channel_scale = scale[i]
|
|
493
|
+
channel_zero_point = zero_point[i]
|
|
494
|
+
quantized_channel_data = quantize_nparray(
|
|
495
|
+
quant_type, channel_data.ravel(), channel_scale, channel_zero_point
|
|
496
|
+
)
|
|
497
|
+
quantized_channel_data_list.append(numpy.asarray(quantized_channel_data).reshape(channel_dims))
|
|
498
|
+
|
|
499
|
+
q_weight_data = numpy.concatenate(quantized_channel_data_list, axis)
|
|
500
|
+
|
|
501
|
+
q_weight_name = quant_weight_name if quant_weight_name else f"{weight.name}{TENSOR_NAME_QUANT_SUFFIX}"
|
|
502
|
+
|
|
503
|
+
if quant_type == onnx.TensorProto.FLOAT8E4M3FN:
|
|
504
|
+
q_weight_initializer = onnx.TensorProto()
|
|
505
|
+
q_weight_initializer.data_type = quant_type
|
|
506
|
+
q_weight_initializer.dims.extend(weight.dims)
|
|
507
|
+
q_weight_initializer.name = q_weight_name
|
|
508
|
+
# Do not remove .flatten().copy() numpy is not clear about data persistence.
|
|
509
|
+
q_weight_initializer.raw_data = q_weight_data.flatten().copy().tobytes()
|
|
510
|
+
if to_array_extended is not None:
|
|
511
|
+
# This test should not be needed but it helped catch some issues
|
|
512
|
+
# with data persistence and tobytes.
|
|
513
|
+
check = to_array_extended(q_weight_initializer)
|
|
514
|
+
if check.shape != weight_data.shape or check.tobytes() != q_weight_data.tobytes():
|
|
515
|
+
raise RuntimeError(
|
|
516
|
+
f"The initializer of shape {weight_data.shape} could not be created, expecting "
|
|
517
|
+
f"{q_weight_data.tobytes()[:10]}, got {check.tobytes()[:10]} and shape={weight.shape}"
|
|
518
|
+
f"\nraw={str(q_weight_initializer)[:200]}."
|
|
519
|
+
)
|
|
520
|
+
elif quant_type in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4):
|
|
521
|
+
if q_weight_data.dtype not in (int4, uint4):
|
|
522
|
+
raise RuntimeError(f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values.")
|
|
523
|
+
|
|
524
|
+
# We do not use onnx.helper.pack_float32_to_4bit() due to performance.
|
|
525
|
+
# This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes.
|
|
526
|
+
packed_data = bytes(pack_bytes_to_4bit(q_weight_data.tobytes()))
|
|
527
|
+
|
|
528
|
+
# We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161
|
|
529
|
+
q_weight_initializer = onnx.helper.make_tensor(q_weight_name, quant_type, weight.dims, packed_data, raw=True)
|
|
530
|
+
else:
|
|
531
|
+
quant_np_dtype = onnx.helper.tensor_dtype_to_np_dtype(quant_type)
|
|
532
|
+
q_weight_data = numpy.asarray(q_weight_data, dtype=quant_np_dtype).reshape(weight.dims)
|
|
533
|
+
q_weight_initializer = onnx.numpy_helper.from_array(q_weight_data, q_weight_name)
|
|
534
|
+
|
|
535
|
+
return q_weight_initializer
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
def get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=False): # noqa: N802
|
|
539
|
+
"""
|
|
540
|
+
Return qmin and qmax, the minimum and maximum value representable by the given qType
|
|
541
|
+
:parameter qType: onnx.onnx_pb.TensorProto.UINT8 or onnx.onnx_pb.TensorProto.UINT8
|
|
542
|
+
:return: qmin, qmax
|
|
543
|
+
"""
|
|
544
|
+
if qType == onnx_proto.TensorProto.FLOAT8E4M3FN:
|
|
545
|
+
raise NotImplementedError("This function is not implemented for float 8 as not needed.")
|
|
546
|
+
|
|
547
|
+
qrange = None
|
|
548
|
+
|
|
549
|
+
if reduce_range:
|
|
550
|
+
qrange = ONNX_INT_TYPE_REDUCED_RANGE.get(qType)
|
|
551
|
+
elif symmetric and qType in ONNX_INT_TYPE_SYMMETRIC_RANGE:
|
|
552
|
+
qrange = ONNX_INT_TYPE_SYMMETRIC_RANGE[qType]
|
|
553
|
+
else:
|
|
554
|
+
qrange = ONNX_INT_TYPE_RANGE.get(qType)
|
|
555
|
+
|
|
556
|
+
if not qrange:
|
|
557
|
+
raise ValueError(f"Unexpected data type {qType} requested. Only INT8, UINT8, INT16, and UINT16 are supported.")
|
|
558
|
+
|
|
559
|
+
qmin, qmax = qrange
|
|
560
|
+
if qmin > 0 or qmax < 0:
|
|
561
|
+
raise ValueError(
|
|
562
|
+
f"qmin and qmax must meet requirement: qmin <= 0 <= qmax while "
|
|
563
|
+
f"qmin:{qmin}, qmmax:{qmax}, dtype={qmin.dtype}, reduce_range={reduce_range}, "
|
|
564
|
+
f"symmetric={symmetric}, qType={qType}"
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
return qrange
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
def get_qrange_for_qType(qType, reduce_range=False, symmetric=False): # noqa: N802
|
|
571
|
+
"""
|
|
572
|
+
Helper function to get the quantization range for a type.
|
|
573
|
+
parameter qType: quantization type.
|
|
574
|
+
return: quantization range.
|
|
575
|
+
"""
|
|
576
|
+
qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range, symmetric=symmetric)
|
|
577
|
+
return qmax - qmin
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
def normalize_axis(axis: int, rank: int) -> tuple[bool, int]:
|
|
581
|
+
"""
|
|
582
|
+
Helper function that tries to return a normalized axis in the range [0, rank - 1].
|
|
583
|
+
:parameter axis: The axis to normalize.
|
|
584
|
+
:parameter rank: The tensor rank (number of dimensions).
|
|
585
|
+
:return (is_valid, axis_norm)
|
|
586
|
+
"""
|
|
587
|
+
axis_norm = axis + rank if axis < 0 else axis
|
|
588
|
+
is_valid = axis_norm >= 0 and axis_norm < rank
|
|
589
|
+
return is_valid, axis_norm
|
|
590
|
+
|
|
591
|
+
|
|
592
|
+
def pack_bytes_to_4bit(src_8bit: bytes) -> bytearray:
|
|
593
|
+
"""
|
|
594
|
+
Copies a source array of 8-bit values into a destination bytearray of packed 4-bit values.
|
|
595
|
+
Assumes that the source values are already in the appropriate int4 range.
|
|
596
|
+
:parameter src_8bit: The 8-bit element values to pack.
|
|
597
|
+
:return A bytearray with every two 8-bit src elements packed into a single byte.
|
|
598
|
+
"""
|
|
599
|
+
num_elems = len(src_8bit)
|
|
600
|
+
if num_elems == 0:
|
|
601
|
+
return bytearray()
|
|
602
|
+
|
|
603
|
+
dst_size = (num_elems + 1) // 2 # Ex: 5 8-bit elems packed into 3 bytes
|
|
604
|
+
dst = bytearray(dst_size)
|
|
605
|
+
|
|
606
|
+
src_i: int = 0
|
|
607
|
+
dst_i: int = 0
|
|
608
|
+
|
|
609
|
+
# Pack two 8-bit elements into a single byte in each iteration.
|
|
610
|
+
while src_i < num_elems - 1:
|
|
611
|
+
dst[dst_i] = ((src_8bit[src_i + 1] & 0xF) << 4) | (src_8bit[src_i] & 0xF)
|
|
612
|
+
dst_i += 1
|
|
613
|
+
src_i += 2
|
|
614
|
+
|
|
615
|
+
if src_i < num_elems:
|
|
616
|
+
# Odd number of elements.
|
|
617
|
+
dst[dst_i] = src_8bit[src_i] & 0xF
|
|
618
|
+
|
|
619
|
+
return dst
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
class QuantizedInitializer:
|
|
623
|
+
"""
|
|
624
|
+
Represents a linearly quantized weight input from ONNX operators
|
|
625
|
+
"""
|
|
626
|
+
|
|
627
|
+
def __init__(
|
|
628
|
+
self,
|
|
629
|
+
name,
|
|
630
|
+
initializer,
|
|
631
|
+
rmins,
|
|
632
|
+
rmaxs,
|
|
633
|
+
zero_points,
|
|
634
|
+
scales,
|
|
635
|
+
data=[], # noqa: B006
|
|
636
|
+
quantized_data=[], # noqa: B006
|
|
637
|
+
axis=None,
|
|
638
|
+
):
|
|
639
|
+
self.name = name
|
|
640
|
+
self.initializer = initializer # TensorProto initializer in ONNX graph
|
|
641
|
+
self.rmins = rmins # List of minimum range for each axis
|
|
642
|
+
self.rmaxs = rmaxs # List of maximum range for each axis
|
|
643
|
+
# 1D tensor of zero points computed for each axis. scalar if axis is empty
|
|
644
|
+
self.zero_points = zero_points
|
|
645
|
+
self.scales = scales # 1D tensor of scales computed for each axis. scalar if axis is empty
|
|
646
|
+
self.data = data # original data from initializer TensorProto
|
|
647
|
+
self.quantized_data = quantized_data # weight-packed data from data
|
|
648
|
+
# Scalar to specify which dimension in the initializer to weight pack.
|
|
649
|
+
self.axis = axis
|
|
650
|
+
# If empty, single zero point and scales computed from a single rmin and rmax
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
class QuantizedValue:
|
|
654
|
+
"""
|
|
655
|
+
Represents a linearly quantized value (input\\output\\intializer)
|
|
656
|
+
"""
|
|
657
|
+
|
|
658
|
+
def __init__(
|
|
659
|
+
self,
|
|
660
|
+
name,
|
|
661
|
+
new_quantized_name,
|
|
662
|
+
scale_name,
|
|
663
|
+
zero_point_name,
|
|
664
|
+
quantized_value_type,
|
|
665
|
+
axis=None,
|
|
666
|
+
node_type=None,
|
|
667
|
+
node_qtype=None,
|
|
668
|
+
scale_type=None,
|
|
669
|
+
):
|
|
670
|
+
self.original_name = name
|
|
671
|
+
self.q_name = new_quantized_name
|
|
672
|
+
self.scale_name = scale_name
|
|
673
|
+
self.zp_name = zero_point_name
|
|
674
|
+
self.value_type = quantized_value_type
|
|
675
|
+
self.axis = axis
|
|
676
|
+
self.node_type = node_type
|
|
677
|
+
self.node_qtype = node_qtype
|
|
678
|
+
self.scale_type = scale_type
|
|
679
|
+
|
|
680
|
+
|
|
681
|
+
class BiasToQuantize:
|
|
682
|
+
"""
|
|
683
|
+
Represents a bias to be quantized
|
|
684
|
+
"""
|
|
685
|
+
|
|
686
|
+
def __init__(self, bias_name, input_name, weight_name):
|
|
687
|
+
self.bias_name = bias_name
|
|
688
|
+
self.input_name = input_name
|
|
689
|
+
self.weight_name = weight_name
|
|
690
|
+
|
|
691
|
+
|
|
692
|
+
def attribute_to_kwarg(attribute):
|
|
693
|
+
"""
|
|
694
|
+
Convert attribute to kwarg format for use with onnx.helper.make_node.
|
|
695
|
+
:parameter attribute: attribute in AttributeProto format.
|
|
696
|
+
:return: attribute in {key: value} format.
|
|
697
|
+
"""
|
|
698
|
+
if attribute.type == 0:
|
|
699
|
+
raise ValueError(f"attribute {attribute.name} does not have type specified.")
|
|
700
|
+
|
|
701
|
+
# Based on attribute type definitions from AttributeProto
|
|
702
|
+
# definition in https://github.com/onnx/onnx/blob/main/onnx/onnx.proto
|
|
703
|
+
if attribute.type == 1:
|
|
704
|
+
value = attribute.f
|
|
705
|
+
elif attribute.type == 2:
|
|
706
|
+
value = attribute.i
|
|
707
|
+
elif attribute.type == 3:
|
|
708
|
+
value = attribute.s
|
|
709
|
+
elif attribute.type == 4:
|
|
710
|
+
value = attribute.t
|
|
711
|
+
elif attribute.type == 5:
|
|
712
|
+
value = attribute.g
|
|
713
|
+
elif attribute.type == 6:
|
|
714
|
+
value = attribute.floats
|
|
715
|
+
elif attribute.type == 7:
|
|
716
|
+
value = attribute.ints
|
|
717
|
+
elif attribute.type == 8:
|
|
718
|
+
value = attribute.strings
|
|
719
|
+
elif attribute.type == 9:
|
|
720
|
+
value = attribute.tensors
|
|
721
|
+
elif attribute.type == 10:
|
|
722
|
+
value = attribute.graphs
|
|
723
|
+
else:
|
|
724
|
+
raise ValueError(f"attribute {attribute.name} has unsupported type {attribute.type}.")
|
|
725
|
+
|
|
726
|
+
return {attribute.name: value}
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
def find_by_name(item_name, item_list):
|
|
730
|
+
"""
|
|
731
|
+
Helper function to find item by name in a list.
|
|
732
|
+
parameter item_name: name of the item.
|
|
733
|
+
parameter item_list: list of items.
|
|
734
|
+
return: item if found. None otherwise.
|
|
735
|
+
"""
|
|
736
|
+
items = [item for item in item_list if item.name == item_name]
|
|
737
|
+
return items[0] if len(items) > 0 else None
|
|
738
|
+
|
|
739
|
+
|
|
740
|
+
def get_elem_index(elem_name, elem_list):
|
|
741
|
+
"""
|
|
742
|
+
Helper function to return index of an item in a node list
|
|
743
|
+
"""
|
|
744
|
+
elem_idx = -1
|
|
745
|
+
for i in range(len(elem_list)):
|
|
746
|
+
if elem_list[i] == elem_name:
|
|
747
|
+
elem_idx = i
|
|
748
|
+
return elem_idx
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
def get_mul_node(inputs, output, name):
|
|
752
|
+
"""
|
|
753
|
+
Helper function to create a Mul node.
|
|
754
|
+
parameter inputs: list of input names.
|
|
755
|
+
parameter output: output name.
|
|
756
|
+
parameter name: name of the node.
|
|
757
|
+
return: Mul node in NodeProto format.
|
|
758
|
+
"""
|
|
759
|
+
return onnx.helper.make_node("Mul", inputs, [output], name)
|
|
760
|
+
|
|
761
|
+
|
|
762
|
+
def generate_identified_filename(filename: Path, identifier: str) -> Path:
|
|
763
|
+
"""
|
|
764
|
+
Helper function to generate a identifiable filepath by concatenating the given identifier as a suffix.
|
|
765
|
+
"""
|
|
766
|
+
return filename.parent.joinpath(filename.stem + identifier + filename.suffix)
|
|
767
|
+
|
|
768
|
+
|
|
769
|
+
def apply_plot(hist, hist_edges):
|
|
770
|
+
import sys # noqa: PLC0415
|
|
771
|
+
|
|
772
|
+
import matplotlib.pyplot as plt # noqa: PLC0415
|
|
773
|
+
import numpy # noqa: PLC0415
|
|
774
|
+
|
|
775
|
+
numpy.set_printoptions(threshold=sys.maxsize)
|
|
776
|
+
print("Histogram:")
|
|
777
|
+
print(hist)
|
|
778
|
+
print("Histogram Edges:")
|
|
779
|
+
print(hist_edges)
|
|
780
|
+
plt.stairs(hist, hist_edges, fill=True)
|
|
781
|
+
plt.xlabel("Tensor value")
|
|
782
|
+
plt.ylabel("Counts")
|
|
783
|
+
plt.title("Tensor value V.S. Counts")
|
|
784
|
+
plt.show()
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
def write_calibration_table(calibration_cache, dir="."):
|
|
788
|
+
"""
|
|
789
|
+
Helper function to write calibration table to files.
|
|
790
|
+
"""
|
|
791
|
+
|
|
792
|
+
import json # noqa: PLC0415
|
|
793
|
+
|
|
794
|
+
import flatbuffers # noqa: PLC0415
|
|
795
|
+
import numpy as np # noqa: PLC0415
|
|
796
|
+
|
|
797
|
+
import onnxruntime.quantization.CalTableFlatBuffers.KeyValue as KeyValue # noqa: PLC0415
|
|
798
|
+
import onnxruntime.quantization.CalTableFlatBuffers.TrtTable as TrtTable # noqa: PLC0415
|
|
799
|
+
from onnxruntime.quantization.calibrate import CalibrationMethod, TensorData, TensorsData # noqa: PLC0415
|
|
800
|
+
|
|
801
|
+
logging.info(f"calibration cache: {calibration_cache}")
|
|
802
|
+
|
|
803
|
+
class MyEncoder(json.JSONEncoder):
|
|
804
|
+
def default(self, obj):
|
|
805
|
+
if isinstance(obj, (TensorData, TensorsData)):
|
|
806
|
+
return obj.to_dict()
|
|
807
|
+
if isinstance(obj, np.ndarray):
|
|
808
|
+
return {"data": obj.tolist(), "dtype": str(obj.dtype), "CLS": "numpy.array"}
|
|
809
|
+
if isinstance(obj, CalibrationMethod):
|
|
810
|
+
return {"CLS": obj.__class__.__name__, "value": str(obj)}
|
|
811
|
+
return json.JSONEncoder.default(self, obj)
|
|
812
|
+
|
|
813
|
+
json_data = json.dumps(calibration_cache, cls=MyEncoder)
|
|
814
|
+
|
|
815
|
+
with open(os.path.join(dir, "calibration.json"), "w") as file:
|
|
816
|
+
file.write(json_data) # use `json.loads` to do the reverse
|
|
817
|
+
|
|
818
|
+
# Serialize data using FlatBuffers
|
|
819
|
+
zero = np.array(0)
|
|
820
|
+
builder = flatbuffers.Builder(1024)
|
|
821
|
+
key_value_list = []
|
|
822
|
+
for key in sorted(calibration_cache.keys()):
|
|
823
|
+
values = calibration_cache[key]
|
|
824
|
+
d_values = values.to_dict()
|
|
825
|
+
floats = [
|
|
826
|
+
float(d_values.get("highest", zero).item()),
|
|
827
|
+
float(d_values.get("lowest", zero).item()),
|
|
828
|
+
]
|
|
829
|
+
value = str(max(floats))
|
|
830
|
+
|
|
831
|
+
flat_key = builder.CreateString(key)
|
|
832
|
+
flat_value = builder.CreateString(value)
|
|
833
|
+
|
|
834
|
+
KeyValue.KeyValueStart(builder)
|
|
835
|
+
KeyValue.KeyValueAddKey(builder, flat_key)
|
|
836
|
+
KeyValue.KeyValueAddValue(builder, flat_value)
|
|
837
|
+
key_value = KeyValue.KeyValueEnd(builder)
|
|
838
|
+
|
|
839
|
+
key_value_list.append(key_value)
|
|
840
|
+
|
|
841
|
+
TrtTable.TrtTableStartDictVector(builder, len(key_value_list))
|
|
842
|
+
for key_value in key_value_list:
|
|
843
|
+
builder.PrependUOffsetTRelative(key_value)
|
|
844
|
+
main_dict = builder.EndVector()
|
|
845
|
+
|
|
846
|
+
TrtTable.TrtTableStart(builder)
|
|
847
|
+
TrtTable.TrtTableAddDict(builder, main_dict)
|
|
848
|
+
cal_table = TrtTable.TrtTableEnd(builder)
|
|
849
|
+
|
|
850
|
+
builder.Finish(cal_table)
|
|
851
|
+
buf = builder.Output()
|
|
852
|
+
|
|
853
|
+
with open(os.path.join(dir, "calibration.flatbuffers"), "wb") as file:
|
|
854
|
+
file.write(buf)
|
|
855
|
+
|
|
856
|
+
# Deserialize data (for validation)
|
|
857
|
+
if os.environ.get("QUANTIZATION_DEBUG", "0") in (1, "1"):
|
|
858
|
+
cal_table = TrtTable.TrtTable.GetRootAsTrtTable(buf, 0)
|
|
859
|
+
dict_len = cal_table.DictLength()
|
|
860
|
+
for i in range(dict_len):
|
|
861
|
+
key_value = cal_table.Dict(i)
|
|
862
|
+
logging.info(key_value.Key())
|
|
863
|
+
logging.info(key_value.Value())
|
|
864
|
+
|
|
865
|
+
# write plain text
|
|
866
|
+
with open(os.path.join(dir, "calibration.cache"), "w") as file:
|
|
867
|
+
for key in sorted(calibration_cache.keys()):
|
|
868
|
+
values = calibration_cache[key]
|
|
869
|
+
d_values = values.to_dict()
|
|
870
|
+
floats = [
|
|
871
|
+
float(d_values.get("highest", zero).item()),
|
|
872
|
+
float(d_values.get("lowest", zero).item()),
|
|
873
|
+
]
|
|
874
|
+
value = key + " " + str(max(floats))
|
|
875
|
+
file.write(value)
|
|
876
|
+
file.write("\n")
|
|
877
|
+
|
|
878
|
+
|
|
879
|
+
def smooth_distribution(p, eps=0.0001):
|
|
880
|
+
"""Given a discrete distribution (may have not been normalized to 1),
|
|
881
|
+
smooth it by replacing zeros with eps multiplied by a scaling factor
|
|
882
|
+
and taking the corresponding amount off the non-zero values.
|
|
883
|
+
Ref: http://web.engr.illinois.edu/~hanj/cs412/bk3/KL-divergence.pdf
|
|
884
|
+
https://github.com//apache/incubator-mxnet/blob/master/python/mxnet/contrib/quantization.py
|
|
885
|
+
"""
|
|
886
|
+
is_zeros = (p == 0).astype(numpy.float32)
|
|
887
|
+
is_nonzeros = (p != 0).astype(numpy.float32)
|
|
888
|
+
n_zeros = is_zeros.sum()
|
|
889
|
+
n_nonzeros = p.size - n_zeros
|
|
890
|
+
|
|
891
|
+
if not n_nonzeros:
|
|
892
|
+
# raise ValueError('The discrete probability distribution is malformed. All entries are 0.')
|
|
893
|
+
return None
|
|
894
|
+
eps1 = eps * float(n_zeros) / float(n_nonzeros)
|
|
895
|
+
assert eps1 < 1.0, f"n_zeros={n_zeros}, n_nonzeros={n_nonzeros}, eps1={eps1}"
|
|
896
|
+
|
|
897
|
+
hist = p.astype(numpy.float32)
|
|
898
|
+
hist += eps * is_zeros + (-eps1) * is_nonzeros
|
|
899
|
+
assert (hist <= 0).sum() == 0
|
|
900
|
+
|
|
901
|
+
return hist
|
|
902
|
+
|
|
903
|
+
|
|
904
|
+
def model_has_external_data(model_path: Path):
|
|
905
|
+
model = onnx.load(model_path.as_posix(), load_external_data=False)
|
|
906
|
+
return any(external_data_helper.uses_external_data(intializer) for intializer in model.graph.initializer)
|
|
907
|
+
|
|
908
|
+
|
|
909
|
+
def optimize_model(model_path: Path, opt_model_path: Path):
|
|
910
|
+
"""
|
|
911
|
+
Generate model that applies graph optimization (constant folding, etc.)
|
|
912
|
+
parameter model_path: path to the original onnx model
|
|
913
|
+
parameter opt_model_path: path to the optimized onnx model
|
|
914
|
+
:return: optimized onnx model
|
|
915
|
+
"""
|
|
916
|
+
sess_option = SessionOptions()
|
|
917
|
+
sess_option.optimized_model_filepath = opt_model_path.as_posix()
|
|
918
|
+
sess_option.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
|
919
|
+
kwargs = {}
|
|
920
|
+
# This will rename constant initializer names, disable it to make test pass.
|
|
921
|
+
kwargs["disabled_optimizers"] = ["ConstantSharing"]
|
|
922
|
+
_ = InferenceSession(model_path.as_posix(), sess_option, providers=["CPUExecutionProvider"], **kwargs)
|
|
923
|
+
|
|
924
|
+
|
|
925
|
+
def add_pre_process_metadata(model: ModelProto):
|
|
926
|
+
"""Tag the model that it went through quantization pre-processing"""
|
|
927
|
+
metadata_props = {"onnx.quant.pre_process": "onnxruntime.quant"}
|
|
928
|
+
if model.metadata_props:
|
|
929
|
+
for prop in model.metadata_props:
|
|
930
|
+
metadata_props.update({prop.key: prop.value})
|
|
931
|
+
onnx.helper.set_model_props(model, metadata_props)
|
|
932
|
+
|
|
933
|
+
|
|
934
|
+
def model_has_pre_process_metadata(model: ModelProto) -> bool:
|
|
935
|
+
"""Check the model whether it went through quantization pre-processing"""
|
|
936
|
+
if model.metadata_props:
|
|
937
|
+
for prop in model.metadata_props:
|
|
938
|
+
if prop.key == "onnx.quant.pre_process" and prop.value == "onnxruntime.quant":
|
|
939
|
+
return True
|
|
940
|
+
return False
|
|
941
|
+
|
|
942
|
+
|
|
943
|
+
def add_infer_metadata(model: ModelProto):
|
|
944
|
+
metadata_props = {"onnx.infer": "onnxruntime.quant"}
|
|
945
|
+
if model.metadata_props:
|
|
946
|
+
for p in model.metadata_props:
|
|
947
|
+
metadata_props.update({p.key: p.value})
|
|
948
|
+
onnx.helper.set_model_props(model, metadata_props)
|
|
949
|
+
|
|
950
|
+
|
|
951
|
+
def model_has_infer_metadata(model: ModelProto) -> bool:
|
|
952
|
+
if model.metadata_props:
|
|
953
|
+
for p in model.metadata_props:
|
|
954
|
+
if p.key == "onnx.infer" and p.value == "onnxruntime.quant":
|
|
955
|
+
return True
|
|
956
|
+
return False
|
|
957
|
+
|
|
958
|
+
|
|
959
|
+
def get_opset_version(model: ModelProto) -> int:
|
|
960
|
+
ai_onnx_domain = [opset for opset in model.opset_import if not opset.domain or opset.domain == "ai.onnx"]
|
|
961
|
+
if len(ai_onnx_domain) != 1:
|
|
962
|
+
raise ValueError("Failed to find proper ai.onnx domain")
|
|
963
|
+
opset_version = ai_onnx_domain[0].version
|
|
964
|
+
|
|
965
|
+
return opset_version
|
|
966
|
+
|
|
967
|
+
|
|
968
|
+
def update_opset_version(model: ModelProto, weight_type: QuantType) -> ModelProto:
|
|
969
|
+
opset_version = get_opset_version(model)
|
|
970
|
+
target_opset_version = opset_version
|
|
971
|
+
weight_quant_type = getattr(weight_type, "tensor_type", weight_type)
|
|
972
|
+
|
|
973
|
+
if opset_version < 19 and weight_quant_type == onnx.TensorProto.FLOAT8E4M3FN:
|
|
974
|
+
logging.warning(
|
|
975
|
+
f"The original model opset version is {opset_version}, which does not support quantization to float 8. "
|
|
976
|
+
"Please update the model to opset >= 19. Automatically update the model to opset 19. "
|
|
977
|
+
"Please verify the quantized model."
|
|
978
|
+
)
|
|
979
|
+
target_opset_version = 19
|
|
980
|
+
|
|
981
|
+
elif opset_version == 10:
|
|
982
|
+
logging.warning(
|
|
983
|
+
f"The original model opset version is {opset_version}, which does not support node fusions. "
|
|
984
|
+
"Please update the model to opset >= 11 for better performance."
|
|
985
|
+
)
|
|
986
|
+
|
|
987
|
+
elif opset_version < 10:
|
|
988
|
+
logging.warning(
|
|
989
|
+
f"The original model opset version is {opset_version}, which does not support quantization. "
|
|
990
|
+
"Please update the model to opset >= 11. Automatically update the model to opset 11. "
|
|
991
|
+
"Please verify the quantized model."
|
|
992
|
+
)
|
|
993
|
+
target_opset_version = 11
|
|
994
|
+
|
|
995
|
+
if target_opset_version != opset_version:
|
|
996
|
+
model = onnx.version_converter.convert_version(model, target_opset_version)
|
|
997
|
+
# Additional nodes may be added to the model during the opset version conversion. Run shape inference
|
|
998
|
+
# to ensure all nodes are included in model.graph.value_info.
|
|
999
|
+
model = save_and_reload_model_with_shape_infer(model)
|
|
1000
|
+
|
|
1001
|
+
return model
|
|
1002
|
+
|
|
1003
|
+
|
|
1004
|
+
def load_model_with_shape_infer(model_path: Path) -> ModelProto:
|
|
1005
|
+
inferred_model_path = generate_identified_filename(model_path, "-inferred")
|
|
1006
|
+
onnx.shape_inference.infer_shapes_path(str(model_path), str(inferred_model_path))
|
|
1007
|
+
model = onnx.load(inferred_model_path.as_posix())
|
|
1008
|
+
add_infer_metadata(model)
|
|
1009
|
+
inferred_model_path.unlink()
|
|
1010
|
+
return model
|
|
1011
|
+
|
|
1012
|
+
|
|
1013
|
+
def save_and_reload_model_with_shape_infer(model: ModelProto) -> ModelProto:
|
|
1014
|
+
with tempfile.TemporaryDirectory(prefix="ort.quant.") as quant_tmp_dir:
|
|
1015
|
+
model_copy = copy.deepcopy(model)
|
|
1016
|
+
model_path = Path(quant_tmp_dir).joinpath("model.onnx")
|
|
1017
|
+
onnx.save_model(model_copy, model_path.as_posix(), save_as_external_data=True)
|
|
1018
|
+
return load_model_with_shape_infer(model_path)
|
|
1019
|
+
|
|
1020
|
+
|
|
1021
|
+
def tensor_proto_to_array(initializer: TensorProto) -> numpy.ndarray:
|
|
1022
|
+
if initializer.data_type in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16):
|
|
1023
|
+
return onnx.numpy_helper.to_array(initializer)
|
|
1024
|
+
|
|
1025
|
+
raise ValueError(
|
|
1026
|
+
f"Only float type is supported. Weights {initializer.name} is {type_to_name[initializer.data_type]}"
|
|
1027
|
+
)
|
|
1028
|
+
|
|
1029
|
+
|
|
1030
|
+
def add_quant_suffix(tensor_name: str) -> str:
|
|
1031
|
+
return tensor_name + "_QuantizeLinear"
|
|
1032
|
+
|
|
1033
|
+
|
|
1034
|
+
def add_quant_input_suffix(tensor_name: str) -> str:
|
|
1035
|
+
return tensor_name + QUANT_INPUT_SUFFIX
|
|
1036
|
+
|
|
1037
|
+
|
|
1038
|
+
def add_quant_output_suffix(tensor_name) -> str:
|
|
1039
|
+
return tensor_name + "_QuantizeLinear_Output"
|
|
1040
|
+
|
|
1041
|
+
|
|
1042
|
+
def add_dequant_suffix(tensor_name) -> str:
|
|
1043
|
+
return tensor_name + "_DequantizeLinear"
|
|
1044
|
+
|
|
1045
|
+
|
|
1046
|
+
def add_dequant_input_suffix(tensor_name) -> str:
|
|
1047
|
+
return tensor_name + "_DequantizeLinear_Input"
|
|
1048
|
+
|
|
1049
|
+
|
|
1050
|
+
def add_dequant_output_suffix(tensor_name) -> str:
|
|
1051
|
+
return tensor_name + DEQUANT_OUTPUT_SUFFIX
|