tensorrt-cu12-bindings 10.14.1.48.post1__cp39-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,691 @@
1
+ #
2
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ import tensorrt as trt
19
+ import types
20
+ import typing
21
+ from typing import Callable, Tuple, List
22
+ import numpy as np
23
+ from ._plugin_class import _TemplateJITPlugin
24
+ from ._export import IS_AOT_ENABLED
25
+ if IS_AOT_ENABLED:
26
+ from ._plugin_class import _TemplateAOTPlugin
27
+ from ._validate import (
28
+ _parse_register_inputs,
29
+ _parse_register_return,
30
+ _validate_autotune,
31
+ _validate_impl,
32
+ _validate_aot_impl,
33
+ _validate_name_and_namespace,
34
+ )
35
+ from ._utils import (
36
+ _built_in_to_plugin_field_type,
37
+ _join_with,
38
+ _numpy_to_plugin_field_type,
39
+ _is_numpy_array,
40
+ _infer_numpy_type,
41
+ )
42
+
43
+ from ._export import public_api
44
+
45
+ # Namespace to which plugins are dynamically bound
46
+ # A namespace can be thought of as a library of plugins from the same author/common objective
47
+ class _PluginNamespace(types.ModuleType):
48
+ def __init__(self, namespace):
49
+ super().__init__("tensorrt.plugin.op." + namespace)
50
+ self._namespace = namespace
51
+
52
+ def define(self, name, plugin_def):
53
+ assert not hasattr(self, name)
54
+ setattr(self, name, plugin_def)
55
+
56
+ def __getattr__(self, name):
57
+ raise AttributeError(
58
+ f"'{self.__class__.__name__}' object '{self._namespace}' has no attribute '{name}'"
59
+ )
60
+
61
+ def __repr__(self):
62
+ return f'_PluginNamespace(namespace="{self._namespace}")'
63
+
64
+
65
+ # `tensorrt.plugin.op` module to which plugin namespaces are dynamically bound
66
+ class _Op(types.ModuleType):
67
+ def __init__(self):
68
+ super().__init__("tensorrt.plugin.op")
69
+
70
+ def define_or_get(self, namespace):
71
+ if hasattr(self, namespace):
72
+ return getattr(self, namespace)
73
+
74
+ ns = _PluginNamespace(namespace)
75
+ setattr(self, namespace, ns)
76
+
77
+ return ns
78
+
79
+ def __getattr__(self, name):
80
+ raise AttributeError(
81
+ f"'{self.__class__.__name__}' object has no attribute '{name}'"
82
+ )
83
+
84
+
85
+ op = _Op()
86
+ public_api(symbol="op")(op)
87
+
88
+ QDP_CREATORS = {}
89
+ QDP_REGISTRY = {}
90
+
91
+ # Contains metadata about a registered plugin and `__call__()`` that allows for a plugin instance to be created
92
+ class PluginDef:
93
+ def __init__(self):
94
+ self.plugin_id = None # includes namespace (format is ns::name)
95
+ self.register_func = None
96
+ self.impl_func = None
97
+ self.aot_impl_func = None
98
+ self.autotune_func = None
99
+ self.autotune_attr_names = None
100
+ self.input_tensor_names = None
101
+ self.input_attrs = None # map name -> type
102
+ self.impl_attr_names = None
103
+ self.aot_impl_attr_names = None
104
+ self.num_outputs = None
105
+ self.input_arg_schema = None
106
+ self.expects_tactic = None
107
+
108
+ def __call__(
109
+ self, *args, **kwargs
110
+ ) -> Tuple[List[trt.ITensor], List[trt.ITensor], trt.IPluginV3]:
111
+ namespace, name = self.plugin_id.split("::")
112
+
113
+ input_tensors = []
114
+ schema_chunks = []
115
+
116
+ for t in args:
117
+ if not isinstance(t, trt.ITensor):
118
+ raise ValueError(
119
+ f"Expected trt.ITensor but got input of type {type(t)}"
120
+ )
121
+
122
+ schema_chunks.append("ITensor")
123
+ input_tensors.append(t)
124
+
125
+ attrs = {}
126
+ for key, value in kwargs.items():
127
+ if key not in self.input_attrs:
128
+ raise ValueError(
129
+ f"Unexpected attribute {key} provided. Expected one of {self.input_attrs.keys()}."
130
+ )
131
+ attrs[key] = value
132
+ attr_annotation = self.input_attrs[key]
133
+ if isinstance(value, np.ndarray):
134
+ if typing.get_origin(attr_annotation) == np.ndarray:
135
+ np_dtype = typing.get_args(typing.get_args(attr_annotation)[1])[0]
136
+ if np.dtype(np_dtype) != np.dtype(value.dtype):
137
+ raise ValueError(
138
+ f"Unexpected dtype '{np.dtype(value.dtype)}' for attribute '{key}'. Expected '{np_dtype}'."
139
+ )
140
+ else:
141
+ if attr_annotation is not type(value):
142
+ raise ValueError(
143
+ f"Unexpected type '{type(value)}' for attribute '{key}'. Expected '{attr_annotation}'."
144
+ )
145
+
146
+ schema_chunks.append(key)
147
+
148
+ expected_schema = (
149
+ f"({_join_with(['ITensor'] * len(self.input_tensor_names))}"
150
+ + _join_with(self.input_attrs.keys(), True)
151
+ + ")"
152
+ )
153
+ schema = f"({', '.join(schema_chunks)})"
154
+
155
+ if schema != expected_schema:
156
+ raise ValueError(
157
+ f"Unexpected schema {schema} received. Expected {expected_schema}."
158
+ )
159
+
160
+ if self.plugin_id in QDP_CREATORS:
161
+ plg_creator = trt.get_plugin_registry().get_creator(name, "1", namespace)
162
+ else:
163
+ attrs_types = {}
164
+ for key, value in kwargs.items():
165
+ if isinstance(value, np.ndarray):
166
+ attrs_types[key] = (False, value.dtype) # (builtin?, type)
167
+ else:
168
+ attrs_types[key] = (True, type(value)) # (builtin?, type)
169
+
170
+ plg_creator = _register_plugin_creator(name, namespace, attrs_types)
171
+
172
+ fields = []
173
+ for key, value in attrs.items():
174
+ if isinstance(value, np.ndarray):
175
+ np_type = np.dtype(value.dtype)
176
+ if np_type == np.float16:
177
+ fields.append(
178
+ trt.PluginField(
179
+ key, value.tobytes(), trt.PluginFieldType.UNKNOWN
180
+ )
181
+ )
182
+ else:
183
+ fields.append(
184
+ trt.PluginField(
185
+ key, value, _numpy_to_plugin_field_type[np_type]
186
+ )
187
+ )
188
+ elif isinstance(value, str):
189
+ fields.append(
190
+ trt.PluginField(key, value.encode(), trt.PluginFieldType.CHAR)
191
+ )
192
+ elif isinstance(value, bytes):
193
+ fields.append(trt.PluginField(key, value, trt.PluginFieldType.UNKNOWN))
194
+ else:
195
+ fields.append(
196
+ trt.PluginField(
197
+ key,
198
+ np.array([value]),
199
+ _built_in_to_plugin_field_type[type(value)],
200
+ )
201
+ )
202
+
203
+ def create_plugin_instance(quick_plugin_creation_request: "trt.QuickPluginCreationRequest" = None):
204
+ if quick_plugin_creation_request is None:
205
+ plg = plg_creator.create_plugin(
206
+ name,
207
+ namespace,
208
+ trt.PluginFieldCollection(fields),
209
+ trt.TensorRTPhase.BUILD
210
+ )
211
+ else:
212
+ plg = plg_creator.create_plugin(
213
+ name,
214
+ namespace,
215
+ trt.PluginFieldCollection(fields),
216
+ trt.TensorRTPhase.BUILD,
217
+ quick_plugin_creation_request
218
+ )
219
+
220
+ return input_tensors, [], plg
221
+
222
+ return create_plugin_instance
223
+
224
+ class _TemplatePluginCreator(trt.IPluginCreatorV3Quick):
225
+ def __init__(self, name, namespace, attrs):
226
+ trt.IPluginCreatorV3Quick.__init__(self)
227
+ self.name = name
228
+ self.plugin_namespace = namespace
229
+ self.plugin_version = "1"
230
+ field_names = []
231
+ for name, (builtin, type_) in attrs.items():
232
+ if builtin:
233
+ if type_ is str:
234
+ field_names.append(
235
+ trt.PluginField(name, b"", trt.PluginFieldType.CHAR)
236
+ )
237
+ elif type_ is bytes:
238
+ field_names.append(
239
+ trt.PluginField(name, b"", trt.PluginFieldType.UNKNOWN)
240
+ )
241
+ else:
242
+ field_names.append(
243
+ trt.PluginField(
244
+ name, np.array([]), _built_in_to_plugin_field_type[type_]
245
+ )
246
+ )
247
+ else:
248
+ field_names.append(
249
+ trt.PluginField(
250
+ name, np.array([]), _numpy_to_plugin_field_type[np.dtype(type_)]
251
+ )
252
+ )
253
+
254
+ self.field_names = trt.PluginFieldCollection(field_names)
255
+
256
+ def create_plugin(self, name, namespace, fc, phase, qpcr: "trt.QuickPluginCreationRequest" = None):
257
+ desc = QDP_REGISTRY[f"{namespace}::{name}"]
258
+ name = name
259
+ namespace = namespace
260
+
261
+ attrs = {}
262
+ for f in fc:
263
+ if f.name not in desc.input_attrs:
264
+ raise AssertionError(
265
+ f"Unexpected attribute {f.name} provided to create_plugin. Expected one of {desc.input_attrs.keys()}."
266
+ )
267
+
268
+ attr_type_annot = desc.input_attrs[f.name]
269
+ if _is_numpy_array(attr_type_annot):
270
+ np_type = _infer_numpy_type(attr_type_annot)
271
+ if np_type == np.float16:
272
+ attrs[f.name] = np.frombuffer(f.data.tobytes(), dtype=np.float16)
273
+ else:
274
+ attrs[f.name] = f.data.astype(np_type)
275
+ else:
276
+ if issubclass(attr_type_annot, str):
277
+ attrs[f.name] = f.data.tobytes().decode("utf-8")
278
+ else:
279
+ attrs[f.name] = attr_type_annot(f.data)
280
+
281
+ jit_or_aot = None # True if JIT is to be created, False if AOT. Not None will be asserted before plugin creation.
282
+
283
+ if qpcr is None:
284
+ plg = _TemplateJITPlugin(name, namespace, desc.num_outputs)
285
+
286
+ plg.init(
287
+ desc.register_func,
288
+ attrs,
289
+ desc.impl_attr_names,
290
+ desc.impl_func,
291
+ desc.autotune_attr_names,
292
+ desc.autotune_func,
293
+ desc.expects_tactic,
294
+ )
295
+
296
+ return plg
297
+
298
+ # If there is a strict preference, that takes precedence
299
+ if qpcr == trt.QuickPluginCreationRequest.STRICT_AOT:
300
+ if desc.aot_impl_func is None:
301
+ raise ValueError(f"AOT implementation requested, but not defined for '{desc.plugin_id}'. Was @trt.plugin.aot_impl defined?")
302
+ jit_or_aot = False
303
+ elif qpcr == trt.QuickPluginCreationRequest.STRICT_JIT:
304
+ if desc.impl_func is None:
305
+ raise ValueError(f"JIT implementation requested, but not defined for '{desc.plugin_id}'. Was @trt.plugin.impl defined?")
306
+ jit_or_aot = True
307
+ else:
308
+ aot_defined = desc.aot_impl_func is not None
309
+ jit_defined = desc.impl_func is not None
310
+
311
+ # A preferemce must be indicated if both AOT and JIT implementations are defined
312
+ if aot_defined and jit_defined:
313
+ if qpcr == trt.QuickPluginCreationRequest.PREFER_AOT:
314
+ jit_or_aot = False
315
+ elif qpcr == trt.QuickPluginCreationRequest.PREFER_JIT:
316
+ jit_or_aot = True
317
+ else:
318
+ raise ValueError(f"Plugin '{desc.plugin_id}' has both AOT and JIT implementations. NetworkDefinitionCreationFlag.PREFER_AOT_PYTHON_PLUGINS or NetworkDefinitionCreationFlag.PREFER_JIT_PYTHON_PLUGINS should be specified.")
319
+ else:
320
+ # If only one implementation is defined, use that.
321
+ # Any preference specified is ignored. If the preference is strong, a strict flag should have been specified.
322
+ if aot_defined:
323
+ jit_or_aot = False
324
+ elif jit_defined:
325
+ jit_or_aot = True
326
+ else:
327
+ raise ValueError(f"Plugin '{desc.plugin_id}' does not have either a AOT or JIT implementation.")
328
+
329
+ assert jit_or_aot is not None
330
+
331
+ if jit_or_aot:
332
+ plg = _TemplateJITPlugin(name, namespace, desc.num_outputs)
333
+
334
+ plg.init(
335
+ desc.register_func,
336
+ attrs,
337
+ desc.impl_attr_names,
338
+ desc.impl_func,
339
+ desc.autotune_attr_names,
340
+ desc.autotune_func,
341
+ desc.expects_tactic,
342
+ )
343
+
344
+ else:
345
+ plg = _TemplateAOTPlugin(name, namespace, desc.num_outputs)
346
+
347
+ plg.init(
348
+ desc.register_func,
349
+ attrs,
350
+ desc.aot_impl_attr_names,
351
+ desc.aot_impl_func,
352
+ desc.autotune_attr_names,
353
+ desc.autotune_func
354
+ )
355
+
356
+ # the caller can determine if the created plugin is an AOT or JIT plugin by inspecting the interface info
357
+ return plg
358
+
359
+ def _register_plugin_creator(name: str, namespace: str, attrs_types):
360
+ plg_registry = trt.get_plugin_registry()
361
+ plg_creator = _TemplatePluginCreator(name, namespace, attrs_types)
362
+ plg_registry.register_creator(plg_creator, namespace)
363
+ plg_creator = plg_registry.get_creator(name, "1", namespace)
364
+ QDP_CREATORS[f"{namespace}::{name}"] = plg_creator
365
+ return plg_creator
366
+
367
+
368
+ # Decorator for `tensorrt.plugin.register`
369
+ # By default, the plugin will be immediately registered in the TRT plugin registry
370
+ # During plugin development/when building engine, lazy registration may be used to delay plugin registration until the plugin is explicitly instantiated using `trt.plugin.op.ns.plugin_name(...)`
371
+ @public_api()
372
+ def register(plugin_id: str, lazy_register: bool = False) -> Callable:
373
+ """
374
+ Wraps a function to register and describe a TensorRT plugin's IO characteristics. In addition, a complete plugin at least needs an `trt.plugin.impl` function to be registered.
375
+
376
+ This API is only intended to be used as a decorator. The decorated function must have type hints for all inputs as well as return value.
377
+
378
+ .. code-block:: text
379
+
380
+ (inp0: TensorDesc, inp1: TensorDesc, ..., attr0: SupportedAttrType, attr1: SupportedAttrType, ...) -> Union[TensorDesc, Tuple[TensorDesc]]
381
+
382
+ * Input tensors are declared first, each described by a tensor descriptor TensorDesc.
383
+ * Plugin attributes are declared next. "SupportedAttrType" must be one of:
384
+ * Supported built-in types: int, float, str, bool, bytes (Note: Lists/tuples of these types are not supported)
385
+ * 1-D Numpy arrays of the following types: int8, int16, int32, int64, float16, float32, float64, bool. These must be annotated with 'numpy.typing.NDArray[dtype]', where 'dtype' is the expected numpy dtype.
386
+ * If the plugin has only one output, the return annotation could be TensorDesc. Tuple[TensorDesc] could be used for any number of outputs.
387
+
388
+ By default, the plugin will be immediately registered in the TRT plugin registry. Use the lazy_register argument to change this.
389
+
390
+ Args:
391
+ plugin_id: An ID for the plugin in the form "{namespace}::{name}",
392
+ e.g. "my_project::add_plugin". The namespace is used to avoid collisions
393
+ so using your product/project name is recommended.
394
+
395
+ lazy_register: During plugin development/when building engine, lazy registration may be used to delay plugin registration until the plugin is explicitly instantiated using `trt.plugin.op.ns.plugin_name(...)`
396
+
397
+ .. code-block:: python
398
+ :linenos:
399
+ :caption: Registration of an elementwise plugin (output has same characteristics as the input)
400
+
401
+ import tensorrt.plugin as trtp
402
+
403
+ @trtp.register("my::add_plugin")
404
+ def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> Tuple[trtp.TensorDesc]:
405
+ return inp0.like()
406
+
407
+ """
408
+
409
+ def decorator(register_func: Callable):
410
+
411
+ plugin_ns, plugin_name = plugin_id.split("::")
412
+ _validate_name_and_namespace(plugin_ns, plugin_name)
413
+
414
+ op_namespace = op.define_or_get(plugin_ns)
415
+
416
+ if hasattr(op_namespace, plugin_name):
417
+ raise ValueError(
418
+ f"'{op.__class__.__name__}' already has a defintion for '{plugin_name}'"
419
+ )
420
+
421
+ (
422
+ tensor_names,
423
+ input_attrs,
424
+ input_arg_schema,
425
+ attrs_types,
426
+ ) = _parse_register_inputs(register_func, lazy_register)
427
+
428
+ plugin_def = PluginDef()
429
+ plugin_def.plugin_id = plugin_id
430
+ plugin_def.register_func = register_func
431
+ plugin_def.input_tensor_names = tensor_names
432
+ plugin_def.input_attrs = input_attrs
433
+ plugin_def.input_arg_schema = input_arg_schema
434
+
435
+ num_outputs = _parse_register_return(register_func)
436
+
437
+ plugin_def.num_outputs = num_outputs
438
+ QDP_REGISTRY[plugin_id] = plugin_def
439
+
440
+ if not lazy_register:
441
+ _register_plugin_creator(plugin_name, plugin_ns, attrs_types)
442
+
443
+ op_namespace.define(plugin_name, plugin_def)
444
+
445
+ return register_func
446
+
447
+ return decorator
448
+
449
+
450
+ # Decorator for `tensorrt.plugin.impl`
451
+ @public_api()
452
+ def impl(plugin_id: str) -> Callable:
453
+ """
454
+ Wraps a function to define an implementation for a plugin already registered through `trt.plugin.register`.
455
+
456
+ This API is only intended to be used as a decorator. The decorated function is not required to have type hints for input arguments or return value;
457
+ however, any type hints specified will be validated against the `trt.plugin.register` signature for consistency.
458
+
459
+ The schema for the function is as follows:
460
+
461
+ .. code-block:: text
462
+
463
+ (inp0: Tensor, inp1: Tensor, ..., attr0: SupportedAttrType, attr1: SupportedAttrType, outputs: Tuple[Tensor], stream: int, tactic: Optional[int]) -> None
464
+
465
+ * Input tensors are passed first, each described by a `Tensor`.
466
+ * Plugin attributes are declared next.
467
+ * Not all attributes included in `trt.plugin.register` must be specified here -- they could be a subset.
468
+ * Included attributes will be serialized to the TRT engine. Therefore, only attributes the plugin actually needs to perform inference (within the body of `trt.plugin.impl`) should be included.
469
+ * `tactic` is an optional argument. If the plugin is using custom tactics, it must be specified to receive the tactic value to use for the current execution of the plugin.
470
+
471
+ Args:
472
+ plugin_id: The ID for the plugin in the form "{namespace}::{name}", which must match that used during `trt.plugin.register`
473
+
474
+ .. code-block:: python
475
+ :linenos:
476
+ :caption: Implementation of an elementwise plugin with an OpenAI Triton kernel
477
+
478
+ import tensorrt.plugin as trtp
479
+ import triton
480
+ import triton.language as tl
481
+
482
+ @triton.jit
483
+ def add_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
484
+ pid = tl.program_id(0)
485
+ offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
486
+ mask = offsets < n_elements
487
+ x = tl.load(x_ptr + offsets, mask=mask)
488
+ tl.store(y_ptr + offsets, x + 1, mask=mask)
489
+
490
+ @trtp.register("my::add_plugin")
491
+ def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> Tuple[trtp.TensorDesc]:
492
+ return inp0.like()
493
+
494
+ @trtp.impl("my::add_plugin")
495
+ def add_plugin_impl(inp0: trtp.Tensor, block_size: int, outputs: Tuple[trtp.Tensor], stream: int) -> None:
496
+
497
+ n = inp0.numel()
498
+ inp0_t = torch.as_tensor(inp0, device="cuda")
499
+ out_t = torch.as_tensor(outputs[0], device="cuda")
500
+
501
+ add_kernel[(triton.cdiv(n, block_size),)](inp0_t, out_t, n, BLOCK_SIZE = block_size)
502
+ """
503
+
504
+ def decorator(impl_func: Callable):
505
+ if plugin_id not in QDP_REGISTRY:
506
+ raise ValueError(
507
+ f"Plugin {plugin_id} is not registered. Did you register it with tensorrt.plugin.register API?"
508
+ )
509
+
510
+ plugin_def = QDP_REGISTRY[plugin_id]
511
+ impl_attr_names, found_tactic = _validate_impl(impl_func, plugin_def)
512
+
513
+ plugin_def.impl_func = impl_func
514
+ plugin_def.impl_attr_names = impl_attr_names
515
+ plugin_def.expects_tactic = found_tactic
516
+ return impl_func
517
+
518
+ return decorator
519
+
520
+ # Decorator for `tensorrt.plugin.aot_impl`
521
+ @public_api()
522
+ def aot_impl(plugin_id: str) -> Callable:
523
+ """
524
+ Wraps a function to define an Ahead-of-Time (AOT) implementation for a plugin already registered through `trt.plugin.register`.
525
+
526
+ This API is only intended to be used as a decorator. The decorated function is not required to have type hints for input arguments or return value;
527
+ however, any type hints specified will be validated against the `trt.plugin.register` signature for consistency.
528
+
529
+ The schema for the function is as follows:
530
+ .. code-block:: text
531
+
532
+ (inp0: TensorDesc, inp1: TensorDesc, ..., attr0: SupportedAttrType, attr1: SupportedAttrType, outputs: Tuple[TensorDesc], tactic: Optional[int]) -> Tuple[str, str, KernelLaunchParams, SymExprs]
533
+
534
+ * Input tensors are passed first, each described by a `TensorDesc`.
535
+ * Plugin attributes are declared next.
536
+ * Not all attributes included in `trt.plugin.register` must be specified here -- they could be a subset.
537
+ * NOTE: Plugin attributes are not serialized into the engine when using an AOT implementation.
538
+ * `tactic` is an optional argument. If the plugin is using custom tactics, it must be specified to receive the tactic value to use for the current execution of the plugin.
539
+
540
+ Args:
541
+ plugin_id: The ID for the plugin in the form "{namespace}::{name}", which must match that used during `trt.plugin.register`
542
+
543
+ :returns:
544
+ - kernel_name: The name of the kernel.
545
+ - compiled_kernel: Compiled form of the kernel. Presently, only PTX is supported.
546
+ - launch_params: The launch parameters for the kernel
547
+ - extra_args: Symbolic expressions for scalar inputs to the kernel, located after the tensor inputs and before the tensor outputs
548
+
549
+ .. code-block:: python
550
+ :linenos:
551
+ :caption: Implementation of an elementwise plugin with an OpenAI Triton kernel
552
+
553
+ import tensorrt.plugin as trtp
554
+ import triton
555
+ import triton.language as tl
556
+
557
+ @triton.jit
558
+ def add_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr):
559
+ pid = tl.program_id(0)
560
+ offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
561
+ mask = offsets < n_elements
562
+ x = tl.load(x_ptr + offsets, mask=mask)
563
+ tl.store(y_ptr + offsets, x + 1, mask=mask)
564
+
565
+ @trtp.register("my::add_plugin")
566
+ def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> Tuple[trtp.TensorDesc]:
567
+ return inp0.like()
568
+
569
+ @trtp.aot_impl("my::elemwise_add_plugin")
570
+ def add_plugin_aot_impl(
571
+ inp0: trtp.TensorDesc, block_size: int, single_tactic: bool, outputs: Tuple[trtp.TensorDesc], tactic: int
572
+ ) -> Tuple[Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs]:
573
+
574
+ type_str = "fp32" if inp0.dtype == trt.float32 else "fp16"
575
+
576
+ src = triton.compiler.ASTSource(
577
+ fn=add_kernel,
578
+ signature=f"*{type_str},i32,*{type_str}",
579
+ constants={
580
+ "BLOCK_SIZE": block_size,
581
+ },
582
+ )
583
+
584
+ compiled_kernel = triton.compile(src)
585
+
586
+ N = inp0.shape_expr.numel()
587
+ launch_params = trtp.KernelLaunchParams()
588
+
589
+ # grid dims
590
+ launch_params.grid_x = trtp.cdiv(N, block_size)
591
+ # block dims
592
+ launch_params.block_x = compiled_kernel.metadata.num_warps * 32
593
+ # shared memory
594
+ launch_params.shared_mem = compiled_kernel.metadata.shared
595
+
596
+ extra_args = trtp.SymIntExprs(1)
597
+ extra_args[0] = trtp.SymInt32(N)
598
+
599
+ return compiled_kernel.metadata.name, compiled_kernel.asm["ptx"], launch_params, extra_args
600
+ """
601
+ def decorator(aot_impl_func: Callable):
602
+ if plugin_id not in QDP_REGISTRY:
603
+ raise ValueError(
604
+ f"Plugin {plugin_id} is not registered. Did you register it with tensorrt.plugin.register API?"
605
+ )
606
+
607
+ plugin_def = QDP_REGISTRY[plugin_id]
608
+ aot_impl_attr_names = _validate_aot_impl(aot_impl_func, plugin_def)
609
+
610
+ plugin_def.aot_impl_func = aot_impl_func
611
+ plugin_def.aot_impl_attr_names = aot_impl_attr_names
612
+ return aot_impl_func
613
+
614
+ return decorator
615
+
616
+
617
+ # Decorator for `tensorrt.plugin.autotune`
618
+ @public_api()
619
+ def autotune(plugin_id: str) -> Callable:
620
+ """
621
+ Wraps a function to define autotune logic for a plugin already registered through `trt.plugin.register`.
622
+
623
+ Autotuning is the process by which TensorRT executes the plugin over IO type/format combinations, and any custom tactics advertised as being supported by the plugin.
624
+ The (type, format, tactic) combination with the lowest latency is used to execute the plugin once the engine is built.
625
+
626
+ .. note:: An autotune function is optional. If not specified, TensorRT will assume the plugin only supports input types specified at network creation, output types specifeid through `trt.plugin.register`, and linear formats for all I/O.
627
+
628
+ This API is only intended to be used as a decorator. The decorated function is not required to have type hints for input arguments or return value; however, any type hints specified will be validated against the `trt.plugin.register` signature for consistency.
629
+
630
+ The schema for the function is as follows:
631
+
632
+ .. code-block:: text
633
+
634
+ (inp0: TensorDesc, inp1: TensorDesc, ..., attr0: SupportedAttrType, attr1: SupportedAttrType, outputs: Tuple[TensorDesc]) -> List[AutoTuneCombination]
635
+
636
+ * Input tensors are passed first, each described by a :class:`TensorDesc`.
637
+ * Plugin attributes are declared next. Not all attributes included in `trt.plugin.register` must be specified here -- they could be a subset.
638
+ * The function should return a list of :class:`AutoTuneCombination`\s.
639
+
640
+ Args:
641
+ plugin_id: The ID for the plugin in the form "{namespace}::{name}", which must match that used during `trt.plugin.register`
642
+
643
+ .. code-block:: python
644
+ :linenos:
645
+ :caption: An elementwise add plugin which supports both FP32 and FP16 linear I/O and wants to be tuned over 2 custom tactics.
646
+
647
+ import tensorrt.plugin as trtp
648
+
649
+ @trtp.register("my::add_plugin")
650
+ def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> Tuple[trtp.TensorDesc]:
651
+ return inp0.like()
652
+
653
+ @trtp.autotune("my::add_plugin")
654
+ def add_plugin_autotune(inp0: trtp.TensorDesc, block_size: int, outputs: Tuple[trtp.TensorDesc]) -> List[trtp.AutoTuneCombination]:
655
+
656
+ return [trtp.AutoTuneCombination("FP32|FP16, FP32|FP16", "LINEAR", [1, 2])]
657
+
658
+ .. code-block:: python
659
+ :linenos:
660
+ :caption: Same as above example but using index-by-index construction of an `AutoTuneCombination`
661
+
662
+ import tensorrt.plugin as trtp
663
+
664
+ @trtp.register("my::add_plugin")
665
+ def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> Tuple[trtp.TensorDesc]:
666
+ return inp0.like()
667
+
668
+ @trtp.autotune("my::add_plugin")
669
+ def add_plugin_autotune(inp0: trtp.TensorDesc, block_size: int, outputs: Tuple[trtp.TensorDesc]) -> List[trtp.AutoTuneCombination]:
670
+ c = trtp.AutoTuneCombination()
671
+ c.pos(0, "FP32|FP16", "LINEAR")
672
+ c.pos(1, "FP32|FP16") # index 1 is the output. Omitting format is the same as declaring it to be LINEAR.
673
+ c.tactics([1, 2])
674
+ return [c]
675
+ """
676
+
677
+ def decorator(autotune_func: Callable):
678
+ if plugin_id not in QDP_REGISTRY:
679
+ raise ValueError(
680
+ f"Plugin {plugin_id} is not registered. Did you register it with tensorrt.plugin.register API?"
681
+ )
682
+
683
+ plugin_def = QDP_REGISTRY[plugin_id]
684
+ autotune_attr_names = _validate_autotune(autotune_func, plugin_def)
685
+
686
+ plugin_def.autotune_func = autotune_func
687
+ plugin_def.autotune_attr_names = autotune_attr_names
688
+
689
+ return autotune_func
690
+
691
+ return decorator