tensorrt-cu12-bindings 10.14.1.48.post1__cp39-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,132 @@
1
+ #
2
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ from typing import Union, Tuple
19
+ import tensorrt as trt
20
+ from ._tensor import ShapeExpr, TensorDesc, ShapeExprs, SizeTensorDesc
21
+ from ._export import public_api
22
+
23
+ # Miscellaneous top-level functions accessible through `tensorrt.plugin`
24
+
25
+ # Performs `trt.DimensionOperation.CEIL_DIV`
26
+ @public_api()
27
+ def cdiv(first: Union[int, ShapeExpr], second: Union[int, ShapeExpr]) -> ShapeExpr:
28
+ """
29
+ Computes symbolic ceiling division of `first` by `second`
30
+
31
+ Args:
32
+ first (Union[int, ShapeExpr]): Dividend
33
+ second (Union[int, ShapeExpr]): Divisor
34
+
35
+ Raises:
36
+ ValueError: If both arguments are `int`\s or if `second` evaluates to 0
37
+
38
+ Returns:
39
+ ShapeExpr: Symbolic expression for the ceiling division of `first` by `second`
40
+ """
41
+ if isinstance(first, int):
42
+ if isinstance(second, int):
43
+ raise ValueError("Both arguments cannot be 'int's")
44
+ first = ShapeExpr(first)
45
+
46
+ return first._op(trt.DimensionOperation.CEIL_DIV, second)
47
+
48
+
49
+ # Performs `trt.DimensionOperation.MAX`
50
+ @public_api()
51
+ def max(first: Union[int, ShapeExpr], second: Union[int, ShapeExpr]) -> ShapeExpr:
52
+ """
53
+ Computes the maximum of `first` and `second`
54
+
55
+ Args:
56
+ first (Union[int, ShapeExpr]): First operand
57
+ second (Union[int, ShapeExpr]): Second operand
58
+
59
+ Raises:
60
+ ValueError: If both arguments are `int`\s
61
+
62
+ Returns:
63
+ ShapeExpr: Symbolic expression for the maximum of `first` and `second`
64
+ """
65
+ if isinstance(first, int):
66
+ if isinstance(second, int):
67
+ raise ValueError("Both arguments cannot be 'int's")
68
+ first = ShapeExpr(first)
69
+
70
+ return first._op(trt.DimensionOperation.MAX, second)
71
+
72
+
73
+ # Performs `trt.DimensionOperation.MIN`
74
+ @public_api()
75
+ def min(first: Union[int, ShapeExpr], second: Union[int, ShapeExpr]) -> ShapeExpr:
76
+ """
77
+ Computes the minimum of `first` and `second`
78
+
79
+ Args:
80
+ first (Union[int, ShapeExpr]): First operand
81
+ second (Union[int, ShapeExpr]): Second operand
82
+
83
+ Raises:
84
+ ValueError: If both arguments are `int`\s
85
+
86
+ Returns:
87
+ ShapeExpr: Symbolic expression for the minimum of `first` and `second`
88
+ """
89
+ if isinstance(first, int):
90
+ if isinstance(second, int):
91
+ raise ValueError("Both arguments cannot be 'int's")
92
+ first = ShapeExpr(first)
93
+
94
+ return first._op(trt.DimensionOperation.MIN, second)
95
+
96
+
97
+ # Declare a size tensor descriptor with the specified autotune shape expression `opt` and `upper-bound` shape expression
98
+ @public_api()
99
+ def size_tensor(opt: ShapeExpr, upper_bound: ShapeExpr) -> SizeTensorDesc:
100
+ """
101
+ Constructs a size tensor with the specified autotune shape expression `opt` and `upper_bound`
102
+
103
+ Args:
104
+ opt (ShapeExpr): Symbolic expression for the extent of this size tensor to use in the autotune process of the engine build
105
+ upper_bound (ShapeExpr): Symbolic expression for the upper-bound of this size tensor
106
+
107
+ Returns:
108
+ SizeTensorDesc: A tensor descriptor for a size tensor with the specified autotune extent and upper-bound
109
+ """
110
+ return SizeTensorDesc(opt, upper_bound)
111
+
112
+ # Create a TensorDesc using shape expressions and a dtype
113
+ @public_api()
114
+ def from_shape_expr(shape_expr: Union[Tuple[Union[ShapeExpr, int]], ShapeExprs], dtype: trt.DataType) -> TensorDesc:
115
+ """
116
+ Constructs a tensor descriptor with the specified shape expression and data type
117
+
118
+ Args:
119
+ shape_expr (Union[Tuple[Union[ShapeExpr, int]], ShapeExprs]): Expressions or constants denoting the shape of the tensor
120
+ dtype (trt.DataType): Data type of the tensor
121
+
122
+ Returns:
123
+ TensorDesc: Tensor descriptor with the specified shape expression and data type
124
+ """
125
+ if isinstance(shape_expr, tuple):
126
+ shape_expr_ = ShapeExprs.from_tuple(shape_expr)
127
+ else:
128
+ shape_expr_ = shape_expr
129
+
130
+ return TensorDesc(shape_expr_, dtype)
131
+
132
+
@@ -0,0 +1,77 @@
1
+ #
2
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ import tensorrt as trt
19
+ import numpy as np
20
+ import typing
21
+
22
+ _numpy_to_plugin_field_type = {
23
+ np.dtype('int32'): trt.PluginFieldType.INT32,
24
+ np.dtype('int16'): trt.PluginFieldType.INT16,
25
+ np.dtype('int8'): trt.PluginFieldType.INT8,
26
+ np.dtype('bool'): trt.PluginFieldType.INT8,
27
+ np.dtype('int64'): trt.PluginFieldType.INT64,
28
+ np.dtype('float32'): trt.PluginFieldType.FLOAT32,
29
+ np.dtype('float64'): trt.PluginFieldType.FLOAT64,
30
+ np.dtype('float16'): trt.PluginFieldType.FLOAT16
31
+ }
32
+
33
+ _built_in_to_plugin_field_type = {
34
+ int: trt.PluginFieldType.INT64,
35
+ float: trt.PluginFieldType.FLOAT64,
36
+ bool: trt.PluginFieldType.INT8,
37
+ # str is handled separately, so not needed here
38
+ }
39
+
40
+ def _str_to_data_type(dtype: str) -> trt.DataType:
41
+ if dtype == "FP32":
42
+ return trt.DataType.FLOAT
43
+ if dtype == "FP16":
44
+ return trt.DataType.HALF
45
+ try:
46
+ return getattr(trt.DataType, dtype)
47
+ except KeyError:
48
+ raise ValueError(f"Unknown data type string '{dtype}'") from None
49
+
50
+
51
+ def _join_with(lst, middle = False, delim = ", "):
52
+ if len(lst) == 0:
53
+ return ""
54
+
55
+ ret = ""
56
+ if middle:
57
+ ret += ", "
58
+
59
+ ret += delim.join(lst)
60
+
61
+ return ret
62
+
63
+ def _is_npt_ndarray(annotation):
64
+ return (typing.get_origin(annotation) == np.ndarray) or (hasattr(annotation, "__origin__") and annotation.__origin__ == np.ndarray)
65
+
66
+ def _is_numpy_array(annotation):
67
+ return (annotation == np.ndarray) or _is_npt_ndarray(annotation)
68
+
69
+ def _infer_numpy_type(annotation):
70
+ assert _is_npt_ndarray(annotation)
71
+ annot_args = typing.get_args(annotation) or annotation.__args__
72
+ if len(annot_args) >= 2:
73
+ np_type = typing.get_args(annot_args[1]) or annot_args[1].__args__
74
+ if len(np_type) >= 1:
75
+ return np_type[0]
76
+
77
+ raise AttributeError("Improper annotation for numpy array. Annotate numpy array attributes using 'numpy.typing.NDArray[dtype]', where 'dtype' is the expected numpy dtype of the array.")
@@ -0,0 +1,475 @@
1
+ #
2
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ import inspect
19
+ import numpy as np
20
+ import typing
21
+ import types
22
+
23
+ from ._utils import _is_numpy_array, _join_with, _infer_numpy_type, _is_npt_ndarray
24
+ from ._tensor import TensorDesc, Tensor, SymExprs
25
+ from ._export import IS_AOT_ENABLED
26
+ if IS_AOT_ENABLED:
27
+ from ._tensor import KernelLaunchParams
28
+ from ._autotune import AutoTuneCombination
29
+
30
+ SERIALIZABLE_BUILTIN_TYPES = (int, float, bytes, bool, str)
31
+ SERIALIZABLE_NP_DTYPES = (
32
+ np.int8,
33
+ np.int16,
34
+ np.int32,
35
+ np.int64,
36
+ np.float16,
37
+ np.float32,
38
+ np.float64,
39
+ bool,
40
+ np.bool_,
41
+ )
42
+
43
+ # Reserve some namespaces for future use/avoid confusion
44
+ RESERVED_NAMESPACES = {
45
+ "",
46
+ "trt",
47
+ "tensorrt",
48
+ "std",
49
+ }
50
+
51
+ DISALLOWED_ATTR_NAMES = {
52
+ "outputs",
53
+ "stream",
54
+ "tactic",
55
+ }
56
+
57
+ def _validate_name_and_namespace(ns: str, name: str):
58
+ if "." in ns:
59
+ raise ValueError(
60
+ f"Provided namespace {ns} cannot have any '.' in trt.plugin.register(\"{ns}::{name}\", ...)"
61
+ )
62
+
63
+ if "." in name:
64
+ raise ValueError(
65
+ f"Provided name {name} cannot have any '.' in trt.plugin.register(\"{ns}::{name}\", ...)"
66
+ )
67
+
68
+ if ns in RESERVED_NAMESPACES:
69
+ raise ValueError(
70
+ f"Provided namespace {ns} is a reserved namespace"
71
+ )
72
+
73
+
74
+ # Parse `tensorrt.plugin.register` schema
75
+ def _parse_register_inputs(register_func, lazy_register):
76
+ tensor_names = []
77
+ input_attrs = (
78
+ dict()
79
+ ) # order is important here but for Python >= 3.7, dict respects key order
80
+
81
+ schema_chunks = []
82
+
83
+ # TensorDescs and attribute args cannot be interspersed, so remember when we saw the first attribute arg
84
+ saw_first_attr = False
85
+
86
+ # Map of (attr_name: str) -> (is_builtin_type?: bool, type annotation: str)
87
+ attrs_types = {}
88
+
89
+ sig = inspect.signature(register_func)
90
+
91
+ for idx, (name, param) in enumerate(sig.parameters.items()):
92
+
93
+ if param.kind not in (
94
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
95
+ inspect.Parameter.KEYWORD_ONLY,
96
+ ):
97
+ raise ValueError(
98
+ f"Argument {name} is not a positional-or-keyword or keyword-only arg"
99
+ )
100
+
101
+ # Type annotations are manadatory for `tensorrt.plugin.register` args
102
+ if param.annotation == inspect.Parameter.empty:
103
+ raise ValueError(
104
+ f"Argument {name} does not have a type annotation. Please mark as TensorDesc or one of the serializable attribute types."
105
+ )
106
+
107
+ # Presently, we do not support default values for attributes
108
+ if param.default is not inspect.Parameter.empty:
109
+ raise ValueError(
110
+ f"Argument {name} has a default value. Default values are not supported yet."
111
+ )
112
+
113
+
114
+ if issubclass(param.annotation, TensorDesc):
115
+ if saw_first_attr:
116
+ raise ValueError(
117
+ f"TensorDescs args and attribute args cannot be interspersed. Received function with signature {sig}."
118
+ )
119
+
120
+ tensor_names.append(name)
121
+ schema_chunks.append(f"TensorDesc {name}")
122
+ # At this point, we don't validate attribute types since we only care about the types of serializable attributes
123
+ # However, we memorize name and type so that we may validate that the autotune function maintains consistency
124
+ else:
125
+ if idx == 0:
126
+ raise ValueError(
127
+ f"TensorDescs args should come first, followed by attributes. Received function with signature {sig}."
128
+ )
129
+
130
+ if name in DISALLOWED_ATTR_NAMES:
131
+ raise ValueError(
132
+ f"'{name}' is not allowed as a plugin attribute name."
133
+ )
134
+
135
+ if param.annotation not in SERIALIZABLE_BUILTIN_TYPES:
136
+ if _is_numpy_array(param.annotation):
137
+ if not lazy_register:
138
+ if param.annotation == np.ndarray:
139
+ raise ValueError(
140
+ "If using non-lazy registration, annotate numpy array attributes using 'numpy.typing.NDArray[dtype]', where 'dtype' is the expected numpy dtype of the array."
141
+ )
142
+
143
+ if _is_npt_ndarray(param.annotation):
144
+ np_dtype = _infer_numpy_type(param.annotation)
145
+ if np_dtype not in SERIALIZABLE_NP_DTYPES:
146
+ raise ValueError(
147
+ f"Attribute '{name}' is not a supported numpy array type. Supported numpy arrays type are {SERIALIZABLE_NP_DTYPES}."
148
+ )
149
+ attrs_types[name] = (False, np_dtype)
150
+
151
+ else:
152
+ raise ValueError(
153
+ f"Attribute '{name}' of type {param.annotation} is not a supported serializable type. Supported types are {SERIALIZABLE_BUILTIN_TYPES} or numpy arrays of type {SERIALIZABLE_NP_DTYPES}."
154
+ )
155
+ else:
156
+ attrs_types[name] = (True, param.annotation)
157
+
158
+ saw_first_attr = True
159
+
160
+ schema_chunks.append(f"{param.annotation} {name}")
161
+ input_attrs[name] = param.annotation
162
+
163
+ return (
164
+ tensor_names,
165
+ input_attrs,
166
+ f"({_join_with(schema_chunks)})",
167
+ attrs_types,
168
+ )
169
+
170
+
171
+ def _parse_register_return(register_func):
172
+ sig = inspect.signature(register_func)
173
+
174
+ ret_annotation = sig.return_annotation
175
+
176
+ if ret_annotation == inspect.Parameter.empty:
177
+ raise ValueError(
178
+ f"No return annotation found for register function. Received signature {sig}."
179
+ )
180
+
181
+ if typing.get_origin(ret_annotation) is not tuple:
182
+ if not inspect.isclass(ret_annotation) or not issubclass(
183
+ ret_annotation, TensorDesc
184
+ ):
185
+ raise ValueError(
186
+ f"Return argument is of type {ret_annotation}. Return types can only be TensorDesc or Tuple[TensorDesc]."
187
+ )
188
+
189
+ num_outputs = 1
190
+ else:
191
+ args = typing.get_args(ret_annotation)
192
+
193
+ for arg in args:
194
+ if not issubclass(arg, TensorDesc):
195
+ raise ValueError(
196
+ f"Return argument is of type {ret_annotation}. Return types can only be TensorDesc or Tuple[TensorDesc]."
197
+ )
198
+
199
+ num_outputs = len(args)
200
+
201
+ return num_outputs
202
+
203
+
204
+ def _validate_impl(impl_func, plugin_def):
205
+ impl_attr_names = []
206
+ found_tactic = False
207
+
208
+ sig = inspect.signature(impl_func)
209
+ registered_attr_names = plugin_def.input_attrs.keys()
210
+
211
+ # input arg annotations are optional, but we will validate if provided
212
+ for name, param in sig.parameters.items():
213
+ # tactic arg is optional in impl function. If specified, remember so that we can pass it during enqueue.
214
+ if name == "tactic":
215
+ found_tactic = True
216
+ if param.annotation != inspect.Parameter.empty:
217
+ if name == "outputs":
218
+ if typing.get_origin(param.annotation) is not tuple:
219
+ raise ValueError(
220
+ f"'outputs' should be of type Tuple[Tensor]. Received {param.annotation}."
221
+ )
222
+ args = typing.get_args(param.annotation)
223
+ for arg in args:
224
+ if not issubclass(arg, Tensor):
225
+ raise ValueError(
226
+ f"Argument for receiving output Tensor, '{name}' contains a {param.annotation}. '{name}' should be a Tuple[Tensor]."
227
+ )
228
+ elif name == "stream":
229
+ if not issubclass(param.annotation, int):
230
+ raise ValueError("'stream' input argument should be an int")
231
+ elif name == "tactic":
232
+ if not issubclass(param.annotation, int):
233
+ raise ValueError("'tactic' input argument should be an int")
234
+ elif issubclass(param.annotation, Tensor):
235
+ if name not in plugin_def.input_tensor_names:
236
+ raise ValueError(
237
+ f"Unexpected tensor '{name}' specified in autotune function. Expected one of {plugin_def.input_tensor_names}."
238
+ )
239
+ else:
240
+ if name not in plugin_def.input_attrs:
241
+ raise ValueError(
242
+ f"Unexpected attribute '{name}' specified in impl function. Expected one of {list(registered_attr_names)}."
243
+ )
244
+
245
+ if param.annotation != plugin_def.input_attrs[name]:
246
+ raise ValueError(
247
+ f"Attribute '{name}' has a type annotation different from the one specified at registration. Expected '{plugin_def.input_attrs[name]}'."
248
+ )
249
+
250
+ impl_attr_names.append(name)
251
+ else:
252
+ if name in plugin_def.input_attrs:
253
+ impl_attr_names.append(name)
254
+
255
+ # Expected attribute schema should be constructed in the order they appeared in the register function
256
+ expected_attr_schema_chunks = [
257
+ n for n in registered_attr_names if n in impl_attr_names
258
+ ]
259
+
260
+ expected_schema = (
261
+ "("
262
+ + _join_with(plugin_def.input_tensor_names)
263
+ + _join_with(expected_attr_schema_chunks, True)
264
+ + ", outputs, stream"
265
+ )
266
+ if found_tactic:
267
+ expected_schema += ", tactic)"
268
+ else:
269
+ expected_schema += ")"
270
+
271
+ if f"({', '.join(sig.parameters.keys())})" != expected_schema:
272
+ raise ValueError(
273
+ f"Signature of the impl function '{sig}' does not match the expected input arg schema: {expected_schema}"
274
+ )
275
+
276
+ # Return annotation is optional, but we will validate if one is specified
277
+ if sig.return_annotation != inspect.Parameter.empty and sig.return_annotation is not None:
278
+ raise ValueError("Return annotation should be None.")
279
+
280
+ return impl_attr_names, found_tactic
281
+
282
+ def _validate_aot_impl(aot_impl_func, plugin_def):
283
+ aot_impl_attr_names = []
284
+
285
+ sig = inspect.signature(aot_impl_func)
286
+ registered_attr_names = plugin_def.input_attrs.keys()
287
+
288
+ # input arg annotations are optional, but we will validate if provided
289
+ for name, param in sig.parameters.items():
290
+ if param.annotation != inspect.Parameter.empty:
291
+ if name == "outputs":
292
+ if typing.get_origin(param.annotation) is not tuple:
293
+ raise ValueError(
294
+ f"'outputs' should be of type Tuple[TensorDesc]. Received {param.annotation}."
295
+ )
296
+ args = typing.get_args(param.annotation)
297
+ for arg in args:
298
+ if not issubclass(arg, TensorDesc):
299
+ raise ValueError(
300
+ f"Argument for receiving output TensorDesc, '{name}' contains a {param.annotation}. '{name}' should be a Tuple[TensorDesc]."
301
+ )
302
+ elif name == "tactic":
303
+ if not issubclass(param.annotation, int):
304
+ raise ValueError("'tactic' input argument should be an int")
305
+ elif issubclass(param.annotation, TensorDesc):
306
+ if name not in plugin_def.input_tensor_names:
307
+ raise ValueError(
308
+ f"Unexpected tensor '{name}' specified in autotune function. Expected one of {plugin_def.input_tensor_names}."
309
+ )
310
+ else:
311
+ if name not in plugin_def.input_attrs:
312
+ raise ValueError(
313
+ f"Unexpected attribute '{name}' specified in aot_impl function. Expected one of {list(registered_attr_names)}."
314
+ )
315
+
316
+ if param.annotation != plugin_def.input_attrs[name]:
317
+ raise ValueError(
318
+ f"Attribute '{name}' has a type annotation different from the one specified at registration. Expected '{plugin_def.input_attrs[name]}'."
319
+ )
320
+
321
+ aot_impl_attr_names.append(name)
322
+ else:
323
+ if name in plugin_def.input_attrs:
324
+ aot_impl_attr_names.append(name)
325
+
326
+ # Expected attribute schema should be constructed in the order they appeared in the register function
327
+ expected_attr_schema_chunks = [
328
+ n for n in registered_attr_names if n in aot_impl_attr_names
329
+ ]
330
+
331
+ expected_schema = (
332
+ "("
333
+ + _join_with(plugin_def.input_tensor_names)
334
+ + _join_with(expected_attr_schema_chunks, True)
335
+ + ", outputs, tactic)"
336
+ )
337
+
338
+ if f"({', '.join(sig.parameters.keys())})" != expected_schema:
339
+ raise ValueError(
340
+ f"Signature of the aot_impl function '{sig}' does not match the expected input arg schema: {expected_schema}"
341
+ )
342
+
343
+ ret_annotation = sig.return_annotation
344
+
345
+ if ret_annotation == inspect.Parameter.empty:
346
+ raise ValueError(
347
+ f"No return annotation found for aot_impl function. Received signature {sig}."
348
+ )
349
+
350
+ expected_return_schema = "tuple[str | bytes, str | bytes, tensorrt.plugin.KernelLaunchParams, tensorrt.plugin.SymIntExprs]"
351
+
352
+ # Return annotation is optional, but we will validate if one is specified
353
+ if ret_annotation != inspect.Parameter.empty:
354
+ if typing.get_origin(ret_annotation) is not tuple:
355
+ raise ValueError(
356
+ f"Return annotation is {ret_annotation}. Expected {expected_return_schema}."
357
+ )
358
+ else:
359
+ args = typing.get_args(ret_annotation)
360
+
361
+ if len(args) != 4:
362
+ raise ValueError(
363
+ f"Return annotation is {ret_annotation}. Expected {expected_return_schema}."
364
+ )
365
+
366
+ def validate_union_str_or_bytes(index):
367
+ def validate_str_or_bytes(arg_):
368
+ if (arg_ is not str) and (arg_ is not bytes):
369
+ raise ValueError(
370
+ f"Return annotation for argument at {index} is '{arg_}'. Expected 'str' or 'bytes'."
371
+ )
372
+
373
+ orig = typing.get_origin(args[index])
374
+ # orig is `typing.Union` when annotation uses typing module (e.g, Union[str, bytes])
375
+ # orig is `types.UnionType` when annotation is of the new (3.10+) native syntax (e.g, str | bytes)
376
+ if orig is typing.Union or orig is types.UnionType:
377
+ for a in typing.get_args(args[index]):
378
+ validate_str_or_bytes(a)
379
+ else:
380
+ # when annoted with `str` or `bytes`
381
+ validate_str_or_bytes(args[index])
382
+
383
+ # kernel name should be str or bytes encoding
384
+ validate_union_str_or_bytes(0)
385
+ # kernel PTX should be str or bytes encoding
386
+ validate_union_str_or_bytes(1)
387
+
388
+ if not issubclass(args[2], KernelLaunchParams):
389
+ raise ValueError(f"Argument at index 2 of return annotation is '{args[2]}'. Expected 'tensorrt.plugin.KernelLaunchParams'.")
390
+
391
+ if not issubclass(args[3], SymExprs):
392
+ raise ValueError(f"Argument at index 3 of return annotation is '{args[3]}'. Expected a descendent of tensorrt.plugin.SymExprs.")
393
+
394
+ return aot_impl_attr_names
395
+
396
+
397
+ def _validate_autotune(autotune_func, plugin_def):
398
+
399
+ sig = inspect.signature(autotune_func)
400
+ registered_attr_names = plugin_def.input_attrs.keys()
401
+
402
+ autotune_attr_names = []
403
+
404
+ # input arg annotations are optional, but we will validate if provided
405
+ for name, param in sig.parameters.items():
406
+ if param.annotation != inspect.Parameter.empty:
407
+ if name == "outputs":
408
+ if typing.get_origin(param.annotation) is not tuple:
409
+ raise ValueError(
410
+ f"'outputs' should be of type Tuple[TensorDesc]. Received {param.annotation}."
411
+ )
412
+ args = typing.get_args(param.annotation)
413
+ for arg in args:
414
+ if not issubclass(arg, TensorDesc):
415
+ raise ValueError(
416
+ f"Argument for receiving output TensorDescs, '{name}' contains a {param.annotation}. '{name}' should be a Tuple[TensorDesc]."
417
+ )
418
+ elif issubclass(param.annotation, TensorDesc):
419
+ if name not in plugin_def.input_tensor_names:
420
+ raise ValueError(
421
+ f"Unexpected tensor '{name}' specified in autotune function. Expected one of {plugin_def.input_tensor_names}."
422
+ )
423
+ else:
424
+ if name not in plugin_def.input_attrs:
425
+ raise ValueError(
426
+ f"Unexpected attribute '{name}' specified in autotune function. Expected one of {list(registered_attr_names)}."
427
+ )
428
+ if param.annotation != plugin_def.input_attrs[name]:
429
+ raise ValueError(
430
+ f"Attribute '{name}' has a type annotation different from the one specified at registration. Expected '{plugin_def.input_attrs[name]}'."
431
+ )
432
+
433
+ autotune_attr_names.append(name)
434
+ else:
435
+ if name in plugin_def.input_attrs:
436
+ autotune_attr_names.append(name)
437
+
438
+ # Expected attribute schema should be constructed in the order they appeared in the register function
439
+ expected_attr_schema_chunks = [
440
+ n for n in registered_attr_names if n in autotune_attr_names
441
+ ]
442
+
443
+ expected_schema = (
444
+ "("
445
+ + _join_with(plugin_def.input_tensor_names)
446
+ + _join_with(expected_attr_schema_chunks, True)
447
+ + ", outputs)"
448
+ )
449
+
450
+ if f"({', '.join(sig.parameters.keys())})" != expected_schema:
451
+ raise ValueError(
452
+ f"Specified autotune function signature {sig} is not consistent with the expected input arg schema {expected_schema}."
453
+ )
454
+
455
+ ret_annotation = sig.return_annotation
456
+
457
+ # Return annotation is optional, but we will validate if one is specified
458
+ if ret_annotation != inspect.Parameter.empty:
459
+ if typing.get_origin(ret_annotation) is not list:
460
+ if not inspect.isclass(ret_annotation) or not issubclass(
461
+ ret_annotation, AutoTuneCombination
462
+ ):
463
+ raise ValueError(
464
+ f"Return argument is of type {ret_annotation}. Return types can only be AutoTuneCombination or List[AutoTuneCombination]."
465
+ )
466
+ else:
467
+ args = typing.get_args(ret_annotation)
468
+
469
+ for arg in args:
470
+ if not issubclass(arg, AutoTuneCombination):
471
+ raise ValueError(
472
+ f"Return argument is of type {ret_annotation}. Return types can only be AutoTuneCombination or List[AutoTuneCombination]."
473
+ )
474
+
475
+ return autotune_attr_names
Binary file