tensorrt-cu12-bindings 10.8.0.43__cp39-none-win_amd64.whl → 10.9.0.34__cp39-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tensorrt-cu12-bindings might be problematic. Click here for more details.
- tensorrt_bindings/__init__.py +13 -8
- tensorrt_bindings/plugin/_export.py +4 -1
- tensorrt_bindings/plugin/_lib.py +199 -31
- tensorrt_bindings/plugin/_plugin_class.py +176 -41
- tensorrt_bindings/plugin/_tensor.py +351 -73
- tensorrt_bindings/plugin/_validate.py +122 -4
- tensorrt_bindings/tensorrt.cp39-win_amd64.pyd +0 -0
- {tensorrt_cu12_bindings-10.8.0.43.dist-info → tensorrt_cu12_bindings-10.9.0.34.dist-info}/METADATA +1 -1
- tensorrt_cu12_bindings-10.9.0.34.dist-info/RECORD +17 -0
- tensorrt_cu12_bindings-10.8.0.43.dist-info/RECORD +0 -17
- {tensorrt_cu12_bindings-10.8.0.43.dist-info → tensorrt_cu12_bindings-10.9.0.34.dist-info}/LICENSE.txt +0 -0
- {tensorrt_cu12_bindings-10.8.0.43.dist-info → tensorrt_cu12_bindings-10.9.0.34.dist-info}/WHEEL +0 -0
- {tensorrt_cu12_bindings-10.8.0.43.dist-info → tensorrt_cu12_bindings-10.9.0.34.dist-info}/top_level.txt +0 -0
- {tensorrt_cu12_bindings-10.8.0.43.dist-info → tensorrt_cu12_bindings-10.9.0.34.dist-info}/zip-safe +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
#
|
|
2
|
-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
3
|
# SPDX-License-Identifier: Apache-2.0
|
|
4
4
|
#
|
|
5
5
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -18,68 +18,224 @@
|
|
|
18
18
|
import tensorrt as trt
|
|
19
19
|
from typing import Tuple, Union
|
|
20
20
|
import numpy as np
|
|
21
|
-
from ._export import public_api
|
|
21
|
+
from ._export import public_api, IS_AOT_ENABLED
|
|
22
|
+
from abc import ABC, abstractmethod
|
|
22
23
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
24
|
+
class SymExpr(ABC):
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def _op(self, op, other):
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def __add__(self, other):
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def __sub__(self, other):
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def __mul__(self, other):
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def __floordiv__(self, other):
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def __eq__(self, other):
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def __lt__(self, other):
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def __repr__(self):
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def is_constant(self) -> bool:
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
@abstractmethod
|
|
65
|
+
def constant_value(self) -> int:
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
# Evaluate the underlying trt.IDimensionExpr, if so done lazily
|
|
69
|
+
@property
|
|
70
|
+
@abstractmethod
|
|
71
|
+
def _expr(self):
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
class SymIntExprMeta(type(SymExpr)):
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
class SymIntExpr(SymExpr, metaclass=SymIntExprMeta):
|
|
26
78
|
"""
|
|
27
|
-
Symbolic
|
|
79
|
+
Symbolic integer (scalar) expression
|
|
28
80
|
"""
|
|
29
81
|
_exprBuilder = None # trt.IExprBuilder instance. Populated when a shape-calculation context is entered.
|
|
30
82
|
|
|
31
|
-
def __init__(self, value: Union[int, trt.IDimensionExpr, "
|
|
83
|
+
def __init__(self, value: Union[int, trt.IDimensionExpr, "SymIntExpr"] = None):
|
|
32
84
|
"""
|
|
33
85
|
Args:
|
|
34
|
-
value (Union[int, trt.IDimensionExpr,
|
|
86
|
+
value (Union[int, trt.IDimensionExpr, SymIntExpr], optional): Constant or another symbolic expression. Defaults to creating a fake shape expression.
|
|
35
87
|
"""
|
|
36
|
-
self.
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
self._is_dummy = True
|
|
41
|
-
elif isinstance(value, int):
|
|
42
|
-
if self._exprBuilder is None:
|
|
43
|
-
self._dim_expr = None
|
|
44
|
-
self._is_dummy = True
|
|
88
|
+
self._int_expr = None
|
|
89
|
+
if isinstance(value, int):
|
|
90
|
+
if SymIntExpr._exprBuilder is None:
|
|
91
|
+
self._int_expr = None
|
|
45
92
|
else:
|
|
46
|
-
self.
|
|
93
|
+
self._int_expr = SymIntExpr._exprBuilder.constant(value)
|
|
47
94
|
elif isinstance(value, trt.IDimensionExpr):
|
|
48
|
-
self.
|
|
49
|
-
elif isinstance(value,
|
|
50
|
-
self.
|
|
51
|
-
self._is_dummy = value._is_dummy
|
|
52
|
-
self._is_size_tensor = value._is_size_tensor
|
|
95
|
+
self._int_expr = value
|
|
96
|
+
elif isinstance(value, SymIntExpr):
|
|
97
|
+
self._int_expr = value._int_expr
|
|
53
98
|
|
|
54
|
-
def _op(self, op: trt.DimensionOperation, other: Union[int, "
|
|
55
|
-
if self._is_size_tensor:
|
|
56
|
-
raise ValueError("It is not permitted to perform binary operations on size tensor expressions") # trt limitation
|
|
57
|
-
if self._is_dummy:
|
|
58
|
-
return ShapeExpr()
|
|
99
|
+
def _op(self, op: trt.DimensionOperation, other: Union[int, "SymIntExpr"]):
|
|
59
100
|
if isinstance(other, int):
|
|
60
|
-
other =
|
|
61
|
-
return
|
|
101
|
+
other = SymIntExpr(other)
|
|
102
|
+
return SymIntExpr(SymIntExpr._exprBuilder.operation(op, self._expr, other._expr))
|
|
62
103
|
|
|
63
104
|
# Binary operations for +, -, *, //, ==. <
|
|
64
105
|
# Those for ceil_div, max and min are provided as top-level functions of tensorrt.plugin
|
|
65
|
-
def __add__(self, other: Union[int, "
|
|
106
|
+
def __add__(self, other: Union[int, "SymIntExpr"]):
|
|
66
107
|
return self._op(trt.DimensionOperation.SUM, other)
|
|
67
108
|
|
|
68
|
-
def __sub__(self, other: Union[int, "
|
|
109
|
+
def __sub__(self, other: Union[int, "SymIntExpr"]):
|
|
69
110
|
return self._op(trt.DimensionOperation.SUB, other)
|
|
70
111
|
|
|
71
|
-
def __mul__(self, other: Union[int, "
|
|
112
|
+
def __mul__(self, other: Union[int, "SymIntExpr"]):
|
|
72
113
|
return self._op(trt.DimensionOperation.PROD, other)
|
|
73
114
|
|
|
74
|
-
def __floordiv__(self, other: Union[int, "
|
|
115
|
+
def __floordiv__(self, other: Union[int, "SymIntExpr"]):
|
|
75
116
|
return self._op(trt.DimensionOperation.FLOOR_DIV, other)
|
|
76
117
|
|
|
77
|
-
def __eq__(self, other: Union[int, "
|
|
118
|
+
def __eq__(self, other: Union[int, "SymIntExpr"]):
|
|
78
119
|
return self._op(trt.DimensionOperation.EQUAL, other)
|
|
79
120
|
|
|
80
|
-
def __lt__(self, other: Union[int, "
|
|
121
|
+
def __lt__(self, other: Union[int, "SymIntExpr"]):
|
|
81
122
|
return self._op(trt.DimensionOperation.LESS, other)
|
|
82
123
|
|
|
124
|
+
def __repr__(self):
|
|
125
|
+
if self._is_dummy:
|
|
126
|
+
return f"FakeSymIntExpr[id={id(self)}]"
|
|
127
|
+
elif not self.is_constant:
|
|
128
|
+
return f"SymIntExpr[id={id(self)}]"
|
|
129
|
+
return f"SymIntExpr[{self._expr.get_constant_value()}]"
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def is_constant(self) -> bool:
|
|
133
|
+
"""
|
|
134
|
+
`True` if this integer expression is a build-time constant, `False` otherwise.
|
|
135
|
+
|
|
136
|
+
Raises:
|
|
137
|
+
RuntimeError: For fake :class:`SymIntExpr`\s. Check :attr:`is_fake` to determine accessibility.
|
|
138
|
+
"""
|
|
139
|
+
if self._is_dummy:
|
|
140
|
+
raise RuntimeError(
|
|
141
|
+
"Not accessible for fake 'SymIntExpr's. Check is_fake to determine accessibility."
|
|
142
|
+
)
|
|
143
|
+
return self._expr.is_constant()
|
|
144
|
+
|
|
145
|
+
def constant_value(self) -> int:
|
|
146
|
+
"""
|
|
147
|
+
Return value of the constant integer expression.
|
|
148
|
+
|
|
149
|
+
Raises:
|
|
150
|
+
RuntimeError: For non-constant integer expressions. Check :attr:`is_constant` to determine accessibility.
|
|
151
|
+
"""
|
|
152
|
+
if not self.is_constant:
|
|
153
|
+
raise RuntimeError(
|
|
154
|
+
"Not accessible for non-constant integer expressions. Check is_constant to determine accessibility."
|
|
155
|
+
)
|
|
156
|
+
return self._expr.get_constant_value()
|
|
157
|
+
|
|
158
|
+
# Evaluate the underlying trt.IDimensionExpr, if so done lazily
|
|
159
|
+
@property
|
|
160
|
+
def _expr(self):
|
|
161
|
+
return self._int_expr
|
|
162
|
+
|
|
163
|
+
def _clone(self):
|
|
164
|
+
return SymIntExpr(self + 0)
|
|
165
|
+
|
|
166
|
+
class SymExprImpl(trt.ISymExpr):
|
|
167
|
+
def __init__(self, expr: SymIntExpr):
|
|
168
|
+
trt.ISymExpr.__init__(self)
|
|
169
|
+
self.type = trt.PluginArgType.INT
|
|
170
|
+
if isinstance(expr, SymInt32):
|
|
171
|
+
self.dtype = trt.PluginArgDataType.INT32
|
|
172
|
+
elif isinstance(expr, SymInt16):
|
|
173
|
+
self.dtype = trt.PluginArgDataType.INT16
|
|
174
|
+
elif isinstance(expr, SymInt8):
|
|
175
|
+
self.dtype = trt.PluginArgDataType.INT8
|
|
176
|
+
else:
|
|
177
|
+
raise ValueError(f"Unknown SymIntExpr type {type(expr)}")
|
|
178
|
+
|
|
179
|
+
self.expr = expr._expr
|
|
180
|
+
@public_api()
|
|
181
|
+
class SymInt32(SymIntExpr):
|
|
182
|
+
"""
|
|
183
|
+
Symbolic expression for a 32-bit integer
|
|
184
|
+
"""
|
|
185
|
+
def __init__(self, value: Union[int, trt.IDimensionExpr, SymIntExpr] = None):
|
|
186
|
+
super().__init__(value)
|
|
187
|
+
|
|
188
|
+
def __call__(self):
|
|
189
|
+
return SymExprImpl(self)
|
|
190
|
+
@public_api()
|
|
191
|
+
class SymInt8(SymIntExpr):
|
|
192
|
+
"""
|
|
193
|
+
Symbolic expression for an 8-bit integer
|
|
194
|
+
"""
|
|
195
|
+
def __init__(self, value: Union[int, trt.IDimensionExpr, "SymIntExpr"] = None):
|
|
196
|
+
super().__init__(value)
|
|
197
|
+
|
|
198
|
+
@public_api()
|
|
199
|
+
class SymInt16(SymIntExpr):
|
|
200
|
+
"""
|
|
201
|
+
Symbolic expression for a 16-bit integer
|
|
202
|
+
"""
|
|
203
|
+
def __init__(self, value: Union[int, trt.IDimensionExpr, "SymIntExpr"] = None):
|
|
204
|
+
super().__init__(value)
|
|
205
|
+
|
|
206
|
+
# Symbolic expression for a given dimension of a tensor
|
|
207
|
+
@public_api()
|
|
208
|
+
class ShapeExpr(SymInt32):
|
|
209
|
+
"""
|
|
210
|
+
Symbolic expression for single dimension of a tensor
|
|
211
|
+
"""
|
|
212
|
+
def __init__(self, value: Union[int, trt.IDimensionExpr, "ShapeExpr", SymIntExpr] = None):
|
|
213
|
+
"""
|
|
214
|
+
Args:
|
|
215
|
+
value (Union[int, trt.IDimensionExpr, ShapeExpr, SymIntExpr], optional): Constant or another symbolic expression. Defaults to creating a fake shape expression.
|
|
216
|
+
"""
|
|
217
|
+
super().__init__(value)
|
|
218
|
+
self._exprBuilder = SymIntExpr._exprBuilder
|
|
219
|
+
self._is_dummy = False
|
|
220
|
+
self._is_size_tensor = False
|
|
221
|
+
if value is None:
|
|
222
|
+
self._is_dummy = True
|
|
223
|
+
elif isinstance(value, int):
|
|
224
|
+
if self._exprBuilder is None:
|
|
225
|
+
self._is_dummy = True
|
|
226
|
+
elif isinstance(value, ShapeExpr):
|
|
227
|
+
self._is_dummy = value._is_dummy
|
|
228
|
+
self._is_size_tensor = value._is_size_tensor
|
|
229
|
+
elif isinstance(value, SymIntExpr):
|
|
230
|
+
pass
|
|
231
|
+
|
|
232
|
+
def _op(self, op: trt.DimensionOperation, other: Union[int, "ShapeExpr"]):
|
|
233
|
+
if self._is_size_tensor:
|
|
234
|
+
raise ValueError("It is not permitted to perform binary operations on size tensor expressions") # trt limitation
|
|
235
|
+
if self._is_dummy:
|
|
236
|
+
return ShapeExpr()
|
|
237
|
+
return ShapeExpr(super()._op(op, other))
|
|
238
|
+
|
|
83
239
|
def __repr__(self):
|
|
84
240
|
if self._is_dummy:
|
|
85
241
|
return f"FakeShapeExpr[id={id(self)}]"
|
|
@@ -116,7 +272,7 @@ class ShapeExpr:
|
|
|
116
272
|
raise RuntimeError(
|
|
117
273
|
"Not accessible for fake 'ShapeExpr's. Check is_fake to determine accessibility."
|
|
118
274
|
)
|
|
119
|
-
return
|
|
275
|
+
return super().is_constant
|
|
120
276
|
|
|
121
277
|
def constant_value(self) -> int:
|
|
122
278
|
"""
|
|
@@ -129,12 +285,7 @@ class ShapeExpr:
|
|
|
129
285
|
raise RuntimeError(
|
|
130
286
|
"Not accessible for non-constant shape expressions. Check is_constant to determine accessibility."
|
|
131
287
|
)
|
|
132
|
-
return
|
|
133
|
-
|
|
134
|
-
# Evaluate the underlying trt.IDimensionExpr, if so done lazily
|
|
135
|
-
@property
|
|
136
|
-
def _expr(self):
|
|
137
|
-
return self._dim_expr
|
|
288
|
+
return super().constant_value()
|
|
138
289
|
|
|
139
290
|
def _clone(self):
|
|
140
291
|
ret = ShapeExpr(self + 0)
|
|
@@ -172,73 +323,163 @@ class SizeTensorShapeExpr(ShapeExpr):
|
|
|
172
323
|
|
|
173
324
|
@property
|
|
174
325
|
def _expr(self):
|
|
175
|
-
if self.
|
|
176
|
-
return self.
|
|
177
|
-
|
|
178
|
-
self._dim_expr = super()._exprBuilder.declare_size_tensor(self._size_tensor_desc.index, self._size_tensor_desc.opt._expr, self._size_tensor_desc.upper_bound._expr)
|
|
179
|
-
return self._dim_expr
|
|
326
|
+
if self._int_expr is not None:
|
|
327
|
+
return self._int_expr
|
|
180
328
|
|
|
329
|
+
self._int_expr = super()._exprBuilder.declare_size_tensor(self._size_tensor_desc.index, self._size_tensor_desc.opt._expr, self._size_tensor_desc.upper_bound._expr)
|
|
330
|
+
return self._int_expr
|
|
331
|
+
|
|
181
332
|
def __repr__(self):
|
|
182
333
|
return f"ShapeExpr[is_size_tensor = True, id={id(self)}]"
|
|
183
334
|
|
|
335
|
+
def _from_scalar(s):
|
|
336
|
+
if isinstance(s, int):
|
|
337
|
+
return SymInt32(s)
|
|
338
|
+
elif isinstance(s, float):
|
|
339
|
+
raise ValueError("Float symbolic expressions are not supported")
|
|
340
|
+
else:
|
|
341
|
+
raise ValueError(f"Unsupported type: '{type(s)}'")
|
|
342
|
+
|
|
184
343
|
# Iterable holding `ShapeExpr`s
|
|
185
344
|
@public_api()
|
|
186
|
-
class
|
|
187
|
-
def __init__(self, length: int
|
|
345
|
+
class SymExprs:
|
|
346
|
+
def __init__(self, length: int):
|
|
188
347
|
"""
|
|
189
|
-
Iterable holding
|
|
348
|
+
Iterable holding symbolic expressions
|
|
190
349
|
|
|
191
350
|
Args:
|
|
192
351
|
length (int): Number of dimensions of the tensor
|
|
193
352
|
"""
|
|
194
353
|
self._length = length
|
|
354
|
+
self._exprs = [None] * length
|
|
355
|
+
|
|
356
|
+
@classmethod
|
|
357
|
+
def from_tuple(cls, shape_exprs: Tuple[Union[SymExpr, int]]) -> "SymExprs":
|
|
358
|
+
"""
|
|
359
|
+
Args:
|
|
360
|
+
shape_exprs (Tuple[Union[SymExpr, int]]): Tuple to construct :class:`SymExprs` from
|
|
361
|
+
"""
|
|
362
|
+
|
|
363
|
+
shape_exprs_ = tuple([e if isinstance(e, SymExpr) else _from_scalar(e) for e in shape_exprs])
|
|
364
|
+
inst = cls(len(shape_exprs_))
|
|
365
|
+
inst._exprs = list(shape_exprs_)
|
|
366
|
+
return inst
|
|
367
|
+
|
|
368
|
+
def __iter__(self):
|
|
369
|
+
return iter(self._exprs)
|
|
370
|
+
|
|
371
|
+
def __getitem__(self, index):
|
|
372
|
+
return self._exprs[index]
|
|
373
|
+
|
|
374
|
+
def __len__(self):
|
|
375
|
+
return self._length
|
|
376
|
+
|
|
377
|
+
def __setitem__(self, index, expr):
|
|
378
|
+
if index >= self._length:
|
|
379
|
+
raise IndexError("Index out of range")
|
|
380
|
+
|
|
381
|
+
if not isinstance(expr, SymExpr):
|
|
382
|
+
expr = _from_scalar(expr)
|
|
383
|
+
|
|
384
|
+
self._exprs[index] = expr
|
|
385
|
+
|
|
386
|
+
def __repr__(self):
|
|
387
|
+
return f"SymExprs[{', '.join([s.__repr__() for s in self._exprs])}]"
|
|
388
|
+
|
|
389
|
+
@public_api()
|
|
390
|
+
class ShapeExprs(SymExprs):
|
|
391
|
+
def __init__(self, length, _is_dummy = False):
|
|
392
|
+
"""
|
|
393
|
+
Iterable holding :class:`ShapeExpr`\s, representing a tensor shape
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
length (int): Number of dimensions of the tensor
|
|
397
|
+
"""
|
|
398
|
+
if length > trt.Dims.MAX_DIMS:
|
|
399
|
+
raise ValueError(f"ShapeExprs can only support up to trt.Dims.MAX_DIMS = {trt.Dims.MAX_DIMS} dimensions. {length} given.")
|
|
400
|
+
|
|
401
|
+
super().__init__(length)
|
|
402
|
+
|
|
195
403
|
self._is_dummy = _is_dummy
|
|
196
404
|
if _is_dummy:
|
|
197
|
-
self.
|
|
198
|
-
else:
|
|
199
|
-
self._shapes = [None] * length
|
|
405
|
+
self._exprs = [ShapeExpr()] * length
|
|
200
406
|
|
|
201
407
|
@classmethod
|
|
202
|
-
def from_tuple(cls, shape_exprs: Tuple[Union[ShapeExpr, int]]) -> "
|
|
408
|
+
def from_tuple(cls, shape_exprs: Tuple[Union[ShapeExpr, int]]) -> "ShapeExpr":
|
|
203
409
|
"""
|
|
204
410
|
Args:
|
|
205
411
|
shape_exprs (Tuple[Union[ShapeExpr, int]]): Tuple to construct :class:`ShapeExprs` from
|
|
206
412
|
"""
|
|
413
|
+
|
|
207
414
|
shape_exprs_ = tuple([e if isinstance(e, ShapeExpr) else ShapeExpr(e) for e in shape_exprs])
|
|
208
415
|
inst = cls(len(shape_exprs_))
|
|
209
|
-
inst.
|
|
416
|
+
inst._exprs = list(shape_exprs_)
|
|
210
417
|
return inst
|
|
211
|
-
|
|
418
|
+
|
|
212
419
|
def numel(self) -> ShapeExpr:
|
|
213
420
|
"""
|
|
214
421
|
Returns a symbolic expression for the number of elements
|
|
215
422
|
"""
|
|
216
423
|
ret = ShapeExpr(1)
|
|
217
|
-
for s in self.
|
|
424
|
+
for s in self._exprs:
|
|
218
425
|
ret *= s
|
|
219
426
|
return ret
|
|
220
427
|
|
|
221
|
-
def
|
|
222
|
-
return iter(self._shapes)
|
|
223
|
-
|
|
224
|
-
def __getitem__(self, index):
|
|
225
|
-
return self._shapes[index]
|
|
226
|
-
|
|
227
|
-
def __len__(self):
|
|
228
|
-
return self._length
|
|
229
|
-
|
|
230
|
-
def __setitem__(self, index, shape):
|
|
428
|
+
def __setitem__(self, index, value):
|
|
231
429
|
if index >= self._length:
|
|
232
430
|
raise IndexError("Index out of range")
|
|
233
|
-
|
|
234
|
-
|
|
431
|
+
|
|
432
|
+
if not isinstance(value, ShapeExpr):
|
|
433
|
+
if not isinstance(value, int):
|
|
434
|
+
raise ValueError(f"Value should be int or ShapeExpr. Got '{type(value)}'")
|
|
435
|
+
value = ShapeExpr(value)
|
|
436
|
+
|
|
437
|
+
self._exprs[index] = value
|
|
438
|
+
|
|
235
439
|
def __repr__(self):
|
|
236
|
-
return f"ShapeExprs[{', '.join([s.__repr__() for s in self.
|
|
440
|
+
return f"ShapeExprs[{', '.join([s.__repr__() for s in self._exprs])}]"
|
|
237
441
|
|
|
238
442
|
def _clone(self):
|
|
239
|
-
ret = ShapeExprs.from_tuple((e._clone() for e in self.
|
|
443
|
+
ret = ShapeExprs.from_tuple((e._clone() for e in self._exprs))
|
|
240
444
|
ret._is_dummy = self._is_dummy
|
|
241
445
|
return ret
|
|
446
|
+
|
|
447
|
+
@public_api()
|
|
448
|
+
class SymIntExprs(SymExprs):
|
|
449
|
+
def __init__(self, length):
|
|
450
|
+
"""
|
|
451
|
+
Iterable holding :class:`SymIntExpr`\s
|
|
452
|
+
|
|
453
|
+
Args:
|
|
454
|
+
length (int): Number of symbolic expressions in the iterable
|
|
455
|
+
"""
|
|
456
|
+
super().__init__(length)
|
|
457
|
+
|
|
458
|
+
@classmethod
|
|
459
|
+
def from_tuple(cls, shape_exprs: Tuple[Union[SymIntExpr, int]]) -> "SymIntExpr":
|
|
460
|
+
"""
|
|
461
|
+
Args:
|
|
462
|
+
shape_exprs (Tuple[Union[SymIntExpr, int]]): Tuple to construct :class:`SymIntExprs` from
|
|
463
|
+
"""
|
|
464
|
+
|
|
465
|
+
shape_exprs_ = tuple([e if isinstance(e, SymIntExpr) else SymIntExpr(e) for e in shape_exprs])
|
|
466
|
+
inst = cls(len(shape_exprs_))
|
|
467
|
+
inst._exprs = list(shape_exprs_)
|
|
468
|
+
return inst
|
|
469
|
+
|
|
470
|
+
def __setitem__(self, index, value):
|
|
471
|
+
if index >= self._length:
|
|
472
|
+
raise IndexError("Index out of range")
|
|
473
|
+
|
|
474
|
+
if not isinstance(value, SymIntExpr):
|
|
475
|
+
if not isinstance(value, int):
|
|
476
|
+
raise ValueError(f"Value should be int or SymIntExpr. Got '{type(value)}'")
|
|
477
|
+
value = SymIntExpr(value)
|
|
478
|
+
|
|
479
|
+
self._exprs[index] = value
|
|
480
|
+
|
|
481
|
+
def __repr__(self):
|
|
482
|
+
return f"SymIntExprs[{', '.join([s.__repr__() for s in self._exprs])}]"
|
|
242
483
|
|
|
243
484
|
# Numerical representation of a tensor shape
|
|
244
485
|
@public_api()
|
|
@@ -247,7 +488,7 @@ class Shape:
|
|
|
247
488
|
Numerical representation of a tensor shape
|
|
248
489
|
"""
|
|
249
490
|
def __init__(
|
|
250
|
-
self, tensor_desc: Union[Tuple[int], trt.DynamicPluginTensorDesc, trt.PluginTensorDesc]
|
|
491
|
+
self, tensor_desc: Union[Tuple[int], trt.DynamicPluginTensorDesc, trt.PluginTensorDesc] = None
|
|
251
492
|
):
|
|
252
493
|
self._is_dynamic = None # set lazily
|
|
253
494
|
if isinstance(tensor_desc, trt.DynamicPluginTensorDesc):
|
|
@@ -353,6 +594,7 @@ class Shape:
|
|
|
353
594
|
ret.__dict__.update(self.__dict__)
|
|
354
595
|
return ret
|
|
355
596
|
|
|
597
|
+
|
|
356
598
|
# Descriptor for a tensor
|
|
357
599
|
# A `TensorDesc` never contains nor refers to any tensor data.
|
|
358
600
|
@public_api()
|
|
@@ -848,3 +1090,39 @@ class Tensor:
|
|
|
848
1090
|
|
|
849
1091
|
cloned._aliased_to = self
|
|
850
1092
|
return cloned
|
|
1093
|
+
|
|
1094
|
+
if IS_AOT_ENABLED:
|
|
1095
|
+
@public_api()
|
|
1096
|
+
class KernelLaunchParams:
|
|
1097
|
+
"""
|
|
1098
|
+
Args:
|
|
1099
|
+
grid_x (Union[int, trt.IDimensionExpr, SymInt32], optional): The grid x dimension. Defaults to 1.
|
|
1100
|
+
grid_y (Union[int, trt.IDimensionExpr, SymInt32], optional): The grid y dimension. Defaults to 1.
|
|
1101
|
+
grid_z (Union[int, trt.IDimensionExpr, SymInt32], optional): The grid z dimension. Defaults to 1.
|
|
1102
|
+
block_x (Union[int, trt.IDimensionExpr, SymInt32], optional): The x dimension of each thread block. Defaults to 1.
|
|
1103
|
+
block_y (Union[int, trt.IDimensionExpr, SymInt32], optional): The y dimension of each thread block. Defaults to 1.
|
|
1104
|
+
block_z (Union[int, trt.IDimensionExpr, SymInt32], optional): The z dimension of each thread block. Defaults to 1.
|
|
1105
|
+
shared_mem (Union[int, trt.IDimensionExpr, SymInt32], optional): Shared-memory per thread block in bytes. Defaults to 0.
|
|
1106
|
+
"""
|
|
1107
|
+
def __init__(self,
|
|
1108
|
+
grid_x: Union[int, trt.IDimensionExpr, SymInt32] = 1,
|
|
1109
|
+
grid_y: Union[int, trt.IDimensionExpr, SymInt32] = 1,
|
|
1110
|
+
grid_z: Union[int, trt.IDimensionExpr, SymInt32] = 1,
|
|
1111
|
+
block_x: Union[int, trt.IDimensionExpr, SymInt32] = 1,
|
|
1112
|
+
block_y: Union[int, trt.IDimensionExpr, SymInt32] = 1,
|
|
1113
|
+
block_z: Union[int, trt.IDimensionExpr, SymInt32] = 1,
|
|
1114
|
+
shared_mem: Union[int, trt.IDimensionExpr, SymInt32] = 0):
|
|
1115
|
+
self.grid_x = SymInt32(grid_x)
|
|
1116
|
+
self.grid_y = SymInt32(grid_y)
|
|
1117
|
+
self.grid_z = SymInt32(grid_z)
|
|
1118
|
+
self.block_x = SymInt32(block_x)
|
|
1119
|
+
self.block_y = SymInt32(block_y)
|
|
1120
|
+
self.block_z = SymInt32(block_z)
|
|
1121
|
+
self.shared_mem = SymInt32(shared_mem)
|
|
1122
|
+
|
|
1123
|
+
|
|
1124
|
+
def __setattr__(self, name, value):
|
|
1125
|
+
if name in ["grid_x", "grid_y", "grid_z", "block_x", "block_y", "block_z", "shared_mem"]:
|
|
1126
|
+
self.__dict__[name] = SymInt32(value)
|
|
1127
|
+
else:
|
|
1128
|
+
raise AttributeError(f"KernelLaunchParams object has no attribute '{name}'")
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
#
|
|
2
|
-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
3
|
# SPDX-License-Identifier: Apache-2.0
|
|
4
4
|
#
|
|
5
5
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -18,9 +18,13 @@
|
|
|
18
18
|
import inspect
|
|
19
19
|
import numpy as np
|
|
20
20
|
import typing
|
|
21
|
+
import types
|
|
21
22
|
|
|
22
23
|
from ._utils import _is_numpy_array, _join_with, _infer_numpy_type, _is_npt_ndarray
|
|
23
|
-
from ._tensor import TensorDesc, Tensor
|
|
24
|
+
from ._tensor import TensorDesc, Tensor, SymExprs
|
|
25
|
+
from ._export import IS_AOT_ENABLED
|
|
26
|
+
if IS_AOT_ENABLED:
|
|
27
|
+
from ._tensor import KernelLaunchParams
|
|
24
28
|
from ._autotune import AutoTuneCombination
|
|
25
29
|
|
|
26
30
|
SERIALIZABLE_BUILTIN_TYPES = (int, float, bytes, bool, str)
|
|
@@ -93,7 +97,7 @@ def _parse_register_inputs(register_func, lazy_register):
|
|
|
93
97
|
raise ValueError(
|
|
94
98
|
f"Argument {name} is not a positional-or-keyword or keyword-only arg"
|
|
95
99
|
)
|
|
96
|
-
|
|
100
|
+
|
|
97
101
|
# Type annotations are manadatory for `tensorrt.plugin.register` args
|
|
98
102
|
if param.annotation == inspect.Parameter.empty:
|
|
99
103
|
raise ValueError(
|
|
@@ -105,7 +109,7 @@ def _parse_register_inputs(register_func, lazy_register):
|
|
|
105
109
|
raise ValueError(
|
|
106
110
|
f"Argument {name} has a default value. Default values are not supported yet."
|
|
107
111
|
)
|
|
108
|
-
|
|
112
|
+
|
|
109
113
|
|
|
110
114
|
if issubclass(param.annotation, TensorDesc):
|
|
111
115
|
if saw_first_attr:
|
|
@@ -275,6 +279,120 @@ def _validate_impl(impl_func, plugin_def):
|
|
|
275
279
|
|
|
276
280
|
return impl_attr_names, found_tactic
|
|
277
281
|
|
|
282
|
+
def _validate_aot_impl(aot_impl_func, plugin_def):
|
|
283
|
+
aot_impl_attr_names = []
|
|
284
|
+
|
|
285
|
+
sig = inspect.signature(aot_impl_func)
|
|
286
|
+
registered_attr_names = plugin_def.input_attrs.keys()
|
|
287
|
+
|
|
288
|
+
# input arg annotations are optional, but we will validate if provided
|
|
289
|
+
for name, param in sig.parameters.items():
|
|
290
|
+
if param.annotation != inspect.Parameter.empty:
|
|
291
|
+
if name == "outputs":
|
|
292
|
+
if typing.get_origin(param.annotation) is not tuple:
|
|
293
|
+
raise ValueError(
|
|
294
|
+
f"'outputs' should be of type Tuple[TensorDesc]. Received {param.annotation}."
|
|
295
|
+
)
|
|
296
|
+
args = typing.get_args(param.annotation)
|
|
297
|
+
for arg in args:
|
|
298
|
+
if not issubclass(arg, TensorDesc):
|
|
299
|
+
raise ValueError(
|
|
300
|
+
f"Argument for receiving output TensorDesc, '{name}' contains a {param.annotation}. '{name}' should be a Tuple[TensorDesc]."
|
|
301
|
+
)
|
|
302
|
+
elif name == "tactic":
|
|
303
|
+
if not issubclass(param.annotation, int):
|
|
304
|
+
raise ValueError("'tactic' input argument should be an int")
|
|
305
|
+
elif issubclass(param.annotation, TensorDesc):
|
|
306
|
+
if name not in plugin_def.input_tensor_names:
|
|
307
|
+
raise ValueError(
|
|
308
|
+
f"Unexpected tensor '{name}' specified in autotune function. Expected one of {plugin_def.input_tensor_names}."
|
|
309
|
+
)
|
|
310
|
+
else:
|
|
311
|
+
if name not in plugin_def.input_attrs:
|
|
312
|
+
raise ValueError(
|
|
313
|
+
f"Unexpected attribute '{name}' specified in aot_impl function. Expected one of {list(registered_attr_names)}."
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
if param.annotation != plugin_def.input_attrs[name]:
|
|
317
|
+
raise ValueError(
|
|
318
|
+
f"Attribute '{name}' has a type annotation different from the one specified at registration. Expected '{plugin_def.input_attrs[name]}'."
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
aot_impl_attr_names.append(name)
|
|
322
|
+
else:
|
|
323
|
+
if name in plugin_def.input_attrs:
|
|
324
|
+
aot_impl_attr_names.append(name)
|
|
325
|
+
|
|
326
|
+
# Expected attribute schema should be constructed in the order they appeared in the register function
|
|
327
|
+
expected_attr_schema_chunks = [
|
|
328
|
+
n for n in registered_attr_names if n in aot_impl_attr_names
|
|
329
|
+
]
|
|
330
|
+
|
|
331
|
+
expected_schema = (
|
|
332
|
+
"("
|
|
333
|
+
+ _join_with(plugin_def.input_tensor_names)
|
|
334
|
+
+ _join_with(expected_attr_schema_chunks, True)
|
|
335
|
+
+ ", outputs, tactic)"
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
if f"({', '.join(sig.parameters.keys())})" != expected_schema:
|
|
339
|
+
raise ValueError(
|
|
340
|
+
f"Signature of the aot_impl function '{sig}' does not match the expected input arg schema: {expected_schema}"
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
ret_annotation = sig.return_annotation
|
|
344
|
+
|
|
345
|
+
if ret_annotation == inspect.Parameter.empty:
|
|
346
|
+
raise ValueError(
|
|
347
|
+
f"No return annotation found for aot_impl function. Received signature {sig}."
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
expected_return_schema = "tuple[str | bytes, str | bytes, tensorrt.plugin.KernelLaunchParams, tensorrt.plugin.SymIntExprs]"
|
|
351
|
+
|
|
352
|
+
# Return annotation is optional, but we will validate if one is specified
|
|
353
|
+
if ret_annotation != inspect.Parameter.empty:
|
|
354
|
+
if typing.get_origin(ret_annotation) is not tuple:
|
|
355
|
+
raise ValueError(
|
|
356
|
+
f"Return annotation is {ret_annotation}. Expected {expected_return_schema}."
|
|
357
|
+
)
|
|
358
|
+
else:
|
|
359
|
+
args = typing.get_args(ret_annotation)
|
|
360
|
+
|
|
361
|
+
if len(args) != 4:
|
|
362
|
+
raise ValueError(
|
|
363
|
+
f"Return annotation is {ret_annotation}. Expected {expected_return_schema}."
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
def validate_union_str_or_bytes(index):
|
|
367
|
+
def validate_str_or_bytes(arg_):
|
|
368
|
+
if (arg_ is not str) and (arg_ is not bytes):
|
|
369
|
+
raise ValueError(
|
|
370
|
+
f"Return annotation for argument at {index} is '{arg_}'. Expected 'str' or 'bytes'."
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
orig = typing.get_origin(args[index])
|
|
374
|
+
# orig is `typing.Union` when annotation uses typing module (e.g, Union[str, bytes])
|
|
375
|
+
# orig is `types.UnionType` when annotation is of the new (3.10+) native syntax (e.g, str | bytes)
|
|
376
|
+
if orig is typing.Union or orig is types.UnionType:
|
|
377
|
+
for a in typing.get_args(args[index]):
|
|
378
|
+
validate_str_or_bytes(a)
|
|
379
|
+
else:
|
|
380
|
+
# when annoted with `str` or `bytes`
|
|
381
|
+
validate_str_or_bytes(args[index])
|
|
382
|
+
|
|
383
|
+
# kernel name should be str or bytes encoding
|
|
384
|
+
validate_union_str_or_bytes(0)
|
|
385
|
+
# kernel PTX should be str or bytes encoding
|
|
386
|
+
validate_union_str_or_bytes(1)
|
|
387
|
+
|
|
388
|
+
if not issubclass(args[2], KernelLaunchParams):
|
|
389
|
+
raise ValueError(f"Argument at index 2 of return annotation is '{args[2]}'. Expected 'tensorrt.plugin.KernelLaunchParams'.")
|
|
390
|
+
|
|
391
|
+
if not issubclass(args[3], SymExprs):
|
|
392
|
+
raise ValueError(f"Argument at index 3 of return annotation is '{args[3]}'. Expected a descendent of tensorrt.plugin.SymExprs.")
|
|
393
|
+
|
|
394
|
+
return aot_impl_attr_names
|
|
395
|
+
|
|
278
396
|
|
|
279
397
|
def _validate_autotune(autotune_func, plugin_def):
|
|
280
398
|
|