onnxruntime-directml 1.22.1.dev20250710002__cp313-cp313-win_amd64.whl → 1.24.1__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. onnxruntime/ThirdPartyNotices.txt +0 -35
  2. onnxruntime/__init__.py +119 -46
  3. onnxruntime/capi/DirectML.dll +0 -0
  4. onnxruntime/capi/build_and_package_info.py +1 -1
  5. onnxruntime/capi/onnxruntime.dll +0 -0
  6. onnxruntime/capi/onnxruntime_inference_collection.py +338 -52
  7. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  8. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  9. onnxruntime/capi/onnxruntime_validation.py +10 -10
  10. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +1 -1
  11. onnxruntime/quantization/base_quantizer.py +3 -38
  12. onnxruntime/quantization/calibrate.py +20 -5
  13. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  14. onnxruntime/quantization/execution_providers/qnn/preprocess.py +46 -0
  15. onnxruntime/quantization/execution_providers/qnn/quant_config.py +0 -17
  16. onnxruntime/quantization/fusions/__init__.py +1 -0
  17. onnxruntime/quantization/fusions/fusion_layernorm.py +18 -7
  18. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  19. onnxruntime/quantization/matmul_bnb4_quantizer.py +1 -1
  20. onnxruntime/quantization/matmul_nbits_quantizer.py +151 -49
  21. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  22. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  23. onnxruntime/quantization/neural_compressor/util.py +80 -0
  24. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  25. onnxruntime/quantization/onnx_model.py +1 -1
  26. onnxruntime/quantization/onnx_quantizer.py +156 -1
  27. onnxruntime/quantization/operators/gemm.py +3 -3
  28. onnxruntime/quantization/qdq_quantizer.py +0 -1
  29. onnxruntime/quantization/quant_utils.py +67 -37
  30. onnxruntime/quantization/quantize.py +16 -6
  31. onnxruntime/quantization/registry.py +2 -0
  32. onnxruntime/quantization/shape_inference.py +16 -4
  33. onnxruntime/quantization/static_quantize_runner.py +1 -1
  34. onnxruntime/quantization/tensor_quant_overrides.py +1 -1
  35. onnxruntime/tools/convert_onnx_models_to_ort.py +6 -3
  36. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +3 -0
  37. onnxruntime/tools/mobile_helpers/usability_checker.py +1 -1
  38. onnxruntime/tools/onnx_model_utils.py +3 -0
  39. onnxruntime/tools/optimize_onnx_model.py +1 -1
  40. onnxruntime/tools/ort_format_model/utils.py +1 -2
  41. onnxruntime/tools/pytorch_export_contrib_ops.py +1 -1
  42. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  43. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  44. onnxruntime/tools/qnn/preprocess.py +165 -0
  45. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  46. onnxruntime/tools/symbolic_shape_infer.py +13 -12
  47. onnxruntime/transformers/benchmark.py +7 -10
  48. onnxruntime/transformers/benchmark_helper.py +11 -15
  49. onnxruntime/transformers/bert_perf_test.py +5 -9
  50. onnxruntime/transformers/bert_test_data.py +1 -1
  51. onnxruntime/transformers/compare_bert_results.py +1 -1
  52. onnxruntime/transformers/convert_generation.py +106 -48
  53. onnxruntime/transformers/convert_tf_models_to_pytorch.py +8 -8
  54. onnxruntime/transformers/convert_to_packing_mode.py +4 -5
  55. onnxruntime/transformers/dynamo_onnx_helper.py +1 -1
  56. onnxruntime/transformers/fusion_attention.py +2 -2
  57. onnxruntime/transformers/fusion_attention_clip.py +38 -33
  58. onnxruntime/transformers/fusion_bart_attention.py +205 -414
  59. onnxruntime/transformers/fusion_base.py +2 -2
  60. onnxruntime/transformers/fusion_nhwc_conv.py +1 -1
  61. onnxruntime/transformers/fusion_utils.py +9 -5
  62. onnxruntime/transformers/io_binding_helper.py +61 -21
  63. onnxruntime/transformers/machine_info.py +11 -9
  64. onnxruntime/transformers/models/bert/eval_squad.py +1 -1
  65. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +10 -2
  66. onnxruntime/transformers/models/gpt2/gpt2_parity.py +1 -1
  67. onnxruntime/transformers/models/gpt2/gpt2_tester.py +3 -3
  68. onnxruntime/transformers/models/gpt2/parity_check_helper.py +2 -2
  69. onnxruntime/transformers/models/llama/benchmark.py +1 -4
  70. onnxruntime/transformers/models/llama/benchmark_all.py +3 -3
  71. onnxruntime/transformers/models/llama/convert_to_onnx.py +28 -62
  72. onnxruntime/transformers/models/llama/dist_settings.py +4 -4
  73. onnxruntime/transformers/models/llama/llama_parity.py +7 -6
  74. onnxruntime/transformers/models/longformer/benchmark_longformer.py +3 -3
  75. onnxruntime/transformers/models/longformer/convert_to_onnx.py +1 -1
  76. onnxruntime/transformers/models/phi2/convert_to_onnx.py +14 -6
  77. onnxruntime/transformers/models/phi2/inference_example.py +3 -3
  78. onnxruntime/transformers/models/sam2/benchmark_sam2.py +6 -6
  79. onnxruntime/transformers/models/sam2/convert_to_onnx.py +1 -1
  80. onnxruntime/transformers/models/sam2/image_decoder.py +2 -2
  81. onnxruntime/transformers/models/sam2/image_encoder.py +4 -4
  82. onnxruntime/transformers/models/sam2/mask_decoder.py +1 -1
  83. onnxruntime/transformers/models/sam2/prompt_encoder.py +1 -1
  84. onnxruntime/transformers/models/sam2/sam2_demo.py +1 -1
  85. onnxruntime/transformers/models/stable_diffusion/benchmark.py +42 -46
  86. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +6 -6
  87. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +3 -2
  88. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +4 -3
  89. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +2 -2
  90. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +1 -1
  91. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +1 -1
  92. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +1 -1
  93. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +8 -3
  94. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +1 -1
  95. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +4 -4
  96. onnxruntime/transformers/models/t5/t5_helper.py +2 -2
  97. onnxruntime/transformers/models/whisper/benchmark.py +3 -28
  98. onnxruntime/transformers/models/whisper/benchmark_all.py +4 -4
  99. onnxruntime/transformers/models/whisper/convert_to_onnx.py +106 -40
  100. onnxruntime/transformers/models/whisper/whisper_chain.py +10 -7
  101. onnxruntime/transformers/models/whisper/whisper_decoder.py +7 -8
  102. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +2 -2
  103. onnxruntime/transformers/models/whisper/whisper_helper.py +562 -10
  104. onnxruntime/transformers/models/whisper/whisper_inputs.py +3 -3
  105. onnxruntime/transformers/models/whisper/whisper_jump_times.py +2 -2
  106. onnxruntime/transformers/onnx_exporter.py +10 -10
  107. onnxruntime/transformers/onnx_model.py +15 -3
  108. onnxruntime/transformers/onnx_model_mmdit.py +2 -2
  109. onnxruntime/transformers/onnx_model_sam2.py +2 -2
  110. onnxruntime/transformers/onnx_model_t5.py +1 -1
  111. onnxruntime/transformers/onnx_model_unet.py +2 -2
  112. onnxruntime/transformers/optimizer.py +11 -14
  113. onnxruntime/transformers/profiler.py +4 -4
  114. onnxruntime/transformers/quantize_helper.py +2 -2
  115. {onnxruntime_directml-1.22.1.dev20250710002.dist-info → onnxruntime_directml-1.24.1.dist-info}/METADATA +9 -5
  116. {onnxruntime_directml-1.22.1.dev20250710002.dist-info → onnxruntime_directml-1.24.1.dist-info}/RECORD +119 -109
  117. {onnxruntime_directml-1.22.1.dev20250710002.dist-info → onnxruntime_directml-1.24.1.dist-info}/WHEEL +1 -1
  118. {onnxruntime_directml-1.22.1.dev20250710002.dist-info → onnxruntime_directml-1.24.1.dist-info}/entry_points.txt +0 -0
  119. {onnxruntime_directml-1.22.1.dev20250710002.dist-info → onnxruntime_directml-1.24.1.dist-info}/top_level.txt +0 -0
@@ -5806,41 +5806,6 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
5806
5806
 
5807
5807
  _____
5808
5808
 
5809
- composable_kernel
5810
-
5811
- https://github.com/ROCmSoftwarePlatform/composable_kernel
5812
-
5813
- Copyright (c) 2018- , Advanced Micro Devices, Inc. (Chao Liu, Jing Zhang)
5814
- Copyright (c) 2019- , Advanced Micro Devices, Inc. (Letao Qin, Qianfeng Zhang, Liang Huang, Shaojie Wang)
5815
- Copyright (c) 2022- , Advanced Micro Devices, Inc. (Anthony Chang, Chunyu Lai, Illia Silin, Adam Osewski, Poyen Chen, Jehandad Khan)
5816
- Copyright (c) 2019-2021, Advanced Micro Devices, Inc. (Hanwen Chang)
5817
- Copyright (c) 2019-2020, Advanced Micro Devices, Inc. (Tejash Shah)
5818
- Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou)
5819
- Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan)
5820
-
5821
- SPDX-License-Identifier: MIT
5822
- Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
5823
-
5824
- Permission is hereby granted, free of charge, to any person obtaining a copy
5825
- of this software and associated documentation files (the "Software"), to deal
5826
- in the Software without restriction, including without limitation the rights
5827
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
5828
- copies of the Software, and to permit persons to whom the Software is
5829
- furnished to do so, subject to the following conditions:
5830
-
5831
- The above copyright notice and this permission notice shall be included in all
5832
- copies or substantial portions of the Software.
5833
-
5834
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
5835
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
5836
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
5837
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
5838
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
5839
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
5840
- SOFTWARE.
5841
-
5842
- _____
5843
-
5844
5809
  neural-speed
5845
5810
 
5846
5811
  https://github.com/intel/neural-speed
onnxruntime/__init__.py CHANGED
@@ -8,7 +8,9 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime <https://ak
8
8
  or the `Github project <https://github.com/microsoft/onnxruntime/>`_.
9
9
  """
10
10
 
11
- __version__ = "1.22.1"
11
+ import contextlib
12
+
13
+ __version__ = "1.24.1"
12
14
  __author__ = "Microsoft"
13
15
 
14
16
  # we need to do device version validation (for example to check Cuda version for an onnxruntime-training package).
@@ -30,9 +32,20 @@ try:
30
32
  NodeArg, # noqa: F401
31
33
  OrtAllocatorType, # noqa: F401
32
34
  OrtArenaCfg, # noqa: F401
35
+ OrtCompileApiFlags, # noqa: F401
36
+ OrtDeviceMemoryType, # noqa: F401
37
+ OrtEpAssignedNode, # noqa: F401
38
+ OrtEpAssignedSubgraph, # noqa: F401
39
+ OrtEpDevice, # noqa: F401
40
+ OrtExecutionProviderDevicePolicy, # noqa: F401
41
+ OrtExternalInitializerInfo, # noqa: F401
42
+ OrtHardwareDevice, # noqa: F401
43
+ OrtHardwareDeviceType, # noqa: F401
33
44
  OrtMemoryInfo, # noqa: F401
45
+ OrtMemoryInfoDeviceType, # noqa: F401
34
46
  OrtMemType, # noqa: F401
35
47
  OrtSparseFormat, # noqa: F401
48
+ OrtSyncStream, # noqa: F401
36
49
  RunOptions, # noqa: F401
37
50
  SessionIOBinding, # noqa: F401
38
51
  SessionOptions, # noqa: F401
@@ -44,11 +57,15 @@ try:
44
57
  get_available_providers, # noqa: F401
45
58
  get_build_info, # noqa: F401
46
59
  get_device, # noqa: F401
60
+ get_ep_devices, # noqa: F401
47
61
  get_version_string, # noqa: F401
48
62
  has_collective_ops, # noqa: F401
63
+ register_execution_provider_library, # noqa: F401
49
64
  set_default_logger_severity, # noqa: F401
50
65
  set_default_logger_verbosity, # noqa: F401
66
+ set_global_thread_pool_sizes, # noqa: F401
51
67
  set_seed, # noqa: F401
68
+ unregister_execution_provider_library, # noqa: F401
52
69
  )
53
70
 
54
71
  import_capi_exception = None
@@ -64,9 +81,11 @@ from onnxruntime.capi.onnxruntime_inference_collection import (
64
81
  AdapterFormat, # noqa: F401
65
82
  InferenceSession, # noqa: F401
66
83
  IOBinding, # noqa: F401
84
+ ModelCompiler, # noqa: F401
67
85
  OrtDevice, # noqa: F401
68
86
  OrtValue, # noqa: F401
69
87
  SparseTensor, # noqa: F401
88
+ copy_tensors, # noqa: F401
70
89
  )
71
90
 
72
91
  # TODO: thiagofc: Temporary experimental namespace for new PyTorch front-end
@@ -85,7 +104,7 @@ onnxruntime_validation.check_distro_info()
85
104
 
86
105
 
87
106
  def _get_package_version(package_name: str):
88
- from importlib.metadata import PackageNotFoundError, version
107
+ from importlib.metadata import PackageNotFoundError, version # noqa: PLC0415
89
108
 
90
109
  try:
91
110
  package_version = version(package_name)
@@ -95,7 +114,7 @@ def _get_package_version(package_name: str):
95
114
 
96
115
 
97
116
  def _get_package_root(package_name: str, directory_name: str | None = None):
98
- from importlib.metadata import PackageNotFoundError, distribution
117
+ from importlib.metadata import PackageNotFoundError, distribution # noqa: PLC0415
99
118
 
100
119
  root_directory_name = directory_name or package_name
101
120
  try:
@@ -118,14 +137,43 @@ def _get_package_root(package_name: str, directory_name: str | None = None):
118
137
  return None
119
138
 
120
139
 
140
+ def _extract_cuda_major_version(version_str: str) -> str:
141
+ """Extract CUDA major version from version string (e.g., '12.1' -> '12').
142
+
143
+ Args:
144
+ version_str: CUDA version string to parse
145
+
146
+ Returns:
147
+ Major version as string, or "12" if parsing fails
148
+ """
149
+ return version_str.split(".")[0] if version_str else "12"
150
+
151
+
152
+ def _get_cufft_version(cuda_major: str) -> str:
153
+ """Get cufft library version based on CUDA major version.
154
+
155
+ Args:
156
+ cuda_major: CUDA major version as string (e.g., "12", "13")
157
+
158
+ Returns:
159
+ cufft version as string
160
+ """
161
+ # cufft versions: CUDA 12.x -> 11, CUDA 13.x -> 12
162
+ return "12" if cuda_major == "13" else "11"
163
+
164
+
121
165
  def _get_nvidia_dll_paths(is_windows: bool, cuda: bool = True, cudnn: bool = True):
166
+ # Dynamically determine CUDA major version from build info
167
+ cuda_major_version = _extract_cuda_major_version(cuda_version)
168
+ cufft_version = _get_cufft_version(cuda_major_version)
169
+
122
170
  if is_windows:
123
171
  # Path is relative to site-packages directory.
124
172
  cuda_dll_paths = [
125
- ("nvidia", "cublas", "bin", "cublasLt64_12.dll"),
126
- ("nvidia", "cublas", "bin", "cublas64_12.dll"),
127
- ("nvidia", "cufft", "bin", "cufft64_11.dll"),
128
- ("nvidia", "cuda_runtime", "bin", "cudart64_12.dll"),
173
+ ("nvidia", "cublas", "bin", f"cublasLt64_{cuda_major_version}.dll"),
174
+ ("nvidia", "cublas", "bin", f"cublas64_{cuda_major_version}.dll"),
175
+ ("nvidia", "cufft", "bin", f"cufft64_{cufft_version}.dll"),
176
+ ("nvidia", "cuda_runtime", "bin", f"cudart64_{cuda_major_version}.dll"),
129
177
  ]
130
178
  cudnn_dll_paths = [
131
179
  ("nvidia", "cudnn", "bin", "cudnn_engines_runtime_compiled64_9.dll"),
@@ -139,12 +187,12 @@ def _get_nvidia_dll_paths(is_windows: bool, cuda: bool = True, cudnn: bool = Tru
139
187
  else: # Linux
140
188
  # cublas64 depends on cublasLt64, so cublasLt64 should be loaded first.
141
189
  cuda_dll_paths = [
142
- ("nvidia", "cublas", "lib", "libcublasLt.so.12"),
143
- ("nvidia", "cublas", "lib", "libcublas.so.12"),
144
- ("nvidia", "cuda_nvrtc", "lib", "libnvrtc.so.12"),
190
+ ("nvidia", "cublas", "lib", f"libcublasLt.so.{cuda_major_version}"),
191
+ ("nvidia", "cublas", "lib", f"libcublas.so.{cuda_major_version}"),
192
+ ("nvidia", "cuda_nvrtc", "lib", f"libnvrtc.so.{cuda_major_version}"),
145
193
  ("nvidia", "curand", "lib", "libcurand.so.10"),
146
- ("nvidia", "cufft", "lib", "libcufft.so.11"),
147
- ("nvidia", "cuda_runtime", "lib", "libcudart.so.12"),
194
+ ("nvidia", "cufft", "lib", f"libcufft.so.{cufft_version}"),
195
+ ("nvidia", "cuda_runtime", "lib", f"libcudart.so.{cuda_major_version}"),
148
196
  ]
149
197
 
150
198
  # Do not load cudnn sub DLLs (they will be dynamically loaded later) to be consistent with PyTorch in Linux.
@@ -157,10 +205,10 @@ def _get_nvidia_dll_paths(is_windows: bool, cuda: bool = True, cudnn: bool = Tru
157
205
 
158
206
  def print_debug_info():
159
207
  """Print information to help debugging."""
160
- import importlib.util
161
- import os
162
- import platform
163
- from importlib.metadata import distributions
208
+ import importlib.util # noqa: PLC0415
209
+ import os # noqa: PLC0415
210
+ import platform # noqa: PLC0415
211
+ from importlib.metadata import distributions # noqa: PLC0415
164
212
 
165
213
  print(f"{package_name} version: {__version__}")
166
214
  if cuda_version:
@@ -186,15 +234,17 @@ def print_debug_info():
186
234
 
187
235
  if cuda_version:
188
236
  # Print version of installed packages that is related to CUDA or cuDNN DLLs.
237
+ cuda_major = _extract_cuda_major_version(cuda_version)
238
+
189
239
  packages = [
190
240
  "torch",
191
- "nvidia-cuda-runtime-cu12",
192
- "nvidia-cudnn-cu12",
193
- "nvidia-cublas-cu12",
194
- "nvidia-cufft-cu12",
195
- "nvidia-curand-cu12",
196
- "nvidia-cuda-nvrtc-cu12",
197
- "nvidia-nvjitlink-cu12",
241
+ f"nvidia-cuda-runtime-cu{cuda_major}",
242
+ f"nvidia-cudnn-cu{cuda_major}",
243
+ f"nvidia-cublas-cu{cuda_major}",
244
+ f"nvidia-cufft-cu{cuda_major}",
245
+ f"nvidia-curand-cu{cuda_major}",
246
+ f"nvidia-cuda-nvrtc-cu{cuda_major}",
247
+ f"nvidia-nvjitlink-cu{cuda_major}",
198
248
  ]
199
249
  for package in packages:
200
250
  directory_name = "nvidia" if package.startswith("nvidia-") else None
@@ -205,9 +255,9 @@ def print_debug_info():
205
255
  print(f"{package} not installed")
206
256
 
207
257
  if platform.system() == "Windows":
208
- print(f"\nEnvironment variable:\nPATH={os.environ['PATH']}")
258
+ print(f"\nEnvironment variable:\nPATH={os.environ.get('PATH', '(unset)')}")
209
259
  elif platform.system() == "Linux":
210
- print(f"\nEnvironment variable:\nLD_LIBRARY_PATH={os.environ['LD_LIBRARY_PATH']}")
260
+ print(f"\nEnvironment variable:\nLD_LIBRARY_PATH={os.environ.get('LD_LIBRARY_PATH', '(unset)')}")
211
261
 
212
262
  if importlib.util.find_spec("psutil"):
213
263
 
@@ -217,7 +267,7 @@ def print_debug_info():
217
267
  target_keywords = ["cufft", "cublas", "cudart", "nvrtc", "curand", "cudnn", *target_keywords]
218
268
  return any(keyword in path for keyword in target_keywords)
219
269
 
220
- import psutil
270
+ import psutil # noqa: PLC0415
221
271
 
222
272
  p = psutil.Process(os.getpid())
223
273
 
@@ -228,7 +278,7 @@ def print_debug_info():
228
278
 
229
279
  if cuda_version:
230
280
  if importlib.util.find_spec("cpuinfo") and importlib.util.find_spec("py3nvml"):
231
- from .transformers.machine_info import get_device_info
281
+ from .transformers.machine_info import get_device_info # noqa: PLC0415
232
282
 
233
283
  print("\nDevice information:")
234
284
  print(get_device_info())
@@ -239,7 +289,7 @@ def print_debug_info():
239
289
 
240
290
 
241
291
  def preload_dlls(cuda: bool = True, cudnn: bool = True, msvc: bool = True, directory=None):
242
- """Preload CUDA 12.x and cuDNN 9.x DLLs in Windows or Linux, and MSVC runtime DLLs in Windows.
292
+ """Preload CUDA 12.x+ and cuDNN 9.x DLLs in Windows or Linux, and MSVC runtime DLLs in Windows.
243
293
 
244
294
  When the installed PyTorch is compatible (using same major version of CUDA and cuDNN),
245
295
  there is no need to call this function if `import torch` is done before `import onnxruntime`.
@@ -255,10 +305,10 @@ def preload_dlls(cuda: bool = True, cudnn: bool = True, msvc: bool = True, direc
255
305
  If directory is empty string (""), the search order: nvidia site packages, default DLL loading paths.
256
306
  If directory is a path, the search order: the directory, default DLL loading paths.
257
307
  """
258
- import ctypes
259
- import os
260
- import platform
261
- import sys
308
+ import ctypes # noqa: PLC0415
309
+ import os # noqa: PLC0415
310
+ import platform # noqa: PLC0415
311
+ import sys # noqa: PLC0415
262
312
 
263
313
  if platform.system() not in ["Windows", "Linux"]:
264
314
  return
@@ -274,30 +324,53 @@ def preload_dlls(cuda: bool = True, cudnn: bool = True, msvc: bool = True, direc
274
324
  print("Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure.")
275
325
  print("It can be downloaded at https://aka.ms/vs/17/release/vc_redist.x64.exe.")
276
326
 
277
- if not (cuda_version and cuda_version.startswith("12.")) and (cuda or cudnn):
278
- print(
279
- f"\033[33mWARNING: {package_name} is not built with CUDA 12.x support. "
280
- "Please install a version that supports CUDA 12.x, or call preload_dlls with cuda=False and cudnn=False.\033[0m"
281
- )
282
- return
283
-
284
- if not (cuda_version and cuda_version.startswith("12.") and (cuda or cudnn)):
327
+ # Check if CUDA version is supported (12.x or 13.x+)
328
+ ort_cuda_major = None
329
+ if cuda_version:
330
+ try:
331
+ ort_cuda_major = int(cuda_version.split(".")[0])
332
+ if ort_cuda_major < 12 and (cuda or cudnn):
333
+ print(
334
+ f"\033[33mWARNING: {package_name} is built with CUDA {cuda_version}, which is not supported for preloading. "
335
+ f"CUDA 12.x or newer is required. Call preload_dlls with cuda=False and cudnn=False.\033[0m"
336
+ )
337
+ return
338
+ except ValueError:
339
+ print(
340
+ f"\033[33mWARNING: Unable to parse CUDA version '{cuda_version}'. "
341
+ "Skipping DLL preloading. Call preload_dlls with cuda=False and cudnn=False.\033[0m"
342
+ )
343
+ return
344
+ elif cuda or cudnn:
345
+ # No CUDA version info available but CUDA/cuDNN preloading requested
285
346
  return
286
347
 
287
348
  is_cuda_cudnn_imported_by_torch = False
288
349
 
289
350
  if is_windows:
290
351
  torch_version = _get_package_version("torch")
291
- is_torch_for_cuda_12 = torch_version and "+cu12" in torch_version
352
+ # Check if torch CUDA version matches onnxruntime CUDA version
353
+ torch_cuda_major = None
354
+ if torch_version and "+cu" in torch_version:
355
+ with contextlib.suppress(ValueError):
356
+ # Extract CUDA version from torch (e.g., "2.0.0+cu121" -> 12)
357
+ cu_part = torch_version.split("+cu")[1]
358
+ torch_cuda_major = int(cu_part[:2]) # First 2 digits are major version
359
+
360
+ is_torch_cuda_compatible = (
361
+ torch_cuda_major == ort_cuda_major if (torch_cuda_major and ort_cuda_major) else False
362
+ )
363
+
292
364
  if "torch" in sys.modules:
293
- is_cuda_cudnn_imported_by_torch = is_torch_for_cuda_12
294
- if (torch_version and "+cu" in torch_version) and not is_torch_for_cuda_12:
365
+ is_cuda_cudnn_imported_by_torch = is_torch_cuda_compatible
366
+ if torch_cuda_major and ort_cuda_major and torch_cuda_major != ort_cuda_major:
295
367
  print(
296
- f"\033[33mWARNING: The installed PyTorch {torch_version} does not support CUDA 12.x. "
297
- f"Please install PyTorch for CUDA 12.x to be compatible with {package_name}.\033[0m"
368
+ f"\033[33mWARNING: The installed PyTorch {torch_version} uses CUDA {torch_cuda_major}.x, "
369
+ f"but {package_name} is built with CUDA {ort_cuda_major}.x. "
370
+ f"Please install PyTorch for CUDA {ort_cuda_major}.x to be compatible.\033[0m"
298
371
  )
299
372
 
300
- if is_torch_for_cuda_12 and directory is None:
373
+ if is_torch_cuda_compatible and directory is None:
301
374
  torch_root = _get_package_root("torch", "torch")
302
375
  if torch_root:
303
376
  directory = os.path.join(torch_root, "lib")
Binary file
@@ -1,2 +1,2 @@
1
1
  package_name = 'onnxruntime-directml'
2
- __version__ = '1.22.1.dev20250710002'
2
+ __version__ = '1.24.1'
Binary file