onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1267 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
# -------------------------------------------------------------------------
|
|
3
|
+
# Copyright (c) Microsoft, Intel Corporation. All rights reserved.
|
|
4
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
5
|
+
# license information.
|
|
6
|
+
# --------------------------------------------------------------------------
|
|
7
|
+
import abc
|
|
8
|
+
import copy
|
|
9
|
+
import itertools
|
|
10
|
+
import os
|
|
11
|
+
import uuid
|
|
12
|
+
from collections.abc import Sequence
|
|
13
|
+
from enum import Enum
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
import onnx
|
|
18
|
+
from onnx import ModelProto, TensorProto, helper, numpy_helper
|
|
19
|
+
|
|
20
|
+
import onnxruntime
|
|
21
|
+
|
|
22
|
+
from .quant_utils import apply_plot, load_model_with_shape_infer, smooth_distribution
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def rel_entr(pk: np.ndarray, qk: np.ndarray) -> np.ndarray:
|
|
26
|
+
"""
|
|
27
|
+
See https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.rel_entr.html#scipy.special.rel_entr.
|
|
28
|
+
Python implementation.
|
|
29
|
+
"""
|
|
30
|
+
res = np.empty(pk.shape, dtype=pk.dtype)
|
|
31
|
+
res[:] = pk[:] * np.log(pk[:] / qk[:])
|
|
32
|
+
c2 = (pk == 0) & (qk >= 0)
|
|
33
|
+
res[c2] = 0
|
|
34
|
+
c1 = (pk > 0) & (qk > 0)
|
|
35
|
+
res[~c1] = np.inf
|
|
36
|
+
return res
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def entropy(
|
|
40
|
+
pk: np.ndarray,
|
|
41
|
+
qk: np.ndarray,
|
|
42
|
+
base: float | None = None,
|
|
43
|
+
axis: int = 0,
|
|
44
|
+
) -> np.ndarray:
|
|
45
|
+
"""
|
|
46
|
+
Simplifeied version of entropy.
|
|
47
|
+
Source: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html.
|
|
48
|
+
This avoids taking a dependency on scipy just for this function.
|
|
49
|
+
"""
|
|
50
|
+
assert base is None or base > 0, "base={base} must be a positive number or `None`."
|
|
51
|
+
assert qk is not None, "qk is None"
|
|
52
|
+
|
|
53
|
+
pk = np.asarray(pk).astype(np.float32)
|
|
54
|
+
pk = 1.0 * pk / np.sum(pk, axis=axis, keepdims=True)
|
|
55
|
+
|
|
56
|
+
qk = np.asarray(qk).astype(np.float32)
|
|
57
|
+
pk, qk = np.broadcast_arrays(pk, qk)
|
|
58
|
+
qk = 1.0 * qk / np.sum(qk, axis=axis, keepdims=True)
|
|
59
|
+
vec = rel_entr(pk, qk)
|
|
60
|
+
|
|
61
|
+
s = np.sum(vec, axis=axis)
|
|
62
|
+
if base is not None:
|
|
63
|
+
s /= np.log(base)
|
|
64
|
+
return s.astype(pk.dtype)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class TensorData:
|
|
68
|
+
_allowed = frozenset(["avg", "std", "lowest", "highest", "hist", "hist_edges", "bins"])
|
|
69
|
+
_floats = frozenset(["avg", "std", "lowest", "highest", "hist_edges"])
|
|
70
|
+
|
|
71
|
+
def __init__(self, **kwargs):
|
|
72
|
+
self._attrs = list(kwargs.keys())
|
|
73
|
+
for k, v in kwargs.items():
|
|
74
|
+
if k not in TensorData._allowed:
|
|
75
|
+
raise ValueError(f"Unexpected value {k!r} not in {TensorData._allowed}.")
|
|
76
|
+
if k in TensorData._floats:
|
|
77
|
+
if not hasattr(v, "dtype"):
|
|
78
|
+
raise ValueError(f"Unexpected type {type(v)} for k={k!r}")
|
|
79
|
+
if v.dtype not in (np.float16, np.float32):
|
|
80
|
+
raise ValueError(f"Unexpected dtype {v.dtype} for k={k!r}")
|
|
81
|
+
setattr(self, k, v)
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def range_value(self):
|
|
85
|
+
if not hasattr(self, "lowest") or not hasattr(self, "highest"):
|
|
86
|
+
raise AttributeError(f"Attributes 'lowest' and/or 'highest' missing in {dir(self)}.")
|
|
87
|
+
return (self.lowest, self.highest)
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def avg_std(self):
|
|
91
|
+
if not hasattr(self, "avg") or not hasattr(self, "std"):
|
|
92
|
+
raise AttributeError(f"Attributes 'avg' and/or 'std' missing in {dir(self)}.")
|
|
93
|
+
return (self.avg, self.std)
|
|
94
|
+
|
|
95
|
+
def to_dict(self):
|
|
96
|
+
# This is needed to serialize the data into JSON.
|
|
97
|
+
data = {k: getattr(self, k) for k in self._attrs}
|
|
98
|
+
data["CLS"] = self.__class__.__name__
|
|
99
|
+
return data
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class TensorsData:
|
|
103
|
+
def __init__(self, calibration_method, data: dict[str, TensorData | tuple]):
|
|
104
|
+
self.calibration_method = calibration_method
|
|
105
|
+
self.data = {}
|
|
106
|
+
for k, v in data.items():
|
|
107
|
+
if not isinstance(k, str):
|
|
108
|
+
raise TypeError(f"Keys must be strings not {type(k)}.")
|
|
109
|
+
if isinstance(v, tuple):
|
|
110
|
+
if calibration_method == CalibrationMethod.MinMax and len(v) == 2:
|
|
111
|
+
self.data[k] = TensorData(lowest=v[0], highest=v[1])
|
|
112
|
+
continue
|
|
113
|
+
if len(v) == 4:
|
|
114
|
+
self.data[k] = TensorData(lowest=v[0], highest=v[1], hist=v[2], bins=v[3])
|
|
115
|
+
continue
|
|
116
|
+
raise TypeError(f"Unexpected tuple for {k:r}, it has {len(v)} elements: {v}.")
|
|
117
|
+
if not isinstance(v, TensorData):
|
|
118
|
+
raise TypeError(f"Values must be TensorData not {type(v)}.")
|
|
119
|
+
self.data[k] = v
|
|
120
|
+
|
|
121
|
+
def __iter__(self):
|
|
122
|
+
yield from self.data
|
|
123
|
+
|
|
124
|
+
def __contains__(self, key):
|
|
125
|
+
return key in self.data
|
|
126
|
+
|
|
127
|
+
def __getitem__(self, key):
|
|
128
|
+
return self.data[key]
|
|
129
|
+
|
|
130
|
+
def __setitem__(self, key, value):
|
|
131
|
+
if key not in self.data:
|
|
132
|
+
raise RuntimeError(f"Only an existing tensor can be modified, {key!r} is not.")
|
|
133
|
+
self.data[key] = value
|
|
134
|
+
|
|
135
|
+
def keys(self):
|
|
136
|
+
return self.data.keys()
|
|
137
|
+
|
|
138
|
+
def values(self):
|
|
139
|
+
return self.data.values()
|
|
140
|
+
|
|
141
|
+
def items(self):
|
|
142
|
+
return self.data.items()
|
|
143
|
+
|
|
144
|
+
def to_dict(self):
|
|
145
|
+
# This is needed to serialize the data into JSON.
|
|
146
|
+
data = {
|
|
147
|
+
"CLS": self.__class__.__name__,
|
|
148
|
+
"data": self.data,
|
|
149
|
+
"calibration_method": self.calibration_method,
|
|
150
|
+
}
|
|
151
|
+
return data
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class CalibrationMethod(Enum):
|
|
155
|
+
MinMax = 0
|
|
156
|
+
Entropy = 1
|
|
157
|
+
Percentile = 2
|
|
158
|
+
Distribution = 3
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class CalibrationDataReader(metaclass=abc.ABCMeta):
|
|
162
|
+
@classmethod
|
|
163
|
+
def __subclasshook__(cls, subclass):
|
|
164
|
+
return (hasattr(subclass, "get_next") and callable(subclass.get_next)) or NotImplemented
|
|
165
|
+
|
|
166
|
+
@abc.abstractmethod
|
|
167
|
+
def get_next(self) -> dict:
|
|
168
|
+
"""generate the input data dict for ONNXinferenceSession run"""
|
|
169
|
+
raise NotImplementedError
|
|
170
|
+
|
|
171
|
+
def __iter__(self):
|
|
172
|
+
return self
|
|
173
|
+
|
|
174
|
+
def __next__(self):
|
|
175
|
+
result = self.get_next()
|
|
176
|
+
if result is None:
|
|
177
|
+
raise StopIteration
|
|
178
|
+
return result
|
|
179
|
+
|
|
180
|
+
def __len__(self):
|
|
181
|
+
raise NotImplementedError
|
|
182
|
+
|
|
183
|
+
def set_range(self, start_index: int, end_index: int):
|
|
184
|
+
raise NotImplementedError
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class CalibraterBase:
|
|
188
|
+
def __init__(
|
|
189
|
+
self,
|
|
190
|
+
model_path: str | Path,
|
|
191
|
+
op_types_to_calibrate: Sequence[str] | None = None,
|
|
192
|
+
augmented_model_path="augmented_model.onnx",
|
|
193
|
+
symmetric=False,
|
|
194
|
+
use_external_data_format=False,
|
|
195
|
+
per_channel=False,
|
|
196
|
+
):
|
|
197
|
+
"""
|
|
198
|
+
:param model_path: ONNX model to calibrate. It should be a model file path
|
|
199
|
+
:param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors.
|
|
200
|
+
:param augmented_model_path: save augmented model to this path.
|
|
201
|
+
:param symmetric: make range of tensor symmetric (central point is 0).
|
|
202
|
+
:param use_external_data_format: use external data format to store model which size is >= 2Gb.
|
|
203
|
+
:param per_channel: whether to compute ranges per each channel.
|
|
204
|
+
"""
|
|
205
|
+
if isinstance(model_path, str):
|
|
206
|
+
self.model = load_model_with_shape_infer(Path(model_path))
|
|
207
|
+
elif isinstance(model_path, Path):
|
|
208
|
+
self.model = load_model_with_shape_infer(model_path)
|
|
209
|
+
else:
|
|
210
|
+
raise ValueError("model_path should be model path.")
|
|
211
|
+
|
|
212
|
+
self.op_types_to_calibrate = op_types_to_calibrate
|
|
213
|
+
self.augmented_model_path = augmented_model_path
|
|
214
|
+
self.symmetric = symmetric
|
|
215
|
+
self.use_external_data_format = use_external_data_format
|
|
216
|
+
self.per_channel = per_channel
|
|
217
|
+
|
|
218
|
+
self.augment_model = None
|
|
219
|
+
self.infer_session = None
|
|
220
|
+
self.execution_providers = ["CPUExecutionProvider"]
|
|
221
|
+
|
|
222
|
+
def set_execution_providers(self, execution_providers=["CPUExecutionProvider"]): # noqa: B006
|
|
223
|
+
"""
|
|
224
|
+
reset the execution providers to execute the collect_data. It triggers to re-creating inference session.
|
|
225
|
+
"""
|
|
226
|
+
self.execution_providers = execution_providers
|
|
227
|
+
self.create_inference_session()
|
|
228
|
+
|
|
229
|
+
def create_inference_session(self):
|
|
230
|
+
"""
|
|
231
|
+
create an OnnxRuntime InferenceSession.
|
|
232
|
+
"""
|
|
233
|
+
sess_options = onnxruntime.SessionOptions()
|
|
234
|
+
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
|
|
235
|
+
self.infer_session = onnxruntime.InferenceSession(
|
|
236
|
+
self.augmented_model_path,
|
|
237
|
+
sess_options=sess_options,
|
|
238
|
+
providers=self.execution_providers,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
def select_tensors_to_calibrate(self, model: ModelProto):
|
|
242
|
+
"""
|
|
243
|
+
select input/output tensors of candidate nodes to calibrate.
|
|
244
|
+
returns:
|
|
245
|
+
tensors (set): set of tensor name.
|
|
246
|
+
value_infos (dict): tensor name to value info.
|
|
247
|
+
"""
|
|
248
|
+
value_infos = {vi.name: vi for vi in model.graph.value_info}
|
|
249
|
+
value_infos.update({ot.name: ot for ot in model.graph.output})
|
|
250
|
+
value_infos.update({it.name: it for it in model.graph.input})
|
|
251
|
+
initializer = {init.name for init in model.graph.initializer}
|
|
252
|
+
|
|
253
|
+
tensors_to_calibrate = set()
|
|
254
|
+
tensor_type_to_calibrate = {TensorProto.FLOAT, TensorProto.FLOAT16}
|
|
255
|
+
|
|
256
|
+
for node in model.graph.node:
|
|
257
|
+
if not self.op_types_to_calibrate or node.op_type in self.op_types_to_calibrate:
|
|
258
|
+
for tensor_name in itertools.chain(node.input, node.output):
|
|
259
|
+
if tensor_name in value_infos:
|
|
260
|
+
vi = value_infos[tensor_name]
|
|
261
|
+
if (
|
|
262
|
+
vi.type.HasField("tensor_type")
|
|
263
|
+
and (vi.type.tensor_type.elem_type in tensor_type_to_calibrate)
|
|
264
|
+
and (tensor_name not in initializer)
|
|
265
|
+
):
|
|
266
|
+
tensors_to_calibrate.add(tensor_name)
|
|
267
|
+
|
|
268
|
+
return tensors_to_calibrate, value_infos
|
|
269
|
+
|
|
270
|
+
def get_augment_model(self):
|
|
271
|
+
"""
|
|
272
|
+
return: augmented onnx model. Call after calling augment_graph
|
|
273
|
+
"""
|
|
274
|
+
return self.model
|
|
275
|
+
|
|
276
|
+
def augment_graph(self):
|
|
277
|
+
"""
|
|
278
|
+
abstract method: augment the input model to prepare for collecting data. It will:
|
|
279
|
+
1. augment the model to be able to collect desired statistics data
|
|
280
|
+
2. save augmented model to augmented_model_paths
|
|
281
|
+
"""
|
|
282
|
+
raise NotImplementedError
|
|
283
|
+
|
|
284
|
+
def collect_data(self, data_reader: CalibrationDataReader):
|
|
285
|
+
"""
|
|
286
|
+
abstract method: collect the tensors that will be used for range computation. It can be called multiple times.
|
|
287
|
+
"""
|
|
288
|
+
raise NotImplementedError
|
|
289
|
+
|
|
290
|
+
def compute_data(self) -> TensorsData:
|
|
291
|
+
"""
|
|
292
|
+
abstract method: compute data based on the calibration method stored in TensorsData
|
|
293
|
+
"""
|
|
294
|
+
raise NotImplementedError
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
class MinMaxCalibrater(CalibraterBase):
|
|
298
|
+
def __init__(
|
|
299
|
+
self,
|
|
300
|
+
model_path: str | Path,
|
|
301
|
+
op_types_to_calibrate: Sequence[str] | None = None,
|
|
302
|
+
augmented_model_path="augmented_model.onnx",
|
|
303
|
+
symmetric=False,
|
|
304
|
+
use_external_data_format=False,
|
|
305
|
+
moving_average=False,
|
|
306
|
+
averaging_constant=0.01,
|
|
307
|
+
max_intermediate_outputs=None,
|
|
308
|
+
per_channel=False,
|
|
309
|
+
):
|
|
310
|
+
"""
|
|
311
|
+
:param model_path: ONNX model to calibrate. It is a model path
|
|
312
|
+
:param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors.
|
|
313
|
+
:param augmented_model_path: save augmented model to this path.
|
|
314
|
+
:param symmetric: make range of tensor symmetric (central point is 0).
|
|
315
|
+
:param use_external_data_format: use external data format to store model which size is >= 2Gb
|
|
316
|
+
:param moving_average: compute the moving average of the minimum and maximum values instead of the global minimum and maximum.
|
|
317
|
+
:param averaging_constant: constant smoothing factor to use when computing the moving average.
|
|
318
|
+
:param max_intermediate_outputs: maximum number of intermediate outputs before an intermediate range is computed.
|
|
319
|
+
:param per_channel: whether to compute ranges per each channel.
|
|
320
|
+
"""
|
|
321
|
+
super().__init__(
|
|
322
|
+
model_path,
|
|
323
|
+
op_types_to_calibrate=op_types_to_calibrate,
|
|
324
|
+
augmented_model_path=augmented_model_path,
|
|
325
|
+
symmetric=symmetric,
|
|
326
|
+
use_external_data_format=use_external_data_format,
|
|
327
|
+
per_channel=per_channel,
|
|
328
|
+
)
|
|
329
|
+
self.intermediate_outputs = []
|
|
330
|
+
self.calibrate_tensors_range = None
|
|
331
|
+
self.num_model_outputs = len(self.model.graph.output)
|
|
332
|
+
self.model_original_outputs = {output.name for output in self.model.graph.output}
|
|
333
|
+
self.moving_average = moving_average
|
|
334
|
+
if moving_average and (averaging_constant < 0 or averaging_constant > 1):
|
|
335
|
+
raise ValueError("Invalid averaging constant, which should not be < 0 or > 1.")
|
|
336
|
+
self.averaging_constant = averaging_constant
|
|
337
|
+
self.max_intermediate_outputs = max_intermediate_outputs
|
|
338
|
+
|
|
339
|
+
def augment_graph(self):
|
|
340
|
+
"""
|
|
341
|
+
Adds ReduceMin and ReduceMax nodes to all quantization_candidates op type nodes in
|
|
342
|
+
model and ensures their outputs are stored as part of the graph output
|
|
343
|
+
:return: augmented ONNX model
|
|
344
|
+
"""
|
|
345
|
+
tensors, _ = self.select_tensors_to_calibrate(self.model)
|
|
346
|
+
reshape_shape_name = str(uuid.uuid4())
|
|
347
|
+
reshape_shape = numpy_helper.from_array(np.array([-1], dtype=np.int64), reshape_shape_name)
|
|
348
|
+
self.model.graph.initializer.append(reshape_shape)
|
|
349
|
+
|
|
350
|
+
def get_op_version(op_type, model):
|
|
351
|
+
for opset_import in model.opset_import:
|
|
352
|
+
if onnx.defs.has(op_type, opset_import.domain):
|
|
353
|
+
return opset_import.version
|
|
354
|
+
raise RuntimeError(f"Model does not contain a version for '{op_type}'.")
|
|
355
|
+
|
|
356
|
+
def insert_nodes(tensor_name, new_nodes):
|
|
357
|
+
index = next(
|
|
358
|
+
(i for i, x in enumerate(self.model.graph.node) if tensor_name in x.input), len(self.model.graph.node)
|
|
359
|
+
)
|
|
360
|
+
for node in new_nodes:
|
|
361
|
+
self.model.graph.node.insert(index, node)
|
|
362
|
+
index += 1
|
|
363
|
+
|
|
364
|
+
def add_reduce_min_max(tensor_name, reduce_op_name):
|
|
365
|
+
# When doing ReduceMax/ReduceMin, ORT can't reduce on dim with value of 0 if 'keepdims' is false.
|
|
366
|
+
# To make the code simple, we always let keepdims to be 1.
|
|
367
|
+
keepdims = 1
|
|
368
|
+
|
|
369
|
+
# Adding ReduceMin/ReduceMax nodes: ReduceMin/ReduceMax -> Reshape-> (output)
|
|
370
|
+
reduce_output = tensor_name + "_" + reduce_op_name
|
|
371
|
+
intermediate_output = reduce_output + "_Reshape"
|
|
372
|
+
reduce_node = onnx.helper.make_node(
|
|
373
|
+
reduce_op_name, [tensor_name], [intermediate_output], keepdims=keepdims, name=reduce_output
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
reshape_node = onnx.helper.make_node(
|
|
377
|
+
"Reshape",
|
|
378
|
+
inputs=[intermediate_output, reshape_shape_name],
|
|
379
|
+
outputs=[reduce_output],
|
|
380
|
+
name=intermediate_output,
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
value_infos = {vi.name: vi for vi in self.model.graph.value_info}
|
|
384
|
+
value_infos.update({o.name: o for o in self.model.graph.output})
|
|
385
|
+
value_infos.update({i.name: i for i in self.model.graph.input})
|
|
386
|
+
if tensor_name in value_infos:
|
|
387
|
+
onnx_type = value_infos[tensor_name].type.tensor_type.elem_type
|
|
388
|
+
else:
|
|
389
|
+
raise ValueError(
|
|
390
|
+
f"Unable to guess tensor type for tensor {tensor_name!r}, "
|
|
391
|
+
"running shape inference before quantization may resolve this issue."
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
# Include axes in reduce_op when per_channel, always keeping axis=1
|
|
395
|
+
if self.per_channel:
|
|
396
|
+
tensor_rank = len(value_infos[tensor_name].type.tensor_type.shape.dim)
|
|
397
|
+
reduced_axes = [0, *range(2, tensor_rank)]
|
|
398
|
+
# Depending on opset version, axes in ReduceMin/ReduceMax are in attribute or inputs
|
|
399
|
+
if get_op_version(reduce_op_name, self.model) < 18:
|
|
400
|
+
reduce_node.attribute.append(helper.make_attribute("axes", reduced_axes))
|
|
401
|
+
else:
|
|
402
|
+
reduce_axes_name = str(uuid.uuid4())
|
|
403
|
+
reduce_axes = numpy_helper.from_array(np.array(reduced_axes, dtype=np.int64), reduce_axes_name)
|
|
404
|
+
reduce_node.input.append(reduce_axes_name)
|
|
405
|
+
self.model.graph.initializer.append(reduce_axes)
|
|
406
|
+
|
|
407
|
+
insert_nodes(tensor_name, [reduce_node, reshape_node])
|
|
408
|
+
self.model.graph.output.append(helper.make_tensor_value_info(reduce_output, onnx_type, [None]))
|
|
409
|
+
|
|
410
|
+
for tensor in tensors:
|
|
411
|
+
add_reduce_min_max(tensor, "ReduceMin")
|
|
412
|
+
add_reduce_min_max(tensor, "ReduceMax")
|
|
413
|
+
|
|
414
|
+
onnx.save(
|
|
415
|
+
self.model,
|
|
416
|
+
self.augmented_model_path,
|
|
417
|
+
save_as_external_data=self.use_external_data_format,
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
def clear_collected_data(self):
|
|
421
|
+
self.intermediate_outputs = []
|
|
422
|
+
|
|
423
|
+
def collect_data(self, data_reader: CalibrationDataReader):
|
|
424
|
+
while True:
|
|
425
|
+
inputs = data_reader.get_next()
|
|
426
|
+
if not inputs:
|
|
427
|
+
break
|
|
428
|
+
self.intermediate_outputs.append(
|
|
429
|
+
[
|
|
430
|
+
value if sess_o.name not in self.model_original_outputs else None
|
|
431
|
+
for sess_o, value in zip(
|
|
432
|
+
self.infer_session.get_outputs(), self.infer_session.run(None, inputs), strict=False
|
|
433
|
+
)
|
|
434
|
+
]
|
|
435
|
+
)
|
|
436
|
+
if (
|
|
437
|
+
self.max_intermediate_outputs is not None
|
|
438
|
+
and len(self.intermediate_outputs) == self.max_intermediate_outputs
|
|
439
|
+
):
|
|
440
|
+
self.clear_collected_data()
|
|
441
|
+
|
|
442
|
+
if len(self.intermediate_outputs) == 0 and self.calibrate_tensors_range is None:
|
|
443
|
+
raise ValueError("No data is collected.")
|
|
444
|
+
|
|
445
|
+
t = self.compute_data()
|
|
446
|
+
if not isinstance(t, TensorsData):
|
|
447
|
+
raise TypeError(f"compute_data must return a TensorsData not {type(t)}.")
|
|
448
|
+
self.clear_collected_data()
|
|
449
|
+
|
|
450
|
+
def merge_range(self, old_range, new_range):
|
|
451
|
+
if not old_range:
|
|
452
|
+
return new_range
|
|
453
|
+
|
|
454
|
+
for key, value in old_range.items():
|
|
455
|
+
# Handling for structured data types with TensorData
|
|
456
|
+
if isinstance(value, TensorData):
|
|
457
|
+
old_min = value.range_value[0]
|
|
458
|
+
old_max = value.range_value[1]
|
|
459
|
+
else:
|
|
460
|
+
old_min, old_max = value
|
|
461
|
+
|
|
462
|
+
if isinstance(new_range[key], TensorData):
|
|
463
|
+
new_min = new_range[key].range_value[0]
|
|
464
|
+
new_max = new_range[key].range_value[1]
|
|
465
|
+
else:
|
|
466
|
+
new_min, new_max = new_range[key]
|
|
467
|
+
|
|
468
|
+
if self.moving_average:
|
|
469
|
+
min_value = old_min + self.averaging_constant * (new_min - old_min)
|
|
470
|
+
max_value = old_max + self.averaging_constant * (new_max - old_max)
|
|
471
|
+
else:
|
|
472
|
+
min_value = min(old_min, new_min)
|
|
473
|
+
max_value = max(old_max, new_max)
|
|
474
|
+
|
|
475
|
+
# If structured as TensorData, wrap the result accordingly
|
|
476
|
+
if isinstance(value, TensorData) or isinstance(new_range[key], TensorData):
|
|
477
|
+
new_range[key] = TensorData(lowest=min_value, highest=max_value)
|
|
478
|
+
else:
|
|
479
|
+
new_range[key] = (min_value, max_value)
|
|
480
|
+
|
|
481
|
+
return new_range
|
|
482
|
+
|
|
483
|
+
def compute_data(self) -> TensorsData:
|
|
484
|
+
"""
|
|
485
|
+
Compute the min-max range of tensor
|
|
486
|
+
:return: dictionary mapping: {added node names: (ReduceMin, ReduceMax) pairs }
|
|
487
|
+
"""
|
|
488
|
+
|
|
489
|
+
if len(self.intermediate_outputs) == 0:
|
|
490
|
+
return self.calibrate_tensors_range
|
|
491
|
+
|
|
492
|
+
output_names = [self.infer_session.get_outputs()[i].name for i in range(len(self.intermediate_outputs[0]))]
|
|
493
|
+
output_dicts_list = [
|
|
494
|
+
dict(zip(output_names, intermediate_output, strict=False))
|
|
495
|
+
for intermediate_output in self.intermediate_outputs
|
|
496
|
+
]
|
|
497
|
+
|
|
498
|
+
merged_output_dict = {}
|
|
499
|
+
for d in output_dicts_list:
|
|
500
|
+
for k, v in d.items():
|
|
501
|
+
merged_output_dict.setdefault(k, []).append(v)
|
|
502
|
+
added_output_names = output_names[self.num_model_outputs :]
|
|
503
|
+
calibrate_tensor_names = [
|
|
504
|
+
added_output_names[i].rpartition("_")[0] for i in range(0, len(added_output_names), 2)
|
|
505
|
+
] # output names
|
|
506
|
+
|
|
507
|
+
merged_added_output_dict = {
|
|
508
|
+
i: merged_output_dict[i] for i in merged_output_dict if i not in self.model_original_outputs
|
|
509
|
+
}
|
|
510
|
+
|
|
511
|
+
pairs = []
|
|
512
|
+
for i in range(0, len(added_output_names), 2):
|
|
513
|
+
if self.moving_average:
|
|
514
|
+
min_value_array = np.nanmean(merged_added_output_dict[added_output_names[i]], axis=0)
|
|
515
|
+
max_value_array = np.nanmean(merged_added_output_dict[added_output_names[i + 1]], axis=0)
|
|
516
|
+
else:
|
|
517
|
+
min_value_array = np.nanmin(merged_added_output_dict[added_output_names[i]], axis=0)
|
|
518
|
+
max_value_array = np.nanmax(merged_added_output_dict[added_output_names[i + 1]], axis=0)
|
|
519
|
+
|
|
520
|
+
if self.symmetric:
|
|
521
|
+
max_absolute_value = np.nanmax([np.abs(min_value_array), np.abs(max_value_array)], axis=0)
|
|
522
|
+
pairs.append((-max_absolute_value, max_absolute_value))
|
|
523
|
+
else:
|
|
524
|
+
pairs.append((min_value_array, max_value_array))
|
|
525
|
+
|
|
526
|
+
new_calibrate_tensors_range = TensorsData(
|
|
527
|
+
CalibrationMethod.MinMax, dict(zip(calibrate_tensor_names, pairs, strict=False))
|
|
528
|
+
)
|
|
529
|
+
if self.calibrate_tensors_range:
|
|
530
|
+
self.calibrate_tensors_range = self.merge_range(self.calibrate_tensors_range, new_calibrate_tensors_range)
|
|
531
|
+
else:
|
|
532
|
+
self.calibrate_tensors_range = new_calibrate_tensors_range
|
|
533
|
+
|
|
534
|
+
return self.calibrate_tensors_range
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
class HistogramCalibrater(CalibraterBase):
|
|
538
|
+
def __init__(
|
|
539
|
+
self,
|
|
540
|
+
model_path: str | Path,
|
|
541
|
+
op_types_to_calibrate: Sequence[str] | None = None,
|
|
542
|
+
augmented_model_path="augmented_model.onnx",
|
|
543
|
+
use_external_data_format=False,
|
|
544
|
+
method="percentile",
|
|
545
|
+
symmetric=False,
|
|
546
|
+
num_bins=128,
|
|
547
|
+
num_quantized_bins=2048,
|
|
548
|
+
percentile=99.999,
|
|
549
|
+
scenario="same",
|
|
550
|
+
):
|
|
551
|
+
"""
|
|
552
|
+
:param model_path: ONNX model to calibrate. It is a model path.
|
|
553
|
+
:param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors.
|
|
554
|
+
:param augmented_model_path: save augmented model to this path.
|
|
555
|
+
:param use_external_data_format: use external data format to store model which size is >= 2Gb
|
|
556
|
+
:param method: A string. One of ['entropy', 'percentile'].
|
|
557
|
+
:param symmetric: make range of tensor symmetric (central point is 0).
|
|
558
|
+
:param num_bins: number of bins to create a new histogram for collecting tensor values.
|
|
559
|
+
:param num_quantized_bins: number of quantized bins. Default 128.
|
|
560
|
+
:param percentile: A float number between [0, 100]. Default 99.99.
|
|
561
|
+
:param scenario: see :class:`DistributionCalibrater`
|
|
562
|
+
"""
|
|
563
|
+
super().__init__(
|
|
564
|
+
model_path,
|
|
565
|
+
op_types_to_calibrate=op_types_to_calibrate,
|
|
566
|
+
augmented_model_path=augmented_model_path,
|
|
567
|
+
symmetric=symmetric,
|
|
568
|
+
use_external_data_format=use_external_data_format,
|
|
569
|
+
)
|
|
570
|
+
self.intermediate_outputs = []
|
|
571
|
+
self.calibrate_tensors_range = None
|
|
572
|
+
self.num_model_outputs = len(self.model.graph.output)
|
|
573
|
+
self.model_original_outputs = {output.name for output in self.model.graph.output}
|
|
574
|
+
self.collector = None
|
|
575
|
+
self.method = method
|
|
576
|
+
self.num_bins = num_bins
|
|
577
|
+
self.num_quantized_bins = num_quantized_bins
|
|
578
|
+
self.percentile = percentile
|
|
579
|
+
self.tensors_to_calibrate = None
|
|
580
|
+
self.scenario = scenario
|
|
581
|
+
|
|
582
|
+
def augment_graph(self):
|
|
583
|
+
"""
|
|
584
|
+
make all quantization_candidates op type nodes as part of the graph output.
|
|
585
|
+
:return: augmented ONNX model
|
|
586
|
+
"""
|
|
587
|
+
self.tensors_to_calibrate, value_infos = self.select_tensors_to_calibrate(self.model)
|
|
588
|
+
for tensor in self.tensors_to_calibrate:
|
|
589
|
+
if tensor not in self.model_original_outputs:
|
|
590
|
+
self.model.graph.output.append(value_infos[tensor])
|
|
591
|
+
|
|
592
|
+
onnx.save(
|
|
593
|
+
self.model,
|
|
594
|
+
self.augmented_model_path,
|
|
595
|
+
save_as_external_data=self.use_external_data_format,
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
def clear_collected_data(self):
|
|
599
|
+
self.intermediate_outputs = []
|
|
600
|
+
|
|
601
|
+
def collect_data(self, data_reader: CalibrationDataReader):
|
|
602
|
+
"""
|
|
603
|
+
Entropy Calibrator collects operators' tensors as well as generates tensor histogram for each operator.
|
|
604
|
+
"""
|
|
605
|
+
input_names_set = {node_arg.name for node_arg in self.infer_session.get_inputs()}
|
|
606
|
+
output_names = [node_arg.name for node_arg in self.infer_session.get_outputs()]
|
|
607
|
+
|
|
608
|
+
while True:
|
|
609
|
+
inputs = data_reader.get_next()
|
|
610
|
+
if not inputs:
|
|
611
|
+
break
|
|
612
|
+
outputs = self.infer_session.run(None, inputs)
|
|
613
|
+
|
|
614
|
+
# Copy np.ndarray only for graph outputs that are also graph inputs to workaround bug:
|
|
615
|
+
# https://github.com/microsoft/onnxruntime/issues/21922
|
|
616
|
+
fixed_outputs = []
|
|
617
|
+
for output_index, output in enumerate(outputs):
|
|
618
|
+
if output_names[output_index] in input_names_set:
|
|
619
|
+
fixed_outputs.append(copy.copy(output))
|
|
620
|
+
else:
|
|
621
|
+
fixed_outputs.append(output)
|
|
622
|
+
|
|
623
|
+
self.intermediate_outputs.append(fixed_outputs)
|
|
624
|
+
|
|
625
|
+
if len(self.intermediate_outputs) == 0:
|
|
626
|
+
raise ValueError("No data is collected.")
|
|
627
|
+
|
|
628
|
+
output_dicts_list = [
|
|
629
|
+
dict(zip(output_names, intermediate_output, strict=False))
|
|
630
|
+
for intermediate_output in self.intermediate_outputs
|
|
631
|
+
]
|
|
632
|
+
|
|
633
|
+
merged_dict = {}
|
|
634
|
+
for d in output_dicts_list:
|
|
635
|
+
for k, v in d.items():
|
|
636
|
+
merged_dict.setdefault(k, []).append(v)
|
|
637
|
+
|
|
638
|
+
clean_merged_dict = {i: merged_dict[i] for i in merged_dict if i in self.tensors_to_calibrate}
|
|
639
|
+
|
|
640
|
+
if not self.collector:
|
|
641
|
+
self.collector = HistogramCollector(
|
|
642
|
+
method=self.method,
|
|
643
|
+
symmetric=self.symmetric,
|
|
644
|
+
num_bins=self.num_bins,
|
|
645
|
+
num_quantized_bins=self.num_quantized_bins,
|
|
646
|
+
percentile=self.percentile,
|
|
647
|
+
scenario=self.scenario,
|
|
648
|
+
)
|
|
649
|
+
self.collector.collect(clean_merged_dict)
|
|
650
|
+
|
|
651
|
+
self.clear_collected_data()
|
|
652
|
+
|
|
653
|
+
def compute_data(self) -> TensorsData:
|
|
654
|
+
"""
|
|
655
|
+
Compute the min-max range of tensor
|
|
656
|
+
:return: dictionary mapping: {tensor name: (min value, max value)}
|
|
657
|
+
"""
|
|
658
|
+
if not self.collector:
|
|
659
|
+
raise ValueError("No collector created and can't generate calibration data.")
|
|
660
|
+
|
|
661
|
+
if isinstance(self, EntropyCalibrater):
|
|
662
|
+
cal = CalibrationMethod.Entropy
|
|
663
|
+
elif isinstance(self, PercentileCalibrater):
|
|
664
|
+
cal = CalibrationMethod.Percentile
|
|
665
|
+
elif isinstance(self, DistributionCalibrater):
|
|
666
|
+
cal = CalibrationMethod.Distribution
|
|
667
|
+
else:
|
|
668
|
+
raise TypeError(f"Unknown calibrater {type(self)}. This method must be overwritten.")
|
|
669
|
+
return TensorsData(cal, self.collector.compute_collection_result())
|
|
670
|
+
|
|
671
|
+
|
|
672
|
+
class EntropyCalibrater(HistogramCalibrater):
|
|
673
|
+
def __init__(
|
|
674
|
+
self,
|
|
675
|
+
model_path: str | Path,
|
|
676
|
+
op_types_to_calibrate: Sequence[str] | None = None,
|
|
677
|
+
augmented_model_path="augmented_model.onnx",
|
|
678
|
+
use_external_data_format=False,
|
|
679
|
+
method="entropy",
|
|
680
|
+
symmetric=False,
|
|
681
|
+
num_bins=128,
|
|
682
|
+
num_quantized_bins=128,
|
|
683
|
+
):
|
|
684
|
+
"""
|
|
685
|
+
:param model_path: ONNX model to calibrate. It is a model path
|
|
686
|
+
:param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors.
|
|
687
|
+
:param augmented_model_path: save augmented model to this path.
|
|
688
|
+
:param use_external_data_format: use external data format to store model which size is >= 2Gb
|
|
689
|
+
:param method: A string. One of ['entropy', 'percentile', 'distribution'].
|
|
690
|
+
:param symmetric: make range of tensor symmetric (central point is 0).
|
|
691
|
+
:param num_bins: number of bins to create a new histogram for collecting tensor values.
|
|
692
|
+
:param num_quantized_bins: number of quantized bins. Default 128.
|
|
693
|
+
"""
|
|
694
|
+
super().__init__(
|
|
695
|
+
model_path,
|
|
696
|
+
op_types_to_calibrate,
|
|
697
|
+
augmented_model_path,
|
|
698
|
+
use_external_data_format,
|
|
699
|
+
method=method,
|
|
700
|
+
symmetric=symmetric,
|
|
701
|
+
num_bins=num_bins,
|
|
702
|
+
num_quantized_bins=num_quantized_bins,
|
|
703
|
+
)
|
|
704
|
+
|
|
705
|
+
|
|
706
|
+
class PercentileCalibrater(HistogramCalibrater):
|
|
707
|
+
def __init__(
|
|
708
|
+
self,
|
|
709
|
+
model_path: str | Path,
|
|
710
|
+
op_types_to_calibrate: Sequence[str] | None = None,
|
|
711
|
+
augmented_model_path="augmented_model.onnx",
|
|
712
|
+
use_external_data_format=False,
|
|
713
|
+
method="percentile",
|
|
714
|
+
symmetric=False,
|
|
715
|
+
num_bins=2048,
|
|
716
|
+
percentile=99.999,
|
|
717
|
+
):
|
|
718
|
+
"""
|
|
719
|
+
:param model_path: ONNX model to calibrate. It is a model path
|
|
720
|
+
:param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors.
|
|
721
|
+
:param augmented_model_path: save augmented model to this path.
|
|
722
|
+
:param use_external_data_format: use external data format to store model which size is >= 2Gb
|
|
723
|
+
:param method: A string. One of ['entropy', 'percentile', 'distribution'].
|
|
724
|
+
:param symmetric: make range of tensor symmetric (central point is 0).
|
|
725
|
+
:param num_quantized_bins: number of quantized bins. Default 128.
|
|
726
|
+
:param percentile: A float number between [0, 100]. Default 99.99.
|
|
727
|
+
"""
|
|
728
|
+
super().__init__(
|
|
729
|
+
model_path,
|
|
730
|
+
op_types_to_calibrate,
|
|
731
|
+
augmented_model_path,
|
|
732
|
+
use_external_data_format,
|
|
733
|
+
method=method,
|
|
734
|
+
symmetric=symmetric,
|
|
735
|
+
num_bins=num_bins,
|
|
736
|
+
percentile=percentile,
|
|
737
|
+
)
|
|
738
|
+
|
|
739
|
+
|
|
740
|
+
class DistributionCalibrater(HistogramCalibrater):
|
|
741
|
+
def __init__(
|
|
742
|
+
self,
|
|
743
|
+
model_path: str | Path,
|
|
744
|
+
op_types_to_calibrate: Sequence[str] | None = None,
|
|
745
|
+
augmented_model_path="augmented_model.onnx",
|
|
746
|
+
use_external_data_format=False,
|
|
747
|
+
method="distribution",
|
|
748
|
+
num_bins=128,
|
|
749
|
+
scenario="same",
|
|
750
|
+
):
|
|
751
|
+
"""
|
|
752
|
+
:param model_path: ONNX model to calibrate. It is a model path
|
|
753
|
+
:param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors.
|
|
754
|
+
:param augmented_model_path: save augmented model to this path.
|
|
755
|
+
:param use_external_data_format: use external data format to store model which size is >= 2Gb
|
|
756
|
+
:param method: A string. One of ['entropy', 'percentile', 'distribution'].
|
|
757
|
+
:param symmetric: make range of tensor symmetric (central point is 0).
|
|
758
|
+
:param num_bins: number of bins to create a new histogram for collecting tensor values.
|
|
759
|
+
:param scenario: for float 8 only, if `scenario="same"`,
|
|
760
|
+
the algorithm weights and float 8 follow the same distribution,
|
|
761
|
+
if `scenario="p3"`, it assumes the weights follow
|
|
762
|
+
a gaussian law and float 8 ~ X^3 where X is a gaussian law
|
|
763
|
+
"""
|
|
764
|
+
super().__init__(
|
|
765
|
+
model_path,
|
|
766
|
+
op_types_to_calibrate,
|
|
767
|
+
augmented_model_path,
|
|
768
|
+
use_external_data_format,
|
|
769
|
+
method=method,
|
|
770
|
+
num_bins=num_bins,
|
|
771
|
+
scenario=scenario,
|
|
772
|
+
)
|
|
773
|
+
|
|
774
|
+
|
|
775
|
+
class CalibrationDataCollector(metaclass=abc.ABCMeta):
|
|
776
|
+
"""
|
|
777
|
+
Base class for collecting data for calibration-based quantization.
|
|
778
|
+
"""
|
|
779
|
+
|
|
780
|
+
@abc.abstractmethod
|
|
781
|
+
def collect(self, name_to_arr):
|
|
782
|
+
"""
|
|
783
|
+
Generate informative data based on given data.
|
|
784
|
+
name_to_arr : dict
|
|
785
|
+
tensor name to NDArray data
|
|
786
|
+
"""
|
|
787
|
+
raise NotImplementedError
|
|
788
|
+
|
|
789
|
+
@abc.abstractmethod
|
|
790
|
+
def compute_collection_result(self):
|
|
791
|
+
"""
|
|
792
|
+
Get the optimal result among collection data.
|
|
793
|
+
"""
|
|
794
|
+
raise NotImplementedError
|
|
795
|
+
|
|
796
|
+
|
|
797
|
+
class HistogramCollector(CalibrationDataCollector):
|
|
798
|
+
"""
|
|
799
|
+
Collecting histogram for each tensor. Percentile and Entropy method are supported.
|
|
800
|
+
|
|
801
|
+
ref: https://github.com//apache/incubator-mxnet/blob/master/python/mxnet/contrib/quantization.py
|
|
802
|
+
ref: https://docs.nvidia.com/deeplearning/tensorrt/pytorch-quantization-toolkit/docs/_modules/
|
|
803
|
+
pytorch_quantization/calib/histogram.html
|
|
804
|
+
"""
|
|
805
|
+
|
|
806
|
+
def __init__(self, method, symmetric, num_bins, num_quantized_bins, percentile, scenario):
|
|
807
|
+
self.histogram_dict = {}
|
|
808
|
+
self.method = method
|
|
809
|
+
self.symmetric = symmetric
|
|
810
|
+
self.num_bins = num_bins
|
|
811
|
+
self.num_quantized_bins = num_quantized_bins
|
|
812
|
+
self.percentile = percentile
|
|
813
|
+
self.scenario = scenario
|
|
814
|
+
|
|
815
|
+
def get_histogram_dict(self):
|
|
816
|
+
return self.histogram_dict
|
|
817
|
+
|
|
818
|
+
def collect(self, name_to_arr):
|
|
819
|
+
print("Collecting tensor data and making histogram ...")
|
|
820
|
+
|
|
821
|
+
# TODO: Currently we have different collect() for entropy and percentile method respectively.
|
|
822
|
+
# Need unified collect in the future.
|
|
823
|
+
if self.method in {"distribution", "entropy"}:
|
|
824
|
+
return self.collect_value(name_to_arr)
|
|
825
|
+
elif self.method == "percentile":
|
|
826
|
+
if self.symmetric:
|
|
827
|
+
return self.collect_absolute_value(name_to_arr)
|
|
828
|
+
else:
|
|
829
|
+
return self.collect_value(name_to_arr)
|
|
830
|
+
else:
|
|
831
|
+
raise ValueError("Only 'entropy', 'percentile' or 'distribution' methods are supported")
|
|
832
|
+
|
|
833
|
+
def collect_absolute_value(self, name_to_arr):
|
|
834
|
+
"""
|
|
835
|
+
Collect histogram on absolute value
|
|
836
|
+
"""
|
|
837
|
+
for tensor, data_arr in name_to_arr.items():
|
|
838
|
+
if isinstance(data_arr, list):
|
|
839
|
+
for arr in data_arr:
|
|
840
|
+
assert isinstance(arr, np.ndarray), f"Unexpected type {type(arr)} for tensor={tensor!r}"
|
|
841
|
+
dtypes = {a.dtype for a in data_arr}
|
|
842
|
+
assert len(dtypes) == 1, (
|
|
843
|
+
f"The calibration expects only one element type but got {dtypes} for tensor={tensor!r}"
|
|
844
|
+
)
|
|
845
|
+
data_arr_np = np.asarray(data_arr)
|
|
846
|
+
elif not isinstance(data_arr, np.ndarray):
|
|
847
|
+
raise ValueError(f"Unexpected type {type(data_arr)} for tensor={tensor!r}")
|
|
848
|
+
else:
|
|
849
|
+
data_arr_np = data_arr
|
|
850
|
+
data_arr_np = data_arr_np.flatten()
|
|
851
|
+
if data_arr_np.size > 0:
|
|
852
|
+
min_value = np.nanmin(data_arr_np)
|
|
853
|
+
max_value = np.nanmax(data_arr_np)
|
|
854
|
+
else:
|
|
855
|
+
min_value = np.array(0, dtype=data_arr_np.dtype)
|
|
856
|
+
max_value = np.array(0, dtype=data_arr_np.dtype)
|
|
857
|
+
|
|
858
|
+
data_arr_np = np.absolute(data_arr_np) # only consider absolute value
|
|
859
|
+
|
|
860
|
+
if tensor not in self.histogram_dict:
|
|
861
|
+
# first time it uses num_bins to compute histogram.
|
|
862
|
+
hist, hist_edges = np.histogram(data_arr_np, bins=self.num_bins)
|
|
863
|
+
hist_edges = hist_edges.astype(data_arr_np.dtype)
|
|
864
|
+
assert data_arr_np.dtype != np.float64, (
|
|
865
|
+
"only float32 or float16 is supported, every constant must be explicitly typed"
|
|
866
|
+
)
|
|
867
|
+
self.histogram_dict[tensor] = (hist, hist_edges, min_value, max_value)
|
|
868
|
+
else:
|
|
869
|
+
old_histogram = self.histogram_dict[tensor]
|
|
870
|
+
old_min = old_histogram[2]
|
|
871
|
+
old_max = old_histogram[3]
|
|
872
|
+
assert hasattr(old_min, "dtype"), f"old_min should be a numpy array but is {type(old_min)}"
|
|
873
|
+
assert hasattr(old_max, "dtype"), f"old_min should be a numpy array but is {type(old_max)}"
|
|
874
|
+
old_hist = old_histogram[0]
|
|
875
|
+
old_hist_edges = old_histogram[1]
|
|
876
|
+
temp_amax = np.nanmax(data_arr_np)
|
|
877
|
+
if temp_amax > old_hist_edges[-1]:
|
|
878
|
+
# increase the number of bins
|
|
879
|
+
width = old_hist_edges[1] - old_hist_edges[0]
|
|
880
|
+
# NOTE: np.arange may create an extra bin after the one containing temp_amax
|
|
881
|
+
new_bin_edges = np.arange(old_hist_edges[-1] + width, temp_amax + width, width)
|
|
882
|
+
old_hist_edges = np.hstack((old_hist_edges, new_bin_edges))
|
|
883
|
+
hist, hist_edges = np.histogram(data_arr_np, bins=old_hist_edges)
|
|
884
|
+
hist_edges = hist_edges.astype(data_arr_np.dtype)
|
|
885
|
+
hist[: len(old_hist)] += old_hist
|
|
886
|
+
assert data_arr_np.dtype != np.float64, (
|
|
887
|
+
"only float32 or float16 is supported, every constant must be explicitly typed"
|
|
888
|
+
)
|
|
889
|
+
self.histogram_dict[tensor] = (hist, hist_edges, min(old_min, min_value), max(old_max, max_value))
|
|
890
|
+
|
|
891
|
+
def collect_value(self, name_to_arr):
|
|
892
|
+
"""
|
|
893
|
+
Collect histogram on real value
|
|
894
|
+
"""
|
|
895
|
+
for tensor, data_arr in name_to_arr.items():
|
|
896
|
+
data_arr = np.asarray(data_arr) # noqa: PLW2901
|
|
897
|
+
data_arr = data_arr.flatten() # noqa: PLW2901
|
|
898
|
+
|
|
899
|
+
if data_arr.size > 0:
|
|
900
|
+
min_value = np.nanmin(data_arr)
|
|
901
|
+
max_value = np.nanmax(data_arr)
|
|
902
|
+
else:
|
|
903
|
+
min_value = np.array(0, dtype=data_arr.dtype)
|
|
904
|
+
max_value = np.array(0, dtype=data_arr.dtype)
|
|
905
|
+
|
|
906
|
+
threshold = np.array(max(abs(min_value), abs(max_value)), dtype=data_arr.dtype)
|
|
907
|
+
|
|
908
|
+
if tensor in self.histogram_dict:
|
|
909
|
+
old_histogram = self.histogram_dict[tensor]
|
|
910
|
+
self.histogram_dict[tensor] = self.merge_histogram(
|
|
911
|
+
old_histogram, data_arr, min_value, max_value, threshold
|
|
912
|
+
)
|
|
913
|
+
else:
|
|
914
|
+
hist, hist_edges = np.histogram(data_arr, self.num_bins, range=(-threshold, threshold))
|
|
915
|
+
self.histogram_dict[tensor] = (
|
|
916
|
+
hist,
|
|
917
|
+
hist_edges,
|
|
918
|
+
min_value,
|
|
919
|
+
max_value,
|
|
920
|
+
threshold,
|
|
921
|
+
)
|
|
922
|
+
|
|
923
|
+
def merge_histogram(self, old_histogram, data_arr, new_min, new_max, new_threshold):
|
|
924
|
+
(old_hist, old_hist_edges, old_min, old_max, old_threshold) = old_histogram
|
|
925
|
+
|
|
926
|
+
if new_threshold <= old_threshold:
|
|
927
|
+
new_hist, _ = np.histogram(data_arr, len(old_hist), range=(-old_threshold, old_threshold))
|
|
928
|
+
return (
|
|
929
|
+
new_hist + old_hist,
|
|
930
|
+
old_hist_edges,
|
|
931
|
+
min(old_min, new_min),
|
|
932
|
+
max(old_max, new_max),
|
|
933
|
+
old_threshold,
|
|
934
|
+
)
|
|
935
|
+
else:
|
|
936
|
+
if old_threshold == 0:
|
|
937
|
+
hist, hist_edges = np.histogram(data_arr, len(old_hist), range=(-new_threshold, new_threshold))
|
|
938
|
+
hist += old_hist
|
|
939
|
+
else:
|
|
940
|
+
old_num_bins = len(old_hist)
|
|
941
|
+
old_stride = 2 * old_threshold / old_num_bins
|
|
942
|
+
half_increased_bins = int((new_threshold - old_threshold) // old_stride + 1)
|
|
943
|
+
new_num_bins = old_num_bins + 2 * half_increased_bins
|
|
944
|
+
new_threshold = half_increased_bins * old_stride + old_threshold
|
|
945
|
+
hist, hist_edges = np.histogram(data_arr, new_num_bins, range=(-new_threshold, new_threshold))
|
|
946
|
+
hist[half_increased_bins : new_num_bins - half_increased_bins] += old_hist
|
|
947
|
+
return (
|
|
948
|
+
hist,
|
|
949
|
+
hist_edges,
|
|
950
|
+
min(old_min, new_min),
|
|
951
|
+
max(old_max, new_max),
|
|
952
|
+
new_threshold,
|
|
953
|
+
)
|
|
954
|
+
|
|
955
|
+
def compute_collection_result(self):
|
|
956
|
+
if not self.histogram_dict or len(self.histogram_dict) == 0:
|
|
957
|
+
raise ValueError("Histogram has not been collected. Please run collect() first.")
|
|
958
|
+
print(f"Finding optimal threshold for each tensor using {self.method!r} algorithm ...")
|
|
959
|
+
|
|
960
|
+
if self.method == "entropy":
|
|
961
|
+
return self.compute_entropy()
|
|
962
|
+
elif self.method == "percentile":
|
|
963
|
+
return self.compute_percentile()
|
|
964
|
+
elif self.method == "distribution":
|
|
965
|
+
return self.compute_distribution()
|
|
966
|
+
else:
|
|
967
|
+
raise ValueError("Only 'entropy', 'percentile' or 'distribution' methods are supported")
|
|
968
|
+
|
|
969
|
+
def compute_percentile(self):
|
|
970
|
+
if self.percentile < 0 or self.percentile > 100:
|
|
971
|
+
raise ValueError("Invalid percentile. Must be in range 0 <= percentile <= 100.")
|
|
972
|
+
|
|
973
|
+
histogram_dict = self.histogram_dict
|
|
974
|
+
percentile = self.percentile
|
|
975
|
+
|
|
976
|
+
thresholds_dict = {} # per tensor thresholds
|
|
977
|
+
|
|
978
|
+
print(f"Number of tensors : {len(histogram_dict)}")
|
|
979
|
+
print(f"Number of histogram bins : {self.num_bins}")
|
|
980
|
+
print(f"Percentile : ({100.0 - percentile},{percentile})")
|
|
981
|
+
|
|
982
|
+
for tensor, histogram in histogram_dict.items():
|
|
983
|
+
hist = histogram[0]
|
|
984
|
+
hist_edges = histogram[1]
|
|
985
|
+
total = hist.sum()
|
|
986
|
+
cdf = np.cumsum(hist / total)
|
|
987
|
+
if self.symmetric:
|
|
988
|
+
idx_right = np.searchsorted(cdf, percentile / 100.0)
|
|
989
|
+
|
|
990
|
+
thresholds_dict[tensor] = (
|
|
991
|
+
-np.array(hist_edges[idx_right], dtype=hist_edges.dtype),
|
|
992
|
+
np.array(hist_edges[idx_right], dtype=hist_edges.dtype),
|
|
993
|
+
)
|
|
994
|
+
else:
|
|
995
|
+
percent_to_cut_one_side = (100.0 - percentile) / 200.0
|
|
996
|
+
idx_right = np.searchsorted(cdf, 1.0 - percent_to_cut_one_side)
|
|
997
|
+
idx_left = np.searchsorted(cdf, percent_to_cut_one_side)
|
|
998
|
+
thresholds_dict[tensor] = (
|
|
999
|
+
np.array(hist_edges[idx_left], dtype=hist_edges.dtype),
|
|
1000
|
+
np.array(hist_edges[idx_right], dtype=hist_edges.dtype),
|
|
1001
|
+
)
|
|
1002
|
+
min_value = histogram[2]
|
|
1003
|
+
max_value = histogram[3]
|
|
1004
|
+
if thresholds_dict[tensor][0] < min_value:
|
|
1005
|
+
thresholds_dict[tensor] = (min_value, thresholds_dict[tensor][1])
|
|
1006
|
+
if thresholds_dict[tensor][1] > max_value:
|
|
1007
|
+
thresholds_dict[tensor] = (thresholds_dict[tensor][0], max_value)
|
|
1008
|
+
thresholds_dict[tensor] = (*thresholds_dict[tensor], *hist[:2])
|
|
1009
|
+
# Plot histogram for debug only
|
|
1010
|
+
if os.environ.get("QUANTIZATION_DEBUG", "0") in (1, "1"):
|
|
1011
|
+
apply_plot(hist, hist_edges)
|
|
1012
|
+
|
|
1013
|
+
return thresholds_dict
|
|
1014
|
+
|
|
1015
|
+
def compute_entropy(self):
|
|
1016
|
+
histogram_dict = self.histogram_dict
|
|
1017
|
+
num_quantized_bins = self.num_quantized_bins
|
|
1018
|
+
|
|
1019
|
+
thresholds_dict = {} # per tensor thresholds
|
|
1020
|
+
|
|
1021
|
+
print(f"Number of tensors : {len(histogram_dict)}")
|
|
1022
|
+
print(f"Number of histogram bins : {self.num_bins} (The number may increase depends on the data it collects)")
|
|
1023
|
+
print(f"Number of quantized bins : {self.num_quantized_bins}")
|
|
1024
|
+
|
|
1025
|
+
for tensor, histogram in histogram_dict.items():
|
|
1026
|
+
optimal_threshold = self.get_entropy_threshold(histogram, num_quantized_bins)
|
|
1027
|
+
thresholds_dict[tensor] = optimal_threshold
|
|
1028
|
+
thresholds_dict[tensor] = (*optimal_threshold, *histogram[:2])
|
|
1029
|
+
|
|
1030
|
+
# Plot histogram for debug only
|
|
1031
|
+
if os.environ.get("QUANTIZATION_DEBUG", "0") in (1, "1"):
|
|
1032
|
+
apply_plot(histogram[0], histogram[1])
|
|
1033
|
+
|
|
1034
|
+
return thresholds_dict
|
|
1035
|
+
|
|
1036
|
+
@staticmethod
|
|
1037
|
+
def _avg_std(hist, hist_edges, power=1):
|
|
1038
|
+
if power <= 0:
|
|
1039
|
+
raise ValueError(f"power={power} <= 0 is invalid.")
|
|
1040
|
+
values = (hist_edges[:-1] + hist_edges[1:]) * 0.5
|
|
1041
|
+
if power == 1:
|
|
1042
|
+
avg = (hist * values).sum() / hist.sum()
|
|
1043
|
+
std = ((hist * values**2).sum() / hist.sum() - avg**2) ** 0.5
|
|
1044
|
+
return np.array(avg, dtype=hist_edges.dtype), np.array(std, dtype=hist_edges.dtype)
|
|
1045
|
+
if int(power) == power and int(power) % 2 == 1:
|
|
1046
|
+
avg = (hist * values**power).sum() / hist.sum()
|
|
1047
|
+
std = ((hist * (values**power - avg) ** 2).sum() / hist.sum()) ** 0.5
|
|
1048
|
+
return np.array(avg, dtype=hist_edges.dtype), np.array(std, dtype=hist_edges.dtype)
|
|
1049
|
+
|
|
1050
|
+
fact = np.abs(values) / values
|
|
1051
|
+
fact[np.isnan(fact)] = 1
|
|
1052
|
+
fact[np.isinf(fact)] = 1
|
|
1053
|
+
values = np.abs(values) ** power * fact
|
|
1054
|
+
avg = (hist * values).sum() / hist.sum()
|
|
1055
|
+
std = ((hist * values**2).sum() / hist.sum() - avg**2) ** 0.5
|
|
1056
|
+
return np.array(avg, dtype=hist_edges.dtype), np.array(std, dtype=hist_edges.dtype)
|
|
1057
|
+
|
|
1058
|
+
def compute_distribution(self):
|
|
1059
|
+
if self.num_bins < 512:
|
|
1060
|
+
raise ValueError("Invalid num_bins. Must be in range 512 <= num_bins.")
|
|
1061
|
+
|
|
1062
|
+
histogram_dict = self.histogram_dict
|
|
1063
|
+
thresholds_dict = {} # per tensor thresholds
|
|
1064
|
+
|
|
1065
|
+
print(f"Number of tensors : {len(histogram_dict)}")
|
|
1066
|
+
print(f"Number of histogram bins : {self.num_bins}")
|
|
1067
|
+
print(f"Scenario : {self.scenario!r})")
|
|
1068
|
+
|
|
1069
|
+
for tensor, histogram in histogram_dict.items():
|
|
1070
|
+
hist = histogram[0]
|
|
1071
|
+
hist_edges = histogram[1]
|
|
1072
|
+
|
|
1073
|
+
assert hist_edges.dtype != np.float64
|
|
1074
|
+
if self.scenario == "same":
|
|
1075
|
+
avg_coef, std_coef = self._avg_std(hist, hist_edges, power=1)
|
|
1076
|
+
elif self.scenario == "p3":
|
|
1077
|
+
avg_coef, std_coef = self._avg_std(hist, hist_edges, power=1.0 / 3.0)
|
|
1078
|
+
else:
|
|
1079
|
+
raise ValueError("Invalid scenario. Must be in {'same', 'p3'}.")
|
|
1080
|
+
assert avg_coef.dtype != np.float64
|
|
1081
|
+
assert std_coef.dtype != np.float64
|
|
1082
|
+
assert hist_edges.dtype != np.float64
|
|
1083
|
+
thresholds_dict[tensor] = TensorData(
|
|
1084
|
+
avg=avg_coef,
|
|
1085
|
+
std=std_coef,
|
|
1086
|
+
hist=hist,
|
|
1087
|
+
hist_edges=hist_edges,
|
|
1088
|
+
lowest=hist_edges.min(),
|
|
1089
|
+
highest=hist_edges.max(),
|
|
1090
|
+
)
|
|
1091
|
+
|
|
1092
|
+
# Plot histogram for debug only
|
|
1093
|
+
if os.environ.get("QUANTIZATION_DEBUG", "0") in (1, "1"):
|
|
1094
|
+
apply_plot(hist, hist_edges)
|
|
1095
|
+
|
|
1096
|
+
return thresholds_dict
|
|
1097
|
+
|
|
1098
|
+
def get_entropy_threshold(self, histogram, num_quantized_bins):
|
|
1099
|
+
"""Given a dataset, find the optimal threshold for quantizing it.
|
|
1100
|
+
The reference distribution is `q`, and the candidate distribution is `p`.
|
|
1101
|
+
`q` is a truncated version of the original distribution.
|
|
1102
|
+
Ref: http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf
|
|
1103
|
+
"""
|
|
1104
|
+
hist = histogram[0]
|
|
1105
|
+
hist_edges = histogram[1]
|
|
1106
|
+
num_bins = hist.size
|
|
1107
|
+
zero_bin_index = num_bins // 2
|
|
1108
|
+
num_half_quantized_bin = num_quantized_bins // 2
|
|
1109
|
+
|
|
1110
|
+
dtype = histogram[1].dtype
|
|
1111
|
+
kl_divergence = np.zeros(zero_bin_index - num_half_quantized_bin + 1)
|
|
1112
|
+
thresholds = [(np.array(0, dtype=dtype), np.array(0, dtype=dtype)) for i in range(kl_divergence.size)]
|
|
1113
|
+
|
|
1114
|
+
# <------------ num bins ---------------->
|
|
1115
|
+
# <--- quantized bins ---->
|
|
1116
|
+
# |======|===========|===========|=======|
|
|
1117
|
+
# zero bin index
|
|
1118
|
+
# ^ ^
|
|
1119
|
+
# | |
|
|
1120
|
+
# start index end index (start of iteration)
|
|
1121
|
+
# ^ ^
|
|
1122
|
+
# | |
|
|
1123
|
+
# start index end index ...
|
|
1124
|
+
# ^ ^
|
|
1125
|
+
# | |
|
|
1126
|
+
# start index end index (end of iteration)
|
|
1127
|
+
|
|
1128
|
+
for i in range(num_half_quantized_bin, zero_bin_index + 1, 1):
|
|
1129
|
+
start_index = zero_bin_index - i
|
|
1130
|
+
end_index = min(zero_bin_index + i + 1, num_bins)
|
|
1131
|
+
|
|
1132
|
+
thresholds[i - num_half_quantized_bin] = (hist_edges[start_index], hist_edges[end_index])
|
|
1133
|
+
|
|
1134
|
+
sliced_distribution = copy.deepcopy(hist[start_index:end_index])
|
|
1135
|
+
|
|
1136
|
+
# reference distribution p
|
|
1137
|
+
p = sliced_distribution.copy() # a copy of np array
|
|
1138
|
+
left_outliers_count = sum(hist[:start_index])
|
|
1139
|
+
right_outliers_count = sum(hist[end_index:])
|
|
1140
|
+
p[0] += left_outliers_count
|
|
1141
|
+
p[-1] += right_outliers_count
|
|
1142
|
+
|
|
1143
|
+
# nonzeros[i] incidates whether p[i] is non-zero
|
|
1144
|
+
nonzeros = (p != 0).astype(np.int64)
|
|
1145
|
+
|
|
1146
|
+
# quantize p.size bins into quantized bins (default 128 bins)
|
|
1147
|
+
quantized_bins = np.zeros(num_quantized_bins, dtype=np.int64)
|
|
1148
|
+
num_merged_bins = sliced_distribution.size // num_quantized_bins
|
|
1149
|
+
|
|
1150
|
+
# merge bins into quantized bins
|
|
1151
|
+
for index in range(num_quantized_bins):
|
|
1152
|
+
start = index * num_merged_bins
|
|
1153
|
+
end = start + num_merged_bins
|
|
1154
|
+
quantized_bins[index] = sum(sliced_distribution[start:end])
|
|
1155
|
+
quantized_bins[-1] += sum(sliced_distribution[num_quantized_bins * num_merged_bins :])
|
|
1156
|
+
|
|
1157
|
+
# in order to compare p and q, we need to make length of q equals to length of p
|
|
1158
|
+
# expand quantized bins into p.size bins
|
|
1159
|
+
q = np.zeros(p.size, dtype=np.int64)
|
|
1160
|
+
for index in range(num_quantized_bins):
|
|
1161
|
+
start = index * num_merged_bins
|
|
1162
|
+
end = start + num_merged_bins
|
|
1163
|
+
|
|
1164
|
+
norm = sum(nonzeros[start:end])
|
|
1165
|
+
if norm != 0:
|
|
1166
|
+
q[start:end] = quantized_bins[index] / norm
|
|
1167
|
+
|
|
1168
|
+
p = smooth_distribution(p)
|
|
1169
|
+
q = smooth_distribution(q)
|
|
1170
|
+
if p is None or q is None:
|
|
1171
|
+
div = np.array(np.inf, dtype=dtype)
|
|
1172
|
+
else:
|
|
1173
|
+
div = np.array(entropy(p, q), dtype=dtype)
|
|
1174
|
+
kl_divergence[i - num_half_quantized_bin] = div
|
|
1175
|
+
|
|
1176
|
+
min_kl_divergence_idx = np.argmin(kl_divergence)
|
|
1177
|
+
optimal_threshold = thresholds[min_kl_divergence_idx]
|
|
1178
|
+
min_value = histogram[2]
|
|
1179
|
+
max_value = histogram[3]
|
|
1180
|
+
if optimal_threshold[0] < min_value:
|
|
1181
|
+
optimal_threshold = (min_value, optimal_threshold[1])
|
|
1182
|
+
if optimal_threshold[1] > max_value:
|
|
1183
|
+
optimal_threshold = (optimal_threshold[0], max_value)
|
|
1184
|
+
assert hasattr(optimal_threshold[0], "dtype")
|
|
1185
|
+
assert hasattr(optimal_threshold[1], "dtype")
|
|
1186
|
+
return optimal_threshold
|
|
1187
|
+
|
|
1188
|
+
|
|
1189
|
+
def create_calibrator(
|
|
1190
|
+
model: str | Path,
|
|
1191
|
+
op_types_to_calibrate: Sequence[str] | None = None,
|
|
1192
|
+
augmented_model_path="augmented_model.onnx",
|
|
1193
|
+
calibrate_method=CalibrationMethod.MinMax,
|
|
1194
|
+
use_external_data_format=False,
|
|
1195
|
+
providers=None,
|
|
1196
|
+
extra_options={}, # noqa: B006
|
|
1197
|
+
):
|
|
1198
|
+
calibrator = None
|
|
1199
|
+
if calibrate_method == CalibrationMethod.MinMax:
|
|
1200
|
+
# default settings for min-max algorithm
|
|
1201
|
+
symmetric = extra_options.get("symmetric", False)
|
|
1202
|
+
moving_average = extra_options.get("moving_average", False)
|
|
1203
|
+
averaging_constant = extra_options.get("averaging_constant", 0.01)
|
|
1204
|
+
max_intermediate_outputs = extra_options.get("max_intermediate_outputs", None)
|
|
1205
|
+
per_channel = extra_options.get("per_channel", False)
|
|
1206
|
+
calibrator = MinMaxCalibrater(
|
|
1207
|
+
model,
|
|
1208
|
+
op_types_to_calibrate,
|
|
1209
|
+
augmented_model_path,
|
|
1210
|
+
use_external_data_format=use_external_data_format,
|
|
1211
|
+
symmetric=symmetric,
|
|
1212
|
+
moving_average=moving_average,
|
|
1213
|
+
averaging_constant=averaging_constant,
|
|
1214
|
+
max_intermediate_outputs=max_intermediate_outputs,
|
|
1215
|
+
per_channel=per_channel,
|
|
1216
|
+
)
|
|
1217
|
+
elif calibrate_method == CalibrationMethod.Entropy:
|
|
1218
|
+
# default settings for entropy algorithm
|
|
1219
|
+
num_bins = extra_options.get("num_bins", 128)
|
|
1220
|
+
num_quantized_bins = extra_options.get("num_quantized_bins", 128)
|
|
1221
|
+
symmetric = extra_options.get("symmetric", False)
|
|
1222
|
+
calibrator = EntropyCalibrater(
|
|
1223
|
+
model,
|
|
1224
|
+
op_types_to_calibrate,
|
|
1225
|
+
augmented_model_path,
|
|
1226
|
+
use_external_data_format=use_external_data_format,
|
|
1227
|
+
symmetric=symmetric,
|
|
1228
|
+
num_bins=num_bins,
|
|
1229
|
+
num_quantized_bins=num_quantized_bins,
|
|
1230
|
+
)
|
|
1231
|
+
elif calibrate_method == CalibrationMethod.Percentile:
|
|
1232
|
+
# default settings for percentile algorithm
|
|
1233
|
+
num_bins = extra_options.get("num_bins", 2048)
|
|
1234
|
+
percentile = extra_options.get("percentile", 99.999)
|
|
1235
|
+
symmetric = extra_options.get("symmetric", True)
|
|
1236
|
+
calibrator = PercentileCalibrater(
|
|
1237
|
+
model,
|
|
1238
|
+
op_types_to_calibrate,
|
|
1239
|
+
augmented_model_path,
|
|
1240
|
+
use_external_data_format=use_external_data_format,
|
|
1241
|
+
symmetric=symmetric,
|
|
1242
|
+
num_bins=num_bins,
|
|
1243
|
+
percentile=percentile,
|
|
1244
|
+
)
|
|
1245
|
+
|
|
1246
|
+
elif calibrate_method == CalibrationMethod.Distribution:
|
|
1247
|
+
# default settings for percentile algorithm
|
|
1248
|
+
num_bins = extra_options.get("num_bins", 2048)
|
|
1249
|
+
scenario = extra_options.get("scenario", "same")
|
|
1250
|
+
|
|
1251
|
+
calibrator = DistributionCalibrater(
|
|
1252
|
+
model,
|
|
1253
|
+
op_types_to_calibrate,
|
|
1254
|
+
augmented_model_path,
|
|
1255
|
+
use_external_data_format=use_external_data_format,
|
|
1256
|
+
num_bins=num_bins,
|
|
1257
|
+
scenario=scenario,
|
|
1258
|
+
)
|
|
1259
|
+
|
|
1260
|
+
if calibrator:
|
|
1261
|
+
calibrator.augment_graph()
|
|
1262
|
+
if providers:
|
|
1263
|
+
calibrator.execution_providers = providers
|
|
1264
|
+
calibrator.create_inference_session()
|
|
1265
|
+
return calibrator
|
|
1266
|
+
|
|
1267
|
+
raise ValueError(f"Unsupported calibration method {calibrate_method}")
|