onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
from .operators.activation import QDQRemovableActivation, QLinearActivation
|
|
2
|
+
from .operators.argmax import QArgMax
|
|
3
|
+
from .operators.attention import AttentionQuant
|
|
4
|
+
from .operators.base_operator import QuantOperatorBase
|
|
5
|
+
from .operators.binary_op import QLinearBinaryOp
|
|
6
|
+
from .operators.concat import QLinearConcat
|
|
7
|
+
from .operators.conv import ConvInteger, QDQConv, QLinearConv
|
|
8
|
+
from .operators.direct_q8 import Direct8BitOp, QDQDirect8BitOp
|
|
9
|
+
from .operators.embed_layernorm import EmbedLayerNormalizationQuant
|
|
10
|
+
from .operators.gather import GatherQuant, QDQGather
|
|
11
|
+
from .operators.gavgpool import QGlobalAveragePool
|
|
12
|
+
from .operators.gemm import QDQGemm, QLinearGemm
|
|
13
|
+
from .operators.lstm import LSTMQuant
|
|
14
|
+
from .operators.matmul import MatMulInteger, QDQMatMul, QLinearMatMul
|
|
15
|
+
from .operators.maxpool import QDQMaxPool, QMaxPool
|
|
16
|
+
from .operators.norm import QDQNormalization
|
|
17
|
+
from .operators.pad import QPad
|
|
18
|
+
from .operators.pooling import QLinearPool
|
|
19
|
+
from .operators.qdq_base_operator import QDQOperatorBase
|
|
20
|
+
from .operators.resize import QDQResize, QResize
|
|
21
|
+
from .operators.softmax import QLinearSoftmax
|
|
22
|
+
from .operators.split import QDQSplit, QSplit
|
|
23
|
+
from .operators.where import QDQWhere, QLinearWhere
|
|
24
|
+
from .quant_utils import QuantizationMode
|
|
25
|
+
|
|
26
|
+
CommonOpsRegistry = {
|
|
27
|
+
"Gather": GatherQuant,
|
|
28
|
+
"Transpose": Direct8BitOp,
|
|
29
|
+
"EmbedLayerNormalization": EmbedLayerNormalizationQuant,
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
IntegerOpsRegistry = {
|
|
33
|
+
"Conv": ConvInteger,
|
|
34
|
+
"MatMul": MatMulInteger,
|
|
35
|
+
"Attention": AttentionQuant,
|
|
36
|
+
"LSTM": LSTMQuant,
|
|
37
|
+
}
|
|
38
|
+
IntegerOpsRegistry.update(CommonOpsRegistry)
|
|
39
|
+
|
|
40
|
+
QLinearOpsRegistry = {
|
|
41
|
+
"ArgMax": QArgMax,
|
|
42
|
+
"Conv": QLinearConv,
|
|
43
|
+
"Gemm": QLinearGemm,
|
|
44
|
+
"MatMul": QLinearMatMul,
|
|
45
|
+
"Add": QLinearBinaryOp,
|
|
46
|
+
"Mul": QLinearBinaryOp,
|
|
47
|
+
"Relu": QLinearActivation,
|
|
48
|
+
"Clip": QLinearActivation,
|
|
49
|
+
"LeakyRelu": QLinearActivation,
|
|
50
|
+
"Sigmoid": QLinearActivation,
|
|
51
|
+
"MaxPool": QMaxPool,
|
|
52
|
+
"GlobalAveragePool": QGlobalAveragePool,
|
|
53
|
+
"Split": QSplit,
|
|
54
|
+
"Pad": QPad,
|
|
55
|
+
"Reshape": Direct8BitOp,
|
|
56
|
+
"Squeeze": Direct8BitOp,
|
|
57
|
+
"Unsqueeze": Direct8BitOp,
|
|
58
|
+
"Resize": QResize,
|
|
59
|
+
"AveragePool": QLinearPool,
|
|
60
|
+
"Concat": QLinearConcat,
|
|
61
|
+
"Softmax": QLinearSoftmax,
|
|
62
|
+
"Where": QLinearWhere,
|
|
63
|
+
}
|
|
64
|
+
QLinearOpsRegistry.update(CommonOpsRegistry)
|
|
65
|
+
|
|
66
|
+
QDQRegistry = {
|
|
67
|
+
"Conv": QDQConv,
|
|
68
|
+
"ConvTranspose": QDQConv,
|
|
69
|
+
"Gemm": QDQGemm,
|
|
70
|
+
"Clip": QDQRemovableActivation,
|
|
71
|
+
"Relu": QDQRemovableActivation,
|
|
72
|
+
"Reshape": QDQDirect8BitOp,
|
|
73
|
+
"Transpose": QDQDirect8BitOp,
|
|
74
|
+
"Squeeze": QDQDirect8BitOp,
|
|
75
|
+
"Unsqueeze": QDQDirect8BitOp,
|
|
76
|
+
"Resize": QDQResize,
|
|
77
|
+
"MaxPool": QDQMaxPool,
|
|
78
|
+
"AveragePool": QDQDirect8BitOp,
|
|
79
|
+
"MatMul": QDQMatMul,
|
|
80
|
+
"Split": QDQSplit,
|
|
81
|
+
"Gather": QDQGather,
|
|
82
|
+
"GatherElements": QDQGather,
|
|
83
|
+
"Where": QDQWhere,
|
|
84
|
+
"InstanceNormalization": QDQNormalization,
|
|
85
|
+
"LayerNormalization": QDQNormalization,
|
|
86
|
+
"BatchNormalization": QDQNormalization,
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def CreateDefaultOpQuantizer(onnx_quantizer, node): # noqa: N802
|
|
91
|
+
return QuantOperatorBase(onnx_quantizer, node)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def CreateOpQuantizer(onnx_quantizer, node): # noqa: N802
|
|
95
|
+
registry = IntegerOpsRegistry if onnx_quantizer.mode == QuantizationMode.IntegerOps else QLinearOpsRegistry
|
|
96
|
+
if node.op_type in registry:
|
|
97
|
+
op_quantizer = registry[node.op_type](onnx_quantizer, node)
|
|
98
|
+
if op_quantizer.should_quantize():
|
|
99
|
+
return op_quantizer
|
|
100
|
+
return QuantOperatorBase(onnx_quantizer, node)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def CreateQDQQuantizer(onnx_quantizer, node): # noqa: N802
|
|
104
|
+
if node.op_type in QDQRegistry:
|
|
105
|
+
return QDQRegistry[node.op_type](onnx_quantizer, node)
|
|
106
|
+
return QDQOperatorBase(onnx_quantizer, node)
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
# --------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft, Intel Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
import tempfile
|
|
10
|
+
import traceback
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Optional, Union
|
|
13
|
+
|
|
14
|
+
import onnx
|
|
15
|
+
|
|
16
|
+
import onnxruntime
|
|
17
|
+
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
|
|
18
|
+
from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data
|
|
19
|
+
|
|
20
|
+
from .quant_utils import add_pre_process_metadata
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def quant_pre_process(
|
|
26
|
+
input_model: Optional[Union[str, Path, onnx.ModelProto]] = None,
|
|
27
|
+
output_model_path: Optional[Union[str, Path]] = None,
|
|
28
|
+
skip_optimization: bool = False,
|
|
29
|
+
skip_onnx_shape: bool = False,
|
|
30
|
+
skip_symbolic_shape: bool = False,
|
|
31
|
+
auto_merge: bool = False,
|
|
32
|
+
int_max: int = 2**31 - 1,
|
|
33
|
+
guess_output_rank: bool = False,
|
|
34
|
+
verbose: int = 0,
|
|
35
|
+
save_as_external_data: bool = False,
|
|
36
|
+
all_tensors_to_one_file: bool = False,
|
|
37
|
+
external_data_location: Optional[str] = None,
|
|
38
|
+
external_data_size_threshold: int = 1024,
|
|
39
|
+
**deprecated_kwargs,
|
|
40
|
+
) -> None:
|
|
41
|
+
"""Shape inference and model optimization, in preparation for quantization.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
input_model: Path to the input model file or ModelProto
|
|
45
|
+
output_model_path: Path to the output model file
|
|
46
|
+
skip_optimization: Skip model optimization step if true. This may result in ONNX shape
|
|
47
|
+
inference failure for some models.
|
|
48
|
+
skip_onnx_shape: Skip ONNX shape inference. Symbolic shape inference is most effective
|
|
49
|
+
with transformer based models. Skipping all shape inferences may
|
|
50
|
+
reduce the effectiveness of quantization, as a tensor with unknown
|
|
51
|
+
shape can not be quantized.
|
|
52
|
+
skip_symbolic_shape: Skip symbolic shape inference. Symbolic shape inference is most
|
|
53
|
+
effective with transformer based models. Skipping all shape
|
|
54
|
+
inferences may reduce the effectiveness of quantization, as a tensor
|
|
55
|
+
with unknown shape can not be quantized.
|
|
56
|
+
auto_merge: For symbolic shape inference, automatically merge symbolic dims when
|
|
57
|
+
conflict happens.
|
|
58
|
+
int_max: For symbolic shape inference, specify the maximum value for integer to be
|
|
59
|
+
treated as boundless for ops like slice
|
|
60
|
+
guess_output_rank: Guess output rank to be the same as input 0 for unknown ops
|
|
61
|
+
verbose: Logs detailed info of inference, 0: turn off, 1: warnings, 3: detailed
|
|
62
|
+
save_as_external_data: Saving an ONNX model to external data
|
|
63
|
+
all_tensors_to_one_file: Saving all the external data to one file
|
|
64
|
+
external_data_location: The file location to save the external file
|
|
65
|
+
external_data_size_threshold: The size threshold for external data
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
if input_model is None:
|
|
69
|
+
input_model = deprecated_kwargs.pop("input_model_path", None)
|
|
70
|
+
assert input_model is not None
|
|
71
|
+
|
|
72
|
+
assert output_model_path is not None, "output_model_path is required."
|
|
73
|
+
|
|
74
|
+
with tempfile.TemporaryDirectory(prefix="pre.quant.") as quant_tmp_dir:
|
|
75
|
+
temp_path = Path(quant_tmp_dir)
|
|
76
|
+
model = None
|
|
77
|
+
|
|
78
|
+
if not skip_symbolic_shape:
|
|
79
|
+
logger.info("Performing symbolic shape inference...")
|
|
80
|
+
loaded_model = input_model if isinstance(input_model, onnx.ModelProto) else onnx.load(input_model)
|
|
81
|
+
model = SymbolicShapeInference.infer_shapes(
|
|
82
|
+
loaded_model,
|
|
83
|
+
int_max,
|
|
84
|
+
auto_merge,
|
|
85
|
+
guess_output_rank,
|
|
86
|
+
verbose,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
if not skip_optimization:
|
|
90
|
+
# Use ORT optimizers (native code) to optimize model
|
|
91
|
+
if not skip_symbolic_shape:
|
|
92
|
+
# Need to save the inferenced model to file so as to run the optimizer
|
|
93
|
+
input_model = str(temp_path / "symbolic_shape_inferred.onnx")
|
|
94
|
+
if save_as_external_data:
|
|
95
|
+
onnx.save_model(
|
|
96
|
+
model,
|
|
97
|
+
input_model,
|
|
98
|
+
save_as_external_data=True,
|
|
99
|
+
all_tensors_to_one_file=all_tensors_to_one_file,
|
|
100
|
+
size_threshold=external_data_size_threshold,
|
|
101
|
+
convert_attribute=False,
|
|
102
|
+
)
|
|
103
|
+
else:
|
|
104
|
+
onnx.save(model, input_model)
|
|
105
|
+
model = None
|
|
106
|
+
|
|
107
|
+
opt_model_path = str(temp_path / "optimized.onnx")
|
|
108
|
+
try:
|
|
109
|
+
sess_option = onnxruntime.SessionOptions()
|
|
110
|
+
sess_option.optimized_model_filepath = opt_model_path
|
|
111
|
+
sess_option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
|
112
|
+
# For large model, extract external data from model and add to session options
|
|
113
|
+
if isinstance(input_model, onnx.ModelProto):
|
|
114
|
+
if has_external_data(input_model):
|
|
115
|
+
raise ValueError(
|
|
116
|
+
"ModelProto has external data not loaded into memory, ORT cannot create session. "
|
|
117
|
+
"Please load external data before calling this function. "
|
|
118
|
+
"See https://onnx.ai/onnx/repo-docs/ExternalData.html for more information."
|
|
119
|
+
)
|
|
120
|
+
external_names, external_values = extract_raw_data_from_model(input_model)
|
|
121
|
+
sess_option.add_external_initializers(list(external_names), list(external_values))
|
|
122
|
+
input_model = input_model.SerializeToString()
|
|
123
|
+
|
|
124
|
+
sess = onnxruntime.InferenceSession(input_model, sess_option, providers=["CPUExecutionProvider"])
|
|
125
|
+
# Close the session to avoid the cleanup error on Windows for temp folders
|
|
126
|
+
# https://github.com/microsoft/onnxruntime/issues/17627
|
|
127
|
+
del sess
|
|
128
|
+
except Exception:
|
|
129
|
+
logger.error(
|
|
130
|
+
"ONNX Runtime Model Optimization Failed! Consider rerun with option `--skip_optimization'."
|
|
131
|
+
)
|
|
132
|
+
logger.error(traceback.format_exc())
|
|
133
|
+
|
|
134
|
+
input_model = opt_model_path
|
|
135
|
+
|
|
136
|
+
if not skip_onnx_shape:
|
|
137
|
+
# ONNX shape inference.
|
|
138
|
+
# According to docs, infer_shapes_path should be used for 2G+ models.
|
|
139
|
+
# If the skip optimization is specified, we could be dealing with a
|
|
140
|
+
# large model. So be on the safe side, save the model
|
|
141
|
+
if model is not None:
|
|
142
|
+
input_model = str(temp_path / "symbolic_shape_inferred.onnx")
|
|
143
|
+
if save_as_external_data:
|
|
144
|
+
onnx.save_model(
|
|
145
|
+
model,
|
|
146
|
+
input_model,
|
|
147
|
+
save_as_external_data=True,
|
|
148
|
+
all_tensors_to_one_file=all_tensors_to_one_file,
|
|
149
|
+
size_threshold=external_data_size_threshold,
|
|
150
|
+
convert_attribute=False,
|
|
151
|
+
)
|
|
152
|
+
else:
|
|
153
|
+
onnx.save(model, input_model)
|
|
154
|
+
model = None
|
|
155
|
+
|
|
156
|
+
if isinstance(input_model, onnx.ModelProto):
|
|
157
|
+
input_model = str(Path(quant_tmp_dir) / "model_input.onnx")
|
|
158
|
+
onnx.save_model(
|
|
159
|
+
model,
|
|
160
|
+
input_model,
|
|
161
|
+
save_as_external_data=True,
|
|
162
|
+
all_tensors_to_one_file=all_tensors_to_one_file,
|
|
163
|
+
size_threshold=external_data_size_threshold,
|
|
164
|
+
convert_attribute=False,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
inferred_model_path = str(temp_path / "onnx_shape_inferred.onnx")
|
|
168
|
+
onnx.shape_inference.infer_shapes_path(input_model, inferred_model_path)
|
|
169
|
+
model = onnx.load(inferred_model_path)
|
|
170
|
+
|
|
171
|
+
if model is None:
|
|
172
|
+
model = input_model if isinstance(input_model, onnx.ModelProto) else onnx.load(input_model)
|
|
173
|
+
|
|
174
|
+
add_pre_process_metadata(model)
|
|
175
|
+
|
|
176
|
+
if save_as_external_data:
|
|
177
|
+
onnx.save_model(
|
|
178
|
+
model,
|
|
179
|
+
output_model_path,
|
|
180
|
+
save_as_external_data=True,
|
|
181
|
+
all_tensors_to_one_file=all_tensors_to_one_file,
|
|
182
|
+
location=external_data_location,
|
|
183
|
+
size_threshold=external_data_size_threshold,
|
|
184
|
+
convert_attribute=False,
|
|
185
|
+
)
|
|
186
|
+
else:
|
|
187
|
+
onnx.save(model, output_model_path)
|