tensorrt-cu12-bindings 10.8.0.43__cp312-none-win_amd64.whl → 10.9.0.34__cp312-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tensorrt-cu12-bindings might be problematic. Click here for more details.
- tensorrt_bindings/__init__.py +13 -8
- tensorrt_bindings/plugin/_export.py +4 -1
- tensorrt_bindings/plugin/_lib.py +199 -31
- tensorrt_bindings/plugin/_plugin_class.py +176 -41
- tensorrt_bindings/plugin/_tensor.py +351 -73
- tensorrt_bindings/plugin/_validate.py +122 -4
- tensorrt_bindings/tensorrt.cp312-win_amd64.pyd +0 -0
- {tensorrt_cu12_bindings-10.8.0.43.dist-info → tensorrt_cu12_bindings-10.9.0.34.dist-info}/METADATA +1 -1
- tensorrt_cu12_bindings-10.9.0.34.dist-info/RECORD +17 -0
- tensorrt_cu12_bindings-10.8.0.43.dist-info/RECORD +0 -17
- {tensorrt_cu12_bindings-10.8.0.43.dist-info → tensorrt_cu12_bindings-10.9.0.34.dist-info}/LICENSE.txt +0 -0
- {tensorrt_cu12_bindings-10.8.0.43.dist-info → tensorrt_cu12_bindings-10.9.0.34.dist-info}/WHEEL +0 -0
- {tensorrt_cu12_bindings-10.8.0.43.dist-info → tensorrt_cu12_bindings-10.9.0.34.dist-info}/top_level.txt +0 -0
- {tensorrt_cu12_bindings-10.8.0.43.dist-info → tensorrt_cu12_bindings-10.9.0.34.dist-info}/zip-safe +0 -0
tensorrt_bindings/__init__.py
CHANGED
|
@@ -31,20 +31,25 @@ else:
|
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
if not _libs_wheel_imported and sys.platform.startswith("win"):
|
|
34
|
+
log_found_dlls = bool(int(os.environ.get("TRT_LOG_FOUND_DLLS", 0)))
|
|
34
35
|
# On Windows, we need to manually open the TensorRT libraries - otherwise we are unable to
|
|
35
36
|
# load the bindings. If we imported the tensorrt_libs wheel, then that should have taken care of it for us.
|
|
36
37
|
def find_lib(name):
|
|
37
38
|
paths = os.environ["PATH"].split(os.path.pathsep)
|
|
39
|
+
|
|
38
40
|
# Add ../tensorrt.libs to the search path. This allows repackaging non-standalone TensorRT wheels as standalone
|
|
39
41
|
# using delvewheel (with the --no-mangle-all flag set) to work properly.
|
|
40
42
|
paths.append(os.path.join(os.path.dirname(__file__), os.pardir, "tensorrt.libs"))
|
|
43
|
+
|
|
41
44
|
for path in paths:
|
|
42
45
|
libpath = os.path.join(path, name)
|
|
43
46
|
if os.path.isfile(libpath):
|
|
47
|
+
if log_found_dlls:
|
|
48
|
+
print(f"Found {name} in path: {libpath}")
|
|
44
49
|
return libpath
|
|
45
50
|
|
|
46
|
-
if name.startswith("
|
|
47
|
-
return
|
|
51
|
+
if name.startswith("nvinfer_builder_resource"):
|
|
52
|
+
return None
|
|
48
53
|
|
|
49
54
|
raise FileNotFoundError(
|
|
50
55
|
"Could not find: {:}. Is it on your PATH?\nNote: Paths searched were:\n{:}".format(name, paths)
|
|
@@ -54,11 +59,9 @@ if not _libs_wheel_imported and sys.platform.startswith("win"):
|
|
|
54
59
|
LIBRARIES = {
|
|
55
60
|
"tensorrt": [
|
|
56
61
|
"nvinfer_10.dll",
|
|
57
|
-
"cublas64_12.dll",
|
|
58
|
-
"cublasLt64_12.dll",
|
|
59
|
-
"cudnn64_##CUDNN_MAJOR##.dll",
|
|
60
62
|
"nvinfer_plugin_10.dll",
|
|
61
63
|
"nvonnxparser_10.dll",
|
|
64
|
+
"nvinfer_builder_resource_10.dll",
|
|
62
65
|
],
|
|
63
66
|
"tensorrt_dispatch": [
|
|
64
67
|
"nvinfer_dispatch_10.dll",
|
|
@@ -70,14 +73,16 @@ if not _libs_wheel_imported and sys.platform.startswith("win"):
|
|
|
70
73
|
|
|
71
74
|
for lib in LIBRARIES:
|
|
72
75
|
lib_path = find_lib(lib)
|
|
73
|
-
if lib_path
|
|
74
|
-
|
|
76
|
+
if not lib_path:
|
|
77
|
+
continue
|
|
78
|
+
assert os.path.isfile(lib_path)
|
|
79
|
+
ctypes.CDLL(lib_path)
|
|
75
80
|
|
|
76
81
|
del _libs_wheel_imported
|
|
77
82
|
|
|
78
83
|
from .tensorrt import *
|
|
79
84
|
|
|
80
|
-
__version__ = "10.
|
|
85
|
+
__version__ = "10.9.0.34"
|
|
81
86
|
|
|
82
87
|
|
|
83
88
|
# Provides Python's `with` syntax
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
#
|
|
2
|
-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
3
|
# SPDX-License-Identifier: Apache-2.0
|
|
4
4
|
#
|
|
5
5
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
# limitations under the License.
|
|
16
16
|
#
|
|
17
17
|
|
|
18
|
+
import tensorrt as trt
|
|
18
19
|
from types import ModuleType
|
|
19
20
|
import importlib
|
|
20
21
|
|
|
@@ -34,3 +35,5 @@ def public_api(module: ModuleType = None, symbol: str = None):
|
|
|
34
35
|
return obj
|
|
35
36
|
|
|
36
37
|
return export_impl
|
|
38
|
+
|
|
39
|
+
IS_AOT_ENABLED = hasattr(trt, "QuickPluginCreationRequest")
|
tensorrt_bindings/plugin/_lib.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
#
|
|
2
|
-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
3
|
# SPDX-License-Identifier: Apache-2.0
|
|
4
4
|
#
|
|
5
5
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -20,13 +20,16 @@ import types
|
|
|
20
20
|
import typing
|
|
21
21
|
from typing import Callable, Tuple, List
|
|
22
22
|
import numpy as np
|
|
23
|
-
|
|
24
|
-
from .
|
|
23
|
+
from ._plugin_class import _TemplateJITPlugin
|
|
24
|
+
from ._export import IS_AOT_ENABLED
|
|
25
|
+
if IS_AOT_ENABLED:
|
|
26
|
+
from ._plugin_class import _TemplateAOTPlugin
|
|
25
27
|
from ._validate import (
|
|
26
28
|
_parse_register_inputs,
|
|
27
29
|
_parse_register_return,
|
|
28
30
|
_validate_autotune,
|
|
29
31
|
_validate_impl,
|
|
32
|
+
_validate_aot_impl,
|
|
30
33
|
_validate_name_and_namespace,
|
|
31
34
|
)
|
|
32
35
|
from ._utils import (
|
|
@@ -91,11 +94,13 @@ class PluginDef:
|
|
|
91
94
|
self.plugin_id = None # includes namespace (format is ns::name)
|
|
92
95
|
self.register_func = None
|
|
93
96
|
self.impl_func = None
|
|
97
|
+
self.aot_impl_func = None
|
|
94
98
|
self.autotune_func = None
|
|
95
99
|
self.autotune_attr_names = None
|
|
96
100
|
self.input_tensor_names = None
|
|
97
101
|
self.input_attrs = None # map name -> type
|
|
98
102
|
self.impl_attr_names = None
|
|
103
|
+
self.aot_impl_attr_names = None
|
|
99
104
|
self.num_outputs = None
|
|
100
105
|
self.input_arg_schema = None
|
|
101
106
|
self.expects_tactic = None
|
|
@@ -195,24 +200,26 @@ class PluginDef:
|
|
|
195
200
|
)
|
|
196
201
|
)
|
|
197
202
|
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
203
|
+
def create_plugin_instance(quick_plugin_creation_request: "trt.QuickPluginCreationRequest" = None):
|
|
204
|
+
if quick_plugin_creation_request is None:
|
|
205
|
+
plg = plg_creator.create_plugin(
|
|
206
|
+
name,
|
|
207
|
+
namespace,
|
|
208
|
+
trt.PluginFieldCollection(fields),
|
|
209
|
+
trt.TensorRTPhase.BUILD
|
|
210
|
+
)
|
|
211
|
+
else:
|
|
212
|
+
plg = plg_creator.create_plugin(
|
|
213
|
+
name,
|
|
214
|
+
namespace,
|
|
215
|
+
trt.PluginFieldCollection(fields),
|
|
216
|
+
trt.TensorRTPhase.BUILD,
|
|
217
|
+
quick_plugin_creation_request
|
|
218
|
+
)
|
|
213
219
|
|
|
214
|
-
|
|
220
|
+
return input_tensors, [], plg
|
|
215
221
|
|
|
222
|
+
return create_plugin_instance
|
|
216
223
|
|
|
217
224
|
class _TemplatePluginCreator(trt.IPluginCreatorV3Quick):
|
|
218
225
|
def __init__(self, name, namespace, attrs):
|
|
@@ -246,7 +253,7 @@ class _TemplatePluginCreator(trt.IPluginCreatorV3Quick):
|
|
|
246
253
|
|
|
247
254
|
self.field_names = trt.PluginFieldCollection(field_names)
|
|
248
255
|
|
|
249
|
-
def create_plugin(self, name, namespace, fc, phase):
|
|
256
|
+
def create_plugin(self, name, namespace, fc, phase, qpcr: "trt.QuickPluginCreationRequest" = None):
|
|
250
257
|
desc = QDP_REGISTRY[f"{namespace}::{name}"]
|
|
251
258
|
name = name
|
|
252
259
|
namespace = namespace
|
|
@@ -271,18 +278,83 @@ class _TemplatePluginCreator(trt.IPluginCreatorV3Quick):
|
|
|
271
278
|
else:
|
|
272
279
|
attrs[f.name] = attr_type_annot(f.data)
|
|
273
280
|
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
281
|
+
jit_or_aot = None # True if JIT is to be created, False if AOT. Not None will be asserted before plugin creation.
|
|
282
|
+
|
|
283
|
+
if qpcr is None:
|
|
284
|
+
plg = _TemplateJITPlugin(name, namespace, desc.num_outputs)
|
|
285
|
+
|
|
286
|
+
plg.init(
|
|
287
|
+
desc.register_func,
|
|
288
|
+
attrs,
|
|
289
|
+
desc.impl_attr_names,
|
|
290
|
+
desc.impl_func,
|
|
291
|
+
desc.autotune_attr_names,
|
|
292
|
+
desc.autotune_func,
|
|
293
|
+
desc.expects_tactic,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
return plg
|
|
297
|
+
|
|
298
|
+
# If there is a strict preference, that takes precedence
|
|
299
|
+
if qpcr == trt.QuickPluginCreationRequest.STRICT_AOT:
|
|
300
|
+
if desc.aot_impl_func is None:
|
|
301
|
+
raise ValueError(f"AOT implementation requested, but not defined for '{desc.plugin_id}'. Was @trt.plugin.aot_impl defined?")
|
|
302
|
+
jit_or_aot = False
|
|
303
|
+
elif qpcr == trt.QuickPluginCreationRequest.STRICT_JIT:
|
|
304
|
+
if desc.impl_func is None:
|
|
305
|
+
raise ValueError(f"JIT implementation requested, but not defined for '{desc.plugin_id}'. Was @trt.plugin.impl defined?")
|
|
306
|
+
jit_or_aot = True
|
|
307
|
+
else:
|
|
308
|
+
aot_defined = desc.aot_impl_func is not None
|
|
309
|
+
jit_defined = desc.impl_func is not None
|
|
310
|
+
|
|
311
|
+
# A preferemce must be indicated if both AOT and JIT implementations are defined
|
|
312
|
+
if aot_defined and jit_defined:
|
|
313
|
+
if qpcr == trt.QuickPluginCreationRequest.PREFER_AOT:
|
|
314
|
+
jit_or_aot = False
|
|
315
|
+
elif qpcr == trt.QuickPluginCreationRequest.PREFER_JIT:
|
|
316
|
+
jit_or_aot = True
|
|
317
|
+
else:
|
|
318
|
+
raise ValueError(f"Plugin '{desc.plugin_id}' has both AOT and JIT implementations. NetworkDefinitionCreationFlag.PREFER_AOT_PYTHON_PLUGINS or NetworkDefinitionCreationFlag.PREFER_JIT_PYTHON_PLUGINS should be specified.")
|
|
319
|
+
else:
|
|
320
|
+
# If only one implementation is defined, use that.
|
|
321
|
+
# Any preference specified is ignored. If the preference is strong, a strict flag should have been specified.
|
|
322
|
+
if aot_defined:
|
|
323
|
+
jit_or_aot = False
|
|
324
|
+
elif jit_defined:
|
|
325
|
+
jit_or_aot = True
|
|
326
|
+
else:
|
|
327
|
+
raise ValueError(f"Plugin '{desc.plugin_id}' does not have either a AOT or JIT implementation.")
|
|
328
|
+
|
|
329
|
+
assert jit_or_aot is not None
|
|
330
|
+
|
|
331
|
+
if jit_or_aot:
|
|
332
|
+
plg = _TemplateJITPlugin(name, namespace, desc.num_outputs)
|
|
333
|
+
|
|
334
|
+
plg.init(
|
|
335
|
+
desc.register_func,
|
|
336
|
+
attrs,
|
|
337
|
+
desc.impl_attr_names,
|
|
338
|
+
desc.impl_func,
|
|
339
|
+
desc.autotune_attr_names,
|
|
340
|
+
desc.autotune_func,
|
|
341
|
+
desc.expects_tactic,
|
|
342
|
+
)
|
|
285
343
|
|
|
344
|
+
else:
|
|
345
|
+
plg = _TemplateAOTPlugin(name, namespace, desc.num_outputs)
|
|
346
|
+
|
|
347
|
+
plg.init(
|
|
348
|
+
desc.register_func,
|
|
349
|
+
attrs,
|
|
350
|
+
desc.aot_impl_attr_names,
|
|
351
|
+
desc.aot_impl_func,
|
|
352
|
+
desc.autotune_attr_names,
|
|
353
|
+
desc.autotune_func
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# the caller can determine if the created plugin is an AOT or JIT plugin by inspecting the interface info
|
|
357
|
+
return plg
|
|
286
358
|
|
|
287
359
|
def _register_plugin_creator(name: str, namespace: str, attrs_types):
|
|
288
360
|
plg_registry = trt.get_plugin_registry()
|
|
@@ -445,6 +517,102 @@ def impl(plugin_id: str) -> Callable:
|
|
|
445
517
|
|
|
446
518
|
return decorator
|
|
447
519
|
|
|
520
|
+
# Decorator for `tensorrt.plugin.aot_impl`
|
|
521
|
+
@public_api()
|
|
522
|
+
def aot_impl(plugin_id: str) -> Callable:
|
|
523
|
+
"""
|
|
524
|
+
Wraps a function to define an Ahead-of-Time (AOT) implementation for a plugin already registered through `trt.plugin.register`.
|
|
525
|
+
|
|
526
|
+
This API is only intended to be used as a decorator. The decorated function is not required to have type hints for input arguments or return value;
|
|
527
|
+
however, any type hints specified will be validated against the `trt.plugin.register` signature for consistency.
|
|
528
|
+
|
|
529
|
+
The schema for the function is as follows:
|
|
530
|
+
.. code-block:: text
|
|
531
|
+
|
|
532
|
+
(inp0: TensorDesc, inp1: TensorDesc, ..., attr0: SupportedAttrType, attr1: SupportedAttrType, outputs: Tuple[TensorDesc], tactic: Optional[int]) -> Tuple[str, str, KernelLaunchParams, SymExprs]
|
|
533
|
+
|
|
534
|
+
* Input tensors are passed first, each described by a `TensorDesc`.
|
|
535
|
+
* Plugin attributes are declared next.
|
|
536
|
+
* Not all attributes included in `trt.plugin.register` must be specified here -- they could be a subset.
|
|
537
|
+
* NOTE: Plugin attributes are not serialized into the engine when using an AOT implementation.
|
|
538
|
+
* `tactic` is an optional argument. If the plugin is using custom tactics, it must be specified to receive the tactic value to use for the current execution of the plugin.
|
|
539
|
+
|
|
540
|
+
Args:
|
|
541
|
+
plugin_id: The ID for the plugin in the form "{namespace}::{name}", which must match that used during `trt.plugin.register`
|
|
542
|
+
|
|
543
|
+
:returns:
|
|
544
|
+
- kernel_name: The name of the kernel.
|
|
545
|
+
- compiled_kernel: Compiled form of the kernel. Presently, only PTX is supported.
|
|
546
|
+
- launch_params: The launch parameters for the kernel
|
|
547
|
+
- extra_args: Symbolic expressions for scalar inputs to the kernel, located after the tensor inputs and before the tensor outputs
|
|
548
|
+
|
|
549
|
+
.. code-block:: python
|
|
550
|
+
:linenos:
|
|
551
|
+
:caption: Implementation of an elementwise plugin with an OpenAI Triton kernel
|
|
552
|
+
|
|
553
|
+
import tensorrt.plugin as trtp
|
|
554
|
+
import triton
|
|
555
|
+
import triton.language as tl
|
|
556
|
+
|
|
557
|
+
@triton.jit
|
|
558
|
+
def add_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr):
|
|
559
|
+
pid = tl.program_id(0)
|
|
560
|
+
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
561
|
+
mask = offsets < n_elements
|
|
562
|
+
x = tl.load(x_ptr + offsets, mask=mask)
|
|
563
|
+
tl.store(y_ptr + offsets, x + 1, mask=mask)
|
|
564
|
+
|
|
565
|
+
@trtp.register("my::add_plugin")
|
|
566
|
+
def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> Tuple[trtp.TensorDesc]:
|
|
567
|
+
return inp0.like()
|
|
568
|
+
|
|
569
|
+
@trtp.aot_impl("my::elemwise_add_plugin")
|
|
570
|
+
def add_plugin_aot_impl(
|
|
571
|
+
inp0: trtp.TensorDesc, block_size: int, single_tactic: bool, outputs: Tuple[trtp.TensorDesc], tactic: int
|
|
572
|
+
) -> Tuple[Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs]:
|
|
573
|
+
|
|
574
|
+
type_str = "fp32" if inp0.dtype == trt.float32 else "fp16"
|
|
575
|
+
|
|
576
|
+
src = triton.compiler.ASTSource(
|
|
577
|
+
fn=add_kernel,
|
|
578
|
+
signature=f"*{type_str},i32,*{type_str}",
|
|
579
|
+
constants={
|
|
580
|
+
"BLOCK_SIZE": block_size,
|
|
581
|
+
},
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
compiled_kernel = triton.compile(src)
|
|
585
|
+
|
|
586
|
+
N = inp0.shape_expr.numel()
|
|
587
|
+
launch_params = trtp.KernelLaunchParams()
|
|
588
|
+
|
|
589
|
+
# grid dims
|
|
590
|
+
launch_params.grid_x = trtp.cdiv(N, block_size)
|
|
591
|
+
# block dims
|
|
592
|
+
launch_params.block_x = compiled_kernel.metadata.num_warps * 32
|
|
593
|
+
# shared memory
|
|
594
|
+
launch_params.shared_mem = compiled_kernel.metadata.shared
|
|
595
|
+
|
|
596
|
+
extra_args = trtp.SymIntExprs(1)
|
|
597
|
+
extra_args[0] = trtp.SymInt32(N)
|
|
598
|
+
|
|
599
|
+
return compiled_kernel.metadata.name, compiled_kernel.asm["ptx"], launch_params, extra_args
|
|
600
|
+
"""
|
|
601
|
+
def decorator(aot_impl_func: Callable):
|
|
602
|
+
if plugin_id not in QDP_REGISTRY:
|
|
603
|
+
raise ValueError(
|
|
604
|
+
f"Plugin {plugin_id} is not registered. Did you register it with tensorrt.plugin.register API?"
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
plugin_def = QDP_REGISTRY[plugin_id]
|
|
608
|
+
aot_impl_attr_names = _validate_aot_impl(aot_impl_func, plugin_def)
|
|
609
|
+
|
|
610
|
+
plugin_def.aot_impl_func = aot_impl_func
|
|
611
|
+
plugin_def.aot_impl_attr_names = aot_impl_attr_names
|
|
612
|
+
return aot_impl_func
|
|
613
|
+
|
|
614
|
+
return decorator
|
|
615
|
+
|
|
448
616
|
|
|
449
617
|
# Decorator for `tensorrt.plugin.autotune`
|
|
450
618
|
@public_api()
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
#
|
|
2
|
-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
3
|
# SPDX-License-Identifier: Apache-2.0
|
|
4
4
|
#
|
|
5
5
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -15,25 +15,27 @@
|
|
|
15
15
|
# limitations under the License.
|
|
16
16
|
#
|
|
17
17
|
import tensorrt as trt
|
|
18
|
-
from typing import Tuple
|
|
18
|
+
from typing import Tuple, Union
|
|
19
19
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
from ._utils import _numpy_to_plugin_field_type, _built_in_to_plugin_field_type
|
|
22
|
-
from ._tensor import TensorDesc, Tensor, Shape, ShapeExpr, ShapeExprs
|
|
22
|
+
from ._tensor import TensorDesc, Tensor, Shape, ShapeExpr, ShapeExprs, SymIntExpr, SymExprs, SymInt32
|
|
23
|
+
from ._export import IS_AOT_ENABLED
|
|
24
|
+
if IS_AOT_ENABLED:
|
|
25
|
+
from ._tensor import KernelLaunchParams
|
|
23
26
|
from ._autotune import _TypeFormatCombination
|
|
24
27
|
|
|
28
|
+
from ._export import public_api
|
|
25
29
|
|
|
26
|
-
class
|
|
30
|
+
class _TemplatePluginBase(
|
|
27
31
|
trt.IPluginV3,
|
|
28
32
|
trt.IPluginV3QuickCore,
|
|
29
33
|
trt.IPluginV3QuickBuild,
|
|
30
|
-
trt.IPluginV3QuickRuntime,
|
|
31
34
|
):
|
|
32
35
|
def __init__(self, name, namespace, num_outputs):
|
|
33
36
|
trt.IPluginV3.__init__(self)
|
|
34
37
|
trt.IPluginV3QuickCore.__init__(self)
|
|
35
38
|
trt.IPluginV3QuickBuild.__init__(self)
|
|
36
|
-
trt.IPluginV3QuickRuntime.__init__(self)
|
|
37
39
|
|
|
38
40
|
self.plugin_version = "1"
|
|
39
41
|
self.input_types = []
|
|
@@ -46,28 +48,6 @@ class _TemplatePlugin(
|
|
|
46
48
|
self.autotune_combs = []
|
|
47
49
|
self.supported_combs = {}
|
|
48
50
|
self.curr_comb = None
|
|
49
|
-
self.expects_tactic = False
|
|
50
|
-
|
|
51
|
-
def init(
|
|
52
|
-
self,
|
|
53
|
-
register_function,
|
|
54
|
-
attrs,
|
|
55
|
-
impl_attr_names,
|
|
56
|
-
impl_function,
|
|
57
|
-
autotune_attr_names,
|
|
58
|
-
autotune_function,
|
|
59
|
-
expects_tactic,
|
|
60
|
-
):
|
|
61
|
-
self.register_function = register_function
|
|
62
|
-
self.impl_function = impl_function
|
|
63
|
-
self.attrs = attrs
|
|
64
|
-
self.impl_attr_names = impl_attr_names
|
|
65
|
-
self.autotune_attr_names = autotune_attr_names
|
|
66
|
-
self.autotune_function = autotune_function
|
|
67
|
-
self.expects_tactic = expects_tactic
|
|
68
|
-
|
|
69
|
-
def get_capability_interface(self, type):
|
|
70
|
-
return self
|
|
71
51
|
|
|
72
52
|
def get_num_outputs(self):
|
|
73
53
|
return self.num_outputs
|
|
@@ -140,7 +120,7 @@ class _TemplatePlugin(
|
|
|
140
120
|
|
|
141
121
|
def get_output_shapes(self, inputs, shape_inputs, exprBuilder):
|
|
142
122
|
assert len(shape_inputs) == 0 # Shape inputs are not yet supported for QDPs
|
|
143
|
-
|
|
123
|
+
SymIntExpr._exprBuilder = exprBuilder
|
|
144
124
|
self.input_descs = []
|
|
145
125
|
for i in range(len(inputs)):
|
|
146
126
|
desc = TensorDesc()
|
|
@@ -247,6 +227,45 @@ class _TemplatePlugin(
|
|
|
247
227
|
|
|
248
228
|
return ret_supported_combs
|
|
249
229
|
|
|
230
|
+
def get_aliased_input(self, output_index: int):
|
|
231
|
+
return self.aliased_map[output_index]
|
|
232
|
+
|
|
233
|
+
def get_valid_tactics(self):
|
|
234
|
+
tactics = self.supported_combs.get(self.curr_comb)
|
|
235
|
+
assert tactics is not None
|
|
236
|
+
return list(tactics)
|
|
237
|
+
|
|
238
|
+
def set_tactic(self, tactic):
|
|
239
|
+
self._tactic = tactic
|
|
240
|
+
|
|
241
|
+
class _TemplateJITPlugin(_TemplatePluginBase, trt.IPluginV3QuickRuntime):
|
|
242
|
+
def __init__(self, name, namespace, num_outputs):
|
|
243
|
+
super().__init__(name, namespace, num_outputs)
|
|
244
|
+
trt.IPluginV3QuickRuntime.__init__(self)
|
|
245
|
+
|
|
246
|
+
self.expects_tactic = False
|
|
247
|
+
|
|
248
|
+
def init(
|
|
249
|
+
self,
|
|
250
|
+
register_function,
|
|
251
|
+
attrs,
|
|
252
|
+
impl_attr_names,
|
|
253
|
+
impl_function,
|
|
254
|
+
autotune_attr_names,
|
|
255
|
+
autotune_function,
|
|
256
|
+
expects_tactic,
|
|
257
|
+
):
|
|
258
|
+
self.register_function = register_function
|
|
259
|
+
self.impl_function = impl_function
|
|
260
|
+
self.attrs = attrs
|
|
261
|
+
self.impl_attr_names = impl_attr_names
|
|
262
|
+
self.autotune_attr_names = autotune_attr_names
|
|
263
|
+
self.autotune_function = autotune_function
|
|
264
|
+
self.expects_tactic = expects_tactic
|
|
265
|
+
|
|
266
|
+
def get_capability_interface(self, type):
|
|
267
|
+
return self
|
|
268
|
+
|
|
250
269
|
def enqueue(
|
|
251
270
|
self,
|
|
252
271
|
input_desc,
|
|
@@ -305,20 +324,136 @@ class _TemplatePlugin(
|
|
|
305
324
|
else:
|
|
306
325
|
self.impl_function(*input_tensors, *val, output_tensors, stream=stream)
|
|
307
326
|
|
|
308
|
-
def get_aliased_input(self, output_index: int):
|
|
309
|
-
return self.aliased_map[output_index]
|
|
310
|
-
|
|
311
|
-
def get_valid_tactics(self):
|
|
312
|
-
tactics = self.supported_combs.get(self.curr_comb)
|
|
313
|
-
assert tactics is not None
|
|
314
|
-
return list(tactics)
|
|
315
|
-
|
|
316
|
-
def set_tactic(self, tactic):
|
|
317
|
-
self._tactic = tactic
|
|
318
|
-
|
|
319
327
|
def clone(self):
|
|
320
|
-
cloned_plugin =
|
|
328
|
+
cloned_plugin = _TemplateJITPlugin(
|
|
321
329
|
self.plugin_name, self.plugin_namespace, self.num_outputs
|
|
322
330
|
)
|
|
323
331
|
cloned_plugin.__dict__.update(self.__dict__)
|
|
324
332
|
return cloned_plugin
|
|
333
|
+
|
|
334
|
+
if IS_AOT_ENABLED:
|
|
335
|
+
class _TemplateAOTPlugin(
|
|
336
|
+
_TemplatePluginBase,
|
|
337
|
+
trt.IPluginV3QuickAOTBuild,
|
|
338
|
+
):
|
|
339
|
+
def __init__(self, name, namespace, num_outputs):
|
|
340
|
+
_TemplatePluginBase.__init__(self, name, namespace, num_outputs)
|
|
341
|
+
trt.IPluginV3QuickAOTBuild.__init__(self)
|
|
342
|
+
self.kernel_map = {}
|
|
343
|
+
|
|
344
|
+
def set_tactic(self, tactic):
|
|
345
|
+
self._tactic = tactic
|
|
346
|
+
|
|
347
|
+
def init(
|
|
348
|
+
self,
|
|
349
|
+
register_function,
|
|
350
|
+
attrs,
|
|
351
|
+
aot_impl_attr_names,
|
|
352
|
+
aot_impl_function,
|
|
353
|
+
autotune_attr_names,
|
|
354
|
+
autotune_function
|
|
355
|
+
):
|
|
356
|
+
self.register_function = register_function
|
|
357
|
+
self.aot_impl_function = aot_impl_function
|
|
358
|
+
self.attrs = attrs
|
|
359
|
+
self.aot_impl_attr_names = aot_impl_attr_names
|
|
360
|
+
self.autotune_attr_names = autotune_attr_names
|
|
361
|
+
self.autotune_function = autotune_function
|
|
362
|
+
|
|
363
|
+
def get_capability_interface(self, type):
|
|
364
|
+
return self
|
|
365
|
+
|
|
366
|
+
def get_kernel(self, inputDesc, outputDesc):
|
|
367
|
+
io_types = []
|
|
368
|
+
io_formats = []
|
|
369
|
+
|
|
370
|
+
for i, desc in enumerate(inputDesc):
|
|
371
|
+
io_types.append(desc.type)
|
|
372
|
+
io_formats.append(desc.format)
|
|
373
|
+
|
|
374
|
+
for i, desc in enumerate(outputDesc):
|
|
375
|
+
io_types.append(desc.type)
|
|
376
|
+
io_formats.append(desc.format)
|
|
377
|
+
|
|
378
|
+
key = (tuple(io_types), tuple(io_formats), self._tactic)
|
|
379
|
+
|
|
380
|
+
assert key in self.kernel_map, "key {} not in kernel_map".format(key)
|
|
381
|
+
|
|
382
|
+
kernel_name, ptx = self.kernel_map[key]
|
|
383
|
+
|
|
384
|
+
return kernel_name, ptx.encode() if isinstance(ptx, str) else ptx
|
|
385
|
+
|
|
386
|
+
def get_launch_params(self, inDimsExprs, in_out, num_inputs, launchParams, symExprSetter, exprBuilder):
|
|
387
|
+
|
|
388
|
+
SymIntExpr._exprBuilder = exprBuilder
|
|
389
|
+
|
|
390
|
+
if len(self.attrs) > 0:
|
|
391
|
+
_, val = zip(*self.attrs.items())
|
|
392
|
+
else:
|
|
393
|
+
val = ()
|
|
394
|
+
|
|
395
|
+
io_types = []
|
|
396
|
+
io_formats = []
|
|
397
|
+
|
|
398
|
+
for i, desc in enumerate(in_out):
|
|
399
|
+
if i < num_inputs:
|
|
400
|
+
self.input_descs[i]._immutable = False
|
|
401
|
+
self.input_descs[i].shape = Shape(desc)
|
|
402
|
+
self.input_descs[i].dtype = desc.desc.type
|
|
403
|
+
self.input_descs[i].format = desc.desc.format
|
|
404
|
+
self.input_descs[i].scale = desc.desc.scale
|
|
405
|
+
io_types.append(desc.desc.type)
|
|
406
|
+
io_formats.append(desc.desc.format)
|
|
407
|
+
self.input_descs[i]._immutable = True
|
|
408
|
+
else:
|
|
409
|
+
self.output_descs[i - num_inputs]._immutable = False
|
|
410
|
+
self.output_descs[i - num_inputs].shape = Shape(desc)
|
|
411
|
+
self.output_descs[i - num_inputs].dtype = desc.desc.type
|
|
412
|
+
self.output_descs[i - num_inputs].format = desc.desc.format
|
|
413
|
+
self.output_descs[i - num_inputs].scale = desc.desc.scale
|
|
414
|
+
io_types.append(desc.desc.type)
|
|
415
|
+
io_formats.append(desc.desc.format)
|
|
416
|
+
self.output_descs[i - num_inputs]._immutable = True
|
|
417
|
+
|
|
418
|
+
kernel_name, ptx, launch_params, extra_args = self.aot_impl_function(
|
|
419
|
+
*self.input_descs, *val, self.output_descs, self._tactic
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
if not isinstance(kernel_name, str) and not isinstance(kernel_name, bytes):
|
|
423
|
+
raise TypeError(f"Kernel name must be a 'str' or 'bytes'. Got: {type(kernel_name)}.")
|
|
424
|
+
|
|
425
|
+
if not isinstance(ptx, str) and not isinstance(ptx, bytes):
|
|
426
|
+
raise TypeError(f"PTX/CUBIN must be a 'str' or 'bytes'. Got: {type(ptx)}.")
|
|
427
|
+
|
|
428
|
+
if not isinstance(launch_params, KernelLaunchParams):
|
|
429
|
+
raise TypeError(f"Launch params must be a 'tensorrt.plugin.KernelLaunchParams'. Got: {type(launch_params)}.")
|
|
430
|
+
|
|
431
|
+
if not isinstance(extra_args, SymExprs):
|
|
432
|
+
raise TypeError(f"Extra args must be a 'tensorrt.plugin.SymIntExprs'. Got: {type(extra_args)}.")
|
|
433
|
+
|
|
434
|
+
launchParams.grid_x = launch_params.grid_x()
|
|
435
|
+
launchParams.grid_y = launch_params.grid_y()
|
|
436
|
+
launchParams.grid_z = launch_params.grid_z()
|
|
437
|
+
launchParams.block_x = launch_params.block_x()
|
|
438
|
+
launchParams.block_y = launch_params.block_y()
|
|
439
|
+
launchParams.block_z = launch_params.block_z()
|
|
440
|
+
launchParams.shared_mem = launch_params.shared_mem()
|
|
441
|
+
|
|
442
|
+
self.kernel_map[(tuple(io_types), tuple(io_formats), self._tactic)] = (kernel_name, ptx)
|
|
443
|
+
|
|
444
|
+
symExprSetter.nbSymExprs = len(extra_args)
|
|
445
|
+
|
|
446
|
+
for i, arg in enumerate(extra_args):
|
|
447
|
+
if not isinstance(arg, SymInt32):
|
|
448
|
+
raise TypeError(f"Extra args must be a 'tensorrt.plugin.SymInt32'. Got: {type(arg)}.")
|
|
449
|
+
symExprSetter[i] = arg()
|
|
450
|
+
|
|
451
|
+
def get_timing_cache_id(self):
|
|
452
|
+
return ""
|
|
453
|
+
|
|
454
|
+
def clone(self):
|
|
455
|
+
cloned_plugin = _TemplateAOTPlugin(
|
|
456
|
+
self.plugin_name, self.plugin_namespace, self.num_outputs
|
|
457
|
+
)
|
|
458
|
+
cloned_plugin.__dict__.update(self.__dict__)
|
|
459
|
+
return cloned_plugin
|