onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
from .operators.activation import QDQRemovableActivation, QLinearActivation
|
|
2
|
+
from .operators.argmax import QArgMax
|
|
3
|
+
from .operators.attention import AttentionQuant
|
|
4
|
+
from .operators.base_operator import QuantOperatorBase
|
|
5
|
+
from .operators.binary_op import QLinearBinaryOp
|
|
6
|
+
from .operators.concat import QLinearConcat
|
|
7
|
+
from .operators.conv import ConvInteger, QDQConv, QLinearConv
|
|
8
|
+
from .operators.direct_q8 import Direct8BitOp, QDQDirect8BitOp
|
|
9
|
+
from .operators.embed_layernorm import EmbedLayerNormalizationQuant
|
|
10
|
+
from .operators.gather import GatherQuant, QDQGather
|
|
11
|
+
from .operators.gavgpool import QGlobalAveragePool
|
|
12
|
+
from .operators.gemm import QDQGemm, QLinearGemm
|
|
13
|
+
from .operators.lstm import LSTMQuant
|
|
14
|
+
from .operators.matmul import MatMulInteger, QDQMatMul, QLinearMatMul
|
|
15
|
+
from .operators.maxpool import QDQMaxPool, QMaxPool
|
|
16
|
+
from .operators.norm import QDQNormalization
|
|
17
|
+
from .operators.pad import QDQPad, QPad
|
|
18
|
+
from .operators.pooling import QLinearPool
|
|
19
|
+
from .operators.qdq_base_operator import QDQOperatorBase
|
|
20
|
+
from .operators.resize import QDQResize, QResize
|
|
21
|
+
from .operators.softmax import QLinearSoftmax
|
|
22
|
+
from .operators.split import QDQSplit, QSplit
|
|
23
|
+
from .operators.where import QDQWhere, QLinearWhere
|
|
24
|
+
from .quant_utils import QuantizationMode
|
|
25
|
+
|
|
26
|
+
CommonOpsRegistry = {
|
|
27
|
+
"Gather": GatherQuant,
|
|
28
|
+
"Transpose": Direct8BitOp,
|
|
29
|
+
"EmbedLayerNormalization": EmbedLayerNormalizationQuant,
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
IntegerOpsRegistry = {
|
|
33
|
+
"Conv": ConvInteger,
|
|
34
|
+
"MatMul": MatMulInteger,
|
|
35
|
+
"Attention": AttentionQuant,
|
|
36
|
+
"LSTM": LSTMQuant,
|
|
37
|
+
}
|
|
38
|
+
IntegerOpsRegistry.update(CommonOpsRegistry)
|
|
39
|
+
|
|
40
|
+
QLinearOpsRegistry = {
|
|
41
|
+
"ArgMax": QArgMax,
|
|
42
|
+
"Conv": QLinearConv,
|
|
43
|
+
"Gemm": QLinearGemm,
|
|
44
|
+
"MatMul": QLinearMatMul,
|
|
45
|
+
"Add": QLinearBinaryOp,
|
|
46
|
+
"Mul": QLinearBinaryOp,
|
|
47
|
+
"Relu": QLinearActivation,
|
|
48
|
+
"Clip": QLinearActivation,
|
|
49
|
+
"LeakyRelu": QLinearActivation,
|
|
50
|
+
"Sigmoid": QLinearActivation,
|
|
51
|
+
"MaxPool": QMaxPool,
|
|
52
|
+
"GlobalAveragePool": QGlobalAveragePool,
|
|
53
|
+
"Split": QSplit,
|
|
54
|
+
"Pad": QPad,
|
|
55
|
+
"Reshape": Direct8BitOp,
|
|
56
|
+
"Squeeze": Direct8BitOp,
|
|
57
|
+
"Unsqueeze": Direct8BitOp,
|
|
58
|
+
"Resize": QResize,
|
|
59
|
+
"AveragePool": QLinearPool,
|
|
60
|
+
"Concat": QLinearConcat,
|
|
61
|
+
"Softmax": QLinearSoftmax,
|
|
62
|
+
"Where": QLinearWhere,
|
|
63
|
+
}
|
|
64
|
+
QLinearOpsRegistry.update(CommonOpsRegistry)
|
|
65
|
+
|
|
66
|
+
QDQRegistry = {
|
|
67
|
+
"Conv": QDQConv,
|
|
68
|
+
"ConvTranspose": QDQConv,
|
|
69
|
+
"Gemm": QDQGemm,
|
|
70
|
+
"Clip": QDQRemovableActivation,
|
|
71
|
+
"Relu": QDQRemovableActivation,
|
|
72
|
+
"Reshape": QDQDirect8BitOp,
|
|
73
|
+
"Transpose": QDQDirect8BitOp,
|
|
74
|
+
"Squeeze": QDQDirect8BitOp,
|
|
75
|
+
"Unsqueeze": QDQDirect8BitOp,
|
|
76
|
+
"Resize": QDQResize,
|
|
77
|
+
"MaxPool": QDQMaxPool,
|
|
78
|
+
"AveragePool": QDQDirect8BitOp,
|
|
79
|
+
"Slice": QDQDirect8BitOp,
|
|
80
|
+
"Pad": QDQPad,
|
|
81
|
+
"MatMul": QDQMatMul,
|
|
82
|
+
"Split": QDQSplit,
|
|
83
|
+
"Gather": QDQGather,
|
|
84
|
+
"GatherElements": QDQGather,
|
|
85
|
+
"Where": QDQWhere,
|
|
86
|
+
"InstanceNormalization": QDQNormalization,
|
|
87
|
+
"LayerNormalization": QDQNormalization,
|
|
88
|
+
"BatchNormalization": QDQNormalization,
|
|
89
|
+
"TopK": QDQDirect8BitOp,
|
|
90
|
+
"CumSum": QDQOperatorBase,
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def CreateDefaultOpQuantizer(onnx_quantizer, node): # noqa: N802
|
|
95
|
+
return QuantOperatorBase(onnx_quantizer, node)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def CreateOpQuantizer(onnx_quantizer, node): # noqa: N802
|
|
99
|
+
registry = IntegerOpsRegistry if onnx_quantizer.mode == QuantizationMode.IntegerOps else QLinearOpsRegistry
|
|
100
|
+
if node.op_type in registry:
|
|
101
|
+
op_quantizer = registry[node.op_type](onnx_quantizer, node)
|
|
102
|
+
if op_quantizer.should_quantize():
|
|
103
|
+
return op_quantizer
|
|
104
|
+
return QuantOperatorBase(onnx_quantizer, node)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def CreateQDQQuantizer(onnx_quantizer, node): # noqa: N802
|
|
108
|
+
if node.op_type in QDQRegistry:
|
|
109
|
+
return QDQRegistry[node.op_type](onnx_quantizer, node)
|
|
110
|
+
return QDQOperatorBase(onnx_quantizer, node)
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
# --------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft, Intel Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
import tempfile
|
|
10
|
+
import traceback
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import onnx
|
|
14
|
+
|
|
15
|
+
import onnxruntime
|
|
16
|
+
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
|
|
17
|
+
from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data
|
|
18
|
+
|
|
19
|
+
from .fusions import ReplaceUpsampleWithResize
|
|
20
|
+
from .onnx_model import ONNXModel
|
|
21
|
+
from .quant_utils import add_pre_process_metadata, save_and_reload_model_with_shape_infer
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def quant_pre_process(
|
|
27
|
+
input_model: str | Path | onnx.ModelProto | None = None,
|
|
28
|
+
output_model_path: str | Path | None = None,
|
|
29
|
+
skip_optimization: bool = False,
|
|
30
|
+
skip_onnx_shape: bool = False,
|
|
31
|
+
skip_symbolic_shape: bool = False,
|
|
32
|
+
auto_merge: bool = False,
|
|
33
|
+
int_max: int = 2**31 - 1,
|
|
34
|
+
guess_output_rank: bool = False,
|
|
35
|
+
verbose: int = 0,
|
|
36
|
+
save_as_external_data: bool = False,
|
|
37
|
+
all_tensors_to_one_file: bool = False,
|
|
38
|
+
external_data_location: str | None = None,
|
|
39
|
+
external_data_size_threshold: int = 1024,
|
|
40
|
+
**deprecated_kwargs,
|
|
41
|
+
) -> None:
|
|
42
|
+
"""Shape inference and model optimization, in preparation for quantization.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
input_model: Path to the input model file or ModelProto
|
|
46
|
+
output_model_path: Path to the output model file
|
|
47
|
+
skip_optimization: Skip model optimization step if true. This may result in ONNX shape
|
|
48
|
+
inference failure for some models.
|
|
49
|
+
skip_onnx_shape: Skip ONNX shape inference. Symbolic shape inference is most effective
|
|
50
|
+
with transformer based models. Skipping all shape inferences may
|
|
51
|
+
reduce the effectiveness of quantization, as a tensor with unknown
|
|
52
|
+
shape can not be quantized.
|
|
53
|
+
skip_symbolic_shape: Skip symbolic shape inference. Symbolic shape inference is most
|
|
54
|
+
effective with transformer based models. Skipping all shape
|
|
55
|
+
inferences may reduce the effectiveness of quantization, as a tensor
|
|
56
|
+
with unknown shape can not be quantized.
|
|
57
|
+
auto_merge: For symbolic shape inference, automatically merge symbolic dims when
|
|
58
|
+
conflict happens.
|
|
59
|
+
int_max: For symbolic shape inference, specify the maximum value for integer to be
|
|
60
|
+
treated as boundless for ops like slice
|
|
61
|
+
guess_output_rank: Guess output rank to be the same as input 0 for unknown ops
|
|
62
|
+
verbose: Logs detailed info of inference, 0: turn off, 1: warnings, 3: detailed
|
|
63
|
+
save_as_external_data: Saving an ONNX model to external data
|
|
64
|
+
all_tensors_to_one_file: Saving all the external data to one file
|
|
65
|
+
external_data_location: The file location to save the external file
|
|
66
|
+
external_data_size_threshold: The size threshold for external data
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
if input_model is None:
|
|
70
|
+
input_model = deprecated_kwargs.pop("input_model_path", None)
|
|
71
|
+
assert input_model is not None
|
|
72
|
+
|
|
73
|
+
assert output_model_path is not None, "output_model_path is required."
|
|
74
|
+
|
|
75
|
+
with tempfile.TemporaryDirectory(prefix="pre.quant.") as quant_tmp_dir:
|
|
76
|
+
temp_path = Path(quant_tmp_dir)
|
|
77
|
+
model = input_model if isinstance(input_model, onnx.ModelProto) else onnx.load(input_model)
|
|
78
|
+
|
|
79
|
+
# Since Upsample is deprecated after opset v10, and the model's opset will
|
|
80
|
+
# be upgraded to at least v11 during quantization, we need to replace Upsample
|
|
81
|
+
# with Resize first to avoid generating an invalid model.
|
|
82
|
+
ai_onnx_domain = [opset for opset in model.opset_import if not opset.domain or opset.domain == "ai.onnx"]
|
|
83
|
+
if len(ai_onnx_domain) == 1:
|
|
84
|
+
opset_version = ai_onnx_domain[0].version
|
|
85
|
+
if opset_version <= 10:
|
|
86
|
+
ReplaceUpsampleWithResize(ONNXModel(model), opset_version).apply()
|
|
87
|
+
model = onnx.version_converter.convert_version(model, 11)
|
|
88
|
+
model = save_and_reload_model_with_shape_infer(model)
|
|
89
|
+
|
|
90
|
+
if not skip_symbolic_shape:
|
|
91
|
+
logger.info("Performing symbolic shape inference...")
|
|
92
|
+
model = SymbolicShapeInference.infer_shapes(
|
|
93
|
+
model,
|
|
94
|
+
int_max,
|
|
95
|
+
auto_merge,
|
|
96
|
+
guess_output_rank,
|
|
97
|
+
verbose,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
if not skip_optimization:
|
|
101
|
+
# Use ORT optimizers (native code) to optimize model
|
|
102
|
+
if not skip_symbolic_shape:
|
|
103
|
+
# Need to save the inferenced model to file so as to run the optimizer
|
|
104
|
+
input_model = str(temp_path / "symbolic_shape_inferred.onnx")
|
|
105
|
+
if save_as_external_data:
|
|
106
|
+
onnx.save_model(
|
|
107
|
+
model,
|
|
108
|
+
input_model,
|
|
109
|
+
save_as_external_data=True,
|
|
110
|
+
all_tensors_to_one_file=all_tensors_to_one_file,
|
|
111
|
+
size_threshold=external_data_size_threshold,
|
|
112
|
+
convert_attribute=False,
|
|
113
|
+
)
|
|
114
|
+
else:
|
|
115
|
+
onnx.save(model, input_model)
|
|
116
|
+
model = None
|
|
117
|
+
|
|
118
|
+
opt_model_path = str(temp_path / "optimized.onnx")
|
|
119
|
+
try:
|
|
120
|
+
sess_option = onnxruntime.SessionOptions()
|
|
121
|
+
sess_option.optimized_model_filepath = opt_model_path
|
|
122
|
+
sess_option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
|
123
|
+
# For large model, extract external data from model and add to session options
|
|
124
|
+
if isinstance(input_model, onnx.ModelProto):
|
|
125
|
+
if has_external_data(input_model):
|
|
126
|
+
raise ValueError(
|
|
127
|
+
"ModelProto has external data not loaded into memory, ORT cannot create session. "
|
|
128
|
+
"Please load external data before calling this function. "
|
|
129
|
+
"See https://onnx.ai/onnx/repo-docs/ExternalData.html for more information."
|
|
130
|
+
)
|
|
131
|
+
external_names, external_values = extract_raw_data_from_model(input_model)
|
|
132
|
+
sess_option.add_external_initializers(list(external_names), list(external_values))
|
|
133
|
+
input_model = input_model.SerializeToString()
|
|
134
|
+
# the saved optimized model otherwise points to the original external data file name
|
|
135
|
+
# which is not available relative to the optimized model file
|
|
136
|
+
elif skip_symbolic_shape and save_as_external_data:
|
|
137
|
+
sess_option.add_session_config_entry(
|
|
138
|
+
"session.optimized_model_external_initializers_file_name", "optimized.onnx.data"
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
sess = onnxruntime.InferenceSession(input_model, sess_option, providers=["CPUExecutionProvider"])
|
|
142
|
+
# Close the session to avoid the cleanup error on Windows for temp folders
|
|
143
|
+
# https://github.com/microsoft/onnxruntime/issues/17627
|
|
144
|
+
del sess
|
|
145
|
+
except Exception:
|
|
146
|
+
logger.error(
|
|
147
|
+
"ONNX Runtime Model Optimization Failed! Consider rerun with option `--skip_optimization'."
|
|
148
|
+
)
|
|
149
|
+
logger.error(traceback.format_exc())
|
|
150
|
+
|
|
151
|
+
input_model = opt_model_path
|
|
152
|
+
|
|
153
|
+
if not skip_onnx_shape:
|
|
154
|
+
# ONNX shape inference.
|
|
155
|
+
# According to docs, infer_shapes_path should be used for 2G+ models.
|
|
156
|
+
# If the skip optimization is specified, we could be dealing with a
|
|
157
|
+
# large model. So be on the safe side, save the model
|
|
158
|
+
if model is not None:
|
|
159
|
+
input_model = str(temp_path / "symbolic_shape_inferred.onnx")
|
|
160
|
+
if save_as_external_data:
|
|
161
|
+
onnx.save_model(
|
|
162
|
+
model,
|
|
163
|
+
input_model,
|
|
164
|
+
save_as_external_data=True,
|
|
165
|
+
all_tensors_to_one_file=all_tensors_to_one_file,
|
|
166
|
+
size_threshold=external_data_size_threshold,
|
|
167
|
+
convert_attribute=False,
|
|
168
|
+
)
|
|
169
|
+
else:
|
|
170
|
+
onnx.save(model, input_model)
|
|
171
|
+
model = None
|
|
172
|
+
|
|
173
|
+
if isinstance(input_model, onnx.ModelProto):
|
|
174
|
+
input_model = str(Path(quant_tmp_dir) / "model_input.onnx")
|
|
175
|
+
onnx.save_model(
|
|
176
|
+
model,
|
|
177
|
+
input_model,
|
|
178
|
+
save_as_external_data=True,
|
|
179
|
+
all_tensors_to_one_file=all_tensors_to_one_file,
|
|
180
|
+
size_threshold=external_data_size_threshold,
|
|
181
|
+
convert_attribute=False,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
inferred_model_path = str(temp_path / "onnx_shape_inferred.onnx")
|
|
185
|
+
onnx.shape_inference.infer_shapes_path(input_model, inferred_model_path)
|
|
186
|
+
model = onnx.load(inferred_model_path)
|
|
187
|
+
|
|
188
|
+
if model is None:
|
|
189
|
+
model = input_model if isinstance(input_model, onnx.ModelProto) else onnx.load(input_model)
|
|
190
|
+
|
|
191
|
+
add_pre_process_metadata(model)
|
|
192
|
+
|
|
193
|
+
if save_as_external_data:
|
|
194
|
+
onnx.save_model(
|
|
195
|
+
model,
|
|
196
|
+
output_model_path,
|
|
197
|
+
save_as_external_data=True,
|
|
198
|
+
all_tensors_to_one_file=all_tensors_to_one_file,
|
|
199
|
+
location=external_data_location,
|
|
200
|
+
size_threshold=external_data_size_threshold,
|
|
201
|
+
convert_attribute=False,
|
|
202
|
+
)
|
|
203
|
+
else:
|
|
204
|
+
onnx.save(model, output_model_path)
|
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import onnx
|
|
7
|
+
|
|
8
|
+
import onnxruntime
|
|
9
|
+
from onnxruntime.quantization import QuantFormat, QuantType, StaticQuantConfig, quantize
|
|
10
|
+
from onnxruntime.quantization.calibrate import CalibrationDataReader, CalibrationMethod
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class OnnxModelCalibrationDataReader(CalibrationDataReader):
|
|
14
|
+
def __init__(self, model_path):
|
|
15
|
+
self.model_dir = os.path.dirname(model_path)
|
|
16
|
+
data_dirs = [
|
|
17
|
+
os.path.join(self.model_dir, a) for a in os.listdir(self.model_dir) if a.startswith("test_data_set_")
|
|
18
|
+
]
|
|
19
|
+
model_inputs = onnxruntime.InferenceSession(model_path).get_inputs()
|
|
20
|
+
name2tensors = []
|
|
21
|
+
for data_dir in data_dirs:
|
|
22
|
+
name2tensor = {}
|
|
23
|
+
data_paths = [os.path.join(data_dir, f"input_{input_idx}.pb") for input_idx in range(len(model_inputs))]
|
|
24
|
+
data_ndarrays = [self.read_onnx_pb_data(data_path) for data_path in data_paths]
|
|
25
|
+
for model_input, data_ndarray in zip(model_inputs, data_ndarrays, strict=False):
|
|
26
|
+
name2tensor[model_input.name] = data_ndarray
|
|
27
|
+
name2tensors.append(name2tensor)
|
|
28
|
+
assert len(name2tensors) == len(data_dirs)
|
|
29
|
+
assert len(name2tensors[0]) == len(model_inputs)
|
|
30
|
+
|
|
31
|
+
self.calibration_data = iter(name2tensors)
|
|
32
|
+
|
|
33
|
+
def get_next(self) -> dict:
|
|
34
|
+
"""generate the input data dict for ONNXinferenceSession run"""
|
|
35
|
+
return next(self.calibration_data, None)
|
|
36
|
+
|
|
37
|
+
def read_onnx_pb_data(self, file_pb):
|
|
38
|
+
tensor = onnx.TensorProto()
|
|
39
|
+
with open(file_pb, "rb") as f:
|
|
40
|
+
tensor.ParseFromString(f.read())
|
|
41
|
+
ret = onnx.numpy_helper.to_array(tensor)
|
|
42
|
+
return ret
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def parse_arguments():
|
|
46
|
+
parser = argparse.ArgumentParser(description="The arguments for static quantization")
|
|
47
|
+
parser.add_argument("-i", "--input_model_path", required=True, help="Path to the input onnx model")
|
|
48
|
+
parser.add_argument(
|
|
49
|
+
"-o", "--output_quantized_model_path", required=True, help="Path to the output quantized onnx model"
|
|
50
|
+
)
|
|
51
|
+
parser.add_argument(
|
|
52
|
+
"--activation_type",
|
|
53
|
+
choices=["qint8", "quint8", "qint16", "quint16", "qint4", "quint4", "qfloat8e4m3fn"],
|
|
54
|
+
default="quint8",
|
|
55
|
+
help="Activation quantization type used",
|
|
56
|
+
)
|
|
57
|
+
parser.add_argument(
|
|
58
|
+
"--weight_type",
|
|
59
|
+
choices=["qint8", "quint8", "qint16", "quint16", "qint4", "quint4", "qfloat8e4m3fn"],
|
|
60
|
+
default="qint8",
|
|
61
|
+
help="Weight quantization type used",
|
|
62
|
+
)
|
|
63
|
+
parser.add_argument("--enable_subgraph", action="store_true", help="If set, subgraph will be quantized.")
|
|
64
|
+
parser.add_argument(
|
|
65
|
+
"--force_quantize_no_input_check",
|
|
66
|
+
action="store_true",
|
|
67
|
+
help="By default, some latent operators like maxpool, transpose, do not quantize if their input is not"
|
|
68
|
+
" quantized already. Setting to True to force such operator always quantize input and so generate"
|
|
69
|
+
" quantized output. Also the True behavior could be disabled per node using the nodes_to_exclude.",
|
|
70
|
+
)
|
|
71
|
+
parser.add_argument(
|
|
72
|
+
"--matmul_const_b_only",
|
|
73
|
+
action="store_true",
|
|
74
|
+
help="If set, only MatMul with const B will be quantized.",
|
|
75
|
+
)
|
|
76
|
+
parser.add_argument(
|
|
77
|
+
"--add_qdq_pair_to_weight",
|
|
78
|
+
action="store_true",
|
|
79
|
+
help="If set, it remains floating-point weight and inserts both QuantizeLinear/DeQuantizeLinear"
|
|
80
|
+
" nodes to weight.",
|
|
81
|
+
)
|
|
82
|
+
parser.add_argument(
|
|
83
|
+
"--dedicated_qdq_pair",
|
|
84
|
+
action="store_true",
|
|
85
|
+
help="If set, it will create identical and dedicated QDQ pair for each node.",
|
|
86
|
+
)
|
|
87
|
+
parser.add_argument(
|
|
88
|
+
"--op_types_to_exclude_output_quantization",
|
|
89
|
+
nargs="+",
|
|
90
|
+
default=[],
|
|
91
|
+
help="If any op type is specified, it won't quantize the output of ops with this specific op types.",
|
|
92
|
+
)
|
|
93
|
+
parser.add_argument(
|
|
94
|
+
"--calibration_method",
|
|
95
|
+
default="minmax",
|
|
96
|
+
choices=["minmax", "entropy", "percentile", "distribution"],
|
|
97
|
+
help="Calibration method used",
|
|
98
|
+
)
|
|
99
|
+
parser.add_argument("--quant_format", default="qdq", choices=["qdq", "qoperator"], help="Quantization format used")
|
|
100
|
+
parser.add_argument(
|
|
101
|
+
"--calib_tensor_range_symmetric",
|
|
102
|
+
action="store_true",
|
|
103
|
+
help="If enabled, the final range of tensor during calibration will be explicitly"
|
|
104
|
+
" set to symmetric to central point 0",
|
|
105
|
+
)
|
|
106
|
+
# TODO: --calib_strided_minmax"
|
|
107
|
+
# TODO: --calib_moving_average_constant"
|
|
108
|
+
# TODO: --calib_max_intermediate_outputs"
|
|
109
|
+
parser.add_argument(
|
|
110
|
+
"--calib_moving_average",
|
|
111
|
+
action="store_true",
|
|
112
|
+
help="If enabled, the moving average of"
|
|
113
|
+
" the minimum and maximum values will be computed when the calibration method selected is MinMax.",
|
|
114
|
+
)
|
|
115
|
+
parser.add_argument(
|
|
116
|
+
"--disable_quantize_bias",
|
|
117
|
+
action="store_true",
|
|
118
|
+
help="Whether to quantize floating-point biases by solely inserting a DeQuantizeLinear node"
|
|
119
|
+
" If not set, it remains floating-point bias and does not insert any quantization nodes"
|
|
120
|
+
" associated with biases.",
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# TODO: Add arguments related to Smooth Quant
|
|
124
|
+
|
|
125
|
+
parser.add_argument(
|
|
126
|
+
"--use_qdq_contrib_ops",
|
|
127
|
+
action="store_true",
|
|
128
|
+
help="If set, the inserted QuantizeLinear and DequantizeLinear ops will have the com.microsoft domain,"
|
|
129
|
+
" which forces use of ONNX Runtime's QuantizeLinear and DequantizeLinear contrib op implementations.",
|
|
130
|
+
)
|
|
131
|
+
parser.add_argument(
|
|
132
|
+
"--minimum_real_range",
|
|
133
|
+
type=float,
|
|
134
|
+
default=0.0001,
|
|
135
|
+
help="If set to a floating-point value, the calculation of the quantization parameters"
|
|
136
|
+
" (i.e., scale and zero point) will enforce a minimum range between rmin and rmax. If (rmax-rmin)"
|
|
137
|
+
" is less than the specified minimum range, rmax will be set to rmin + MinimumRealRange. This is"
|
|
138
|
+
" necessary for EPs like QNN that require a minimum floating-point range when determining "
|
|
139
|
+
" quantization parameters.",
|
|
140
|
+
)
|
|
141
|
+
parser.add_argument(
|
|
142
|
+
"--qdq_keep_removable_activations",
|
|
143
|
+
action="store_true",
|
|
144
|
+
help="If set, removable activations (e.g., Clip or Relu) will not be removed,"
|
|
145
|
+
" and will be explicitly represented in the QDQ model.",
|
|
146
|
+
)
|
|
147
|
+
parser.add_argument(
|
|
148
|
+
"--qdq_disable_weight_adjust_for_int32_bias",
|
|
149
|
+
action="store_true",
|
|
150
|
+
help="If set, QDQ quantizer will not adjust the weight's scale when the bias"
|
|
151
|
+
" has a scale (input_scale * weight_scale) that is too small.",
|
|
152
|
+
)
|
|
153
|
+
parser.add_argument("--per_channel", action="store_true", help="Whether using per-channel quantization")
|
|
154
|
+
parser.add_argument(
|
|
155
|
+
"--nodes_to_quantize",
|
|
156
|
+
nargs="+",
|
|
157
|
+
default=None,
|
|
158
|
+
help="List of nodes names to quantize. When this list is not None only the nodes in this list are quantized.",
|
|
159
|
+
)
|
|
160
|
+
parser.add_argument(
|
|
161
|
+
"--nodes_to_exclude",
|
|
162
|
+
nargs="+",
|
|
163
|
+
default=None,
|
|
164
|
+
help="List of nodes names to exclude. The nodes in this list will be excluded from quantization when it is not None.",
|
|
165
|
+
)
|
|
166
|
+
parser.add_argument(
|
|
167
|
+
"--op_per_channel_axis",
|
|
168
|
+
nargs=2,
|
|
169
|
+
action="append",
|
|
170
|
+
metavar=("OP_TYPE", "PER_CHANNEL_AXIS"),
|
|
171
|
+
default=[],
|
|
172
|
+
help="Set channel axis for specific op type, for example: --op_per_channel_axis MatMul 1, and it's"
|
|
173
|
+
" effective only when per channel quantization is supported and per_channel is True. If specific"
|
|
174
|
+
" op type supports per channel quantization but not explicitly specified with channel axis,"
|
|
175
|
+
" default channel axis will be used.",
|
|
176
|
+
)
|
|
177
|
+
parser.add_argument("--tensor_quant_overrides", help="Set the json file for tensor quantization overrides.")
|
|
178
|
+
return parser.parse_args()
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def get_tensor_quant_overrides(file):
|
|
182
|
+
# TODO: Enhance the function to handle more real cases of json file
|
|
183
|
+
if not file:
|
|
184
|
+
return {}
|
|
185
|
+
with open(file) as f:
|
|
186
|
+
quant_override_dict = json.load(f)
|
|
187
|
+
for tensor in quant_override_dict:
|
|
188
|
+
for enc_dict in quant_override_dict[tensor]:
|
|
189
|
+
enc_dict["scale"] = np.array(enc_dict["scale"], dtype=np.float32)
|
|
190
|
+
enc_dict["zero_point"] = np.array(enc_dict["zero_point"])
|
|
191
|
+
return quant_override_dict
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def main():
|
|
195
|
+
args = parse_arguments()
|
|
196
|
+
data_reader = OnnxModelCalibrationDataReader(model_path=args.input_model_path)
|
|
197
|
+
arg2quant_type = {
|
|
198
|
+
"qint8": QuantType.QInt8,
|
|
199
|
+
"quint8": QuantType.QUInt8,
|
|
200
|
+
"qint16": QuantType.QInt16,
|
|
201
|
+
"quint16": QuantType.QUInt16,
|
|
202
|
+
"qint4": QuantType.QInt4,
|
|
203
|
+
"quint4": QuantType.QUInt4,
|
|
204
|
+
"qfloat8e4m3fn": QuantType.QFLOAT8E4M3FN,
|
|
205
|
+
}
|
|
206
|
+
activation_type = arg2quant_type[args.activation_type]
|
|
207
|
+
weight_type = arg2quant_type[args.weight_type]
|
|
208
|
+
qdq_op_type_per_channel_support_to_axis = dict(args.op_per_channel_axis)
|
|
209
|
+
extra_options = {
|
|
210
|
+
"EnableSubgraph": args.enable_subgraph,
|
|
211
|
+
"ForceQuantizeNoInputCheck": args.force_quantize_no_input_check,
|
|
212
|
+
"MatMulConstBOnly": args.matmul_const_b_only,
|
|
213
|
+
"AddQDQPairToWeight": args.add_qdq_pair_to_weight,
|
|
214
|
+
"OpTypesToExcludeOutputQuantization": args.op_types_to_exclude_output_quantization,
|
|
215
|
+
"DedicatedQDQPair": args.dedicated_qdq_pair,
|
|
216
|
+
"QDQOpTypePerChannelSupportToAxis": qdq_op_type_per_channel_support_to_axis,
|
|
217
|
+
"CalibTensorRangeSymmetric": args.calib_tensor_range_symmetric,
|
|
218
|
+
"CalibMovingAverage": args.calib_moving_average,
|
|
219
|
+
"QuantizeBias": not args.disable_quantize_bias,
|
|
220
|
+
"UseQDQContribOps": args.use_qdq_contrib_ops,
|
|
221
|
+
"MinimumRealRange": args.minimum_real_range,
|
|
222
|
+
"QDQKeepRemovableActivations": args.qdq_keep_removable_activations,
|
|
223
|
+
"QDQDisableWeightAdjustForInt32Bias": args.qdq_disable_weight_adjust_for_int32_bias,
|
|
224
|
+
# Load json file for encoding override
|
|
225
|
+
"TensorQuantOverrides": get_tensor_quant_overrides(args.tensor_quant_overrides),
|
|
226
|
+
}
|
|
227
|
+
arg2calib_method = {
|
|
228
|
+
"minmax": CalibrationMethod.MinMax,
|
|
229
|
+
"entropy": CalibrationMethod.Entropy,
|
|
230
|
+
"percentile": CalibrationMethod.Percentile,
|
|
231
|
+
"distribution": CalibrationMethod.Distribution,
|
|
232
|
+
}
|
|
233
|
+
arg2quant_format = {
|
|
234
|
+
"qdq": QuantFormat.QDQ,
|
|
235
|
+
"qoperator": QuantFormat.QOperator,
|
|
236
|
+
}
|
|
237
|
+
sqc = StaticQuantConfig(
|
|
238
|
+
calibration_data_reader=data_reader,
|
|
239
|
+
calibrate_method=arg2calib_method[args.calibration_method],
|
|
240
|
+
quant_format=arg2quant_format[args.quant_format],
|
|
241
|
+
activation_type=activation_type,
|
|
242
|
+
weight_type=weight_type,
|
|
243
|
+
op_types_to_quantize=None,
|
|
244
|
+
nodes_to_quantize=args.nodes_to_quantize,
|
|
245
|
+
nodes_to_exclude=args.nodes_to_exclude,
|
|
246
|
+
per_channel=args.per_channel,
|
|
247
|
+
reduce_range=False,
|
|
248
|
+
use_external_data_format=False,
|
|
249
|
+
calibration_providers=None, # Use CPUExecutionProvider
|
|
250
|
+
extra_options=extra_options,
|
|
251
|
+
)
|
|
252
|
+
quantize(model_input=args.input_model_path, model_output=args.output_quantized_model_path, quant_config=sqc)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
if __name__ == "__main__":
|
|
256
|
+
main()
|