tensorrt-cu12-bindings 10.8.0.43__cp38-none-win_amd64.whl → 10.9.0.34__cp38-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorrt-cu12-bindings might be problematic. Click here for more details.

@@ -31,20 +31,25 @@ else:
31
31
 
32
32
 
33
33
  if not _libs_wheel_imported and sys.platform.startswith("win"):
34
+ log_found_dlls = bool(int(os.environ.get("TRT_LOG_FOUND_DLLS", 0)))
34
35
  # On Windows, we need to manually open the TensorRT libraries - otherwise we are unable to
35
36
  # load the bindings. If we imported the tensorrt_libs wheel, then that should have taken care of it for us.
36
37
  def find_lib(name):
37
38
  paths = os.environ["PATH"].split(os.path.pathsep)
39
+
38
40
  # Add ../tensorrt.libs to the search path. This allows repackaging non-standalone TensorRT wheels as standalone
39
41
  # using delvewheel (with the --no-mangle-all flag set) to work properly.
40
42
  paths.append(os.path.join(os.path.dirname(__file__), os.pardir, "tensorrt.libs"))
43
+
41
44
  for path in paths:
42
45
  libpath = os.path.join(path, name)
43
46
  if os.path.isfile(libpath):
47
+ if log_found_dlls:
48
+ print(f"Found {name} in path: {libpath}")
44
49
  return libpath
45
50
 
46
- if name.startswith("cudnn") or name.startswith("cublas"):
47
- return ""
51
+ if name.startswith("nvinfer_builder_resource"):
52
+ return None
48
53
 
49
54
  raise FileNotFoundError(
50
55
  "Could not find: {:}. Is it on your PATH?\nNote: Paths searched were:\n{:}".format(name, paths)
@@ -54,11 +59,9 @@ if not _libs_wheel_imported and sys.platform.startswith("win"):
54
59
  LIBRARIES = {
55
60
  "tensorrt": [
56
61
  "nvinfer_10.dll",
57
- "cublas64_12.dll",
58
- "cublasLt64_12.dll",
59
- "cudnn64_##CUDNN_MAJOR##.dll",
60
62
  "nvinfer_plugin_10.dll",
61
63
  "nvonnxparser_10.dll",
64
+ "nvinfer_builder_resource_10.dll",
62
65
  ],
63
66
  "tensorrt_dispatch": [
64
67
  "nvinfer_dispatch_10.dll",
@@ -70,14 +73,16 @@ if not _libs_wheel_imported and sys.platform.startswith("win"):
70
73
 
71
74
  for lib in LIBRARIES:
72
75
  lib_path = find_lib(lib)
73
- if lib_path != "":
74
- ctypes.CDLL(lib_path)
76
+ if not lib_path:
77
+ continue
78
+ assert os.path.isfile(lib_path)
79
+ ctypes.CDLL(lib_path)
75
80
 
76
81
  del _libs_wheel_imported
77
82
 
78
83
  from .tensorrt import *
79
84
 
80
- __version__ = "10.8.0.43"
85
+ __version__ = "10.9.0.34"
81
86
 
82
87
 
83
88
  # Provides Python's `with` syntax
@@ -1,5 +1,5 @@
1
1
  #
2
- # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,6 +15,7 @@
15
15
  # limitations under the License.
16
16
  #
17
17
 
18
+ import tensorrt as trt
18
19
  from types import ModuleType
19
20
  import importlib
20
21
 
@@ -34,3 +35,5 @@ def public_api(module: ModuleType = None, symbol: str = None):
34
35
  return obj
35
36
 
36
37
  return export_impl
38
+
39
+ IS_AOT_ENABLED = hasattr(trt, "QuickPluginCreationRequest")
@@ -1,5 +1,5 @@
1
1
  #
2
- # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,13 +20,16 @@ import types
20
20
  import typing
21
21
  from typing import Callable, Tuple, List
22
22
  import numpy as np
23
-
24
- from ._plugin_class import _TemplatePlugin
23
+ from ._plugin_class import _TemplateJITPlugin
24
+ from ._export import IS_AOT_ENABLED
25
+ if IS_AOT_ENABLED:
26
+ from ._plugin_class import _TemplateAOTPlugin
25
27
  from ._validate import (
26
28
  _parse_register_inputs,
27
29
  _parse_register_return,
28
30
  _validate_autotune,
29
31
  _validate_impl,
32
+ _validate_aot_impl,
30
33
  _validate_name_and_namespace,
31
34
  )
32
35
  from ._utils import (
@@ -91,11 +94,13 @@ class PluginDef:
91
94
  self.plugin_id = None # includes namespace (format is ns::name)
92
95
  self.register_func = None
93
96
  self.impl_func = None
97
+ self.aot_impl_func = None
94
98
  self.autotune_func = None
95
99
  self.autotune_attr_names = None
96
100
  self.input_tensor_names = None
97
101
  self.input_attrs = None # map name -> type
98
102
  self.impl_attr_names = None
103
+ self.aot_impl_attr_names = None
99
104
  self.num_outputs = None
100
105
  self.input_arg_schema = None
101
106
  self.expects_tactic = None
@@ -195,24 +200,26 @@ class PluginDef:
195
200
  )
196
201
  )
197
202
 
198
- plg = plg_creator.create_plugin(
199
- name,
200
- namespace,
201
- trt.PluginFieldCollection(fields),
202
- trt.TensorRTPhase.BUILD,
203
- )
204
- plg.init(
205
- self.register_func,
206
- attrs,
207
- self.impl_attr_names,
208
- self.impl_func,
209
- self.autotune_attr_names,
210
- self.autotune_func,
211
- self.expects_tactic,
212
- )
203
+ def create_plugin_instance(quick_plugin_creation_request: "trt.QuickPluginCreationRequest" = None):
204
+ if quick_plugin_creation_request is None:
205
+ plg = plg_creator.create_plugin(
206
+ name,
207
+ namespace,
208
+ trt.PluginFieldCollection(fields),
209
+ trt.TensorRTPhase.BUILD
210
+ )
211
+ else:
212
+ plg = plg_creator.create_plugin(
213
+ name,
214
+ namespace,
215
+ trt.PluginFieldCollection(fields),
216
+ trt.TensorRTPhase.BUILD,
217
+ quick_plugin_creation_request
218
+ )
213
219
 
214
- return input_tensors, [], plg
220
+ return input_tensors, [], plg
215
221
 
222
+ return create_plugin_instance
216
223
 
217
224
  class _TemplatePluginCreator(trt.IPluginCreatorV3Quick):
218
225
  def __init__(self, name, namespace, attrs):
@@ -246,7 +253,7 @@ class _TemplatePluginCreator(trt.IPluginCreatorV3Quick):
246
253
 
247
254
  self.field_names = trt.PluginFieldCollection(field_names)
248
255
 
249
- def create_plugin(self, name, namespace, fc, phase):
256
+ def create_plugin(self, name, namespace, fc, phase, qpcr: "trt.QuickPluginCreationRequest" = None):
250
257
  desc = QDP_REGISTRY[f"{namespace}::{name}"]
251
258
  name = name
252
259
  namespace = namespace
@@ -271,18 +278,83 @@ class _TemplatePluginCreator(trt.IPluginCreatorV3Quick):
271
278
  else:
272
279
  attrs[f.name] = attr_type_annot(f.data)
273
280
 
274
- plg = _TemplatePlugin(name, namespace, desc.num_outputs)
275
- plg.init(
276
- desc.register_func,
277
- attrs,
278
- desc.impl_attr_names,
279
- desc.impl_func,
280
- desc.autotune_attr_names,
281
- desc.autotune_func,
282
- desc.expects_tactic,
283
- )
284
- return plg
281
+ jit_or_aot = None # True if JIT is to be created, False if AOT. Not None will be asserted before plugin creation.
282
+
283
+ if qpcr is None:
284
+ plg = _TemplateJITPlugin(name, namespace, desc.num_outputs)
285
+
286
+ plg.init(
287
+ desc.register_func,
288
+ attrs,
289
+ desc.impl_attr_names,
290
+ desc.impl_func,
291
+ desc.autotune_attr_names,
292
+ desc.autotune_func,
293
+ desc.expects_tactic,
294
+ )
295
+
296
+ return plg
297
+
298
+ # If there is a strict preference, that takes precedence
299
+ if qpcr == trt.QuickPluginCreationRequest.STRICT_AOT:
300
+ if desc.aot_impl_func is None:
301
+ raise ValueError(f"AOT implementation requested, but not defined for '{desc.plugin_id}'. Was @trt.plugin.aot_impl defined?")
302
+ jit_or_aot = False
303
+ elif qpcr == trt.QuickPluginCreationRequest.STRICT_JIT:
304
+ if desc.impl_func is None:
305
+ raise ValueError(f"JIT implementation requested, but not defined for '{desc.plugin_id}'. Was @trt.plugin.impl defined?")
306
+ jit_or_aot = True
307
+ else:
308
+ aot_defined = desc.aot_impl_func is not None
309
+ jit_defined = desc.impl_func is not None
310
+
311
+ # A preferemce must be indicated if both AOT and JIT implementations are defined
312
+ if aot_defined and jit_defined:
313
+ if qpcr == trt.QuickPluginCreationRequest.PREFER_AOT:
314
+ jit_or_aot = False
315
+ elif qpcr == trt.QuickPluginCreationRequest.PREFER_JIT:
316
+ jit_or_aot = True
317
+ else:
318
+ raise ValueError(f"Plugin '{desc.plugin_id}' has both AOT and JIT implementations. NetworkDefinitionCreationFlag.PREFER_AOT_PYTHON_PLUGINS or NetworkDefinitionCreationFlag.PREFER_JIT_PYTHON_PLUGINS should be specified.")
319
+ else:
320
+ # If only one implementation is defined, use that.
321
+ # Any preference specified is ignored. If the preference is strong, a strict flag should have been specified.
322
+ if aot_defined:
323
+ jit_or_aot = False
324
+ elif jit_defined:
325
+ jit_or_aot = True
326
+ else:
327
+ raise ValueError(f"Plugin '{desc.plugin_id}' does not have either a AOT or JIT implementation.")
328
+
329
+ assert jit_or_aot is not None
330
+
331
+ if jit_or_aot:
332
+ plg = _TemplateJITPlugin(name, namespace, desc.num_outputs)
333
+
334
+ plg.init(
335
+ desc.register_func,
336
+ attrs,
337
+ desc.impl_attr_names,
338
+ desc.impl_func,
339
+ desc.autotune_attr_names,
340
+ desc.autotune_func,
341
+ desc.expects_tactic,
342
+ )
285
343
 
344
+ else:
345
+ plg = _TemplateAOTPlugin(name, namespace, desc.num_outputs)
346
+
347
+ plg.init(
348
+ desc.register_func,
349
+ attrs,
350
+ desc.aot_impl_attr_names,
351
+ desc.aot_impl_func,
352
+ desc.autotune_attr_names,
353
+ desc.autotune_func
354
+ )
355
+
356
+ # the caller can determine if the created plugin is an AOT or JIT plugin by inspecting the interface info
357
+ return plg
286
358
 
287
359
  def _register_plugin_creator(name: str, namespace: str, attrs_types):
288
360
  plg_registry = trt.get_plugin_registry()
@@ -445,6 +517,102 @@ def impl(plugin_id: str) -> Callable:
445
517
 
446
518
  return decorator
447
519
 
520
+ # Decorator for `tensorrt.plugin.aot_impl`
521
+ @public_api()
522
+ def aot_impl(plugin_id: str) -> Callable:
523
+ """
524
+ Wraps a function to define an Ahead-of-Time (AOT) implementation for a plugin already registered through `trt.plugin.register`.
525
+
526
+ This API is only intended to be used as a decorator. The decorated function is not required to have type hints for input arguments or return value;
527
+ however, any type hints specified will be validated against the `trt.plugin.register` signature for consistency.
528
+
529
+ The schema for the function is as follows:
530
+ .. code-block:: text
531
+
532
+ (inp0: TensorDesc, inp1: TensorDesc, ..., attr0: SupportedAttrType, attr1: SupportedAttrType, outputs: Tuple[TensorDesc], tactic: Optional[int]) -> Tuple[str, str, KernelLaunchParams, SymExprs]
533
+
534
+ * Input tensors are passed first, each described by a `TensorDesc`.
535
+ * Plugin attributes are declared next.
536
+ * Not all attributes included in `trt.plugin.register` must be specified here -- they could be a subset.
537
+ * NOTE: Plugin attributes are not serialized into the engine when using an AOT implementation.
538
+ * `tactic` is an optional argument. If the plugin is using custom tactics, it must be specified to receive the tactic value to use for the current execution of the plugin.
539
+
540
+ Args:
541
+ plugin_id: The ID for the plugin in the form "{namespace}::{name}", which must match that used during `trt.plugin.register`
542
+
543
+ :returns:
544
+ - kernel_name: The name of the kernel.
545
+ - compiled_kernel: Compiled form of the kernel. Presently, only PTX is supported.
546
+ - launch_params: The launch parameters for the kernel
547
+ - extra_args: Symbolic expressions for scalar inputs to the kernel, located after the tensor inputs and before the tensor outputs
548
+
549
+ .. code-block:: python
550
+ :linenos:
551
+ :caption: Implementation of an elementwise plugin with an OpenAI Triton kernel
552
+
553
+ import tensorrt.plugin as trtp
554
+ import triton
555
+ import triton.language as tl
556
+
557
+ @triton.jit
558
+ def add_kernel(x_ptr, n_elements, y_ptr, BLOCK_SIZE: tl.constexpr):
559
+ pid = tl.program_id(0)
560
+ offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
561
+ mask = offsets < n_elements
562
+ x = tl.load(x_ptr + offsets, mask=mask)
563
+ tl.store(y_ptr + offsets, x + 1, mask=mask)
564
+
565
+ @trtp.register("my::add_plugin")
566
+ def add_plugin_desc(inp0: trtp.TensorDesc, block_size: int) -> Tuple[trtp.TensorDesc]:
567
+ return inp0.like()
568
+
569
+ @trtp.aot_impl("my::elemwise_add_plugin")
570
+ def add_plugin_aot_impl(
571
+ inp0: trtp.TensorDesc, block_size: int, single_tactic: bool, outputs: Tuple[trtp.TensorDesc], tactic: int
572
+ ) -> Tuple[Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs]:
573
+
574
+ type_str = "fp32" if inp0.dtype == trt.float32 else "fp16"
575
+
576
+ src = triton.compiler.ASTSource(
577
+ fn=add_kernel,
578
+ signature=f"*{type_str},i32,*{type_str}",
579
+ constants={
580
+ "BLOCK_SIZE": block_size,
581
+ },
582
+ )
583
+
584
+ compiled_kernel = triton.compile(src)
585
+
586
+ N = inp0.shape_expr.numel()
587
+ launch_params = trtp.KernelLaunchParams()
588
+
589
+ # grid dims
590
+ launch_params.grid_x = trtp.cdiv(N, block_size)
591
+ # block dims
592
+ launch_params.block_x = compiled_kernel.metadata.num_warps * 32
593
+ # shared memory
594
+ launch_params.shared_mem = compiled_kernel.metadata.shared
595
+
596
+ extra_args = trtp.SymIntExprs(1)
597
+ extra_args[0] = trtp.SymInt32(N)
598
+
599
+ return compiled_kernel.metadata.name, compiled_kernel.asm["ptx"], launch_params, extra_args
600
+ """
601
+ def decorator(aot_impl_func: Callable):
602
+ if plugin_id not in QDP_REGISTRY:
603
+ raise ValueError(
604
+ f"Plugin {plugin_id} is not registered. Did you register it with tensorrt.plugin.register API?"
605
+ )
606
+
607
+ plugin_def = QDP_REGISTRY[plugin_id]
608
+ aot_impl_attr_names = _validate_aot_impl(aot_impl_func, plugin_def)
609
+
610
+ plugin_def.aot_impl_func = aot_impl_func
611
+ plugin_def.aot_impl_attr_names = aot_impl_attr_names
612
+ return aot_impl_func
613
+
614
+ return decorator
615
+
448
616
 
449
617
  # Decorator for `tensorrt.plugin.autotune`
450
618
  @public_api()
@@ -1,5 +1,5 @@
1
1
  #
2
- # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,25 +15,27 @@
15
15
  # limitations under the License.
16
16
  #
17
17
  import tensorrt as trt
18
- from typing import Tuple
18
+ from typing import Tuple, Union
19
19
 
20
20
  import numpy as np
21
21
  from ._utils import _numpy_to_plugin_field_type, _built_in_to_plugin_field_type
22
- from ._tensor import TensorDesc, Tensor, Shape, ShapeExpr, ShapeExprs
22
+ from ._tensor import TensorDesc, Tensor, Shape, ShapeExpr, ShapeExprs, SymIntExpr, SymExprs, SymInt32
23
+ from ._export import IS_AOT_ENABLED
24
+ if IS_AOT_ENABLED:
25
+ from ._tensor import KernelLaunchParams
23
26
  from ._autotune import _TypeFormatCombination
24
27
 
28
+ from ._export import public_api
25
29
 
26
- class _TemplatePlugin(
30
+ class _TemplatePluginBase(
27
31
  trt.IPluginV3,
28
32
  trt.IPluginV3QuickCore,
29
33
  trt.IPluginV3QuickBuild,
30
- trt.IPluginV3QuickRuntime,
31
34
  ):
32
35
  def __init__(self, name, namespace, num_outputs):
33
36
  trt.IPluginV3.__init__(self)
34
37
  trt.IPluginV3QuickCore.__init__(self)
35
38
  trt.IPluginV3QuickBuild.__init__(self)
36
- trt.IPluginV3QuickRuntime.__init__(self)
37
39
 
38
40
  self.plugin_version = "1"
39
41
  self.input_types = []
@@ -46,28 +48,6 @@ class _TemplatePlugin(
46
48
  self.autotune_combs = []
47
49
  self.supported_combs = {}
48
50
  self.curr_comb = None
49
- self.expects_tactic = False
50
-
51
- def init(
52
- self,
53
- register_function,
54
- attrs,
55
- impl_attr_names,
56
- impl_function,
57
- autotune_attr_names,
58
- autotune_function,
59
- expects_tactic,
60
- ):
61
- self.register_function = register_function
62
- self.impl_function = impl_function
63
- self.attrs = attrs
64
- self.impl_attr_names = impl_attr_names
65
- self.autotune_attr_names = autotune_attr_names
66
- self.autotune_function = autotune_function
67
- self.expects_tactic = expects_tactic
68
-
69
- def get_capability_interface(self, type):
70
- return self
71
51
 
72
52
  def get_num_outputs(self):
73
53
  return self.num_outputs
@@ -140,7 +120,7 @@ class _TemplatePlugin(
140
120
 
141
121
  def get_output_shapes(self, inputs, shape_inputs, exprBuilder):
142
122
  assert len(shape_inputs) == 0 # Shape inputs are not yet supported for QDPs
143
- ShapeExpr._exprBuilder = exprBuilder
123
+ SymIntExpr._exprBuilder = exprBuilder
144
124
  self.input_descs = []
145
125
  for i in range(len(inputs)):
146
126
  desc = TensorDesc()
@@ -247,6 +227,45 @@ class _TemplatePlugin(
247
227
 
248
228
  return ret_supported_combs
249
229
 
230
+ def get_aliased_input(self, output_index: int):
231
+ return self.aliased_map[output_index]
232
+
233
+ def get_valid_tactics(self):
234
+ tactics = self.supported_combs.get(self.curr_comb)
235
+ assert tactics is not None
236
+ return list(tactics)
237
+
238
+ def set_tactic(self, tactic):
239
+ self._tactic = tactic
240
+
241
+ class _TemplateJITPlugin(_TemplatePluginBase, trt.IPluginV3QuickRuntime):
242
+ def __init__(self, name, namespace, num_outputs):
243
+ super().__init__(name, namespace, num_outputs)
244
+ trt.IPluginV3QuickRuntime.__init__(self)
245
+
246
+ self.expects_tactic = False
247
+
248
+ def init(
249
+ self,
250
+ register_function,
251
+ attrs,
252
+ impl_attr_names,
253
+ impl_function,
254
+ autotune_attr_names,
255
+ autotune_function,
256
+ expects_tactic,
257
+ ):
258
+ self.register_function = register_function
259
+ self.impl_function = impl_function
260
+ self.attrs = attrs
261
+ self.impl_attr_names = impl_attr_names
262
+ self.autotune_attr_names = autotune_attr_names
263
+ self.autotune_function = autotune_function
264
+ self.expects_tactic = expects_tactic
265
+
266
+ def get_capability_interface(self, type):
267
+ return self
268
+
250
269
  def enqueue(
251
270
  self,
252
271
  input_desc,
@@ -305,20 +324,136 @@ class _TemplatePlugin(
305
324
  else:
306
325
  self.impl_function(*input_tensors, *val, output_tensors, stream=stream)
307
326
 
308
- def get_aliased_input(self, output_index: int):
309
- return self.aliased_map[output_index]
310
-
311
- def get_valid_tactics(self):
312
- tactics = self.supported_combs.get(self.curr_comb)
313
- assert tactics is not None
314
- return list(tactics)
315
-
316
- def set_tactic(self, tactic):
317
- self._tactic = tactic
318
-
319
327
  def clone(self):
320
- cloned_plugin = _TemplatePlugin(
328
+ cloned_plugin = _TemplateJITPlugin(
321
329
  self.plugin_name, self.plugin_namespace, self.num_outputs
322
330
  )
323
331
  cloned_plugin.__dict__.update(self.__dict__)
324
332
  return cloned_plugin
333
+
334
+ if IS_AOT_ENABLED:
335
+ class _TemplateAOTPlugin(
336
+ _TemplatePluginBase,
337
+ trt.IPluginV3QuickAOTBuild,
338
+ ):
339
+ def __init__(self, name, namespace, num_outputs):
340
+ _TemplatePluginBase.__init__(self, name, namespace, num_outputs)
341
+ trt.IPluginV3QuickAOTBuild.__init__(self)
342
+ self.kernel_map = {}
343
+
344
+ def set_tactic(self, tactic):
345
+ self._tactic = tactic
346
+
347
+ def init(
348
+ self,
349
+ register_function,
350
+ attrs,
351
+ aot_impl_attr_names,
352
+ aot_impl_function,
353
+ autotune_attr_names,
354
+ autotune_function
355
+ ):
356
+ self.register_function = register_function
357
+ self.aot_impl_function = aot_impl_function
358
+ self.attrs = attrs
359
+ self.aot_impl_attr_names = aot_impl_attr_names
360
+ self.autotune_attr_names = autotune_attr_names
361
+ self.autotune_function = autotune_function
362
+
363
+ def get_capability_interface(self, type):
364
+ return self
365
+
366
+ def get_kernel(self, inputDesc, outputDesc):
367
+ io_types = []
368
+ io_formats = []
369
+
370
+ for i, desc in enumerate(inputDesc):
371
+ io_types.append(desc.type)
372
+ io_formats.append(desc.format)
373
+
374
+ for i, desc in enumerate(outputDesc):
375
+ io_types.append(desc.type)
376
+ io_formats.append(desc.format)
377
+
378
+ key = (tuple(io_types), tuple(io_formats), self._tactic)
379
+
380
+ assert key in self.kernel_map, "key {} not in kernel_map".format(key)
381
+
382
+ kernel_name, ptx = self.kernel_map[key]
383
+
384
+ return kernel_name, ptx.encode() if isinstance(ptx, str) else ptx
385
+
386
+ def get_launch_params(self, inDimsExprs, in_out, num_inputs, launchParams, symExprSetter, exprBuilder):
387
+
388
+ SymIntExpr._exprBuilder = exprBuilder
389
+
390
+ if len(self.attrs) > 0:
391
+ _, val = zip(*self.attrs.items())
392
+ else:
393
+ val = ()
394
+
395
+ io_types = []
396
+ io_formats = []
397
+
398
+ for i, desc in enumerate(in_out):
399
+ if i < num_inputs:
400
+ self.input_descs[i]._immutable = False
401
+ self.input_descs[i].shape = Shape(desc)
402
+ self.input_descs[i].dtype = desc.desc.type
403
+ self.input_descs[i].format = desc.desc.format
404
+ self.input_descs[i].scale = desc.desc.scale
405
+ io_types.append(desc.desc.type)
406
+ io_formats.append(desc.desc.format)
407
+ self.input_descs[i]._immutable = True
408
+ else:
409
+ self.output_descs[i - num_inputs]._immutable = False
410
+ self.output_descs[i - num_inputs].shape = Shape(desc)
411
+ self.output_descs[i - num_inputs].dtype = desc.desc.type
412
+ self.output_descs[i - num_inputs].format = desc.desc.format
413
+ self.output_descs[i - num_inputs].scale = desc.desc.scale
414
+ io_types.append(desc.desc.type)
415
+ io_formats.append(desc.desc.format)
416
+ self.output_descs[i - num_inputs]._immutable = True
417
+
418
+ kernel_name, ptx, launch_params, extra_args = self.aot_impl_function(
419
+ *self.input_descs, *val, self.output_descs, self._tactic
420
+ )
421
+
422
+ if not isinstance(kernel_name, str) and not isinstance(kernel_name, bytes):
423
+ raise TypeError(f"Kernel name must be a 'str' or 'bytes'. Got: {type(kernel_name)}.")
424
+
425
+ if not isinstance(ptx, str) and not isinstance(ptx, bytes):
426
+ raise TypeError(f"PTX/CUBIN must be a 'str' or 'bytes'. Got: {type(ptx)}.")
427
+
428
+ if not isinstance(launch_params, KernelLaunchParams):
429
+ raise TypeError(f"Launch params must be a 'tensorrt.plugin.KernelLaunchParams'. Got: {type(launch_params)}.")
430
+
431
+ if not isinstance(extra_args, SymExprs):
432
+ raise TypeError(f"Extra args must be a 'tensorrt.plugin.SymIntExprs'. Got: {type(extra_args)}.")
433
+
434
+ launchParams.grid_x = launch_params.grid_x()
435
+ launchParams.grid_y = launch_params.grid_y()
436
+ launchParams.grid_z = launch_params.grid_z()
437
+ launchParams.block_x = launch_params.block_x()
438
+ launchParams.block_y = launch_params.block_y()
439
+ launchParams.block_z = launch_params.block_z()
440
+ launchParams.shared_mem = launch_params.shared_mem()
441
+
442
+ self.kernel_map[(tuple(io_types), tuple(io_formats), self._tactic)] = (kernel_name, ptx)
443
+
444
+ symExprSetter.nbSymExprs = len(extra_args)
445
+
446
+ for i, arg in enumerate(extra_args):
447
+ if not isinstance(arg, SymInt32):
448
+ raise TypeError(f"Extra args must be a 'tensorrt.plugin.SymInt32'. Got: {type(arg)}.")
449
+ symExprSetter[i] = arg()
450
+
451
+ def get_timing_cache_id(self):
452
+ return ""
453
+
454
+ def clone(self):
455
+ cloned_plugin = _TemplateAOTPlugin(
456
+ self.plugin_name, self.plugin_namespace, self.num_outputs
457
+ )
458
+ cloned_plugin.__dict__.update(self.__dict__)
459
+ return cloned_plugin