onnxruntime-directml 1.22.1.dev20250710002__cp313-cp313-win_amd64.whl → 1.23.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. onnxruntime/__init__.py +24 -13
  2. onnxruntime/capi/DirectML.dll +0 -0
  3. onnxruntime/capi/build_and_package_info.py +1 -1
  4. onnxruntime/capi/onnxruntime.dll +0 -0
  5. onnxruntime/capi/onnxruntime_inference_collection.py +264 -35
  6. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  7. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  8. onnxruntime/capi/onnxruntime_validation.py +8 -8
  9. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +1 -1
  10. onnxruntime/quantization/base_quantizer.py +3 -38
  11. onnxruntime/quantization/calibrate.py +3 -3
  12. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  13. onnxruntime/quantization/execution_providers/qnn/preprocess.py +28 -0
  14. onnxruntime/quantization/fusions/__init__.py +1 -0
  15. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  16. onnxruntime/quantization/matmul_bnb4_quantizer.py +1 -1
  17. onnxruntime/quantization/matmul_nbits_quantizer.py +120 -38
  18. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  19. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  20. onnxruntime/quantization/neural_compressor/util.py +80 -0
  21. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  22. onnxruntime/quantization/onnx_model.py +1 -1
  23. onnxruntime/quantization/onnx_quantizer.py +156 -1
  24. onnxruntime/quantization/operators/gemm.py +3 -3
  25. onnxruntime/quantization/quant_utils.py +58 -13
  26. onnxruntime/quantization/quantize.py +16 -6
  27. onnxruntime/quantization/registry.py +1 -0
  28. onnxruntime/quantization/shape_inference.py +18 -1
  29. onnxruntime/quantization/tensor_quant_overrides.py +1 -1
  30. onnxruntime/tools/convert_onnx_models_to_ort.py +6 -3
  31. onnxruntime/tools/mobile_helpers/usability_checker.py +1 -1
  32. onnxruntime/tools/onnx_model_utils.py +3 -0
  33. onnxruntime/tools/optimize_onnx_model.py +1 -1
  34. onnxruntime/tools/ort_format_model/utils.py +1 -2
  35. onnxruntime/tools/pytorch_export_contrib_ops.py +1 -1
  36. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  37. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  38. onnxruntime/tools/qnn/preprocess.py +165 -0
  39. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  40. onnxruntime/tools/symbolic_shape_infer.py +13 -12
  41. onnxruntime/transformers/benchmark.py +6 -6
  42. onnxruntime/transformers/benchmark_helper.py +5 -5
  43. onnxruntime/transformers/bert_perf_test.py +5 -3
  44. onnxruntime/transformers/bert_test_data.py +1 -1
  45. onnxruntime/transformers/compare_bert_results.py +1 -1
  46. onnxruntime/transformers/convert_generation.py +106 -48
  47. onnxruntime/transformers/convert_tf_models_to_pytorch.py +8 -8
  48. onnxruntime/transformers/dynamo_onnx_helper.py +1 -1
  49. onnxruntime/transformers/fusion_attention.py +2 -2
  50. onnxruntime/transformers/fusion_attention_clip.py +38 -32
  51. onnxruntime/transformers/fusion_bart_attention.py +205 -414
  52. onnxruntime/transformers/fusion_nhwc_conv.py +1 -1
  53. onnxruntime/transformers/io_binding_helper.py +1 -0
  54. onnxruntime/transformers/machine_info.py +4 -4
  55. onnxruntime/transformers/models/bert/eval_squad.py +1 -1
  56. onnxruntime/transformers/models/gpt2/gpt2_parity.py +1 -1
  57. onnxruntime/transformers/models/gpt2/gpt2_tester.py +3 -3
  58. onnxruntime/transformers/models/gpt2/parity_check_helper.py +2 -2
  59. onnxruntime/transformers/models/llama/benchmark_all.py +2 -2
  60. onnxruntime/transformers/models/llama/convert_to_onnx.py +17 -61
  61. onnxruntime/transformers/models/llama/dist_settings.py +4 -4
  62. onnxruntime/transformers/models/llama/llama_parity.py +6 -5
  63. onnxruntime/transformers/models/longformer/benchmark_longformer.py +2 -2
  64. onnxruntime/transformers/models/phi2/convert_to_onnx.py +6 -6
  65. onnxruntime/transformers/models/phi2/inference_example.py +3 -3
  66. onnxruntime/transformers/models/sam2/benchmark_sam2.py +6 -6
  67. onnxruntime/transformers/models/sam2/convert_to_onnx.py +1 -1
  68. onnxruntime/transformers/models/sam2/image_decoder.py +2 -2
  69. onnxruntime/transformers/models/sam2/image_encoder.py +4 -4
  70. onnxruntime/transformers/models/sam2/mask_decoder.py +1 -1
  71. onnxruntime/transformers/models/sam2/prompt_encoder.py +1 -1
  72. onnxruntime/transformers/models/sam2/sam2_demo.py +1 -1
  73. onnxruntime/transformers/models/stable_diffusion/benchmark.py +37 -38
  74. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +6 -6
  75. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +1 -1
  76. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +2 -2
  77. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +1 -1
  78. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +1 -1
  79. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +1 -1
  80. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +0 -1
  81. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +1 -1
  82. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +4 -4
  83. onnxruntime/transformers/models/t5/t5_helper.py +2 -2
  84. onnxruntime/transformers/models/whisper/benchmark_all.py +2 -2
  85. onnxruntime/transformers/models/whisper/convert_to_onnx.py +31 -1
  86. onnxruntime/transformers/models/whisper/whisper_decoder.py +7 -8
  87. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +2 -2
  88. onnxruntime/transformers/models/whisper/whisper_helper.py +561 -9
  89. onnxruntime/transformers/models/whisper/whisper_inputs.py +3 -3
  90. onnxruntime/transformers/models/whisper/whisper_jump_times.py +2 -2
  91. onnxruntime/transformers/onnx_exporter.py +10 -10
  92. onnxruntime/transformers/onnx_model.py +15 -3
  93. onnxruntime/transformers/onnx_model_mmdit.py +2 -2
  94. onnxruntime/transformers/onnx_model_sam2.py +2 -2
  95. onnxruntime/transformers/onnx_model_t5.py +1 -1
  96. onnxruntime/transformers/onnx_model_unet.py +2 -2
  97. onnxruntime/transformers/optimizer.py +6 -4
  98. onnxruntime/transformers/profiler.py +4 -4
  99. onnxruntime/transformers/quantize_helper.py +2 -2
  100. {onnxruntime_directml-1.22.1.dev20250710002.dist-info → onnxruntime_directml-1.23.0.dist-info}/METADATA +3 -3
  101. {onnxruntime_directml-1.22.1.dev20250710002.dist-info → onnxruntime_directml-1.23.0.dist-info}/RECORD +104 -94
  102. {onnxruntime_directml-1.22.1.dev20250710002.dist-info → onnxruntime_directml-1.23.0.dist-info}/WHEEL +0 -0
  103. {onnxruntime_directml-1.22.1.dev20250710002.dist-info → onnxruntime_directml-1.23.0.dist-info}/entry_points.txt +0 -0
  104. {onnxruntime_directml-1.22.1.dev20250710002.dist-info → onnxruntime_directml-1.23.0.dist-info}/top_level.txt +0 -0
onnxruntime/__init__.py CHANGED
@@ -8,7 +8,7 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime <https://ak
8
8
  or the `Github project <https://github.com/microsoft/onnxruntime/>`_.
9
9
  """
10
10
 
11
- __version__ = "1.22.1"
11
+ __version__ = "1.23.0"
12
12
  __author__ = "Microsoft"
13
13
 
14
14
  # we need to do device version validation (for example to check Cuda version for an onnxruntime-training package).
@@ -30,6 +30,12 @@ try:
30
30
  NodeArg, # noqa: F401
31
31
  OrtAllocatorType, # noqa: F401
32
32
  OrtArenaCfg, # noqa: F401
33
+ OrtCompileApiFlags, # noqa: F401
34
+ OrtEpDevice, # noqa: F401
35
+ OrtExecutionProviderDevicePolicy, # noqa: F401
36
+ OrtExternalInitializerInfo, # noqa: F401
37
+ OrtHardwareDevice, # noqa: F401
38
+ OrtHardwareDeviceType, # noqa: F401
33
39
  OrtMemoryInfo, # noqa: F401
34
40
  OrtMemType, # noqa: F401
35
41
  OrtSparseFormat, # noqa: F401
@@ -44,11 +50,15 @@ try:
44
50
  get_available_providers, # noqa: F401
45
51
  get_build_info, # noqa: F401
46
52
  get_device, # noqa: F401
53
+ get_ep_devices, # noqa: F401
47
54
  get_version_string, # noqa: F401
48
55
  has_collective_ops, # noqa: F401
56
+ register_execution_provider_library, # noqa: F401
49
57
  set_default_logger_severity, # noqa: F401
50
58
  set_default_logger_verbosity, # noqa: F401
59
+ set_global_thread_pool_sizes, # noqa: F401
51
60
  set_seed, # noqa: F401
61
+ unregister_execution_provider_library, # noqa: F401
52
62
  )
53
63
 
54
64
  import_capi_exception = None
@@ -64,6 +74,7 @@ from onnxruntime.capi.onnxruntime_inference_collection import (
64
74
  AdapterFormat, # noqa: F401
65
75
  InferenceSession, # noqa: F401
66
76
  IOBinding, # noqa: F401
77
+ ModelCompiler, # noqa: F401
67
78
  OrtDevice, # noqa: F401
68
79
  OrtValue, # noqa: F401
69
80
  SparseTensor, # noqa: F401
@@ -85,7 +96,7 @@ onnxruntime_validation.check_distro_info()
85
96
 
86
97
 
87
98
  def _get_package_version(package_name: str):
88
- from importlib.metadata import PackageNotFoundError, version
99
+ from importlib.metadata import PackageNotFoundError, version # noqa: PLC0415
89
100
 
90
101
  try:
91
102
  package_version = version(package_name)
@@ -95,7 +106,7 @@ def _get_package_version(package_name: str):
95
106
 
96
107
 
97
108
  def _get_package_root(package_name: str, directory_name: str | None = None):
98
- from importlib.metadata import PackageNotFoundError, distribution
109
+ from importlib.metadata import PackageNotFoundError, distribution # noqa: PLC0415
99
110
 
100
111
  root_directory_name = directory_name or package_name
101
112
  try:
@@ -157,10 +168,10 @@ def _get_nvidia_dll_paths(is_windows: bool, cuda: bool = True, cudnn: bool = Tru
157
168
 
158
169
  def print_debug_info():
159
170
  """Print information to help debugging."""
160
- import importlib.util
161
- import os
162
- import platform
163
- from importlib.metadata import distributions
171
+ import importlib.util # noqa: PLC0415
172
+ import os # noqa: PLC0415
173
+ import platform # noqa: PLC0415
174
+ from importlib.metadata import distributions # noqa: PLC0415
164
175
 
165
176
  print(f"{package_name} version: {__version__}")
166
177
  if cuda_version:
@@ -217,7 +228,7 @@ def print_debug_info():
217
228
  target_keywords = ["cufft", "cublas", "cudart", "nvrtc", "curand", "cudnn", *target_keywords]
218
229
  return any(keyword in path for keyword in target_keywords)
219
230
 
220
- import psutil
231
+ import psutil # noqa: PLC0415
221
232
 
222
233
  p = psutil.Process(os.getpid())
223
234
 
@@ -228,7 +239,7 @@ def print_debug_info():
228
239
 
229
240
  if cuda_version:
230
241
  if importlib.util.find_spec("cpuinfo") and importlib.util.find_spec("py3nvml"):
231
- from .transformers.machine_info import get_device_info
242
+ from .transformers.machine_info import get_device_info # noqa: PLC0415
232
243
 
233
244
  print("\nDevice information:")
234
245
  print(get_device_info())
@@ -255,10 +266,10 @@ def preload_dlls(cuda: bool = True, cudnn: bool = True, msvc: bool = True, direc
255
266
  If directory is empty string (""), the search order: nvidia site packages, default DLL loading paths.
256
267
  If directory is a path, the search order: the directory, default DLL loading paths.
257
268
  """
258
- import ctypes
259
- import os
260
- import platform
261
- import sys
269
+ import ctypes # noqa: PLC0415
270
+ import os # noqa: PLC0415
271
+ import platform # noqa: PLC0415
272
+ import sys # noqa: PLC0415
262
273
 
263
274
  if platform.system() not in ["Windows", "Linux"]:
264
275
  return
Binary file
@@ -1,2 +1,2 @@
1
1
  package_name = 'onnxruntime-directml'
2
- __version__ = '1.22.1.dev20250710002'
2
+ __version__ = '1.23.0'
Binary file
@@ -9,7 +9,7 @@ import collections.abc
9
9
  import os
10
10
  import typing
11
11
  import warnings
12
- from collections.abc import Sequence
12
+ from collections.abc import Callable, Sequence
13
13
  from typing import Any
14
14
 
15
15
  from onnxruntime.capi import _pybind_state as C
@@ -21,7 +21,7 @@ if typing.TYPE_CHECKING:
21
21
  import onnxruntime
22
22
 
23
23
 
24
- def get_ort_device_type(device_type: str, device_index) -> C.OrtDevice:
24
+ def get_ort_device_type(device_type: str) -> int:
25
25
  if device_type == "cuda":
26
26
  return C.OrtDevice.cuda()
27
27
  elif device_type == "cann":
@@ -32,8 +32,10 @@ def get_ort_device_type(device_type: str, device_index) -> C.OrtDevice:
32
32
  return C.OrtDevice.dml()
33
33
  elif device_type == "webgpu":
34
34
  return C.OrtDevice.webgpu()
35
- elif device_type == "ort":
36
- return C.get_ort_device(device_index).device_type()
35
+ elif device_type == "gpu":
36
+ return C.OrtDevice.gpu()
37
+ elif device_type == "npu":
38
+ return C.OrtDevice.npu()
37
39
  else:
38
40
  raise Exception("Unsupported device type: " + device_type)
39
41
 
@@ -172,10 +174,10 @@ class Session:
172
174
  This is the main class used to run a model.
173
175
  """
174
176
 
175
- def __init__(self):
177
+ def __init__(self, enable_fallback: bool = True):
176
178
  # self._sess is managed by the derived class and relies on bindings from C.InferenceSession
177
179
  self._sess = None
178
- self._enable_fallback = True
180
+ self._enable_fallback = enable_fallback
179
181
 
180
182
  def get_session_options(self) -> onnxruntime.SessionOptions:
181
183
  "Return the session options. See :class:`onnxruntime.SessionOptions`."
@@ -446,7 +448,7 @@ class InferenceSession(Session):
446
448
  means execute a node using `CUDAExecutionProvider`
447
449
  if capable, otherwise execute using `CPUExecutionProvider`.
448
450
  """
449
- super().__init__()
451
+ super().__init__(enable_fallback=int(kwargs.get("enable_fallback", 1)) == 1)
450
452
 
451
453
  if isinstance(path_or_bytes, (str, os.PathLike)):
452
454
  self._model_path = os.fspath(path_or_bytes)
@@ -459,7 +461,6 @@ class InferenceSession(Session):
459
461
 
460
462
  self._sess_options = sess_options
461
463
  self._sess_options_initial = sess_options
462
- self._enable_fallback = True
463
464
  if "read_config_from_model" in kwargs:
464
465
  self._read_config_from_model = int(kwargs["read_config_from_model"]) == 1
465
466
  else:
@@ -542,6 +543,16 @@ class InferenceSession(Session):
542
543
  providers, provider_options, available_providers
543
544
  )
544
545
 
546
+ # Print a warning if user passed providers to InferenceSession() but the SessionOptions instance
547
+ # already has provider information (e.g., via add_provider_for_devices()). The providers specified
548
+ # here will take precedence.
549
+ if self._sess_options is not None and (providers or provider_options) and self._sess_options.has_providers():
550
+ warnings.warn(
551
+ "Specified 'providers'/'provider_options' when creating InferenceSession but SessionOptions has "
552
+ "already been configured with providers. InferenceSession will only use the providers "
553
+ "passed to InferenceSession()."
554
+ )
555
+
545
556
  session_options = self._sess_options if self._sess_options else C.get_default_session_options()
546
557
 
547
558
  self._register_ep_custom_ops(session_options, providers, provider_options, available_providers)
@@ -609,6 +620,197 @@ class InferenceSession(Session):
609
620
  C.register_nv_tensorrt_rtx_plugins_as_custom_ops(session_options, providers[i][1])
610
621
 
611
622
 
623
+ def make_get_initializer_location_func_wrapper(
624
+ get_initializer_location_func: GetInitializerLocationFunc,
625
+ ) -> GetInitializerLocationWrapperFunc:
626
+ """
627
+ Wraps a user's "get initializer location" function. The returned wrapper function adheres to the
628
+ signature expected by ORT.
629
+
630
+ Need this wrapper to:
631
+ - Convert the `initializer_value` parameter from `C.OrtValue` to `onnxruntime.OrtValue`, which is more
632
+ convenient for the user's function to use.
633
+ - Allow the user's function to return the original `external_info` parameter (this wrapper makes a copy)
634
+ """
635
+
636
+ def get_initializer_location_func_wrapper(
637
+ initializer_name: str,
638
+ initializer_value: C.OrtValue,
639
+ external_info: C.OrtExternalInitializerInfo | None,
640
+ ) -> C.OrtExternalInitializerInfo | None:
641
+ ret_val: C.OrtExternalInitializerInfo | None = get_initializer_location_func(
642
+ initializer_name, OrtValue(initializer_value), external_info
643
+ )
644
+ if ret_val is not None and ret_val == external_info:
645
+ # User returned `external_info` (const and owned by ORT). ORT expects the returned value to be
646
+ # a new instance (that it deletes), so make a copy.
647
+ ret_val = C.OrtExternalInitializerInfo(ret_val.filepath, ret_val.file_offset, ret_val.byte_size)
648
+ return ret_val
649
+
650
+ return get_initializer_location_func_wrapper
651
+
652
+
653
+ class ModelCompiler:
654
+ """
655
+ This class is used to compile an ONNX model. A compiled ONNX model has EPContext nodes that each
656
+ encapsulates a subgraph compiled/optimized for a specific execution provider.
657
+
658
+ Refer to the EPContext design document for more information about EPContext models:
659
+ https://onnxruntime.ai/docs/execution-providers/EP-Context-Design.html
660
+
661
+ ::
662
+
663
+ sess_options = onnxruntime.SessionOptions()
664
+ sess_options.add_provider("SomeExecutionProvider", {"option1": "value1"})
665
+ # Alternatively, allow ONNX Runtime to select the provider automatically given a policy:
666
+ # sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_NPU)
667
+
668
+ model_compiler = onnxruntime.ModelCompiler(sess_options, "input_model.onnx")
669
+ model_compiler.compile_to_file("output_model.onnx")
670
+ """
671
+
672
+ def __init__(
673
+ self,
674
+ sess_options: onnxruntime.SessionOptions,
675
+ input_model_path_or_bytes: str | os.PathLike | bytes,
676
+ embed_compiled_data_into_model: bool = False,
677
+ external_initializers_file_path: str | os.PathLike | None = None,
678
+ external_initializers_size_threshold: int = 1024,
679
+ flags: int = C.OrtCompileApiFlags.NONE,
680
+ graph_optimization_level: C.GraphOptimizationLevel = C.GraphOptimizationLevel.ORT_DISABLE_ALL,
681
+ get_initializer_location_func: GetInitializerLocationFunc | None = None,
682
+ ):
683
+ """
684
+ Creates a ModelCompiler instance.
685
+
686
+ :param sess_options: Session options containing the providers for which the model will be compiled.
687
+ Refer to SessionOptions.add_provider() and SessionOptions.set_provider_selection_policy().
688
+ :param input_model_path_or_bytes: The path to the input model file or bytes representing a serialized
689
+ ONNX model.
690
+ :param embed_compiled_data_into_model: Defaults to False. Set to True to embed compiled binary data into
691
+ EPContext nodes in the compiled model.
692
+ :param external_initializers_file_path: Defaults to None. Set to a path for a file that will store the
693
+ initializers for non-compiled nodes.
694
+ :param external_initializers_size_threshold: Defaults to 1024. Ignored if `external_initializers_file_path`
695
+ is None or empty. Initializers larger than this threshold are stored in the external initializers file.
696
+ :param flags: Additional boolean options to enable. Set this parameter to a bitwise OR of
697
+ flags in onnxruntime.OrtCompileApiFlags.
698
+ :param graph_optimization_level: The graph optimization level.
699
+ Defaults to onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL.
700
+ :param get_initializer_location_func: Optional function called for every initializer to allow user to specify
701
+ whether an initializer should be stored within the model or externally. Example:
702
+ ```
703
+ def get_initializer_location(
704
+ initializer_name: str,
705
+ initializer_value: onnxrt.OrtValue,
706
+ external_info: onnxrt.OrtExternalInitializerInfo | None,
707
+ ) -> onnxrt.OrtExternalInitializerInfo | None:
708
+ byte_size = initializer_value.tensor_size_in_bytes()
709
+
710
+ if byte_size < 64:
711
+ return None # Store small initializer within compiled model.
712
+
713
+ # Else, write initializer to new external file.
714
+ value_np = initializer_value.numpy()
715
+ file_offset = ext_init_file.tell()
716
+ ext_init_file.write(value_np.tobytes())
717
+ return onnxrt.OrtExternalInitializerInfo(initializer_file_path, file_offset, byte_size)
718
+ ```
719
+ """
720
+ input_model_path: str | os.PathLike | None = None
721
+ input_model_bytes: bytes | None = None
722
+ if isinstance(input_model_path_or_bytes, (str, os.PathLike)):
723
+ if not input_model_path_or_bytes:
724
+ raise ValueError("Input model path is empty")
725
+ input_model_path = os.fspath(input_model_path_or_bytes)
726
+ elif isinstance(input_model_path_or_bytes, bytes):
727
+ if len(input_model_path_or_bytes) == 0:
728
+ raise ValueError("Input model bytes array is empty")
729
+ input_model_bytes = input_model_path_or_bytes
730
+ else:
731
+ raise TypeError(f"Unable to load from type '{type(input_model_path_or_bytes)}'")
732
+
733
+ if external_initializers_file_path:
734
+ if not isinstance(external_initializers_file_path, (str, os.PathLike)):
735
+ arg_type = type(external_initializers_file_path)
736
+ raise TypeError(f"Output external initializer filepath is of unexpected type '{arg_type}'")
737
+ external_initializers_file_path = os.fspath(external_initializers_file_path)
738
+ else:
739
+ external_initializers_file_path = ""
740
+
741
+ if get_initializer_location_func is not None:
742
+ if external_initializers_file_path:
743
+ raise ValueError(
744
+ "Cannot initialize ModelCompiler with both `external_initializers_file_path` "
745
+ "and `get_initializer_location_func`"
746
+ )
747
+ self.get_initializer_location_func_wrapper = make_get_initializer_location_func_wrapper(
748
+ get_initializer_location_func
749
+ )
750
+ else:
751
+ self.get_initializer_location_func_wrapper = None
752
+
753
+ if input_model_path:
754
+ self._model_compiler = C.ModelCompiler(
755
+ sess_options,
756
+ input_model_path,
757
+ True, # is path
758
+ embed_compiled_data_into_model,
759
+ external_initializers_file_path,
760
+ external_initializers_size_threshold,
761
+ flags,
762
+ graph_optimization_level,
763
+ self.get_initializer_location_func_wrapper,
764
+ )
765
+ else:
766
+ self._model_compiler = C.ModelCompiler(
767
+ sess_options,
768
+ input_model_bytes,
769
+ False, # is bytes
770
+ embed_compiled_data_into_model,
771
+ external_initializers_file_path,
772
+ external_initializers_size_threshold,
773
+ flags,
774
+ graph_optimization_level,
775
+ self.get_initializer_location_func_wrapper,
776
+ )
777
+
778
+ def compile_to_file(self, output_model_path: str | None = None):
779
+ """
780
+ Compiles to an output file. If an output file path is not provided,
781
+ the output file path is generated based on the input model path by replacing
782
+ '.onnx' with '_ctx.onnx'. Ex: The generated output file is 'model_ctx.onnx' for
783
+ an input model with path 'model.onnx'.
784
+
785
+ Raises an 'InvalidArgument' exception if the compilation options are invalid.
786
+
787
+ :param output_model_path: Defaults to None. The path for the output/compiled model.
788
+ """
789
+ if output_model_path:
790
+ if not isinstance(output_model_path, (str, os.PathLike)):
791
+ raise TypeError(f"Output model's filepath is of unexpected type '{type(output_model_path)}'")
792
+ output_model_path = os.fspath(output_model_path)
793
+ self._model_compiler.compile_to_file(output_model_path)
794
+
795
+ def compile_to_bytes(self) -> bytes:
796
+ """
797
+ Compiles to bytes representing the serialized compiled ONNX model.
798
+
799
+ Raises an 'InvalidArgument' exception if the compilation options are invalid.
800
+
801
+ :return: A bytes object representing the compiled ONNX model.
802
+ """
803
+ return self._model_compiler.compile_to_bytes()
804
+
805
+ def compile_to_stream(self, write_function: Callable[[bytes], None]):
806
+ """
807
+ Compiles the input model and writes the serialized ONNX bytes to a stream using the provided write function.
808
+ Raises an 'InvalidArgument' exception if the compilation options are invalid.
809
+ :param write_function: A callable that accepts a bytes buffer to write.
810
+ """
811
+ self._model_compiler.compile_to_stream(write_function)
812
+
813
+
612
814
  class IOBinding:
613
815
  """
614
816
  This class provides API to bind input/output to a specified device, e.g. GPU.
@@ -642,7 +844,7 @@ class IOBinding:
642
844
  self._iobinding.bind_input(
643
845
  name,
644
846
  C.OrtDevice(
645
- get_ort_device_type(device_type, device_id),
847
+ get_ort_device_type(device_type),
646
848
  C.OrtDevice.default_memory(),
647
849
  device_id,
648
850
  ),
@@ -689,7 +891,7 @@ class IOBinding:
689
891
  self._iobinding.bind_output(
690
892
  name,
691
893
  C.OrtDevice(
692
- get_ort_device_type(device_type, device_id),
894
+ get_ort_device_type(device_type),
693
895
  C.OrtDevice.default_memory(),
694
896
  device_id,
695
897
  ),
@@ -700,7 +902,7 @@ class IOBinding:
700
902
  self._iobinding.bind_output(
701
903
  name,
702
904
  C.OrtDevice(
703
- get_ort_device_type(device_type, device_id),
905
+ get_ort_device_type(device_type),
704
906
  C.OrtDevice.default_memory(),
705
907
  device_id,
706
908
  ),
@@ -766,7 +968,7 @@ class OrtValue:
766
968
  return self._ortvalue
767
969
 
768
970
  @classmethod
769
- def ortvalue_from_numpy(cls, numpy_obj: np.ndarray, /, device_type="cpu", device_id=0) -> OrtValue:
971
+ def ortvalue_from_numpy(cls, numpy_obj: np.ndarray, /, device_type="cpu", device_id=0, vendor_id=-1) -> OrtValue:
770
972
  """
771
973
  Factory method to construct an OrtValue (which holds a Tensor) from a given Numpy object
772
974
  A copy of the data in the Numpy object is held by the OrtValue only if the device is NOT cpu
@@ -774,6 +976,7 @@ class OrtValue:
774
976
  :param numpy_obj: The Numpy object to construct the OrtValue from
775
977
  :param device_type: e.g. cpu, cuda, cann, cpu by default
776
978
  :param device_id: device id, e.g. 0
979
+ :param vendor_id: The device's PCI vendor id. If provided, the device_type should be "gpu" or "npu".
777
980
  """
778
981
  # Hold a reference to the numpy object (if device_type is 'cpu') as the OrtValue
779
982
  # is backed directly by the data buffer of the numpy object and so the numpy object
@@ -781,11 +984,7 @@ class OrtValue:
781
984
  return cls(
782
985
  C.OrtValue.ortvalue_from_numpy(
783
986
  numpy_obj,
784
- C.OrtDevice(
785
- get_ort_device_type(device_type, device_id),
786
- C.OrtDevice.default_memory(),
787
- device_id,
788
- ),
987
+ OrtDevice.make(device_type, device_id, vendor_id)._get_c_device(),
789
988
  ),
790
989
  numpy_obj if device_type.lower() == "cpu" else None,
791
990
  )
@@ -806,7 +1005,7 @@ class OrtValue:
806
1005
 
807
1006
  @classmethod
808
1007
  def ortvalue_from_shape_and_type(
809
- cls, shape: Sequence[int], element_type, device_type: str = "cpu", device_id: int = 0
1008
+ cls, shape: Sequence[int], element_type, device_type: str = "cpu", device_id: int = 0, vendor_id: int = -1
810
1009
  ) -> OrtValue:
811
1010
  """
812
1011
  Factory method to construct an OrtValue (which holds a Tensor) from given shape and element_type
@@ -815,7 +1014,11 @@ class OrtValue:
815
1014
  :param element_type: The data type of the elements. It can be either numpy type (like numpy.float32) or an integer for onnx type (like onnx.TensorProto.BFLOAT16).
816
1015
  :param device_type: e.g. cpu, cuda, cann, cpu by default
817
1016
  :param device_id: device id, e.g. 0
1017
+ :param vendor_id: If provided the device type should be "gpu" or "npu".
818
1018
  """
1019
+
1020
+ device = OrtDevice.make(device_type, device_id, vendor_id)._get_c_device()
1021
+
819
1022
  # Integer for onnx element type (see https://onnx.ai/onnx/api/mapping.html).
820
1023
  # This is helpful for some data type (like TensorProto.BFLOAT16) that is not available in numpy.
821
1024
  if isinstance(element_type, int):
@@ -823,11 +1026,7 @@ class OrtValue:
823
1026
  C.OrtValue.ortvalue_from_shape_and_onnx_type(
824
1027
  shape,
825
1028
  element_type,
826
- C.OrtDevice(
827
- get_ort_device_type(device_type, device_id),
828
- C.OrtDevice.default_memory(),
829
- device_id,
830
- ),
1029
+ device,
831
1030
  )
832
1031
  )
833
1032
 
@@ -835,11 +1034,7 @@ class OrtValue:
835
1034
  C.OrtValue.ortvalue_from_shape_and_type(
836
1035
  shape,
837
1036
  element_type,
838
- C.OrtDevice(
839
- get_ort_device_type(device_type, device_id),
840
- C.OrtDevice.default_memory(),
841
- device_id,
842
- ),
1037
+ device,
843
1038
  )
844
1039
  )
845
1040
 
@@ -888,6 +1083,13 @@ class OrtValue:
888
1083
  """
889
1084
  return self._ortvalue.element_type()
890
1085
 
1086
+ def tensor_size_in_bytes(self) -> int:
1087
+ """
1088
+ Returns the size of the data in the OrtValue in bytes
1089
+ if the OrtValue is a tensor.
1090
+ """
1091
+ return self._ortvalue.tensor_size_in_bytes()
1092
+
891
1093
  def has_value(self) -> bool:
892
1094
  """
893
1095
  Returns True if the OrtValue corresponding to an
@@ -955,14 +1157,27 @@ class OrtDevice:
955
1157
  return self._ort_device
956
1158
 
957
1159
  @staticmethod
958
- def make(ort_device_name, device_id):
959
- return OrtDevice(
960
- C.OrtDevice(
961
- get_ort_device_type(ort_device_name, device_id),
962
- C.OrtDevice.default_memory(),
963
- device_id,
1160
+ def make(ort_device_name, device_id, vendor_id=-1):
1161
+ if vendor_id < 0:
1162
+ # backwards compatibility with predefined OrtDevice names
1163
+ return OrtDevice(
1164
+ C.OrtDevice(
1165
+ get_ort_device_type(ort_device_name),
1166
+ C.OrtDevice.default_memory(),
1167
+ device_id,
1168
+ )
1169
+ )
1170
+ else:
1171
+ # generic. use GPU or NPU for ort_device_name and provide a vendor id.
1172
+ # vendor id of 0 is valid in some cases (e.g. webgpu is generic and does not have a vendor id)
1173
+ return OrtDevice(
1174
+ C.OrtDevice(
1175
+ get_ort_device_type(ort_device_name),
1176
+ C.OrtDevice.default_memory(),
1177
+ vendor_id,
1178
+ device_id,
1179
+ )
964
1180
  )
965
- )
966
1181
 
967
1182
  def device_id(self):
968
1183
  return self._ort_device.device_id()
@@ -970,6 +1185,9 @@ class OrtDevice:
970
1185
  def device_type(self):
971
1186
  return self._ort_device.device_type()
972
1187
 
1188
+ def device_vendor_id(self):
1189
+ return self._ort_device.vendor_id()
1190
+
973
1191
 
974
1192
  class SparseTensor:
975
1193
  """
@@ -1152,3 +1370,14 @@ class SparseTensor:
1152
1370
  Returns the name of the device where the SparseTensor data buffers reside e.g. cpu, cuda
1153
1371
  """
1154
1372
  return self._tensor.device_name().lower()
1373
+
1374
+
1375
+ # Type hint for user-specified function that allows the user to specify initializer locations when compiling a model.
1376
+ GetInitializerLocationFunc = Callable[
1377
+ [str, OrtValue, C.OrtExternalInitializerInfo | None], C.OrtExternalInitializerInfo | None
1378
+ ]
1379
+
1380
+ # Type hint that adheres to the signature expected by ORT.
1381
+ GetInitializerLocationWrapperFunc = Callable[
1382
+ [str, C.OrtValue, C.OrtExternalInitializerInfo | None], C.OrtExternalInitializerInfo | None
1383
+ ]
@@ -57,7 +57,7 @@ def check_distro_info():
57
57
  f"Unsupported macOS version ({__my_distro_ver__}). ONNX Runtime supports macOS 11.0 or later."
58
58
  )
59
59
  elif __my_system__ == "aix":
60
- import subprocess
60
+ import subprocess # noqa: PLC0415
61
61
 
62
62
  returned_output = subprocess.check_output("oslevel")
63
63
  __my_distro_ver__str = returned_output.decode("utf-8")
@@ -74,11 +74,11 @@ def get_package_name_and_version_info():
74
74
  cuda_version = ""
75
75
 
76
76
  try:
77
- from .build_and_package_info import __version__ as version
78
- from .build_and_package_info import package_name
77
+ from .build_and_package_info import __version__ as version # noqa: PLC0415
78
+ from .build_and_package_info import package_name # noqa: PLC0415
79
79
 
80
80
  try: # noqa: SIM105
81
- from .build_and_package_info import cuda_version
81
+ from .build_and_package_info import cuda_version # noqa: PLC0415
82
82
  except ImportError:
83
83
  # cuda_version is optional. For example, cpu only package does not have the attribute.
84
84
  pass
@@ -94,7 +94,7 @@ def check_training_module():
94
94
 
95
95
  has_ortmodule = False
96
96
  try:
97
- from onnxruntime.training.ortmodule import ORTModule # noqa: F401
97
+ from onnxruntime.training.ortmodule import ORTModule # noqa: F401, PLC0415
98
98
 
99
99
  has_ortmodule = True
100
100
  except ImportError:
@@ -105,7 +105,7 @@ def check_training_module():
105
105
  # for any exception other than not having ortmodule, we want to continue
106
106
  # device version validation and raise the exception after.
107
107
  try:
108
- from onnxruntime.training.ortmodule._fallback import ORTModuleInitException
108
+ from onnxruntime.training.ortmodule._fallback import ORTModuleInitException # noqa: PLC0415
109
109
 
110
110
  if isinstance(e, ORTModuleInitException):
111
111
  # ORTModule is present but not ready to run yet
@@ -125,7 +125,7 @@ def check_training_module():
125
125
  # collect cuda library build info. the library info may not be available
126
126
  # when the build environment has none or multiple libraries installed
127
127
  try:
128
- from .build_and_package_info import cudart_version
128
+ from .build_and_package_info import cudart_version # noqa: PLC0415
129
129
  except ImportError:
130
130
  warnings.warn("WARNING: failed to get cudart_version from onnxruntime build info.")
131
131
  cudart_version = None
@@ -137,7 +137,7 @@ def check_training_module():
137
137
  warnings.warn(f"onnxruntime build info: cudart_version: {cudart_version}")
138
138
 
139
139
  # collection cuda library info from current environment.
140
- from onnxruntime.capi.onnxruntime_collect_build_info import find_cudart_versions
140
+ from onnxruntime.capi.onnxruntime_collect_build_info import find_cudart_versions # noqa: PLC0415
141
141
 
142
142
  local_cudart_versions = find_cudart_versions(build_env=False, build_cuda_version=cuda_version)
143
143
  if cudart_version and local_cudart_versions and cudart_version not in local_cudart_versions:
@@ -34,7 +34,7 @@ class TrtTable:
34
34
  x = self._tab.Vector(o)
35
35
  x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
36
36
  x = self._tab.Indirect(x)
37
- from onnxruntime.quantization.CalTableFlatBuffers.KeyValue import KeyValue
37
+ from onnxruntime.quantization.CalTableFlatBuffers.KeyValue import KeyValue # noqa: PLC0415
38
38
 
39
39
  obj = KeyValue()
40
40
  obj.Init(self._tab.Bytes, x)