tensorrt-cu12-bindings 10.8.0.43__cp310-none-win_amd64.whl → 10.9.0.34__cp310-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorrt-cu12-bindings might be problematic. Click here for more details.

@@ -1,5 +1,5 @@
1
1
  #
2
- # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,68 +18,224 @@
18
18
  import tensorrt as trt
19
19
  from typing import Tuple, Union
20
20
  import numpy as np
21
- from ._export import public_api
21
+ from ._export import public_api, IS_AOT_ENABLED
22
+ from abc import ABC, abstractmethod
22
23
 
23
- # Symbolic expression for a given dimension of a tensor
24
- @public_api()
25
- class ShapeExpr:
24
+ class SymExpr(ABC):
25
+
26
+ @abstractmethod
27
+ def _op(self, op, other):
28
+ pass
29
+
30
+ @abstractmethod
31
+ def __add__(self, other):
32
+ pass
33
+
34
+ @abstractmethod
35
+ def __sub__(self, other):
36
+ pass
37
+
38
+ @abstractmethod
39
+ def __mul__(self, other):
40
+ pass
41
+
42
+ @abstractmethod
43
+ def __floordiv__(self, other):
44
+ pass
45
+
46
+ @abstractmethod
47
+ def __eq__(self, other):
48
+ pass
49
+
50
+ @abstractmethod
51
+ def __lt__(self, other):
52
+ pass
53
+
54
+ @abstractmethod
55
+ def __repr__(self):
56
+ pass
57
+
58
+ @property
59
+ @abstractmethod
60
+ def is_constant(self) -> bool:
61
+ pass
62
+
63
+ @property
64
+ @abstractmethod
65
+ def constant_value(self) -> int:
66
+ pass
67
+
68
+ # Evaluate the underlying trt.IDimensionExpr, if so done lazily
69
+ @property
70
+ @abstractmethod
71
+ def _expr(self):
72
+ pass
73
+
74
+ class SymIntExprMeta(type(SymExpr)):
75
+ pass
76
+
77
+ class SymIntExpr(SymExpr, metaclass=SymIntExprMeta):
26
78
  """
27
- Symbolic expression for single dimension of a tensor
79
+ Symbolic integer (scalar) expression
28
80
  """
29
81
  _exprBuilder = None # trt.IExprBuilder instance. Populated when a shape-calculation context is entered.
30
82
 
31
- def __init__(self, value: Union[int, trt.IDimensionExpr, "ShapeExpr"] = None):
83
+ def __init__(self, value: Union[int, trt.IDimensionExpr, "SymIntExpr"] = None):
32
84
  """
33
85
  Args:
34
- value (Union[int, trt.IDimensionExpr, ShapeExpr], optional): Constant or another symbolic expression. Defaults to creating a fake shape expression.
86
+ value (Union[int, trt.IDimensionExpr, SymIntExpr], optional): Constant or another symbolic expression. Defaults to creating a fake shape expression.
35
87
  """
36
- self._is_dummy = False
37
- self._dim_expr = None
38
- self._is_size_tensor = False
39
- if value is None:
40
- self._is_dummy = True
41
- elif isinstance(value, int):
42
- if self._exprBuilder is None:
43
- self._dim_expr = None
44
- self._is_dummy = True
88
+ self._int_expr = None
89
+ if isinstance(value, int):
90
+ if SymIntExpr._exprBuilder is None:
91
+ self._int_expr = None
45
92
  else:
46
- self._dim_expr = ShapeExpr._exprBuilder.constant(value)
93
+ self._int_expr = SymIntExpr._exprBuilder.constant(value)
47
94
  elif isinstance(value, trt.IDimensionExpr):
48
- self._dim_expr = value
49
- elif isinstance(value, ShapeExpr):
50
- self._dim_expr = value._dim_expr
51
- self._is_dummy = value._is_dummy
52
- self._is_size_tensor = value._is_size_tensor
95
+ self._int_expr = value
96
+ elif isinstance(value, SymIntExpr):
97
+ self._int_expr = value._int_expr
53
98
 
54
- def _op(self, op: trt.DimensionOperation, other: Union[int, "ShapeExpr"]):
55
- if self._is_size_tensor:
56
- raise ValueError("It is not permitted to perform binary operations on size tensor expressions") # trt limitation
57
- if self._is_dummy:
58
- return ShapeExpr()
99
+ def _op(self, op: trt.DimensionOperation, other: Union[int, "SymIntExpr"]):
59
100
  if isinstance(other, int):
60
- other = ShapeExpr(other)
61
- return ShapeExpr(ShapeExpr._exprBuilder.operation(op, self._expr, other._expr))
101
+ other = SymIntExpr(other)
102
+ return SymIntExpr(SymIntExpr._exprBuilder.operation(op, self._expr, other._expr))
62
103
 
63
104
  # Binary operations for +, -, *, //, ==. <
64
105
  # Those for ceil_div, max and min are provided as top-level functions of tensorrt.plugin
65
- def __add__(self, other: Union[int, "ShapeExpr"]):
106
+ def __add__(self, other: Union[int, "SymIntExpr"]):
66
107
  return self._op(trt.DimensionOperation.SUM, other)
67
108
 
68
- def __sub__(self, other: Union[int, "ShapeExpr"]):
109
+ def __sub__(self, other: Union[int, "SymIntExpr"]):
69
110
  return self._op(trt.DimensionOperation.SUB, other)
70
111
 
71
- def __mul__(self, other: Union[int, "ShapeExpr"]):
112
+ def __mul__(self, other: Union[int, "SymIntExpr"]):
72
113
  return self._op(trt.DimensionOperation.PROD, other)
73
114
 
74
- def __floordiv__(self, other: Union[int, "ShapeExpr"]):
115
+ def __floordiv__(self, other: Union[int, "SymIntExpr"]):
75
116
  return self._op(trt.DimensionOperation.FLOOR_DIV, other)
76
117
 
77
- def __eq__(self, other: Union[int, "ShapeExpr"]):
118
+ def __eq__(self, other: Union[int, "SymIntExpr"]):
78
119
  return self._op(trt.DimensionOperation.EQUAL, other)
79
120
 
80
- def __lt__(self, other: Union[int, "ShapeExpr"]):
121
+ def __lt__(self, other: Union[int, "SymIntExpr"]):
81
122
  return self._op(trt.DimensionOperation.LESS, other)
82
123
 
124
+ def __repr__(self):
125
+ if self._is_dummy:
126
+ return f"FakeSymIntExpr[id={id(self)}]"
127
+ elif not self.is_constant:
128
+ return f"SymIntExpr[id={id(self)}]"
129
+ return f"SymIntExpr[{self._expr.get_constant_value()}]"
130
+
131
+ @property
132
+ def is_constant(self) -> bool:
133
+ """
134
+ `True` if this integer expression is a build-time constant, `False` otherwise.
135
+
136
+ Raises:
137
+ RuntimeError: For fake :class:`SymIntExpr`\s. Check :attr:`is_fake` to determine accessibility.
138
+ """
139
+ if self._is_dummy:
140
+ raise RuntimeError(
141
+ "Not accessible for fake 'SymIntExpr's. Check is_fake to determine accessibility."
142
+ )
143
+ return self._expr.is_constant()
144
+
145
+ def constant_value(self) -> int:
146
+ """
147
+ Return value of the constant integer expression.
148
+
149
+ Raises:
150
+ RuntimeError: For non-constant integer expressions. Check :attr:`is_constant` to determine accessibility.
151
+ """
152
+ if not self.is_constant:
153
+ raise RuntimeError(
154
+ "Not accessible for non-constant integer expressions. Check is_constant to determine accessibility."
155
+ )
156
+ return self._expr.get_constant_value()
157
+
158
+ # Evaluate the underlying trt.IDimensionExpr, if so done lazily
159
+ @property
160
+ def _expr(self):
161
+ return self._int_expr
162
+
163
+ def _clone(self):
164
+ return SymIntExpr(self + 0)
165
+
166
+ class SymExprImpl(trt.ISymExpr):
167
+ def __init__(self, expr: SymIntExpr):
168
+ trt.ISymExpr.__init__(self)
169
+ self.type = trt.PluginArgType.INT
170
+ if isinstance(expr, SymInt32):
171
+ self.dtype = trt.PluginArgDataType.INT32
172
+ elif isinstance(expr, SymInt16):
173
+ self.dtype = trt.PluginArgDataType.INT16
174
+ elif isinstance(expr, SymInt8):
175
+ self.dtype = trt.PluginArgDataType.INT8
176
+ else:
177
+ raise ValueError(f"Unknown SymIntExpr type {type(expr)}")
178
+
179
+ self.expr = expr._expr
180
+ @public_api()
181
+ class SymInt32(SymIntExpr):
182
+ """
183
+ Symbolic expression for a 32-bit integer
184
+ """
185
+ def __init__(self, value: Union[int, trt.IDimensionExpr, SymIntExpr] = None):
186
+ super().__init__(value)
187
+
188
+ def __call__(self):
189
+ return SymExprImpl(self)
190
+ @public_api()
191
+ class SymInt8(SymIntExpr):
192
+ """
193
+ Symbolic expression for an 8-bit integer
194
+ """
195
+ def __init__(self, value: Union[int, trt.IDimensionExpr, "SymIntExpr"] = None):
196
+ super().__init__(value)
197
+
198
+ @public_api()
199
+ class SymInt16(SymIntExpr):
200
+ """
201
+ Symbolic expression for a 16-bit integer
202
+ """
203
+ def __init__(self, value: Union[int, trt.IDimensionExpr, "SymIntExpr"] = None):
204
+ super().__init__(value)
205
+
206
+ # Symbolic expression for a given dimension of a tensor
207
+ @public_api()
208
+ class ShapeExpr(SymInt32):
209
+ """
210
+ Symbolic expression for single dimension of a tensor
211
+ """
212
+ def __init__(self, value: Union[int, trt.IDimensionExpr, "ShapeExpr", SymIntExpr] = None):
213
+ """
214
+ Args:
215
+ value (Union[int, trt.IDimensionExpr, ShapeExpr, SymIntExpr], optional): Constant or another symbolic expression. Defaults to creating a fake shape expression.
216
+ """
217
+ super().__init__(value)
218
+ self._exprBuilder = SymIntExpr._exprBuilder
219
+ self._is_dummy = False
220
+ self._is_size_tensor = False
221
+ if value is None:
222
+ self._is_dummy = True
223
+ elif isinstance(value, int):
224
+ if self._exprBuilder is None:
225
+ self._is_dummy = True
226
+ elif isinstance(value, ShapeExpr):
227
+ self._is_dummy = value._is_dummy
228
+ self._is_size_tensor = value._is_size_tensor
229
+ elif isinstance(value, SymIntExpr):
230
+ pass
231
+
232
+ def _op(self, op: trt.DimensionOperation, other: Union[int, "ShapeExpr"]):
233
+ if self._is_size_tensor:
234
+ raise ValueError("It is not permitted to perform binary operations on size tensor expressions") # trt limitation
235
+ if self._is_dummy:
236
+ return ShapeExpr()
237
+ return ShapeExpr(super()._op(op, other))
238
+
83
239
  def __repr__(self):
84
240
  if self._is_dummy:
85
241
  return f"FakeShapeExpr[id={id(self)}]"
@@ -116,7 +272,7 @@ class ShapeExpr:
116
272
  raise RuntimeError(
117
273
  "Not accessible for fake 'ShapeExpr's. Check is_fake to determine accessibility."
118
274
  )
119
- return self._expr.is_constant()
275
+ return super().is_constant
120
276
 
121
277
  def constant_value(self) -> int:
122
278
  """
@@ -129,12 +285,7 @@ class ShapeExpr:
129
285
  raise RuntimeError(
130
286
  "Not accessible for non-constant shape expressions. Check is_constant to determine accessibility."
131
287
  )
132
- return self._expr.get_constant_value()
133
-
134
- # Evaluate the underlying trt.IDimensionExpr, if so done lazily
135
- @property
136
- def _expr(self):
137
- return self._dim_expr
288
+ return super().constant_value()
138
289
 
139
290
  def _clone(self):
140
291
  ret = ShapeExpr(self + 0)
@@ -172,73 +323,163 @@ class SizeTensorShapeExpr(ShapeExpr):
172
323
 
173
324
  @property
174
325
  def _expr(self):
175
- if self._dim_expr is not None:
176
- return self._dim_expr
177
-
178
- self._dim_expr = super()._exprBuilder.declare_size_tensor(self._size_tensor_desc.index, self._size_tensor_desc.opt._expr, self._size_tensor_desc.upper_bound._expr)
179
- return self._dim_expr
326
+ if self._int_expr is not None:
327
+ return self._int_expr
180
328
 
329
+ self._int_expr = super()._exprBuilder.declare_size_tensor(self._size_tensor_desc.index, self._size_tensor_desc.opt._expr, self._size_tensor_desc.upper_bound._expr)
330
+ return self._int_expr
331
+
181
332
  def __repr__(self):
182
333
  return f"ShapeExpr[is_size_tensor = True, id={id(self)}]"
183
334
 
335
+ def _from_scalar(s):
336
+ if isinstance(s, int):
337
+ return SymInt32(s)
338
+ elif isinstance(s, float):
339
+ raise ValueError("Float symbolic expressions are not supported")
340
+ else:
341
+ raise ValueError(f"Unsupported type: '{type(s)}'")
342
+
184
343
  # Iterable holding `ShapeExpr`s
185
344
  @public_api()
186
- class ShapeExprs:
187
- def __init__(self, length: int, _is_dummy: bool = False):
345
+ class SymExprs:
346
+ def __init__(self, length: int):
188
347
  """
189
- Iterable holding :class:`ShapeExpr`\s
348
+ Iterable holding symbolic expressions
190
349
 
191
350
  Args:
192
351
  length (int): Number of dimensions of the tensor
193
352
  """
194
353
  self._length = length
354
+ self._exprs = [None] * length
355
+
356
+ @classmethod
357
+ def from_tuple(cls, shape_exprs: Tuple[Union[SymExpr, int]]) -> "SymExprs":
358
+ """
359
+ Args:
360
+ shape_exprs (Tuple[Union[SymExpr, int]]): Tuple to construct :class:`SymExprs` from
361
+ """
362
+
363
+ shape_exprs_ = tuple([e if isinstance(e, SymExpr) else _from_scalar(e) for e in shape_exprs])
364
+ inst = cls(len(shape_exprs_))
365
+ inst._exprs = list(shape_exprs_)
366
+ return inst
367
+
368
+ def __iter__(self):
369
+ return iter(self._exprs)
370
+
371
+ def __getitem__(self, index):
372
+ return self._exprs[index]
373
+
374
+ def __len__(self):
375
+ return self._length
376
+
377
+ def __setitem__(self, index, expr):
378
+ if index >= self._length:
379
+ raise IndexError("Index out of range")
380
+
381
+ if not isinstance(expr, SymExpr):
382
+ expr = _from_scalar(expr)
383
+
384
+ self._exprs[index] = expr
385
+
386
+ def __repr__(self):
387
+ return f"SymExprs[{', '.join([s.__repr__() for s in self._exprs])}]"
388
+
389
+ @public_api()
390
+ class ShapeExprs(SymExprs):
391
+ def __init__(self, length, _is_dummy = False):
392
+ """
393
+ Iterable holding :class:`ShapeExpr`\s, representing a tensor shape
394
+
395
+ Args:
396
+ length (int): Number of dimensions of the tensor
397
+ """
398
+ if length > trt.Dims.MAX_DIMS:
399
+ raise ValueError(f"ShapeExprs can only support up to trt.Dims.MAX_DIMS = {trt.Dims.MAX_DIMS} dimensions. {length} given.")
400
+
401
+ super().__init__(length)
402
+
195
403
  self._is_dummy = _is_dummy
196
404
  if _is_dummy:
197
- self._shapes = [ShapeExpr()] * length
198
- else:
199
- self._shapes = [None] * length
405
+ self._exprs = [ShapeExpr()] * length
200
406
 
201
407
  @classmethod
202
- def from_tuple(cls, shape_exprs: Tuple[Union[ShapeExpr, int]]) -> "ShapeExprs":
408
+ def from_tuple(cls, shape_exprs: Tuple[Union[ShapeExpr, int]]) -> "ShapeExpr":
203
409
  """
204
410
  Args:
205
411
  shape_exprs (Tuple[Union[ShapeExpr, int]]): Tuple to construct :class:`ShapeExprs` from
206
412
  """
413
+
207
414
  shape_exprs_ = tuple([e if isinstance(e, ShapeExpr) else ShapeExpr(e) for e in shape_exprs])
208
415
  inst = cls(len(shape_exprs_))
209
- inst._shapes = list(shape_exprs_)
416
+ inst._exprs = list(shape_exprs_)
210
417
  return inst
211
-
418
+
212
419
  def numel(self) -> ShapeExpr:
213
420
  """
214
421
  Returns a symbolic expression for the number of elements
215
422
  """
216
423
  ret = ShapeExpr(1)
217
- for s in self._shapes:
424
+ for s in self._exprs:
218
425
  ret *= s
219
426
  return ret
220
427
 
221
- def __iter__(self):
222
- return iter(self._shapes)
223
-
224
- def __getitem__(self, index):
225
- return self._shapes[index]
226
-
227
- def __len__(self):
228
- return self._length
229
-
230
- def __setitem__(self, index, shape):
428
+ def __setitem__(self, index, value):
231
429
  if index >= self._length:
232
430
  raise IndexError("Index out of range")
233
- self._shapes[index] = shape
234
-
431
+
432
+ if not isinstance(value, ShapeExpr):
433
+ if not isinstance(value, int):
434
+ raise ValueError(f"Value should be int or ShapeExpr. Got '{type(value)}'")
435
+ value = ShapeExpr(value)
436
+
437
+ self._exprs[index] = value
438
+
235
439
  def __repr__(self):
236
- return f"ShapeExprs[{', '.join([s.__repr__() for s in self._shapes])}]"
440
+ return f"ShapeExprs[{', '.join([s.__repr__() for s in self._exprs])}]"
237
441
 
238
442
  def _clone(self):
239
- ret = ShapeExprs.from_tuple((e._clone() for e in self._shapes))
443
+ ret = ShapeExprs.from_tuple((e._clone() for e in self._exprs))
240
444
  ret._is_dummy = self._is_dummy
241
445
  return ret
446
+
447
+ @public_api()
448
+ class SymIntExprs(SymExprs):
449
+ def __init__(self, length):
450
+ """
451
+ Iterable holding :class:`SymIntExpr`\s
452
+
453
+ Args:
454
+ length (int): Number of symbolic expressions in the iterable
455
+ """
456
+ super().__init__(length)
457
+
458
+ @classmethod
459
+ def from_tuple(cls, shape_exprs: Tuple[Union[SymIntExpr, int]]) -> "SymIntExpr":
460
+ """
461
+ Args:
462
+ shape_exprs (Tuple[Union[SymIntExpr, int]]): Tuple to construct :class:`SymIntExprs` from
463
+ """
464
+
465
+ shape_exprs_ = tuple([e if isinstance(e, SymIntExpr) else SymIntExpr(e) for e in shape_exprs])
466
+ inst = cls(len(shape_exprs_))
467
+ inst._exprs = list(shape_exprs_)
468
+ return inst
469
+
470
+ def __setitem__(self, index, value):
471
+ if index >= self._length:
472
+ raise IndexError("Index out of range")
473
+
474
+ if not isinstance(value, SymIntExpr):
475
+ if not isinstance(value, int):
476
+ raise ValueError(f"Value should be int or SymIntExpr. Got '{type(value)}'")
477
+ value = SymIntExpr(value)
478
+
479
+ self._exprs[index] = value
480
+
481
+ def __repr__(self):
482
+ return f"SymIntExprs[{', '.join([s.__repr__() for s in self._exprs])}]"
242
483
 
243
484
  # Numerical representation of a tensor shape
244
485
  @public_api()
@@ -247,7 +488,7 @@ class Shape:
247
488
  Numerical representation of a tensor shape
248
489
  """
249
490
  def __init__(
250
- self, tensor_desc: Union[Tuple[int], trt.DynamicPluginTensorDesc, trt.PluginTensorDesc]
491
+ self, tensor_desc: Union[Tuple[int], trt.DynamicPluginTensorDesc, trt.PluginTensorDesc] = None
251
492
  ):
252
493
  self._is_dynamic = None # set lazily
253
494
  if isinstance(tensor_desc, trt.DynamicPluginTensorDesc):
@@ -353,6 +594,7 @@ class Shape:
353
594
  ret.__dict__.update(self.__dict__)
354
595
  return ret
355
596
 
597
+
356
598
  # Descriptor for a tensor
357
599
  # A `TensorDesc` never contains nor refers to any tensor data.
358
600
  @public_api()
@@ -848,3 +1090,39 @@ class Tensor:
848
1090
 
849
1091
  cloned._aliased_to = self
850
1092
  return cloned
1093
+
1094
+ if IS_AOT_ENABLED:
1095
+ @public_api()
1096
+ class KernelLaunchParams:
1097
+ """
1098
+ Args:
1099
+ grid_x (Union[int, trt.IDimensionExpr, SymInt32], optional): The grid x dimension. Defaults to 1.
1100
+ grid_y (Union[int, trt.IDimensionExpr, SymInt32], optional): The grid y dimension. Defaults to 1.
1101
+ grid_z (Union[int, trt.IDimensionExpr, SymInt32], optional): The grid z dimension. Defaults to 1.
1102
+ block_x (Union[int, trt.IDimensionExpr, SymInt32], optional): The x dimension of each thread block. Defaults to 1.
1103
+ block_y (Union[int, trt.IDimensionExpr, SymInt32], optional): The y dimension of each thread block. Defaults to 1.
1104
+ block_z (Union[int, trt.IDimensionExpr, SymInt32], optional): The z dimension of each thread block. Defaults to 1.
1105
+ shared_mem (Union[int, trt.IDimensionExpr, SymInt32], optional): Shared-memory per thread block in bytes. Defaults to 0.
1106
+ """
1107
+ def __init__(self,
1108
+ grid_x: Union[int, trt.IDimensionExpr, SymInt32] = 1,
1109
+ grid_y: Union[int, trt.IDimensionExpr, SymInt32] = 1,
1110
+ grid_z: Union[int, trt.IDimensionExpr, SymInt32] = 1,
1111
+ block_x: Union[int, trt.IDimensionExpr, SymInt32] = 1,
1112
+ block_y: Union[int, trt.IDimensionExpr, SymInt32] = 1,
1113
+ block_z: Union[int, trt.IDimensionExpr, SymInt32] = 1,
1114
+ shared_mem: Union[int, trt.IDimensionExpr, SymInt32] = 0):
1115
+ self.grid_x = SymInt32(grid_x)
1116
+ self.grid_y = SymInt32(grid_y)
1117
+ self.grid_z = SymInt32(grid_z)
1118
+ self.block_x = SymInt32(block_x)
1119
+ self.block_y = SymInt32(block_y)
1120
+ self.block_z = SymInt32(block_z)
1121
+ self.shared_mem = SymInt32(shared_mem)
1122
+
1123
+
1124
+ def __setattr__(self, name, value):
1125
+ if name in ["grid_x", "grid_y", "grid_z", "block_x", "block_y", "block_z", "shared_mem"]:
1126
+ self.__dict__[name] = SymInt32(value)
1127
+ else:
1128
+ raise AttributeError(f"KernelLaunchParams object has no attribute '{name}'")
@@ -1,5 +1,5 @@
1
1
  #
2
- # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,9 +18,13 @@
18
18
  import inspect
19
19
  import numpy as np
20
20
  import typing
21
+ import types
21
22
 
22
23
  from ._utils import _is_numpy_array, _join_with, _infer_numpy_type, _is_npt_ndarray
23
- from ._tensor import TensorDesc, Tensor
24
+ from ._tensor import TensorDesc, Tensor, SymExprs
25
+ from ._export import IS_AOT_ENABLED
26
+ if IS_AOT_ENABLED:
27
+ from ._tensor import KernelLaunchParams
24
28
  from ._autotune import AutoTuneCombination
25
29
 
26
30
  SERIALIZABLE_BUILTIN_TYPES = (int, float, bytes, bool, str)
@@ -93,7 +97,7 @@ def _parse_register_inputs(register_func, lazy_register):
93
97
  raise ValueError(
94
98
  f"Argument {name} is not a positional-or-keyword or keyword-only arg"
95
99
  )
96
-
100
+
97
101
  # Type annotations are manadatory for `tensorrt.plugin.register` args
98
102
  if param.annotation == inspect.Parameter.empty:
99
103
  raise ValueError(
@@ -105,7 +109,7 @@ def _parse_register_inputs(register_func, lazy_register):
105
109
  raise ValueError(
106
110
  f"Argument {name} has a default value. Default values are not supported yet."
107
111
  )
108
-
112
+
109
113
 
110
114
  if issubclass(param.annotation, TensorDesc):
111
115
  if saw_first_attr:
@@ -275,6 +279,120 @@ def _validate_impl(impl_func, plugin_def):
275
279
 
276
280
  return impl_attr_names, found_tactic
277
281
 
282
+ def _validate_aot_impl(aot_impl_func, plugin_def):
283
+ aot_impl_attr_names = []
284
+
285
+ sig = inspect.signature(aot_impl_func)
286
+ registered_attr_names = plugin_def.input_attrs.keys()
287
+
288
+ # input arg annotations are optional, but we will validate if provided
289
+ for name, param in sig.parameters.items():
290
+ if param.annotation != inspect.Parameter.empty:
291
+ if name == "outputs":
292
+ if typing.get_origin(param.annotation) is not tuple:
293
+ raise ValueError(
294
+ f"'outputs' should be of type Tuple[TensorDesc]. Received {param.annotation}."
295
+ )
296
+ args = typing.get_args(param.annotation)
297
+ for arg in args:
298
+ if not issubclass(arg, TensorDesc):
299
+ raise ValueError(
300
+ f"Argument for receiving output TensorDesc, '{name}' contains a {param.annotation}. '{name}' should be a Tuple[TensorDesc]."
301
+ )
302
+ elif name == "tactic":
303
+ if not issubclass(param.annotation, int):
304
+ raise ValueError("'tactic' input argument should be an int")
305
+ elif issubclass(param.annotation, TensorDesc):
306
+ if name not in plugin_def.input_tensor_names:
307
+ raise ValueError(
308
+ f"Unexpected tensor '{name}' specified in autotune function. Expected one of {plugin_def.input_tensor_names}."
309
+ )
310
+ else:
311
+ if name not in plugin_def.input_attrs:
312
+ raise ValueError(
313
+ f"Unexpected attribute '{name}' specified in aot_impl function. Expected one of {list(registered_attr_names)}."
314
+ )
315
+
316
+ if param.annotation != plugin_def.input_attrs[name]:
317
+ raise ValueError(
318
+ f"Attribute '{name}' has a type annotation different from the one specified at registration. Expected '{plugin_def.input_attrs[name]}'."
319
+ )
320
+
321
+ aot_impl_attr_names.append(name)
322
+ else:
323
+ if name in plugin_def.input_attrs:
324
+ aot_impl_attr_names.append(name)
325
+
326
+ # Expected attribute schema should be constructed in the order they appeared in the register function
327
+ expected_attr_schema_chunks = [
328
+ n for n in registered_attr_names if n in aot_impl_attr_names
329
+ ]
330
+
331
+ expected_schema = (
332
+ "("
333
+ + _join_with(plugin_def.input_tensor_names)
334
+ + _join_with(expected_attr_schema_chunks, True)
335
+ + ", outputs, tactic)"
336
+ )
337
+
338
+ if f"({', '.join(sig.parameters.keys())})" != expected_schema:
339
+ raise ValueError(
340
+ f"Signature of the aot_impl function '{sig}' does not match the expected input arg schema: {expected_schema}"
341
+ )
342
+
343
+ ret_annotation = sig.return_annotation
344
+
345
+ if ret_annotation == inspect.Parameter.empty:
346
+ raise ValueError(
347
+ f"No return annotation found for aot_impl function. Received signature {sig}."
348
+ )
349
+
350
+ expected_return_schema = "tuple[str | bytes, str | bytes, tensorrt.plugin.KernelLaunchParams, tensorrt.plugin.SymIntExprs]"
351
+
352
+ # Return annotation is optional, but we will validate if one is specified
353
+ if ret_annotation != inspect.Parameter.empty:
354
+ if typing.get_origin(ret_annotation) is not tuple:
355
+ raise ValueError(
356
+ f"Return annotation is {ret_annotation}. Expected {expected_return_schema}."
357
+ )
358
+ else:
359
+ args = typing.get_args(ret_annotation)
360
+
361
+ if len(args) != 4:
362
+ raise ValueError(
363
+ f"Return annotation is {ret_annotation}. Expected {expected_return_schema}."
364
+ )
365
+
366
+ def validate_union_str_or_bytes(index):
367
+ def validate_str_or_bytes(arg_):
368
+ if (arg_ is not str) and (arg_ is not bytes):
369
+ raise ValueError(
370
+ f"Return annotation for argument at {index} is '{arg_}'. Expected 'str' or 'bytes'."
371
+ )
372
+
373
+ orig = typing.get_origin(args[index])
374
+ # orig is `typing.Union` when annotation uses typing module (e.g, Union[str, bytes])
375
+ # orig is `types.UnionType` when annotation is of the new (3.10+) native syntax (e.g, str | bytes)
376
+ if orig is typing.Union or orig is types.UnionType:
377
+ for a in typing.get_args(args[index]):
378
+ validate_str_or_bytes(a)
379
+ else:
380
+ # when annoted with `str` or `bytes`
381
+ validate_str_or_bytes(args[index])
382
+
383
+ # kernel name should be str or bytes encoding
384
+ validate_union_str_or_bytes(0)
385
+ # kernel PTX should be str or bytes encoding
386
+ validate_union_str_or_bytes(1)
387
+
388
+ if not issubclass(args[2], KernelLaunchParams):
389
+ raise ValueError(f"Argument at index 2 of return annotation is '{args[2]}'. Expected 'tensorrt.plugin.KernelLaunchParams'.")
390
+
391
+ if not issubclass(args[3], SymExprs):
392
+ raise ValueError(f"Argument at index 3 of return annotation is '{args[3]}'. Expected a descendent of tensorrt.plugin.SymExprs.")
393
+
394
+ return aot_impl_attr_names
395
+
278
396
 
279
397
  def _validate_autotune(autotune_func, plugin_def):
280
398