tensorrt-cu12-bindings 10.7.0.post1__cp38-none-win_amd64.whl → 10.9.0.34__cp38-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tensorrt-cu12-bindings might be problematic. Click here for more details.
- tensorrt_bindings/__init__.py +17 -8
- tensorrt_bindings/plugin/__init__.py +1 -1
- tensorrt_bindings/plugin/_export.py +4 -1
- tensorrt_bindings/plugin/_lib.py +199 -31
- tensorrt_bindings/plugin/_plugin_class.py +178 -40
- tensorrt_bindings/plugin/_tensor.py +375 -72
- tensorrt_bindings/plugin/_validate.py +122 -4
- tensorrt_bindings/tensorrt.cp38-win_amd64.pyd +0 -0
- {tensorrt_cu12_bindings-10.7.0.post1.dist-info → tensorrt_cu12_bindings-10.9.0.34.dist-info}/METADATA +1 -1
- tensorrt_cu12_bindings-10.9.0.34.dist-info/RECORD +17 -0
- tensorrt_cu12_bindings-10.7.0.post1.dist-info/RECORD +0 -17
- {tensorrt_cu12_bindings-10.7.0.post1.dist-info → tensorrt_cu12_bindings-10.9.0.34.dist-info}/LICENSE.txt +0 -0
- {tensorrt_cu12_bindings-10.7.0.post1.dist-info → tensorrt_cu12_bindings-10.9.0.34.dist-info}/WHEEL +0 -0
- {tensorrt_cu12_bindings-10.7.0.post1.dist-info → tensorrt_cu12_bindings-10.9.0.34.dist-info}/top_level.txt +0 -0
- {tensorrt_cu12_bindings-10.7.0.post1.dist-info → tensorrt_cu12_bindings-10.9.0.34.dist-info}/zip-safe +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
#
|
|
2
|
-
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
3
|
# SPDX-License-Identifier: Apache-2.0
|
|
4
4
|
#
|
|
5
5
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -18,68 +18,224 @@
|
|
|
18
18
|
import tensorrt as trt
|
|
19
19
|
from typing import Tuple, Union
|
|
20
20
|
import numpy as np
|
|
21
|
-
from ._export import public_api
|
|
21
|
+
from ._export import public_api, IS_AOT_ENABLED
|
|
22
|
+
from abc import ABC, abstractmethod
|
|
22
23
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
24
|
+
class SymExpr(ABC):
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def _op(self, op, other):
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def __add__(self, other):
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def __sub__(self, other):
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def __mul__(self, other):
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def __floordiv__(self, other):
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def __eq__(self, other):
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def __lt__(self, other):
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def __repr__(self):
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def is_constant(self) -> bool:
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
@abstractmethod
|
|
65
|
+
def constant_value(self) -> int:
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
# Evaluate the underlying trt.IDimensionExpr, if so done lazily
|
|
69
|
+
@property
|
|
70
|
+
@abstractmethod
|
|
71
|
+
def _expr(self):
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
class SymIntExprMeta(type(SymExpr)):
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
class SymIntExpr(SymExpr, metaclass=SymIntExprMeta):
|
|
26
78
|
"""
|
|
27
|
-
Symbolic
|
|
79
|
+
Symbolic integer (scalar) expression
|
|
28
80
|
"""
|
|
29
81
|
_exprBuilder = None # trt.IExprBuilder instance. Populated when a shape-calculation context is entered.
|
|
30
82
|
|
|
31
|
-
def __init__(self, value: Union[int, trt.IDimensionExpr, "
|
|
83
|
+
def __init__(self, value: Union[int, trt.IDimensionExpr, "SymIntExpr"] = None):
|
|
32
84
|
"""
|
|
33
85
|
Args:
|
|
34
|
-
value (Union[int, trt.IDimensionExpr,
|
|
86
|
+
value (Union[int, trt.IDimensionExpr, SymIntExpr], optional): Constant or another symbolic expression. Defaults to creating a fake shape expression.
|
|
35
87
|
"""
|
|
36
|
-
self.
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
self._is_dummy = True
|
|
41
|
-
elif isinstance(value, int):
|
|
42
|
-
if self._exprBuilder is None:
|
|
43
|
-
self._dim_expr = None
|
|
44
|
-
self._is_dummy = True
|
|
88
|
+
self._int_expr = None
|
|
89
|
+
if isinstance(value, int):
|
|
90
|
+
if SymIntExpr._exprBuilder is None:
|
|
91
|
+
self._int_expr = None
|
|
45
92
|
else:
|
|
46
|
-
self.
|
|
93
|
+
self._int_expr = SymIntExpr._exprBuilder.constant(value)
|
|
47
94
|
elif isinstance(value, trt.IDimensionExpr):
|
|
48
|
-
self.
|
|
49
|
-
elif isinstance(value,
|
|
50
|
-
self.
|
|
51
|
-
self._is_dummy = value._is_dummy
|
|
52
|
-
self._is_size_tensor = value._is_size_tensor
|
|
95
|
+
self._int_expr = value
|
|
96
|
+
elif isinstance(value, SymIntExpr):
|
|
97
|
+
self._int_expr = value._int_expr
|
|
53
98
|
|
|
54
|
-
def _op(self, op: trt.DimensionOperation, other: Union[int, "
|
|
55
|
-
if self._is_size_tensor:
|
|
56
|
-
raise ValueError("It is not permitted to perform binary operations on size tensor expressions") # trt limitation
|
|
57
|
-
if self._is_dummy:
|
|
58
|
-
return ShapeExpr()
|
|
99
|
+
def _op(self, op: trt.DimensionOperation, other: Union[int, "SymIntExpr"]):
|
|
59
100
|
if isinstance(other, int):
|
|
60
|
-
other =
|
|
61
|
-
return
|
|
101
|
+
other = SymIntExpr(other)
|
|
102
|
+
return SymIntExpr(SymIntExpr._exprBuilder.operation(op, self._expr, other._expr))
|
|
62
103
|
|
|
63
104
|
# Binary operations for +, -, *, //, ==. <
|
|
64
105
|
# Those for ceil_div, max and min are provided as top-level functions of tensorrt.plugin
|
|
65
|
-
def __add__(self, other: Union[int, "
|
|
106
|
+
def __add__(self, other: Union[int, "SymIntExpr"]):
|
|
66
107
|
return self._op(trt.DimensionOperation.SUM, other)
|
|
67
108
|
|
|
68
|
-
def __sub__(self, other: Union[int, "
|
|
109
|
+
def __sub__(self, other: Union[int, "SymIntExpr"]):
|
|
69
110
|
return self._op(trt.DimensionOperation.SUB, other)
|
|
70
111
|
|
|
71
|
-
def __mul__(self, other: Union[int, "
|
|
112
|
+
def __mul__(self, other: Union[int, "SymIntExpr"]):
|
|
72
113
|
return self._op(trt.DimensionOperation.PROD, other)
|
|
73
114
|
|
|
74
|
-
def __floordiv__(self, other: Union[int, "
|
|
115
|
+
def __floordiv__(self, other: Union[int, "SymIntExpr"]):
|
|
75
116
|
return self._op(trt.DimensionOperation.FLOOR_DIV, other)
|
|
76
117
|
|
|
77
|
-
def __eq__(self, other: Union[int, "
|
|
118
|
+
def __eq__(self, other: Union[int, "SymIntExpr"]):
|
|
78
119
|
return self._op(trt.DimensionOperation.EQUAL, other)
|
|
79
120
|
|
|
80
|
-
def __lt__(self, other: Union[int, "
|
|
121
|
+
def __lt__(self, other: Union[int, "SymIntExpr"]):
|
|
81
122
|
return self._op(trt.DimensionOperation.LESS, other)
|
|
82
123
|
|
|
124
|
+
def __repr__(self):
|
|
125
|
+
if self._is_dummy:
|
|
126
|
+
return f"FakeSymIntExpr[id={id(self)}]"
|
|
127
|
+
elif not self.is_constant:
|
|
128
|
+
return f"SymIntExpr[id={id(self)}]"
|
|
129
|
+
return f"SymIntExpr[{self._expr.get_constant_value()}]"
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def is_constant(self) -> bool:
|
|
133
|
+
"""
|
|
134
|
+
`True` if this integer expression is a build-time constant, `False` otherwise.
|
|
135
|
+
|
|
136
|
+
Raises:
|
|
137
|
+
RuntimeError: For fake :class:`SymIntExpr`\s. Check :attr:`is_fake` to determine accessibility.
|
|
138
|
+
"""
|
|
139
|
+
if self._is_dummy:
|
|
140
|
+
raise RuntimeError(
|
|
141
|
+
"Not accessible for fake 'SymIntExpr's. Check is_fake to determine accessibility."
|
|
142
|
+
)
|
|
143
|
+
return self._expr.is_constant()
|
|
144
|
+
|
|
145
|
+
def constant_value(self) -> int:
|
|
146
|
+
"""
|
|
147
|
+
Return value of the constant integer expression.
|
|
148
|
+
|
|
149
|
+
Raises:
|
|
150
|
+
RuntimeError: For non-constant integer expressions. Check :attr:`is_constant` to determine accessibility.
|
|
151
|
+
"""
|
|
152
|
+
if not self.is_constant:
|
|
153
|
+
raise RuntimeError(
|
|
154
|
+
"Not accessible for non-constant integer expressions. Check is_constant to determine accessibility."
|
|
155
|
+
)
|
|
156
|
+
return self._expr.get_constant_value()
|
|
157
|
+
|
|
158
|
+
# Evaluate the underlying trt.IDimensionExpr, if so done lazily
|
|
159
|
+
@property
|
|
160
|
+
def _expr(self):
|
|
161
|
+
return self._int_expr
|
|
162
|
+
|
|
163
|
+
def _clone(self):
|
|
164
|
+
return SymIntExpr(self + 0)
|
|
165
|
+
|
|
166
|
+
class SymExprImpl(trt.ISymExpr):
|
|
167
|
+
def __init__(self, expr: SymIntExpr):
|
|
168
|
+
trt.ISymExpr.__init__(self)
|
|
169
|
+
self.type = trt.PluginArgType.INT
|
|
170
|
+
if isinstance(expr, SymInt32):
|
|
171
|
+
self.dtype = trt.PluginArgDataType.INT32
|
|
172
|
+
elif isinstance(expr, SymInt16):
|
|
173
|
+
self.dtype = trt.PluginArgDataType.INT16
|
|
174
|
+
elif isinstance(expr, SymInt8):
|
|
175
|
+
self.dtype = trt.PluginArgDataType.INT8
|
|
176
|
+
else:
|
|
177
|
+
raise ValueError(f"Unknown SymIntExpr type {type(expr)}")
|
|
178
|
+
|
|
179
|
+
self.expr = expr._expr
|
|
180
|
+
@public_api()
|
|
181
|
+
class SymInt32(SymIntExpr):
|
|
182
|
+
"""
|
|
183
|
+
Symbolic expression for a 32-bit integer
|
|
184
|
+
"""
|
|
185
|
+
def __init__(self, value: Union[int, trt.IDimensionExpr, SymIntExpr] = None):
|
|
186
|
+
super().__init__(value)
|
|
187
|
+
|
|
188
|
+
def __call__(self):
|
|
189
|
+
return SymExprImpl(self)
|
|
190
|
+
@public_api()
|
|
191
|
+
class SymInt8(SymIntExpr):
|
|
192
|
+
"""
|
|
193
|
+
Symbolic expression for an 8-bit integer
|
|
194
|
+
"""
|
|
195
|
+
def __init__(self, value: Union[int, trt.IDimensionExpr, "SymIntExpr"] = None):
|
|
196
|
+
super().__init__(value)
|
|
197
|
+
|
|
198
|
+
@public_api()
|
|
199
|
+
class SymInt16(SymIntExpr):
|
|
200
|
+
"""
|
|
201
|
+
Symbolic expression for a 16-bit integer
|
|
202
|
+
"""
|
|
203
|
+
def __init__(self, value: Union[int, trt.IDimensionExpr, "SymIntExpr"] = None):
|
|
204
|
+
super().__init__(value)
|
|
205
|
+
|
|
206
|
+
# Symbolic expression for a given dimension of a tensor
|
|
207
|
+
@public_api()
|
|
208
|
+
class ShapeExpr(SymInt32):
|
|
209
|
+
"""
|
|
210
|
+
Symbolic expression for single dimension of a tensor
|
|
211
|
+
"""
|
|
212
|
+
def __init__(self, value: Union[int, trt.IDimensionExpr, "ShapeExpr", SymIntExpr] = None):
|
|
213
|
+
"""
|
|
214
|
+
Args:
|
|
215
|
+
value (Union[int, trt.IDimensionExpr, ShapeExpr, SymIntExpr], optional): Constant or another symbolic expression. Defaults to creating a fake shape expression.
|
|
216
|
+
"""
|
|
217
|
+
super().__init__(value)
|
|
218
|
+
self._exprBuilder = SymIntExpr._exprBuilder
|
|
219
|
+
self._is_dummy = False
|
|
220
|
+
self._is_size_tensor = False
|
|
221
|
+
if value is None:
|
|
222
|
+
self._is_dummy = True
|
|
223
|
+
elif isinstance(value, int):
|
|
224
|
+
if self._exprBuilder is None:
|
|
225
|
+
self._is_dummy = True
|
|
226
|
+
elif isinstance(value, ShapeExpr):
|
|
227
|
+
self._is_dummy = value._is_dummy
|
|
228
|
+
self._is_size_tensor = value._is_size_tensor
|
|
229
|
+
elif isinstance(value, SymIntExpr):
|
|
230
|
+
pass
|
|
231
|
+
|
|
232
|
+
def _op(self, op: trt.DimensionOperation, other: Union[int, "ShapeExpr"]):
|
|
233
|
+
if self._is_size_tensor:
|
|
234
|
+
raise ValueError("It is not permitted to perform binary operations on size tensor expressions") # trt limitation
|
|
235
|
+
if self._is_dummy:
|
|
236
|
+
return ShapeExpr()
|
|
237
|
+
return ShapeExpr(super()._op(op, other))
|
|
238
|
+
|
|
83
239
|
def __repr__(self):
|
|
84
240
|
if self._is_dummy:
|
|
85
241
|
return f"FakeShapeExpr[id={id(self)}]"
|
|
@@ -116,7 +272,7 @@ class ShapeExpr:
|
|
|
116
272
|
raise RuntimeError(
|
|
117
273
|
"Not accessible for fake 'ShapeExpr's. Check is_fake to determine accessibility."
|
|
118
274
|
)
|
|
119
|
-
return
|
|
275
|
+
return super().is_constant
|
|
120
276
|
|
|
121
277
|
def constant_value(self) -> int:
|
|
122
278
|
"""
|
|
@@ -129,12 +285,13 @@ class ShapeExpr:
|
|
|
129
285
|
raise RuntimeError(
|
|
130
286
|
"Not accessible for non-constant shape expressions. Check is_constant to determine accessibility."
|
|
131
287
|
)
|
|
132
|
-
return
|
|
288
|
+
return super().constant_value()
|
|
133
289
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
290
|
+
def _clone(self):
|
|
291
|
+
ret = ShapeExpr(self + 0)
|
|
292
|
+
ret._is_dummy = self._is_dummy
|
|
293
|
+
ret._is_size_tensor = self._is_size_tensor
|
|
294
|
+
return ret
|
|
138
295
|
|
|
139
296
|
@public_api()
|
|
140
297
|
class SizeTensorShapeExpr(ShapeExpr):
|
|
@@ -166,69 +323,163 @@ class SizeTensorShapeExpr(ShapeExpr):
|
|
|
166
323
|
|
|
167
324
|
@property
|
|
168
325
|
def _expr(self):
|
|
169
|
-
if self.
|
|
170
|
-
return self.
|
|
171
|
-
|
|
172
|
-
self._dim_expr = super()._exprBuilder.declare_size_tensor(self._size_tensor_desc.index, self._size_tensor_desc.opt._expr, self._size_tensor_desc.upper_bound._expr)
|
|
173
|
-
return self._dim_expr
|
|
326
|
+
if self._int_expr is not None:
|
|
327
|
+
return self._int_expr
|
|
174
328
|
|
|
329
|
+
self._int_expr = super()._exprBuilder.declare_size_tensor(self._size_tensor_desc.index, self._size_tensor_desc.opt._expr, self._size_tensor_desc.upper_bound._expr)
|
|
330
|
+
return self._int_expr
|
|
331
|
+
|
|
175
332
|
def __repr__(self):
|
|
176
333
|
return f"ShapeExpr[is_size_tensor = True, id={id(self)}]"
|
|
177
334
|
|
|
335
|
+
def _from_scalar(s):
|
|
336
|
+
if isinstance(s, int):
|
|
337
|
+
return SymInt32(s)
|
|
338
|
+
elif isinstance(s, float):
|
|
339
|
+
raise ValueError("Float symbolic expressions are not supported")
|
|
340
|
+
else:
|
|
341
|
+
raise ValueError(f"Unsupported type: '{type(s)}'")
|
|
342
|
+
|
|
178
343
|
# Iterable holding `ShapeExpr`s
|
|
179
344
|
@public_api()
|
|
180
|
-
class
|
|
181
|
-
def __init__(self, length: int
|
|
345
|
+
class SymExprs:
|
|
346
|
+
def __init__(self, length: int):
|
|
182
347
|
"""
|
|
183
|
-
Iterable holding
|
|
348
|
+
Iterable holding symbolic expressions
|
|
184
349
|
|
|
185
350
|
Args:
|
|
186
351
|
length (int): Number of dimensions of the tensor
|
|
187
352
|
"""
|
|
188
353
|
self._length = length
|
|
354
|
+
self._exprs = [None] * length
|
|
355
|
+
|
|
356
|
+
@classmethod
|
|
357
|
+
def from_tuple(cls, shape_exprs: Tuple[Union[SymExpr, int]]) -> "SymExprs":
|
|
358
|
+
"""
|
|
359
|
+
Args:
|
|
360
|
+
shape_exprs (Tuple[Union[SymExpr, int]]): Tuple to construct :class:`SymExprs` from
|
|
361
|
+
"""
|
|
362
|
+
|
|
363
|
+
shape_exprs_ = tuple([e if isinstance(e, SymExpr) else _from_scalar(e) for e in shape_exprs])
|
|
364
|
+
inst = cls(len(shape_exprs_))
|
|
365
|
+
inst._exprs = list(shape_exprs_)
|
|
366
|
+
return inst
|
|
367
|
+
|
|
368
|
+
def __iter__(self):
|
|
369
|
+
return iter(self._exprs)
|
|
370
|
+
|
|
371
|
+
def __getitem__(self, index):
|
|
372
|
+
return self._exprs[index]
|
|
373
|
+
|
|
374
|
+
def __len__(self):
|
|
375
|
+
return self._length
|
|
376
|
+
|
|
377
|
+
def __setitem__(self, index, expr):
|
|
378
|
+
if index >= self._length:
|
|
379
|
+
raise IndexError("Index out of range")
|
|
380
|
+
|
|
381
|
+
if not isinstance(expr, SymExpr):
|
|
382
|
+
expr = _from_scalar(expr)
|
|
383
|
+
|
|
384
|
+
self._exprs[index] = expr
|
|
385
|
+
|
|
386
|
+
def __repr__(self):
|
|
387
|
+
return f"SymExprs[{', '.join([s.__repr__() for s in self._exprs])}]"
|
|
388
|
+
|
|
389
|
+
@public_api()
|
|
390
|
+
class ShapeExprs(SymExprs):
|
|
391
|
+
def __init__(self, length, _is_dummy = False):
|
|
392
|
+
"""
|
|
393
|
+
Iterable holding :class:`ShapeExpr`\s, representing a tensor shape
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
length (int): Number of dimensions of the tensor
|
|
397
|
+
"""
|
|
398
|
+
if length > trt.Dims.MAX_DIMS:
|
|
399
|
+
raise ValueError(f"ShapeExprs can only support up to trt.Dims.MAX_DIMS = {trt.Dims.MAX_DIMS} dimensions. {length} given.")
|
|
400
|
+
|
|
401
|
+
super().__init__(length)
|
|
402
|
+
|
|
189
403
|
self._is_dummy = _is_dummy
|
|
190
404
|
if _is_dummy:
|
|
191
|
-
self.
|
|
192
|
-
else:
|
|
193
|
-
self._shapes = [None] * length
|
|
405
|
+
self._exprs = [ShapeExpr()] * length
|
|
194
406
|
|
|
195
407
|
@classmethod
|
|
196
|
-
def from_tuple(cls, shape_exprs: Tuple[Union[ShapeExpr, int]]) -> "
|
|
408
|
+
def from_tuple(cls, shape_exprs: Tuple[Union[ShapeExpr, int]]) -> "ShapeExpr":
|
|
197
409
|
"""
|
|
198
410
|
Args:
|
|
199
411
|
shape_exprs (Tuple[Union[ShapeExpr, int]]): Tuple to construct :class:`ShapeExprs` from
|
|
200
412
|
"""
|
|
413
|
+
|
|
201
414
|
shape_exprs_ = tuple([e if isinstance(e, ShapeExpr) else ShapeExpr(e) for e in shape_exprs])
|
|
202
415
|
inst = cls(len(shape_exprs_))
|
|
203
|
-
inst.
|
|
416
|
+
inst._exprs = list(shape_exprs_)
|
|
204
417
|
return inst
|
|
205
|
-
|
|
418
|
+
|
|
206
419
|
def numel(self) -> ShapeExpr:
|
|
207
420
|
"""
|
|
208
421
|
Returns a symbolic expression for the number of elements
|
|
209
422
|
"""
|
|
210
423
|
ret = ShapeExpr(1)
|
|
211
|
-
for s in self.
|
|
424
|
+
for s in self._exprs:
|
|
212
425
|
ret *= s
|
|
213
426
|
return ret
|
|
214
427
|
|
|
215
|
-
def
|
|
216
|
-
|
|
428
|
+
def __setitem__(self, index, value):
|
|
429
|
+
if index >= self._length:
|
|
430
|
+
raise IndexError("Index out of range")
|
|
431
|
+
|
|
432
|
+
if not isinstance(value, ShapeExpr):
|
|
433
|
+
if not isinstance(value, int):
|
|
434
|
+
raise ValueError(f"Value should be int or ShapeExpr. Got '{type(value)}'")
|
|
435
|
+
value = ShapeExpr(value)
|
|
436
|
+
|
|
437
|
+
self._exprs[index] = value
|
|
438
|
+
|
|
439
|
+
def __repr__(self):
|
|
440
|
+
return f"ShapeExprs[{', '.join([s.__repr__() for s in self._exprs])}]"
|
|
217
441
|
|
|
218
|
-
def
|
|
219
|
-
|
|
442
|
+
def _clone(self):
|
|
443
|
+
ret = ShapeExprs.from_tuple((e._clone() for e in self._exprs))
|
|
444
|
+
ret._is_dummy = self._is_dummy
|
|
445
|
+
return ret
|
|
446
|
+
|
|
447
|
+
@public_api()
|
|
448
|
+
class SymIntExprs(SymExprs):
|
|
449
|
+
def __init__(self, length):
|
|
450
|
+
"""
|
|
451
|
+
Iterable holding :class:`SymIntExpr`\s
|
|
220
452
|
|
|
221
|
-
|
|
222
|
-
|
|
453
|
+
Args:
|
|
454
|
+
length (int): Number of symbolic expressions in the iterable
|
|
455
|
+
"""
|
|
456
|
+
super().__init__(length)
|
|
457
|
+
|
|
458
|
+
@classmethod
|
|
459
|
+
def from_tuple(cls, shape_exprs: Tuple[Union[SymIntExpr, int]]) -> "SymIntExpr":
|
|
460
|
+
"""
|
|
461
|
+
Args:
|
|
462
|
+
shape_exprs (Tuple[Union[SymIntExpr, int]]): Tuple to construct :class:`SymIntExprs` from
|
|
463
|
+
"""
|
|
464
|
+
|
|
465
|
+
shape_exprs_ = tuple([e if isinstance(e, SymIntExpr) else SymIntExpr(e) for e in shape_exprs])
|
|
466
|
+
inst = cls(len(shape_exprs_))
|
|
467
|
+
inst._exprs = list(shape_exprs_)
|
|
468
|
+
return inst
|
|
223
469
|
|
|
224
|
-
def __setitem__(self, index,
|
|
470
|
+
def __setitem__(self, index, value):
|
|
225
471
|
if index >= self._length:
|
|
226
472
|
raise IndexError("Index out of range")
|
|
227
|
-
|
|
228
|
-
|
|
473
|
+
|
|
474
|
+
if not isinstance(value, SymIntExpr):
|
|
475
|
+
if not isinstance(value, int):
|
|
476
|
+
raise ValueError(f"Value should be int or SymIntExpr. Got '{type(value)}'")
|
|
477
|
+
value = SymIntExpr(value)
|
|
478
|
+
|
|
479
|
+
self._exprs[index] = value
|
|
480
|
+
|
|
229
481
|
def __repr__(self):
|
|
230
|
-
return f"
|
|
231
|
-
|
|
482
|
+
return f"SymIntExprs[{', '.join([s.__repr__() for s in self._exprs])}]"
|
|
232
483
|
|
|
233
484
|
# Numerical representation of a tensor shape
|
|
234
485
|
@public_api()
|
|
@@ -237,7 +488,7 @@ class Shape:
|
|
|
237
488
|
Numerical representation of a tensor shape
|
|
238
489
|
"""
|
|
239
490
|
def __init__(
|
|
240
|
-
self, tensor_desc: Union[Tuple[int], trt.DynamicPluginTensorDesc, trt.PluginTensorDesc]
|
|
491
|
+
self, tensor_desc: Union[Tuple[int], trt.DynamicPluginTensorDesc, trt.PluginTensorDesc] = None
|
|
241
492
|
):
|
|
242
493
|
self._is_dynamic = None # set lazily
|
|
243
494
|
if isinstance(tensor_desc, trt.DynamicPluginTensorDesc):
|
|
@@ -250,6 +501,9 @@ class Shape:
|
|
|
250
501
|
elif isinstance(tensor_desc, tuple):
|
|
251
502
|
self._shapes = trt.Dims(tensor_desc)
|
|
252
503
|
self._length = len(self._shapes)
|
|
504
|
+
elif tensor_desc is None:
|
|
505
|
+
self._length = 0
|
|
506
|
+
self._shapes = trt.Dims(0)
|
|
253
507
|
else:
|
|
254
508
|
raise ValueError("Unsupported type used for constructing trt.plugin.Shape! tensor_desc must be a Tuple[int], trt.DynamicPluginTensorDesc, or trt.PluginTensorDesc")
|
|
255
509
|
|
|
@@ -335,6 +589,11 @@ class Shape:
|
|
|
335
589
|
raise IndexError("Index out of range")
|
|
336
590
|
self._shapes[index] = val
|
|
337
591
|
|
|
592
|
+
def _clone(self):
|
|
593
|
+
ret = Shape()
|
|
594
|
+
ret.__dict__.update(self.__dict__)
|
|
595
|
+
return ret
|
|
596
|
+
|
|
338
597
|
|
|
339
598
|
# Descriptor for a tensor
|
|
340
599
|
# A `TensorDesc` never contains nor refers to any tensor data.
|
|
@@ -416,8 +675,7 @@ class TensorDesc:
|
|
|
416
675
|
def _(inp: tensorrt.plugin.TensorDesc) -> tensorrt.plugin.TensorDesc:
|
|
417
676
|
return inp.like()
|
|
418
677
|
"""
|
|
419
|
-
cloned =
|
|
420
|
-
cloned.__dict__.update(self.__dict__)
|
|
678
|
+
cloned = self._clone()
|
|
421
679
|
cloned._immutable = False
|
|
422
680
|
return cloned
|
|
423
681
|
|
|
@@ -436,10 +694,19 @@ class TensorDesc:
|
|
|
436
694
|
def _(inp: tensorrt.plugin.TensorDesc) -> tensorrt.plugin.TensorDesc:
|
|
437
695
|
return inp.aliased()
|
|
438
696
|
"""
|
|
697
|
+
cloned = self._clone()
|
|
698
|
+
cloned._immutable = False
|
|
699
|
+
cloned._aliased_to = self
|
|
700
|
+
cloned._immutable = True
|
|
701
|
+
return cloned
|
|
702
|
+
|
|
703
|
+
def _clone(self) -> "TensorDesc":
|
|
439
704
|
cloned = TensorDesc()
|
|
440
705
|
cloned.__dict__.update(self.__dict__)
|
|
441
706
|
cloned._immutable = False
|
|
442
|
-
cloned.
|
|
707
|
+
cloned._shape_expr = self._shape_expr._clone()
|
|
708
|
+
if self._shape is not None:
|
|
709
|
+
cloned._shape = self._shape._clone()
|
|
443
710
|
cloned._immutable = True
|
|
444
711
|
return cloned
|
|
445
712
|
|
|
@@ -823,3 +1090,39 @@ class Tensor:
|
|
|
823
1090
|
|
|
824
1091
|
cloned._aliased_to = self
|
|
825
1092
|
return cloned
|
|
1093
|
+
|
|
1094
|
+
if IS_AOT_ENABLED:
|
|
1095
|
+
@public_api()
|
|
1096
|
+
class KernelLaunchParams:
|
|
1097
|
+
"""
|
|
1098
|
+
Args:
|
|
1099
|
+
grid_x (Union[int, trt.IDimensionExpr, SymInt32], optional): The grid x dimension. Defaults to 1.
|
|
1100
|
+
grid_y (Union[int, trt.IDimensionExpr, SymInt32], optional): The grid y dimension. Defaults to 1.
|
|
1101
|
+
grid_z (Union[int, trt.IDimensionExpr, SymInt32], optional): The grid z dimension. Defaults to 1.
|
|
1102
|
+
block_x (Union[int, trt.IDimensionExpr, SymInt32], optional): The x dimension of each thread block. Defaults to 1.
|
|
1103
|
+
block_y (Union[int, trt.IDimensionExpr, SymInt32], optional): The y dimension of each thread block. Defaults to 1.
|
|
1104
|
+
block_z (Union[int, trt.IDimensionExpr, SymInt32], optional): The z dimension of each thread block. Defaults to 1.
|
|
1105
|
+
shared_mem (Union[int, trt.IDimensionExpr, SymInt32], optional): Shared-memory per thread block in bytes. Defaults to 0.
|
|
1106
|
+
"""
|
|
1107
|
+
def __init__(self,
|
|
1108
|
+
grid_x: Union[int, trt.IDimensionExpr, SymInt32] = 1,
|
|
1109
|
+
grid_y: Union[int, trt.IDimensionExpr, SymInt32] = 1,
|
|
1110
|
+
grid_z: Union[int, trt.IDimensionExpr, SymInt32] = 1,
|
|
1111
|
+
block_x: Union[int, trt.IDimensionExpr, SymInt32] = 1,
|
|
1112
|
+
block_y: Union[int, trt.IDimensionExpr, SymInt32] = 1,
|
|
1113
|
+
block_z: Union[int, trt.IDimensionExpr, SymInt32] = 1,
|
|
1114
|
+
shared_mem: Union[int, trt.IDimensionExpr, SymInt32] = 0):
|
|
1115
|
+
self.grid_x = SymInt32(grid_x)
|
|
1116
|
+
self.grid_y = SymInt32(grid_y)
|
|
1117
|
+
self.grid_z = SymInt32(grid_z)
|
|
1118
|
+
self.block_x = SymInt32(block_x)
|
|
1119
|
+
self.block_y = SymInt32(block_y)
|
|
1120
|
+
self.block_z = SymInt32(block_z)
|
|
1121
|
+
self.shared_mem = SymInt32(shared_mem)
|
|
1122
|
+
|
|
1123
|
+
|
|
1124
|
+
def __setattr__(self, name, value):
|
|
1125
|
+
if name in ["grid_x", "grid_y", "grid_z", "block_x", "block_y", "block_z", "shared_mem"]:
|
|
1126
|
+
self.__dict__[name] = SymInt32(value)
|
|
1127
|
+
else:
|
|
1128
|
+
raise AttributeError(f"KernelLaunchParams object has no attribute '{name}'")
|