onnxruntime-directml 1.21.1__cp313-cp313-win_amd64.whl → 1.22.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. onnxruntime/ThirdPartyNotices.txt +74 -0
  2. onnxruntime/__init__.py +1 -1
  3. onnxruntime/capi/DirectML.dll +0 -0
  4. onnxruntime/capi/build_and_package_info.py +1 -1
  5. onnxruntime/capi/onnxruntime.dll +0 -0
  6. onnxruntime/capi/onnxruntime_inference_collection.py +104 -62
  7. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  8. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  9. onnxruntime/quantization/{matmul_4bits_quantizer.py → matmul_nbits_quantizer.py} +92 -41
  10. onnxruntime/quantization/quantize.py +2 -2
  11. onnxruntime/quantization/static_quantize_runner.py +256 -0
  12. onnxruntime/tools/pytorch_export_contrib_ops.py +22 -2
  13. onnxruntime/tools/symbolic_shape_infer.py +71 -14
  14. onnxruntime/transformers/benchmark_helper.py +46 -45
  15. onnxruntime/transformers/convert_generation.py +528 -105
  16. onnxruntime/transformers/fusion_attention.py +10 -2
  17. onnxruntime/transformers/fusion_attention_clip.py +62 -21
  18. onnxruntime/transformers/fusion_bart_attention.py +21 -18
  19. onnxruntime/transformers/fusion_conformer_attention.py +8 -3
  20. onnxruntime/transformers/fusion_fastgelu.py +12 -5
  21. onnxruntime/transformers/fusion_simplified_layernorm.py +41 -13
  22. onnxruntime/transformers/fusion_utils.py +1 -3
  23. onnxruntime/transformers/io_binding_helper.py +4 -0
  24. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +0 -3
  25. onnxruntime/transformers/models/llama/convert_to_onnx.py +85 -13
  26. onnxruntime/transformers/models/llama/llama_inputs.py +7 -2
  27. onnxruntime/transformers/models/llama/llama_parity.py +34 -1
  28. onnxruntime/transformers/models/phi2/convert_to_onnx.py +9 -3
  29. onnxruntime/transformers/models/sam2/benchmark_sam2.py +14 -1
  30. onnxruntime/transformers/models/sam2/convert_to_onnx.py +12 -2
  31. onnxruntime/transformers/models/sam2/image_decoder.py +1 -1
  32. onnxruntime/transformers/models/sam2/image_encoder.py +62 -12
  33. onnxruntime/transformers/models/sam2/mask_decoder.py +1 -1
  34. onnxruntime/transformers/models/sam2/prompt_encoder.py +1 -1
  35. onnxruntime/transformers/models/t5/convert_to_onnx.py +72 -32
  36. onnxruntime/transformers/models/t5/t5_encoder.py +8 -108
  37. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +95 -32
  38. onnxruntime/transformers/models/t5/t5_helper.py +91 -60
  39. onnxruntime/transformers/models/whisper/convert_to_onnx.py +60 -42
  40. onnxruntime/transformers/models/whisper/whisper_chain.py +5 -3
  41. onnxruntime/transformers/models/whisper/whisper_decoder.py +386 -322
  42. onnxruntime/transformers/models/whisper/whisper_encoder.py +101 -100
  43. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +301 -235
  44. onnxruntime/transformers/models/whisper/whisper_helper.py +143 -179
  45. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  46. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  47. onnxruntime/transformers/onnx_model.py +12 -2
  48. onnxruntime/transformers/onnx_model_t5.py +225 -136
  49. {onnxruntime_directml-1.21.1.dist-info → onnxruntime_directml-1.22.0.dist-info}/METADATA +6 -1
  50. {onnxruntime_directml-1.21.1.dist-info → onnxruntime_directml-1.22.0.dist-info}/RECORD +54 -52
  51. {onnxruntime_directml-1.21.1.dist-info → onnxruntime_directml-1.22.0.dist-info}/WHEEL +1 -1
  52. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +0 -84
  53. /onnxruntime/transformers/{models/t5/past_helper.py → past_helper.py} +0 -0
  54. {onnxruntime_directml-1.21.1.dist-info → onnxruntime_directml-1.22.0.dist-info}/entry_points.txt +0 -0
  55. {onnxruntime_directml-1.21.1.dist-info → onnxruntime_directml-1.22.0.dist-info}/top_level.txt +0 -0
@@ -6080,3 +6080,77 @@ https://dawn.googlesource.com/dawn
6080
6080
  CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
6081
6081
  OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
6082
6082
  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
6083
+
6084
+ _____
6085
+
6086
+ KleidiAI
6087
+
6088
+ https://gitlab.arm.com/kleidi/kleidiai
6089
+
6090
+ Apache License
6091
+ Version 2.0, January 2004
6092
+ http://www.apache.org/licenses/
6093
+
6094
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6095
+
6096
+ 1. Definitions.
6097
+
6098
+ "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
6099
+
6100
+ "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
6101
+
6102
+ "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
6103
+
6104
+ "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
6105
+
6106
+ "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
6107
+
6108
+ "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
6109
+
6110
+ "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
6111
+
6112
+ "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
6113
+
6114
+ "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
6115
+
6116
+ "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
6117
+
6118
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
6119
+
6120
+ 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
6121
+
6122
+ 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
6123
+
6124
+ (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and
6125
+
6126
+ (b) You must cause any modified files to carry prominent notices stating that You changed the files; and
6127
+
6128
+ (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
6129
+
6130
+ (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
6131
+
6132
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
6133
+
6134
+ 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
6135
+
6136
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
6137
+
6138
+ 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
6139
+
6140
+ 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
6141
+
6142
+ 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
6143
+
6144
+ END OF TERMS AND CONDITIONS
6145
+
6146
+ Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
6147
+
6148
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
6149
+
6150
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
6151
+
6152
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
6153
+
6154
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
6155
+
6156
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
onnxruntime/__init__.py CHANGED
@@ -8,7 +8,7 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime <https://ak
8
8
  or the `Github project <https://github.com/microsoft/onnxruntime/>`_.
9
9
  """
10
10
 
11
- __version__ = "1.21.1"
11
+ __version__ = "1.22.0"
12
12
  __author__ = "Microsoft"
13
13
 
14
14
  # we need to do device version validation (for example to check Cuda version for an onnxruntime-training package).
Binary file
@@ -1,2 +1,2 @@
1
1
  package_name = 'onnxruntime-directml'
2
- __version__ = '1.21.1'
2
+ __version__ = '1.22.0'
Binary file
@@ -15,6 +15,9 @@ from typing import Any
15
15
  from onnxruntime.capi import _pybind_state as C
16
16
 
17
17
  if typing.TYPE_CHECKING:
18
+ import numpy as np
19
+ import numpy.typing as npt
20
+
18
21
  import onnxruntime
19
22
 
20
23
 
@@ -59,22 +62,22 @@ class AdapterFormat:
59
62
  """
60
63
  self._adapter.export_adapter(file_path)
61
64
 
62
- def get_format_version(self):
65
+ def get_format_version(self) -> int:
63
66
  return self._adapter.format_version
64
67
 
65
- def set_adapter_version(self, adapter_version: int):
68
+ def set_adapter_version(self, adapter_version: int) -> None:
66
69
  self._adapter.adapter_version = adapter_version
67
70
 
68
- def get_adapter_version(self):
71
+ def get_adapter_version(self) -> int:
69
72
  return self._adapter.adapter_version
70
73
 
71
- def set_model_version(self, model_version: int):
74
+ def set_model_version(self, model_version: int) -> None:
72
75
  self._adapter.model_version = model_version
73
76
 
74
- def get_model_version(self):
77
+ def get_model_version(self) -> int:
75
78
  return self._adapter.model_version
76
79
 
77
- def set_parameters(self, params: dict[str, OrtValue]):
80
+ def set_parameters(self, params: dict[str, OrtValue]) -> None:
78
81
  self._adapter.parameters = {k: v._ortvalue for k, v in params.items()}
79
82
 
80
83
  def get_parameters(self) -> dict[str, OrtValue]:
@@ -174,27 +177,27 @@ class Session:
174
177
  self._sess = None
175
178
  self._enable_fallback = True
176
179
 
177
- def get_session_options(self):
180
+ def get_session_options(self) -> onnxruntime.SessionOptions:
178
181
  "Return the session options. See :class:`onnxruntime.SessionOptions`."
179
182
  return self._sess_options
180
183
 
181
- def get_inputs(self):
184
+ def get_inputs(self) -> Sequence[onnxruntime.NodeArg]:
182
185
  "Return the inputs metadata as a list of :class:`onnxruntime.NodeArg`."
183
186
  return self._inputs_meta
184
187
 
185
- def get_outputs(self):
188
+ def get_outputs(self) -> Sequence[onnxruntime.NodeArg]:
186
189
  "Return the outputs metadata as a list of :class:`onnxruntime.NodeArg`."
187
190
  return self._outputs_meta
188
191
 
189
- def get_overridable_initializers(self):
192
+ def get_overridable_initializers(self) -> Sequence[onnxruntime.NodeArg]:
190
193
  "Return the inputs (including initializers) metadata as a list of :class:`onnxruntime.NodeArg`."
191
194
  return self._overridable_initializers
192
195
 
193
- def get_modelmeta(self):
196
+ def get_modelmeta(self) -> onnxruntime.ModelMetadata:
194
197
  "Return the metadata. See :class:`onnxruntime.ModelMetadata`."
195
198
  return self._model_meta
196
199
 
197
- def get_providers(self):
200
+ def get_providers(self) -> Sequence[str]:
198
201
  "Return list of registered execution providers."
199
202
  return self._providers
200
203
 
@@ -202,7 +205,7 @@ class Session:
202
205
  "Return registered execution providers' configurations."
203
206
  return self._provider_options
204
207
 
205
- def set_providers(self, providers=None, provider_options=None):
208
+ def set_providers(self, providers=None, provider_options=None) -> None:
206
209
  """
207
210
  Register the input list of execution providers. The underlying session is re-created.
208
211
 
@@ -224,13 +227,13 @@ class Session:
224
227
  # recreate the underlying C.InferenceSession
225
228
  self._reset_session(providers, provider_options)
226
229
 
227
- def disable_fallback(self):
230
+ def disable_fallback(self) -> None:
228
231
  """
229
232
  Disable session.run() fallback mechanism.
230
233
  """
231
234
  self._enable_fallback = False
232
235
 
233
- def enable_fallback(self):
236
+ def enable_fallback(self) -> None:
234
237
  """
235
238
  Enable session.Run() fallback mechanism. If session.Run() fails due to an internal Execution Provider failure,
236
239
  reset the Execution Providers enabled for this session.
@@ -249,7 +252,7 @@ class Session:
249
252
  f"Required inputs ({missing_input_names}) are missing from input feed ({feed_input_names})."
250
253
  )
251
254
 
252
- def run(self, output_names, input_feed, run_options=None):
255
+ def run(self, output_names, input_feed, run_options=None) -> Sequence[np.ndarray | SparseTensor | list | dict]:
253
256
  """
254
257
  Compute the predictions.
255
258
 
@@ -308,7 +311,7 @@ class Session:
308
311
  output_names = [output.name for output in self._outputs_meta]
309
312
  return self._sess.run_async(output_names, input_feed, callback, user_data, run_options)
310
313
 
311
- def run_with_ort_values(self, output_names, input_dict_ort_values, run_options=None):
314
+ def run_with_ort_values(self, output_names, input_dict_ort_values, run_options=None) -> Sequence[OrtValue]:
312
315
  """
313
316
  Compute the predictions.
314
317
 
@@ -367,7 +370,7 @@ class Session:
367
370
  """
368
371
  return self._sess.get_profiling_start_time_ns
369
372
 
370
- def io_binding(self):
373
+ def io_binding(self) -> IOBinding:
371
374
  "Return an onnxruntime.IOBinding object`."
372
375
  return IOBinding(self)
373
376
 
@@ -504,6 +507,23 @@ class InferenceSession(Session):
504
507
  self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
505
508
  else:
506
509
  self._fallback_providers = ["CPUExecutionProvider"]
510
+ if "NvTensorRTRTXExecutionProvider" in available_providers:
511
+ if (
512
+ providers
513
+ and any(
514
+ provider == "CUDAExecutionProvider"
515
+ or (isinstance(provider, tuple) and provider[0] == "CUDAExecutionProvider")
516
+ for provider in providers
517
+ )
518
+ and any(
519
+ provider == "NvTensorRTRTXExecutionProvider"
520
+ or (isinstance(provider, tuple) and provider[0] == "NvExecutionProvider")
521
+ for provider in providers
522
+ )
523
+ ):
524
+ self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
525
+ else:
526
+ self._fallback_providers = ["CPUExecutionProvider"]
507
527
  # MIGraphX can fall back to ROCM if it's explicitly assigned. All others fall back to CPU.
508
528
  elif "MIGraphXExecutionProvider" in available_providers:
509
529
  if providers and any(
@@ -550,7 +570,7 @@ class InferenceSession(Session):
550
570
  self._provider_options = self._sess.get_provider_options()
551
571
  self._profiling_start_time_ns = self._sess.get_profiling_start_time_ns
552
572
 
553
- def _reset_session(self, providers, provider_options):
573
+ def _reset_session(self, providers, provider_options) -> None:
554
574
  "release underlying session object."
555
575
  # meta data references session internal structures
556
576
  # so they must be set to None to decrement _sess reference count.
@@ -579,6 +599,15 @@ class InferenceSession(Session):
579
599
  ):
580
600
  C.register_tensorrt_plugins_as_custom_ops(session_options, providers[i][1])
581
601
 
602
+ if providers[i] in available_providers and providers[i] == "NvTensorRTRTXExecutionProvider":
603
+ C.register_nv_tensorrt_rtx_plugins_as_custom_ops(session_options, provider_options[i])
604
+ elif (
605
+ isinstance(providers[i], tuple)
606
+ and providers[i][0] in available_providers
607
+ and providers[i][0] == "NvTensorrtRTXExecutionProvider"
608
+ ):
609
+ C.register_nv_tensorrt_rtx_plugins_as_custom_ops(session_options, providers[i][1])
610
+
582
611
 
583
612
  class IOBinding:
584
613
  """
@@ -721,7 +750,7 @@ class OrtValue:
721
750
  This class provides APIs to construct and deal with OrtValues.
722
751
  """
723
752
 
724
- def __init__(self, ortvalue, numpy_obj=None):
753
+ def __init__(self, ortvalue: C.OrtValue, numpy_obj: np.ndarray | None = None):
725
754
  if isinstance(ortvalue, C.OrtValue):
726
755
  self._ortvalue = ortvalue
727
756
  # Hold a ref count to the numpy object if the OrtValue is backed directly
@@ -733,11 +762,11 @@ class OrtValue:
733
762
  "`Provided ortvalue` needs to be of type `onnxruntime.capi.onnxruntime_pybind11_state.OrtValue`"
734
763
  )
735
764
 
736
- def _get_c_value(self):
765
+ def _get_c_value(self) -> C.OrtValue:
737
766
  return self._ortvalue
738
767
 
739
- @staticmethod
740
- def ortvalue_from_numpy(numpy_obj, device_type="cpu", device_id=0):
768
+ @classmethod
769
+ def ortvalue_from_numpy(cls, numpy_obj: np.ndarray, /, device_type="cpu", device_id=0) -> OrtValue:
741
770
  """
742
771
  Factory method to construct an OrtValue (which holds a Tensor) from a given Numpy object
743
772
  A copy of the data in the Numpy object is held by the OrtValue only if the device is NOT cpu
@@ -749,7 +778,7 @@ class OrtValue:
749
778
  # Hold a reference to the numpy object (if device_type is 'cpu') as the OrtValue
750
779
  # is backed directly by the data buffer of the numpy object and so the numpy object
751
780
  # must be around until this OrtValue instance is around
752
- return OrtValue(
781
+ return cls(
753
782
  C.OrtValue.ortvalue_from_numpy(
754
783
  numpy_obj,
755
784
  C.OrtDevice(
@@ -761,8 +790,8 @@ class OrtValue:
761
790
  numpy_obj if device_type.lower() == "cpu" else None,
762
791
  )
763
792
 
764
- @staticmethod
765
- def ortvalue_from_numpy_with_onnx_type(data, onnx_element_type: int):
793
+ @classmethod
794
+ def ortvalue_from_numpy_with_onnx_type(cls, data: np.ndarray, /, onnx_element_type: int) -> OrtValue:
766
795
  """
767
796
  This method creates an instance of OrtValue on top of the numpy array.
768
797
  No data copy is made and the lifespan of the resulting OrtValue should never
@@ -771,12 +800,14 @@ class OrtValue:
771
800
  when we want to use an ONNX data type that is not supported by numpy.
772
801
 
773
802
  :param data: numpy.ndarray.
774
- :param onnx_elemenet_type: a valid onnx TensorProto::DataType enum value
803
+ :param onnx_element_type: a valid onnx TensorProto::DataType enum value
775
804
  """
776
- return OrtValue(C.OrtValue.ortvalue_from_numpy_with_onnx_type(data, onnx_element_type), data)
805
+ return cls(C.OrtValue.ortvalue_from_numpy_with_onnx_type(data, onnx_element_type), data)
777
806
 
778
- @staticmethod
779
- def ortvalue_from_shape_and_type(shape, element_type, device_type: str = "cpu", device_id: int = 0):
807
+ @classmethod
808
+ def ortvalue_from_shape_and_type(
809
+ cls, shape: Sequence[int], element_type, device_type: str = "cpu", device_id: int = 0
810
+ ) -> OrtValue:
780
811
  """
781
812
  Factory method to construct an OrtValue (which holds a Tensor) from given shape and element_type
782
813
 
@@ -788,7 +819,7 @@ class OrtValue:
788
819
  # Integer for onnx element type (see https://onnx.ai/onnx/api/mapping.html).
789
820
  # This is helpful for some data type (like TensorProto.BFLOAT16) that is not available in numpy.
790
821
  if isinstance(element_type, int):
791
- return OrtValue(
822
+ return cls(
792
823
  C.OrtValue.ortvalue_from_shape_and_onnx_type(
793
824
  shape,
794
825
  element_type,
@@ -800,7 +831,7 @@ class OrtValue:
800
831
  )
801
832
  )
802
833
 
803
- return OrtValue(
834
+ return cls(
804
835
  C.OrtValue.ortvalue_from_shape_and_type(
805
836
  shape,
806
837
  element_type,
@@ -812,77 +843,77 @@ class OrtValue:
812
843
  )
813
844
  )
814
845
 
815
- @staticmethod
816
- def ort_value_from_sparse_tensor(sparse_tensor):
846
+ @classmethod
847
+ def ort_value_from_sparse_tensor(cls, sparse_tensor: SparseTensor) -> OrtValue:
817
848
  """
818
849
  The function will construct an OrtValue instance from a valid SparseTensor
819
850
  The new instance of OrtValue will assume the ownership of sparse_tensor
820
851
  """
821
- return OrtValue(C.OrtValue.ort_value_from_sparse_tensor(sparse_tensor._get_c_tensor()))
852
+ return cls(C.OrtValue.ort_value_from_sparse_tensor(sparse_tensor._get_c_tensor()))
822
853
 
823
- def as_sparse_tensor(self):
854
+ def as_sparse_tensor(self) -> SparseTensor:
824
855
  """
825
856
  The function will return SparseTensor contained in this OrtValue
826
857
  """
827
858
  return SparseTensor(self._ortvalue.as_sparse_tensor())
828
859
 
829
- def data_ptr(self):
860
+ def data_ptr(self) -> int:
830
861
  """
831
862
  Returns the address of the first element in the OrtValue's data buffer
832
863
  """
833
864
  return self._ortvalue.data_ptr()
834
865
 
835
- def device_name(self):
866
+ def device_name(self) -> str:
836
867
  """
837
868
  Returns the name of the device where the OrtValue's data buffer resides e.g. cpu, cuda, cann
838
869
  """
839
870
  return self._ortvalue.device_name().lower()
840
871
 
841
- def shape(self):
872
+ def shape(self) -> Sequence[int]:
842
873
  """
843
874
  Returns the shape of the data in the OrtValue
844
875
  """
845
876
  return self._ortvalue.shape()
846
877
 
847
- def data_type(self):
878
+ def data_type(self) -> str:
848
879
  """
849
- Returns the data type of the data in the OrtValue
880
+ Returns the data type of the data in the OrtValue. E.g. 'tensor(int64)'
850
881
  """
851
882
  return self._ortvalue.data_type()
852
883
 
853
- def element_type(self):
884
+ def element_type(self) -> int:
854
885
  """
855
886
  Returns the proto type of the data in the OrtValue
856
887
  if the OrtValue is a tensor.
857
888
  """
858
889
  return self._ortvalue.element_type()
859
890
 
860
- def has_value(self):
891
+ def has_value(self) -> bool:
861
892
  """
862
893
  Returns True if the OrtValue corresponding to an
863
894
  optional type contains data, else returns False
864
895
  """
865
896
  return self._ortvalue.has_value()
866
897
 
867
- def is_tensor(self):
898
+ def is_tensor(self) -> bool:
868
899
  """
869
900
  Returns True if the OrtValue contains a Tensor, else returns False
870
901
  """
871
902
  return self._ortvalue.is_tensor()
872
903
 
873
- def is_sparse_tensor(self):
904
+ def is_sparse_tensor(self) -> bool:
874
905
  """
875
906
  Returns True if the OrtValue contains a SparseTensor, else returns False
876
907
  """
877
908
  return self._ortvalue.is_sparse_tensor()
878
909
 
879
- def is_tensor_sequence(self):
910
+ def is_tensor_sequence(self) -> bool:
880
911
  """
881
912
  Returns True if the OrtValue contains a Tensor Sequence, else returns False
882
913
  """
883
914
  return self._ortvalue.is_tensor_sequence()
884
915
 
885
- def numpy(self):
916
+ def numpy(self) -> np.ndarray:
886
917
  """
887
918
  Returns a Numpy object from the OrtValue.
888
919
  Valid only for OrtValues holding Tensors. Throws for OrtValues holding non-Tensors.
@@ -890,7 +921,7 @@ class OrtValue:
890
921
  """
891
922
  return self._ortvalue.numpy()
892
923
 
893
- def update_inplace(self, np_arr):
924
+ def update_inplace(self, np_arr) -> None:
894
925
  """
895
926
  Update the OrtValue in place with a new Numpy array. The numpy contents
896
927
  are copied over to the device memory backing the OrtValue. It can be used
@@ -948,7 +979,7 @@ class SparseTensor:
948
979
  depending on the format
949
980
  """
950
981
 
951
- def __init__(self, sparse_tensor):
982
+ def __init__(self, sparse_tensor: C.SparseTensor):
952
983
  """
953
984
  Internal constructor
954
985
  """
@@ -960,11 +991,17 @@ class SparseTensor:
960
991
  "`Provided object` needs to be of type `onnxruntime.capi.onnxruntime_pybind11_state.SparseTensor`"
961
992
  )
962
993
 
963
- def _get_c_tensor(self):
994
+ def _get_c_tensor(self) -> C.SparseTensor:
964
995
  return self._tensor
965
996
 
966
- @staticmethod
967
- def sparse_coo_from_numpy(dense_shape, values, coo_indices, ort_device):
997
+ @classmethod
998
+ def sparse_coo_from_numpy(
999
+ cls,
1000
+ dense_shape: npt.NDArray[np.int64],
1001
+ values: np.ndarray,
1002
+ coo_indices: npt.NDArray[np.int64],
1003
+ ort_device: OrtDevice,
1004
+ ) -> SparseTensor:
968
1005
  """
969
1006
  Factory method to construct a SparseTensor in COO format from given arguments
970
1007
 
@@ -985,12 +1022,17 @@ class SparseTensor:
985
1022
  For strings and objects, it will create a copy of the arrays in CPU memory as ORT does not support those
986
1023
  on other devices and their memory can not be mapped.
987
1024
  """
988
- return SparseTensor(
989
- C.SparseTensor.sparse_coo_from_numpy(dense_shape, values, coo_indices, ort_device._get_c_device())
990
- )
1025
+ return cls(C.SparseTensor.sparse_coo_from_numpy(dense_shape, values, coo_indices, ort_device._get_c_device()))
991
1026
 
992
- @staticmethod
993
- def sparse_csr_from_numpy(dense_shape, values, inner_indices, outer_indices, ort_device):
1027
+ @classmethod
1028
+ def sparse_csr_from_numpy(
1029
+ cls,
1030
+ dense_shape: npt.NDArray[np.int64],
1031
+ values: np.ndarray,
1032
+ inner_indices: npt.NDArray[np.int64],
1033
+ outer_indices: npt.NDArray[np.int64],
1034
+ ort_device: OrtDevice,
1035
+ ) -> SparseTensor:
994
1036
  """
995
1037
  Factory method to construct a SparseTensor in CSR format from given arguments
996
1038
 
@@ -1011,7 +1053,7 @@ class SparseTensor:
1011
1053
  For strings and objects, it will create a copy of the arrays in CPU memory as ORT does not support those
1012
1054
  on other devices and their memory can not be mapped.
1013
1055
  """
1014
- return SparseTensor(
1056
+ return cls(
1015
1057
  C.SparseTensor.sparse_csr_from_numpy(
1016
1058
  dense_shape,
1017
1059
  values,
@@ -1021,7 +1063,7 @@ class SparseTensor:
1021
1063
  )
1022
1064
  )
1023
1065
 
1024
- def values(self):
1066
+ def values(self) -> np.ndarray:
1025
1067
  """
1026
1068
  The method returns a numpy array that is backed by the native memory
1027
1069
  if the data type is numeric. Otherwise, the returned numpy array that contains
@@ -1093,19 +1135,19 @@ class SparseTensor:
1093
1135
  """
1094
1136
  return self._tensor.format
1095
1137
 
1096
- def dense_shape(self):
1138
+ def dense_shape(self) -> npt.NDArray[np.int64]:
1097
1139
  """
1098
1140
  Returns a numpy array(int64) containing a dense shape of a sparse tensor
1099
1141
  """
1100
1142
  return self._tensor.dense_shape()
1101
1143
 
1102
- def data_type(self):
1144
+ def data_type(self) -> str:
1103
1145
  """
1104
1146
  Returns a string data type of the data in the OrtValue
1105
1147
  """
1106
1148
  return self._tensor.data_type()
1107
1149
 
1108
- def device_name(self):
1150
+ def device_name(self) -> str:
1109
1151
  """
1110
1152
  Returns the name of the device where the SparseTensor data buffers reside e.g. cpu, cuda
1111
1153
  """