tensorrt-cu12-bindings 10.13.3.9.post1__cp312-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,459 @@
1
+ #
2
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ import tensorrt as trt
18
+ from typing import Tuple, Union
19
+
20
+ import numpy as np
21
+ from ._utils import _numpy_to_plugin_field_type, _built_in_to_plugin_field_type
22
+ from ._tensor import TensorDesc, Tensor, Shape, ShapeExpr, ShapeExprs, SymIntExpr, SymExprs, SymInt32
23
+ from ._export import IS_AOT_ENABLED
24
+ if IS_AOT_ENABLED:
25
+ from ._tensor import KernelLaunchParams
26
+ from ._autotune import _TypeFormatCombination
27
+
28
+ from ._export import public_api
29
+
30
+ class _TemplatePluginBase(
31
+ trt.IPluginV3,
32
+ trt.IPluginV3QuickCore,
33
+ trt.IPluginV3QuickBuild,
34
+ ):
35
+ def __init__(self, name, namespace, num_outputs):
36
+ trt.IPluginV3.__init__(self)
37
+ trt.IPluginV3QuickCore.__init__(self)
38
+ trt.IPluginV3QuickBuild.__init__(self)
39
+
40
+ self.plugin_version = "1"
41
+ self.input_types = []
42
+ self.aliased_map = {} # output index -> input index
43
+
44
+ self.plugin_namespace = namespace
45
+ self.plugin_name = name
46
+ self.num_outputs = num_outputs
47
+
48
+ self.autotune_combs = []
49
+ self.supported_combs = {}
50
+ self.curr_comb = None
51
+
52
+ def get_num_outputs(self):
53
+ return self.num_outputs
54
+
55
+ def get_output_data_types(self, input_types, ranks):
56
+ self.input_types = input_types
57
+
58
+ input_descs = [None] * len(input_types)
59
+ input_desc_map = {}
60
+ for i in range(len(input_types)):
61
+ input_descs[i] = TensorDesc()
62
+ input_descs[i].dtype = input_types[i]
63
+ input_descs[i].shape_expr = ShapeExprs(ranks[i], _is_dummy=True)
64
+ input_descs[i]._immutable = True
65
+ input_desc_map[id(input_descs[i])] = i
66
+
67
+ output_descs = self.register_function(*input_descs, **self.attrs)
68
+ if not isinstance(output_descs, Tuple):
69
+ output_descs = tuple([output_descs])
70
+
71
+ self.output_types = []
72
+
73
+ for i in range(len(output_descs)):
74
+ self.output_types.append(output_descs[i].dtype)
75
+
76
+ if output_descs[i].get_aliased() is not None:
77
+ self.aliased_map[i] = input_desc_map[id(output_descs[i].get_aliased())]
78
+ else:
79
+ self.aliased_map[i] = -1
80
+
81
+ return self.output_types
82
+
83
+ def get_fields_to_serialize(self):
84
+ fields = []
85
+ for key, value in self.attrs.items():
86
+ if key in self.impl_attr_names:
87
+ if isinstance(value, np.ndarray):
88
+ if np.dtype(value.dtype) == np.float16:
89
+ fields.append(
90
+ trt.PluginField(
91
+ key, value.tobytes(), trt.PluginFieldType.UNKNOWN
92
+ )
93
+ )
94
+ else:
95
+ fields.append(
96
+ trt.PluginField(
97
+ key,
98
+ value,
99
+ _numpy_to_plugin_field_type[np.dtype(value.dtype)],
100
+ )
101
+ )
102
+ elif isinstance(value, str):
103
+ fields.append(
104
+ trt.PluginField(key, value.encode(), trt.PluginFieldType.CHAR)
105
+ )
106
+ elif isinstance(value, bytes):
107
+ fields.append(
108
+ trt.PluginField(key, value, trt.PluginFieldType.UNKNOWN)
109
+ )
110
+ else:
111
+ fields.append(
112
+ trt.PluginField(
113
+ key,
114
+ np.array([value]),
115
+ _built_in_to_plugin_field_type[type(value)],
116
+ )
117
+ )
118
+
119
+ return trt.PluginFieldCollection(fields)
120
+
121
+ def get_output_shapes(self, inputs, shape_inputs, exprBuilder):
122
+ assert len(shape_inputs) == 0 # Shape inputs are not yet supported for QDPs
123
+ SymIntExpr._exprBuilder = exprBuilder
124
+ self.input_descs = []
125
+ for i in range(len(inputs)):
126
+ desc = TensorDesc()
127
+ inp = inputs[i]
128
+
129
+ desc.dtype = self.input_types[i]
130
+ desc.shape_expr = ShapeExprs(len(inp))
131
+ for j in range(len(inp)):
132
+ desc.shape_expr[j] = ShapeExpr(inp[j])
133
+ desc._immutable = True
134
+
135
+ self.input_descs.append(desc)
136
+
137
+ self.output_descs = self.register_function(*self.input_descs, **self.attrs)
138
+ if not isinstance(self.output_descs, Tuple):
139
+ self.output_descs = tuple([self.output_descs])
140
+
141
+ for idx, desc in enumerate(self.output_descs):
142
+ if desc.is_size_tensor:
143
+ desc._set_index(idx)
144
+
145
+ output_exprs = []
146
+ for i in range(len(self.output_descs)):
147
+ exprs = trt.DimsExprs(len(self.output_descs[i].shape_expr))
148
+ for j in range(len(exprs)):
149
+ exprs[j] = self.output_descs[i].shape_expr[j]._expr
150
+
151
+ output_exprs.append(exprs)
152
+
153
+ return output_exprs
154
+
155
+ def configure_plugin(self, inputs, outputs):
156
+ self.curr_comb = _TypeFormatCombination()
157
+ self.curr_comb.types = [inp.desc.type for inp in inputs] + [
158
+ out.desc.type for out in outputs
159
+ ]
160
+ self.curr_comb.layouts = [inp.desc.format for inp in inputs] + [
161
+ out.desc.format for out in outputs
162
+ ]
163
+
164
+ def get_supported_format_combinations(self, in_out, num_inputs):
165
+ if self.autotune_function is not None:
166
+ if len(self.autotune_attr_names) > 0:
167
+ val = [self.attrs[k] for k in self.autotune_attr_names]
168
+ else:
169
+ val = ()
170
+
171
+ for i, desc in enumerate(in_out):
172
+ if i < num_inputs:
173
+ self.input_descs[i]._immutable = False
174
+ self.input_descs[i].shape = Shape(desc)
175
+ self.input_descs[i].format = desc.desc.format
176
+ self.input_descs[i].scale = desc.desc.scale
177
+ self.input_descs[i]._immutable = True
178
+ else:
179
+ self.output_descs[i - num_inputs]._immutable = False
180
+ self.output_descs[i - num_inputs].shape = Shape(desc)
181
+ self.output_descs[i - num_inputs].format = desc.desc.format
182
+ self.output_descs[i - num_inputs].scale = desc.desc.scale
183
+ self.output_descs[i - num_inputs]._immutable = True
184
+
185
+ self.autotune_combs = self.autotune_function(
186
+ *self.input_descs, *val, self.output_descs
187
+ )
188
+
189
+ if len(self.autotune_combs) == 0:
190
+ default_comb = [None] * len(in_out)
191
+ comb = _TypeFormatCombination(len(in_out))
192
+ for j in range(len(in_out)):
193
+ default_comb[j] = trt.PluginTensorDesc()
194
+ default_comb[j].type = (
195
+ self.input_types[j]
196
+ if j < num_inputs
197
+ else self.output_descs[j - num_inputs].dtype
198
+ )
199
+ default_comb[j].format = trt.TensorFormat.LINEAR
200
+ comb.types[j] = default_comb[j].type
201
+ comb.layouts[j] = default_comb[j].format
202
+
203
+ self.supported_combs[comb] = set()
204
+
205
+ return default_comb
206
+
207
+ all_combs = []
208
+ for comb in self.autotune_combs:
209
+ all_combs.extend(comb._get_combinations())
210
+
211
+ ret_supported_combs = []
212
+ self.supported_combs = {}
213
+
214
+ for i, comb in enumerate(all_combs):
215
+ value = self.supported_combs.get(comb)
216
+ if value is not None:
217
+ value.update(set(comb.tactics) if comb.tactics is not None else set())
218
+ else:
219
+ self.supported_combs[comb] = (
220
+ set(comb.tactics) if comb.tactics is not None else set()
221
+ )
222
+ for j in range(len(in_out)):
223
+ curr_comb = trt.PluginTensorDesc()
224
+ curr_comb.type = comb.types[j]
225
+ curr_comb.format = comb.layouts[j]
226
+ ret_supported_combs.append(curr_comb)
227
+
228
+ return ret_supported_combs
229
+
230
+ def get_aliased_input(self, output_index: int):
231
+ return self.aliased_map[output_index]
232
+
233
+ def get_valid_tactics(self):
234
+ tactics = self.supported_combs.get(self.curr_comb)
235
+ assert tactics is not None
236
+ return list(tactics)
237
+
238
+ def set_tactic(self, tactic):
239
+ self._tactic = tactic
240
+
241
+ class _TemplateJITPlugin(_TemplatePluginBase, trt.IPluginV3QuickRuntime):
242
+ def __init__(self, name, namespace, num_outputs):
243
+ super().__init__(name, namespace, num_outputs)
244
+ trt.IPluginV3QuickRuntime.__init__(self)
245
+
246
+ self.expects_tactic = False
247
+
248
+ def init(
249
+ self,
250
+ register_function,
251
+ attrs,
252
+ impl_attr_names,
253
+ impl_function,
254
+ autotune_attr_names,
255
+ autotune_function,
256
+ expects_tactic,
257
+ ):
258
+ self.register_function = register_function
259
+ self.impl_function = impl_function
260
+ self.attrs = attrs
261
+ self.impl_attr_names = impl_attr_names
262
+ self.autotune_attr_names = autotune_attr_names
263
+ self.autotune_function = autotune_function
264
+ self.expects_tactic = expects_tactic
265
+
266
+ def get_capability_interface(self, type):
267
+ return self
268
+
269
+ def enqueue(
270
+ self,
271
+ input_desc,
272
+ output_desc,
273
+ inputs,
274
+ outputs,
275
+ in_strides,
276
+ out_strides,
277
+ stream,
278
+ ):
279
+ input_tensors = [None] * (len(inputs))
280
+ aliased_input_idxs = list(self.aliased_map.values())
281
+
282
+ for i in range(len(inputs)):
283
+ input_tensors[i] = Tensor()
284
+ input_tensors[i].dtype = input_desc[i].type
285
+ input_tensors[i].shape = Shape(input_desc[i])
286
+ input_tensors[i].format = input_desc[i].format
287
+ input_tensors[i].scale = input_desc[i].scale
288
+ input_tensors[i].data_ptr = inputs[i]
289
+ input_tensors[i]._stream = stream
290
+ input_tensors[i]._read_only = i not in aliased_input_idxs
291
+ input_tensors[i].strides = in_strides[i]
292
+
293
+ output_tensors = [None] * (len(outputs))
294
+ for i in range(len(outputs)):
295
+ output_tensors[i] = Tensor()
296
+ output_tensors[i].dtype = output_desc[i].type
297
+ output_tensors[i].shape = Shape(output_desc[i])
298
+ output_tensors[i].format = output_desc[i].format
299
+ output_tensors[i].scale = output_desc[i].scale
300
+ output_tensors[i].data_ptr = outputs[i]
301
+ output_tensors[i]._stream = stream
302
+ output_tensors[i]._read_only = False
303
+ output_tensors[i].strides = out_strides[i]
304
+
305
+ for i, j in self.aliased_map.items():
306
+ output_tensors[i]._aliased_to = input_tensors[j]
307
+ input_tensors[j]._aliased_to = output_tensors[i]
308
+
309
+ for t in input_tensors:
310
+ t._immutable = True
311
+
312
+ for t in output_tensors:
313
+ t._immutable = True
314
+
315
+ if len(self.impl_attr_names) > 0:
316
+ val = [self.attrs[k] for k in self.impl_attr_names]
317
+ else:
318
+ val = ()
319
+
320
+ if self.expects_tactic:
321
+ self.impl_function(
322
+ *input_tensors, *val, output_tensors, stream, self._tactic
323
+ )
324
+ else:
325
+ self.impl_function(*input_tensors, *val, output_tensors, stream=stream)
326
+
327
+ def clone(self):
328
+ cloned_plugin = _TemplateJITPlugin(
329
+ self.plugin_name, self.plugin_namespace, self.num_outputs
330
+ )
331
+ cloned_plugin.__dict__.update(self.__dict__)
332
+ return cloned_plugin
333
+
334
+ if IS_AOT_ENABLED:
335
+ class _TemplateAOTPlugin(
336
+ _TemplatePluginBase,
337
+ trt.IPluginV3QuickAOTBuild,
338
+ ):
339
+ def __init__(self, name, namespace, num_outputs):
340
+ _TemplatePluginBase.__init__(self, name, namespace, num_outputs)
341
+ trt.IPluginV3QuickAOTBuild.__init__(self)
342
+ self.kernel_map = {}
343
+
344
+ def set_tactic(self, tactic):
345
+ self._tactic = tactic
346
+
347
+ def init(
348
+ self,
349
+ register_function,
350
+ attrs,
351
+ aot_impl_attr_names,
352
+ aot_impl_function,
353
+ autotune_attr_names,
354
+ autotune_function
355
+ ):
356
+ self.register_function = register_function
357
+ self.aot_impl_function = aot_impl_function
358
+ self.attrs = attrs
359
+ self.aot_impl_attr_names = aot_impl_attr_names
360
+ self.autotune_attr_names = autotune_attr_names
361
+ self.autotune_function = autotune_function
362
+
363
+ def get_capability_interface(self, type):
364
+ return self
365
+
366
+ def get_kernel(self, inputDesc, outputDesc):
367
+ io_types = []
368
+ io_formats = []
369
+
370
+ for i, desc in enumerate(inputDesc):
371
+ io_types.append(desc.type)
372
+ io_formats.append(desc.format)
373
+
374
+ for i, desc in enumerate(outputDesc):
375
+ io_types.append(desc.type)
376
+ io_formats.append(desc.format)
377
+
378
+ key = (tuple(io_types), tuple(io_formats), self._tactic)
379
+
380
+ assert key in self.kernel_map, "key {} not in kernel_map".format(key)
381
+
382
+ kernel_name, ptx = self.kernel_map[key]
383
+
384
+ return kernel_name, ptx.encode() if isinstance(ptx, str) else ptx
385
+
386
+ def get_launch_params(self, inDimsExprs, in_out, num_inputs, launchParams, symExprSetter, exprBuilder):
387
+
388
+ SymIntExpr._exprBuilder = exprBuilder
389
+
390
+ if len(self.attrs) > 0:
391
+ _, val = zip(*self.attrs.items())
392
+ else:
393
+ val = ()
394
+
395
+ io_types = []
396
+ io_formats = []
397
+
398
+ for i, desc in enumerate(in_out):
399
+ if i < num_inputs:
400
+ self.input_descs[i]._immutable = False
401
+ self.input_descs[i].shape = Shape(desc)
402
+ self.input_descs[i].dtype = desc.desc.type
403
+ self.input_descs[i].format = desc.desc.format
404
+ self.input_descs[i].scale = desc.desc.scale
405
+ io_types.append(desc.desc.type)
406
+ io_formats.append(desc.desc.format)
407
+ self.input_descs[i]._immutable = True
408
+ else:
409
+ self.output_descs[i - num_inputs]._immutable = False
410
+ self.output_descs[i - num_inputs].shape = Shape(desc)
411
+ self.output_descs[i - num_inputs].dtype = desc.desc.type
412
+ self.output_descs[i - num_inputs].format = desc.desc.format
413
+ self.output_descs[i - num_inputs].scale = desc.desc.scale
414
+ io_types.append(desc.desc.type)
415
+ io_formats.append(desc.desc.format)
416
+ self.output_descs[i - num_inputs]._immutable = True
417
+
418
+ kernel_name, ptx, launch_params, extra_args = self.aot_impl_function(
419
+ *self.input_descs, *val, self.output_descs, self._tactic
420
+ )
421
+
422
+ if not isinstance(kernel_name, str) and not isinstance(kernel_name, bytes):
423
+ raise TypeError(f"Kernel name must be a 'str' or 'bytes'. Got: {type(kernel_name)}.")
424
+
425
+ if not isinstance(ptx, str) and not isinstance(ptx, bytes):
426
+ raise TypeError(f"PTX/CUBIN must be a 'str' or 'bytes'. Got: {type(ptx)}.")
427
+
428
+ if not isinstance(launch_params, KernelLaunchParams):
429
+ raise TypeError(f"Launch params must be a 'tensorrt.plugin.KernelLaunchParams'. Got: {type(launch_params)}.")
430
+
431
+ if not isinstance(extra_args, SymExprs):
432
+ raise TypeError(f"Extra args must be a 'tensorrt.plugin.SymIntExprs'. Got: {type(extra_args)}.")
433
+
434
+ launchParams.grid_x = launch_params.grid_x()
435
+ launchParams.grid_y = launch_params.grid_y()
436
+ launchParams.grid_z = launch_params.grid_z()
437
+ launchParams.block_x = launch_params.block_x()
438
+ launchParams.block_y = launch_params.block_y()
439
+ launchParams.block_z = launch_params.block_z()
440
+ launchParams.shared_mem = launch_params.shared_mem()
441
+
442
+ self.kernel_map[(tuple(io_types), tuple(io_formats), self._tactic)] = (kernel_name, ptx)
443
+
444
+ symExprSetter.nbSymExprs = len(extra_args)
445
+
446
+ for i, arg in enumerate(extra_args):
447
+ if not isinstance(arg, SymInt32):
448
+ raise TypeError(f"Extra args must be a 'tensorrt.plugin.SymInt32'. Got: {type(arg)}.")
449
+ symExprSetter[i] = arg()
450
+
451
+ def get_timing_cache_id(self):
452
+ return ""
453
+
454
+ def clone(self):
455
+ cloned_plugin = _TemplateAOTPlugin(
456
+ self.plugin_name, self.plugin_namespace, self.num_outputs
457
+ )
458
+ cloned_plugin.__dict__.update(self.__dict__)
459
+ return cloned_plugin