tensorrt-cu12-bindings 10.13.3.9.post1__cp312-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tensorrt_bindings/__init__.py +224 -0
- tensorrt_bindings/plugin/__init__.py +46 -0
- tensorrt_bindings/plugin/_autotune.py +270 -0
- tensorrt_bindings/plugin/_export.py +39 -0
- tensorrt_bindings/plugin/_lib.py +691 -0
- tensorrt_bindings/plugin/_plugin_class.py +459 -0
- tensorrt_bindings/plugin/_tensor.py +1128 -0
- tensorrt_bindings/plugin/_top_level.py +132 -0
- tensorrt_bindings/plugin/_utils.py +77 -0
- tensorrt_bindings/plugin/_validate.py +475 -0
- tensorrt_bindings/tensorrt.cp312-win_amd64.pyd +0 -0
- tensorrt_cu12_bindings-10.13.3.9.post1.dist-info/LICENSE.txt +180 -0
- tensorrt_cu12_bindings-10.13.3.9.post1.dist-info/METADATA +17 -0
- tensorrt_cu12_bindings-10.13.3.9.post1.dist-info/RECORD +17 -0
- tensorrt_cu12_bindings-10.13.3.9.post1.dist-info/WHEEL +5 -0
- tensorrt_cu12_bindings-10.13.3.9.post1.dist-info/top_level.txt +1 -0
- tensorrt_cu12_bindings-10.13.3.9.post1.dist-info/zip-safe +1 -0
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
#
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
|
|
18
|
+
from typing import Union, Tuple
|
|
19
|
+
import tensorrt as trt
|
|
20
|
+
from ._tensor import ShapeExpr, TensorDesc, ShapeExprs, SizeTensorDesc
|
|
21
|
+
from ._export import public_api
|
|
22
|
+
|
|
23
|
+
# Miscellaneous top-level functions accessible through `tensorrt.plugin`
|
|
24
|
+
|
|
25
|
+
# Performs `trt.DimensionOperation.CEIL_DIV`
|
|
26
|
+
@public_api()
|
|
27
|
+
def cdiv(first: Union[int, ShapeExpr], second: Union[int, ShapeExpr]) -> ShapeExpr:
|
|
28
|
+
"""
|
|
29
|
+
Computes symbolic ceiling division of `first` by `second`
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
first (Union[int, ShapeExpr]): Dividend
|
|
33
|
+
second (Union[int, ShapeExpr]): Divisor
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
ValueError: If both arguments are `int`\s or if `second` evaluates to 0
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
ShapeExpr: Symbolic expression for the ceiling division of `first` by `second`
|
|
40
|
+
"""
|
|
41
|
+
if isinstance(first, int):
|
|
42
|
+
if isinstance(second, int):
|
|
43
|
+
raise ValueError("Both arguments cannot be 'int's")
|
|
44
|
+
first = ShapeExpr(first)
|
|
45
|
+
|
|
46
|
+
return first._op(trt.DimensionOperation.CEIL_DIV, second)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# Performs `trt.DimensionOperation.MAX`
|
|
50
|
+
@public_api()
|
|
51
|
+
def max(first: Union[int, ShapeExpr], second: Union[int, ShapeExpr]) -> ShapeExpr:
|
|
52
|
+
"""
|
|
53
|
+
Computes the maximum of `first` and `second`
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
first (Union[int, ShapeExpr]): First operand
|
|
57
|
+
second (Union[int, ShapeExpr]): Second operand
|
|
58
|
+
|
|
59
|
+
Raises:
|
|
60
|
+
ValueError: If both arguments are `int`\s
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
ShapeExpr: Symbolic expression for the maximum of `first` and `second`
|
|
64
|
+
"""
|
|
65
|
+
if isinstance(first, int):
|
|
66
|
+
if isinstance(second, int):
|
|
67
|
+
raise ValueError("Both arguments cannot be 'int's")
|
|
68
|
+
first = ShapeExpr(first)
|
|
69
|
+
|
|
70
|
+
return first._op(trt.DimensionOperation.MAX, second)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
# Performs `trt.DimensionOperation.MIN`
|
|
74
|
+
@public_api()
|
|
75
|
+
def min(first: Union[int, ShapeExpr], second: Union[int, ShapeExpr]) -> ShapeExpr:
|
|
76
|
+
"""
|
|
77
|
+
Computes the minimum of `first` and `second`
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
first (Union[int, ShapeExpr]): First operand
|
|
81
|
+
second (Union[int, ShapeExpr]): Second operand
|
|
82
|
+
|
|
83
|
+
Raises:
|
|
84
|
+
ValueError: If both arguments are `int`\s
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
ShapeExpr: Symbolic expression for the minimum of `first` and `second`
|
|
88
|
+
"""
|
|
89
|
+
if isinstance(first, int):
|
|
90
|
+
if isinstance(second, int):
|
|
91
|
+
raise ValueError("Both arguments cannot be 'int's")
|
|
92
|
+
first = ShapeExpr(first)
|
|
93
|
+
|
|
94
|
+
return first._op(trt.DimensionOperation.MIN, second)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
# Declare a size tensor descriptor with the specified autotune shape expression `opt` and `upper-bound` shape expression
|
|
98
|
+
@public_api()
|
|
99
|
+
def size_tensor(opt: ShapeExpr, upper_bound: ShapeExpr) -> SizeTensorDesc:
|
|
100
|
+
"""
|
|
101
|
+
Constructs a size tensor with the specified autotune shape expression `opt` and `upper_bound`
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
opt (ShapeExpr): Symbolic expression for the extent of this size tensor to use in the autotune process of the engine build
|
|
105
|
+
upper_bound (ShapeExpr): Symbolic expression for the upper-bound of this size tensor
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
SizeTensorDesc: A tensor descriptor for a size tensor with the specified autotune extent and upper-bound
|
|
109
|
+
"""
|
|
110
|
+
return SizeTensorDesc(opt, upper_bound)
|
|
111
|
+
|
|
112
|
+
# Create a TensorDesc using shape expressions and a dtype
|
|
113
|
+
@public_api()
|
|
114
|
+
def from_shape_expr(shape_expr: Union[Tuple[Union[ShapeExpr, int]], ShapeExprs], dtype: trt.DataType) -> TensorDesc:
|
|
115
|
+
"""
|
|
116
|
+
Constructs a tensor descriptor with the specified shape expression and data type
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
shape_expr (Union[Tuple[Union[ShapeExpr, int]], ShapeExprs]): Expressions or constants denoting the shape of the tensor
|
|
120
|
+
dtype (trt.DataType): Data type of the tensor
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
TensorDesc: Tensor descriptor with the specified shape expression and data type
|
|
124
|
+
"""
|
|
125
|
+
if isinstance(shape_expr, tuple):
|
|
126
|
+
shape_expr_ = ShapeExprs.from_tuple(shape_expr)
|
|
127
|
+
else:
|
|
128
|
+
shape_expr_ = shape_expr
|
|
129
|
+
|
|
130
|
+
return TensorDesc(shape_expr_, dtype)
|
|
131
|
+
|
|
132
|
+
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
#
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
|
|
18
|
+
import tensorrt as trt
|
|
19
|
+
import numpy as np
|
|
20
|
+
import typing
|
|
21
|
+
|
|
22
|
+
_numpy_to_plugin_field_type = {
|
|
23
|
+
np.dtype('int32'): trt.PluginFieldType.INT32,
|
|
24
|
+
np.dtype('int16'): trt.PluginFieldType.INT16,
|
|
25
|
+
np.dtype('int8'): trt.PluginFieldType.INT8,
|
|
26
|
+
np.dtype('bool'): trt.PluginFieldType.INT8,
|
|
27
|
+
np.dtype('int64'): trt.PluginFieldType.INT64,
|
|
28
|
+
np.dtype('float32'): trt.PluginFieldType.FLOAT32,
|
|
29
|
+
np.dtype('float64'): trt.PluginFieldType.FLOAT64,
|
|
30
|
+
np.dtype('float16'): trt.PluginFieldType.FLOAT16
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
_built_in_to_plugin_field_type = {
|
|
34
|
+
int: trt.PluginFieldType.INT64,
|
|
35
|
+
float: trt.PluginFieldType.FLOAT64,
|
|
36
|
+
bool: trt.PluginFieldType.INT8,
|
|
37
|
+
# str is handled separately, so not needed here
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
def _str_to_data_type(dtype: str) -> trt.DataType:
|
|
41
|
+
if dtype == "FP32":
|
|
42
|
+
return trt.DataType.FLOAT
|
|
43
|
+
if dtype == "FP16":
|
|
44
|
+
return trt.DataType.HALF
|
|
45
|
+
try:
|
|
46
|
+
return getattr(trt.DataType, dtype)
|
|
47
|
+
except KeyError:
|
|
48
|
+
raise ValueError(f"Unknown data type string '{dtype}'") from None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _join_with(lst, middle = False, delim = ", "):
|
|
52
|
+
if len(lst) == 0:
|
|
53
|
+
return ""
|
|
54
|
+
|
|
55
|
+
ret = ""
|
|
56
|
+
if middle:
|
|
57
|
+
ret += ", "
|
|
58
|
+
|
|
59
|
+
ret += delim.join(lst)
|
|
60
|
+
|
|
61
|
+
return ret
|
|
62
|
+
|
|
63
|
+
def _is_npt_ndarray(annotation):
|
|
64
|
+
return (typing.get_origin(annotation) == np.ndarray) or (hasattr(annotation, "__origin__") and annotation.__origin__ == np.ndarray)
|
|
65
|
+
|
|
66
|
+
def _is_numpy_array(annotation):
|
|
67
|
+
return (annotation == np.ndarray) or _is_npt_ndarray(annotation)
|
|
68
|
+
|
|
69
|
+
def _infer_numpy_type(annotation):
|
|
70
|
+
assert _is_npt_ndarray(annotation)
|
|
71
|
+
annot_args = typing.get_args(annotation) or annotation.__args__
|
|
72
|
+
if len(annot_args) >= 2:
|
|
73
|
+
np_type = typing.get_args(annot_args[1]) or annot_args[1].__args__
|
|
74
|
+
if len(np_type) >= 1:
|
|
75
|
+
return np_type[0]
|
|
76
|
+
|
|
77
|
+
raise AttributeError("Improper annotation for numpy array. Annotate numpy array attributes using 'numpy.typing.NDArray[dtype]', where 'dtype' is the expected numpy dtype of the array.")
|
|
@@ -0,0 +1,475 @@
|
|
|
1
|
+
#
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
|
|
18
|
+
import inspect
|
|
19
|
+
import numpy as np
|
|
20
|
+
import typing
|
|
21
|
+
import types
|
|
22
|
+
|
|
23
|
+
from ._utils import _is_numpy_array, _join_with, _infer_numpy_type, _is_npt_ndarray
|
|
24
|
+
from ._tensor import TensorDesc, Tensor, SymExprs
|
|
25
|
+
from ._export import IS_AOT_ENABLED
|
|
26
|
+
if IS_AOT_ENABLED:
|
|
27
|
+
from ._tensor import KernelLaunchParams
|
|
28
|
+
from ._autotune import AutoTuneCombination
|
|
29
|
+
|
|
30
|
+
SERIALIZABLE_BUILTIN_TYPES = (int, float, bytes, bool, str)
|
|
31
|
+
SERIALIZABLE_NP_DTYPES = (
|
|
32
|
+
np.int8,
|
|
33
|
+
np.int16,
|
|
34
|
+
np.int32,
|
|
35
|
+
np.int64,
|
|
36
|
+
np.float16,
|
|
37
|
+
np.float32,
|
|
38
|
+
np.float64,
|
|
39
|
+
bool,
|
|
40
|
+
np.bool_,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Reserve some namespaces for future use/avoid confusion
|
|
44
|
+
RESERVED_NAMESPACES = {
|
|
45
|
+
"",
|
|
46
|
+
"trt",
|
|
47
|
+
"tensorrt",
|
|
48
|
+
"std",
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
DISALLOWED_ATTR_NAMES = {
|
|
52
|
+
"outputs",
|
|
53
|
+
"stream",
|
|
54
|
+
"tactic",
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
def _validate_name_and_namespace(ns: str, name: str):
|
|
58
|
+
if "." in ns:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
f"Provided namespace {ns} cannot have any '.' in trt.plugin.register(\"{ns}::{name}\", ...)"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
if "." in name:
|
|
64
|
+
raise ValueError(
|
|
65
|
+
f"Provided name {name} cannot have any '.' in trt.plugin.register(\"{ns}::{name}\", ...)"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
if ns in RESERVED_NAMESPACES:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
f"Provided namespace {ns} is a reserved namespace"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# Parse `tensorrt.plugin.register` schema
|
|
75
|
+
def _parse_register_inputs(register_func, lazy_register):
|
|
76
|
+
tensor_names = []
|
|
77
|
+
input_attrs = (
|
|
78
|
+
dict()
|
|
79
|
+
) # order is important here but for Python >= 3.7, dict respects key order
|
|
80
|
+
|
|
81
|
+
schema_chunks = []
|
|
82
|
+
|
|
83
|
+
# TensorDescs and attribute args cannot be interspersed, so remember when we saw the first attribute arg
|
|
84
|
+
saw_first_attr = False
|
|
85
|
+
|
|
86
|
+
# Map of (attr_name: str) -> (is_builtin_type?: bool, type annotation: str)
|
|
87
|
+
attrs_types = {}
|
|
88
|
+
|
|
89
|
+
sig = inspect.signature(register_func)
|
|
90
|
+
|
|
91
|
+
for idx, (name, param) in enumerate(sig.parameters.items()):
|
|
92
|
+
|
|
93
|
+
if param.kind not in (
|
|
94
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
95
|
+
inspect.Parameter.KEYWORD_ONLY,
|
|
96
|
+
):
|
|
97
|
+
raise ValueError(
|
|
98
|
+
f"Argument {name} is not a positional-or-keyword or keyword-only arg"
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Type annotations are manadatory for `tensorrt.plugin.register` args
|
|
102
|
+
if param.annotation == inspect.Parameter.empty:
|
|
103
|
+
raise ValueError(
|
|
104
|
+
f"Argument {name} does not have a type annotation. Please mark as TensorDesc or one of the serializable attribute types."
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# Presently, we do not support default values for attributes
|
|
108
|
+
if param.default is not inspect.Parameter.empty:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
f"Argument {name} has a default value. Default values are not supported yet."
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
if issubclass(param.annotation, TensorDesc):
|
|
115
|
+
if saw_first_attr:
|
|
116
|
+
raise ValueError(
|
|
117
|
+
f"TensorDescs args and attribute args cannot be interspersed. Received function with signature {sig}."
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
tensor_names.append(name)
|
|
121
|
+
schema_chunks.append(f"TensorDesc {name}")
|
|
122
|
+
# At this point, we don't validate attribute types since we only care about the types of serializable attributes
|
|
123
|
+
# However, we memorize name and type so that we may validate that the autotune function maintains consistency
|
|
124
|
+
else:
|
|
125
|
+
if idx == 0:
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"TensorDescs args should come first, followed by attributes. Received function with signature {sig}."
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
if name in DISALLOWED_ATTR_NAMES:
|
|
131
|
+
raise ValueError(
|
|
132
|
+
f"'{name}' is not allowed as a plugin attribute name."
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
if param.annotation not in SERIALIZABLE_BUILTIN_TYPES:
|
|
136
|
+
if _is_numpy_array(param.annotation):
|
|
137
|
+
if not lazy_register:
|
|
138
|
+
if param.annotation == np.ndarray:
|
|
139
|
+
raise ValueError(
|
|
140
|
+
"If using non-lazy registration, annotate numpy array attributes using 'numpy.typing.NDArray[dtype]', where 'dtype' is the expected numpy dtype of the array."
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
if _is_npt_ndarray(param.annotation):
|
|
144
|
+
np_dtype = _infer_numpy_type(param.annotation)
|
|
145
|
+
if np_dtype not in SERIALIZABLE_NP_DTYPES:
|
|
146
|
+
raise ValueError(
|
|
147
|
+
f"Attribute '{name}' is not a supported numpy array type. Supported numpy arrays type are {SERIALIZABLE_NP_DTYPES}."
|
|
148
|
+
)
|
|
149
|
+
attrs_types[name] = (False, np_dtype)
|
|
150
|
+
|
|
151
|
+
else:
|
|
152
|
+
raise ValueError(
|
|
153
|
+
f"Attribute '{name}' of type {param.annotation} is not a supported serializable type. Supported types are {SERIALIZABLE_BUILTIN_TYPES} or numpy arrays of type {SERIALIZABLE_NP_DTYPES}."
|
|
154
|
+
)
|
|
155
|
+
else:
|
|
156
|
+
attrs_types[name] = (True, param.annotation)
|
|
157
|
+
|
|
158
|
+
saw_first_attr = True
|
|
159
|
+
|
|
160
|
+
schema_chunks.append(f"{param.annotation} {name}")
|
|
161
|
+
input_attrs[name] = param.annotation
|
|
162
|
+
|
|
163
|
+
return (
|
|
164
|
+
tensor_names,
|
|
165
|
+
input_attrs,
|
|
166
|
+
f"({_join_with(schema_chunks)})",
|
|
167
|
+
attrs_types,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _parse_register_return(register_func):
|
|
172
|
+
sig = inspect.signature(register_func)
|
|
173
|
+
|
|
174
|
+
ret_annotation = sig.return_annotation
|
|
175
|
+
|
|
176
|
+
if ret_annotation == inspect.Parameter.empty:
|
|
177
|
+
raise ValueError(
|
|
178
|
+
f"No return annotation found for register function. Received signature {sig}."
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
if typing.get_origin(ret_annotation) is not tuple:
|
|
182
|
+
if not inspect.isclass(ret_annotation) or not issubclass(
|
|
183
|
+
ret_annotation, TensorDesc
|
|
184
|
+
):
|
|
185
|
+
raise ValueError(
|
|
186
|
+
f"Return argument is of type {ret_annotation}. Return types can only be TensorDesc or Tuple[TensorDesc]."
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
num_outputs = 1
|
|
190
|
+
else:
|
|
191
|
+
args = typing.get_args(ret_annotation)
|
|
192
|
+
|
|
193
|
+
for arg in args:
|
|
194
|
+
if not issubclass(arg, TensorDesc):
|
|
195
|
+
raise ValueError(
|
|
196
|
+
f"Return argument is of type {ret_annotation}. Return types can only be TensorDesc or Tuple[TensorDesc]."
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
num_outputs = len(args)
|
|
200
|
+
|
|
201
|
+
return num_outputs
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _validate_impl(impl_func, plugin_def):
|
|
205
|
+
impl_attr_names = []
|
|
206
|
+
found_tactic = False
|
|
207
|
+
|
|
208
|
+
sig = inspect.signature(impl_func)
|
|
209
|
+
registered_attr_names = plugin_def.input_attrs.keys()
|
|
210
|
+
|
|
211
|
+
# input arg annotations are optional, but we will validate if provided
|
|
212
|
+
for name, param in sig.parameters.items():
|
|
213
|
+
# tactic arg is optional in impl function. If specified, remember so that we can pass it during enqueue.
|
|
214
|
+
if name == "tactic":
|
|
215
|
+
found_tactic = True
|
|
216
|
+
if param.annotation != inspect.Parameter.empty:
|
|
217
|
+
if name == "outputs":
|
|
218
|
+
if typing.get_origin(param.annotation) is not tuple:
|
|
219
|
+
raise ValueError(
|
|
220
|
+
f"'outputs' should be of type Tuple[Tensor]. Received {param.annotation}."
|
|
221
|
+
)
|
|
222
|
+
args = typing.get_args(param.annotation)
|
|
223
|
+
for arg in args:
|
|
224
|
+
if not issubclass(arg, Tensor):
|
|
225
|
+
raise ValueError(
|
|
226
|
+
f"Argument for receiving output Tensor, '{name}' contains a {param.annotation}. '{name}' should be a Tuple[Tensor]."
|
|
227
|
+
)
|
|
228
|
+
elif name == "stream":
|
|
229
|
+
if not issubclass(param.annotation, int):
|
|
230
|
+
raise ValueError("'stream' input argument should be an int")
|
|
231
|
+
elif name == "tactic":
|
|
232
|
+
if not issubclass(param.annotation, int):
|
|
233
|
+
raise ValueError("'tactic' input argument should be an int")
|
|
234
|
+
elif issubclass(param.annotation, Tensor):
|
|
235
|
+
if name not in plugin_def.input_tensor_names:
|
|
236
|
+
raise ValueError(
|
|
237
|
+
f"Unexpected tensor '{name}' specified in autotune function. Expected one of {plugin_def.input_tensor_names}."
|
|
238
|
+
)
|
|
239
|
+
else:
|
|
240
|
+
if name not in plugin_def.input_attrs:
|
|
241
|
+
raise ValueError(
|
|
242
|
+
f"Unexpected attribute '{name}' specified in impl function. Expected one of {list(registered_attr_names)}."
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
if param.annotation != plugin_def.input_attrs[name]:
|
|
246
|
+
raise ValueError(
|
|
247
|
+
f"Attribute '{name}' has a type annotation different from the one specified at registration. Expected '{plugin_def.input_attrs[name]}'."
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
impl_attr_names.append(name)
|
|
251
|
+
else:
|
|
252
|
+
if name in plugin_def.input_attrs:
|
|
253
|
+
impl_attr_names.append(name)
|
|
254
|
+
|
|
255
|
+
# Expected attribute schema should be constructed in the order they appeared in the register function
|
|
256
|
+
expected_attr_schema_chunks = [
|
|
257
|
+
n for n in registered_attr_names if n in impl_attr_names
|
|
258
|
+
]
|
|
259
|
+
|
|
260
|
+
expected_schema = (
|
|
261
|
+
"("
|
|
262
|
+
+ _join_with(plugin_def.input_tensor_names)
|
|
263
|
+
+ _join_with(expected_attr_schema_chunks, True)
|
|
264
|
+
+ ", outputs, stream"
|
|
265
|
+
)
|
|
266
|
+
if found_tactic:
|
|
267
|
+
expected_schema += ", tactic)"
|
|
268
|
+
else:
|
|
269
|
+
expected_schema += ")"
|
|
270
|
+
|
|
271
|
+
if f"({', '.join(sig.parameters.keys())})" != expected_schema:
|
|
272
|
+
raise ValueError(
|
|
273
|
+
f"Signature of the impl function '{sig}' does not match the expected input arg schema: {expected_schema}"
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# Return annotation is optional, but we will validate if one is specified
|
|
277
|
+
if sig.return_annotation != inspect.Parameter.empty and sig.return_annotation is not None:
|
|
278
|
+
raise ValueError("Return annotation should be None.")
|
|
279
|
+
|
|
280
|
+
return impl_attr_names, found_tactic
|
|
281
|
+
|
|
282
|
+
def _validate_aot_impl(aot_impl_func, plugin_def):
|
|
283
|
+
aot_impl_attr_names = []
|
|
284
|
+
|
|
285
|
+
sig = inspect.signature(aot_impl_func)
|
|
286
|
+
registered_attr_names = plugin_def.input_attrs.keys()
|
|
287
|
+
|
|
288
|
+
# input arg annotations are optional, but we will validate if provided
|
|
289
|
+
for name, param in sig.parameters.items():
|
|
290
|
+
if param.annotation != inspect.Parameter.empty:
|
|
291
|
+
if name == "outputs":
|
|
292
|
+
if typing.get_origin(param.annotation) is not tuple:
|
|
293
|
+
raise ValueError(
|
|
294
|
+
f"'outputs' should be of type Tuple[TensorDesc]. Received {param.annotation}."
|
|
295
|
+
)
|
|
296
|
+
args = typing.get_args(param.annotation)
|
|
297
|
+
for arg in args:
|
|
298
|
+
if not issubclass(arg, TensorDesc):
|
|
299
|
+
raise ValueError(
|
|
300
|
+
f"Argument for receiving output TensorDesc, '{name}' contains a {param.annotation}. '{name}' should be a Tuple[TensorDesc]."
|
|
301
|
+
)
|
|
302
|
+
elif name == "tactic":
|
|
303
|
+
if not issubclass(param.annotation, int):
|
|
304
|
+
raise ValueError("'tactic' input argument should be an int")
|
|
305
|
+
elif issubclass(param.annotation, TensorDesc):
|
|
306
|
+
if name not in plugin_def.input_tensor_names:
|
|
307
|
+
raise ValueError(
|
|
308
|
+
f"Unexpected tensor '{name}' specified in autotune function. Expected one of {plugin_def.input_tensor_names}."
|
|
309
|
+
)
|
|
310
|
+
else:
|
|
311
|
+
if name not in plugin_def.input_attrs:
|
|
312
|
+
raise ValueError(
|
|
313
|
+
f"Unexpected attribute '{name}' specified in aot_impl function. Expected one of {list(registered_attr_names)}."
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
if param.annotation != plugin_def.input_attrs[name]:
|
|
317
|
+
raise ValueError(
|
|
318
|
+
f"Attribute '{name}' has a type annotation different from the one specified at registration. Expected '{plugin_def.input_attrs[name]}'."
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
aot_impl_attr_names.append(name)
|
|
322
|
+
else:
|
|
323
|
+
if name in plugin_def.input_attrs:
|
|
324
|
+
aot_impl_attr_names.append(name)
|
|
325
|
+
|
|
326
|
+
# Expected attribute schema should be constructed in the order they appeared in the register function
|
|
327
|
+
expected_attr_schema_chunks = [
|
|
328
|
+
n for n in registered_attr_names if n in aot_impl_attr_names
|
|
329
|
+
]
|
|
330
|
+
|
|
331
|
+
expected_schema = (
|
|
332
|
+
"("
|
|
333
|
+
+ _join_with(plugin_def.input_tensor_names)
|
|
334
|
+
+ _join_with(expected_attr_schema_chunks, True)
|
|
335
|
+
+ ", outputs, tactic)"
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
if f"({', '.join(sig.parameters.keys())})" != expected_schema:
|
|
339
|
+
raise ValueError(
|
|
340
|
+
f"Signature of the aot_impl function '{sig}' does not match the expected input arg schema: {expected_schema}"
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
ret_annotation = sig.return_annotation
|
|
344
|
+
|
|
345
|
+
if ret_annotation == inspect.Parameter.empty:
|
|
346
|
+
raise ValueError(
|
|
347
|
+
f"No return annotation found for aot_impl function. Received signature {sig}."
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
expected_return_schema = "tuple[str | bytes, str | bytes, tensorrt.plugin.KernelLaunchParams, tensorrt.plugin.SymIntExprs]"
|
|
351
|
+
|
|
352
|
+
# Return annotation is optional, but we will validate if one is specified
|
|
353
|
+
if ret_annotation != inspect.Parameter.empty:
|
|
354
|
+
if typing.get_origin(ret_annotation) is not tuple:
|
|
355
|
+
raise ValueError(
|
|
356
|
+
f"Return annotation is {ret_annotation}. Expected {expected_return_schema}."
|
|
357
|
+
)
|
|
358
|
+
else:
|
|
359
|
+
args = typing.get_args(ret_annotation)
|
|
360
|
+
|
|
361
|
+
if len(args) != 4:
|
|
362
|
+
raise ValueError(
|
|
363
|
+
f"Return annotation is {ret_annotation}. Expected {expected_return_schema}."
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
def validate_union_str_or_bytes(index):
|
|
367
|
+
def validate_str_or_bytes(arg_):
|
|
368
|
+
if (arg_ is not str) and (arg_ is not bytes):
|
|
369
|
+
raise ValueError(
|
|
370
|
+
f"Return annotation for argument at {index} is '{arg_}'. Expected 'str' or 'bytes'."
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
orig = typing.get_origin(args[index])
|
|
374
|
+
# orig is `typing.Union` when annotation uses typing module (e.g, Union[str, bytes])
|
|
375
|
+
# orig is `types.UnionType` when annotation is of the new (3.10+) native syntax (e.g, str | bytes)
|
|
376
|
+
if orig is typing.Union or orig is types.UnionType:
|
|
377
|
+
for a in typing.get_args(args[index]):
|
|
378
|
+
validate_str_or_bytes(a)
|
|
379
|
+
else:
|
|
380
|
+
# when annoted with `str` or `bytes`
|
|
381
|
+
validate_str_or_bytes(args[index])
|
|
382
|
+
|
|
383
|
+
# kernel name should be str or bytes encoding
|
|
384
|
+
validate_union_str_or_bytes(0)
|
|
385
|
+
# kernel PTX should be str or bytes encoding
|
|
386
|
+
validate_union_str_or_bytes(1)
|
|
387
|
+
|
|
388
|
+
if not issubclass(args[2], KernelLaunchParams):
|
|
389
|
+
raise ValueError(f"Argument at index 2 of return annotation is '{args[2]}'. Expected 'tensorrt.plugin.KernelLaunchParams'.")
|
|
390
|
+
|
|
391
|
+
if not issubclass(args[3], SymExprs):
|
|
392
|
+
raise ValueError(f"Argument at index 3 of return annotation is '{args[3]}'. Expected a descendent of tensorrt.plugin.SymExprs.")
|
|
393
|
+
|
|
394
|
+
return aot_impl_attr_names
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def _validate_autotune(autotune_func, plugin_def):
|
|
398
|
+
|
|
399
|
+
sig = inspect.signature(autotune_func)
|
|
400
|
+
registered_attr_names = plugin_def.input_attrs.keys()
|
|
401
|
+
|
|
402
|
+
autotune_attr_names = []
|
|
403
|
+
|
|
404
|
+
# input arg annotations are optional, but we will validate if provided
|
|
405
|
+
for name, param in sig.parameters.items():
|
|
406
|
+
if param.annotation != inspect.Parameter.empty:
|
|
407
|
+
if name == "outputs":
|
|
408
|
+
if typing.get_origin(param.annotation) is not tuple:
|
|
409
|
+
raise ValueError(
|
|
410
|
+
f"'outputs' should be of type Tuple[TensorDesc]. Received {param.annotation}."
|
|
411
|
+
)
|
|
412
|
+
args = typing.get_args(param.annotation)
|
|
413
|
+
for arg in args:
|
|
414
|
+
if not issubclass(arg, TensorDesc):
|
|
415
|
+
raise ValueError(
|
|
416
|
+
f"Argument for receiving output TensorDescs, '{name}' contains a {param.annotation}. '{name}' should be a Tuple[TensorDesc]."
|
|
417
|
+
)
|
|
418
|
+
elif issubclass(param.annotation, TensorDesc):
|
|
419
|
+
if name not in plugin_def.input_tensor_names:
|
|
420
|
+
raise ValueError(
|
|
421
|
+
f"Unexpected tensor '{name}' specified in autotune function. Expected one of {plugin_def.input_tensor_names}."
|
|
422
|
+
)
|
|
423
|
+
else:
|
|
424
|
+
if name not in plugin_def.input_attrs:
|
|
425
|
+
raise ValueError(
|
|
426
|
+
f"Unexpected attribute '{name}' specified in autotune function. Expected one of {list(registered_attr_names)}."
|
|
427
|
+
)
|
|
428
|
+
if param.annotation != plugin_def.input_attrs[name]:
|
|
429
|
+
raise ValueError(
|
|
430
|
+
f"Attribute '{name}' has a type annotation different from the one specified at registration. Expected '{plugin_def.input_attrs[name]}'."
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
autotune_attr_names.append(name)
|
|
434
|
+
else:
|
|
435
|
+
if name in plugin_def.input_attrs:
|
|
436
|
+
autotune_attr_names.append(name)
|
|
437
|
+
|
|
438
|
+
# Expected attribute schema should be constructed in the order they appeared in the register function
|
|
439
|
+
expected_attr_schema_chunks = [
|
|
440
|
+
n for n in registered_attr_names if n in autotune_attr_names
|
|
441
|
+
]
|
|
442
|
+
|
|
443
|
+
expected_schema = (
|
|
444
|
+
"("
|
|
445
|
+
+ _join_with(plugin_def.input_tensor_names)
|
|
446
|
+
+ _join_with(expected_attr_schema_chunks, True)
|
|
447
|
+
+ ", outputs)"
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
if f"({', '.join(sig.parameters.keys())})" != expected_schema:
|
|
451
|
+
raise ValueError(
|
|
452
|
+
f"Specified autotune function signature {sig} is not consistent with the expected input arg schema {expected_schema}."
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
ret_annotation = sig.return_annotation
|
|
456
|
+
|
|
457
|
+
# Return annotation is optional, but we will validate if one is specified
|
|
458
|
+
if ret_annotation != inspect.Parameter.empty:
|
|
459
|
+
if typing.get_origin(ret_annotation) is not list:
|
|
460
|
+
if not inspect.isclass(ret_annotation) or not issubclass(
|
|
461
|
+
ret_annotation, AutoTuneCombination
|
|
462
|
+
):
|
|
463
|
+
raise ValueError(
|
|
464
|
+
f"Return argument is of type {ret_annotation}. Return types can only be AutoTuneCombination or List[AutoTuneCombination]."
|
|
465
|
+
)
|
|
466
|
+
else:
|
|
467
|
+
args = typing.get_args(ret_annotation)
|
|
468
|
+
|
|
469
|
+
for arg in args:
|
|
470
|
+
if not issubclass(arg, AutoTuneCombination):
|
|
471
|
+
raise ValueError(
|
|
472
|
+
f"Return argument is of type {ret_annotation}. Return types can only be AutoTuneCombination or List[AutoTuneCombination]."
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
return autotune_attr_names
|
|
Binary file
|