onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
onnxruntime/__init__.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
"""
|
|
6
|
+
ONNX Runtime is a performance-focused scoring engine for Open Neural Network Exchange (ONNX) models.
|
|
7
|
+
For more information on ONNX Runtime, please see `aka.ms/onnxruntime <https://aka.ms/onnxruntime/>`_
|
|
8
|
+
or the `Github project <https://github.com/microsoft/onnxruntime/>`_.
|
|
9
|
+
"""
|
|
10
|
+
__version__ = "1.20.0"
|
|
11
|
+
__author__ = "Microsoft"
|
|
12
|
+
|
|
13
|
+
# we need to do device version validation (for example to check Cuda version for an onnxruntime-training package).
|
|
14
|
+
# in order to know whether the onnxruntime package is for training it needs
|
|
15
|
+
# to do import onnxruntime.training.ortmodule first.
|
|
16
|
+
# onnxruntime.capi._pybind_state is required before import onnxruntime.training.ortmodule.
|
|
17
|
+
# however, import onnxruntime.capi._pybind_state will already raise an exception if a required Cuda version
|
|
18
|
+
# is not found.
|
|
19
|
+
# here we need to save the exception and continue with Cuda version validation in order to post
|
|
20
|
+
# meaningful messages to the user.
|
|
21
|
+
# the saved exception is raised after device version validation.
|
|
22
|
+
try:
|
|
23
|
+
from onnxruntime.capi._pybind_state import ExecutionMode # noqa: F401
|
|
24
|
+
from onnxruntime.capi._pybind_state import ExecutionOrder # noqa: F401
|
|
25
|
+
from onnxruntime.capi._pybind_state import GraphOptimizationLevel # noqa: F401
|
|
26
|
+
from onnxruntime.capi._pybind_state import LoraAdapter # noqa: F401
|
|
27
|
+
from onnxruntime.capi._pybind_state import ModelMetadata # noqa: F401
|
|
28
|
+
from onnxruntime.capi._pybind_state import NodeArg # noqa: F401
|
|
29
|
+
from onnxruntime.capi._pybind_state import OrtAllocatorType # noqa: F401
|
|
30
|
+
from onnxruntime.capi._pybind_state import OrtArenaCfg # noqa: F401
|
|
31
|
+
from onnxruntime.capi._pybind_state import OrtMemoryInfo # noqa: F401
|
|
32
|
+
from onnxruntime.capi._pybind_state import OrtMemType # noqa: F401
|
|
33
|
+
from onnxruntime.capi._pybind_state import OrtSparseFormat # noqa: F401
|
|
34
|
+
from onnxruntime.capi._pybind_state import RunOptions # noqa: F401
|
|
35
|
+
from onnxruntime.capi._pybind_state import SessionIOBinding # noqa: F401
|
|
36
|
+
from onnxruntime.capi._pybind_state import SessionOptions # noqa: F401
|
|
37
|
+
from onnxruntime.capi._pybind_state import create_and_register_allocator # noqa: F401
|
|
38
|
+
from onnxruntime.capi._pybind_state import create_and_register_allocator_v2 # noqa: F401
|
|
39
|
+
from onnxruntime.capi._pybind_state import disable_telemetry_events # noqa: F401
|
|
40
|
+
from onnxruntime.capi._pybind_state import enable_telemetry_events # noqa: F401
|
|
41
|
+
from onnxruntime.capi._pybind_state import get_all_providers # noqa: F401
|
|
42
|
+
from onnxruntime.capi._pybind_state import get_available_providers # noqa: F401
|
|
43
|
+
from onnxruntime.capi._pybind_state import get_build_info # noqa: F401
|
|
44
|
+
from onnxruntime.capi._pybind_state import get_device # noqa: F401
|
|
45
|
+
from onnxruntime.capi._pybind_state import get_version_string # noqa: F401
|
|
46
|
+
from onnxruntime.capi._pybind_state import has_collective_ops # noqa: F401
|
|
47
|
+
from onnxruntime.capi._pybind_state import set_default_logger_severity # noqa: F401
|
|
48
|
+
from onnxruntime.capi._pybind_state import set_default_logger_verbosity # noqa: F401
|
|
49
|
+
from onnxruntime.capi._pybind_state import set_seed # noqa: F401
|
|
50
|
+
|
|
51
|
+
import_capi_exception = None
|
|
52
|
+
except Exception as e:
|
|
53
|
+
import_capi_exception = e
|
|
54
|
+
|
|
55
|
+
from onnxruntime.capi import onnxruntime_validation
|
|
56
|
+
|
|
57
|
+
if import_capi_exception:
|
|
58
|
+
raise import_capi_exception
|
|
59
|
+
|
|
60
|
+
from onnxruntime.capi.onnxruntime_inference_collection import AdapterFormat # noqa: F401
|
|
61
|
+
from onnxruntime.capi.onnxruntime_inference_collection import InferenceSession # noqa: F401
|
|
62
|
+
from onnxruntime.capi.onnxruntime_inference_collection import IOBinding # noqa: F401
|
|
63
|
+
from onnxruntime.capi.onnxruntime_inference_collection import OrtDevice # noqa: F401
|
|
64
|
+
from onnxruntime.capi.onnxruntime_inference_collection import OrtValue # noqa: F401
|
|
65
|
+
from onnxruntime.capi.onnxruntime_inference_collection import SparseTensor # noqa: F401
|
|
66
|
+
|
|
67
|
+
# TODO: thiagofc: Temporary experimental namespace for new PyTorch front-end
|
|
68
|
+
try: # noqa: SIM105
|
|
69
|
+
from . import experimental # noqa: F401
|
|
70
|
+
except ImportError:
|
|
71
|
+
pass
|
|
72
|
+
|
|
73
|
+
from onnxruntime.capi.onnxruntime_validation import cuda_version, package_name, version # noqa: F401
|
|
74
|
+
|
|
75
|
+
if version:
|
|
76
|
+
__version__ = version
|
|
77
|
+
|
|
78
|
+
onnxruntime_validation.check_distro_info()
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
from .backend import is_compatible, prepare, run, supports_device # noqa: F401
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
"""
|
|
6
|
+
Implements ONNX's backend API.
|
|
7
|
+
"""
|
|
8
|
+
import os
|
|
9
|
+
import unittest
|
|
10
|
+
|
|
11
|
+
import packaging.version
|
|
12
|
+
from onnx import ModelProto, helper, version # noqa: F401
|
|
13
|
+
from onnx.backend.base import Backend
|
|
14
|
+
from onnx.checker import check_model
|
|
15
|
+
|
|
16
|
+
from onnxruntime import InferenceSession, SessionOptions, get_available_providers, get_device
|
|
17
|
+
from onnxruntime.backend.backend_rep import OnnxRuntimeBackendRep
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class OnnxRuntimeBackend(Backend):
|
|
21
|
+
"""
|
|
22
|
+
Implements
|
|
23
|
+
`ONNX's backend API <https://github.com/onnx/onnx/blob/main/docs/ImplementingAnOnnxBackend.md>`_
|
|
24
|
+
with *ONNX Runtime*.
|
|
25
|
+
The backend is mostly used when you need to switch between
|
|
26
|
+
multiple runtimes with the same API.
|
|
27
|
+
`Importing models from ONNX to Caffe2 <https://github.com/onnx/tutorials/blob/master/tutorials/OnnxCaffe2Import.ipynb>`_
|
|
28
|
+
shows how to use *caffe2* as a backend for a converted model.
|
|
29
|
+
Note: This is not the official Python API.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
allowReleasedOpsetsOnly = bool(os.getenv("ALLOW_RELEASED_ONNX_OPSET_ONLY", "1") == "1") # noqa: N815
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
def is_compatible(cls, model, device=None, **kwargs):
|
|
36
|
+
"""
|
|
37
|
+
Return whether the model is compatible with the backend.
|
|
38
|
+
|
|
39
|
+
:param model: unused
|
|
40
|
+
:param device: None to use the default device or a string (ex: `'CPU'`)
|
|
41
|
+
:return: boolean
|
|
42
|
+
"""
|
|
43
|
+
if device is None:
|
|
44
|
+
device = get_device()
|
|
45
|
+
return cls.supports_device(device)
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def is_opset_supported(cls, model):
|
|
49
|
+
"""
|
|
50
|
+
Return whether the opset for the model is supported by the backend.
|
|
51
|
+
When By default only released onnx opsets are allowed by the backend
|
|
52
|
+
To test new opsets env variable ALLOW_RELEASED_ONNX_OPSET_ONLY should be set to 0
|
|
53
|
+
|
|
54
|
+
:param model: Model whose opsets needed to be verified.
|
|
55
|
+
:return: boolean and error message if opset is not supported.
|
|
56
|
+
"""
|
|
57
|
+
if cls.allowReleasedOpsetsOnly:
|
|
58
|
+
for opset in model.opset_import:
|
|
59
|
+
domain = opset.domain if opset.domain else "ai.onnx"
|
|
60
|
+
try:
|
|
61
|
+
key = (domain, opset.version)
|
|
62
|
+
if key not in helper.OP_SET_ID_VERSION_MAP:
|
|
63
|
+
error_message = (
|
|
64
|
+
"Skipping this test as only released onnx opsets are supported."
|
|
65
|
+
"To run this test set env variable ALLOW_RELEASED_ONNX_OPSET_ONLY to 0."
|
|
66
|
+
f" Got Domain '{domain}' version '{opset.version}'."
|
|
67
|
+
)
|
|
68
|
+
return False, error_message
|
|
69
|
+
except AttributeError:
|
|
70
|
+
# for some CI pipelines accessing helper.OP_SET_ID_VERSION_MAP
|
|
71
|
+
# is generating attribute error. TODO investigate the pipelines to
|
|
72
|
+
# fix this error. Falling back to a simple version check when this error is encountered
|
|
73
|
+
if (domain == "ai.onnx" and opset.version > 12) or (domain == "ai.ommx.ml" and opset.version > 2):
|
|
74
|
+
error_message = (
|
|
75
|
+
"Skipping this test as only released onnx opsets are supported."
|
|
76
|
+
"To run this test set env variable ALLOW_RELEASED_ONNX_OPSET_ONLY to 0."
|
|
77
|
+
f" Got Domain '{domain}' version '{opset.version}'."
|
|
78
|
+
)
|
|
79
|
+
return False, error_message
|
|
80
|
+
return True, ""
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def supports_device(cls, device):
|
|
84
|
+
"""
|
|
85
|
+
Check whether the backend is compiled with particular device support.
|
|
86
|
+
In particular it's used in the testing suite.
|
|
87
|
+
"""
|
|
88
|
+
if device == "CUDA":
|
|
89
|
+
device = "GPU"
|
|
90
|
+
return device in get_device()
|
|
91
|
+
|
|
92
|
+
@classmethod
|
|
93
|
+
def prepare(cls, model, device=None, **kwargs):
|
|
94
|
+
"""
|
|
95
|
+
Load the model and creates a :class:`onnxruntime.InferenceSession`
|
|
96
|
+
ready to be used as a backend.
|
|
97
|
+
|
|
98
|
+
:param model: ModelProto (returned by `onnx.load`),
|
|
99
|
+
string for a filename or bytes for a serialized model
|
|
100
|
+
:param device: requested device for the computation,
|
|
101
|
+
None means the default one which depends on
|
|
102
|
+
the compilation settings
|
|
103
|
+
:param kwargs: see :class:`onnxruntime.SessionOptions`
|
|
104
|
+
:return: :class:`onnxruntime.InferenceSession`
|
|
105
|
+
"""
|
|
106
|
+
if isinstance(model, OnnxRuntimeBackendRep):
|
|
107
|
+
return model
|
|
108
|
+
elif isinstance(model, InferenceSession):
|
|
109
|
+
return OnnxRuntimeBackendRep(model)
|
|
110
|
+
elif isinstance(model, (str, bytes)):
|
|
111
|
+
options = SessionOptions()
|
|
112
|
+
for k, v in kwargs.items():
|
|
113
|
+
if hasattr(options, k):
|
|
114
|
+
setattr(options, k, v)
|
|
115
|
+
|
|
116
|
+
excluded_providers = os.getenv("ORT_ONNX_BACKEND_EXCLUDE_PROVIDERS", default="").split(",")
|
|
117
|
+
providers = [x for x in get_available_providers() if (x not in excluded_providers)]
|
|
118
|
+
|
|
119
|
+
inf = InferenceSession(model, sess_options=options, providers=providers)
|
|
120
|
+
# backend API is primarily used for ONNX test/validation. As such, we should disable session.run() fallback
|
|
121
|
+
# which may hide test failures.
|
|
122
|
+
inf.disable_fallback()
|
|
123
|
+
if device is not None and not cls.supports_device(device):
|
|
124
|
+
raise RuntimeError(f"Incompatible device expected '{device}', got '{get_device()}'")
|
|
125
|
+
return cls.prepare(inf, device, **kwargs)
|
|
126
|
+
else:
|
|
127
|
+
# type: ModelProto
|
|
128
|
+
# check_model serializes the model anyways, so serialize the model once here
|
|
129
|
+
# and reuse it below in the cls.prepare call to avoid an additional serialization
|
|
130
|
+
# only works with onnx >= 1.10.0 hence the version check
|
|
131
|
+
onnx_version = packaging.version.parse(version.version) or packaging.version.Version("0")
|
|
132
|
+
onnx_supports_serialized_model_check = onnx_version.release >= (1, 10, 0)
|
|
133
|
+
bin_or_model = model.SerializeToString() if onnx_supports_serialized_model_check else model
|
|
134
|
+
check_model(bin_or_model)
|
|
135
|
+
opset_supported, error_message = cls.is_opset_supported(model)
|
|
136
|
+
if not opset_supported:
|
|
137
|
+
raise unittest.SkipTest(error_message)
|
|
138
|
+
# Now bin might be serialized, if it's not we need to serialize it otherwise we'll have
|
|
139
|
+
# an infinite recursive call
|
|
140
|
+
bin = bin_or_model
|
|
141
|
+
if not isinstance(bin, (str, bytes)):
|
|
142
|
+
bin = bin.SerializeToString()
|
|
143
|
+
return cls.prepare(bin, device, **kwargs)
|
|
144
|
+
|
|
145
|
+
@classmethod
|
|
146
|
+
def run_model(cls, model, inputs, device=None, **kwargs):
|
|
147
|
+
"""
|
|
148
|
+
Compute the prediction.
|
|
149
|
+
|
|
150
|
+
:param model: :class:`onnxruntime.InferenceSession` returned
|
|
151
|
+
by function *prepare*
|
|
152
|
+
:param inputs: inputs
|
|
153
|
+
:param device: requested device for the computation,
|
|
154
|
+
None means the default one which depends on
|
|
155
|
+
the compilation settings
|
|
156
|
+
:param kwargs: see :class:`onnxruntime.RunOptions`
|
|
157
|
+
:return: predictions
|
|
158
|
+
"""
|
|
159
|
+
rep = cls.prepare(model, device, **kwargs)
|
|
160
|
+
return rep.run(inputs, **kwargs)
|
|
161
|
+
|
|
162
|
+
@classmethod
|
|
163
|
+
def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
|
|
164
|
+
"""
|
|
165
|
+
This method is not implemented as it is much more efficient
|
|
166
|
+
to run a whole model than every node independently.
|
|
167
|
+
"""
|
|
168
|
+
raise NotImplementedError("It is much more efficient to run a whole model than every node independently.")
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
is_compatible = OnnxRuntimeBackend.is_compatible
|
|
172
|
+
prepare = OnnxRuntimeBackend.prepare
|
|
173
|
+
run = OnnxRuntimeBackend.run_model
|
|
174
|
+
supports_device = OnnxRuntimeBackend.supports_device
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
"""
|
|
6
|
+
Implements ONNX's backend API.
|
|
7
|
+
"""
|
|
8
|
+
from typing import Any, Tuple # noqa: F401
|
|
9
|
+
|
|
10
|
+
from onnx.backend.base import BackendRep
|
|
11
|
+
|
|
12
|
+
from onnxruntime import RunOptions
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class OnnxRuntimeBackendRep(BackendRep):
|
|
16
|
+
"""
|
|
17
|
+
Computes the prediction for a pipeline converted into
|
|
18
|
+
an :class:`onnxruntime.InferenceSession` node.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, session):
|
|
22
|
+
"""
|
|
23
|
+
:param session: :class:`onnxruntime.InferenceSession`
|
|
24
|
+
"""
|
|
25
|
+
self._session = session
|
|
26
|
+
|
|
27
|
+
def run(self, inputs, **kwargs): # type: (Any, **Any) -> Tuple[Any, ...]
|
|
28
|
+
"""
|
|
29
|
+
Computes the prediction.
|
|
30
|
+
See :meth:`onnxruntime.InferenceSession.run`.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
options = RunOptions()
|
|
34
|
+
for k, v in kwargs.items():
|
|
35
|
+
if hasattr(options, k):
|
|
36
|
+
setattr(options, k, v)
|
|
37
|
+
|
|
38
|
+
if isinstance(inputs, list):
|
|
39
|
+
inps = {}
|
|
40
|
+
for i, inp in enumerate(self._session.get_inputs()):
|
|
41
|
+
inps[inp.name] = inputs[i]
|
|
42
|
+
outs = self._session.run(None, inps, options)
|
|
43
|
+
if isinstance(outs, list):
|
|
44
|
+
return outs
|
|
45
|
+
else:
|
|
46
|
+
output_names = [o.name for o in self._session.get_outputs()]
|
|
47
|
+
return [outs[name] for name in output_names]
|
|
48
|
+
else:
|
|
49
|
+
inp = self._session.get_inputs()
|
|
50
|
+
if len(inp) != 1:
|
|
51
|
+
raise RuntimeError(f"Model expect {len(inp)} inputs")
|
|
52
|
+
inps = {inp[0].name: inputs}
|
|
53
|
+
return self._session.run(None, inps, options)
|
|
Binary file
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
# This file can be modified by setup.py when building a manylinux2010 wheel
|
|
7
|
+
# When modified, it will preload some libraries needed for the python C extension
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
"""
|
|
6
|
+
Ensure that dependencies are available and then load the extension module.
|
|
7
|
+
"""
|
|
8
|
+
import os
|
|
9
|
+
import platform
|
|
10
|
+
import warnings
|
|
11
|
+
|
|
12
|
+
from . import _ld_preload # noqa: F401
|
|
13
|
+
|
|
14
|
+
if platform.system() == "Windows":
|
|
15
|
+
from . import version_info
|
|
16
|
+
|
|
17
|
+
# If on Windows, check if this import error is caused by the user not installing the 2019 VC Runtime
|
|
18
|
+
# The VC Redist installer usually puts the VC Runtime dlls in the System32 folder, but it may also be found
|
|
19
|
+
# in some other locations.
|
|
20
|
+
# TODO, we may want to try to load the VC Runtime dlls instead of checking if the hardcoded file path
|
|
21
|
+
# is valid, and raise ImportError if the load fails
|
|
22
|
+
if version_info.vs2019 and platform.architecture()[0] == "64bit":
|
|
23
|
+
system_root = os.getenv("SystemRoot") or "C:\\Windows"
|
|
24
|
+
if not os.path.isfile(os.path.join(system_root, "System32", "vcruntime140_1.dll")):
|
|
25
|
+
warnings.warn("Please install the 2019 Visual C++ runtime and then try again. "
|
|
26
|
+
"If you've installed the runtime in a non-standard location "
|
|
27
|
+
"(other than %SystemRoot%\\System32), "
|
|
28
|
+
"make sure it can be found by setting the correct path.")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
from .onnxruntime_pybind11_state import * # noqa
|
|
33
|
+
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
# This script helps converting .npz files to .onnx_adapter files
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import os
|
|
8
|
+
import sys
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
import onnxruntime as ort
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_args() -> argparse:
|
|
16
|
+
parser = argparse.ArgumentParser()
|
|
17
|
+
parser.add_argument("--npz_file_path", type=str, required=True)
|
|
18
|
+
parser.add_argument("--output_file_path", type=str, required=True)
|
|
19
|
+
parser.add_argument("--adapter_version", type=int, required=True)
|
|
20
|
+
parser.add_argument("--model_version", type=int, required=True)
|
|
21
|
+
return parser.parse_args()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def export_lora_parameters(
|
|
25
|
+
npz_file_path: os.PathLike, adapter_version: int, model_version: int, output_file_path: os.PathLike
|
|
26
|
+
):
|
|
27
|
+
"""The function converts lora parameters in npz to onnx_adapter format"""
|
|
28
|
+
adapter_format = ort.AdapterFormat()
|
|
29
|
+
adapter_format.set_adapter_version(adapter_version)
|
|
30
|
+
adapter_format.set_model_version(model_version)
|
|
31
|
+
name_to_ort_value = {}
|
|
32
|
+
with np.load(npz_file_path) as data:
|
|
33
|
+
for name, np_arr in data.items():
|
|
34
|
+
ort_value = ort.OrtValue.ortvalue_from_numpy(np_arr)
|
|
35
|
+
name_to_ort_value[name] = ort_value
|
|
36
|
+
|
|
37
|
+
adapter_format.set_parameters(name_to_ort_value)
|
|
38
|
+
adapter_format.export_adapter(output_file_path)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def main() -> int:
|
|
42
|
+
args = get_args()
|
|
43
|
+
export_lora_parameters(args.npz_file_path, args.adapter_version, args.model_version, args.output_file_path)
|
|
44
|
+
return 0
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
if __name__ == "__main__":
|
|
48
|
+
sys.exit(main())
|
|
Binary file
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import ctypes
|
|
6
|
+
import sys
|
|
7
|
+
import warnings
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def find_cudart_versions(build_env=False, build_cuda_version=None):
|
|
11
|
+
# ctypes.CDLL and ctypes.util.find_library load the latest installed library.
|
|
12
|
+
# it may not the the library that would be loaded by onnxruntime.
|
|
13
|
+
# for example, in an environment with Cuda 11.1 and subsequently
|
|
14
|
+
# conda cudatoolkit 10.2.89 installed. ctypes will find cudart 10.2. however,
|
|
15
|
+
# onnxruntime built with Cuda 11.1 will find and load cudart for Cuda 11.1.
|
|
16
|
+
# for the above reason, we need find all versions in the environment and
|
|
17
|
+
# only give warnings if the expected cuda version is not found.
|
|
18
|
+
# in onnxruntime build environment, we expected only one Cuda version.
|
|
19
|
+
if not sys.platform.startswith("linux"):
|
|
20
|
+
warnings.warn("find_cudart_versions only works on Linux")
|
|
21
|
+
return None
|
|
22
|
+
|
|
23
|
+
cudart_possible_versions = {None, build_cuda_version}
|
|
24
|
+
|
|
25
|
+
def get_cudart_version(find_cudart_version=None):
|
|
26
|
+
cudart_lib_filename = "libcudart.so"
|
|
27
|
+
if find_cudart_version:
|
|
28
|
+
cudart_lib_filename = cudart_lib_filename + "." + find_cudart_version
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
cudart = ctypes.CDLL(cudart_lib_filename)
|
|
32
|
+
cudart.cudaRuntimeGetVersion.restype = int
|
|
33
|
+
cudart.cudaRuntimeGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
|
|
34
|
+
version = ctypes.c_int()
|
|
35
|
+
status = cudart.cudaRuntimeGetVersion(ctypes.byref(version))
|
|
36
|
+
if status != 0:
|
|
37
|
+
return None
|
|
38
|
+
except Exception:
|
|
39
|
+
return None
|
|
40
|
+
|
|
41
|
+
return version.value
|
|
42
|
+
|
|
43
|
+
# use set to avoid duplications
|
|
44
|
+
cudart_found_versions = {get_cudart_version(cudart_version) for cudart_version in cudart_possible_versions}
|
|
45
|
+
|
|
46
|
+
# convert to list and remove None
|
|
47
|
+
return [ver for ver in cudart_found_versions if ver]
|