tensorrt-cu12-bindings 10.14.1.48.post1__cp39-none-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tensorrt_bindings/__init__.py +224 -0
- tensorrt_bindings/plugin/__init__.py +46 -0
- tensorrt_bindings/plugin/_autotune.py +270 -0
- tensorrt_bindings/plugin/_export.py +39 -0
- tensorrt_bindings/plugin/_lib.py +691 -0
- tensorrt_bindings/plugin/_plugin_class.py +459 -0
- tensorrt_bindings/plugin/_tensor.py +1128 -0
- tensorrt_bindings/plugin/_top_level.py +132 -0
- tensorrt_bindings/plugin/_utils.py +77 -0
- tensorrt_bindings/plugin/_validate.py +475 -0
- tensorrt_bindings/tensorrt.so +0 -0
- tensorrt_cu12_bindings-10.14.1.48.post1.dist-info/LICENSE.txt +180 -0
- tensorrt_cu12_bindings-10.14.1.48.post1.dist-info/METADATA +17 -0
- tensorrt_cu12_bindings-10.14.1.48.post1.dist-info/RECORD +17 -0
- tensorrt_cu12_bindings-10.14.1.48.post1.dist-info/WHEEL +5 -0
- tensorrt_cu12_bindings-10.14.1.48.post1.dist-info/top_level.txt +1 -0
- tensorrt_cu12_bindings-10.14.1.48.post1.dist-info/zip-safe +1 -0
|
@@ -0,0 +1,691 @@
|
|
|
1
|
+
#
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
|
|
18
|
+
import tensorrt as trt
|
|
19
|
+
import types
|
|
20
|
+
import typing
|
|
21
|
+
from typing import Callable, Tuple, List
|
|
22
|
+
import numpy as np
|
|
23
|
+
from ._plugin_class import _TemplateJITPlugin
|
|
24
|
+
from ._export import IS_AOT_ENABLED
|
|
25
|
+
if IS_AOT_ENABLED:
|
|
26
|
+
from ._plugin_class import _TemplateAOTPlugin
|
|
27
|
+
from ._validate import (
|
|
28
|
+
_parse_register_inputs,
|
|
29
|
+
_parse_register_return,
|
|
30
|
+
_validate_autotune,
|
|
31
|
+
_validate_impl,
|
|
32
|
+
_validate_aot_impl,
|
|
33
|
+
_validate_name_and_namespace,
|
|
34
|
+
)
|
|
35
|
+
from ._utils import (
|
|
36
|
+
_built_in_to_plugin_field_type,
|
|
37
|
+
_join_with,
|
|
38
|
+
_numpy_to_plugin_field_type,
|
|
39
|
+
_is_numpy_array,
|
|
40
|
+
_infer_numpy_type,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
from ._export import public_api
|
|
44
|
+
|
|
45
|
+
# Namespace to which plugins are dynamically bound
|
|
46
|
+
# A namespace can be thought of as a library of plugins from the same author/common objective
|
|
47
|
+
class _PluginNamespace(types.ModuleType):
|
|
48
|
+
def __init__(self, namespace):
|
|
49
|
+
super().__init__("tensorrt.plugin.op." + namespace)
|
|
50
|
+
self._namespace = namespace
|
|
51
|
+
|
|
52
|
+
def define(self, name, plugin_def):
|
|
53
|
+
assert not hasattr(self, name)
|
|
54
|
+
setattr(self, name, plugin_def)
|
|
55
|
+
|
|
56
|
+
def __getattr__(self, name):
|
|
57
|
+
raise AttributeError(
|
|
58
|
+
f"'{self.__class__.__name__}' object '{self._namespace}' has no attribute '{name}'"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def __repr__(self):
|
|
62
|
+
return f'_PluginNamespace(namespace="{self._namespace}")'
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# `tensorrt.plugin.op` module to which plugin namespaces are dynamically bound
|
|
66
|
+
class _Op(types.ModuleType):
|
|
67
|
+
def __init__(self):
|
|
68
|
+
super().__init__("tensorrt.plugin.op")
|
|
69
|
+
|
|
70
|
+
def define_or_get(self, namespace):
|
|
71
|
+
if hasattr(self, namespace):
|
|
72
|
+
return getattr(self, namespace)
|
|
73
|
+
|
|
74
|
+
ns = _PluginNamespace(namespace)
|
|
75
|
+
setattr(self, namespace, ns)
|
|
76
|
+
|
|
77
|
+
return ns
|
|
78
|
+
|
|
79
|
+
def __getattr__(self, name):
|
|
80
|
+
raise AttributeError(
|
|
81
|
+
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
op = _Op()
|
|
86
|
+
public_api(symbol="op")(op)
|
|
87
|
+
|
|
88
|
+
QDP_CREATORS = {}
|
|
89
|
+
QDP_REGISTRY = {}
|
|
90
|
+
|
|
91
|
+
# Contains metadata about a registered plugin and `__call__()`` that allows for a plugin instance to be created
|
|
92
|
+
class PluginDef:
|
|
93
|
+
def __init__(self):
|
|
94
|
+
self.plugin_id = None # includes namespace (format is ns::name)
|
|
95
|
+
self.register_func = None
|
|
96
|
+
self.impl_func = None
|
|
97
|
+
self.aot_impl_func = None
|
|
98
|
+
self.autotune_func = None
|
|
99
|
+
self.autotune_attr_names = None
|
|
100
|
+
self.input_tensor_names = None
|
|
101
|
+
self.input_attrs = None # map name -> type
|
|
102
|
+
self.impl_attr_names = None
|
|
103
|
+
self.aot_impl_attr_names = None
|
|
104
|
+
self.num_outputs = None
|
|
105
|
+
self.input_arg_schema = None
|
|
106
|
+
self.expects_tactic = None
|
|
107
|
+
|
|
108
|
+
def __call__(
|
|
109
|
+
self, *args, **kwargs
|
|
110
|
+
) -> Tuple[List[trt.ITensor], List[trt.ITensor], trt.IPluginV3]:
|
|
111
|
+
namespace, name = self.plugin_id.split("::")
|
|
112
|
+
|
|
113
|
+
input_tensors = []
|
|
114
|
+
schema_chunks = []
|
|
115
|
+
|
|
116
|
+
for t in args:
|
|
117
|
+
if not isinstance(t, trt.ITensor):
|
|
118
|
+
raise ValueError(
|
|
119
|
+
f"Expected trt.ITensor but got input of type {type(t)}"
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
schema_chunks.append("ITensor")
|
|
123
|
+
input_tensors.append(t)
|
|
124
|
+
|
|
125
|
+
attrs = {}
|
|
126
|
+
for key, value in kwargs.items():
|
|
127
|
+
if key not in self.input_attrs:
|
|
128
|
+
raise ValueError(
|
|
129
|
+
f"Unexpected attribute {key} provided. Expected one of {self.input_attrs.keys()}."
|
|
130
|
+
)
|
|
131
|
+
attrs[key] = value
|
|
132
|
+
attr_annotation = self.input_attrs[key]
|
|
133
|
+
if isinstance(value, np.ndarray):
|
|
134
|
+
if typing.get_origin(attr_annotation) == np.ndarray:
|
|
135
|
+
np_dtype = typing.get_args(typing.get_args(attr_annotation)[1])[0]
|
|
136
|
+
if np.dtype(np_dtype) != np.dtype(value.dtype):
|
|
137
|
+
raise ValueError(
|
|
138
|
+
f"Unexpected dtype '{np.dtype(value.dtype)}' for attribute '{key}'. Expected '{np_dtype}'."
|
|
139
|
+
)
|
|
140
|
+
else:
|
|
141
|
+
if attr_annotation is not type(value):
|
|
142
|
+
raise ValueError(
|
|
143
|
+
f"Unexpected type '{type(value)}' for attribute '{key}'. Expected '{attr_annotation}'."
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
schema_chunks.append(key)
|
|
147
|
+
|
|
148
|
+
expected_schema = (
|
|
149
|
+
f"({_join_with(['ITensor'] * len(self.input_tensor_names))}"
|
|
150
|
+
+ _join_with(self.input_attrs.keys(), True)
|
|
151
|
+
+ ")"
|
|
152
|
+
)
|
|
153
|
+
schema = f"({', '.join(schema_chunks)})"
|
|
154
|
+
|
|
155
|
+
if schema != expected_schema:
|
|
156
|
+
raise ValueError(
|
|
157
|
+
f"Unexpected schema {schema} received. Expected {expected_schema}."
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
if self.plugin_id in QDP_CREATORS:
|
|
161
|
+
plg_creator = trt.get_plugin_registry().get_creator(name, "1", namespace)
|
|
162
|
+
else:
|
|
163
|
+
attrs_types = {}
|
|
164
|
+
for key, value in kwargs.items():
|
|
165
|
+
if isinstance(value, np.ndarray):
|
|
166
|
+
attrs_types[key] = (False, value.dtype) # (builtin?, type)
|
|
167
|
+
else:
|
|
168
|
+
attrs_types[key] = (True, type(value)) # (builtin?, type)
|
|
169
|
+
|
|
170
|
+
plg_creator = _register_plugin_creator(name, namespace, attrs_types)
|
|
171
|
+
|
|
172
|
+
fields = []
|
|
173
|
+
for key, value in attrs.items():
|
|
174
|
+
if isinstance(value, np.ndarray):
|
|
175
|
+
np_type = np.dtype(value.dtype)
|
|
176
|
+
if np_type == np.float16:
|
|
177
|
+
fields.append(
|
|
178
|
+
trt.PluginField(
|
|
179
|
+
key, value.tobytes(), trt.PluginFieldType.UNKNOWN
|
|
180
|
+
)
|
|
181
|
+
)
|
|
182
|
+
else:
|
|
183
|
+
fields.append(
|
|
184
|
+
trt.PluginField(
|
|
185
|
+
key, value, _numpy_to_plugin_field_type[np_type]
|
|
186
|
+
)
|
|
187
|
+
)
|
|
188
|
+
elif isinstance(value, str):
|
|
189
|
+
fields.append(
|
|
190
|
+
trt.PluginField(key, value.encode(), trt.PluginFieldType.CHAR)
|
|
191
|
+
)
|
|
192
|
+
elif isinstance(value, bytes):
|
|
193
|
+
fields.append(trt.PluginField(key, value, trt.PluginFieldType.UNKNOWN))
|
|
194
|
+
else:
|
|
195
|
+
fields.append(
|
|
196
|
+
trt.PluginField(
|
|
197
|
+
key,
|
|
198
|
+
np.array([value]),
|
|
199
|
+
_built_in_to_plugin_field_type[type(value)],
|
|
200
|
+
)
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
def create_plugin_instance(quick_plugin_creation_request: "trt.QuickPluginCreationRequest" = None):
|
|
204
|
+
if quick_plugin_creation_request is None:
|
|
205
|
+
plg = plg_creator.create_plugin(
|
|
206
|
+
name,
|
|
207
|
+
namespace,
|
|
208
|
+
trt.PluginFieldCollection(fields),
|
|
209
|
+
trt.TensorRTPhase.BUILD
|
|
210
|
+
)
|
|
211
|
+
else:
|
|
212
|
+
plg = plg_creator.create_plugin(
|
|
213
|
+
name,
|
|
214
|
+
namespace,
|
|
215
|
+
trt.PluginFieldCollection(fields),
|
|
216
|
+
trt.TensorRTPhase.BUILD,
|
|
217
|
+
quick_plugin_creation_request
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
return input_tensors, [], plg
|
|
221
|
+
|
|
222
|
+
return create_plugin_instance
|
|
223
|
+
|
|
224
|
+
class _TemplatePluginCreator(trt.IPluginCreatorV3Quick):
|
|
225
|
+
def __init__(self, name, namespace, attrs):
|
|
226
|
+
trt.IPluginCreatorV3Quick.__init__(self)
|
|
227
|
+
self.name = name
|
|
228
|
+
self.plugin_namespace = namespace
|
|
229
|
+
self.plugin_version = "1"
|
|
230
|
+
field_names = []
|
|
231
|
+
for name, (builtin, type_) in attrs.items():
|
|
232
|
+
if builtin:
|
|
233
|
+
if type_ is str:
|
|
234
|
+
field_names.append(
|
|
235
|
+
trt.PluginField(name, b"", trt.PluginFieldType.CHAR)
|
|
236
|
+
)
|
|
237
|
+
elif type_ is bytes:
|
|
238
|
+
field_names.append(
|
|
239
|
+
trt.PluginField(name, b"", trt.PluginFieldType.UNKNOWN)
|
|
240
|
+
)
|
|
241
|
+
else:
|
|
242
|
+
field_names.append(
|
|
243
|
+
trt.PluginField(
|
|
244
|
+
name, np.array([]), _built_in_to_plugin_field_type[type_]
|
|
245
|
+
)
|
|
246
|
+
)
|
|
247
|
+
else:
|
|
248
|
+
field_names.append(
|
|
249
|
+
trt.PluginField(
|
|
250
|
+
name, np.array([]), _numpy_to_plugin_field_type[np.dtype(type_)]
|
|
251
|
+
)
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
self.field_names = trt.PluginFieldCollection(field_names)
|
|
255
|
+
|
|
256
|
+
def create_plugin(self, name, namespace, fc, phase, qpcr: "trt.QuickPluginCreationRequest" = None):
|
|
257
|
+
desc = QDP_REGISTRY[f"{namespace}::{name}"]
|
|
258
|
+
name = name
|
|
259
|
+
namespace = namespace
|
|
260
|
+
|
|
261
|
+
attrs = {}
|
|
262
|
+
for f in fc:
|
|
263
|
+
if f.name not in desc.input_attrs:
|
|
264
|
+
raise AssertionError(
|
|
265
|
+
f"Unexpected attribute {f.name} provided to create_plugin. Expected one of {desc.input_attrs.keys()}."
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
attr_type_annot = desc.input_attrs[f.name]
|
|
269
|
+
if _is_numpy_array(attr_type_annot):
|
|
270
|
+
np_type = _infer_numpy_type(attr_type_annot)
|
|
271
|
+
if np_type == np.float16:
|
|
272
|
+
attrs[f.name] = np.frombuffer(f.data.tobytes(), dtype=np.float16)
|
|
273
|
+
else:
|
|
274
|
+
attrs[f.name] = f.data.astype(np_type)
|
|
275
|
+
else:
|
|
276
|
+
if issubclass(attr_type_annot, str):
|
|
277
|
+
attrs[f.name] = f.data.tobytes().decode("utf-8")
|
|
278
|
+
else:
|
|
279
|
+
attrs[f.name] = attr_type_annot(f.data)
|
|
280
|
+
|
|
281
|
+
jit_or_aot = None # True if JIT is to be created, False if AOT. Not None will be asserted before plugin creation.
|
|
282
|
+
|
|
283
|
+
if qpcr is None:
|
|
284
|
+
plg = _TemplateJITPlugin(name, namespace, desc.num_outputs)
|
|
285
|
+
|
|
286
|
+
plg.init(
|
|
287
|
+
desc.register_func,
|
|
288
|
+
attrs,
|
|
289
|
+
desc.impl_attr_names,
|
|
290
|
+
desc.impl_func,
|
|
291
|
+
desc.autotune_attr_names,
|
|
292
|
+
desc.autotune_func,
|
|
293
|
+
desc.expects_tactic,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
return plg
|
|
297
|
+
|
|
298
|
+
# If there is a strict preference, that takes precedence
|
|
299
|
+
if qpcr == trt.QuickPluginCreationRequest.STRICT_AOT:
|
|
300
|
+
if desc.aot_impl_func is None:
|
|
301
|
+
raise ValueError(f"AOT implementation requested, but not defined for '{desc.plugin_id}'. Was @trt.plugin.aot_impl defined?")
|
|
302
|
+
jit_or_aot = False
|
|
303
|
+
elif qpcr == trt.QuickPluginCreationRequest.STRICT_JIT:
|
|
304
|
+
if desc.impl_func is None:
|
|
305
|
+
raise ValueError(f"JIT implementation requested, but not defined for '{desc.plugin_id}'. Was @trt.plugin.impl defined?")
|
|
306
|
+
jit_or_aot = True
|
|
307
|
+
else:
|
|
308
|
+
aot_defined = desc.aot_impl_func is not None
|
|
309
|
+
jit_defined = desc.impl_func is not None
|
|
310
|
+
|
|
311
|
+
# A preferemce must be indicated if both AOT and JIT implementations are defined
|
|
312
|
+
if aot_defined and jit_defined:
|
|
313
|
+
if qpcr == trt.QuickPluginCreationRequest.PREFER_AOT:
|
|
314
|
+
jit_or_aot = False
|
|
315
|
+
elif qpcr == trt.QuickPluginCreationRequest.PREFER_JIT:
|
|
316
|
+
jit_or_aot = True
|
|
317
|
+
else:
|
|
318
|
+
raise ValueError(f"Plugin '{desc.plugin_id}' has both AOT and JIT implementations. NetworkDefinitionCreationFlag.PREFER_AOT_PYTHON_PLUGINS or NetworkDefinitionCreationFlag.PREFER_JIT_PYTHON_PLUGINS should be specified.")
|
|
319
|
+
else:
|
|
320
|
+
# If only one implementation is defined, use that.
|
|
321
|
+
# Any preference specified is ignored. If the preference is strong, a strict flag should have been specified.
|
|
322
|
+
if aot_defined:
|
|
323
|
+
jit_or_aot = False
|
|
324
|
+
elif jit_defined:
|
|
325
|
+
jit_or_aot = True
|
|
326
|
+
else:
|
|
327
|
+
raise ValueError(f"Plugin '{desc.plugin_id}' does not have either a AOT or JIT implementation.")
|
|
328
|
+
|
|
329
|
+
assert jit_or_aot is not None
|
|
330
|
+
|
|
331
|
+
if jit_or_aot:
|
|
332
|
+
plg = _TemplateJITPlugin(name, namespace, desc.num_outputs)
|
|
333
|
+
|
|
334
|
+
plg.init(
|
|
335
|
+
desc.register_func,
|
|
336
|
+
attrs,
|
|
337
|
+
desc.impl_attr_names,
|
|
338
|
+
desc.impl_func,
|
|
339
|
+
desc.autotune_attr_names,
|
|
340
|
+
desc.autotune_func,
|
|
341
|
+
desc.expects_tactic,
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
else:
|
|
345
|
+
plg = _TemplateAOTPlugin(name, namespace, desc.num_outputs)
|
|
346
|
+
|
|
347
|
+
plg.init(
|
|
348
|
+
desc.register_func,
|
|
349
|
+
attrs,
|
|
350
|
+
desc.aot_impl_attr_names,
|
|
351
|
+
desc.aot_impl_func,
|
|
352
|
+
desc.autotune_attr_names,
|
|
353
|
+
desc.autotune_func
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# the caller can determine if the created plugin is an AOT or JIT plugin by inspecting the interface info
|
|
357
|
+
return plg
|
|
358
|
+
|
|
359
|
+
def _register_plugin_creator(name: str, namespace: str, attrs_types):
|
|
360
|
+
plg_registry = trt.get_plugin_registry()
|
|
361
|
+
plg_creator = _TemplatePluginCreator(name, namespace, attrs_types)
|
|
362
|
+
plg_registry.register_creator(plg_creator, namespace)
|
|
363
|
+
plg_creator = plg_registry.get_creator(name, "1", namespace)
|
|
364
|
+
QDP_CREATORS[f"{namespace}::{name}"] = plg_creator
|
|
365
|
+
return plg_creator
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
# Decorator for `tensorrt.plugin.register`
|
|
369
|
+
# By default, the plugin will be immediately registered in the TRT plugin registry
|
|
370
|
+
# During plugin development/when building engine, lazy registration may be used to delay plugin registration until the plugin is explicitly instantiated using `trt.plugin.op.ns.plugin_name(...)`
|
|
371
|
+
@public_api()
|
|
372
|
+
def register(plugin_id: str, lazy_register: bool = False) -> Callable:
|
|
373
|
+
"""
|
|
374
|
+
Wraps a function to register and describe a TensorRT plugin's IO characteristics. In addition, a complete plugin at least needs an `trt.plugin.impl` function to be registered.
|
|
375
|
+
|
|
376
|
+
This API is only intended to be used as a decorator. The decorated function must have type hints for all inputs as well as return value.
|
|
377
|
+
|
|
378
|
+
.. code-block:: text
|
|
379
|
+
|
|
380
|
+
(inp0: TensorDesc, inp1: TensorDesc, ..., attr0: SupportedAttrType, attr1: SupportedAttrType, ...) -> Union[TensorDesc, Tuple[TensorDesc]]
|
|
381
|
+
|
|
382
|
+
* Input tensors are declared first, each described by a tensor descriptor TensorDesc.
|
|
383
|
+
* Plugin attributes are declared next. "SupportedAttrType" must be one of:
|
|
384
|
+
* Supported built-in types: int, float, str, bool, bytes (Note: Lists/tuples of these types are not supported)
|
|
385
|
+
* 1-D Numpy arrays of the following types: int8, int16, int32, int64, float16, float32, float64, bool. These must be annotated with 'numpy.typing.NDArray[dtype]', where 'dtype' is the expected numpy dtype.
|
|
386
|
+
* If the plugin has only one output, the return annotation could be TensorDesc. Tuple[TensorDesc] could be used for any number of outputs.
|
|
387
|
+
|
|
388
|
+
By default, the plugin will be immediately registered in the TRT plugin registry. Use the lazy_register argument to change this.
|
|
389
|
+
|
|
390
|
+
Args:
|
|
391
|
+
plugin_id: An ID for the plugin in the form "{namespace}::{name}",
|
|
392
|
+
e.g. "my_project::add_plugin". The namespace is used to avoid collisions
|
|
393
|
+
so using your product/project name is recommended.
|
|
394
|
+
|
|
395
|
+
lazy_register: During plugin development/when building engine, lazy registration may be used to delay plugin registration until the plugin is explicitly instantiated using `trt.plugin.op.ns.plugin_name(...)`
|
|
396
|
+
|
|
397
|
+
.. code-block:: python
|
|
398
|
+
:linenos:
|
|
399
|
+
:caption: Registration of an elementwise plugin (output has same characteristics as the input)
|
|
400
|
+
|
|
401
|
+
import tensorrt.plugin as trtp
|
|
402
|
+
|
|
403
|
+
@trtp.register("my::add_plugin")
|
|
404
|
+
def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> Tuple[trtp.TensorDesc]:
|
|
405
|
+
return inp0.like()
|
|
406
|
+
|
|
407
|
+
"""
|
|
408
|
+
|
|
409
|
+
def decorator(register_func: Callable):
|
|
410
|
+
|
|
411
|
+
plugin_ns, plugin_name = plugin_id.split("::")
|
|
412
|
+
_validate_name_and_namespace(plugin_ns, plugin_name)
|
|
413
|
+
|
|
414
|
+
op_namespace = op.define_or_get(plugin_ns)
|
|
415
|
+
|
|
416
|
+
if hasattr(op_namespace, plugin_name):
|
|
417
|
+
raise ValueError(
|
|
418
|
+
f"'{op.__class__.__name__}' already has a defintion for '{plugin_name}'"
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
(
|
|
422
|
+
tensor_names,
|
|
423
|
+
input_attrs,
|
|
424
|
+
input_arg_schema,
|
|
425
|
+
attrs_types,
|
|
426
|
+
) = _parse_register_inputs(register_func, lazy_register)
|
|
427
|
+
|
|
428
|
+
plugin_def = PluginDef()
|
|
429
|
+
plugin_def.plugin_id = plugin_id
|
|
430
|
+
plugin_def.register_func = register_func
|
|
431
|
+
plugin_def.input_tensor_names = tensor_names
|
|
432
|
+
plugin_def.input_attrs = input_attrs
|
|
433
|
+
plugin_def.input_arg_schema = input_arg_schema
|
|
434
|
+
|
|
435
|
+
num_outputs = _parse_register_return(register_func)
|
|
436
|
+
|
|
437
|
+
plugin_def.num_outputs = num_outputs
|
|
438
|
+
QDP_REGISTRY[plugin_id] = plugin_def
|
|
439
|
+
|
|
440
|
+
if not lazy_register:
|
|
441
|
+
_register_plugin_creator(plugin_name, plugin_ns, attrs_types)
|
|
442
|
+
|
|
443
|
+
op_namespace.define(plugin_name, plugin_def)
|
|
444
|
+
|
|
445
|
+
return register_func
|
|
446
|
+
|
|
447
|
+
return decorator
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
# Decorator for `tensorrt.plugin.impl`
|
|
451
|
+
@public_api()
|
|
452
|
+
def impl(plugin_id: str) -> Callable:
|
|
453
|
+
"""
|
|
454
|
+
Wraps a function to define an implementation for a plugin already registered through `trt.plugin.register`.
|
|
455
|
+
|
|
456
|
+
This API is only intended to be used as a decorator. The decorated function is not required to have type hints for input arguments or return value;
|
|
457
|
+
however, any type hints specified will be validated against the `trt.plugin.register` signature for consistency.
|
|
458
|
+
|
|
459
|
+
The schema for the function is as follows:
|
|
460
|
+
|
|
461
|
+
.. code-block:: text
|
|
462
|
+
|
|
463
|
+
(inp0: Tensor, inp1: Tensor, ..., attr0: SupportedAttrType, attr1: SupportedAttrType, outputs: Tuple[Tensor], stream: int, tactic: Optional[int]) -> None
|
|
464
|
+
|
|
465
|
+
* Input tensors are passed first, each described by a `Tensor`.
|
|
466
|
+
* Plugin attributes are declared next.
|
|
467
|
+
* Not all attributes included in `trt.plugin.register` must be specified here -- they could be a subset.
|
|
468
|
+
* Included attributes will be serialized to the TRT engine. Therefore, only attributes the plugin actually needs to perform inference (within the body of `trt.plugin.impl`) should be included.
|
|
469
|
+
* `tactic` is an optional argument. If the plugin is using custom tactics, it must be specified to receive the tactic value to use for the current execution of the plugin.
|
|
470
|
+
|
|
471
|
+
Args:
|
|
472
|
+
plugin_id: The ID for the plugin in the form "{namespace}::{name}", which must match that used during `trt.plugin.register`
|
|
473
|
+
|
|
474
|
+
.. code-block:: python
|
|
475
|
+
:linenos:
|
|
476
|
+
:caption: Implementation of an elementwise plugin with an OpenAI Triton kernel
|
|
477
|
+
|
|
478
|
+
import tensorrt.plugin as trtp
|
|
479
|
+
import triton
|
|
480
|
+
import triton.language as tl
|
|
481
|
+
|
|
482
|
+
@triton.jit
|
|
483
|
+
def add_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
|
484
|
+
pid = tl.program_id(0)
|
|
485
|
+
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
486
|
+
mask = offsets < n_elements
|
|
487
|
+
x = tl.load(x_ptr + offsets, mask=mask)
|
|
488
|
+
tl.store(y_ptr + offsets, x + 1, mask=mask)
|
|
489
|
+
|
|
490
|
+
@trtp.register("my::add_plugin")
|
|
491
|
+
def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> Tuple[trtp.TensorDesc]:
|
|
492
|
+
return inp0.like()
|
|
493
|
+
|
|
494
|
+
@trtp.impl("my::add_plugin")
|
|
495
|
+
def add_plugin_impl(inp0: trtp.Tensor, block_size: int, outputs: Tuple[trtp.Tensor], stream: int) -> None:
|
|
496
|
+
|
|
497
|
+
n = inp0.numel()
|
|
498
|
+
inp0_t = torch.as_tensor(inp0, device="cuda")
|
|
499
|
+
out_t = torch.as_tensor(outputs[0], device="cuda")
|
|
500
|
+
|
|
501
|
+
add_kernel[(triton.cdiv(n, block_size),)](inp0_t, out_t, n, BLOCK_SIZE = block_size)
|
|
502
|
+
"""
|
|
503
|
+
|
|
504
|
+
def decorator(impl_func: Callable):
|
|
505
|
+
if plugin_id not in QDP_REGISTRY:
|
|
506
|
+
raise ValueError(
|
|
507
|
+
f"Plugin {plugin_id} is not registered. Did you register it with tensorrt.plugin.register API?"
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
plugin_def = QDP_REGISTRY[plugin_id]
|
|
511
|
+
impl_attr_names, found_tactic = _validate_impl(impl_func, plugin_def)
|
|
512
|
+
|
|
513
|
+
plugin_def.impl_func = impl_func
|
|
514
|
+
plugin_def.impl_attr_names = impl_attr_names
|
|
515
|
+
plugin_def.expects_tactic = found_tactic
|
|
516
|
+
return impl_func
|
|
517
|
+
|
|
518
|
+
return decorator
|
|
519
|
+
|
|
520
|
+
# Decorator for `tensorrt.plugin.aot_impl`
|
|
521
|
+
@public_api()
|
|
522
|
+
def aot_impl(plugin_id: str) -> Callable:
|
|
523
|
+
"""
|
|
524
|
+
Wraps a function to define an Ahead-of-Time (AOT) implementation for a plugin already registered through `trt.plugin.register`.
|
|
525
|
+
|
|
526
|
+
This API is only intended to be used as a decorator. The decorated function is not required to have type hints for input arguments or return value;
|
|
527
|
+
however, any type hints specified will be validated against the `trt.plugin.register` signature for consistency.
|
|
528
|
+
|
|
529
|
+
The schema for the function is as follows:
|
|
530
|
+
.. code-block:: text
|
|
531
|
+
|
|
532
|
+
(inp0: TensorDesc, inp1: TensorDesc, ..., attr0: SupportedAttrType, attr1: SupportedAttrType, outputs: Tuple[TensorDesc], tactic: Optional[int]) -> Tuple[str, str, KernelLaunchParams, SymExprs]
|
|
533
|
+
|
|
534
|
+
* Input tensors are passed first, each described by a `TensorDesc`.
|
|
535
|
+
* Plugin attributes are declared next.
|
|
536
|
+
* Not all attributes included in `trt.plugin.register` must be specified here -- they could be a subset.
|
|
537
|
+
* NOTE: Plugin attributes are not serialized into the engine when using an AOT implementation.
|
|
538
|
+
* `tactic` is an optional argument. If the plugin is using custom tactics, it must be specified to receive the tactic value to use for the current execution of the plugin.
|
|
539
|
+
|
|
540
|
+
Args:
|
|
541
|
+
plugin_id: The ID for the plugin in the form "{namespace}::{name}", which must match that used during `trt.plugin.register`
|
|
542
|
+
|
|
543
|
+
:returns:
|
|
544
|
+
- kernel_name: The name of the kernel.
|
|
545
|
+
- compiled_kernel: Compiled form of the kernel. Presently, only PTX is supported.
|
|
546
|
+
- launch_params: The launch parameters for the kernel
|
|
547
|
+
- extra_args: Symbolic expressions for scalar inputs to the kernel, located after the tensor inputs and before the tensor outputs
|
|
548
|
+
|
|
549
|
+
.. code-block:: python
|
|
550
|
+
:linenos:
|
|
551
|
+
:caption: Implementation of an elementwise plugin with an OpenAI Triton kernel
|
|
552
|
+
|
|
553
|
+
import tensorrt.plugin as trtp
|
|
554
|
+
import triton
|
|
555
|
+
import triton.language as tl
|
|
556
|
+
|
|
557
|
+
@triton.jit
|
|
558
|
+
def add_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr):
|
|
559
|
+
pid = tl.program_id(0)
|
|
560
|
+
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
561
|
+
mask = offsets < n_elements
|
|
562
|
+
x = tl.load(x_ptr + offsets, mask=mask)
|
|
563
|
+
tl.store(y_ptr + offsets, x + 1, mask=mask)
|
|
564
|
+
|
|
565
|
+
@trtp.register("my::add_plugin")
|
|
566
|
+
def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> Tuple[trtp.TensorDesc]:
|
|
567
|
+
return inp0.like()
|
|
568
|
+
|
|
569
|
+
@trtp.aot_impl("my::elemwise_add_plugin")
|
|
570
|
+
def add_plugin_aot_impl(
|
|
571
|
+
inp0: trtp.TensorDesc, block_size: int, single_tactic: bool, outputs: Tuple[trtp.TensorDesc], tactic: int
|
|
572
|
+
) -> Tuple[Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs]:
|
|
573
|
+
|
|
574
|
+
type_str = "fp32" if inp0.dtype == trt.float32 else "fp16"
|
|
575
|
+
|
|
576
|
+
src = triton.compiler.ASTSource(
|
|
577
|
+
fn=add_kernel,
|
|
578
|
+
signature=f"*{type_str},i32,*{type_str}",
|
|
579
|
+
constants={
|
|
580
|
+
"BLOCK_SIZE": block_size,
|
|
581
|
+
},
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
compiled_kernel = triton.compile(src)
|
|
585
|
+
|
|
586
|
+
N = inp0.shape_expr.numel()
|
|
587
|
+
launch_params = trtp.KernelLaunchParams()
|
|
588
|
+
|
|
589
|
+
# grid dims
|
|
590
|
+
launch_params.grid_x = trtp.cdiv(N, block_size)
|
|
591
|
+
# block dims
|
|
592
|
+
launch_params.block_x = compiled_kernel.metadata.num_warps * 32
|
|
593
|
+
# shared memory
|
|
594
|
+
launch_params.shared_mem = compiled_kernel.metadata.shared
|
|
595
|
+
|
|
596
|
+
extra_args = trtp.SymIntExprs(1)
|
|
597
|
+
extra_args[0] = trtp.SymInt32(N)
|
|
598
|
+
|
|
599
|
+
return compiled_kernel.metadata.name, compiled_kernel.asm["ptx"], launch_params, extra_args
|
|
600
|
+
"""
|
|
601
|
+
def decorator(aot_impl_func: Callable):
|
|
602
|
+
if plugin_id not in QDP_REGISTRY:
|
|
603
|
+
raise ValueError(
|
|
604
|
+
f"Plugin {plugin_id} is not registered. Did you register it with tensorrt.plugin.register API?"
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
plugin_def = QDP_REGISTRY[plugin_id]
|
|
608
|
+
aot_impl_attr_names = _validate_aot_impl(aot_impl_func, plugin_def)
|
|
609
|
+
|
|
610
|
+
plugin_def.aot_impl_func = aot_impl_func
|
|
611
|
+
plugin_def.aot_impl_attr_names = aot_impl_attr_names
|
|
612
|
+
return aot_impl_func
|
|
613
|
+
|
|
614
|
+
return decorator
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
# Decorator for `tensorrt.plugin.autotune`
|
|
618
|
+
@public_api()
|
|
619
|
+
def autotune(plugin_id: str) -> Callable:
|
|
620
|
+
"""
|
|
621
|
+
Wraps a function to define autotune logic for a plugin already registered through `trt.plugin.register`.
|
|
622
|
+
|
|
623
|
+
Autotuning is the process by which TensorRT executes the plugin over IO type/format combinations, and any custom tactics advertised as being supported by the plugin.
|
|
624
|
+
The (type, format, tactic) combination with the lowest latency is used to execute the plugin once the engine is built.
|
|
625
|
+
|
|
626
|
+
.. note:: An autotune function is optional. If not specified, TensorRT will assume the plugin only supports input types specified at network creation, output types specifeid through `trt.plugin.register`, and linear formats for all I/O.
|
|
627
|
+
|
|
628
|
+
This API is only intended to be used as a decorator. The decorated function is not required to have type hints for input arguments or return value; however, any type hints specified will be validated against the `trt.plugin.register` signature for consistency.
|
|
629
|
+
|
|
630
|
+
The schema for the function is as follows:
|
|
631
|
+
|
|
632
|
+
.. code-block:: text
|
|
633
|
+
|
|
634
|
+
(inp0: TensorDesc, inp1: TensorDesc, ..., attr0: SupportedAttrType, attr1: SupportedAttrType, outputs: Tuple[TensorDesc]) -> List[AutoTuneCombination]
|
|
635
|
+
|
|
636
|
+
* Input tensors are passed first, each described by a :class:`TensorDesc`.
|
|
637
|
+
* Plugin attributes are declared next. Not all attributes included in `trt.plugin.register` must be specified here -- they could be a subset.
|
|
638
|
+
* The function should return a list of :class:`AutoTuneCombination`\s.
|
|
639
|
+
|
|
640
|
+
Args:
|
|
641
|
+
plugin_id: The ID for the plugin in the form "{namespace}::{name}", which must match that used during `trt.plugin.register`
|
|
642
|
+
|
|
643
|
+
.. code-block:: python
|
|
644
|
+
:linenos:
|
|
645
|
+
:caption: An elementwise add plugin which supports both FP32 and FP16 linear I/O and wants to be tuned over 2 custom tactics.
|
|
646
|
+
|
|
647
|
+
import tensorrt.plugin as trtp
|
|
648
|
+
|
|
649
|
+
@trtp.register("my::add_plugin")
|
|
650
|
+
def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> Tuple[trtp.TensorDesc]:
|
|
651
|
+
return inp0.like()
|
|
652
|
+
|
|
653
|
+
@trtp.autotune("my::add_plugin")
|
|
654
|
+
def add_plugin_autotune(inp0: trtp.TensorDesc, block_size: int, outputs: Tuple[trtp.TensorDesc]) -> List[trtp.AutoTuneCombination]:
|
|
655
|
+
|
|
656
|
+
return [trtp.AutoTuneCombination("FP32|FP16, FP32|FP16", "LINEAR", [1, 2])]
|
|
657
|
+
|
|
658
|
+
.. code-block:: python
|
|
659
|
+
:linenos:
|
|
660
|
+
:caption: Same as above example but using index-by-index construction of an `AutoTuneCombination`
|
|
661
|
+
|
|
662
|
+
import tensorrt.plugin as trtp
|
|
663
|
+
|
|
664
|
+
@trtp.register("my::add_plugin")
|
|
665
|
+
def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> Tuple[trtp.TensorDesc]:
|
|
666
|
+
return inp0.like()
|
|
667
|
+
|
|
668
|
+
@trtp.autotune("my::add_plugin")
|
|
669
|
+
def add_plugin_autotune(inp0: trtp.TensorDesc, block_size: int, outputs: Tuple[trtp.TensorDesc]) -> List[trtp.AutoTuneCombination]:
|
|
670
|
+
c = trtp.AutoTuneCombination()
|
|
671
|
+
c.pos(0, "FP32|FP16", "LINEAR")
|
|
672
|
+
c.pos(1, "FP32|FP16") # index 1 is the output. Omitting format is the same as declaring it to be LINEAR.
|
|
673
|
+
c.tactics([1, 2])
|
|
674
|
+
return [c]
|
|
675
|
+
"""
|
|
676
|
+
|
|
677
|
+
def decorator(autotune_func: Callable):
|
|
678
|
+
if plugin_id not in QDP_REGISTRY:
|
|
679
|
+
raise ValueError(
|
|
680
|
+
f"Plugin {plugin_id} is not registered. Did you register it with tensorrt.plugin.register API?"
|
|
681
|
+
)
|
|
682
|
+
|
|
683
|
+
plugin_def = QDP_REGISTRY[plugin_id]
|
|
684
|
+
autotune_attr_names = _validate_autotune(autotune_func, plugin_def)
|
|
685
|
+
|
|
686
|
+
plugin_def.autotune_func = autotune_func
|
|
687
|
+
plugin_def.autotune_attr_names = autotune_attr_names
|
|
688
|
+
|
|
689
|
+
return autotune_func
|
|
690
|
+
|
|
691
|
+
return decorator
|