onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
onnxruntime/__init__.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
"""
|
|
6
|
+
ONNX Runtime is a performance-focused scoring engine for Open Neural Network Exchange (ONNX) models.
|
|
7
|
+
For more information on ONNX Runtime, please see `aka.ms/onnxruntime <https://aka.ms/onnxruntime/>`_
|
|
8
|
+
or the `Github project <https://github.com/microsoft/onnxruntime/>`_.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import contextlib
|
|
12
|
+
|
|
13
|
+
__version__ = "1.24.1"
|
|
14
|
+
__author__ = "Microsoft"
|
|
15
|
+
|
|
16
|
+
# we need to do device version validation (for example to check Cuda version for an onnxruntime-training package).
|
|
17
|
+
# in order to know whether the onnxruntime package is for training it needs
|
|
18
|
+
# to do import onnxruntime.training.ortmodule first.
|
|
19
|
+
# onnxruntime.capi._pybind_state is required before import onnxruntime.training.ortmodule.
|
|
20
|
+
# however, import onnxruntime.capi._pybind_state will already raise an exception if a required Cuda version
|
|
21
|
+
# is not found.
|
|
22
|
+
# here we need to save the exception and continue with Cuda version validation in order to post
|
|
23
|
+
# meaningful messages to the user.
|
|
24
|
+
# the saved exception is raised after device version validation.
|
|
25
|
+
try:
|
|
26
|
+
from onnxruntime.capi._pybind_state import (
|
|
27
|
+
ExecutionMode, # noqa: F401
|
|
28
|
+
ExecutionOrder, # noqa: F401
|
|
29
|
+
GraphOptimizationLevel, # noqa: F401
|
|
30
|
+
LoraAdapter, # noqa: F401
|
|
31
|
+
ModelMetadata, # noqa: F401
|
|
32
|
+
NodeArg, # noqa: F401
|
|
33
|
+
OrtAllocatorType, # noqa: F401
|
|
34
|
+
OrtArenaCfg, # noqa: F401
|
|
35
|
+
OrtCompileApiFlags, # noqa: F401
|
|
36
|
+
OrtDeviceMemoryType, # noqa: F401
|
|
37
|
+
OrtEpAssignedNode, # noqa: F401
|
|
38
|
+
OrtEpAssignedSubgraph, # noqa: F401
|
|
39
|
+
OrtEpDevice, # noqa: F401
|
|
40
|
+
OrtExecutionProviderDevicePolicy, # noqa: F401
|
|
41
|
+
OrtExternalInitializerInfo, # noqa: F401
|
|
42
|
+
OrtHardwareDevice, # noqa: F401
|
|
43
|
+
OrtHardwareDeviceType, # noqa: F401
|
|
44
|
+
OrtMemoryInfo, # noqa: F401
|
|
45
|
+
OrtMemoryInfoDeviceType, # noqa: F401
|
|
46
|
+
OrtMemType, # noqa: F401
|
|
47
|
+
OrtSparseFormat, # noqa: F401
|
|
48
|
+
OrtSyncStream, # noqa: F401
|
|
49
|
+
RunOptions, # noqa: F401
|
|
50
|
+
SessionIOBinding, # noqa: F401
|
|
51
|
+
SessionOptions, # noqa: F401
|
|
52
|
+
create_and_register_allocator, # noqa: F401
|
|
53
|
+
create_and_register_allocator_v2, # noqa: F401
|
|
54
|
+
disable_telemetry_events, # noqa: F401
|
|
55
|
+
enable_telemetry_events, # noqa: F401
|
|
56
|
+
get_all_providers, # noqa: F401
|
|
57
|
+
get_available_providers, # noqa: F401
|
|
58
|
+
get_build_info, # noqa: F401
|
|
59
|
+
get_device, # noqa: F401
|
|
60
|
+
get_ep_devices, # noqa: F401
|
|
61
|
+
get_version_string, # noqa: F401
|
|
62
|
+
has_collective_ops, # noqa: F401
|
|
63
|
+
register_execution_provider_library, # noqa: F401
|
|
64
|
+
set_default_logger_severity, # noqa: F401
|
|
65
|
+
set_default_logger_verbosity, # noqa: F401
|
|
66
|
+
set_global_thread_pool_sizes, # noqa: F401
|
|
67
|
+
set_seed, # noqa: F401
|
|
68
|
+
unregister_execution_provider_library, # noqa: F401
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
import_capi_exception = None
|
|
72
|
+
except Exception as e:
|
|
73
|
+
import_capi_exception = e
|
|
74
|
+
|
|
75
|
+
from onnxruntime.capi import onnxruntime_validation
|
|
76
|
+
|
|
77
|
+
if import_capi_exception:
|
|
78
|
+
raise import_capi_exception
|
|
79
|
+
|
|
80
|
+
from onnxruntime.capi.onnxruntime_inference_collection import (
|
|
81
|
+
AdapterFormat, # noqa: F401
|
|
82
|
+
InferenceSession, # noqa: F401
|
|
83
|
+
IOBinding, # noqa: F401
|
|
84
|
+
ModelCompiler, # noqa: F401
|
|
85
|
+
OrtDevice, # noqa: F401
|
|
86
|
+
OrtValue, # noqa: F401
|
|
87
|
+
SparseTensor, # noqa: F401
|
|
88
|
+
copy_tensors, # noqa: F401
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# TODO: thiagofc: Temporary experimental namespace for new PyTorch front-end
|
|
92
|
+
try: # noqa: SIM105
|
|
93
|
+
from . import experimental # noqa: F401
|
|
94
|
+
except ImportError:
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
package_name, version, cuda_version = onnxruntime_validation.get_package_name_and_version_info()
|
|
99
|
+
|
|
100
|
+
if version:
|
|
101
|
+
__version__ = version
|
|
102
|
+
|
|
103
|
+
onnxruntime_validation.check_distro_info()
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _get_package_version(package_name: str):
|
|
107
|
+
from importlib.metadata import PackageNotFoundError, version # noqa: PLC0415
|
|
108
|
+
|
|
109
|
+
try:
|
|
110
|
+
package_version = version(package_name)
|
|
111
|
+
except PackageNotFoundError:
|
|
112
|
+
package_version = None
|
|
113
|
+
return package_version
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _get_package_root(package_name: str, directory_name: str | None = None):
|
|
117
|
+
from importlib.metadata import PackageNotFoundError, distribution # noqa: PLC0415
|
|
118
|
+
|
|
119
|
+
root_directory_name = directory_name or package_name
|
|
120
|
+
try:
|
|
121
|
+
dist = distribution(package_name)
|
|
122
|
+
files = dist.files or []
|
|
123
|
+
|
|
124
|
+
for file in files:
|
|
125
|
+
if file.name.endswith("__init__.py") and root_directory_name in file.parts:
|
|
126
|
+
return file.locate().parent
|
|
127
|
+
|
|
128
|
+
# Fallback to the first __init__.py
|
|
129
|
+
if not directory_name:
|
|
130
|
+
for file in files:
|
|
131
|
+
if file.name.endswith("__init__.py"):
|
|
132
|
+
return file.locate().parent
|
|
133
|
+
except PackageNotFoundError:
|
|
134
|
+
# package not found, do nothing
|
|
135
|
+
pass
|
|
136
|
+
|
|
137
|
+
return None
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _extract_cuda_major_version(version_str: str) -> str:
|
|
141
|
+
"""Extract CUDA major version from version string (e.g., '12.1' -> '12').
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
version_str: CUDA version string to parse
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
Major version as string, or "12" if parsing fails
|
|
148
|
+
"""
|
|
149
|
+
return version_str.split(".")[0] if version_str else "12"
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _get_cufft_version(cuda_major: str) -> str:
|
|
153
|
+
"""Get cufft library version based on CUDA major version.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
cuda_major: CUDA major version as string (e.g., "12", "13")
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
cufft version as string
|
|
160
|
+
"""
|
|
161
|
+
# cufft versions: CUDA 12.x -> 11, CUDA 13.x -> 12
|
|
162
|
+
return "12" if cuda_major == "13" else "11"
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def _get_nvidia_dll_paths(is_windows: bool, cuda: bool = True, cudnn: bool = True):
|
|
166
|
+
# Dynamically determine CUDA major version from build info
|
|
167
|
+
cuda_major_version = _extract_cuda_major_version(cuda_version)
|
|
168
|
+
cufft_version = _get_cufft_version(cuda_major_version)
|
|
169
|
+
|
|
170
|
+
if is_windows:
|
|
171
|
+
# Path is relative to site-packages directory.
|
|
172
|
+
cuda_dll_paths = [
|
|
173
|
+
("nvidia", "cublas", "bin", f"cublasLt64_{cuda_major_version}.dll"),
|
|
174
|
+
("nvidia", "cublas", "bin", f"cublas64_{cuda_major_version}.dll"),
|
|
175
|
+
("nvidia", "cufft", "bin", f"cufft64_{cufft_version}.dll"),
|
|
176
|
+
("nvidia", "cuda_runtime", "bin", f"cudart64_{cuda_major_version}.dll"),
|
|
177
|
+
]
|
|
178
|
+
cudnn_dll_paths = [
|
|
179
|
+
("nvidia", "cudnn", "bin", "cudnn_engines_runtime_compiled64_9.dll"),
|
|
180
|
+
("nvidia", "cudnn", "bin", "cudnn_engines_precompiled64_9.dll"),
|
|
181
|
+
("nvidia", "cudnn", "bin", "cudnn_heuristic64_9.dll"),
|
|
182
|
+
("nvidia", "cudnn", "bin", "cudnn_ops64_9.dll"),
|
|
183
|
+
("nvidia", "cudnn", "bin", "cudnn_adv64_9.dll"),
|
|
184
|
+
("nvidia", "cudnn", "bin", "cudnn_graph64_9.dll"),
|
|
185
|
+
("nvidia", "cudnn", "bin", "cudnn64_9.dll"),
|
|
186
|
+
]
|
|
187
|
+
else: # Linux
|
|
188
|
+
# cublas64 depends on cublasLt64, so cublasLt64 should be loaded first.
|
|
189
|
+
cuda_dll_paths = [
|
|
190
|
+
("nvidia", "cublas", "lib", f"libcublasLt.so.{cuda_major_version}"),
|
|
191
|
+
("nvidia", "cublas", "lib", f"libcublas.so.{cuda_major_version}"),
|
|
192
|
+
("nvidia", "cuda_nvrtc", "lib", f"libnvrtc.so.{cuda_major_version}"),
|
|
193
|
+
("nvidia", "curand", "lib", "libcurand.so.10"),
|
|
194
|
+
("nvidia", "cufft", "lib", f"libcufft.so.{cufft_version}"),
|
|
195
|
+
("nvidia", "cuda_runtime", "lib", f"libcudart.so.{cuda_major_version}"),
|
|
196
|
+
]
|
|
197
|
+
|
|
198
|
+
# Do not load cudnn sub DLLs (they will be dynamically loaded later) to be consistent with PyTorch in Linux.
|
|
199
|
+
cudnn_dll_paths = [
|
|
200
|
+
("nvidia", "cudnn", "lib", "libcudnn.so.9"),
|
|
201
|
+
]
|
|
202
|
+
|
|
203
|
+
return (cuda_dll_paths if cuda else []) + (cudnn_dll_paths if cudnn else [])
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def print_debug_info():
|
|
207
|
+
"""Print information to help debugging."""
|
|
208
|
+
import importlib.util # noqa: PLC0415
|
|
209
|
+
import os # noqa: PLC0415
|
|
210
|
+
import platform # noqa: PLC0415
|
|
211
|
+
from importlib.metadata import distributions # noqa: PLC0415
|
|
212
|
+
|
|
213
|
+
print(f"{package_name} version: {__version__}")
|
|
214
|
+
if cuda_version:
|
|
215
|
+
print(f"CUDA version used in build: {cuda_version}")
|
|
216
|
+
print("platform:", platform.platform())
|
|
217
|
+
|
|
218
|
+
print("\nPython package, version and location:")
|
|
219
|
+
ort_packages = []
|
|
220
|
+
for dist in distributions():
|
|
221
|
+
package = dist.metadata["Name"]
|
|
222
|
+
if package == "onnxruntime" or package.startswith(("onnxruntime-", "ort-")):
|
|
223
|
+
# Exclude packages whose root directory name is not onnxruntime.
|
|
224
|
+
location = _get_package_root(package, "onnxruntime")
|
|
225
|
+
if location and (package not in ort_packages):
|
|
226
|
+
ort_packages.append(package)
|
|
227
|
+
print(f"{package}=={dist.version} at {location}")
|
|
228
|
+
|
|
229
|
+
if len(ort_packages) > 1:
|
|
230
|
+
print(
|
|
231
|
+
"\033[33mWARNING: multiple onnxruntime packages are installed to the same location. "
|
|
232
|
+
"Please 'pip uninstall` all above packages, then `pip install` only one of them.\033[0m"
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
if cuda_version:
|
|
236
|
+
# Print version of installed packages that is related to CUDA or cuDNN DLLs.
|
|
237
|
+
cuda_major = _extract_cuda_major_version(cuda_version)
|
|
238
|
+
|
|
239
|
+
packages = [
|
|
240
|
+
"torch",
|
|
241
|
+
f"nvidia-cuda-runtime-cu{cuda_major}",
|
|
242
|
+
f"nvidia-cudnn-cu{cuda_major}",
|
|
243
|
+
f"nvidia-cublas-cu{cuda_major}",
|
|
244
|
+
f"nvidia-cufft-cu{cuda_major}",
|
|
245
|
+
f"nvidia-curand-cu{cuda_major}",
|
|
246
|
+
f"nvidia-cuda-nvrtc-cu{cuda_major}",
|
|
247
|
+
f"nvidia-nvjitlink-cu{cuda_major}",
|
|
248
|
+
]
|
|
249
|
+
for package in packages:
|
|
250
|
+
directory_name = "nvidia" if package.startswith("nvidia-") else None
|
|
251
|
+
version = _get_package_version(package)
|
|
252
|
+
if version:
|
|
253
|
+
print(f"{package}=={version} at {_get_package_root(package, directory_name)}")
|
|
254
|
+
else:
|
|
255
|
+
print(f"{package} not installed")
|
|
256
|
+
|
|
257
|
+
if platform.system() == "Windows":
|
|
258
|
+
print(f"\nEnvironment variable:\nPATH={os.environ.get('PATH', '(unset)')}")
|
|
259
|
+
elif platform.system() == "Linux":
|
|
260
|
+
print(f"\nEnvironment variable:\nLD_LIBRARY_PATH={os.environ.get('LD_LIBRARY_PATH', '(unset)')}")
|
|
261
|
+
|
|
262
|
+
if importlib.util.find_spec("psutil"):
|
|
263
|
+
|
|
264
|
+
def is_target_dll(path: str):
|
|
265
|
+
target_keywords = ["vcruntime140", "msvcp140"]
|
|
266
|
+
if cuda_version:
|
|
267
|
+
target_keywords = ["cufft", "cublas", "cudart", "nvrtc", "curand", "cudnn", *target_keywords]
|
|
268
|
+
return any(keyword in path for keyword in target_keywords)
|
|
269
|
+
|
|
270
|
+
import psutil # noqa: PLC0415
|
|
271
|
+
|
|
272
|
+
p = psutil.Process(os.getpid())
|
|
273
|
+
|
|
274
|
+
print("\nList of loaded DLLs:")
|
|
275
|
+
for lib in p.memory_maps():
|
|
276
|
+
if is_target_dll(lib.path.lower()):
|
|
277
|
+
print(lib.path)
|
|
278
|
+
|
|
279
|
+
if cuda_version:
|
|
280
|
+
if importlib.util.find_spec("cpuinfo") and importlib.util.find_spec("py3nvml"):
|
|
281
|
+
from .transformers.machine_info import get_device_info # noqa: PLC0415
|
|
282
|
+
|
|
283
|
+
print("\nDevice information:")
|
|
284
|
+
print(get_device_info())
|
|
285
|
+
else:
|
|
286
|
+
print("please `pip install py-cpuinfo py3nvml` to show device information.")
|
|
287
|
+
else:
|
|
288
|
+
print("please `pip install psutil` to show loaded DLLs.")
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def preload_dlls(cuda: bool = True, cudnn: bool = True, msvc: bool = True, directory=None):
|
|
292
|
+
"""Preload CUDA 12.x+ and cuDNN 9.x DLLs in Windows or Linux, and MSVC runtime DLLs in Windows.
|
|
293
|
+
|
|
294
|
+
When the installed PyTorch is compatible (using same major version of CUDA and cuDNN),
|
|
295
|
+
there is no need to call this function if `import torch` is done before `import onnxruntime`.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
cuda (bool, optional): enable loading CUDA DLLs. Defaults to True.
|
|
299
|
+
cudnn (bool, optional): enable loading cuDNN DLLs. Defaults to True.
|
|
300
|
+
msvc (bool, optional): enable loading MSVC DLLs in Windows. Defaults to True.
|
|
301
|
+
directory(str, optional): a directory contains CUDA or cuDNN DLLs. It can be an absolute path,
|
|
302
|
+
or a path relative to the directory of this file.
|
|
303
|
+
If directory is None (default value), the search order: the lib directory of compatible PyTorch in Windows,
|
|
304
|
+
nvidia site packages, default DLL loading paths.
|
|
305
|
+
If directory is empty string (""), the search order: nvidia site packages, default DLL loading paths.
|
|
306
|
+
If directory is a path, the search order: the directory, default DLL loading paths.
|
|
307
|
+
"""
|
|
308
|
+
import ctypes # noqa: PLC0415
|
|
309
|
+
import os # noqa: PLC0415
|
|
310
|
+
import platform # noqa: PLC0415
|
|
311
|
+
import sys # noqa: PLC0415
|
|
312
|
+
|
|
313
|
+
if platform.system() not in ["Windows", "Linux"]:
|
|
314
|
+
return
|
|
315
|
+
|
|
316
|
+
is_windows = platform.system() == "Windows"
|
|
317
|
+
if is_windows and msvc:
|
|
318
|
+
try:
|
|
319
|
+
ctypes.CDLL("vcruntime140.dll")
|
|
320
|
+
ctypes.CDLL("msvcp140.dll")
|
|
321
|
+
if platform.machine() != "ARM64":
|
|
322
|
+
ctypes.CDLL("vcruntime140_1.dll")
|
|
323
|
+
except OSError:
|
|
324
|
+
print("Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure.")
|
|
325
|
+
print("It can be downloaded at https://aka.ms/vs/17/release/vc_redist.x64.exe.")
|
|
326
|
+
|
|
327
|
+
# Check if CUDA version is supported (12.x or 13.x+)
|
|
328
|
+
ort_cuda_major = None
|
|
329
|
+
if cuda_version:
|
|
330
|
+
try:
|
|
331
|
+
ort_cuda_major = int(cuda_version.split(".")[0])
|
|
332
|
+
if ort_cuda_major < 12 and (cuda or cudnn):
|
|
333
|
+
print(
|
|
334
|
+
f"\033[33mWARNING: {package_name} is built with CUDA {cuda_version}, which is not supported for preloading. "
|
|
335
|
+
f"CUDA 12.x or newer is required. Call preload_dlls with cuda=False and cudnn=False.\033[0m"
|
|
336
|
+
)
|
|
337
|
+
return
|
|
338
|
+
except ValueError:
|
|
339
|
+
print(
|
|
340
|
+
f"\033[33mWARNING: Unable to parse CUDA version '{cuda_version}'. "
|
|
341
|
+
"Skipping DLL preloading. Call preload_dlls with cuda=False and cudnn=False.\033[0m"
|
|
342
|
+
)
|
|
343
|
+
return
|
|
344
|
+
elif cuda or cudnn:
|
|
345
|
+
# No CUDA version info available but CUDA/cuDNN preloading requested
|
|
346
|
+
return
|
|
347
|
+
|
|
348
|
+
is_cuda_cudnn_imported_by_torch = False
|
|
349
|
+
|
|
350
|
+
if is_windows:
|
|
351
|
+
torch_version = _get_package_version("torch")
|
|
352
|
+
# Check if torch CUDA version matches onnxruntime CUDA version
|
|
353
|
+
torch_cuda_major = None
|
|
354
|
+
if torch_version and "+cu" in torch_version:
|
|
355
|
+
with contextlib.suppress(ValueError):
|
|
356
|
+
# Extract CUDA version from torch (e.g., "2.0.0+cu121" -> 12)
|
|
357
|
+
cu_part = torch_version.split("+cu")[1]
|
|
358
|
+
torch_cuda_major = int(cu_part[:2]) # First 2 digits are major version
|
|
359
|
+
|
|
360
|
+
is_torch_cuda_compatible = (
|
|
361
|
+
torch_cuda_major == ort_cuda_major if (torch_cuda_major and ort_cuda_major) else False
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
if "torch" in sys.modules:
|
|
365
|
+
is_cuda_cudnn_imported_by_torch = is_torch_cuda_compatible
|
|
366
|
+
if torch_cuda_major and ort_cuda_major and torch_cuda_major != ort_cuda_major:
|
|
367
|
+
print(
|
|
368
|
+
f"\033[33mWARNING: The installed PyTorch {torch_version} uses CUDA {torch_cuda_major}.x, "
|
|
369
|
+
f"but {package_name} is built with CUDA {ort_cuda_major}.x. "
|
|
370
|
+
f"Please install PyTorch for CUDA {ort_cuda_major}.x to be compatible.\033[0m"
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
if is_torch_cuda_compatible and directory is None:
|
|
374
|
+
torch_root = _get_package_root("torch", "torch")
|
|
375
|
+
if torch_root:
|
|
376
|
+
directory = os.path.join(torch_root, "lib")
|
|
377
|
+
|
|
378
|
+
base_directory = directory or ".."
|
|
379
|
+
if not os.path.isabs(base_directory):
|
|
380
|
+
base_directory = os.path.join(os.path.dirname(__file__), base_directory)
|
|
381
|
+
base_directory = os.path.normpath(base_directory)
|
|
382
|
+
if not os.path.isdir(base_directory):
|
|
383
|
+
raise RuntimeError(f"Invalid parameter of directory={directory}. The directory does not exist!")
|
|
384
|
+
|
|
385
|
+
if is_cuda_cudnn_imported_by_torch:
|
|
386
|
+
# In Windows, PyTorch has loaded CUDA and cuDNN DLLs during `import torch`, no need to load them again.
|
|
387
|
+
print("Skip loading CUDA and cuDNN DLLs since torch is imported.")
|
|
388
|
+
return
|
|
389
|
+
|
|
390
|
+
# Try load DLLs from nvidia site packages.
|
|
391
|
+
dll_paths = _get_nvidia_dll_paths(is_windows, cuda, cudnn)
|
|
392
|
+
loaded_dlls = []
|
|
393
|
+
for relative_path in dll_paths:
|
|
394
|
+
dll_path = (
|
|
395
|
+
os.path.join(base_directory, relative_path[-1])
|
|
396
|
+
if directory
|
|
397
|
+
else os.path.join(base_directory, *relative_path)
|
|
398
|
+
)
|
|
399
|
+
if os.path.isfile(dll_path):
|
|
400
|
+
try:
|
|
401
|
+
_ = ctypes.CDLL(dll_path)
|
|
402
|
+
loaded_dlls.append(relative_path[-1])
|
|
403
|
+
except Exception as e:
|
|
404
|
+
print(f"Failed to load {dll_path}: {e}")
|
|
405
|
+
|
|
406
|
+
# Try load DLLs with default path settings.
|
|
407
|
+
has_failure = False
|
|
408
|
+
for relative_path in dll_paths:
|
|
409
|
+
dll_filename = relative_path[-1]
|
|
410
|
+
if dll_filename not in loaded_dlls:
|
|
411
|
+
try:
|
|
412
|
+
_ = ctypes.CDLL(dll_filename)
|
|
413
|
+
except Exception as e:
|
|
414
|
+
has_failure = True
|
|
415
|
+
print(f"Failed to load {dll_filename}: {e}")
|
|
416
|
+
|
|
417
|
+
if has_failure:
|
|
418
|
+
print("Please follow https://onnxruntime.ai/docs/install/#cuda-and-cudnn to install CUDA and CuDNN.")
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
from .backend import is_compatible, prepare, run, supports_device # noqa: F401
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
"""
|
|
6
|
+
Implements ONNX's backend API.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import unittest
|
|
11
|
+
|
|
12
|
+
import packaging.version
|
|
13
|
+
from onnx import ModelProto, helper, version # noqa: F401
|
|
14
|
+
from onnx.backend.base import Backend
|
|
15
|
+
from onnx.checker import check_model
|
|
16
|
+
|
|
17
|
+
from onnxruntime import InferenceSession, SessionOptions, get_available_providers, get_device
|
|
18
|
+
from onnxruntime.backend.backend_rep import OnnxRuntimeBackendRep
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class OnnxRuntimeBackend(Backend):
|
|
22
|
+
"""
|
|
23
|
+
Implements
|
|
24
|
+
`ONNX's backend API <https://github.com/onnx/onnx/blob/main/docs/ImplementingAnOnnxBackend.md>`_
|
|
25
|
+
with *ONNX Runtime*.
|
|
26
|
+
The backend is mostly used when you need to switch between
|
|
27
|
+
multiple runtimes with the same API.
|
|
28
|
+
`Importing models from ONNX to Caffe2 <https://github.com/onnx/tutorials/blob/master/tutorials/OnnxCaffe2Import.ipynb>`_
|
|
29
|
+
shows how to use *caffe2* as a backend for a converted model.
|
|
30
|
+
Note: This is not the official Python API.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
allowReleasedOpsetsOnly = bool(os.getenv("ALLOW_RELEASED_ONNX_OPSET_ONLY", "1") == "1") # noqa: N815
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
def is_compatible(cls, model, device=None, **kwargs):
|
|
37
|
+
"""
|
|
38
|
+
Return whether the model is compatible with the backend.
|
|
39
|
+
|
|
40
|
+
:param model: unused
|
|
41
|
+
:param device: None to use the default device or a string (ex: `'CPU'`)
|
|
42
|
+
:return: boolean
|
|
43
|
+
"""
|
|
44
|
+
if device is None:
|
|
45
|
+
device = get_device()
|
|
46
|
+
return cls.supports_device(device)
|
|
47
|
+
|
|
48
|
+
@classmethod
|
|
49
|
+
def is_opset_supported(cls, model):
|
|
50
|
+
"""
|
|
51
|
+
Return whether the opset for the model is supported by the backend.
|
|
52
|
+
When By default only released onnx opsets are allowed by the backend
|
|
53
|
+
To test new opsets env variable ALLOW_RELEASED_ONNX_OPSET_ONLY should be set to 0
|
|
54
|
+
|
|
55
|
+
:param model: Model whose opsets needed to be verified.
|
|
56
|
+
:return: boolean and error message if opset is not supported.
|
|
57
|
+
"""
|
|
58
|
+
if cls.allowReleasedOpsetsOnly:
|
|
59
|
+
for opset in model.opset_import:
|
|
60
|
+
domain = opset.domain if opset.domain else "ai.onnx"
|
|
61
|
+
try:
|
|
62
|
+
key = (domain, opset.version)
|
|
63
|
+
if key not in helper.OP_SET_ID_VERSION_MAP:
|
|
64
|
+
error_message = (
|
|
65
|
+
"Skipping this test as only released onnx opsets are supported."
|
|
66
|
+
"To run this test set env variable ALLOW_RELEASED_ONNX_OPSET_ONLY to 0."
|
|
67
|
+
f" Got Domain '{domain}' version '{opset.version}'."
|
|
68
|
+
)
|
|
69
|
+
return False, error_message
|
|
70
|
+
except AttributeError:
|
|
71
|
+
# for some CI pipelines accessing helper.OP_SET_ID_VERSION_MAP
|
|
72
|
+
# is generating attribute error. TODO investigate the pipelines to
|
|
73
|
+
# fix this error. Falling back to a simple version check when this error is encountered
|
|
74
|
+
if (domain == "ai.onnx" and opset.version > 12) or (domain == "ai.ommx.ml" and opset.version > 2):
|
|
75
|
+
error_message = (
|
|
76
|
+
"Skipping this test as only released onnx opsets are supported."
|
|
77
|
+
"To run this test set env variable ALLOW_RELEASED_ONNX_OPSET_ONLY to 0."
|
|
78
|
+
f" Got Domain '{domain}' version '{opset.version}'."
|
|
79
|
+
)
|
|
80
|
+
return False, error_message
|
|
81
|
+
return True, ""
|
|
82
|
+
|
|
83
|
+
@classmethod
|
|
84
|
+
def supports_device(cls, device):
|
|
85
|
+
"""
|
|
86
|
+
Check whether the backend is compiled with particular device support.
|
|
87
|
+
In particular it's used in the testing suite.
|
|
88
|
+
"""
|
|
89
|
+
if device == "CUDA":
|
|
90
|
+
device = "GPU"
|
|
91
|
+
return "-" + device in get_device() or device + "-" in get_device() or device == get_device()
|
|
92
|
+
|
|
93
|
+
@classmethod
|
|
94
|
+
def prepare(cls, model, device=None, **kwargs):
|
|
95
|
+
"""
|
|
96
|
+
Load the model and creates a :class:`onnxruntime.InferenceSession`
|
|
97
|
+
ready to be used as a backend.
|
|
98
|
+
|
|
99
|
+
:param model: ModelProto (returned by `onnx.load`),
|
|
100
|
+
string for a filename or bytes for a serialized model
|
|
101
|
+
:param device: requested device for the computation,
|
|
102
|
+
None means the default one which depends on
|
|
103
|
+
the compilation settings
|
|
104
|
+
:param kwargs: see :class:`onnxruntime.SessionOptions`
|
|
105
|
+
:return: :class:`onnxruntime.InferenceSession`
|
|
106
|
+
"""
|
|
107
|
+
if isinstance(model, OnnxRuntimeBackendRep):
|
|
108
|
+
return model
|
|
109
|
+
elif isinstance(model, InferenceSession):
|
|
110
|
+
return OnnxRuntimeBackendRep(model)
|
|
111
|
+
elif isinstance(model, (str, bytes)):
|
|
112
|
+
options = SessionOptions()
|
|
113
|
+
for k, v in kwargs.items():
|
|
114
|
+
if hasattr(options, k):
|
|
115
|
+
setattr(options, k, v)
|
|
116
|
+
|
|
117
|
+
excluded_providers = os.getenv("ORT_ONNX_BACKEND_EXCLUDE_PROVIDERS", default="").split(",")
|
|
118
|
+
providers = [x for x in get_available_providers() if (x not in excluded_providers)]
|
|
119
|
+
|
|
120
|
+
inf = InferenceSession(model, sess_options=options, providers=providers)
|
|
121
|
+
# backend API is primarily used for ONNX test/validation. As such, we should disable session.run() fallback
|
|
122
|
+
# which may hide test failures.
|
|
123
|
+
inf.disable_fallback()
|
|
124
|
+
if device is not None and not cls.supports_device(device):
|
|
125
|
+
raise RuntimeError(f"Incompatible device expected '{device}', got '{get_device()}'")
|
|
126
|
+
return cls.prepare(inf, device, **kwargs)
|
|
127
|
+
else:
|
|
128
|
+
# type: ModelProto
|
|
129
|
+
# check_model serializes the model anyways, so serialize the model once here
|
|
130
|
+
# and reuse it below in the cls.prepare call to avoid an additional serialization
|
|
131
|
+
# only works with onnx >= 1.10.0 hence the version check
|
|
132
|
+
onnx_version = packaging.version.parse(version.version) or packaging.version.Version("0")
|
|
133
|
+
onnx_supports_serialized_model_check = onnx_version.release >= (1, 10, 0)
|
|
134
|
+
bin_or_model = model.SerializeToString() if onnx_supports_serialized_model_check else model
|
|
135
|
+
check_model(bin_or_model)
|
|
136
|
+
opset_supported, error_message = cls.is_opset_supported(model)
|
|
137
|
+
if not opset_supported:
|
|
138
|
+
raise unittest.SkipTest(error_message)
|
|
139
|
+
# Now bin might be serialized, if it's not we need to serialize it otherwise we'll have
|
|
140
|
+
# an infinite recursive call
|
|
141
|
+
bin = bin_or_model
|
|
142
|
+
if not isinstance(bin, (str, bytes)):
|
|
143
|
+
bin = bin.SerializeToString()
|
|
144
|
+
return cls.prepare(bin, device, **kwargs)
|
|
145
|
+
|
|
146
|
+
@classmethod
|
|
147
|
+
def run_model(cls, model, inputs, device=None, **kwargs):
|
|
148
|
+
"""
|
|
149
|
+
Compute the prediction.
|
|
150
|
+
|
|
151
|
+
:param model: :class:`onnxruntime.InferenceSession` returned
|
|
152
|
+
by function *prepare*
|
|
153
|
+
:param inputs: inputs
|
|
154
|
+
:param device: requested device for the computation,
|
|
155
|
+
None means the default one which depends on
|
|
156
|
+
the compilation settings
|
|
157
|
+
:param kwargs: see :class:`onnxruntime.RunOptions`
|
|
158
|
+
:return: predictions
|
|
159
|
+
"""
|
|
160
|
+
rep = cls.prepare(model, device, **kwargs)
|
|
161
|
+
return rep.run(inputs, **kwargs)
|
|
162
|
+
|
|
163
|
+
@classmethod
|
|
164
|
+
def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
|
|
165
|
+
"""
|
|
166
|
+
This method is not implemented as it is much more efficient
|
|
167
|
+
to run a whole model than every node independently.
|
|
168
|
+
"""
|
|
169
|
+
raise NotImplementedError("It is much more efficient to run a whole model than every node independently.")
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
is_compatible = OnnxRuntimeBackend.is_compatible
|
|
173
|
+
prepare = OnnxRuntimeBackend.prepare
|
|
174
|
+
run = OnnxRuntimeBackend.run_model
|
|
175
|
+
supports_device = OnnxRuntimeBackend.supports_device
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
"""
|
|
6
|
+
Implements ONNX's backend API.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from onnx.backend.base import BackendRep
|
|
10
|
+
|
|
11
|
+
from onnxruntime import RunOptions
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class OnnxRuntimeBackendRep(BackendRep):
|
|
15
|
+
"""
|
|
16
|
+
Computes the prediction for a pipeline converted into
|
|
17
|
+
an :class:`onnxruntime.InferenceSession` node.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, session):
|
|
21
|
+
"""
|
|
22
|
+
:param session: :class:`onnxruntime.InferenceSession`
|
|
23
|
+
"""
|
|
24
|
+
self._session = session
|
|
25
|
+
|
|
26
|
+
def run(self, inputs, **kwargs): # type: (Any, **Any) -> Tuple[Any, ...]
|
|
27
|
+
"""
|
|
28
|
+
Computes the prediction.
|
|
29
|
+
See :meth:`onnxruntime.InferenceSession.run`.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
options = RunOptions()
|
|
33
|
+
for k, v in kwargs.items():
|
|
34
|
+
if hasattr(options, k):
|
|
35
|
+
setattr(options, k, v)
|
|
36
|
+
|
|
37
|
+
if isinstance(inputs, list):
|
|
38
|
+
inps = {}
|
|
39
|
+
for i, inp in enumerate(self._session.get_inputs()):
|
|
40
|
+
inps[inp.name] = inputs[i]
|
|
41
|
+
outs = self._session.run(None, inps, options)
|
|
42
|
+
if isinstance(outs, list):
|
|
43
|
+
return outs
|
|
44
|
+
else:
|
|
45
|
+
output_names = [o.name for o in self._session.get_outputs()]
|
|
46
|
+
return [outs[name] for name in output_names]
|
|
47
|
+
else:
|
|
48
|
+
inp = self._session.get_inputs()
|
|
49
|
+
if len(inp) != 1:
|
|
50
|
+
raise RuntimeError(f"Model expect {len(inp)} inputs")
|
|
51
|
+
inps = {inp[0].name: inputs}
|
|
52
|
+
return self._session.run(None, inps, options)
|
|
Binary file
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
# This file can be modified by setup.py when building a manylinux2010 wheel
|
|
7
|
+
# When modified, it will preload some libraries needed for the python C extension
|