onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,434 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
import numpy
|
|
5
|
+
import psutil
|
|
6
|
+
from onnx import TensorProto
|
|
7
|
+
|
|
8
|
+
"""
|
|
9
|
+
This profiler tool could run a transformer model and print out the kernel time spent on each Node of the model.
|
|
10
|
+
Example of profiling of longformer model:
|
|
11
|
+
python profiler.py --model longformer-base-4096_fp32.onnx --batch_size 1 --sequence_length 4096 --global_length 8 --samples 1000 --thread_num 8 --dummy_inputs longformer --use_gpu
|
|
12
|
+
Example of importing profile result file from onnxruntime_perf_test:
|
|
13
|
+
python profiler.py --input profile_2021-10-25_12-02-41.json
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def parse_arguments(argv=None):
|
|
18
|
+
parser = argparse.ArgumentParser()
|
|
19
|
+
|
|
20
|
+
parser.add_argument(
|
|
21
|
+
"-i",
|
|
22
|
+
"--input",
|
|
23
|
+
required=False,
|
|
24
|
+
type=str,
|
|
25
|
+
help="Set the input file for reading the profile results",
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
parser.add_argument(
|
|
29
|
+
"-m",
|
|
30
|
+
"--model",
|
|
31
|
+
required=False,
|
|
32
|
+
type=str,
|
|
33
|
+
help="onnx model path to run profiling. Required when --input is not specified.",
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
parser.add_argument(
|
|
37
|
+
"-b",
|
|
38
|
+
"--batch_size",
|
|
39
|
+
required=False,
|
|
40
|
+
type=int,
|
|
41
|
+
default=1,
|
|
42
|
+
help="batch size of input",
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
parser.add_argument(
|
|
46
|
+
"-s",
|
|
47
|
+
"--sequence_length",
|
|
48
|
+
required=False,
|
|
49
|
+
type=int,
|
|
50
|
+
default=32,
|
|
51
|
+
help="sequence length of input",
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
parser.add_argument(
|
|
55
|
+
"--past_sequence_length",
|
|
56
|
+
required=False,
|
|
57
|
+
type=int,
|
|
58
|
+
default=1,
|
|
59
|
+
help="past sequence length for gpt2",
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
parser.add_argument(
|
|
63
|
+
"--global_length",
|
|
64
|
+
required=False,
|
|
65
|
+
type=int,
|
|
66
|
+
default=1,
|
|
67
|
+
help="number of global tokens for longformer",
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
parser.add_argument(
|
|
71
|
+
"--samples",
|
|
72
|
+
required=False,
|
|
73
|
+
type=int,
|
|
74
|
+
default=1000,
|
|
75
|
+
help="number of samples to test. Set it large enough to reduce the variance of performance result.",
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
parser.add_argument(
|
|
79
|
+
"--threshold",
|
|
80
|
+
required=False,
|
|
81
|
+
type=float,
|
|
82
|
+
default=0.01,
|
|
83
|
+
help="Threshold of run time ratio among all nodes. Nodes with larger ratio will show in top expensive nodes.",
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
parser.add_argument(
|
|
87
|
+
"--thread_num",
|
|
88
|
+
required=False,
|
|
89
|
+
type=int,
|
|
90
|
+
default=-1,
|
|
91
|
+
help="number of threads to use",
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
parser.add_argument(
|
|
95
|
+
"--input_ids_name",
|
|
96
|
+
required=False,
|
|
97
|
+
type=str,
|
|
98
|
+
default=None,
|
|
99
|
+
help="input name for input IDs, for bert",
|
|
100
|
+
)
|
|
101
|
+
parser.add_argument(
|
|
102
|
+
"--segment_ids_name",
|
|
103
|
+
required=False,
|
|
104
|
+
type=str,
|
|
105
|
+
default=None,
|
|
106
|
+
help="input name for segment IDs, for bert",
|
|
107
|
+
)
|
|
108
|
+
parser.add_argument(
|
|
109
|
+
"--input_mask_name",
|
|
110
|
+
required=False,
|
|
111
|
+
type=str,
|
|
112
|
+
default=None,
|
|
113
|
+
help="input name for attention mask, for bert",
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
parser.add_argument(
|
|
117
|
+
"--dummy_inputs",
|
|
118
|
+
required=False,
|
|
119
|
+
default="default",
|
|
120
|
+
choices=["bert", "gpt2", "longformer", "default"],
|
|
121
|
+
help="Type of model inputs. The default will create dummy inputs with ones.",
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
parser.add_argument("-g", "--use_gpu", required=False, action="store_true", help="use GPU")
|
|
125
|
+
parser.set_defaults(use_gpu=False)
|
|
126
|
+
|
|
127
|
+
parser.add_argument(
|
|
128
|
+
"--provider",
|
|
129
|
+
required=False,
|
|
130
|
+
type=str,
|
|
131
|
+
default="cuda",
|
|
132
|
+
help="Execution provider to use",
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
parser.add_argument(
|
|
136
|
+
"--basic_optimization",
|
|
137
|
+
required=False,
|
|
138
|
+
action="store_true",
|
|
139
|
+
help="Enable only basic graph optimizations. By default, all optimizations are enabled in OnnxRuntime",
|
|
140
|
+
)
|
|
141
|
+
parser.set_defaults(basic_optimization=False)
|
|
142
|
+
|
|
143
|
+
parser.add_argument(
|
|
144
|
+
"--kernel_time_only",
|
|
145
|
+
required=False,
|
|
146
|
+
action="store_true",
|
|
147
|
+
help="Only include the kernel time and no fence time",
|
|
148
|
+
)
|
|
149
|
+
parser.set_defaults(kernel_time_only=False)
|
|
150
|
+
|
|
151
|
+
parser.add_argument("-v", "--verbose", required=False, action="store_true")
|
|
152
|
+
parser.set_defaults(verbose=False)
|
|
153
|
+
|
|
154
|
+
return parser.parse_args(argv)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def run_profile(onnx_model_path, use_gpu, provider, basic_optimization, thread_num, all_inputs):
|
|
158
|
+
from benchmark_helper import create_onnxruntime_session # noqa: PLC0415
|
|
159
|
+
|
|
160
|
+
session = create_onnxruntime_session(
|
|
161
|
+
onnx_model_path,
|
|
162
|
+
use_gpu,
|
|
163
|
+
provider,
|
|
164
|
+
enable_all_optimization=not basic_optimization,
|
|
165
|
+
num_threads=thread_num,
|
|
166
|
+
enable_profiling=True,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
for inputs in all_inputs:
|
|
170
|
+
_ = session.run(None, inputs)
|
|
171
|
+
|
|
172
|
+
profile_file = session.end_profiling()
|
|
173
|
+
return profile_file
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def get_dim_from_type_proto(dim):
|
|
177
|
+
return getattr(dim, dim.WhichOneof("value")) if type(dim.WhichOneof("value")) == str else None # noqa: E721
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def get_shape_from_type_proto(type_proto):
|
|
181
|
+
return [get_dim_from_type_proto(d) for d in type_proto.tensor_type.shape.dim]
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def create_dummy_inputs(onnx_model, batch_size, sequence_length, samples):
|
|
185
|
+
"""Create dummy inputs for ONNX model.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
onnx_model (OnnxModel): ONNX model
|
|
189
|
+
batch_size (int): batch size
|
|
190
|
+
sequence_length (int): sequence length
|
|
191
|
+
samples (int): number of samples
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
List[Dict]: list of inputs
|
|
195
|
+
"""
|
|
196
|
+
dummy_inputs = {}
|
|
197
|
+
for graph_input in onnx_model.get_graph_inputs_excluding_initializers():
|
|
198
|
+
shape = get_shape_from_type_proto(graph_input.type)
|
|
199
|
+
symbol_dims = []
|
|
200
|
+
for i, dim in enumerate(shape):
|
|
201
|
+
if isinstance(dim, str):
|
|
202
|
+
symbol_dims.append(i)
|
|
203
|
+
|
|
204
|
+
# allowed symbolic dimensions: batch_size and sequence_length
|
|
205
|
+
if len(symbol_dims) > 2:
|
|
206
|
+
return None
|
|
207
|
+
if len(symbol_dims) > 0:
|
|
208
|
+
shape[symbol_dims[0]] = batch_size
|
|
209
|
+
if len(symbol_dims) > 1:
|
|
210
|
+
shape[symbol_dims[1]] = sequence_length
|
|
211
|
+
|
|
212
|
+
elem_type = graph_input.type.tensor_type.elem_type
|
|
213
|
+
assert elem_type in [TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64]
|
|
214
|
+
data_type = (
|
|
215
|
+
numpy.float32
|
|
216
|
+
if elem_type == TensorProto.FLOAT
|
|
217
|
+
else (numpy.int64 if elem_type == TensorProto.INT64 else numpy.int32)
|
|
218
|
+
)
|
|
219
|
+
data = numpy.ones(shape, dtype=data_type)
|
|
220
|
+
dummy_inputs[graph_input.name] = data
|
|
221
|
+
|
|
222
|
+
all_inputs = [dummy_inputs for _ in range(samples)]
|
|
223
|
+
return all_inputs
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def create_bert_inputs(
|
|
227
|
+
onnx_model,
|
|
228
|
+
batch_size,
|
|
229
|
+
sequence_length,
|
|
230
|
+
samples,
|
|
231
|
+
input_ids_name=None,
|
|
232
|
+
segment_ids_name=None,
|
|
233
|
+
input_mask_name=None,
|
|
234
|
+
):
|
|
235
|
+
"""Create dummy inputs for BERT model.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
onnx_model (OnnxModel): ONNX model
|
|
239
|
+
batch_size (int): batch size
|
|
240
|
+
sequence_length (int): sequence length
|
|
241
|
+
samples (int): number of samples
|
|
242
|
+
input_ids_name (str, optional): Name of graph input for input IDs. Defaults to None.
|
|
243
|
+
segment_ids_name (str, optional): Name of graph input for segment IDs. Defaults to None.
|
|
244
|
+
input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
List[Dict]: list of inputs
|
|
248
|
+
"""
|
|
249
|
+
from bert_test_data import find_bert_inputs, generate_test_data # noqa: PLC0415
|
|
250
|
+
|
|
251
|
+
input_ids, segment_ids, input_mask = find_bert_inputs(onnx_model, input_ids_name, segment_ids_name, input_mask_name)
|
|
252
|
+
all_inputs = generate_test_data(
|
|
253
|
+
batch_size,
|
|
254
|
+
sequence_length,
|
|
255
|
+
test_cases=samples,
|
|
256
|
+
seed=123,
|
|
257
|
+
verbose=False,
|
|
258
|
+
input_ids=input_ids,
|
|
259
|
+
segment_ids=segment_ids,
|
|
260
|
+
input_mask=input_mask,
|
|
261
|
+
random_mask_length=False,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
return all_inputs
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def create_gpt2_inputs(onnx_model, batch_size, sequence_length, past_sequence_length, samples):
|
|
268
|
+
"""Create dummy inputs for GPT-2 model.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
onnx_model (OnnxModel): ONNX model
|
|
272
|
+
batch_size (int): batch size
|
|
273
|
+
sequence_length (int): sequence length
|
|
274
|
+
past_sequence_length (int): past sequence length
|
|
275
|
+
samples (int): number of samples
|
|
276
|
+
|
|
277
|
+
Raises:
|
|
278
|
+
RuntimeError: symbolic is not supported. Use the tool convert_to_onnx.py to export ONNX model instead.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
List[Dict]: list of inputs
|
|
282
|
+
"""
|
|
283
|
+
# The symbolic names shall be same as those used in Gpt2Helper.export_onnx(...) function.
|
|
284
|
+
symbols = {
|
|
285
|
+
"batch_size": batch_size,
|
|
286
|
+
"seq_len": sequence_length,
|
|
287
|
+
"past_seq_len": past_sequence_length,
|
|
288
|
+
"total_seq_len": sequence_length + past_sequence_length,
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
dummy_inputs = {}
|
|
292
|
+
for graph_input in onnx_model.get_graph_inputs_excluding_initializers():
|
|
293
|
+
shape = get_shape_from_type_proto(graph_input.type)
|
|
294
|
+
for i, dim in enumerate(shape):
|
|
295
|
+
if isinstance(dim, str):
|
|
296
|
+
if dim not in symbols:
|
|
297
|
+
raise RuntimeError(f"symbol is not supported: {dim}")
|
|
298
|
+
else:
|
|
299
|
+
shape[i] = symbols[dim]
|
|
300
|
+
|
|
301
|
+
elem_type = graph_input.type.tensor_type.elem_type
|
|
302
|
+
assert elem_type in [TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64]
|
|
303
|
+
data_type = (
|
|
304
|
+
numpy.float32
|
|
305
|
+
if elem_type == TensorProto.FLOAT
|
|
306
|
+
else (numpy.int64 if elem_type == TensorProto.INT64 else numpy.int32)
|
|
307
|
+
)
|
|
308
|
+
data = numpy.ones(shape, dtype=data_type)
|
|
309
|
+
dummy_inputs[graph_input.name] = data
|
|
310
|
+
|
|
311
|
+
all_inputs = [dummy_inputs for _ in range(samples)]
|
|
312
|
+
return all_inputs
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def create_longformer_inputs(onnx_model, batch_size, sequence_length, global_length, samples):
|
|
316
|
+
"""Create dummy inputs for Longformer model.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
onnx_model (OnnxModel): ONNX model
|
|
320
|
+
batch_size (int): batch size
|
|
321
|
+
sequence_length (int): sequence length
|
|
322
|
+
global_length (int): number of global tokens
|
|
323
|
+
samples (int): number of samples
|
|
324
|
+
|
|
325
|
+
Raises:
|
|
326
|
+
RuntimeError: symbolic is not supported. Use the tool convert_longformer_to_onnx.py to export ONNX model instead.
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
List[Dict]: list of inputs
|
|
330
|
+
"""
|
|
331
|
+
symbols = {"batch_size": batch_size, "sequence_length": sequence_length}
|
|
332
|
+
|
|
333
|
+
dummy_inputs = {}
|
|
334
|
+
for graph_input in onnx_model.get_graph_inputs_excluding_initializers():
|
|
335
|
+
shape = get_shape_from_type_proto(graph_input.type)
|
|
336
|
+
for i, dim in enumerate(shape):
|
|
337
|
+
if isinstance(dim, str):
|
|
338
|
+
if dim not in symbols:
|
|
339
|
+
raise RuntimeError(f"symbol is not supported: {dim}")
|
|
340
|
+
else:
|
|
341
|
+
shape[i] = symbols[dim]
|
|
342
|
+
|
|
343
|
+
elem_type = graph_input.type.tensor_type.elem_type
|
|
344
|
+
assert elem_type in [TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64]
|
|
345
|
+
data_type = (
|
|
346
|
+
numpy.float32
|
|
347
|
+
if elem_type == TensorProto.FLOAT
|
|
348
|
+
else (numpy.int64 if elem_type == TensorProto.INT64 else numpy.int32)
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
if "global" in graph_input.name:
|
|
352
|
+
data = numpy.zeros(shape, dtype=data_type)
|
|
353
|
+
data[:, :global_length] = 1
|
|
354
|
+
else:
|
|
355
|
+
data = numpy.ones(shape, dtype=data_type)
|
|
356
|
+
dummy_inputs[graph_input.name] = data
|
|
357
|
+
|
|
358
|
+
all_inputs = [dummy_inputs for _ in range(samples)]
|
|
359
|
+
return all_inputs
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def run(args):
|
|
363
|
+
num_threads = args.thread_num if args.thread_num > 0 else psutil.cpu_count(logical=False)
|
|
364
|
+
|
|
365
|
+
# Set OMP environment variable before importing onnxruntime. Needed for cpu only, and no impact for onnxruntime-gpu package.
|
|
366
|
+
if "OMP_NUM_THREADS" not in os.environ:
|
|
367
|
+
os.environ["OMP_NUM_THREADS"] = str(num_threads)
|
|
368
|
+
|
|
369
|
+
from onnx import load # noqa: PLC0415
|
|
370
|
+
from onnx_model import OnnxModel # noqa: PLC0415
|
|
371
|
+
|
|
372
|
+
onnx_model = OnnxModel(load(args.model))
|
|
373
|
+
|
|
374
|
+
all_inputs = None
|
|
375
|
+
if args.dummy_inputs == "bert":
|
|
376
|
+
all_inputs = create_bert_inputs(
|
|
377
|
+
onnx_model,
|
|
378
|
+
args.batch_size,
|
|
379
|
+
args.sequence_length,
|
|
380
|
+
args.samples,
|
|
381
|
+
args.input_ids_name,
|
|
382
|
+
args.segment_ids_name,
|
|
383
|
+
args.input_mask_name,
|
|
384
|
+
)
|
|
385
|
+
elif args.dummy_inputs == "gpt2":
|
|
386
|
+
all_inputs = create_gpt2_inputs(
|
|
387
|
+
onnx_model,
|
|
388
|
+
args.batch_size,
|
|
389
|
+
args.sequence_length,
|
|
390
|
+
args.past_sequence_length,
|
|
391
|
+
args.samples,
|
|
392
|
+
)
|
|
393
|
+
elif args.dummy_inputs == "longformer":
|
|
394
|
+
all_inputs = create_longformer_inputs(
|
|
395
|
+
onnx_model,
|
|
396
|
+
args.batch_size,
|
|
397
|
+
args.sequence_length,
|
|
398
|
+
args.global_length,
|
|
399
|
+
args.samples,
|
|
400
|
+
)
|
|
401
|
+
else: # default
|
|
402
|
+
all_inputs = create_dummy_inputs(onnx_model, args.batch_size, args.sequence_length, args.samples)
|
|
403
|
+
|
|
404
|
+
profile_file = run_profile(
|
|
405
|
+
args.model,
|
|
406
|
+
args.use_gpu,
|
|
407
|
+
args.provider,
|
|
408
|
+
args.basic_optimization,
|
|
409
|
+
args.thread_num,
|
|
410
|
+
all_inputs,
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
return profile_file
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
if __name__ == "__main__":
|
|
417
|
+
arguments = parse_arguments()
|
|
418
|
+
print("Arguments", arguments)
|
|
419
|
+
|
|
420
|
+
from benchmark_helper import setup_logger
|
|
421
|
+
|
|
422
|
+
setup_logger(arguments.verbose)
|
|
423
|
+
|
|
424
|
+
if not arguments.input:
|
|
425
|
+
assert arguments.model, "requires either --model to run profiling or --input to read profiling results"
|
|
426
|
+
profile_file = run(arguments)
|
|
427
|
+
else:
|
|
428
|
+
profile_file = arguments.input
|
|
429
|
+
from profile_result_processor import process_results
|
|
430
|
+
|
|
431
|
+
results = process_results(profile_file, arguments)
|
|
432
|
+
|
|
433
|
+
for line in results:
|
|
434
|
+
print(line)
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
|
|
10
|
+
import onnx
|
|
11
|
+
import torch
|
|
12
|
+
from transformers.modeling_utils import Conv1D
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _conv1d_to_linear(module):
|
|
18
|
+
in_size, out_size = module.weight.shape
|
|
19
|
+
linear = torch.nn.Linear(in_size, out_size)
|
|
20
|
+
linear.weight.data = module.weight.data.T.contiguous()
|
|
21
|
+
linear.bias.data = module.bias.data
|
|
22
|
+
return linear
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def conv1d_to_linear(model):
|
|
26
|
+
"""in-place
|
|
27
|
+
This is for Dynamic Quantization, as Conv1D is not recognized by PyTorch, convert it to nn.Linear
|
|
28
|
+
"""
|
|
29
|
+
logger.debug("replace Conv1D with Linear")
|
|
30
|
+
for name in list(model._modules):
|
|
31
|
+
module = model._modules[name]
|
|
32
|
+
if isinstance(module, Conv1D):
|
|
33
|
+
linear = _conv1d_to_linear(module)
|
|
34
|
+
model._modules[name] = linear
|
|
35
|
+
else:
|
|
36
|
+
conv1d_to_linear(module)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _get_size_of_pytorch_model(model):
|
|
40
|
+
torch.save(model.state_dict(), "temp.p")
|
|
41
|
+
size = os.path.getsize("temp.p") / (1024 * 1024)
|
|
42
|
+
os.remove("temp.p")
|
|
43
|
+
return size
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class QuantizeHelper:
|
|
47
|
+
@staticmethod
|
|
48
|
+
def quantize_torch_model(model, dtype=torch.qint8):
|
|
49
|
+
"""
|
|
50
|
+
Usage: model = quantize_model(model)
|
|
51
|
+
|
|
52
|
+
TODO: mix of in-place and return, but results are different
|
|
53
|
+
"""
|
|
54
|
+
conv1d_to_linear(model)
|
|
55
|
+
quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=dtype)
|
|
56
|
+
logger.info(f"Size of full precision Torch model(MB):{_get_size_of_pytorch_model(model)}")
|
|
57
|
+
logger.info(f"Size of quantized Torch model(MB):{_get_size_of_pytorch_model(quantized_model)}")
|
|
58
|
+
return quantized_model
|
|
59
|
+
|
|
60
|
+
@staticmethod
|
|
61
|
+
def quantize_onnx_model(onnx_model_path, quantized_model_path, use_external_data_format=False):
|
|
62
|
+
from pathlib import Path # noqa: PLC0415
|
|
63
|
+
|
|
64
|
+
from onnxruntime.quantization import quantize_dynamic # noqa: PLC0415
|
|
65
|
+
|
|
66
|
+
Path(quantized_model_path).parent.mkdir(parents=True, exist_ok=True)
|
|
67
|
+
logger.info(f"Size of full precision ONNX model(MB):{os.path.getsize(onnx_model_path) / (1024 * 1024)}")
|
|
68
|
+
quantize_dynamic(
|
|
69
|
+
onnx_model_path,
|
|
70
|
+
quantized_model_path,
|
|
71
|
+
use_external_data_format=use_external_data_format,
|
|
72
|
+
extra_options={"DefaultTensorType": onnx.TensorProto.FLOAT},
|
|
73
|
+
)
|
|
74
|
+
logger.info(f"quantized model saved to:{quantized_model_path}")
|
|
75
|
+
# TODO: inlcude external data in total model size.
|
|
76
|
+
logger.info(f"Size of quantized ONNX model(MB):{os.path.getsize(quantized_model_path) / (1024 * 1024)}")
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import sys
|
|
9
|
+
|
|
10
|
+
# In ORT Package the symbolic_shape_infer.py is in ../tools
|
|
11
|
+
file_path = os.path.dirname(__file__)
|
|
12
|
+
if os.path.exists(os.path.join(file_path, "../tools/symbolic_shape_infer.py")):
|
|
13
|
+
sys.path.append(os.path.join(file_path, "../tools"))
|
|
14
|
+
else:
|
|
15
|
+
sys.path.append(os.path.join(file_path, ".."))
|
|
16
|
+
|
|
17
|
+
from symbolic_shape_infer import SymbolicShapeInference, get_shape_from_type_proto, sympy # noqa: E402
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SymbolicShapeInferenceHelper(SymbolicShapeInference):
|
|
23
|
+
def __init__(self, model, verbose=0, int_max=2**31 - 1, auto_merge=True, guess_output_rank=False):
|
|
24
|
+
super().__init__(int_max, auto_merge, guess_output_rank, verbose)
|
|
25
|
+
self.model_ = model
|
|
26
|
+
self.all_shapes_inferred_: bool = False
|
|
27
|
+
self.is_inferred_: bool = False
|
|
28
|
+
self.dynamic_axis_mapping_: dict[str, int] = {}
|
|
29
|
+
|
|
30
|
+
def infer(self, dynamic_axis_mapping: dict[str, int], max_runs: int = 200):
|
|
31
|
+
"""Run shape inference, and try replace dynamic axis from string to integer when mapping is provided.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
dynamic_axis_mapping (_type_): a dictionary with name of dynamic axis as key, like {"batch_size" : 4}
|
|
35
|
+
max_runs (int, optional): limit maximum number of runs to avoid infinite loop. Defaults to 200.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
bool: whether all shapes has been inferred or not.
|
|
39
|
+
"""
|
|
40
|
+
assert dynamic_axis_mapping is not None
|
|
41
|
+
|
|
42
|
+
if self.is_inferred_ and self.dynamic_axis_mapping_ == dynamic_axis_mapping:
|
|
43
|
+
return self.all_shapes_inferred_
|
|
44
|
+
|
|
45
|
+
self.dynamic_axis_mapping_ = dynamic_axis_mapping
|
|
46
|
+
|
|
47
|
+
self._preprocess(self.model_)
|
|
48
|
+
|
|
49
|
+
count = 0
|
|
50
|
+
while self.run_:
|
|
51
|
+
logger.debug(f"shape infer run {count}")
|
|
52
|
+
self.all_shapes_inferred_ = self._infer_impl()
|
|
53
|
+
count += 1
|
|
54
|
+
if max_runs > 0 and count >= max_runs:
|
|
55
|
+
break
|
|
56
|
+
|
|
57
|
+
self.is_inferred_ = True
|
|
58
|
+
return self.all_shapes_inferred_
|
|
59
|
+
|
|
60
|
+
def _get_sympy_shape(self, node, idx):
|
|
61
|
+
"""Override it to ensure shape inference by giving the actual value of dynamic axis."""
|
|
62
|
+
sympy_shape = []
|
|
63
|
+
|
|
64
|
+
shape = self._get_shape(node, idx)
|
|
65
|
+
if shape:
|
|
66
|
+
for dim in shape:
|
|
67
|
+
if isinstance(dim, str):
|
|
68
|
+
if dim in self.dynamic_axis_mapping_:
|
|
69
|
+
sympy_shape.append(self.dynamic_axis_mapping_[dim])
|
|
70
|
+
elif dim in self.symbolic_dims_:
|
|
71
|
+
sympy_shape.append(self.symbolic_dims_[dim])
|
|
72
|
+
else:
|
|
73
|
+
sympy_shape.append(sympy.Symbol(dim, integer=True))
|
|
74
|
+
else:
|
|
75
|
+
assert dim is not None
|
|
76
|
+
sympy_shape.append(dim)
|
|
77
|
+
return sympy_shape
|
|
78
|
+
|
|
79
|
+
def get_edge_shape(self, edge):
|
|
80
|
+
"""Get shape of an edge.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
edge (str): name of edge
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Optional[List[int]]: the shape, or None if shape is unknown
|
|
87
|
+
"""
|
|
88
|
+
assert self.all_shapes_inferred_
|
|
89
|
+
if edge not in self.known_vi_:
|
|
90
|
+
print("Cannot retrieve the shape of " + str(edge))
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
type_proto = self.known_vi_[edge].type
|
|
94
|
+
shape = get_shape_from_type_proto(type_proto)
|
|
95
|
+
|
|
96
|
+
if shape is not None:
|
|
97
|
+
for i, dim in enumerate(shape):
|
|
98
|
+
if isinstance(dim, str) and dim in self.dynamic_axis_mapping_:
|
|
99
|
+
shape[i] = self.dynamic_axis_mapping_[dim]
|
|
100
|
+
|
|
101
|
+
return shape
|
|
102
|
+
|
|
103
|
+
def compare_shape(self, edge, edge_other):
|
|
104
|
+
"""Compare shape of two edges.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
edge (str): name of edge
|
|
108
|
+
edge_other (str): name of another edge
|
|
109
|
+
|
|
110
|
+
Raises:
|
|
111
|
+
Exception: At least one shape is missed for edges to compare
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
bool: whether the shape is same or not
|
|
115
|
+
"""
|
|
116
|
+
assert self.all_shapes_inferred_
|
|
117
|
+
shape = self.get_edge_shape(edge)
|
|
118
|
+
shape_other = self.get_edge_shape(edge_other)
|
|
119
|
+
if shape is None or shape_other is None:
|
|
120
|
+
raise Exception("At least one shape is missed for edges to compare")
|
|
121
|
+
return shape == shape_other
|