tensorrt-cu12-bindings 10.13.3.9.post1__cp312-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,224 @@
1
+ #
2
+ # SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ import ctypes
19
+ import os
20
+ import sys
21
+ import warnings
22
+
23
+ # For standalone wheels, attempt to import the wheel containing the libraries.
24
+ _libs_wheel_imported = False
25
+ try:
26
+ import tensorrt_libs
27
+ except (ImportError, ModuleNotFoundError):
28
+ pass
29
+ else:
30
+ _libs_wheel_imported = True
31
+
32
+ _trt_lib_suffix = ""
33
+ if "nvinfer".strip() == "tensorrt_rtx":
34
+ _trt_lib_suffix = "_13"
35
+
36
+ if not _libs_wheel_imported and sys.platform.startswith("win"):
37
+ log_found_dlls = bool(int(os.environ.get("TRT_LOG_FOUND_DLLS", 0)))
38
+ # On Windows, we need to manually open the TensorRT libraries - otherwise we are unable to
39
+ # load the bindings. If we imported the tensorrt_libs wheel, then that should have taken care of it for us.
40
+ def find_lib(name):
41
+ paths = os.environ["PATH"].split(os.path.pathsep)
42
+
43
+ # Add ../tensorrt.libs to the search path. This allows repackaging non-standalone TensorRT wheels as standalone
44
+ # using delvewheel (with the --no-mangle-all flag set) to work properly.
45
+ paths.append(os.path.join(os.path.dirname(__file__), os.pardir, "tensorrt.libs"))
46
+
47
+ for path in paths:
48
+ libpath = os.path.join(path, name)
49
+ if os.path.isfile(libpath):
50
+ if log_found_dlls:
51
+ print(f"Found {name} in path: {libpath}")
52
+ return libpath
53
+
54
+ if False and name.startswith("nvinfer_plugin"):
55
+ return None
56
+
57
+ if name.startswith("nvinfer_builder_resource"):
58
+ return None
59
+
60
+ raise FileNotFoundError(
61
+ "Could not find: {:}. Is it on your PATH?\nNote: Paths searched were:\n{:}".format(name, paths)
62
+ )
63
+
64
+ # Order matters here because of dependencies
65
+ LIBRARIES = {
66
+ "tensorrt": [
67
+ f"nvinfer_10{_trt_lib_suffix}.dll",
68
+ "nvinfer_plugin_10.dll",
69
+ f"nvonnxparser_10{_trt_lib_suffix}.dll",
70
+ "nvinfer_builder_resource_10.dll",
71
+ ],
72
+ "tensorrt_rtx": [
73
+ f"nvinfer_10{_trt_lib_suffix}.dll",
74
+ "nvinfer_plugin_10.dll",
75
+ f"nvonnxparser_10{_trt_lib_suffix}.dll",
76
+ "nvinfer_builder_resource_10.dll",
77
+ ],
78
+ "tensorrt_dispatch": [
79
+ "nvinfer_dispatch_10.dll",
80
+ ],
81
+ "tensorrt_lean": [
82
+ "nvinfer_lean_10.dll",
83
+ ],
84
+ }["tensorrt"]
85
+
86
+ for lib in LIBRARIES:
87
+ lib_path = find_lib(lib)
88
+ if not lib_path:
89
+ continue
90
+ assert os.path.isfile(lib_path)
91
+ ctypes.CDLL(lib_path)
92
+
93
+ del _libs_wheel_imported
94
+ del _trt_lib_suffix
95
+
96
+ from .tensorrt import *
97
+
98
+ __version__ = "10.13.3.9.post1"
99
+
100
+
101
+ # Provides Python's `with` syntax
102
+ def common_enter(this):
103
+ warnings.warn(
104
+ "Context managers for TensorRT types are deprecated. "
105
+ "Memory will be freed automatically when the reference count reaches 0.",
106
+ DeprecationWarning,
107
+ )
108
+ return this
109
+
110
+
111
+ def common_exit(this, exc_type, exc_value, traceback):
112
+ """
113
+ Context managers are deprecated and have no effect. Objects are automatically freed when
114
+ the reference count reaches 0.
115
+ """
116
+ pass
117
+
118
+
119
+ # Logger does not have a destructor.
120
+ ILogger.__enter__ = common_enter
121
+ ILogger.__exit__ = lambda this, exc_type, exc_value, traceback: None
122
+
123
+ ICudaEngine.__enter__ = common_enter
124
+ ICudaEngine.__exit__ = common_exit
125
+
126
+ IExecutionContext.__enter__ = common_enter
127
+ IExecutionContext.__exit__ = common_exit
128
+
129
+ Runtime.__enter__ = common_enter
130
+ Runtime.__exit__ = common_exit
131
+
132
+ IHostMemory.__enter__ = common_enter
133
+ IHostMemory.__exit__ = common_exit
134
+
135
+ if "tensorrt" == "tensorrt" or "tensorrt" == "tensorrt_rtx":
136
+ Builder.__enter__ = common_enter
137
+ Builder.__exit__ = common_exit
138
+
139
+ INetworkDefinition.__enter__ = common_enter
140
+ INetworkDefinition.__exit__ = common_exit
141
+
142
+ OnnxParser.__enter__ = common_enter
143
+ OnnxParser.__exit__ = common_exit
144
+
145
+ IBuilderConfig.__enter__ = common_enter
146
+ IBuilderConfig.__exit__ = common_exit
147
+
148
+
149
+ # Add logger severity into the default implementation to preserve backwards compatibility.
150
+ Logger.Severity = ILogger.Severity
151
+
152
+ for attr, value in ILogger.Severity.__members__.items():
153
+ setattr(Logger, attr, value)
154
+
155
+
156
+ # Computes the volume of an iterable.
157
+ def volume(iterable):
158
+ """
159
+ Computes the volume of an iterable.
160
+
161
+ :arg iterable: Any python iterable, including a :class:`Dims` object.
162
+
163
+ :returns: The volume of the iterable. This will return 1 for empty iterables, as a scalar has an empty shape and the volume of a tensor with empty shape is 1.
164
+ """
165
+ vol = 1
166
+ for elem in iterable:
167
+ vol *= elem
168
+ return vol
169
+
170
+
171
+ # Converts a TensorRT datatype to the equivalent numpy type.
172
+ def nptype(trt_type):
173
+ """
174
+ Returns the numpy-equivalent of a TensorRT :class:`DataType` .
175
+
176
+ :arg trt_type: The TensorRT data type to convert.
177
+
178
+ :returns: The equivalent numpy type.
179
+ """
180
+ import numpy as np
181
+
182
+ mapping = {
183
+ float32: np.float32,
184
+ float16: np.float16,
185
+ int8: np.int8,
186
+ int32: np.int32,
187
+ int64: np.int64,
188
+ bool: np.bool_,
189
+ uint8: np.uint8,
190
+ # Note: fp8 and bfloat16 have no equivalent numpy type
191
+ }
192
+ if trt_type in mapping:
193
+ return mapping[trt_type]
194
+ raise TypeError("Could not resolve TensorRT datatype to an equivalent numpy datatype.")
195
+
196
+
197
+ # Add a numpy-like itemsize property to the datatype.
198
+ def _itemsize(trt_type):
199
+ """
200
+ Returns the size in bytes of this :class:`DataType`.
201
+ The returned size is a rational number, possibly a `Real` denoting a fraction of a byte.
202
+
203
+ :arg trt_type: The TensorRT data type.
204
+
205
+ :returns: The size of the type.
206
+ """
207
+ mapping = {
208
+ float32: 4,
209
+ float16: 2,
210
+ bfloat16: 2,
211
+ int8: 1,
212
+ int32: 4,
213
+ int64: 8,
214
+ bool: 1,
215
+ uint8: 1,
216
+ fp8: 1,
217
+ int4: 0.5,
218
+ fp4: 0.5,
219
+ }
220
+ if trt_type in mapping:
221
+ return mapping[trt_type]
222
+
223
+
224
+ DataType.itemsize = property(lambda this: _itemsize(this))
@@ -0,0 +1,46 @@
1
+ #
2
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ import tensorrt as trt
19
+
20
+ logger = trt.Logger()
21
+ logger.log(trt.Logger.WARNING, "Functionality provided through tensorrt.plugin module is experimental.")
22
+
23
+ # export.public_api() will expose things here. To make sure that happens, we just need to
24
+ # import all the submodules so that the decorator is actually executed (__discover_modules() below).
25
+ __all__ = []
26
+
27
+ def __discover_modules():
28
+ import importlib
29
+ import pkgutil
30
+
31
+ mods = [importlib.import_module(__package__)]
32
+ while mods:
33
+ mod = mods.pop(0)
34
+
35
+ yield mod
36
+
37
+ if hasattr(mod, "__path__"):
38
+ mods.extend(
39
+ [
40
+ importlib.import_module(f"{mod.__name__}.{submod.name}")
41
+ for submod in pkgutil.iter_modules(mod.__path__)
42
+ ]
43
+ )
44
+
45
+
46
+ _ = list(__discover_modules())
@@ -0,0 +1,270 @@
1
+ #
2
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ import builtins
19
+ import tensorrt as trt
20
+ from typing import List, Iterable
21
+ import copy
22
+
23
+ from ._utils import _str_to_data_type
24
+ from ._export import public_api
25
+
26
+
27
+ # "onesided" means either type or format combinations. After combinations for each are separately generated, we will combine them later.
28
+ # e.g. io_variants = ["FP32|FP16", "FP32|FP16", "FP32*FP16"] for a plugin with 3 I/Os. i.e. I/O indices 0 and 1 are dependently either FP32/FP16 and index 2 is independently FP32/FP16.
29
+ # There will be 2 * 2 = 4 combinations here: ["FP32", "FP32", "FP32"], ["FP16", "FP16", "FP32"], ["FP32", "FP32", "FP16"], ["FP16", "FP16", "FP16"]
30
+ def _gen_onesided_combinations(io_variants):
31
+
32
+ # Algorithm:
33
+ # (1) Ignore independent variants and count the (max) number of dependent variants `mx_poly`
34
+ # (2) Compile initial list of #`mx_poly` combinations using the first option (option 0) for any independent variants
35
+ # (3) For each independent variant IO index, add combinations with that index replaced by option 1, 2, ...
36
+
37
+ combinations = []
38
+ mx_poly = 0 # This is the number of dependent variants
39
+
40
+ for io_variant in io_variants:
41
+ io_variant_list = io_variant.split("|")
42
+
43
+ if len(io_variant_list) > 1:
44
+ if "*" in io_variant:
45
+ raise ValueError(
46
+ f"Type/Format '{io_variant}' contains both '|' and '*'"
47
+ )
48
+ if mx_poly > 1:
49
+ if mx_poly != len(io_variant_list):
50
+ raise ValueError(
51
+ f"Type/Format combinations {io_variants} contain illegal dependent lengths"
52
+ )
53
+
54
+ mx_poly = builtins.max(mx_poly, len(io_variant_list))
55
+
56
+ for _ in range(mx_poly):
57
+ combinations.append([None] * len(io_variants))
58
+
59
+ for j, io_variant in enumerate(io_variants):
60
+ io_variant_list = io_variant.split("|")
61
+
62
+ if len(io_variant_list) == 1:
63
+ if "*" in io_variant:
64
+ io_variant_list = io_variant.split("*")
65
+ for i in range(len(combinations)):
66
+ combinations[i][j] = io_variant_list[0]
67
+ else:
68
+ for k in range(len(io_variant_list)):
69
+ combinations[k][j] = io_variant_list[k]
70
+
71
+ for j, io_variant in enumerate(io_variants):
72
+ new_combs = []
73
+ if "*" in io_variant:
74
+ io_variant_list = io_variant.split("*")
75
+ for k in range(1, len(io_variant_list)):
76
+ for c in combinations:
77
+ new_c = copy.deepcopy(c)
78
+ new_c[j] = io_variant_list[k]
79
+ new_combs.append(new_c)
80
+ combinations.extend(new_combs)
81
+
82
+ return combinations
83
+
84
+
85
+ class _TypeFormatCombination:
86
+ def __init__(self, num=0):
87
+ self.types = [None] * num
88
+ self.layouts = [None] * num
89
+ self.tactics = []
90
+
91
+ def set_types(self, types):
92
+ self.types = types
93
+
94
+ def set_layouts(self, layouts=None):
95
+ if isinstance(layouts, List):
96
+ self.layouts = layouts
97
+ else:
98
+ self.layouts = [layouts] * len(self.types)
99
+
100
+ def __hash__(self):
101
+ return hash((tuple(self.types), tuple(self.layouts)))
102
+
103
+ def __eq__(self, other):
104
+ return (
105
+ isinstance(other, _TypeFormatCombination)
106
+ and self.types == other.types
107
+ and self.layouts == other.layouts
108
+ )
109
+
110
+ def __str__(self) -> str:
111
+ return "{" + str(self.types) + ", " + str(self.layouts) + "}"
112
+
113
+
114
+ @public_api()
115
+ class AutoTuneCombination:
116
+ def __init__(
117
+ self, io_types: str = None, layouts: str = None, tactics: Iterable[int] = None
118
+ ):
119
+ """
120
+ Construct a set of supported type/format combinations of a plugin's I/O.
121
+
122
+ Any custom *tactic* s per each such type/format combination can also be advertised. A tactic is simply another way to
123
+ calculate the output of a plugin for the same type/format combination of the I/O (e.g. if there are multiple kernels available).
124
+
125
+ Args:
126
+ io_types (str, optional): A string representation of a type combination.
127
+
128
+ Valid format is "type0,type1,...,type#io" where 'type' is of the form "TYPE0[sep]TYPE1[sep]...".
129
+
130
+ TYPE is a valid string representation of a `trt.DataType`. These include "FP32" for trt.float32, "FP16" for trt.float16. The string representation of other data types is the same as their name in the trt.DataType enum.
131
+
132
+
133
+ [sep] is a valid separator, which is either '|' or '*'. Only one of these separators can appear in a given `io_types`.
134
+
135
+ (1). '|' indicates a dependent combination: the dependence of the type of one I/O to another I/O. e.g. "FP32|FP16,FP32|FP16" indicates the IO can only be both FP32 or both FP16.
136
+
137
+ (2). '*' indicates an independent combination. e.g. "FP32*FP16,FP32|FP16,FP32|FP16" indicates that the first input is independently either FP32 or FP16 regardless of the rest of the IO.
138
+
139
+ layouts (str, optional): A string representation of a format combination.
140
+
141
+ Valid format is "format0,format1,...,format#io" where 'format' is of the form "FORMAT0[sep]FORMAT1[sep]...".
142
+
143
+ FORMAT is a valid string representation of a `trt.TensorFormat`. These are string versions for the enum values of `trt.TensorFormat`. e.g. "LINEAR" for `trt.TensorFormat.LINEAR`.
144
+
145
+ [sep] is a valid separator, which is either '|' or '*'. The rules are the same as for `io_types`.
146
+
147
+ tactics (Iterable[int], optional): Custom tactics for this type/format combination. Each custom tactic must be a positive integer. Defaults to default tactic (0).
148
+
149
+ .. code-block:: python
150
+ :linenos:
151
+ :caption: For a plugin with 3 I/Os, I/O indices 0 and 1 are dependently either FP32/FP16 and index 2 is independently FP32/FP16.
152
+
153
+ @trtp.autotune("my::plugin")
154
+ def autotune(inp0: trtp.TensorDesc, inp1: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc]) -> List[trtp.AutoTuneCombination]:
155
+ # The following would result in the following type combinations:
156
+ # [FP32, FP32, FP32], [FP16, FP16, FP32], [FP32, FP32, FP16], [FP16, FP16, FP16]
157
+ return [trtp.AutoTuneCombination("FP32|FP16, FP32|FP16, FP32|FP16", "LINEAR", [1, 2])]
158
+
159
+ .. code-block:: python
160
+ :linenos:
161
+ :caption: For a plugin with 2 I/Os, the input/output supports either LINEAR or HWC format for FP32 and LINEAR format for FP16.
162
+
163
+ @trtp.autotune("my::plugin")
164
+ def autotune(inp0: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc]) -> List[trtp.AutoTuneCombination]:
165
+ # Even though (FP16, HWC) is not a valid combination (see next example), TRT should intelligently reject those
166
+ # and pass the following combinations to the impl function:
167
+ # [{FP32, FP32}, {LINEAR, LINEAR}], [{FP32, FP32}, {HWC, LINEAR}], [{FP16, FP32}, {LINEAR, LINEAR}]
168
+ return [trtp.AutoTuneCombination("FP32*FP16, FP32", "LINEAR*HWC, LINEAR", [1, 2])]
169
+
170
+ .. code-block:: python
171
+ :linenos:
172
+ :caption: For a plugin with 2 I/Os, the input/output supports either LINEAR or HWC format for FP32 and LINEAR format for FP16 (second method).
173
+
174
+ @trtp.autotune("my::plugin")
175
+ def autotune(inp0: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc]) -> List[trtp.AutoTuneCombination]:
176
+ # We can use two AutoTuneCombination objects to avoid communicating illegal combinations
177
+ return [trtp.AutoTuneCombination("FP32*FP16, FP32", "LINEAR, LINEAR", [1, 2]), trtp.AutoTuneCombination("FP32, FP32", "HWC, LINEAR", [1, 2])]
178
+ """
179
+
180
+ if io_types is not None:
181
+ self.io_types = [s.strip() for s in io_types.split(",")]
182
+ if layouts is None:
183
+ layouts = "LINEAR"
184
+ self.layouts = [s.strip() for s in layouts.split(",")]
185
+
186
+ if len(self.layouts) > 1:
187
+ assert len(self.io_types) == len(self.layouts)
188
+
189
+ if len(self.io_types) > len(self.layouts):
190
+ assert len(self.layouts) == 1
191
+ self.layouts = [self.layouts[0]] * len(self.io_types)
192
+ else:
193
+ self.io_types = []
194
+ self.layouts = []
195
+
196
+ self.combinations = []
197
+ self._tactics = tactics
198
+
199
+ def pos(self, pos: Iterable[int], io_types: str, layouts: str = "LINEAR") -> None:
200
+ """
201
+ Specify I/O types and formats for a specified set of I/O indices.
202
+
203
+ Args:
204
+ pos (Iterable[int]): I/O indices. Input indices are [0, 1, ..., #inputs - 1] and output indices are [#inputs, #inputs + 1, ..., #inputs + #outputs - 1].
205
+ io_types (str): Data types for these I/O indices.
206
+ layouts (str, optional): Tensor format(s) for these I/O indices. Defaults to "LINEAR".
207
+ Raises:
208
+ ValueError: If types or layouts for any of these I/O indices is already specified.
209
+
210
+ .. code-block:: python
211
+ :linenos:
212
+ :caption: For a plugin with 3 I/Os, I/O indices 0 and 1 are dependently either FP32/FP16 and index 2 is independently FP32/FP16.
213
+
214
+ @trtp.autotune("my::plugin")
215
+ def autotune(inp0: trtp.TensorDesc, inp1: trtp.TensorDesc, outputs: Tuple[trtp.TensorDesc]) -> List[trtp.AutoTuneCombination]:
216
+ c = trtp.AutoTuneCombination()
217
+ c.pos([0, 1], "FP32|FP16", "LINEAR")
218
+ c.pos(2, "FP32*FP16") # Omitting format is the same as declaring it to be LINEAR.
219
+ c.tactics([1, 2])
220
+ return [c]
221
+ """
222
+ if max(pos) >= len(self.io_types):
223
+ self.io_types.extend([None] * (max(pos) + 1 - len(self.io_types)))
224
+ self.layouts.extend([None] * (max(pos) + 1 - len(self.layouts)))
225
+ assert len(self.io_types) == len(self.layouts)
226
+
227
+ for p in pos:
228
+ if self.io_types[p] is not None:
229
+ raise ValueError(f"Type(s) for position {p} already specified")
230
+ if self.layouts[p] is not None:
231
+ raise ValueError(f"Layout(s) for position {p} already specified")
232
+ self.io_types[p] = io_types
233
+ self.layouts[p] = layouts
234
+
235
+ def tactics(self, tactics: Iterable[int]) -> None:
236
+ """
237
+ Specify custom tactics for this type/format combination
238
+
239
+ Args:
240
+ tactics (Iterable[int]): Custom tactics. These must be positive integers.
241
+ """
242
+ self._tactics = tactics
243
+
244
+ def _generate_combinations(self):
245
+
246
+ self.combinations = []
247
+
248
+ type_combinations = _gen_onesided_combinations(self.io_types)
249
+ layout_combinations = _gen_onesided_combinations(self.layouts)
250
+
251
+ for t in type_combinations:
252
+ for l in layout_combinations:
253
+ c = _TypeFormatCombination(len(self.io_types))
254
+ c.types = [_str_to_data_type(tt) for tt in t]
255
+ c.layouts = [getattr(trt.TensorFormat, ff) for ff in l]
256
+ c.tactics = self._tactics
257
+ self.combinations.append(c)
258
+
259
+ def _get_combinations(self):
260
+ self._generate_combinations()
261
+ return self.combinations
262
+
263
+ def _check(self, pos, type, layout):
264
+ for i in range(len(self.combinations)):
265
+ if (
266
+ self.combinations[i].types[pos] == _str_to_data_type(type)
267
+ and self.combinations[i].layouts[pos] == layout.name
268
+ ):
269
+ return True
270
+ return False
@@ -0,0 +1,39 @@
1
+ #
2
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ import tensorrt as trt
19
+ from types import ModuleType
20
+ import importlib
21
+
22
+ def public_api(module: ModuleType = None, symbol: str = None):
23
+ def export_impl(obj):
24
+ nonlocal module, symbol
25
+
26
+ module = module or importlib.import_module(__package__)
27
+ symbol = symbol or obj.__name__
28
+
29
+ if not hasattr(module, "__all__"):
30
+ module.__all__ = []
31
+
32
+ module.__all__.append(symbol)
33
+ setattr(module, symbol, obj)
34
+
35
+ return obj
36
+
37
+ return export_impl
38
+
39
+ IS_AOT_ENABLED = hasattr(trt, "QuickPluginCreationRequest")