onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,520 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
from collections.abc import MutableMapping
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import onnx
|
|
14
|
+
|
|
15
|
+
from .quant_utils import QuantType
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class QuantTypeInfo: # noqa: PLW1641
|
|
20
|
+
"""
|
|
21
|
+
The quantization type information for a tensor override.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
quant_type: QuantType
|
|
25
|
+
symmetric: bool | None = None # If None, assumes default is used.
|
|
26
|
+
reduce_range: bool | None = None # If None, assumes default is used.
|
|
27
|
+
axis: int | None = None # If None, assumes per-tensor quantization
|
|
28
|
+
|
|
29
|
+
def __eq__(self, other: object):
|
|
30
|
+
if isinstance(other, QuantTypeInfo):
|
|
31
|
+
return (
|
|
32
|
+
self.quant_type == other.quant_type
|
|
33
|
+
and (self.symmetric is None or other.symmetric is None or self.symmetric == other.symmetric)
|
|
34
|
+
and (self.reduce_range is None or other.reduce_range is None or self.reduce_range == other.reduce_range)
|
|
35
|
+
and (self.axis == other.axis)
|
|
36
|
+
)
|
|
37
|
+
return NotImplemented
|
|
38
|
+
|
|
39
|
+
@staticmethod
|
|
40
|
+
def load_from_dict(
|
|
41
|
+
raw_dict: dict[str, Any],
|
|
42
|
+
default_qtype: QuantType | None = None,
|
|
43
|
+
default_symmetric: bool | None = None,
|
|
44
|
+
default_reduce_range: bool | None = None,
|
|
45
|
+
) -> QuantTypeInfo:
|
|
46
|
+
return QuantTypeInfo(
|
|
47
|
+
raw_dict.get("quant_type", default_qtype),
|
|
48
|
+
raw_dict.get("symmetric", default_symmetric),
|
|
49
|
+
raw_dict.get("reduce_range", default_reduce_range),
|
|
50
|
+
raw_dict.get("axis"),
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def save_to_dict(self, raw_dict: dict[str, Any]):
|
|
54
|
+
raw_dict["quant_type"] = self.quant_type
|
|
55
|
+
if self.symmetric is not None:
|
|
56
|
+
raw_dict["symmetric"] = self.symmetric
|
|
57
|
+
if self.reduce_range is not None:
|
|
58
|
+
raw_dict["reduce_range"] = self.reduce_range
|
|
59
|
+
if self.axis is not None:
|
|
60
|
+
raw_dict["axis"] = self.axis
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class TensorQuantOverridesHelper(MutableMapping):
|
|
64
|
+
"""
|
|
65
|
+
Utility wrapper over the tensor quantization overrides passed via extra_options.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(self, raw_overrides: dict[str, list[dict[str, Any]]]):
|
|
69
|
+
self.overrides = raw_overrides
|
|
70
|
+
self.quant_types = None
|
|
71
|
+
self.keys_unsupported_with_scale_zp = {"symmetric", "reduce_range", "rmax", "rmin"}
|
|
72
|
+
|
|
73
|
+
def has_per_tensor_overrides(self, tensor_name: str) -> bool:
|
|
74
|
+
overrides_list = self.overrides.get(tensor_name)
|
|
75
|
+
return overrides_list and "axis" not in overrides_list[0]
|
|
76
|
+
|
|
77
|
+
def has_per_channel_overrides(self, tensor_name: str) -> bool:
|
|
78
|
+
overrides_list = self.overrides.get(tensor_name)
|
|
79
|
+
return overrides_list and "axis" in overrides_list[0]
|
|
80
|
+
|
|
81
|
+
def overrides_scale_zp(self, tensor_name: str) -> bool:
|
|
82
|
+
overrides_list = self.overrides.get(tensor_name)
|
|
83
|
+
return overrides_list and ("scale" in overrides_list[0]) and ("zero_point" in overrides_list[0])
|
|
84
|
+
|
|
85
|
+
def get_per_tensor_overrides(
|
|
86
|
+
self,
|
|
87
|
+
tensor_name: str,
|
|
88
|
+
default_val: dict[str, Any] | None = None,
|
|
89
|
+
) -> dict[str, Any] | None:
|
|
90
|
+
default_list_val = [default_val] if default_val is not None else None
|
|
91
|
+
overrides_list = self.overrides.get(tensor_name, default_list_val)
|
|
92
|
+
if overrides_list and "axis" in overrides_list[0]:
|
|
93
|
+
raise ValueError(
|
|
94
|
+
f"Expected tensor '{tensor_name}' to use per-tensor quantization overrides, "
|
|
95
|
+
f"but found per-channel overrides."
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
return overrides_list[0] if overrides_list else None
|
|
99
|
+
|
|
100
|
+
def get_per_channel_overrides(
|
|
101
|
+
self,
|
|
102
|
+
tensor_name: str,
|
|
103
|
+
default_val: list[dict[str, Any]] | None = None,
|
|
104
|
+
) -> list[dict[str, Any]] | None:
|
|
105
|
+
overrides_list = self.overrides.get(tensor_name, default_val)
|
|
106
|
+
|
|
107
|
+
if not overrides_list:
|
|
108
|
+
return None
|
|
109
|
+
|
|
110
|
+
if "axis" not in overrides_list[0]:
|
|
111
|
+
raise ValueError(
|
|
112
|
+
f"Expected tensor '{tensor_name}' to have per-channel quantization overrides (axis value is missing).",
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
return overrides_list
|
|
116
|
+
|
|
117
|
+
def get_quant_types(self) -> set[QuantType]:
|
|
118
|
+
if self.quant_types is not None:
|
|
119
|
+
return self.quant_types
|
|
120
|
+
|
|
121
|
+
self.quant_types = set()
|
|
122
|
+
|
|
123
|
+
if self.overrides:
|
|
124
|
+
for quant_overrides_list in self.overrides.values():
|
|
125
|
+
for quant_overrides in quant_overrides_list:
|
|
126
|
+
if "quant_type" in quant_overrides:
|
|
127
|
+
self.quant_types.add(quant_overrides["quant_type"])
|
|
128
|
+
|
|
129
|
+
if "convert" in quant_overrides and "quant_type" in quant_overrides["convert"]:
|
|
130
|
+
self.quant_types.add(quant_overrides["convert"]["quant_type"])
|
|
131
|
+
|
|
132
|
+
return self.quant_types
|
|
133
|
+
|
|
134
|
+
def _is_valid_per_tensor(
|
|
135
|
+
self,
|
|
136
|
+
initializers,
|
|
137
|
+
default_activation_qtype,
|
|
138
|
+
tensor_name: str,
|
|
139
|
+
quant_overrides: dict[str, Any],
|
|
140
|
+
) -> tuple[bool, str | None]:
|
|
141
|
+
if not isinstance(quant_overrides, dict):
|
|
142
|
+
return (
|
|
143
|
+
False,
|
|
144
|
+
f"Tensor quantization overrides for '{tensor_name}' are not in a dict",
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
is_initializer = tensor_name in initializers
|
|
148
|
+
|
|
149
|
+
quant_type = quant_overrides.get("quant_type")
|
|
150
|
+
if quant_type:
|
|
151
|
+
self.quant_types.add(quant_type)
|
|
152
|
+
|
|
153
|
+
has_scale = "scale" in quant_overrides
|
|
154
|
+
has_zero_point = "zero_point" in quant_overrides
|
|
155
|
+
|
|
156
|
+
if (has_scale and not has_zero_point) or (has_zero_point and not has_scale):
|
|
157
|
+
return (
|
|
158
|
+
False,
|
|
159
|
+
"Must provide both 'scale' and 'zero_point' if one of the overrides is provided",
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
if has_scale:
|
|
163
|
+
keys = self.keys_unsupported_with_scale_zp.intersection(set(quant_overrides))
|
|
164
|
+
if keys:
|
|
165
|
+
return (
|
|
166
|
+
False,
|
|
167
|
+
f"Tensor override option(s) [{', '.join(keys)}] are invalid with 'scale' and 'zero_point'",
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
if "reduce_range" in quant_overrides and not is_initializer:
|
|
171
|
+
return (
|
|
172
|
+
False,
|
|
173
|
+
f"Option 'reduce_range' is only supported for initializers, not for activation {tensor_name}",
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
if "convert" in quant_overrides:
|
|
177
|
+
if is_initializer:
|
|
178
|
+
return False, "Cannot use 'convert' override for initializers"
|
|
179
|
+
|
|
180
|
+
if "quant_type" not in quant_overrides["convert"]:
|
|
181
|
+
return False, f"'convert' options (tensor '{tensor_name}') must specify a 'quant_type'"
|
|
182
|
+
|
|
183
|
+
if "reduce_range" in quant_overrides["convert"]:
|
|
184
|
+
return (
|
|
185
|
+
False,
|
|
186
|
+
f"Option 'reduce_range' is only supported for initializers, not for activation {tensor_name}",
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
convert_quant_type = quant_overrides["convert"]["quant_type"]
|
|
190
|
+
original_quant_type = quant_type if quant_type is not None else default_activation_qtype
|
|
191
|
+
if convert_quant_type == original_quant_type:
|
|
192
|
+
return (
|
|
193
|
+
False,
|
|
194
|
+
f"'convert' quant_type must differ from original quant_type (tensor '{tensor_name}')",
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
convert_has_scale = "scale" in quant_overrides["convert"]
|
|
198
|
+
convert_has_zero_point = "zero_point" in quant_overrides["convert"]
|
|
199
|
+
|
|
200
|
+
if (convert_has_scale and not convert_has_zero_point) or (convert_has_zero_point and not convert_has_scale):
|
|
201
|
+
return (
|
|
202
|
+
False,
|
|
203
|
+
f"Must provide both 'scale' and 'zero_point' if one of the overrides is provided (tensor '{tensor_name}')",
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
if convert_has_scale:
|
|
207
|
+
keys = self.keys_unsupported_with_scale_zp.intersection(set(quant_overrides["convert"]))
|
|
208
|
+
if keys:
|
|
209
|
+
return (
|
|
210
|
+
False,
|
|
211
|
+
f"Tensor override option(s) [{', '.join(keys)}] are invalid with 'scale' and 'zero_point' "
|
|
212
|
+
f"(tensor '{tensor_name}')",
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
self.quant_types.add(convert_quant_type)
|
|
216
|
+
|
|
217
|
+
return True, None
|
|
218
|
+
|
|
219
|
+
def _is_valid_per_channel(
|
|
220
|
+
self,
|
|
221
|
+
initializers,
|
|
222
|
+
tensor_name: str,
|
|
223
|
+
quant_overrides_list: list[dict[str, Any]],
|
|
224
|
+
) -> tuple[bool, str | None]:
|
|
225
|
+
is_initializer = tensor_name in initializers
|
|
226
|
+
|
|
227
|
+
if not is_initializer:
|
|
228
|
+
return (
|
|
229
|
+
False,
|
|
230
|
+
f"Tensor '{tensor_name}' has per-channel overrides, but is not an initializer",
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
axis = quant_overrides_list[0].get("axis")
|
|
234
|
+
|
|
235
|
+
if axis is None:
|
|
236
|
+
return (
|
|
237
|
+
False,
|
|
238
|
+
f"Per-channel overrides for tensor {tensor_name} is missing an 'axis' value in "
|
|
239
|
+
"the first channel dictionary.",
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
weight_shape = list(initializers[tensor_name].dims)
|
|
243
|
+
weight_rank = len(weight_shape)
|
|
244
|
+
norm_axis = axis
|
|
245
|
+
if norm_axis < 0:
|
|
246
|
+
norm_axis += weight_rank
|
|
247
|
+
|
|
248
|
+
if norm_axis < 0 or norm_axis >= len(weight_shape):
|
|
249
|
+
return (
|
|
250
|
+
False,
|
|
251
|
+
f"Axis override value is out-of-bounds for tensor {tensor_name} (rank {len(weight_shape)})",
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
if len(quant_overrides_list) > 1 and len(quant_overrides_list) != weight_shape[norm_axis]:
|
|
255
|
+
return (
|
|
256
|
+
False,
|
|
257
|
+
f"Incorrect number of channel overrides for tensor {tensor_name} (axis {axis}), "
|
|
258
|
+
f"expected {weight_shape[axis]}, but found {len(quant_overrides_list)}.",
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
if "convert" in quant_overrides_list[0]:
|
|
262
|
+
return False, f"Cannot use 'convert' override for initializers, such as {tensor_name}."
|
|
263
|
+
|
|
264
|
+
quant_type = quant_overrides_list[0].get("quant_type")
|
|
265
|
+
if quant_type:
|
|
266
|
+
self.quant_types.add(quant_type)
|
|
267
|
+
|
|
268
|
+
symmetric = quant_overrides_list[0].get("symmetric")
|
|
269
|
+
reduce_range = quant_overrides_list[0].get("reduce_range")
|
|
270
|
+
|
|
271
|
+
has_scale = "scale" in quant_overrides_list[0]
|
|
272
|
+
has_zero_point = "zero_point" in quant_overrides_list[0]
|
|
273
|
+
has_scale_zp = has_scale and has_zero_point
|
|
274
|
+
|
|
275
|
+
if (has_scale and not has_zero_point) or (has_zero_point and not has_scale):
|
|
276
|
+
return (
|
|
277
|
+
False,
|
|
278
|
+
"Must provide both 'scale' and 'zero_point' if one of the overrides is provided",
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
if has_scale_zp:
|
|
282
|
+
keys = self.keys_unsupported_with_scale_zp.intersection(set(quant_overrides_list[0]))
|
|
283
|
+
if keys:
|
|
284
|
+
return (
|
|
285
|
+
False,
|
|
286
|
+
f"Tensor override option(s) [{', '.join(keys)}] are invalid with 'scale' and 'zero_point'",
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
has_rmin = "rmin" in quant_overrides_list[0]
|
|
290
|
+
has_rmax = "rmax" in quant_overrides_list[0]
|
|
291
|
+
has_rmin_rmax = has_rmin and has_rmax
|
|
292
|
+
if (has_rmin and not has_rmax) or (not has_rmin and has_rmax):
|
|
293
|
+
return (
|
|
294
|
+
False,
|
|
295
|
+
"Must provide both 'rmin' and 'rmax' if one is provided",
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
for index, quant_overrides in enumerate(quant_overrides_list[1:]):
|
|
299
|
+
if not isinstance(quant_overrides, dict):
|
|
300
|
+
return (
|
|
301
|
+
False,
|
|
302
|
+
f"Tensor quantization overrides at index {index} for '{tensor_name}' are not in a dict",
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
if "convert" in quant_overrides:
|
|
306
|
+
return False, f"Cannot use 'convert' override for initializers, such as {tensor_name}."
|
|
307
|
+
|
|
308
|
+
# For per-channel quantization, all channels must use the same quantization type, axis, symmetric
|
|
309
|
+
# and reduce_range values. And, if specified, they must be present in the first channel dict
|
|
310
|
+
# (i.e., quant_overrides_list[0]).
|
|
311
|
+
if "quant_type" in quant_overrides and quant_type != quant_overrides["quant_type"]:
|
|
312
|
+
return (
|
|
313
|
+
False,
|
|
314
|
+
"Channel quantization types for tensor '{tensor_name}' do not match at index {index}.",
|
|
315
|
+
)
|
|
316
|
+
if "axis" in quant_overrides and axis != quant_overrides["axis"] and norm_axis != quant_overrides["axis"]:
|
|
317
|
+
return (
|
|
318
|
+
False,
|
|
319
|
+
"Channel axis for tensor '{tensor_name}' does not match at index {index}.",
|
|
320
|
+
)
|
|
321
|
+
if "symmetric" in quant_overrides and symmetric != quant_overrides["symmetric"]:
|
|
322
|
+
return (
|
|
323
|
+
False,
|
|
324
|
+
"Channel symmetric value for tensor '{tensor_name}' does not match at index {index}.",
|
|
325
|
+
)
|
|
326
|
+
if "reduce_range" in quant_overrides and reduce_range != quant_overrides["reduce_range"]:
|
|
327
|
+
return (
|
|
328
|
+
False,
|
|
329
|
+
"Channel reduce_range value for tensor '{tensor_name}' does not match at index {index}.",
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
# If override scale/zp, must do so for all channels.
|
|
333
|
+
chan_has_scale_zp = "scale" in quant_overrides and "zero_point" in quant_overrides
|
|
334
|
+
|
|
335
|
+
if has_scale_zp and not chan_has_scale_zp:
|
|
336
|
+
return (
|
|
337
|
+
False,
|
|
338
|
+
"Per-channel overrides that specify scale/zero_point must do so for all channels, "
|
|
339
|
+
f"but tensor '{tensor_name}' is missing them at index {index}.",
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
if chan_has_scale_zp:
|
|
343
|
+
keys = self.keys_unsupported_with_scale_zp.intersection(set(quant_overrides))
|
|
344
|
+
if keys:
|
|
345
|
+
return (
|
|
346
|
+
False,
|
|
347
|
+
f"Tensor override option(s) [{', '.join(keys)}] are invalid with 'scale' and 'zero_point'",
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
# If override rmin/rmax, must do so for all channels.
|
|
351
|
+
chan_has_rmin_rmax = "rmin" in quant_overrides and "rmax" in quant_overrides
|
|
352
|
+
if has_rmin_rmax and not chan_has_rmin_rmax:
|
|
353
|
+
return (
|
|
354
|
+
False,
|
|
355
|
+
"Per-channel overrides that specify rmin/rmax must do so for all channels, "
|
|
356
|
+
f"but tensor '{tensor_name}' is missing them at index {index}.",
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
return True, None
|
|
360
|
+
|
|
361
|
+
def is_valid(
|
|
362
|
+
self,
|
|
363
|
+
initializers: dict[str, onnx.TensorProto],
|
|
364
|
+
activation_names: set[str],
|
|
365
|
+
default_activation_qtype,
|
|
366
|
+
) -> tuple[bool, str | None]:
|
|
367
|
+
self.quant_types = set()
|
|
368
|
+
|
|
369
|
+
# Validate that compatible/valid overrides are provided.
|
|
370
|
+
if self.overrides:
|
|
371
|
+
for tensor_name, quant_overrides_list in self.overrides.items():
|
|
372
|
+
if tensor_name not in initializers and tensor_name not in activation_names:
|
|
373
|
+
return False, f"Tensor '{tensor_name}' in TensorQuantOverrides is not present in the model"
|
|
374
|
+
|
|
375
|
+
if not isinstance(quant_overrides_list, list):
|
|
376
|
+
return False, f"Tensor quantization overrides for '{tensor_name}' are not in a list"
|
|
377
|
+
|
|
378
|
+
if not quant_overrides_list:
|
|
379
|
+
continue
|
|
380
|
+
|
|
381
|
+
if not isinstance(quant_overrides_list[0], dict):
|
|
382
|
+
return False, f"Tensor quantization overrides at index 0 for '{tensor_name}' are not in a dict"
|
|
383
|
+
|
|
384
|
+
if not quant_overrides_list[0]:
|
|
385
|
+
continue
|
|
386
|
+
|
|
387
|
+
axis = quant_overrides_list[0].get("axis")
|
|
388
|
+
is_per_channel = len(quant_overrides_list) > 1 or axis is not None
|
|
389
|
+
|
|
390
|
+
if is_per_channel:
|
|
391
|
+
return self._is_valid_per_channel(initializers, tensor_name, quant_overrides_list)
|
|
392
|
+
|
|
393
|
+
return self._is_valid_per_tensor(
|
|
394
|
+
initializers, default_activation_qtype, tensor_name, quant_overrides_list[0]
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
return True, None
|
|
398
|
+
|
|
399
|
+
def update_tensor_overrides(
|
|
400
|
+
self,
|
|
401
|
+
tensor_name: str,
|
|
402
|
+
new_vals: dict[str, Any],
|
|
403
|
+
channels: list[int] | None = None,
|
|
404
|
+
overwrite: bool = True,
|
|
405
|
+
) -> bool:
|
|
406
|
+
if not new_vals:
|
|
407
|
+
return False
|
|
408
|
+
|
|
409
|
+
channels = set(channels) if channels is not None else None
|
|
410
|
+
have_overrides = self.overrides.get(tensor_name)
|
|
411
|
+
|
|
412
|
+
# If `overwrite` is False, check if we would overwrite anything.
|
|
413
|
+
do_update = True
|
|
414
|
+
if not overwrite and have_overrides:
|
|
415
|
+
for channel, overrides in enumerate(self.overrides[tensor_name]):
|
|
416
|
+
if channels is not None and channel not in channels:
|
|
417
|
+
continue
|
|
418
|
+
if set(new_vals).intersection(set(overrides)):
|
|
419
|
+
do_update = False
|
|
420
|
+
break
|
|
421
|
+
|
|
422
|
+
# Do the update if `overwrite` is True or if nothing is overwritten (do not want partial overwrites).
|
|
423
|
+
if do_update:
|
|
424
|
+
if not have_overrides:
|
|
425
|
+
self.overrides[tensor_name] = [{}]
|
|
426
|
+
|
|
427
|
+
for channel, overrides in enumerate(self.overrides[tensor_name]):
|
|
428
|
+
if channels is not None and channel not in channels:
|
|
429
|
+
continue
|
|
430
|
+
overrides.update(new_vals)
|
|
431
|
+
|
|
432
|
+
return do_update
|
|
433
|
+
|
|
434
|
+
def get_node_output_qtype_info(
|
|
435
|
+
self,
|
|
436
|
+
output_name: str,
|
|
437
|
+
default_qtype: QuantType | None,
|
|
438
|
+
default_symmetric: bool | None = None,
|
|
439
|
+
) -> QuantTypeInfo:
|
|
440
|
+
# Outputs are activations, which do not support 'reduce_range' or 'axis'
|
|
441
|
+
if output_name not in self.overrides:
|
|
442
|
+
return QuantTypeInfo(default_qtype, default_symmetric)
|
|
443
|
+
|
|
444
|
+
tensor_overrides = self.overrides[output_name][0]
|
|
445
|
+
|
|
446
|
+
return QuantTypeInfo(
|
|
447
|
+
tensor_overrides.get("quant_type", default_qtype),
|
|
448
|
+
tensor_overrides.get("symmetric", default_symmetric),
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
def get_node_input_qtype_info(
|
|
452
|
+
self,
|
|
453
|
+
input_name: str,
|
|
454
|
+
node_name: str,
|
|
455
|
+
default_qtype: QuantType | None,
|
|
456
|
+
default_symmetric: bool | None = None,
|
|
457
|
+
default_reduce_range: bool | None = None,
|
|
458
|
+
) -> QuantTypeInfo:
|
|
459
|
+
if input_name not in self.overrides or not self.overrides[input_name]:
|
|
460
|
+
return QuantTypeInfo(default_qtype, default_symmetric, default_reduce_range)
|
|
461
|
+
|
|
462
|
+
# Get the first overrides dict in the list. This works for both per-tensor and per-channel
|
|
463
|
+
# quantization because all channels must use the same quant type.
|
|
464
|
+
tensor_overrides = self.overrides[input_name][0]
|
|
465
|
+
producer_type = tensor_overrides.get("quant_type", default_qtype)
|
|
466
|
+
|
|
467
|
+
if "convert" not in tensor_overrides:
|
|
468
|
+
return QuantTypeInfo(
|
|
469
|
+
producer_type,
|
|
470
|
+
tensor_overrides.get("symmetric", default_symmetric),
|
|
471
|
+
tensor_overrides.get("reduce_range", default_reduce_range),
|
|
472
|
+
tensor_overrides.get("axis"),
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
# This tensor is converted. Check if the node gets the original qtype or the converted qtype.
|
|
476
|
+
convert_dict = tensor_overrides["convert"]
|
|
477
|
+
qtype_info = QuantTypeInfo(
|
|
478
|
+
producer_type,
|
|
479
|
+
convert_dict.get("symmetric", default_symmetric),
|
|
480
|
+
# Converted tensors are not initializers, so do not have 'axis' or 'reduce_range'.
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
# Check if all nodes receive the converted type (i.e., recv_nodes is None) or this node
|
|
484
|
+
# is in the list of consumers (recv_nodes).
|
|
485
|
+
if ("recv_nodes" not in convert_dict) or (node_name in convert_dict["recv_nodes"]):
|
|
486
|
+
qtype_info.quant_type = convert_dict["quant_type"]
|
|
487
|
+
|
|
488
|
+
return qtype_info
|
|
489
|
+
|
|
490
|
+
def pprint_str(self, indent=None) -> str:
|
|
491
|
+
return json.dumps(self.overrides, default=str, indent=indent)
|
|
492
|
+
|
|
493
|
+
def empty(self) -> bool:
|
|
494
|
+
return not self.overrides
|
|
495
|
+
|
|
496
|
+
def get_dict(self) -> dict[str, list[dict[str, Any]]]:
|
|
497
|
+
return self.overrides
|
|
498
|
+
|
|
499
|
+
# Required implementations of abstract methods in collections.abc.MutableMapping
|
|
500
|
+
# so that this class can be used like a dict.
|
|
501
|
+
def __setitem__(self, key: str, value: list[dict]):
|
|
502
|
+
self.overrides[key] = value
|
|
503
|
+
|
|
504
|
+
def __getitem__(self, key: str) -> list[dict]:
|
|
505
|
+
return self.overrides[key]
|
|
506
|
+
|
|
507
|
+
def __delitem__(self, key: str):
|
|
508
|
+
del self.overrides[key]
|
|
509
|
+
|
|
510
|
+
def __iter__(self):
|
|
511
|
+
return iter(self.overrides)
|
|
512
|
+
|
|
513
|
+
def __len__(self):
|
|
514
|
+
return len(self.overrides)
|
|
515
|
+
|
|
516
|
+
def __str__(self) -> str:
|
|
517
|
+
return str(self.overrides)
|
|
518
|
+
|
|
519
|
+
def __repr__(self) -> str:
|
|
520
|
+
return f"{super().__repr__()}, TensorQuantOverridesHelper({self.overrides})"
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
# appended to the __init__.py in the onnxruntime module's 'tools' folder from /tools/python/util/__init__append.py
|
|
6
|
+
import importlib.util
|
|
7
|
+
|
|
8
|
+
have_torch = importlib.util.find_spec("torch")
|
|
9
|
+
if have_torch:
|
|
10
|
+
from .pytorch_export_helpers import infer_input_info # noqa: F401
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
import argparse
|
|
5
|
+
import logging
|
|
6
|
+
import pathlib
|
|
7
|
+
|
|
8
|
+
# need this before the mobile helper imports for some reason
|
|
9
|
+
logging.basicConfig(format="%(levelname)s: %(message)s")
|
|
10
|
+
|
|
11
|
+
from .mobile_helpers import usability_checker # noqa: E402
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def check_usability():
|
|
15
|
+
parser = argparse.ArgumentParser(
|
|
16
|
+
description="""Analyze an ONNX model to determine how well it will work in mobile scenarios.""",
|
|
17
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
18
|
+
)
|
|
19
|
+
parser.add_argument("--log_level", choices=["debug", "info"], default="info", help="Logging level")
|
|
20
|
+
parser.add_argument("model_path", help="Path to ONNX model to check", type=pathlib.Path)
|
|
21
|
+
|
|
22
|
+
args = parser.parse_args()
|
|
23
|
+
logger = logging.getLogger("check_usability")
|
|
24
|
+
|
|
25
|
+
if args.log_level == "debug":
|
|
26
|
+
logger.setLevel(logging.DEBUG)
|
|
27
|
+
elif args.log_level == "info":
|
|
28
|
+
logger.setLevel(logging.INFO)
|
|
29
|
+
elif args.log_level == "warning":
|
|
30
|
+
logger.setLevel(logging.WARNING)
|
|
31
|
+
else:
|
|
32
|
+
logger.setLevel(logging.ERROR)
|
|
33
|
+
|
|
34
|
+
try_eps = usability_checker.analyze_model(args.model_path, skip_optimize=False, logger=logger)
|
|
35
|
+
|
|
36
|
+
if try_eps:
|
|
37
|
+
logger.info(
|
|
38
|
+
"As NNAPI or CoreML may provide benefits with this model it is recommended to compare the "
|
|
39
|
+
"performance of the model using the NNAPI EP on Android, and the CoreML EP on iOS, "
|
|
40
|
+
"against the performance using the CPU EP."
|
|
41
|
+
)
|
|
42
|
+
else:
|
|
43
|
+
logger.info("For optimal performance the model should be used with the CPU EP. ")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
if __name__ == "__main__":
|
|
47
|
+
check_usability()
|