onnxruntime-directml 1.21.1__cp313-cp313-win_amd64.whl → 1.22.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/ThirdPartyNotices.txt +74 -0
- onnxruntime/__init__.py +1 -1
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/build_and_package_info.py +1 -1
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +104 -62
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/quantization/{matmul_4bits_quantizer.py → matmul_nbits_quantizer.py} +92 -41
- onnxruntime/quantization/quantize.py +2 -2
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +22 -2
- onnxruntime/tools/symbolic_shape_infer.py +71 -14
- onnxruntime/transformers/benchmark_helper.py +46 -45
- onnxruntime/transformers/convert_generation.py +528 -105
- onnxruntime/transformers/fusion_attention.py +10 -2
- onnxruntime/transformers/fusion_attention_clip.py +62 -21
- onnxruntime/transformers/fusion_bart_attention.py +21 -18
- onnxruntime/transformers/fusion_conformer_attention.py +8 -3
- onnxruntime/transformers/fusion_fastgelu.py +12 -5
- onnxruntime/transformers/fusion_simplified_layernorm.py +41 -13
- onnxruntime/transformers/fusion_utils.py +1 -3
- onnxruntime/transformers/io_binding_helper.py +4 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +0 -3
- onnxruntime/transformers/models/llama/convert_to_onnx.py +85 -13
- onnxruntime/transformers/models/llama/llama_inputs.py +7 -2
- onnxruntime/transformers/models/llama/llama_parity.py +34 -1
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +9 -3
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +14 -1
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +12 -2
- onnxruntime/transformers/models/sam2/image_decoder.py +1 -1
- onnxruntime/transformers/models/sam2/image_encoder.py +62 -12
- onnxruntime/transformers/models/sam2/mask_decoder.py +1 -1
- onnxruntime/transformers/models/sam2/prompt_encoder.py +1 -1
- onnxruntime/transformers/models/t5/convert_to_onnx.py +72 -32
- onnxruntime/transformers/models/t5/t5_encoder.py +8 -108
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +95 -32
- onnxruntime/transformers/models/t5/t5_helper.py +91 -60
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +60 -42
- onnxruntime/transformers/models/whisper/whisper_chain.py +5 -3
- onnxruntime/transformers/models/whisper/whisper_decoder.py +386 -322
- onnxruntime/transformers/models/whisper/whisper_encoder.py +101 -100
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +301 -235
- onnxruntime/transformers/models/whisper/whisper_helper.py +143 -179
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_model.py +12 -2
- onnxruntime/transformers/onnx_model_t5.py +225 -136
- {onnxruntime_directml-1.21.1.dist-info → onnxruntime_directml-1.22.0.dist-info}/METADATA +6 -1
- {onnxruntime_directml-1.21.1.dist-info → onnxruntime_directml-1.22.0.dist-info}/RECORD +54 -52
- {onnxruntime_directml-1.21.1.dist-info → onnxruntime_directml-1.22.0.dist-info}/WHEEL +1 -1
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +0 -84
- /onnxruntime/transformers/{models/t5/past_helper.py → past_helper.py} +0 -0
- {onnxruntime_directml-1.21.1.dist-info → onnxruntime_directml-1.22.0.dist-info}/entry_points.txt +0 -0
- {onnxruntime_directml-1.21.1.dist-info → onnxruntime_directml-1.22.0.dist-info}/top_level.txt +0 -0
|
@@ -6080,3 +6080,77 @@ https://dawn.googlesource.com/dawn
|
|
|
6080
6080
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
6081
6081
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
6082
6082
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
6083
|
+
|
|
6084
|
+
_____
|
|
6085
|
+
|
|
6086
|
+
KleidiAI
|
|
6087
|
+
|
|
6088
|
+
https://gitlab.arm.com/kleidi/kleidiai
|
|
6089
|
+
|
|
6090
|
+
Apache License
|
|
6091
|
+
Version 2.0, January 2004
|
|
6092
|
+
http://www.apache.org/licenses/
|
|
6093
|
+
|
|
6094
|
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
6095
|
+
|
|
6096
|
+
1. Definitions.
|
|
6097
|
+
|
|
6098
|
+
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
|
6099
|
+
|
|
6100
|
+
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
|
6101
|
+
|
|
6102
|
+
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
|
6103
|
+
|
|
6104
|
+
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
|
|
6105
|
+
|
|
6106
|
+
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
|
6107
|
+
|
|
6108
|
+
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
|
6109
|
+
|
|
6110
|
+
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
|
|
6111
|
+
|
|
6112
|
+
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
|
6113
|
+
|
|
6114
|
+
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
|
6115
|
+
|
|
6116
|
+
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
|
6117
|
+
|
|
6118
|
+
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
|
6119
|
+
|
|
6120
|
+
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
|
|
6121
|
+
|
|
6122
|
+
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
|
6123
|
+
|
|
6124
|
+
(a) You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
|
6125
|
+
|
|
6126
|
+
(b) You must cause any modified files to carry prominent notices stating that You changed the files; and
|
|
6127
|
+
|
|
6128
|
+
(c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
|
6129
|
+
|
|
6130
|
+
(d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
|
6131
|
+
|
|
6132
|
+
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
|
6133
|
+
|
|
6134
|
+
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
|
6135
|
+
|
|
6136
|
+
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
|
|
6137
|
+
|
|
6138
|
+
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
|
|
6139
|
+
|
|
6140
|
+
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
|
6141
|
+
|
|
6142
|
+
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
|
6143
|
+
|
|
6144
|
+
END OF TERMS AND CONDITIONS
|
|
6145
|
+
|
|
6146
|
+
Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
|
|
6147
|
+
|
|
6148
|
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
|
6149
|
+
|
|
6150
|
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
|
6151
|
+
|
|
6152
|
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
|
6153
|
+
|
|
6154
|
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
|
6155
|
+
|
|
6156
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
onnxruntime/__init__.py
CHANGED
|
@@ -8,7 +8,7 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime <https://ak
|
|
|
8
8
|
or the `Github project <https://github.com/microsoft/onnxruntime/>`_.
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
|
-
__version__ = "1.
|
|
11
|
+
__version__ = "1.22.0"
|
|
12
12
|
__author__ = "Microsoft"
|
|
13
13
|
|
|
14
14
|
# we need to do device version validation (for example to check Cuda version for an onnxruntime-training package).
|
onnxruntime/capi/DirectML.dll
CHANGED
|
Binary file
|
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
package_name = 'onnxruntime-directml'
|
|
2
|
-
__version__ = '1.
|
|
2
|
+
__version__ = '1.22.0'
|
onnxruntime/capi/onnxruntime.dll
CHANGED
|
Binary file
|
|
@@ -15,6 +15,9 @@ from typing import Any
|
|
|
15
15
|
from onnxruntime.capi import _pybind_state as C
|
|
16
16
|
|
|
17
17
|
if typing.TYPE_CHECKING:
|
|
18
|
+
import numpy as np
|
|
19
|
+
import numpy.typing as npt
|
|
20
|
+
|
|
18
21
|
import onnxruntime
|
|
19
22
|
|
|
20
23
|
|
|
@@ -59,22 +62,22 @@ class AdapterFormat:
|
|
|
59
62
|
"""
|
|
60
63
|
self._adapter.export_adapter(file_path)
|
|
61
64
|
|
|
62
|
-
def get_format_version(self):
|
|
65
|
+
def get_format_version(self) -> int:
|
|
63
66
|
return self._adapter.format_version
|
|
64
67
|
|
|
65
|
-
def set_adapter_version(self, adapter_version: int):
|
|
68
|
+
def set_adapter_version(self, adapter_version: int) -> None:
|
|
66
69
|
self._adapter.adapter_version = adapter_version
|
|
67
70
|
|
|
68
|
-
def get_adapter_version(self):
|
|
71
|
+
def get_adapter_version(self) -> int:
|
|
69
72
|
return self._adapter.adapter_version
|
|
70
73
|
|
|
71
|
-
def set_model_version(self, model_version: int):
|
|
74
|
+
def set_model_version(self, model_version: int) -> None:
|
|
72
75
|
self._adapter.model_version = model_version
|
|
73
76
|
|
|
74
|
-
def get_model_version(self):
|
|
77
|
+
def get_model_version(self) -> int:
|
|
75
78
|
return self._adapter.model_version
|
|
76
79
|
|
|
77
|
-
def set_parameters(self, params: dict[str, OrtValue]):
|
|
80
|
+
def set_parameters(self, params: dict[str, OrtValue]) -> None:
|
|
78
81
|
self._adapter.parameters = {k: v._ortvalue for k, v in params.items()}
|
|
79
82
|
|
|
80
83
|
def get_parameters(self) -> dict[str, OrtValue]:
|
|
@@ -174,27 +177,27 @@ class Session:
|
|
|
174
177
|
self._sess = None
|
|
175
178
|
self._enable_fallback = True
|
|
176
179
|
|
|
177
|
-
def get_session_options(self):
|
|
180
|
+
def get_session_options(self) -> onnxruntime.SessionOptions:
|
|
178
181
|
"Return the session options. See :class:`onnxruntime.SessionOptions`."
|
|
179
182
|
return self._sess_options
|
|
180
183
|
|
|
181
|
-
def get_inputs(self):
|
|
184
|
+
def get_inputs(self) -> Sequence[onnxruntime.NodeArg]:
|
|
182
185
|
"Return the inputs metadata as a list of :class:`onnxruntime.NodeArg`."
|
|
183
186
|
return self._inputs_meta
|
|
184
187
|
|
|
185
|
-
def get_outputs(self):
|
|
188
|
+
def get_outputs(self) -> Sequence[onnxruntime.NodeArg]:
|
|
186
189
|
"Return the outputs metadata as a list of :class:`onnxruntime.NodeArg`."
|
|
187
190
|
return self._outputs_meta
|
|
188
191
|
|
|
189
|
-
def get_overridable_initializers(self):
|
|
192
|
+
def get_overridable_initializers(self) -> Sequence[onnxruntime.NodeArg]:
|
|
190
193
|
"Return the inputs (including initializers) metadata as a list of :class:`onnxruntime.NodeArg`."
|
|
191
194
|
return self._overridable_initializers
|
|
192
195
|
|
|
193
|
-
def get_modelmeta(self):
|
|
196
|
+
def get_modelmeta(self) -> onnxruntime.ModelMetadata:
|
|
194
197
|
"Return the metadata. See :class:`onnxruntime.ModelMetadata`."
|
|
195
198
|
return self._model_meta
|
|
196
199
|
|
|
197
|
-
def get_providers(self):
|
|
200
|
+
def get_providers(self) -> Sequence[str]:
|
|
198
201
|
"Return list of registered execution providers."
|
|
199
202
|
return self._providers
|
|
200
203
|
|
|
@@ -202,7 +205,7 @@ class Session:
|
|
|
202
205
|
"Return registered execution providers' configurations."
|
|
203
206
|
return self._provider_options
|
|
204
207
|
|
|
205
|
-
def set_providers(self, providers=None, provider_options=None):
|
|
208
|
+
def set_providers(self, providers=None, provider_options=None) -> None:
|
|
206
209
|
"""
|
|
207
210
|
Register the input list of execution providers. The underlying session is re-created.
|
|
208
211
|
|
|
@@ -224,13 +227,13 @@ class Session:
|
|
|
224
227
|
# recreate the underlying C.InferenceSession
|
|
225
228
|
self._reset_session(providers, provider_options)
|
|
226
229
|
|
|
227
|
-
def disable_fallback(self):
|
|
230
|
+
def disable_fallback(self) -> None:
|
|
228
231
|
"""
|
|
229
232
|
Disable session.run() fallback mechanism.
|
|
230
233
|
"""
|
|
231
234
|
self._enable_fallback = False
|
|
232
235
|
|
|
233
|
-
def enable_fallback(self):
|
|
236
|
+
def enable_fallback(self) -> None:
|
|
234
237
|
"""
|
|
235
238
|
Enable session.Run() fallback mechanism. If session.Run() fails due to an internal Execution Provider failure,
|
|
236
239
|
reset the Execution Providers enabled for this session.
|
|
@@ -249,7 +252,7 @@ class Session:
|
|
|
249
252
|
f"Required inputs ({missing_input_names}) are missing from input feed ({feed_input_names})."
|
|
250
253
|
)
|
|
251
254
|
|
|
252
|
-
def run(self, output_names, input_feed, run_options=None):
|
|
255
|
+
def run(self, output_names, input_feed, run_options=None) -> Sequence[np.ndarray | SparseTensor | list | dict]:
|
|
253
256
|
"""
|
|
254
257
|
Compute the predictions.
|
|
255
258
|
|
|
@@ -308,7 +311,7 @@ class Session:
|
|
|
308
311
|
output_names = [output.name for output in self._outputs_meta]
|
|
309
312
|
return self._sess.run_async(output_names, input_feed, callback, user_data, run_options)
|
|
310
313
|
|
|
311
|
-
def run_with_ort_values(self, output_names, input_dict_ort_values, run_options=None):
|
|
314
|
+
def run_with_ort_values(self, output_names, input_dict_ort_values, run_options=None) -> Sequence[OrtValue]:
|
|
312
315
|
"""
|
|
313
316
|
Compute the predictions.
|
|
314
317
|
|
|
@@ -367,7 +370,7 @@ class Session:
|
|
|
367
370
|
"""
|
|
368
371
|
return self._sess.get_profiling_start_time_ns
|
|
369
372
|
|
|
370
|
-
def io_binding(self):
|
|
373
|
+
def io_binding(self) -> IOBinding:
|
|
371
374
|
"Return an onnxruntime.IOBinding object`."
|
|
372
375
|
return IOBinding(self)
|
|
373
376
|
|
|
@@ -504,6 +507,23 @@ class InferenceSession(Session):
|
|
|
504
507
|
self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
|
505
508
|
else:
|
|
506
509
|
self._fallback_providers = ["CPUExecutionProvider"]
|
|
510
|
+
if "NvTensorRTRTXExecutionProvider" in available_providers:
|
|
511
|
+
if (
|
|
512
|
+
providers
|
|
513
|
+
and any(
|
|
514
|
+
provider == "CUDAExecutionProvider"
|
|
515
|
+
or (isinstance(provider, tuple) and provider[0] == "CUDAExecutionProvider")
|
|
516
|
+
for provider in providers
|
|
517
|
+
)
|
|
518
|
+
and any(
|
|
519
|
+
provider == "NvTensorRTRTXExecutionProvider"
|
|
520
|
+
or (isinstance(provider, tuple) and provider[0] == "NvExecutionProvider")
|
|
521
|
+
for provider in providers
|
|
522
|
+
)
|
|
523
|
+
):
|
|
524
|
+
self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
|
525
|
+
else:
|
|
526
|
+
self._fallback_providers = ["CPUExecutionProvider"]
|
|
507
527
|
# MIGraphX can fall back to ROCM if it's explicitly assigned. All others fall back to CPU.
|
|
508
528
|
elif "MIGraphXExecutionProvider" in available_providers:
|
|
509
529
|
if providers and any(
|
|
@@ -550,7 +570,7 @@ class InferenceSession(Session):
|
|
|
550
570
|
self._provider_options = self._sess.get_provider_options()
|
|
551
571
|
self._profiling_start_time_ns = self._sess.get_profiling_start_time_ns
|
|
552
572
|
|
|
553
|
-
def _reset_session(self, providers, provider_options):
|
|
573
|
+
def _reset_session(self, providers, provider_options) -> None:
|
|
554
574
|
"release underlying session object."
|
|
555
575
|
# meta data references session internal structures
|
|
556
576
|
# so they must be set to None to decrement _sess reference count.
|
|
@@ -579,6 +599,15 @@ class InferenceSession(Session):
|
|
|
579
599
|
):
|
|
580
600
|
C.register_tensorrt_plugins_as_custom_ops(session_options, providers[i][1])
|
|
581
601
|
|
|
602
|
+
if providers[i] in available_providers and providers[i] == "NvTensorRTRTXExecutionProvider":
|
|
603
|
+
C.register_nv_tensorrt_rtx_plugins_as_custom_ops(session_options, provider_options[i])
|
|
604
|
+
elif (
|
|
605
|
+
isinstance(providers[i], tuple)
|
|
606
|
+
and providers[i][0] in available_providers
|
|
607
|
+
and providers[i][0] == "NvTensorrtRTXExecutionProvider"
|
|
608
|
+
):
|
|
609
|
+
C.register_nv_tensorrt_rtx_plugins_as_custom_ops(session_options, providers[i][1])
|
|
610
|
+
|
|
582
611
|
|
|
583
612
|
class IOBinding:
|
|
584
613
|
"""
|
|
@@ -721,7 +750,7 @@ class OrtValue:
|
|
|
721
750
|
This class provides APIs to construct and deal with OrtValues.
|
|
722
751
|
"""
|
|
723
752
|
|
|
724
|
-
def __init__(self, ortvalue, numpy_obj=None):
|
|
753
|
+
def __init__(self, ortvalue: C.OrtValue, numpy_obj: np.ndarray | None = None):
|
|
725
754
|
if isinstance(ortvalue, C.OrtValue):
|
|
726
755
|
self._ortvalue = ortvalue
|
|
727
756
|
# Hold a ref count to the numpy object if the OrtValue is backed directly
|
|
@@ -733,11 +762,11 @@ class OrtValue:
|
|
|
733
762
|
"`Provided ortvalue` needs to be of type `onnxruntime.capi.onnxruntime_pybind11_state.OrtValue`"
|
|
734
763
|
)
|
|
735
764
|
|
|
736
|
-
def _get_c_value(self):
|
|
765
|
+
def _get_c_value(self) -> C.OrtValue:
|
|
737
766
|
return self._ortvalue
|
|
738
767
|
|
|
739
|
-
@
|
|
740
|
-
def ortvalue_from_numpy(numpy_obj, device_type="cpu", device_id=0):
|
|
768
|
+
@classmethod
|
|
769
|
+
def ortvalue_from_numpy(cls, numpy_obj: np.ndarray, /, device_type="cpu", device_id=0) -> OrtValue:
|
|
741
770
|
"""
|
|
742
771
|
Factory method to construct an OrtValue (which holds a Tensor) from a given Numpy object
|
|
743
772
|
A copy of the data in the Numpy object is held by the OrtValue only if the device is NOT cpu
|
|
@@ -749,7 +778,7 @@ class OrtValue:
|
|
|
749
778
|
# Hold a reference to the numpy object (if device_type is 'cpu') as the OrtValue
|
|
750
779
|
# is backed directly by the data buffer of the numpy object and so the numpy object
|
|
751
780
|
# must be around until this OrtValue instance is around
|
|
752
|
-
return
|
|
781
|
+
return cls(
|
|
753
782
|
C.OrtValue.ortvalue_from_numpy(
|
|
754
783
|
numpy_obj,
|
|
755
784
|
C.OrtDevice(
|
|
@@ -761,8 +790,8 @@ class OrtValue:
|
|
|
761
790
|
numpy_obj if device_type.lower() == "cpu" else None,
|
|
762
791
|
)
|
|
763
792
|
|
|
764
|
-
@
|
|
765
|
-
def ortvalue_from_numpy_with_onnx_type(data, onnx_element_type: int):
|
|
793
|
+
@classmethod
|
|
794
|
+
def ortvalue_from_numpy_with_onnx_type(cls, data: np.ndarray, /, onnx_element_type: int) -> OrtValue:
|
|
766
795
|
"""
|
|
767
796
|
This method creates an instance of OrtValue on top of the numpy array.
|
|
768
797
|
No data copy is made and the lifespan of the resulting OrtValue should never
|
|
@@ -771,12 +800,14 @@ class OrtValue:
|
|
|
771
800
|
when we want to use an ONNX data type that is not supported by numpy.
|
|
772
801
|
|
|
773
802
|
:param data: numpy.ndarray.
|
|
774
|
-
:param
|
|
803
|
+
:param onnx_element_type: a valid onnx TensorProto::DataType enum value
|
|
775
804
|
"""
|
|
776
|
-
return
|
|
805
|
+
return cls(C.OrtValue.ortvalue_from_numpy_with_onnx_type(data, onnx_element_type), data)
|
|
777
806
|
|
|
778
|
-
@
|
|
779
|
-
def ortvalue_from_shape_and_type(
|
|
807
|
+
@classmethod
|
|
808
|
+
def ortvalue_from_shape_and_type(
|
|
809
|
+
cls, shape: Sequence[int], element_type, device_type: str = "cpu", device_id: int = 0
|
|
810
|
+
) -> OrtValue:
|
|
780
811
|
"""
|
|
781
812
|
Factory method to construct an OrtValue (which holds a Tensor) from given shape and element_type
|
|
782
813
|
|
|
@@ -788,7 +819,7 @@ class OrtValue:
|
|
|
788
819
|
# Integer for onnx element type (see https://onnx.ai/onnx/api/mapping.html).
|
|
789
820
|
# This is helpful for some data type (like TensorProto.BFLOAT16) that is not available in numpy.
|
|
790
821
|
if isinstance(element_type, int):
|
|
791
|
-
return
|
|
822
|
+
return cls(
|
|
792
823
|
C.OrtValue.ortvalue_from_shape_and_onnx_type(
|
|
793
824
|
shape,
|
|
794
825
|
element_type,
|
|
@@ -800,7 +831,7 @@ class OrtValue:
|
|
|
800
831
|
)
|
|
801
832
|
)
|
|
802
833
|
|
|
803
|
-
return
|
|
834
|
+
return cls(
|
|
804
835
|
C.OrtValue.ortvalue_from_shape_and_type(
|
|
805
836
|
shape,
|
|
806
837
|
element_type,
|
|
@@ -812,77 +843,77 @@ class OrtValue:
|
|
|
812
843
|
)
|
|
813
844
|
)
|
|
814
845
|
|
|
815
|
-
@
|
|
816
|
-
def ort_value_from_sparse_tensor(sparse_tensor):
|
|
846
|
+
@classmethod
|
|
847
|
+
def ort_value_from_sparse_tensor(cls, sparse_tensor: SparseTensor) -> OrtValue:
|
|
817
848
|
"""
|
|
818
849
|
The function will construct an OrtValue instance from a valid SparseTensor
|
|
819
850
|
The new instance of OrtValue will assume the ownership of sparse_tensor
|
|
820
851
|
"""
|
|
821
|
-
return
|
|
852
|
+
return cls(C.OrtValue.ort_value_from_sparse_tensor(sparse_tensor._get_c_tensor()))
|
|
822
853
|
|
|
823
|
-
def as_sparse_tensor(self):
|
|
854
|
+
def as_sparse_tensor(self) -> SparseTensor:
|
|
824
855
|
"""
|
|
825
856
|
The function will return SparseTensor contained in this OrtValue
|
|
826
857
|
"""
|
|
827
858
|
return SparseTensor(self._ortvalue.as_sparse_tensor())
|
|
828
859
|
|
|
829
|
-
def data_ptr(self):
|
|
860
|
+
def data_ptr(self) -> int:
|
|
830
861
|
"""
|
|
831
862
|
Returns the address of the first element in the OrtValue's data buffer
|
|
832
863
|
"""
|
|
833
864
|
return self._ortvalue.data_ptr()
|
|
834
865
|
|
|
835
|
-
def device_name(self):
|
|
866
|
+
def device_name(self) -> str:
|
|
836
867
|
"""
|
|
837
868
|
Returns the name of the device where the OrtValue's data buffer resides e.g. cpu, cuda, cann
|
|
838
869
|
"""
|
|
839
870
|
return self._ortvalue.device_name().lower()
|
|
840
871
|
|
|
841
|
-
def shape(self):
|
|
872
|
+
def shape(self) -> Sequence[int]:
|
|
842
873
|
"""
|
|
843
874
|
Returns the shape of the data in the OrtValue
|
|
844
875
|
"""
|
|
845
876
|
return self._ortvalue.shape()
|
|
846
877
|
|
|
847
|
-
def data_type(self):
|
|
878
|
+
def data_type(self) -> str:
|
|
848
879
|
"""
|
|
849
|
-
Returns the data type of the data in the OrtValue
|
|
880
|
+
Returns the data type of the data in the OrtValue. E.g. 'tensor(int64)'
|
|
850
881
|
"""
|
|
851
882
|
return self._ortvalue.data_type()
|
|
852
883
|
|
|
853
|
-
def element_type(self):
|
|
884
|
+
def element_type(self) -> int:
|
|
854
885
|
"""
|
|
855
886
|
Returns the proto type of the data in the OrtValue
|
|
856
887
|
if the OrtValue is a tensor.
|
|
857
888
|
"""
|
|
858
889
|
return self._ortvalue.element_type()
|
|
859
890
|
|
|
860
|
-
def has_value(self):
|
|
891
|
+
def has_value(self) -> bool:
|
|
861
892
|
"""
|
|
862
893
|
Returns True if the OrtValue corresponding to an
|
|
863
894
|
optional type contains data, else returns False
|
|
864
895
|
"""
|
|
865
896
|
return self._ortvalue.has_value()
|
|
866
897
|
|
|
867
|
-
def is_tensor(self):
|
|
898
|
+
def is_tensor(self) -> bool:
|
|
868
899
|
"""
|
|
869
900
|
Returns True if the OrtValue contains a Tensor, else returns False
|
|
870
901
|
"""
|
|
871
902
|
return self._ortvalue.is_tensor()
|
|
872
903
|
|
|
873
|
-
def is_sparse_tensor(self):
|
|
904
|
+
def is_sparse_tensor(self) -> bool:
|
|
874
905
|
"""
|
|
875
906
|
Returns True if the OrtValue contains a SparseTensor, else returns False
|
|
876
907
|
"""
|
|
877
908
|
return self._ortvalue.is_sparse_tensor()
|
|
878
909
|
|
|
879
|
-
def is_tensor_sequence(self):
|
|
910
|
+
def is_tensor_sequence(self) -> bool:
|
|
880
911
|
"""
|
|
881
912
|
Returns True if the OrtValue contains a Tensor Sequence, else returns False
|
|
882
913
|
"""
|
|
883
914
|
return self._ortvalue.is_tensor_sequence()
|
|
884
915
|
|
|
885
|
-
def numpy(self):
|
|
916
|
+
def numpy(self) -> np.ndarray:
|
|
886
917
|
"""
|
|
887
918
|
Returns a Numpy object from the OrtValue.
|
|
888
919
|
Valid only for OrtValues holding Tensors. Throws for OrtValues holding non-Tensors.
|
|
@@ -890,7 +921,7 @@ class OrtValue:
|
|
|
890
921
|
"""
|
|
891
922
|
return self._ortvalue.numpy()
|
|
892
923
|
|
|
893
|
-
def update_inplace(self, np_arr):
|
|
924
|
+
def update_inplace(self, np_arr) -> None:
|
|
894
925
|
"""
|
|
895
926
|
Update the OrtValue in place with a new Numpy array. The numpy contents
|
|
896
927
|
are copied over to the device memory backing the OrtValue. It can be used
|
|
@@ -948,7 +979,7 @@ class SparseTensor:
|
|
|
948
979
|
depending on the format
|
|
949
980
|
"""
|
|
950
981
|
|
|
951
|
-
def __init__(self, sparse_tensor):
|
|
982
|
+
def __init__(self, sparse_tensor: C.SparseTensor):
|
|
952
983
|
"""
|
|
953
984
|
Internal constructor
|
|
954
985
|
"""
|
|
@@ -960,11 +991,17 @@ class SparseTensor:
|
|
|
960
991
|
"`Provided object` needs to be of type `onnxruntime.capi.onnxruntime_pybind11_state.SparseTensor`"
|
|
961
992
|
)
|
|
962
993
|
|
|
963
|
-
def _get_c_tensor(self):
|
|
994
|
+
def _get_c_tensor(self) -> C.SparseTensor:
|
|
964
995
|
return self._tensor
|
|
965
996
|
|
|
966
|
-
@
|
|
967
|
-
def sparse_coo_from_numpy(
|
|
997
|
+
@classmethod
|
|
998
|
+
def sparse_coo_from_numpy(
|
|
999
|
+
cls,
|
|
1000
|
+
dense_shape: npt.NDArray[np.int64],
|
|
1001
|
+
values: np.ndarray,
|
|
1002
|
+
coo_indices: npt.NDArray[np.int64],
|
|
1003
|
+
ort_device: OrtDevice,
|
|
1004
|
+
) -> SparseTensor:
|
|
968
1005
|
"""
|
|
969
1006
|
Factory method to construct a SparseTensor in COO format from given arguments
|
|
970
1007
|
|
|
@@ -985,12 +1022,17 @@ class SparseTensor:
|
|
|
985
1022
|
For strings and objects, it will create a copy of the arrays in CPU memory as ORT does not support those
|
|
986
1023
|
on other devices and their memory can not be mapped.
|
|
987
1024
|
"""
|
|
988
|
-
return SparseTensor(
|
|
989
|
-
C.SparseTensor.sparse_coo_from_numpy(dense_shape, values, coo_indices, ort_device._get_c_device())
|
|
990
|
-
)
|
|
1025
|
+
return cls(C.SparseTensor.sparse_coo_from_numpy(dense_shape, values, coo_indices, ort_device._get_c_device()))
|
|
991
1026
|
|
|
992
|
-
@
|
|
993
|
-
def sparse_csr_from_numpy(
|
|
1027
|
+
@classmethod
|
|
1028
|
+
def sparse_csr_from_numpy(
|
|
1029
|
+
cls,
|
|
1030
|
+
dense_shape: npt.NDArray[np.int64],
|
|
1031
|
+
values: np.ndarray,
|
|
1032
|
+
inner_indices: npt.NDArray[np.int64],
|
|
1033
|
+
outer_indices: npt.NDArray[np.int64],
|
|
1034
|
+
ort_device: OrtDevice,
|
|
1035
|
+
) -> SparseTensor:
|
|
994
1036
|
"""
|
|
995
1037
|
Factory method to construct a SparseTensor in CSR format from given arguments
|
|
996
1038
|
|
|
@@ -1011,7 +1053,7 @@ class SparseTensor:
|
|
|
1011
1053
|
For strings and objects, it will create a copy of the arrays in CPU memory as ORT does not support those
|
|
1012
1054
|
on other devices and their memory can not be mapped.
|
|
1013
1055
|
"""
|
|
1014
|
-
return
|
|
1056
|
+
return cls(
|
|
1015
1057
|
C.SparseTensor.sparse_csr_from_numpy(
|
|
1016
1058
|
dense_shape,
|
|
1017
1059
|
values,
|
|
@@ -1021,7 +1063,7 @@ class SparseTensor:
|
|
|
1021
1063
|
)
|
|
1022
1064
|
)
|
|
1023
1065
|
|
|
1024
|
-
def values(self):
|
|
1066
|
+
def values(self) -> np.ndarray:
|
|
1025
1067
|
"""
|
|
1026
1068
|
The method returns a numpy array that is backed by the native memory
|
|
1027
1069
|
if the data type is numeric. Otherwise, the returned numpy array that contains
|
|
@@ -1093,19 +1135,19 @@ class SparseTensor:
|
|
|
1093
1135
|
"""
|
|
1094
1136
|
return self._tensor.format
|
|
1095
1137
|
|
|
1096
|
-
def dense_shape(self):
|
|
1138
|
+
def dense_shape(self) -> npt.NDArray[np.int64]:
|
|
1097
1139
|
"""
|
|
1098
1140
|
Returns a numpy array(int64) containing a dense shape of a sparse tensor
|
|
1099
1141
|
"""
|
|
1100
1142
|
return self._tensor.dense_shape()
|
|
1101
1143
|
|
|
1102
|
-
def data_type(self):
|
|
1144
|
+
def data_type(self) -> str:
|
|
1103
1145
|
"""
|
|
1104
1146
|
Returns a string data type of the data in the OrtValue
|
|
1105
1147
|
"""
|
|
1106
1148
|
return self._tensor.data_type()
|
|
1107
1149
|
|
|
1108
|
-
def device_name(self):
|
|
1150
|
+
def device_name(self) -> str:
|
|
1109
1151
|
"""
|
|
1110
1152
|
Returns the name of the device where the SparseTensor data buffers reside e.g. cpu, cuda
|
|
1111
1153
|
"""
|
|
Binary file
|
|
Binary file
|