tensorrt-cu12-bindings 10.14.1.48.post1__cp39-none-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tensorrt_bindings/__init__.py +224 -0
- tensorrt_bindings/plugin/__init__.py +46 -0
- tensorrt_bindings/plugin/_autotune.py +270 -0
- tensorrt_bindings/plugin/_export.py +39 -0
- tensorrt_bindings/plugin/_lib.py +691 -0
- tensorrt_bindings/plugin/_plugin_class.py +459 -0
- tensorrt_bindings/plugin/_tensor.py +1128 -0
- tensorrt_bindings/plugin/_top_level.py +132 -0
- tensorrt_bindings/plugin/_utils.py +77 -0
- tensorrt_bindings/plugin/_validate.py +475 -0
- tensorrt_bindings/tensorrt.so +0 -0
- tensorrt_cu12_bindings-10.14.1.48.post1.dist-info/LICENSE.txt +180 -0
- tensorrt_cu12_bindings-10.14.1.48.post1.dist-info/METADATA +17 -0
- tensorrt_cu12_bindings-10.14.1.48.post1.dist-info/RECORD +17 -0
- tensorrt_cu12_bindings-10.14.1.48.post1.dist-info/WHEEL +5 -0
- tensorrt_cu12_bindings-10.14.1.48.post1.dist-info/top_level.txt +1 -0
- tensorrt_cu12_bindings-10.14.1.48.post1.dist-info/zip-safe +1 -0
|
@@ -0,0 +1,459 @@
|
|
|
1
|
+
#
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
import tensorrt as trt
|
|
18
|
+
from typing import Tuple, Union
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
from ._utils import _numpy_to_plugin_field_type, _built_in_to_plugin_field_type
|
|
22
|
+
from ._tensor import TensorDesc, Tensor, Shape, ShapeExpr, ShapeExprs, SymIntExpr, SymExprs, SymInt32
|
|
23
|
+
from ._export import IS_AOT_ENABLED
|
|
24
|
+
if IS_AOT_ENABLED:
|
|
25
|
+
from ._tensor import KernelLaunchParams
|
|
26
|
+
from ._autotune import _TypeFormatCombination
|
|
27
|
+
|
|
28
|
+
from ._export import public_api
|
|
29
|
+
|
|
30
|
+
class _TemplatePluginBase(
|
|
31
|
+
trt.IPluginV3,
|
|
32
|
+
trt.IPluginV3QuickCore,
|
|
33
|
+
trt.IPluginV3QuickBuild,
|
|
34
|
+
):
|
|
35
|
+
def __init__(self, name, namespace, num_outputs):
|
|
36
|
+
trt.IPluginV3.__init__(self)
|
|
37
|
+
trt.IPluginV3QuickCore.__init__(self)
|
|
38
|
+
trt.IPluginV3QuickBuild.__init__(self)
|
|
39
|
+
|
|
40
|
+
self.plugin_version = "1"
|
|
41
|
+
self.input_types = []
|
|
42
|
+
self.aliased_map = {} # output index -> input index
|
|
43
|
+
|
|
44
|
+
self.plugin_namespace = namespace
|
|
45
|
+
self.plugin_name = name
|
|
46
|
+
self.num_outputs = num_outputs
|
|
47
|
+
|
|
48
|
+
self.autotune_combs = []
|
|
49
|
+
self.supported_combs = {}
|
|
50
|
+
self.curr_comb = None
|
|
51
|
+
|
|
52
|
+
def get_num_outputs(self):
|
|
53
|
+
return self.num_outputs
|
|
54
|
+
|
|
55
|
+
def get_output_data_types(self, input_types, ranks):
|
|
56
|
+
self.input_types = input_types
|
|
57
|
+
|
|
58
|
+
input_descs = [None] * len(input_types)
|
|
59
|
+
input_desc_map = {}
|
|
60
|
+
for i in range(len(input_types)):
|
|
61
|
+
input_descs[i] = TensorDesc()
|
|
62
|
+
input_descs[i].dtype = input_types[i]
|
|
63
|
+
input_descs[i].shape_expr = ShapeExprs(ranks[i], _is_dummy=True)
|
|
64
|
+
input_descs[i]._immutable = True
|
|
65
|
+
input_desc_map[id(input_descs[i])] = i
|
|
66
|
+
|
|
67
|
+
output_descs = self.register_function(*input_descs, **self.attrs)
|
|
68
|
+
if not isinstance(output_descs, Tuple):
|
|
69
|
+
output_descs = tuple([output_descs])
|
|
70
|
+
|
|
71
|
+
self.output_types = []
|
|
72
|
+
|
|
73
|
+
for i in range(len(output_descs)):
|
|
74
|
+
self.output_types.append(output_descs[i].dtype)
|
|
75
|
+
|
|
76
|
+
if output_descs[i].get_aliased() is not None:
|
|
77
|
+
self.aliased_map[i] = input_desc_map[id(output_descs[i].get_aliased())]
|
|
78
|
+
else:
|
|
79
|
+
self.aliased_map[i] = -1
|
|
80
|
+
|
|
81
|
+
return self.output_types
|
|
82
|
+
|
|
83
|
+
def get_fields_to_serialize(self):
|
|
84
|
+
fields = []
|
|
85
|
+
for key, value in self.attrs.items():
|
|
86
|
+
if key in self.impl_attr_names:
|
|
87
|
+
if isinstance(value, np.ndarray):
|
|
88
|
+
if np.dtype(value.dtype) == np.float16:
|
|
89
|
+
fields.append(
|
|
90
|
+
trt.PluginField(
|
|
91
|
+
key, value.tobytes(), trt.PluginFieldType.UNKNOWN
|
|
92
|
+
)
|
|
93
|
+
)
|
|
94
|
+
else:
|
|
95
|
+
fields.append(
|
|
96
|
+
trt.PluginField(
|
|
97
|
+
key,
|
|
98
|
+
value,
|
|
99
|
+
_numpy_to_plugin_field_type[np.dtype(value.dtype)],
|
|
100
|
+
)
|
|
101
|
+
)
|
|
102
|
+
elif isinstance(value, str):
|
|
103
|
+
fields.append(
|
|
104
|
+
trt.PluginField(key, value.encode(), trt.PluginFieldType.CHAR)
|
|
105
|
+
)
|
|
106
|
+
elif isinstance(value, bytes):
|
|
107
|
+
fields.append(
|
|
108
|
+
trt.PluginField(key, value, trt.PluginFieldType.UNKNOWN)
|
|
109
|
+
)
|
|
110
|
+
else:
|
|
111
|
+
fields.append(
|
|
112
|
+
trt.PluginField(
|
|
113
|
+
key,
|
|
114
|
+
np.array([value]),
|
|
115
|
+
_built_in_to_plugin_field_type[type(value)],
|
|
116
|
+
)
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
return trt.PluginFieldCollection(fields)
|
|
120
|
+
|
|
121
|
+
def get_output_shapes(self, inputs, shape_inputs, exprBuilder):
|
|
122
|
+
assert len(shape_inputs) == 0 # Shape inputs are not yet supported for QDPs
|
|
123
|
+
SymIntExpr._exprBuilder = exprBuilder
|
|
124
|
+
self.input_descs = []
|
|
125
|
+
for i in range(len(inputs)):
|
|
126
|
+
desc = TensorDesc()
|
|
127
|
+
inp = inputs[i]
|
|
128
|
+
|
|
129
|
+
desc.dtype = self.input_types[i]
|
|
130
|
+
desc.shape_expr = ShapeExprs(len(inp))
|
|
131
|
+
for j in range(len(inp)):
|
|
132
|
+
desc.shape_expr[j] = ShapeExpr(inp[j])
|
|
133
|
+
desc._immutable = True
|
|
134
|
+
|
|
135
|
+
self.input_descs.append(desc)
|
|
136
|
+
|
|
137
|
+
self.output_descs = self.register_function(*self.input_descs, **self.attrs)
|
|
138
|
+
if not isinstance(self.output_descs, Tuple):
|
|
139
|
+
self.output_descs = tuple([self.output_descs])
|
|
140
|
+
|
|
141
|
+
for idx, desc in enumerate(self.output_descs):
|
|
142
|
+
if desc.is_size_tensor:
|
|
143
|
+
desc._set_index(idx)
|
|
144
|
+
|
|
145
|
+
output_exprs = []
|
|
146
|
+
for i in range(len(self.output_descs)):
|
|
147
|
+
exprs = trt.DimsExprs(len(self.output_descs[i].shape_expr))
|
|
148
|
+
for j in range(len(exprs)):
|
|
149
|
+
exprs[j] = self.output_descs[i].shape_expr[j]._expr
|
|
150
|
+
|
|
151
|
+
output_exprs.append(exprs)
|
|
152
|
+
|
|
153
|
+
return output_exprs
|
|
154
|
+
|
|
155
|
+
def configure_plugin(self, inputs, outputs):
|
|
156
|
+
self.curr_comb = _TypeFormatCombination()
|
|
157
|
+
self.curr_comb.types = [inp.desc.type for inp in inputs] + [
|
|
158
|
+
out.desc.type for out in outputs
|
|
159
|
+
]
|
|
160
|
+
self.curr_comb.layouts = [inp.desc.format for inp in inputs] + [
|
|
161
|
+
out.desc.format for out in outputs
|
|
162
|
+
]
|
|
163
|
+
|
|
164
|
+
def get_supported_format_combinations(self, in_out, num_inputs):
|
|
165
|
+
if self.autotune_function is not None:
|
|
166
|
+
if len(self.autotune_attr_names) > 0:
|
|
167
|
+
val = [self.attrs[k] for k in self.autotune_attr_names]
|
|
168
|
+
else:
|
|
169
|
+
val = ()
|
|
170
|
+
|
|
171
|
+
for i, desc in enumerate(in_out):
|
|
172
|
+
if i < num_inputs:
|
|
173
|
+
self.input_descs[i]._immutable = False
|
|
174
|
+
self.input_descs[i].shape = Shape(desc)
|
|
175
|
+
self.input_descs[i].format = desc.desc.format
|
|
176
|
+
self.input_descs[i].scale = desc.desc.scale
|
|
177
|
+
self.input_descs[i]._immutable = True
|
|
178
|
+
else:
|
|
179
|
+
self.output_descs[i - num_inputs]._immutable = False
|
|
180
|
+
self.output_descs[i - num_inputs].shape = Shape(desc)
|
|
181
|
+
self.output_descs[i - num_inputs].format = desc.desc.format
|
|
182
|
+
self.output_descs[i - num_inputs].scale = desc.desc.scale
|
|
183
|
+
self.output_descs[i - num_inputs]._immutable = True
|
|
184
|
+
|
|
185
|
+
self.autotune_combs = self.autotune_function(
|
|
186
|
+
*self.input_descs, *val, self.output_descs
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
if len(self.autotune_combs) == 0:
|
|
190
|
+
default_comb = [None] * len(in_out)
|
|
191
|
+
comb = _TypeFormatCombination(len(in_out))
|
|
192
|
+
for j in range(len(in_out)):
|
|
193
|
+
default_comb[j] = trt.PluginTensorDesc()
|
|
194
|
+
default_comb[j].type = (
|
|
195
|
+
self.input_types[j]
|
|
196
|
+
if j < num_inputs
|
|
197
|
+
else self.output_descs[j - num_inputs].dtype
|
|
198
|
+
)
|
|
199
|
+
default_comb[j].format = trt.TensorFormat.LINEAR
|
|
200
|
+
comb.types[j] = default_comb[j].type
|
|
201
|
+
comb.layouts[j] = default_comb[j].format
|
|
202
|
+
|
|
203
|
+
self.supported_combs[comb] = set()
|
|
204
|
+
|
|
205
|
+
return default_comb
|
|
206
|
+
|
|
207
|
+
all_combs = []
|
|
208
|
+
for comb in self.autotune_combs:
|
|
209
|
+
all_combs.extend(comb._get_combinations())
|
|
210
|
+
|
|
211
|
+
ret_supported_combs = []
|
|
212
|
+
self.supported_combs = {}
|
|
213
|
+
|
|
214
|
+
for i, comb in enumerate(all_combs):
|
|
215
|
+
value = self.supported_combs.get(comb)
|
|
216
|
+
if value is not None:
|
|
217
|
+
value.update(set(comb.tactics) if comb.tactics is not None else set())
|
|
218
|
+
else:
|
|
219
|
+
self.supported_combs[comb] = (
|
|
220
|
+
set(comb.tactics) if comb.tactics is not None else set()
|
|
221
|
+
)
|
|
222
|
+
for j in range(len(in_out)):
|
|
223
|
+
curr_comb = trt.PluginTensorDesc()
|
|
224
|
+
curr_comb.type = comb.types[j]
|
|
225
|
+
curr_comb.format = comb.layouts[j]
|
|
226
|
+
ret_supported_combs.append(curr_comb)
|
|
227
|
+
|
|
228
|
+
return ret_supported_combs
|
|
229
|
+
|
|
230
|
+
def get_aliased_input(self, output_index: int):
|
|
231
|
+
return self.aliased_map[output_index]
|
|
232
|
+
|
|
233
|
+
def get_valid_tactics(self):
|
|
234
|
+
tactics = self.supported_combs.get(self.curr_comb)
|
|
235
|
+
assert tactics is not None
|
|
236
|
+
return list(tactics)
|
|
237
|
+
|
|
238
|
+
def set_tactic(self, tactic):
|
|
239
|
+
self._tactic = tactic
|
|
240
|
+
|
|
241
|
+
class _TemplateJITPlugin(_TemplatePluginBase, trt.IPluginV3QuickRuntime):
|
|
242
|
+
def __init__(self, name, namespace, num_outputs):
|
|
243
|
+
super().__init__(name, namespace, num_outputs)
|
|
244
|
+
trt.IPluginV3QuickRuntime.__init__(self)
|
|
245
|
+
|
|
246
|
+
self.expects_tactic = False
|
|
247
|
+
|
|
248
|
+
def init(
|
|
249
|
+
self,
|
|
250
|
+
register_function,
|
|
251
|
+
attrs,
|
|
252
|
+
impl_attr_names,
|
|
253
|
+
impl_function,
|
|
254
|
+
autotune_attr_names,
|
|
255
|
+
autotune_function,
|
|
256
|
+
expects_tactic,
|
|
257
|
+
):
|
|
258
|
+
self.register_function = register_function
|
|
259
|
+
self.impl_function = impl_function
|
|
260
|
+
self.attrs = attrs
|
|
261
|
+
self.impl_attr_names = impl_attr_names
|
|
262
|
+
self.autotune_attr_names = autotune_attr_names
|
|
263
|
+
self.autotune_function = autotune_function
|
|
264
|
+
self.expects_tactic = expects_tactic
|
|
265
|
+
|
|
266
|
+
def get_capability_interface(self, type):
|
|
267
|
+
return self
|
|
268
|
+
|
|
269
|
+
def enqueue(
|
|
270
|
+
self,
|
|
271
|
+
input_desc,
|
|
272
|
+
output_desc,
|
|
273
|
+
inputs,
|
|
274
|
+
outputs,
|
|
275
|
+
in_strides,
|
|
276
|
+
out_strides,
|
|
277
|
+
stream,
|
|
278
|
+
):
|
|
279
|
+
input_tensors = [None] * (len(inputs))
|
|
280
|
+
aliased_input_idxs = list(self.aliased_map.values())
|
|
281
|
+
|
|
282
|
+
for i in range(len(inputs)):
|
|
283
|
+
input_tensors[i] = Tensor()
|
|
284
|
+
input_tensors[i].dtype = input_desc[i].type
|
|
285
|
+
input_tensors[i].shape = Shape(input_desc[i])
|
|
286
|
+
input_tensors[i].format = input_desc[i].format
|
|
287
|
+
input_tensors[i].scale = input_desc[i].scale
|
|
288
|
+
input_tensors[i].data_ptr = inputs[i]
|
|
289
|
+
input_tensors[i]._stream = stream
|
|
290
|
+
input_tensors[i]._read_only = i not in aliased_input_idxs
|
|
291
|
+
input_tensors[i].strides = in_strides[i]
|
|
292
|
+
|
|
293
|
+
output_tensors = [None] * (len(outputs))
|
|
294
|
+
for i in range(len(outputs)):
|
|
295
|
+
output_tensors[i] = Tensor()
|
|
296
|
+
output_tensors[i].dtype = output_desc[i].type
|
|
297
|
+
output_tensors[i].shape = Shape(output_desc[i])
|
|
298
|
+
output_tensors[i].format = output_desc[i].format
|
|
299
|
+
output_tensors[i].scale = output_desc[i].scale
|
|
300
|
+
output_tensors[i].data_ptr = outputs[i]
|
|
301
|
+
output_tensors[i]._stream = stream
|
|
302
|
+
output_tensors[i]._read_only = False
|
|
303
|
+
output_tensors[i].strides = out_strides[i]
|
|
304
|
+
|
|
305
|
+
for i, j in self.aliased_map.items():
|
|
306
|
+
output_tensors[i]._aliased_to = input_tensors[j]
|
|
307
|
+
input_tensors[j]._aliased_to = output_tensors[i]
|
|
308
|
+
|
|
309
|
+
for t in input_tensors:
|
|
310
|
+
t._immutable = True
|
|
311
|
+
|
|
312
|
+
for t in output_tensors:
|
|
313
|
+
t._immutable = True
|
|
314
|
+
|
|
315
|
+
if len(self.impl_attr_names) > 0:
|
|
316
|
+
val = [self.attrs[k] for k in self.impl_attr_names]
|
|
317
|
+
else:
|
|
318
|
+
val = ()
|
|
319
|
+
|
|
320
|
+
if self.expects_tactic:
|
|
321
|
+
self.impl_function(
|
|
322
|
+
*input_tensors, *val, output_tensors, stream, self._tactic
|
|
323
|
+
)
|
|
324
|
+
else:
|
|
325
|
+
self.impl_function(*input_tensors, *val, output_tensors, stream=stream)
|
|
326
|
+
|
|
327
|
+
def clone(self):
|
|
328
|
+
cloned_plugin = _TemplateJITPlugin(
|
|
329
|
+
self.plugin_name, self.plugin_namespace, self.num_outputs
|
|
330
|
+
)
|
|
331
|
+
cloned_plugin.__dict__.update(self.__dict__)
|
|
332
|
+
return cloned_plugin
|
|
333
|
+
|
|
334
|
+
if IS_AOT_ENABLED:
|
|
335
|
+
class _TemplateAOTPlugin(
|
|
336
|
+
_TemplatePluginBase,
|
|
337
|
+
trt.IPluginV3QuickAOTBuild,
|
|
338
|
+
):
|
|
339
|
+
def __init__(self, name, namespace, num_outputs):
|
|
340
|
+
_TemplatePluginBase.__init__(self, name, namespace, num_outputs)
|
|
341
|
+
trt.IPluginV3QuickAOTBuild.__init__(self)
|
|
342
|
+
self.kernel_map = {}
|
|
343
|
+
|
|
344
|
+
def set_tactic(self, tactic):
|
|
345
|
+
self._tactic = tactic
|
|
346
|
+
|
|
347
|
+
def init(
|
|
348
|
+
self,
|
|
349
|
+
register_function,
|
|
350
|
+
attrs,
|
|
351
|
+
aot_impl_attr_names,
|
|
352
|
+
aot_impl_function,
|
|
353
|
+
autotune_attr_names,
|
|
354
|
+
autotune_function
|
|
355
|
+
):
|
|
356
|
+
self.register_function = register_function
|
|
357
|
+
self.aot_impl_function = aot_impl_function
|
|
358
|
+
self.attrs = attrs
|
|
359
|
+
self.aot_impl_attr_names = aot_impl_attr_names
|
|
360
|
+
self.autotune_attr_names = autotune_attr_names
|
|
361
|
+
self.autotune_function = autotune_function
|
|
362
|
+
|
|
363
|
+
def get_capability_interface(self, type):
|
|
364
|
+
return self
|
|
365
|
+
|
|
366
|
+
def get_kernel(self, inputDesc, outputDesc):
|
|
367
|
+
io_types = []
|
|
368
|
+
io_formats = []
|
|
369
|
+
|
|
370
|
+
for i, desc in enumerate(inputDesc):
|
|
371
|
+
io_types.append(desc.type)
|
|
372
|
+
io_formats.append(desc.format)
|
|
373
|
+
|
|
374
|
+
for i, desc in enumerate(outputDesc):
|
|
375
|
+
io_types.append(desc.type)
|
|
376
|
+
io_formats.append(desc.format)
|
|
377
|
+
|
|
378
|
+
key = (tuple(io_types), tuple(io_formats), self._tactic)
|
|
379
|
+
|
|
380
|
+
assert key in self.kernel_map, "key {} not in kernel_map".format(key)
|
|
381
|
+
|
|
382
|
+
kernel_name, ptx = self.kernel_map[key]
|
|
383
|
+
|
|
384
|
+
return kernel_name, ptx.encode() if isinstance(ptx, str) else ptx
|
|
385
|
+
|
|
386
|
+
def get_launch_params(self, inDimsExprs, in_out, num_inputs, launchParams, symExprSetter, exprBuilder):
|
|
387
|
+
|
|
388
|
+
SymIntExpr._exprBuilder = exprBuilder
|
|
389
|
+
|
|
390
|
+
if len(self.attrs) > 0:
|
|
391
|
+
_, val = zip(*self.attrs.items())
|
|
392
|
+
else:
|
|
393
|
+
val = ()
|
|
394
|
+
|
|
395
|
+
io_types = []
|
|
396
|
+
io_formats = []
|
|
397
|
+
|
|
398
|
+
for i, desc in enumerate(in_out):
|
|
399
|
+
if i < num_inputs:
|
|
400
|
+
self.input_descs[i]._immutable = False
|
|
401
|
+
self.input_descs[i].shape = Shape(desc)
|
|
402
|
+
self.input_descs[i].dtype = desc.desc.type
|
|
403
|
+
self.input_descs[i].format = desc.desc.format
|
|
404
|
+
self.input_descs[i].scale = desc.desc.scale
|
|
405
|
+
io_types.append(desc.desc.type)
|
|
406
|
+
io_formats.append(desc.desc.format)
|
|
407
|
+
self.input_descs[i]._immutable = True
|
|
408
|
+
else:
|
|
409
|
+
self.output_descs[i - num_inputs]._immutable = False
|
|
410
|
+
self.output_descs[i - num_inputs].shape = Shape(desc)
|
|
411
|
+
self.output_descs[i - num_inputs].dtype = desc.desc.type
|
|
412
|
+
self.output_descs[i - num_inputs].format = desc.desc.format
|
|
413
|
+
self.output_descs[i - num_inputs].scale = desc.desc.scale
|
|
414
|
+
io_types.append(desc.desc.type)
|
|
415
|
+
io_formats.append(desc.desc.format)
|
|
416
|
+
self.output_descs[i - num_inputs]._immutable = True
|
|
417
|
+
|
|
418
|
+
kernel_name, ptx, launch_params, extra_args = self.aot_impl_function(
|
|
419
|
+
*self.input_descs, *val, self.output_descs, self._tactic
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
if not isinstance(kernel_name, str) and not isinstance(kernel_name, bytes):
|
|
423
|
+
raise TypeError(f"Kernel name must be a 'str' or 'bytes'. Got: {type(kernel_name)}.")
|
|
424
|
+
|
|
425
|
+
if not isinstance(ptx, str) and not isinstance(ptx, bytes):
|
|
426
|
+
raise TypeError(f"PTX/CUBIN must be a 'str' or 'bytes'. Got: {type(ptx)}.")
|
|
427
|
+
|
|
428
|
+
if not isinstance(launch_params, KernelLaunchParams):
|
|
429
|
+
raise TypeError(f"Launch params must be a 'tensorrt.plugin.KernelLaunchParams'. Got: {type(launch_params)}.")
|
|
430
|
+
|
|
431
|
+
if not isinstance(extra_args, SymExprs):
|
|
432
|
+
raise TypeError(f"Extra args must be a 'tensorrt.plugin.SymIntExprs'. Got: {type(extra_args)}.")
|
|
433
|
+
|
|
434
|
+
launchParams.grid_x = launch_params.grid_x()
|
|
435
|
+
launchParams.grid_y = launch_params.grid_y()
|
|
436
|
+
launchParams.grid_z = launch_params.grid_z()
|
|
437
|
+
launchParams.block_x = launch_params.block_x()
|
|
438
|
+
launchParams.block_y = launch_params.block_y()
|
|
439
|
+
launchParams.block_z = launch_params.block_z()
|
|
440
|
+
launchParams.shared_mem = launch_params.shared_mem()
|
|
441
|
+
|
|
442
|
+
self.kernel_map[(tuple(io_types), tuple(io_formats), self._tactic)] = (kernel_name, ptx)
|
|
443
|
+
|
|
444
|
+
symExprSetter.nbSymExprs = len(extra_args)
|
|
445
|
+
|
|
446
|
+
for i, arg in enumerate(extra_args):
|
|
447
|
+
if not isinstance(arg, SymInt32):
|
|
448
|
+
raise TypeError(f"Extra args must be a 'tensorrt.plugin.SymInt32'. Got: {type(arg)}.")
|
|
449
|
+
symExprSetter[i] = arg()
|
|
450
|
+
|
|
451
|
+
def get_timing_cache_id(self):
|
|
452
|
+
return ""
|
|
453
|
+
|
|
454
|
+
def clone(self):
|
|
455
|
+
cloned_plugin = _TemplateAOTPlugin(
|
|
456
|
+
self.plugin_name, self.plugin_namespace, self.num_outputs
|
|
457
|
+
)
|
|
458
|
+
cloned_plugin.__dict__.update(self.__dict__)
|
|
459
|
+
return cloned_plugin
|