onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import onnx
|
|
2
|
+
|
|
3
|
+
from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg
|
|
4
|
+
from .base_operator import QuantOperatorBase
|
|
5
|
+
from .qdq_base_operator import QDQOperatorBase
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class QSplit(QuantOperatorBase):
|
|
9
|
+
def __init__(self, onnx_quantizer, onnx_node):
|
|
10
|
+
super().__init__(onnx_quantizer, onnx_node)
|
|
11
|
+
|
|
12
|
+
def quantize(self):
|
|
13
|
+
node = self.node
|
|
14
|
+
(
|
|
15
|
+
quantized_input_names,
|
|
16
|
+
zero_point_names,
|
|
17
|
+
scale_names,
|
|
18
|
+
nodes,
|
|
19
|
+
) = self.quantizer.quantize_activation(node, [0])
|
|
20
|
+
if quantized_input_names is None:
|
|
21
|
+
return super().quantize()
|
|
22
|
+
|
|
23
|
+
quantized_node_name = ""
|
|
24
|
+
if node.name:
|
|
25
|
+
quantized_node_name = node.name + "_quant"
|
|
26
|
+
kwargs = {}
|
|
27
|
+
for attribute in node.attribute:
|
|
28
|
+
kwargs.update(attribute_to_kwarg(attribute))
|
|
29
|
+
|
|
30
|
+
# Output just derive the scale/zero from input
|
|
31
|
+
quantized_output_names = []
|
|
32
|
+
for output_name in node.output:
|
|
33
|
+
quantized_output_name = output_name + "quantized"
|
|
34
|
+
quantized_output_names.append(quantized_output_name)
|
|
35
|
+
q_output = QuantizedValue(
|
|
36
|
+
output_name,
|
|
37
|
+
quantized_output_name,
|
|
38
|
+
scale_names[0],
|
|
39
|
+
zero_point_names[0],
|
|
40
|
+
QuantizedValueType.Input,
|
|
41
|
+
)
|
|
42
|
+
self.quantizer.quantized_value_map[output_name] = q_output
|
|
43
|
+
|
|
44
|
+
if len(node.input) > 1:
|
|
45
|
+
quantized_input_names.extend(node.input[1:])
|
|
46
|
+
quantized_node = onnx.helper.make_node(
|
|
47
|
+
node.op_type, quantized_input_names, quantized_output_names, quantized_node_name, **kwargs
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
nodes.append(quantized_node)
|
|
51
|
+
self.quantizer.new_nodes += nodes
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class QDQSplit(QDQOperatorBase):
|
|
55
|
+
def quantize(self):
|
|
56
|
+
node = self.node
|
|
57
|
+
assert node.op_type == "Split"
|
|
58
|
+
|
|
59
|
+
if not self.quantizer.is_tensor_quantized(node.input[0]):
|
|
60
|
+
self.quantizer.quantize_activation_tensor(node.input[0])
|
|
61
|
+
if not self.disable_qdq_for_node_output:
|
|
62
|
+
for output in node.output:
|
|
63
|
+
self.quantizer.quantize_output_same_as_input(output, node.input[0], node.name)
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
import onnx
|
|
2
|
+
|
|
3
|
+
from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
|
|
4
|
+
from .base_operator import QuantOperatorBase
|
|
5
|
+
from .qdq_base_operator import QDQOperatorBase
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class QLinearWhere(QuantOperatorBase):
|
|
9
|
+
def should_quantize(self):
|
|
10
|
+
return True
|
|
11
|
+
|
|
12
|
+
def quantize(self):
|
|
13
|
+
node = self.node
|
|
14
|
+
assert node.op_type == "Where"
|
|
15
|
+
if not self.quantizer.force_quantize_no_input_check:
|
|
16
|
+
self.quantizer.new_nodes += [node]
|
|
17
|
+
return
|
|
18
|
+
(
|
|
19
|
+
data_found,
|
|
20
|
+
output_scale_name,
|
|
21
|
+
output_zp_name,
|
|
22
|
+
_,
|
|
23
|
+
_,
|
|
24
|
+
) = self.quantizer._get_quantization_params(node.output[0])
|
|
25
|
+
(
|
|
26
|
+
q_input_names,
|
|
27
|
+
zero_point_names,
|
|
28
|
+
scale_names,
|
|
29
|
+
nodes,
|
|
30
|
+
) = self.quantizer.quantize_activation(node, [1, 2])
|
|
31
|
+
if not data_found or q_input_names is None:
|
|
32
|
+
return super().quantize()
|
|
33
|
+
qlinear_output = node.output[0] + TENSOR_NAME_QUANT_SUFFIX
|
|
34
|
+
qlinear_output_name = node.name + "_quant" if node.name else ""
|
|
35
|
+
|
|
36
|
+
q_output = QuantizedValue(
|
|
37
|
+
node.output[0],
|
|
38
|
+
qlinear_output,
|
|
39
|
+
output_scale_name,
|
|
40
|
+
output_zp_name,
|
|
41
|
+
QuantizedValueType.Input,
|
|
42
|
+
)
|
|
43
|
+
self.quantizer.quantized_value_map[node.output[0]] = q_output
|
|
44
|
+
|
|
45
|
+
kwargs = {}
|
|
46
|
+
for attribute in node.attribute:
|
|
47
|
+
kwargs.update(attribute_to_kwarg(attribute))
|
|
48
|
+
kwargs["domain"] = ms_domain
|
|
49
|
+
|
|
50
|
+
qlwhere_inputs = [
|
|
51
|
+
node.input[0],
|
|
52
|
+
q_input_names[0],
|
|
53
|
+
scale_names[0],
|
|
54
|
+
zero_point_names[0],
|
|
55
|
+
q_input_names[1],
|
|
56
|
+
scale_names[1],
|
|
57
|
+
zero_point_names[1],
|
|
58
|
+
output_scale_name,
|
|
59
|
+
output_zp_name,
|
|
60
|
+
]
|
|
61
|
+
qlwhere_node = onnx.helper.make_node(
|
|
62
|
+
"QLinearWhere", qlwhere_inputs, [qlinear_output], qlinear_output_name, **kwargs
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
self.quantizer.new_nodes += nodes
|
|
66
|
+
self.quantizer.new_nodes += [qlwhere_node]
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class QDQWhere(QDQOperatorBase):
|
|
70
|
+
def quantize(self):
|
|
71
|
+
node = self.node
|
|
72
|
+
assert node.op_type == "Where"
|
|
73
|
+
if self.quantizer.force_quantize_no_input_check:
|
|
74
|
+
if not self.quantizer.is_tensor_quantized(node.input[1]):
|
|
75
|
+
self.quantizer.quantize_activation_tensor(node.input[1])
|
|
76
|
+
if not self.quantizer.is_tensor_quantized(node.input[2]):
|
|
77
|
+
self.quantizer.quantize_activation_tensor(node.input[2])
|
|
78
|
+
if not self.disable_qdq_for_node_output:
|
|
79
|
+
for output in node.output:
|
|
80
|
+
self.quantizer.quantize_activation_tensor(output)
|
|
81
|
+
elif (
|
|
82
|
+
self.quantizer.is_tensor_quantized(node.input[1])
|
|
83
|
+
and self.quantizer.is_tensor_quantized(node.input[2])
|
|
84
|
+
and not self.disable_qdq_for_node_output
|
|
85
|
+
):
|
|
86
|
+
for output in node.output:
|
|
87
|
+
self.quantizer.quantize_activation_tensor(output)
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
# --------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft, Intel Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
import argparse
|
|
8
|
+
import logging
|
|
9
|
+
import sys
|
|
10
|
+
|
|
11
|
+
from .shape_inference import quant_pre_process
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def parse_arguments():
|
|
17
|
+
parser = argparse.ArgumentParser(
|
|
18
|
+
description="""Model optimizer and shape inferencer, in preparation for quantization,
|
|
19
|
+
Consists of three optional steps:
|
|
20
|
+
1. Symbolic shape inference (best for transformer models).
|
|
21
|
+
2. Model optimization.
|
|
22
|
+
3. ONNX shape inference.
|
|
23
|
+
|
|
24
|
+
Model quantization with QDQ format, i.e. inserting QuantizeLinear/DeQuantizeLinear on
|
|
25
|
+
the tensor, requires tensor shape information to perform its best. Currently, shape inferencing
|
|
26
|
+
works best with optimized model. As a result, it is highly recommended to run quantization
|
|
27
|
+
on optimized model with shape information. This is the tool for optimization and shape
|
|
28
|
+
inferencing.
|
|
29
|
+
|
|
30
|
+
Essentially this tool performs the following three (skippable) steps:
|
|
31
|
+
|
|
32
|
+
1. Symbolic shape inference.
|
|
33
|
+
2. Model optimization
|
|
34
|
+
3. ONNX shape inference"""
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
parser.add_argument("--input", required=True, help="Path to the input model file")
|
|
38
|
+
parser.add_argument("--output", required=True, help="Path to the output model file")
|
|
39
|
+
parser.add_argument(
|
|
40
|
+
"--skip_optimization",
|
|
41
|
+
type=bool,
|
|
42
|
+
default=False,
|
|
43
|
+
help="Skip model optimization step if true. It's a known issue that ORT"
|
|
44
|
+
" optimization has difficulty with model size greater than 2GB, rerun with"
|
|
45
|
+
" this option to get around this issue.",
|
|
46
|
+
)
|
|
47
|
+
parser.add_argument(
|
|
48
|
+
"--skip_onnx_shape",
|
|
49
|
+
type=bool,
|
|
50
|
+
default=False,
|
|
51
|
+
help="Skip ONNX shape inference. Symbolic shape inference is most effective"
|
|
52
|
+
" with transformer based models. Skipping all shape inferences may"
|
|
53
|
+
" reduce the effectiveness of quantization, as a tensor with unknown"
|
|
54
|
+
" shape can not be quantized.",
|
|
55
|
+
)
|
|
56
|
+
parser.add_argument(
|
|
57
|
+
"--skip_symbolic_shape",
|
|
58
|
+
type=bool,
|
|
59
|
+
default=False,
|
|
60
|
+
help="Skip symbolic shape inference. Symbolic shape inference is most"
|
|
61
|
+
" effective with transformer based models. Skipping all shape"
|
|
62
|
+
" inferences may reduce the effectiveness of quantization, as a tensor"
|
|
63
|
+
" with unknown shape can not be quantized.",
|
|
64
|
+
)
|
|
65
|
+
parser.add_argument(
|
|
66
|
+
"--auto_merge",
|
|
67
|
+
help="Automatically merge symbolic dims when confliction happens",
|
|
68
|
+
action="store_true",
|
|
69
|
+
default=False,
|
|
70
|
+
)
|
|
71
|
+
parser.add_argument(
|
|
72
|
+
"--int_max",
|
|
73
|
+
help="maximum value for integer to be treated as boundless for ops like slice",
|
|
74
|
+
type=int,
|
|
75
|
+
default=2**31 - 1,
|
|
76
|
+
)
|
|
77
|
+
parser.add_argument(
|
|
78
|
+
"--guess_output_rank",
|
|
79
|
+
help="guess output rank to be the same as input 0 for unknown ops",
|
|
80
|
+
action="store_true",
|
|
81
|
+
default=False,
|
|
82
|
+
)
|
|
83
|
+
parser.add_argument(
|
|
84
|
+
"--verbose",
|
|
85
|
+
help="Prints detailed logs of inference, 0: turn off, 1: warnings, 3: detailed",
|
|
86
|
+
type=int,
|
|
87
|
+
default=0,
|
|
88
|
+
)
|
|
89
|
+
parser.add_argument(
|
|
90
|
+
"--save_as_external_data",
|
|
91
|
+
help="Saving an ONNX model to external data",
|
|
92
|
+
action="store_true",
|
|
93
|
+
default=False,
|
|
94
|
+
)
|
|
95
|
+
parser.add_argument(
|
|
96
|
+
"--all_tensors_to_one_file",
|
|
97
|
+
help="Saving all the external data to one file",
|
|
98
|
+
action="store_true",
|
|
99
|
+
default=False,
|
|
100
|
+
)
|
|
101
|
+
parser.add_argument(
|
|
102
|
+
"--external_data_location",
|
|
103
|
+
help="The file location to save the external file",
|
|
104
|
+
default=None,
|
|
105
|
+
)
|
|
106
|
+
parser.add_argument(
|
|
107
|
+
"--external_data_size_threshold",
|
|
108
|
+
help="The size threshold for external data",
|
|
109
|
+
type=int,
|
|
110
|
+
default=1024,
|
|
111
|
+
)
|
|
112
|
+
return parser.parse_args()
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
if __name__ == "__main__":
|
|
116
|
+
args = parse_arguments()
|
|
117
|
+
if args.skip_optimization and args.skip_onnx_shape and args.skip_symbolic_shape:
|
|
118
|
+
logger.error("Skipping all three steps, nothing to be done. Quitting...")
|
|
119
|
+
sys.exit()
|
|
120
|
+
|
|
121
|
+
if (not args.skip_optimization) and args.save_as_external_data:
|
|
122
|
+
logger.error("ORT model optimization does not support external data yet!")
|
|
123
|
+
sys.exit()
|
|
124
|
+
|
|
125
|
+
logger.info("input model: %s", args.input)
|
|
126
|
+
logger.info("output model: %s", args.output)
|
|
127
|
+
quant_pre_process(
|
|
128
|
+
args.input,
|
|
129
|
+
args.output,
|
|
130
|
+
args.skip_optimization,
|
|
131
|
+
args.skip_onnx_shape,
|
|
132
|
+
args.skip_symbolic_shape,
|
|
133
|
+
args.auto_merge,
|
|
134
|
+
args.int_max,
|
|
135
|
+
args.guess_output_rank,
|
|
136
|
+
args.verbose,
|
|
137
|
+
args.save_as_external_data,
|
|
138
|
+
args.all_tensors_to_one_file,
|
|
139
|
+
args.external_data_location,
|
|
140
|
+
args.external_data_size_threshold,
|
|
141
|
+
)
|
|
@@ -0,0 +1,389 @@
|
|
|
1
|
+
# --------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft, Intel Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
"""Utilities to run a given ONNX model, while saving input/output tensors of
|
|
8
|
+
eligible operator nodes.
|
|
9
|
+
|
|
10
|
+
A use case is to debug quantization induced accuracy drop. An AI engineer can
|
|
11
|
+
run the original float32 model and the quantized model with the same inputs,
|
|
12
|
+
then compare the corresponding activations between the two models to find
|
|
13
|
+
where the divergence is.
|
|
14
|
+
|
|
15
|
+
Example Usage:
|
|
16
|
+
|
|
17
|
+
```python
|
|
18
|
+
class ExampleDataReader(CalibrationDataReader):
|
|
19
|
+
def __init__(self):
|
|
20
|
+
...
|
|
21
|
+
def get_next(self):
|
|
22
|
+
...
|
|
23
|
+
|
|
24
|
+
input_data_reader = ExampleDataReader()
|
|
25
|
+
|
|
26
|
+
augmented_model_path = str(Path(self._tmp_model_dir.name).joinpath("augmented_model.onnx"))
|
|
27
|
+
modify_model_output_intermediate_tensors (path_to_onnx_model, augmented_model_path)
|
|
28
|
+
|
|
29
|
+
tensor_dict = collect_activations(augmented_model_path, input_data_reader)
|
|
30
|
+
```
|
|
31
|
+
|
|
32
|
+
`tensor_dict` points to a dictionary where the keys are tensor names and each value
|
|
33
|
+
is a list of tensors, one from each model run
|
|
34
|
+
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
import logging
|
|
38
|
+
import math
|
|
39
|
+
import time
|
|
40
|
+
from collections.abc import Callable, Sequence
|
|
41
|
+
from pathlib import Path
|
|
42
|
+
|
|
43
|
+
import numpy
|
|
44
|
+
import onnx
|
|
45
|
+
from onnx import helper, numpy_helper
|
|
46
|
+
|
|
47
|
+
import onnxruntime
|
|
48
|
+
|
|
49
|
+
from .calibrate import CalibraterBase, CalibrationDataReader
|
|
50
|
+
from .onnx_model import ONNXModel
|
|
51
|
+
from .quant_utils import (
|
|
52
|
+
DEQUANT_OP_NAME,
|
|
53
|
+
DEQUANT_OUTPUT_SUFFIX,
|
|
54
|
+
QUANT_INPUT_SUFFIX,
|
|
55
|
+
TENSOR_NAME_QUANT_SUFFIX,
|
|
56
|
+
find_by_name,
|
|
57
|
+
load_model_with_shape_infer,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
_TENSOR_SAVE_POSTFIX = "_ReshapedSavedOutput"
|
|
61
|
+
_TENSOR_SAVE_POSTFIX_LEN = len(_TENSOR_SAVE_POSTFIX)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def modify_model_output_intermediate_tensors(
|
|
65
|
+
input_model_path: str | Path,
|
|
66
|
+
output_model_path: str | Path,
|
|
67
|
+
op_types_for_saving: Sequence[str] | None = None,
|
|
68
|
+
save_as_external_data: bool = False,
|
|
69
|
+
) -> None:
|
|
70
|
+
"""Augment a given ONNX model to save node input/output tensors.
|
|
71
|
+
|
|
72
|
+
Add all input/output tensors of operator nodes to model outputs
|
|
73
|
+
so that their values can be retrieved for debugging purposes.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
input_model: the path to load the model.
|
|
77
|
+
op_types_for_saving: Operator types for which the
|
|
78
|
+
input/output should be saved. By default, saving all the
|
|
79
|
+
float32/float16 tensors.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
The augmented ONNX model
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
if op_types_for_saving is None:
|
|
86
|
+
op_types_for_saving = []
|
|
87
|
+
saver = CalibraterBase(input_model_path, op_types_to_calibrate=op_types_for_saving)
|
|
88
|
+
model_to_augment = saver.model
|
|
89
|
+
tensors, value_infos = saver.select_tensors_to_calibrate(model_to_augment)
|
|
90
|
+
reshape_shape_name = "LinearReshape_" + str(time.time())
|
|
91
|
+
reshape_shape = numpy_helper.from_array(numpy.array([-1], dtype=numpy.int64), reshape_shape_name)
|
|
92
|
+
model_to_augment.graph.initializer.append(reshape_shape)
|
|
93
|
+
|
|
94
|
+
for tensor_name in tensors:
|
|
95
|
+
reshape_output = tensor_name + _TENSOR_SAVE_POSTFIX
|
|
96
|
+
reshape_node = onnx.helper.make_node(
|
|
97
|
+
"Reshape",
|
|
98
|
+
inputs=[tensor_name, reshape_shape_name],
|
|
99
|
+
outputs=[reshape_output],
|
|
100
|
+
name=reshape_output,
|
|
101
|
+
)
|
|
102
|
+
model_to_augment.graph.node.append(reshape_node)
|
|
103
|
+
reshape_output_value_info = helper.make_tensor_value_info(
|
|
104
|
+
reshape_output, value_infos[tensor_name].type.tensor_type.elem_type, [-1]
|
|
105
|
+
)
|
|
106
|
+
model_to_augment.graph.output.append(reshape_output_value_info)
|
|
107
|
+
|
|
108
|
+
onnx.save(
|
|
109
|
+
model_to_augment,
|
|
110
|
+
output_model_path,
|
|
111
|
+
save_as_external_data=save_as_external_data,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def collect_activations(
|
|
116
|
+
augmented_model: str,
|
|
117
|
+
input_reader: CalibrationDataReader,
|
|
118
|
+
session_options=None,
|
|
119
|
+
execution_providers: Sequence[str] | None = None,
|
|
120
|
+
) -> dict[str, list[numpy.ndarray]]:
|
|
121
|
+
"""Run augmented model and collect activations tensors.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
augmented_model: Path to augmented model created by modify_model_output_intermediate_tensors ()
|
|
125
|
+
input_reader: Logic for reading input for the model, augmented model have the same
|
|
126
|
+
input with the original model.
|
|
127
|
+
session_options: Optional OnnxRuntime session options for controlling model run.
|
|
128
|
+
By default graph optimization is turned off
|
|
129
|
+
execution_providers: Collection of execution providers for running the model.
|
|
130
|
+
Only CPU EP is used by default.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
A dictionary where the key is tensor name and values are list of tensors from each batch
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
if session_options is None:
|
|
137
|
+
session_options = onnxruntime.SessionOptions()
|
|
138
|
+
session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
|
|
139
|
+
if execution_providers is None:
|
|
140
|
+
execution_providers = ["CPUExecutionProvider"]
|
|
141
|
+
|
|
142
|
+
inference_session = onnxruntime.InferenceSession(
|
|
143
|
+
augmented_model,
|
|
144
|
+
sess_options=session_options,
|
|
145
|
+
providers=execution_providers,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
intermediate_outputs = []
|
|
149
|
+
for input_d in input_reader:
|
|
150
|
+
intermediate_outputs.append(inference_session.run(None, input_d))
|
|
151
|
+
if not intermediate_outputs:
|
|
152
|
+
raise RuntimeError("No data is collected while running augmented model!")
|
|
153
|
+
|
|
154
|
+
output_dict = {}
|
|
155
|
+
output_info = inference_session.get_outputs()
|
|
156
|
+
for batch in intermediate_outputs:
|
|
157
|
+
for output, output_data in zip(output_info, batch, strict=False):
|
|
158
|
+
if output.name.endswith(_TENSOR_SAVE_POSTFIX):
|
|
159
|
+
output_name = output.name[:-_TENSOR_SAVE_POSTFIX_LEN]
|
|
160
|
+
output_dict.setdefault(output_name, []).append(output_data)
|
|
161
|
+
|
|
162
|
+
return output_dict
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
_POST_QDQ_POSTFIX1 = DEQUANT_OUTPUT_SUFFIX + "_1"
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _add_pre_post_qdq_pair(
|
|
169
|
+
qdq_cmp: dict[str, dict[str, Sequence[numpy.ndarray]]],
|
|
170
|
+
activation_name: str,
|
|
171
|
+
pre_qdq_tensors: Sequence[numpy.ndarray] | None,
|
|
172
|
+
post_qdq_tensors: Sequence[numpy.ndarray] | None,
|
|
173
|
+
) -> None:
|
|
174
|
+
if post_qdq_tensors is not None and pre_qdq_tensors is not None:
|
|
175
|
+
qdq_cmp[activation_name] = {}
|
|
176
|
+
qdq_cmp[activation_name]["pre_qdq"] = pre_qdq_tensors
|
|
177
|
+
qdq_cmp[activation_name]["post_qdq"] = post_qdq_tensors
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def create_activation_matching(
|
|
181
|
+
qdq_activations: dict[str, Sequence[numpy.ndarray]],
|
|
182
|
+
float_activations: dict[str, Sequence[numpy.ndarray]] | None = None,
|
|
183
|
+
) -> dict[str, dict[str, Sequence[numpy.ndarray]]]:
|
|
184
|
+
"""Comparing activation values to help debugging accuracy loss due to quantization.
|
|
185
|
+
|
|
186
|
+
This functions takes saved activations from the QDQ model and (optionally) the
|
|
187
|
+
float point model, and provides a data structure for comparing:
|
|
188
|
+
* from the qdq model, activation values before and after QDQ operation
|
|
189
|
+
* across both models, activations from the orignal model vs the corresponding
|
|
190
|
+
activations in the QDQ model
|
|
191
|
+
|
|
192
|
+
Arg:
|
|
193
|
+
qdq_activations: Output of `collect_activations`. This must be from a quantized
|
|
194
|
+
model with QDQ format.
|
|
195
|
+
float_activations: Output of `collect_activations`. This must be from the float
|
|
196
|
+
point model.
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
Dict for comparing pre and post quantized activation tensors. E.g.
|
|
200
|
+
```
|
|
201
|
+
qdq_cmp = cmp_qdq_input_output(qdq_activations)
|
|
202
|
+
print(qdq_cmp['activation1']['pre_qdq'][0])
|
|
203
|
+
print(qdq_cmp['activation1'][`post_qdq'][0])
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
qdq_cmp = cmp_qdq_input_output(qdq_activations, float_activations)
|
|
207
|
+
print(qdq_cmp['activation1']['float'][0])
|
|
208
|
+
print(qdq_cmp['activation1']['pre_qdq'][0])
|
|
209
|
+
print(qdq_cmp['activation1'][`post_qdq'][0])
|
|
210
|
+
```
|
|
211
|
+
"""
|
|
212
|
+
|
|
213
|
+
qdq_cmp: dict[str, dict[str, Sequence[numpy.ndarray]]] = {}
|
|
214
|
+
for tensor_name, tensors in qdq_activations.items():
|
|
215
|
+
if tensor_name.endswith(QUANT_INPUT_SUFFIX):
|
|
216
|
+
pre_name = tensor_name[: -len(QUANT_INPUT_SUFFIX)]
|
|
217
|
+
post_qdq_tensors = qdq_activations.get(pre_name)
|
|
218
|
+
pre_qdq_tensors = tensors
|
|
219
|
+
_add_pre_post_qdq_pair(qdq_cmp, pre_name, pre_qdq_tensors, post_qdq_tensors)
|
|
220
|
+
elif tensor_name.endswith(DEQUANT_OUTPUT_SUFFIX):
|
|
221
|
+
pre_name = tensor_name[: -len(DEQUANT_OUTPUT_SUFFIX)]
|
|
222
|
+
pre_qdq_tensors = qdq_activations.get(pre_name)
|
|
223
|
+
post_qdq_tensors = tensors
|
|
224
|
+
_add_pre_post_qdq_pair(qdq_cmp, pre_name, pre_qdq_tensors, post_qdq_tensors)
|
|
225
|
+
elif tensor_name.endswith(_POST_QDQ_POSTFIX1):
|
|
226
|
+
pre_name = tensor_name[: -len(_POST_QDQ_POSTFIX1)]
|
|
227
|
+
pre_qdq_tensors = qdq_activations.get(pre_name)
|
|
228
|
+
post_qdq_tensors = tensors
|
|
229
|
+
_add_pre_post_qdq_pair(qdq_cmp, pre_name, pre_qdq_tensors, post_qdq_tensors)
|
|
230
|
+
|
|
231
|
+
if not float_activations:
|
|
232
|
+
return qdq_cmp
|
|
233
|
+
|
|
234
|
+
for act_name, act_values in qdq_cmp.items():
|
|
235
|
+
float_acts = float_activations.get(act_name)
|
|
236
|
+
if float_acts is not None:
|
|
237
|
+
act_values["float"] = float_acts
|
|
238
|
+
|
|
239
|
+
return qdq_cmp
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _run_dequantize_linear(
|
|
243
|
+
weight_tensor: numpy.ndarray, weight_scale: numpy.ndarray, weight_zp: numpy.ndarray, channel_axis: int
|
|
244
|
+
) -> numpy.ndarray | None:
|
|
245
|
+
assert weight_scale.shape == weight_zp.shape
|
|
246
|
+
if weight_zp.size == 1:
|
|
247
|
+
return (weight_tensor - weight_zp) * weight_scale
|
|
248
|
+
|
|
249
|
+
assert weight_zp.ndim == 1
|
|
250
|
+
reshape_dims = list(weight_tensor.shape) # deep copy
|
|
251
|
+
reshape_dims[channel_axis] = 1 # only one per channel for reshape
|
|
252
|
+
channel_count = weight_tensor.shape[channel_axis]
|
|
253
|
+
dequantized_weights = None
|
|
254
|
+
for i in range(channel_count):
|
|
255
|
+
per_channel_data = weight_tensor.take(i, channel_axis)
|
|
256
|
+
dequantized_per_channel_data = (per_channel_data - weight_zp[i]) * weight_scale[i]
|
|
257
|
+
if i == 0:
|
|
258
|
+
dequantized_weights = numpy.asarray(dequantized_per_channel_data).reshape(reshape_dims)
|
|
259
|
+
else:
|
|
260
|
+
channel_weights = numpy.asarray(dequantized_per_channel_data).reshape(reshape_dims)
|
|
261
|
+
dequantized_weights = numpy.concatenate((dequantized_weights, channel_weights), channel_axis)
|
|
262
|
+
|
|
263
|
+
if dequantized_weights is None:
|
|
264
|
+
return None
|
|
265
|
+
|
|
266
|
+
dequantized_weights.reshape(weight_tensor.shape)
|
|
267
|
+
return dequantized_weights
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def create_weight_matching(float_model_path: str, qdq_model_path: str) -> dict[str, dict[str, numpy.ndarray]]:
|
|
271
|
+
"""Comparing weight values to help debugging accuracy loss due to quantization.
|
|
272
|
+
|
|
273
|
+
This functions takes the float model and the qdq model, and provides a data structure for comparing
|
|
274
|
+
their corresponding weights to locate quantization errors
|
|
275
|
+
|
|
276
|
+
Arg:
|
|
277
|
+
float_model_path: Path points to the float point model.
|
|
278
|
+
qdq_model_path: Path points to the qdq model.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
Dict for comparing weight tensors. E.g.
|
|
282
|
+
```
|
|
283
|
+
qdq_weight_cmp = create_weight_matching(float_model, qdq_model)
|
|
284
|
+
print(qdq_weight_cmp['activation1']['float'])
|
|
285
|
+
print(qdq_weight_cmp['activation1']['dequantized'])
|
|
286
|
+
```
|
|
287
|
+
"""
|
|
288
|
+
float_onnx_model = ONNXModel(load_model_with_shape_infer(Path(float_model_path)))
|
|
289
|
+
qdq_onnx_model = ONNXModel(load_model_with_shape_infer(Path(qdq_model_path)))
|
|
290
|
+
|
|
291
|
+
matched_weights: dict[str, dict[str, numpy.ndarray]] = {}
|
|
292
|
+
initializers = qdq_onnx_model.initializer()
|
|
293
|
+
for node in qdq_onnx_model.nodes():
|
|
294
|
+
if node.op_type != DEQUANT_OP_NAME:
|
|
295
|
+
continue # Only care about DQ node
|
|
296
|
+
weight_name: str = node.input[0]
|
|
297
|
+
weight_values = find_by_name(weight_name, initializers)
|
|
298
|
+
if not weight_values:
|
|
299
|
+
continue # Only care about DQ node with const inputs
|
|
300
|
+
if not weight_name.endswith(TENSOR_NAME_QUANT_SUFFIX):
|
|
301
|
+
logging.error(f"Model Error in '{qdq_model_path}': Dequantized tensor name '{weight_name}' not recognized!")
|
|
302
|
+
continue
|
|
303
|
+
|
|
304
|
+
axis = -1
|
|
305
|
+
for attr in node.attribute:
|
|
306
|
+
if attr.name == "axis":
|
|
307
|
+
axis = attr.i
|
|
308
|
+
|
|
309
|
+
weight_tensor = numpy_helper.to_array(weight_values)
|
|
310
|
+
weight_scale = numpy_helper.to_array(find_by_name(node.input[1], initializers))
|
|
311
|
+
if len(node.input) > 2:
|
|
312
|
+
weight_zp = numpy_helper.to_array(find_by_name(node.input[2], initializers))
|
|
313
|
+
else:
|
|
314
|
+
weight_zp = numpy.zeros(weight_scale.shape, dtype=numpy.int32)
|
|
315
|
+
|
|
316
|
+
# Perform dequantization:
|
|
317
|
+
if weight_scale.size == weight_zp.size == 1:
|
|
318
|
+
# Avoids the confusion between a scaler and a tensor of one element.
|
|
319
|
+
weight_scale = weight_scale.reshape(())
|
|
320
|
+
weight_zp = weight_zp.reshape(())
|
|
321
|
+
if weight_scale.shape != weight_zp.shape:
|
|
322
|
+
raise RuntimeError(
|
|
323
|
+
f"scale and zero_point must have the same shape but {weight_scale.shape} != {weight_zp.shape}"
|
|
324
|
+
)
|
|
325
|
+
weight_quant = _run_dequantize_linear(weight_tensor, weight_scale, weight_zp, channel_axis=axis)
|
|
326
|
+
weight_name = weight_name[: -len(TENSOR_NAME_QUANT_SUFFIX)]
|
|
327
|
+
if weight_quant is None:
|
|
328
|
+
logging.error(f"Model Error in '{qdq_model_path}': '{weight_name}' per-channel quantization on 0 channel")
|
|
329
|
+
continue
|
|
330
|
+
|
|
331
|
+
float_values = find_by_name(weight_name, float_onnx_model.initializer())
|
|
332
|
+
if not float_values:
|
|
333
|
+
logging.error(f"Model Error in '{float_model_path}': weight tensor '{weight_name}' not found!")
|
|
334
|
+
continue
|
|
335
|
+
weight_float = numpy_helper.to_array(float_values)
|
|
336
|
+
matched_weights[weight_name] = {"float": weight_float, "dequantized": weight_quant}
|
|
337
|
+
|
|
338
|
+
return matched_weights
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def compute_signal_to_quantization_noice_ratio(
|
|
342
|
+
x: Sequence[numpy.ndarray] | numpy.ndarray, y: Sequence[numpy.ndarray] | numpy.ndarray
|
|
343
|
+
) -> float:
|
|
344
|
+
if isinstance(x, numpy.ndarray):
|
|
345
|
+
xlist = [x]
|
|
346
|
+
else:
|
|
347
|
+
xlist = x
|
|
348
|
+
if isinstance(y, numpy.ndarray):
|
|
349
|
+
ylist = [y]
|
|
350
|
+
else:
|
|
351
|
+
ylist = y
|
|
352
|
+
if len(xlist) != len(ylist):
|
|
353
|
+
raise RuntimeError("Unequal number of tensors to compare!")
|
|
354
|
+
|
|
355
|
+
left = numpy.concatenate(xlist).flatten()
|
|
356
|
+
right = numpy.concatenate(ylist).flatten()
|
|
357
|
+
|
|
358
|
+
epsilon = numpy.finfo("float").eps
|
|
359
|
+
tensor_norm = max(numpy.linalg.norm(left), epsilon)
|
|
360
|
+
diff_norm = max(numpy.linalg.norm(left - right), epsilon)
|
|
361
|
+
res = tensor_norm / diff_norm
|
|
362
|
+
return 20 * math.log10(res)
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def compute_weight_error(
|
|
366
|
+
weights_match: dict[str, dict[str, numpy.ndarray]],
|
|
367
|
+
err_func: Callable[[numpy.ndarray, numpy.ndarray], float] = compute_signal_to_quantization_noice_ratio,
|
|
368
|
+
) -> dict[str, float]:
|
|
369
|
+
result: dict[str, float] = {}
|
|
370
|
+
for weight_name, weight_match in weights_match.items():
|
|
371
|
+
result[weight_name] = err_func(weight_match["float"], weight_match["dequantized"])
|
|
372
|
+
return result
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def compute_activation_error(
|
|
376
|
+
activations_match: dict[str, dict[str, Sequence[numpy.ndarray]]],
|
|
377
|
+
err_func: Callable[
|
|
378
|
+
[Sequence[numpy.ndarray], Sequence[numpy.ndarray]], float
|
|
379
|
+
] = compute_signal_to_quantization_noice_ratio,
|
|
380
|
+
) -> dict[str, dict[str, float]]:
|
|
381
|
+
result: dict[str, dict[str, float]] = {}
|
|
382
|
+
for name, match in activations_match.items():
|
|
383
|
+
err_result: dict[str, float] = {}
|
|
384
|
+
err_result["qdq_err"] = err_func(match["pre_qdq"], match["post_qdq"])
|
|
385
|
+
float_activation = match["float"]
|
|
386
|
+
if float_activation:
|
|
387
|
+
err_result["xmodel_err"] = err_func(float_activation, match["post_qdq"])
|
|
388
|
+
result[name] = err_result
|
|
389
|
+
return result
|