tensorrt-cu12-bindings 10.14.1.48.post1__cp39-none-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tensorrt_bindings/__init__.py +224 -0
- tensorrt_bindings/plugin/__init__.py +46 -0
- tensorrt_bindings/plugin/_autotune.py +270 -0
- tensorrt_bindings/plugin/_export.py +39 -0
- tensorrt_bindings/plugin/_lib.py +691 -0
- tensorrt_bindings/plugin/_plugin_class.py +459 -0
- tensorrt_bindings/plugin/_tensor.py +1128 -0
- tensorrt_bindings/plugin/_top_level.py +132 -0
- tensorrt_bindings/plugin/_utils.py +77 -0
- tensorrt_bindings/plugin/_validate.py +475 -0
- tensorrt_bindings/tensorrt.so +0 -0
- tensorrt_cu12_bindings-10.14.1.48.post1.dist-info/LICENSE.txt +180 -0
- tensorrt_cu12_bindings-10.14.1.48.post1.dist-info/METADATA +17 -0
- tensorrt_cu12_bindings-10.14.1.48.post1.dist-info/RECORD +17 -0
- tensorrt_cu12_bindings-10.14.1.48.post1.dist-info/WHEEL +5 -0
- tensorrt_cu12_bindings-10.14.1.48.post1.dist-info/top_level.txt +1 -0
- tensorrt_cu12_bindings-10.14.1.48.post1.dist-info/zip-safe +1 -0
|
@@ -0,0 +1,1128 @@
|
|
|
1
|
+
#
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
|
|
18
|
+
import tensorrt as trt
|
|
19
|
+
from typing import Tuple, Union
|
|
20
|
+
import numpy as np
|
|
21
|
+
from ._export import public_api, IS_AOT_ENABLED
|
|
22
|
+
from abc import ABC, abstractmethod
|
|
23
|
+
|
|
24
|
+
class SymExpr(ABC):
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def _op(self, op, other):
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def __add__(self, other):
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def __sub__(self, other):
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def __mul__(self, other):
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def __floordiv__(self, other):
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def __eq__(self, other):
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def __lt__(self, other):
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def __repr__(self):
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def is_constant(self) -> bool:
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
@abstractmethod
|
|
65
|
+
def constant_value(self) -> int:
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
# Evaluate the underlying trt.IDimensionExpr, if so done lazily
|
|
69
|
+
@property
|
|
70
|
+
@abstractmethod
|
|
71
|
+
def _expr(self):
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
class SymIntExprMeta(type(SymExpr)):
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
class SymIntExpr(SymExpr, metaclass=SymIntExprMeta):
|
|
78
|
+
"""
|
|
79
|
+
Symbolic integer (scalar) expression
|
|
80
|
+
"""
|
|
81
|
+
_exprBuilder = None # trt.IExprBuilder instance. Populated when a shape-calculation context is entered.
|
|
82
|
+
|
|
83
|
+
def __init__(self, value: Union[int, trt.IDimensionExpr, "SymIntExpr"] = None):
|
|
84
|
+
"""
|
|
85
|
+
Args:
|
|
86
|
+
value (Union[int, trt.IDimensionExpr, SymIntExpr], optional): Constant or another symbolic expression. Defaults to creating a fake shape expression.
|
|
87
|
+
"""
|
|
88
|
+
self._int_expr = None
|
|
89
|
+
if isinstance(value, int):
|
|
90
|
+
if SymIntExpr._exprBuilder is None:
|
|
91
|
+
self._int_expr = None
|
|
92
|
+
else:
|
|
93
|
+
self._int_expr = SymIntExpr._exprBuilder.constant(value)
|
|
94
|
+
elif isinstance(value, trt.IDimensionExpr):
|
|
95
|
+
self._int_expr = value
|
|
96
|
+
elif isinstance(value, SymIntExpr):
|
|
97
|
+
self._int_expr = value._int_expr
|
|
98
|
+
|
|
99
|
+
def _op(self, op: trt.DimensionOperation, other: Union[int, "SymIntExpr"]):
|
|
100
|
+
if isinstance(other, int):
|
|
101
|
+
other = SymIntExpr(other)
|
|
102
|
+
return SymIntExpr(SymIntExpr._exprBuilder.operation(op, self._expr, other._expr))
|
|
103
|
+
|
|
104
|
+
# Binary operations for +, -, *, //, ==. <
|
|
105
|
+
# Those for ceil_div, max and min are provided as top-level functions of tensorrt.plugin
|
|
106
|
+
def __add__(self, other: Union[int, "SymIntExpr"]):
|
|
107
|
+
return self._op(trt.DimensionOperation.SUM, other)
|
|
108
|
+
|
|
109
|
+
def __sub__(self, other: Union[int, "SymIntExpr"]):
|
|
110
|
+
return self._op(trt.DimensionOperation.SUB, other)
|
|
111
|
+
|
|
112
|
+
def __mul__(self, other: Union[int, "SymIntExpr"]):
|
|
113
|
+
return self._op(trt.DimensionOperation.PROD, other)
|
|
114
|
+
|
|
115
|
+
def __floordiv__(self, other: Union[int, "SymIntExpr"]):
|
|
116
|
+
return self._op(trt.DimensionOperation.FLOOR_DIV, other)
|
|
117
|
+
|
|
118
|
+
def __eq__(self, other: Union[int, "SymIntExpr"]):
|
|
119
|
+
return self._op(trt.DimensionOperation.EQUAL, other)
|
|
120
|
+
|
|
121
|
+
def __lt__(self, other: Union[int, "SymIntExpr"]):
|
|
122
|
+
return self._op(trt.DimensionOperation.LESS, other)
|
|
123
|
+
|
|
124
|
+
def __repr__(self):
|
|
125
|
+
if self._is_dummy:
|
|
126
|
+
return f"FakeSymIntExpr[id={id(self)}]"
|
|
127
|
+
elif not self.is_constant:
|
|
128
|
+
return f"SymIntExpr[id={id(self)}]"
|
|
129
|
+
return f"SymIntExpr[{self._expr.get_constant_value()}]"
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def is_constant(self) -> bool:
|
|
133
|
+
"""
|
|
134
|
+
`True` if this integer expression is a build-time constant, `False` otherwise.
|
|
135
|
+
|
|
136
|
+
Raises:
|
|
137
|
+
RuntimeError: For fake :class:`SymIntExpr`\s. Check :attr:`is_fake` to determine accessibility.
|
|
138
|
+
"""
|
|
139
|
+
if self._is_dummy:
|
|
140
|
+
raise RuntimeError(
|
|
141
|
+
"Not accessible for fake 'SymIntExpr's. Check is_fake to determine accessibility."
|
|
142
|
+
)
|
|
143
|
+
return self._expr.is_constant()
|
|
144
|
+
|
|
145
|
+
def constant_value(self) -> int:
|
|
146
|
+
"""
|
|
147
|
+
Return value of the constant integer expression.
|
|
148
|
+
|
|
149
|
+
Raises:
|
|
150
|
+
RuntimeError: For non-constant integer expressions. Check :attr:`is_constant` to determine accessibility.
|
|
151
|
+
"""
|
|
152
|
+
if not self.is_constant:
|
|
153
|
+
raise RuntimeError(
|
|
154
|
+
"Not accessible for non-constant integer expressions. Check is_constant to determine accessibility."
|
|
155
|
+
)
|
|
156
|
+
return self._expr.get_constant_value()
|
|
157
|
+
|
|
158
|
+
# Evaluate the underlying trt.IDimensionExpr, if so done lazily
|
|
159
|
+
@property
|
|
160
|
+
def _expr(self):
|
|
161
|
+
return self._int_expr
|
|
162
|
+
|
|
163
|
+
def _clone(self):
|
|
164
|
+
return SymIntExpr(self + 0)
|
|
165
|
+
|
|
166
|
+
class SymExprImpl(trt.ISymExpr):
|
|
167
|
+
def __init__(self, expr: SymIntExpr):
|
|
168
|
+
trt.ISymExpr.__init__(self)
|
|
169
|
+
self.type = trt.PluginArgType.INT
|
|
170
|
+
if isinstance(expr, SymInt32):
|
|
171
|
+
self.dtype = trt.PluginArgDataType.INT32
|
|
172
|
+
elif isinstance(expr, SymInt16):
|
|
173
|
+
self.dtype = trt.PluginArgDataType.INT16
|
|
174
|
+
elif isinstance(expr, SymInt8):
|
|
175
|
+
self.dtype = trt.PluginArgDataType.INT8
|
|
176
|
+
else:
|
|
177
|
+
raise ValueError(f"Unknown SymIntExpr type {type(expr)}")
|
|
178
|
+
|
|
179
|
+
self.expr = expr._expr
|
|
180
|
+
@public_api()
|
|
181
|
+
class SymInt32(SymIntExpr):
|
|
182
|
+
"""
|
|
183
|
+
Symbolic expression for a 32-bit integer
|
|
184
|
+
"""
|
|
185
|
+
def __init__(self, value: Union[int, trt.IDimensionExpr, SymIntExpr] = None):
|
|
186
|
+
super().__init__(value)
|
|
187
|
+
|
|
188
|
+
def __call__(self):
|
|
189
|
+
return SymExprImpl(self)
|
|
190
|
+
@public_api()
|
|
191
|
+
class SymInt8(SymIntExpr):
|
|
192
|
+
"""
|
|
193
|
+
Symbolic expression for an 8-bit integer
|
|
194
|
+
"""
|
|
195
|
+
def __init__(self, value: Union[int, trt.IDimensionExpr, "SymIntExpr"] = None):
|
|
196
|
+
super().__init__(value)
|
|
197
|
+
|
|
198
|
+
@public_api()
|
|
199
|
+
class SymInt16(SymIntExpr):
|
|
200
|
+
"""
|
|
201
|
+
Symbolic expression for a 16-bit integer
|
|
202
|
+
"""
|
|
203
|
+
def __init__(self, value: Union[int, trt.IDimensionExpr, "SymIntExpr"] = None):
|
|
204
|
+
super().__init__(value)
|
|
205
|
+
|
|
206
|
+
# Symbolic expression for a given dimension of a tensor
|
|
207
|
+
@public_api()
|
|
208
|
+
class ShapeExpr(SymInt32):
|
|
209
|
+
"""
|
|
210
|
+
Symbolic expression for single dimension of a tensor
|
|
211
|
+
"""
|
|
212
|
+
def __init__(self, value: Union[int, trt.IDimensionExpr, "ShapeExpr", SymIntExpr] = None):
|
|
213
|
+
"""
|
|
214
|
+
Args:
|
|
215
|
+
value (Union[int, trt.IDimensionExpr, ShapeExpr, SymIntExpr], optional): Constant or another symbolic expression. Defaults to creating a fake shape expression.
|
|
216
|
+
"""
|
|
217
|
+
super().__init__(value)
|
|
218
|
+
self._exprBuilder = SymIntExpr._exprBuilder
|
|
219
|
+
self._is_dummy = False
|
|
220
|
+
self._is_size_tensor = False
|
|
221
|
+
if value is None:
|
|
222
|
+
self._is_dummy = True
|
|
223
|
+
elif isinstance(value, int):
|
|
224
|
+
if self._exprBuilder is None:
|
|
225
|
+
self._is_dummy = True
|
|
226
|
+
elif isinstance(value, ShapeExpr):
|
|
227
|
+
self._is_dummy = value._is_dummy
|
|
228
|
+
self._is_size_tensor = value._is_size_tensor
|
|
229
|
+
elif isinstance(value, SymIntExpr):
|
|
230
|
+
pass
|
|
231
|
+
|
|
232
|
+
def _op(self, op: trt.DimensionOperation, other: Union[int, "ShapeExpr"]):
|
|
233
|
+
if self._is_size_tensor:
|
|
234
|
+
raise ValueError("It is not permitted to perform binary operations on size tensor expressions") # trt limitation
|
|
235
|
+
if self._is_dummy:
|
|
236
|
+
return ShapeExpr()
|
|
237
|
+
return ShapeExpr(super()._op(op, other))
|
|
238
|
+
|
|
239
|
+
def __repr__(self):
|
|
240
|
+
if self._is_dummy:
|
|
241
|
+
return f"FakeShapeExpr[id={id(self)}]"
|
|
242
|
+
elif not self.is_constant:
|
|
243
|
+
return f"ShapeExpr[id={id(self)}]"
|
|
244
|
+
return f"ShapeExpr[{self._expr.get_constant_value()}]"
|
|
245
|
+
|
|
246
|
+
# A ShapeExpr may be "fake" when it is accessed in a non-shape calculation context. Fake `ShapeExpr`s are externally indistinguishable unless `is_constant` or `constant_value` is required.
|
|
247
|
+
# Therefore, constant checks/access must occur conditionally after evaluating `is_fake`.
|
|
248
|
+
@property
|
|
249
|
+
def is_fake(self) -> bool:
|
|
250
|
+
"""
|
|
251
|
+
A ShapeExpr may be "fake" when it is accessed in a non-shape calculation context.
|
|
252
|
+
Fake `ShapeExpr`s are externally indistinguishable unless `is_constant` or `constant_value` is required.
|
|
253
|
+
"""
|
|
254
|
+
return self._is_dummy
|
|
255
|
+
|
|
256
|
+
@property
|
|
257
|
+
def is_size_tensor(self) -> bool:
|
|
258
|
+
"""
|
|
259
|
+
`True` if this represents a size tensor, `False` otherwise.
|
|
260
|
+
"""
|
|
261
|
+
return self._is_size_tensor
|
|
262
|
+
|
|
263
|
+
@property
|
|
264
|
+
def is_constant(self) -> bool:
|
|
265
|
+
"""
|
|
266
|
+
`True` if this shape expression is a build-time constant, `False` otherwise.
|
|
267
|
+
|
|
268
|
+
Raises:
|
|
269
|
+
RuntimeError: For fake :class:`ShapeExpr`\s. Check :attr:`is_fake` to determine accessibility.
|
|
270
|
+
"""
|
|
271
|
+
if self._is_dummy:
|
|
272
|
+
raise RuntimeError(
|
|
273
|
+
"Not accessible for fake 'ShapeExpr's. Check is_fake to determine accessibility."
|
|
274
|
+
)
|
|
275
|
+
return super().is_constant
|
|
276
|
+
|
|
277
|
+
def constant_value(self) -> int:
|
|
278
|
+
"""
|
|
279
|
+
Return value of the constant shape expression.
|
|
280
|
+
|
|
281
|
+
Raises:
|
|
282
|
+
RuntimeError: For non-constant shape expressions. Check :attr:`is_constant` to determine accessibility.
|
|
283
|
+
"""
|
|
284
|
+
if not self.is_constant:
|
|
285
|
+
raise RuntimeError(
|
|
286
|
+
"Not accessible for non-constant shape expressions. Check is_constant to determine accessibility."
|
|
287
|
+
)
|
|
288
|
+
return super().constant_value()
|
|
289
|
+
|
|
290
|
+
def _clone(self):
|
|
291
|
+
ret = ShapeExpr(self + 0)
|
|
292
|
+
ret._is_dummy = self._is_dummy
|
|
293
|
+
ret._is_size_tensor = self._is_size_tensor
|
|
294
|
+
return ret
|
|
295
|
+
|
|
296
|
+
@public_api()
|
|
297
|
+
class SizeTensorShapeExpr(ShapeExpr):
|
|
298
|
+
"""
|
|
299
|
+
Extends :class:`ShapeExpr`
|
|
300
|
+
|
|
301
|
+
A shape expression that represent a size tensor
|
|
302
|
+
|
|
303
|
+
"""
|
|
304
|
+
def __init__(self, size_tensor_desc: "SizeTensorDesc"):
|
|
305
|
+
"""
|
|
306
|
+
.. note:: It is recommended to use :attr:`SizeTensorDesc.expr` to get a :class:`SizeTensorShapeExpr` representing a size tensor
|
|
307
|
+
"""
|
|
308
|
+
super().__init__()
|
|
309
|
+
self._is_size_tensor = True
|
|
310
|
+
self._is_dummy = size_tensor_desc.opt.is_fake
|
|
311
|
+
self._size_tensor_desc = size_tensor_desc
|
|
312
|
+
|
|
313
|
+
def _op(self, op: trt.DimensionOperation, other: Union[int, "ShapeExpr"]):
|
|
314
|
+
raise ValueError("It is not permitted to perform binary operations on size tensor expressions") # TRT limitation
|
|
315
|
+
|
|
316
|
+
@property
|
|
317
|
+
def is_constant(self):
|
|
318
|
+
if self._is_dummy:
|
|
319
|
+
raise RuntimeError(
|
|
320
|
+
"Not accessible for fake 'ShapeExpr's. Check is_fake to determine accessibility."
|
|
321
|
+
)
|
|
322
|
+
return False
|
|
323
|
+
|
|
324
|
+
@property
|
|
325
|
+
def _expr(self):
|
|
326
|
+
if self._int_expr is not None:
|
|
327
|
+
return self._int_expr
|
|
328
|
+
|
|
329
|
+
self._int_expr = super()._exprBuilder.declare_size_tensor(self._size_tensor_desc.index, self._size_tensor_desc.opt._expr, self._size_tensor_desc.upper_bound._expr)
|
|
330
|
+
return self._int_expr
|
|
331
|
+
|
|
332
|
+
def __repr__(self):
|
|
333
|
+
return f"ShapeExpr[is_size_tensor = True, id={id(self)}]"
|
|
334
|
+
|
|
335
|
+
def _from_scalar(s):
|
|
336
|
+
if isinstance(s, int):
|
|
337
|
+
return SymInt32(s)
|
|
338
|
+
elif isinstance(s, float):
|
|
339
|
+
raise ValueError("Float symbolic expressions are not supported")
|
|
340
|
+
else:
|
|
341
|
+
raise ValueError(f"Unsupported type: '{type(s)}'")
|
|
342
|
+
|
|
343
|
+
# Iterable holding `ShapeExpr`s
|
|
344
|
+
@public_api()
|
|
345
|
+
class SymExprs:
|
|
346
|
+
def __init__(self, length: int):
|
|
347
|
+
"""
|
|
348
|
+
Iterable holding symbolic expressions
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
length (int): Number of dimensions of the tensor
|
|
352
|
+
"""
|
|
353
|
+
self._length = length
|
|
354
|
+
self._exprs = [None] * length
|
|
355
|
+
|
|
356
|
+
@classmethod
|
|
357
|
+
def from_tuple(cls, shape_exprs: Tuple[Union[SymExpr, int]]) -> "SymExprs":
|
|
358
|
+
"""
|
|
359
|
+
Args:
|
|
360
|
+
shape_exprs (Tuple[Union[SymExpr, int]]): Tuple to construct :class:`SymExprs` from
|
|
361
|
+
"""
|
|
362
|
+
|
|
363
|
+
shape_exprs_ = tuple([e if isinstance(e, SymExpr) else _from_scalar(e) for e in shape_exprs])
|
|
364
|
+
inst = cls(len(shape_exprs_))
|
|
365
|
+
inst._exprs = list(shape_exprs_)
|
|
366
|
+
return inst
|
|
367
|
+
|
|
368
|
+
def __iter__(self):
|
|
369
|
+
return iter(self._exprs)
|
|
370
|
+
|
|
371
|
+
def __getitem__(self, index):
|
|
372
|
+
return self._exprs[index]
|
|
373
|
+
|
|
374
|
+
def __len__(self):
|
|
375
|
+
return self._length
|
|
376
|
+
|
|
377
|
+
def __setitem__(self, index, expr):
|
|
378
|
+
if index >= self._length:
|
|
379
|
+
raise IndexError("Index out of range")
|
|
380
|
+
|
|
381
|
+
if not isinstance(expr, SymExpr):
|
|
382
|
+
expr = _from_scalar(expr)
|
|
383
|
+
|
|
384
|
+
self._exprs[index] = expr
|
|
385
|
+
|
|
386
|
+
def __repr__(self):
|
|
387
|
+
return f"SymExprs[{', '.join([s.__repr__() for s in self._exprs])}]"
|
|
388
|
+
|
|
389
|
+
@public_api()
|
|
390
|
+
class ShapeExprs(SymExprs):
|
|
391
|
+
def __init__(self, length, _is_dummy = False):
|
|
392
|
+
"""
|
|
393
|
+
Iterable holding :class:`ShapeExpr`\s, representing a tensor shape
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
length (int): Number of dimensions of the tensor
|
|
397
|
+
"""
|
|
398
|
+
if length > trt.Dims.MAX_DIMS:
|
|
399
|
+
raise ValueError(f"ShapeExprs can only support up to trt.Dims.MAX_DIMS = {trt.Dims.MAX_DIMS} dimensions. {length} given.")
|
|
400
|
+
|
|
401
|
+
super().__init__(length)
|
|
402
|
+
|
|
403
|
+
self._is_dummy = _is_dummy
|
|
404
|
+
if _is_dummy:
|
|
405
|
+
self._exprs = [ShapeExpr()] * length
|
|
406
|
+
|
|
407
|
+
@classmethod
|
|
408
|
+
def from_tuple(cls, shape_exprs: Tuple[Union[ShapeExpr, int]]) -> "ShapeExpr":
|
|
409
|
+
"""
|
|
410
|
+
Args:
|
|
411
|
+
shape_exprs (Tuple[Union[ShapeExpr, int]]): Tuple to construct :class:`ShapeExprs` from
|
|
412
|
+
"""
|
|
413
|
+
|
|
414
|
+
shape_exprs_ = tuple([e if isinstance(e, ShapeExpr) else ShapeExpr(e) for e in shape_exprs])
|
|
415
|
+
inst = cls(len(shape_exprs_))
|
|
416
|
+
inst._exprs = list(shape_exprs_)
|
|
417
|
+
return inst
|
|
418
|
+
|
|
419
|
+
def numel(self) -> ShapeExpr:
|
|
420
|
+
"""
|
|
421
|
+
Returns a symbolic expression for the number of elements
|
|
422
|
+
"""
|
|
423
|
+
ret = ShapeExpr(1)
|
|
424
|
+
for s in self._exprs:
|
|
425
|
+
ret *= s
|
|
426
|
+
return ret
|
|
427
|
+
|
|
428
|
+
def __setitem__(self, index, value):
|
|
429
|
+
if index >= self._length:
|
|
430
|
+
raise IndexError("Index out of range")
|
|
431
|
+
|
|
432
|
+
if not isinstance(value, ShapeExpr):
|
|
433
|
+
if not isinstance(value, int):
|
|
434
|
+
raise ValueError(f"Value should be int or ShapeExpr. Got '{type(value)}'")
|
|
435
|
+
value = ShapeExpr(value)
|
|
436
|
+
|
|
437
|
+
self._exprs[index] = value
|
|
438
|
+
|
|
439
|
+
def __repr__(self):
|
|
440
|
+
return f"ShapeExprs[{', '.join([s.__repr__() for s in self._exprs])}]"
|
|
441
|
+
|
|
442
|
+
def _clone(self):
|
|
443
|
+
ret = ShapeExprs.from_tuple((e._clone() for e in self._exprs))
|
|
444
|
+
ret._is_dummy = self._is_dummy
|
|
445
|
+
return ret
|
|
446
|
+
|
|
447
|
+
@public_api()
|
|
448
|
+
class SymIntExprs(SymExprs):
|
|
449
|
+
def __init__(self, length):
|
|
450
|
+
"""
|
|
451
|
+
Iterable holding :class:`SymIntExpr`\s
|
|
452
|
+
|
|
453
|
+
Args:
|
|
454
|
+
length (int): Number of symbolic expressions in the iterable
|
|
455
|
+
"""
|
|
456
|
+
super().__init__(length)
|
|
457
|
+
|
|
458
|
+
@classmethod
|
|
459
|
+
def from_tuple(cls, shape_exprs: Tuple[Union[SymIntExpr, int]]) -> "SymIntExpr":
|
|
460
|
+
"""
|
|
461
|
+
Args:
|
|
462
|
+
shape_exprs (Tuple[Union[SymIntExpr, int]]): Tuple to construct :class:`SymIntExprs` from
|
|
463
|
+
"""
|
|
464
|
+
|
|
465
|
+
shape_exprs_ = tuple([e if isinstance(e, SymIntExpr) else SymIntExpr(e) for e in shape_exprs])
|
|
466
|
+
inst = cls(len(shape_exprs_))
|
|
467
|
+
inst._exprs = list(shape_exprs_)
|
|
468
|
+
return inst
|
|
469
|
+
|
|
470
|
+
def __setitem__(self, index, value):
|
|
471
|
+
if index >= self._length:
|
|
472
|
+
raise IndexError("Index out of range")
|
|
473
|
+
|
|
474
|
+
if not isinstance(value, SymIntExpr):
|
|
475
|
+
if not isinstance(value, int):
|
|
476
|
+
raise ValueError(f"Value should be int or SymIntExpr. Got '{type(value)}'")
|
|
477
|
+
value = SymIntExpr(value)
|
|
478
|
+
|
|
479
|
+
self._exprs[index] = value
|
|
480
|
+
|
|
481
|
+
def __repr__(self):
|
|
482
|
+
return f"SymIntExprs[{', '.join([s.__repr__() for s in self._exprs])}]"
|
|
483
|
+
|
|
484
|
+
# Numerical representation of a tensor shape
|
|
485
|
+
@public_api()
|
|
486
|
+
class Shape:
|
|
487
|
+
"""
|
|
488
|
+
Numerical representation of a tensor shape
|
|
489
|
+
"""
|
|
490
|
+
def __init__(
|
|
491
|
+
self, tensor_desc: Union[Tuple[int], trt.DynamicPluginTensorDesc, trt.PluginTensorDesc] = None
|
|
492
|
+
):
|
|
493
|
+
self._is_dynamic = None # set lazily
|
|
494
|
+
if isinstance(tensor_desc, trt.DynamicPluginTensorDesc):
|
|
495
|
+
self._length = len(tensor_desc.desc.dims)
|
|
496
|
+
self._shapes = tensor_desc.desc.dims
|
|
497
|
+
self._desc = tensor_desc
|
|
498
|
+
elif isinstance(tensor_desc, trt.PluginTensorDesc):
|
|
499
|
+
self._length = len(tensor_desc.dims)
|
|
500
|
+
self._shapes = tensor_desc.dims
|
|
501
|
+
elif isinstance(tensor_desc, tuple):
|
|
502
|
+
self._shapes = trt.Dims(tensor_desc)
|
|
503
|
+
self._length = len(self._shapes)
|
|
504
|
+
elif tensor_desc is None:
|
|
505
|
+
self._length = 0
|
|
506
|
+
self._shapes = trt.Dims(0)
|
|
507
|
+
else:
|
|
508
|
+
raise ValueError("Unsupported type used for constructing trt.plugin.Shape! tensor_desc must be a Tuple[int], trt.DynamicPluginTensorDesc, or trt.PluginTensorDesc")
|
|
509
|
+
|
|
510
|
+
def numel(self) -> int:
|
|
511
|
+
"""
|
|
512
|
+
Number of elements contained
|
|
513
|
+
|
|
514
|
+
Raises:
|
|
515
|
+
ValueError: When :attr:`is_dynamic` is `True`
|
|
516
|
+
"""
|
|
517
|
+
if self.is_dynamic:
|
|
518
|
+
raise ValueError("Shape has at least one dynamic dimension.")
|
|
519
|
+
return int(np.prod(self._shapes))
|
|
520
|
+
|
|
521
|
+
def __iter__(self):
|
|
522
|
+
yield from self._shapes
|
|
523
|
+
|
|
524
|
+
def __getitem__(self, index):
|
|
525
|
+
return self._shapes[index]
|
|
526
|
+
|
|
527
|
+
def __len__(self):
|
|
528
|
+
return self._length
|
|
529
|
+
|
|
530
|
+
def __str__(self):
|
|
531
|
+
return "Shape" + str(tuple(self))
|
|
532
|
+
|
|
533
|
+
@property
|
|
534
|
+
def is_dynamic(self) -> bool:
|
|
535
|
+
"""
|
|
536
|
+
`True` if this tensor has at least one dynamic dimension, `False` otherwise.
|
|
537
|
+
"""
|
|
538
|
+
if self._is_dynamic is not None:
|
|
539
|
+
return self._is_dynamic
|
|
540
|
+
|
|
541
|
+
self._is_dynamic = False
|
|
542
|
+
for d in self._shapes:
|
|
543
|
+
if d == -1:
|
|
544
|
+
self._is_dynamic = True
|
|
545
|
+
|
|
546
|
+
return self._is_dynamic
|
|
547
|
+
|
|
548
|
+
@property
|
|
549
|
+
def opt(self) -> Tuple[int]:
|
|
550
|
+
"""
|
|
551
|
+
Optimum value of dimensions specified for auto-tuning.
|
|
552
|
+
"""
|
|
553
|
+
if not self.is_dynamic:
|
|
554
|
+
raise ValueError("opt property is only accessible if is_dynamic is true")
|
|
555
|
+
if not hasattr(self, "_desc"):
|
|
556
|
+
raise AttributeError(
|
|
557
|
+
"Shape object has at least one dynamic dimension, but no information is available on 'opt' property."
|
|
558
|
+
)
|
|
559
|
+
return tuple(self._desc.opt)
|
|
560
|
+
|
|
561
|
+
@property
|
|
562
|
+
def min(self) -> Tuple[int]:
|
|
563
|
+
"""
|
|
564
|
+
Lower bounds on tensor's dimensions.
|
|
565
|
+
"""
|
|
566
|
+
if not self.is_dynamic:
|
|
567
|
+
raise ValueError("min property is only accessible if is_dynamic is true")
|
|
568
|
+
if not hasattr(self, "_desc"):
|
|
569
|
+
raise AttributeError(
|
|
570
|
+
"Shape object has at least one dynamic dimension, but no information is available on 'min' property."
|
|
571
|
+
)
|
|
572
|
+
return tuple(self._desc.min)
|
|
573
|
+
|
|
574
|
+
@property
|
|
575
|
+
def max(self) -> Tuple[int]:
|
|
576
|
+
"""
|
|
577
|
+
Upper bounds on tensor's dimensions.
|
|
578
|
+
"""
|
|
579
|
+
if not self.is_dynamic:
|
|
580
|
+
raise ValueError("max property is only accessible if is_dynamic is true")
|
|
581
|
+
if not hasattr(self, "_desc"):
|
|
582
|
+
raise AttributeError(
|
|
583
|
+
"Shape object has at least one dynamic dimension, but no information is available on 'max' property."
|
|
584
|
+
)
|
|
585
|
+
return tuple(self._desc.max)
|
|
586
|
+
|
|
587
|
+
def __setitem__(self, index, val):
|
|
588
|
+
if index >= self._length:
|
|
589
|
+
raise IndexError("Index out of range")
|
|
590
|
+
self._shapes[index] = val
|
|
591
|
+
|
|
592
|
+
def _clone(self):
|
|
593
|
+
ret = Shape()
|
|
594
|
+
ret.__dict__.update(self.__dict__)
|
|
595
|
+
return ret
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
# Descriptor for a tensor
|
|
599
|
+
# A `TensorDesc` never contains nor refers to any tensor data.
|
|
600
|
+
@public_api()
|
|
601
|
+
class TensorDesc:
|
|
602
|
+
"""
|
|
603
|
+
Descriptor for a tensor
|
|
604
|
+
A `TensorDesc` never contains nor refers to any tensor data.
|
|
605
|
+
"""
|
|
606
|
+
def __init__(self, shape_expr: ShapeExprs = None, dtype: trt.DataType = None, format: trt.TensorFormat = None, scale: float = None):
|
|
607
|
+
"""
|
|
608
|
+
Args:
|
|
609
|
+
shape_expr (ShapeExprs): The data with which to initialize the tensor.
|
|
610
|
+
dtype (trt.DataType): The data type of the tensor.
|
|
611
|
+
format (trt.TensorFormat): Format (layout) of the tensor.
|
|
612
|
+
scale (float): Scale for INT8 data type.
|
|
613
|
+
|
|
614
|
+
.. code-block:: python
|
|
615
|
+
:linenos:
|
|
616
|
+
:caption: Creates a TensorDesc with constant shape expressions
|
|
617
|
+
|
|
618
|
+
tensor = trt.TensorDesc((10, 2, 32, 32), dtype=trt.float32)
|
|
619
|
+
|
|
620
|
+
.. code-block:: python
|
|
621
|
+
:linenos:
|
|
622
|
+
:caption: Creates a TensorDesc from shape expression of another TensorDesc
|
|
623
|
+
|
|
624
|
+
tensor = trt.from_shape_expr(other.shape_expr, dtype=trt.float32)
|
|
625
|
+
"""
|
|
626
|
+
|
|
627
|
+
# `TensorDesc` may or may not have `Shape` information but always has symbolic shape expressions and dtype
|
|
628
|
+
self._shape_expr = shape_expr
|
|
629
|
+
self._dtype = dtype
|
|
630
|
+
|
|
631
|
+
# `shape`, `format`, and `scale` are only accessible if `has_shape`. Presently, this would be inside autotune.
|
|
632
|
+
self._shape = None
|
|
633
|
+
self._format = format
|
|
634
|
+
self._scale = scale
|
|
635
|
+
|
|
636
|
+
self._aliased_to = None
|
|
637
|
+
self._immutable = False
|
|
638
|
+
|
|
639
|
+
def numel(self) -> int:
|
|
640
|
+
"""
|
|
641
|
+
Returns:
|
|
642
|
+
Returns an int with the number of elements of the tensor.
|
|
643
|
+
|
|
644
|
+
.. warning::
|
|
645
|
+
Should only be called when TensorDesc.has_shape is true. If a symbolic expression for the number of elements is required, query TensorDesc.shape_expr.numel().
|
|
646
|
+
"""
|
|
647
|
+
if not self.has_shape:
|
|
648
|
+
raise ValueError(
|
|
649
|
+
"TensorDesc has no shape information available at this stage. Inspect TensorDesc.has_shape to determine availability."
|
|
650
|
+
)
|
|
651
|
+
return int(np.prod(self.shape))
|
|
652
|
+
|
|
653
|
+
@property
|
|
654
|
+
def ndim(self) -> int:
|
|
655
|
+
"""
|
|
656
|
+
Number of dimensions
|
|
657
|
+
"""
|
|
658
|
+
return len(self._shape_expr)
|
|
659
|
+
|
|
660
|
+
@property
|
|
661
|
+
def is_size_tensor(self):
|
|
662
|
+
return False
|
|
663
|
+
|
|
664
|
+
# Return a `TensorDesc` that has identical properties to `self` but is mutable
|
|
665
|
+
def like(self) -> "TensorDesc":
|
|
666
|
+
"""
|
|
667
|
+
Returns:
|
|
668
|
+
Returns a TensorDesc which has identical properties to this tensor, and is mutable.
|
|
669
|
+
|
|
670
|
+
.. code-block:: python
|
|
671
|
+
:linenos:
|
|
672
|
+
:caption: Communicate that output tensor has identical properties to the input tensor
|
|
673
|
+
|
|
674
|
+
@tensorrt.plugin.register("my::plugin")
|
|
675
|
+
def _(inp: tensorrt.plugin.TensorDesc) -> tensorrt.plugin.TensorDesc:
|
|
676
|
+
return inp.like()
|
|
677
|
+
"""
|
|
678
|
+
cloned = self._clone()
|
|
679
|
+
cloned._immutable = False
|
|
680
|
+
return cloned
|
|
681
|
+
|
|
682
|
+
# Return a `TensorDesc` that has identical properties to `self` AND is aliased to `self` (would result in a `Tensor` during enqueue sharing the same data buffer)
|
|
683
|
+
def aliased(self) -> "TensorDesc":
|
|
684
|
+
"""
|
|
685
|
+
Returns:
|
|
686
|
+
Returns a TensorDesc which has identical properties and is aliased to this tensor (would result in a `Tensor` during enqueue sharing the same data buffer).
|
|
687
|
+
Returned TensorDesc is immutable.
|
|
688
|
+
|
|
689
|
+
.. code-block:: python
|
|
690
|
+
:linenos:
|
|
691
|
+
:caption: Communicate that output tensor has identical properties to the input tensor
|
|
692
|
+
|
|
693
|
+
@tensorrt.plugin.register("my::plugin")
|
|
694
|
+
def _(inp: tensorrt.plugin.TensorDesc) -> tensorrt.plugin.TensorDesc:
|
|
695
|
+
return inp.aliased()
|
|
696
|
+
"""
|
|
697
|
+
cloned = self._clone()
|
|
698
|
+
cloned._immutable = False
|
|
699
|
+
cloned._aliased_to = self
|
|
700
|
+
cloned._immutable = True
|
|
701
|
+
return cloned
|
|
702
|
+
|
|
703
|
+
def _clone(self) -> "TensorDesc":
|
|
704
|
+
cloned = TensorDesc()
|
|
705
|
+
cloned.__dict__.update(self.__dict__)
|
|
706
|
+
cloned._immutable = False
|
|
707
|
+
cloned._shape_expr = self._shape_expr._clone()
|
|
708
|
+
if self._shape is not None:
|
|
709
|
+
cloned._shape = self._shape._clone()
|
|
710
|
+
cloned._immutable = True
|
|
711
|
+
return cloned
|
|
712
|
+
|
|
713
|
+
def get_aliased(self) -> "TensorDesc":
|
|
714
|
+
"""
|
|
715
|
+
Returns:
|
|
716
|
+
Returns a TensorDesc for the tensor which this tensor is aliased to. Returns None is this tensor is not aliased to any other tensor.
|
|
717
|
+
"""
|
|
718
|
+
return self._aliased_to
|
|
719
|
+
|
|
720
|
+
def _validate_has_shape(self) -> None:
|
|
721
|
+
if not self.has_shape:
|
|
722
|
+
raise ValueError(
|
|
723
|
+
"TensorDesc has no shape information available at this stage. Inspect TensorDesc.has_shape to determine availability."
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
def _validate_not_immutable(self):
|
|
727
|
+
if hasattr(self, "_immutable") and self._immutable:
|
|
728
|
+
raise ValueError("Cannot modify immutable TensorDesc")
|
|
729
|
+
|
|
730
|
+
@property
|
|
731
|
+
def shape_expr(self) -> ShapeExprs:
|
|
732
|
+
"""
|
|
733
|
+
Symbolic expressions for the tensor shape.
|
|
734
|
+
"""
|
|
735
|
+
return self._shape_expr
|
|
736
|
+
|
|
737
|
+
@property
|
|
738
|
+
def dtype(self) -> trt.DataType:
|
|
739
|
+
"""
|
|
740
|
+
Data type of the tensor.
|
|
741
|
+
"""
|
|
742
|
+
return self._dtype
|
|
743
|
+
|
|
744
|
+
@property
|
|
745
|
+
def shape(self) -> Shape:
|
|
746
|
+
"""
|
|
747
|
+
The (concrete) shape of the tensor.
|
|
748
|
+
|
|
749
|
+
.. warning::
|
|
750
|
+
Only accessible when TensorDesc.has_shape is true.
|
|
751
|
+
"""
|
|
752
|
+
self._validate_has_shape()
|
|
753
|
+
return self._shape
|
|
754
|
+
|
|
755
|
+
@property
|
|
756
|
+
def format(self) -> trt.TensorFormat:
|
|
757
|
+
"""
|
|
758
|
+
The format of the tensor.
|
|
759
|
+
|
|
760
|
+
.. warning::
|
|
761
|
+
Only accessible when TensorDesc.has_shape is true.
|
|
762
|
+
"""
|
|
763
|
+
self._validate_has_shape()
|
|
764
|
+
return self._format
|
|
765
|
+
|
|
766
|
+
@property
|
|
767
|
+
def scale(self) -> float:
|
|
768
|
+
"""
|
|
769
|
+
Scale for INT8 data type.
|
|
770
|
+
|
|
771
|
+
.. warning::
|
|
772
|
+
Only accessible when TensorDesc.has_shape is true.
|
|
773
|
+
"""
|
|
774
|
+
self._validate_has_shape()
|
|
775
|
+
return self._scale
|
|
776
|
+
|
|
777
|
+
|
|
778
|
+
@shape_expr.setter
|
|
779
|
+
def shape_expr(self, value):
|
|
780
|
+
self._shape_expr = value
|
|
781
|
+
|
|
782
|
+
@dtype.setter
|
|
783
|
+
def dtype(self, value):
|
|
784
|
+
self._dtype = value
|
|
785
|
+
|
|
786
|
+
@shape.setter
|
|
787
|
+
def shape(self, value):
|
|
788
|
+
self._validate_not_immutable()
|
|
789
|
+
self._shape = value
|
|
790
|
+
|
|
791
|
+
@format.setter
|
|
792
|
+
def format(self, value):
|
|
793
|
+
self._validate_not_immutable()
|
|
794
|
+
self._format = value
|
|
795
|
+
|
|
796
|
+
@scale.setter
|
|
797
|
+
def scale(self, value):
|
|
798
|
+
self._validate_not_immutable()
|
|
799
|
+
self._scale = value
|
|
800
|
+
|
|
801
|
+
@property
|
|
802
|
+
def is_aliased(self) -> bool:
|
|
803
|
+
"""
|
|
804
|
+
True if this tensor is aliased to another tensor, False otherwise.
|
|
805
|
+
"""
|
|
806
|
+
return self._aliased_to is not None
|
|
807
|
+
|
|
808
|
+
@property
|
|
809
|
+
def has_shape(self) -> bool:
|
|
810
|
+
"""
|
|
811
|
+
True if this tensor has concrete shape information, False otherwise.
|
|
812
|
+
"""
|
|
813
|
+
return self._shape is not None
|
|
814
|
+
|
|
815
|
+
@property
|
|
816
|
+
def is_dynamic(self) -> bool:
|
|
817
|
+
"""
|
|
818
|
+
`True` if this tensor has at least one dynamic dimension, `False` otherwise.
|
|
819
|
+
"""
|
|
820
|
+
if not self.has_shape:
|
|
821
|
+
raise ValueError(
|
|
822
|
+
"TensorDesc has no shape information available at this stage. Inspect TensorDesc.has_shape to determine availability."
|
|
823
|
+
)
|
|
824
|
+
return self.shape.is_dynamic
|
|
825
|
+
|
|
826
|
+
@property
|
|
827
|
+
def has_shape_expr(self) -> bool:
|
|
828
|
+
"""
|
|
829
|
+
True if this tensor has symbolic shape expressions, False otherwise.
|
|
830
|
+
"""
|
|
831
|
+
return self.shape_expr is not None
|
|
832
|
+
|
|
833
|
+
def __setattr__(self, name, value):
|
|
834
|
+
if hasattr(self, "_immutable") and self._immutable and name != "_immutable":
|
|
835
|
+
raise ValueError("Cannot modify immutable TensorDesc properties")
|
|
836
|
+
super().__setattr__(name, value)
|
|
837
|
+
|
|
838
|
+
@public_api()
|
|
839
|
+
class SizeTensorDesc(TensorDesc):
|
|
840
|
+
"""
|
|
841
|
+
Extends :class:`TensorDesc`
|
|
842
|
+
|
|
843
|
+
Descriptor for a size tensor: a scalar of either INT32 or INT64 data type used to express the extent of a data-dependent dimension.
|
|
844
|
+
"""
|
|
845
|
+
def __init__(self, opt: ShapeExpr, upper_bound: ShapeExpr):
|
|
846
|
+
"""
|
|
847
|
+
Args:
|
|
848
|
+
opt (ShapeExpr): Symbolic expression for the extent of this size tensor to use in the autotune process of the engine build
|
|
849
|
+
upper_bound (ShapeExpr): Symbolic expression for the upper-bound of this size tensor
|
|
850
|
+
|
|
851
|
+
.. note:: It is recommended to construct a size tensor using :func:`size_tensor` instead of using this constructor directly
|
|
852
|
+
"""
|
|
853
|
+
super().__init__(ShapeExprs(0), trt.int32)
|
|
854
|
+
self._opt = opt
|
|
855
|
+
self._upper_bound = upper_bound
|
|
856
|
+
self._index = None
|
|
857
|
+
self._expr = SizeTensorShapeExpr(self)
|
|
858
|
+
|
|
859
|
+
@property
|
|
860
|
+
def is_size_tensor(self):
|
|
861
|
+
return True
|
|
862
|
+
|
|
863
|
+
@property
|
|
864
|
+
def opt(self) -> ShapeExpr:
|
|
865
|
+
"""
|
|
866
|
+
Symbolic expression for the extent of this size tensor to use in the autotune process of the engine build
|
|
867
|
+
"""
|
|
868
|
+
return self._opt
|
|
869
|
+
|
|
870
|
+
@property
|
|
871
|
+
def upper_bound(self) -> ShapeExpr:
|
|
872
|
+
"""
|
|
873
|
+
Symbolic expression for the upper-bound of this size tensor
|
|
874
|
+
"""
|
|
875
|
+
return self._upper_bound
|
|
876
|
+
|
|
877
|
+
@property
|
|
878
|
+
def index(self) -> int:
|
|
879
|
+
"""
|
|
880
|
+
Output index at which this size tensor resides
|
|
881
|
+
"""
|
|
882
|
+
return self._index
|
|
883
|
+
|
|
884
|
+
def _set_index(self, idx: int):
|
|
885
|
+
self._index = idx
|
|
886
|
+
|
|
887
|
+
def expr(self) -> SizeTensorShapeExpr:
|
|
888
|
+
"""
|
|
889
|
+
Symbolic expression for this size tensor
|
|
890
|
+
"""
|
|
891
|
+
return self._expr
|
|
892
|
+
|
|
893
|
+
|
|
894
|
+
# A tensor representation that carries data
|
|
895
|
+
@public_api()
|
|
896
|
+
class Tensor:
|
|
897
|
+
"""
|
|
898
|
+
Representation of a tensor that carries data
|
|
899
|
+
|
|
900
|
+
:class:`Tensor` objects are strictly *descriptors* of a tensor with an underlying data buffer. `tensorrt.plugin` does not provide any APIs that perform standard data-altering operations on :class:`Tensor`\s.
|
|
901
|
+
|
|
902
|
+
Supports `__cuda_array_interface__` for interoperability with other frameworks.
|
|
903
|
+
|
|
904
|
+
"""
|
|
905
|
+
def __init__(self):
|
|
906
|
+
self._data_ptr = None
|
|
907
|
+
self._shape = None
|
|
908
|
+
self._format = None
|
|
909
|
+
self._dtype = None
|
|
910
|
+
self._scale = None
|
|
911
|
+
self._strides = None
|
|
912
|
+
|
|
913
|
+
self._aliased_to = None
|
|
914
|
+
self._stream = None
|
|
915
|
+
self._read_only = None
|
|
916
|
+
self._immutable = False
|
|
917
|
+
|
|
918
|
+
@property
|
|
919
|
+
def ndim(self) -> int:
|
|
920
|
+
"""
|
|
921
|
+
Number of dimensions
|
|
922
|
+
"""
|
|
923
|
+
return len(self._shape)
|
|
924
|
+
|
|
925
|
+
@property
|
|
926
|
+
def data_ptr(self) -> int:
|
|
927
|
+
"""
|
|
928
|
+
Pointer to the data buffer of this tensor
|
|
929
|
+
"""
|
|
930
|
+
return self._data_ptr
|
|
931
|
+
|
|
932
|
+
@property
|
|
933
|
+
def dtype(self) -> trt.DataType:
|
|
934
|
+
"""
|
|
935
|
+
Data type of the tensor.
|
|
936
|
+
"""
|
|
937
|
+
return self._dtype
|
|
938
|
+
|
|
939
|
+
@property
|
|
940
|
+
def shape(self) -> Shape:
|
|
941
|
+
"""
|
|
942
|
+
The (concrete) shape of the tensor.
|
|
943
|
+
"""
|
|
944
|
+
return self._shape
|
|
945
|
+
|
|
946
|
+
@property
|
|
947
|
+
def format(self) -> trt.TensorFormat:
|
|
948
|
+
"""
|
|
949
|
+
The format of the tensor.
|
|
950
|
+
"""
|
|
951
|
+
return self._format
|
|
952
|
+
|
|
953
|
+
@property
|
|
954
|
+
def scale(self) -> float:
|
|
955
|
+
"""
|
|
956
|
+
Scale for INT8 data type.
|
|
957
|
+
"""
|
|
958
|
+
return self._scale
|
|
959
|
+
|
|
960
|
+
@property
|
|
961
|
+
def strides(self) -> Tuple[int]:
|
|
962
|
+
"""
|
|
963
|
+
Strides of this tensor.
|
|
964
|
+
"""
|
|
965
|
+
return self._strides
|
|
966
|
+
|
|
967
|
+
@data_ptr.setter
|
|
968
|
+
def data_ptr(self, value):
|
|
969
|
+
self._data_ptr = value
|
|
970
|
+
|
|
971
|
+
@dtype.setter
|
|
972
|
+
def dtype(self, value):
|
|
973
|
+
self._dtype = value
|
|
974
|
+
|
|
975
|
+
@shape.setter
|
|
976
|
+
def shape(self, value):
|
|
977
|
+
self._shape = value
|
|
978
|
+
|
|
979
|
+
@format.setter
|
|
980
|
+
def format(self, value):
|
|
981
|
+
self._format = value
|
|
982
|
+
|
|
983
|
+
@scale.setter
|
|
984
|
+
def scale(self, value):
|
|
985
|
+
self._scale = value
|
|
986
|
+
|
|
987
|
+
@strides.setter
|
|
988
|
+
def strides(self, value):
|
|
989
|
+
self._strides = value
|
|
990
|
+
|
|
991
|
+
def numel(self) -> int:
|
|
992
|
+
"""
|
|
993
|
+
Returns the number of elements of the tensor
|
|
994
|
+
|
|
995
|
+
Raises:
|
|
996
|
+
ValueError: If the tensor has a data-dependent dimension. Examine :attr:`is_data_dependent` to determine whether the tensor is data-dependent.
|
|
997
|
+
|
|
998
|
+
Returns:
|
|
999
|
+
int: Number of elements of the tensor
|
|
1000
|
+
"""
|
|
1001
|
+
if self.is_data_dependent:
|
|
1002
|
+
raise ValueError(
|
|
1003
|
+
"Tensor has a data-dependent dimension. Examine Tensor.shape to determine wildcards (representing data-dependent dimensions)."
|
|
1004
|
+
)
|
|
1005
|
+
return int(np.prod(self._shape))
|
|
1006
|
+
|
|
1007
|
+
@property
|
|
1008
|
+
def __cuda_array_interface__(self):
|
|
1009
|
+
if self._dtype in [trt.DataType.BF16, trt.DataType.FP8, trt.DataType.INT4]:
|
|
1010
|
+
raise ValueError(
|
|
1011
|
+
f"Handling {self._dtype} via '__cuda_array_interface__' is not supported"
|
|
1012
|
+
)
|
|
1013
|
+
|
|
1014
|
+
desc = {
|
|
1015
|
+
"shape": tuple(self._shape),
|
|
1016
|
+
"typestr": np.dtype(trt.nptype(self._dtype)).str,
|
|
1017
|
+
}
|
|
1018
|
+
desc["stream"] = self._stream
|
|
1019
|
+
desc["version"] = 3
|
|
1020
|
+
desc["data"] = (
|
|
1021
|
+
self._data_ptr,
|
|
1022
|
+
False,
|
|
1023
|
+
) # torch does not support read_only flag. Always set to False -- it is user's responsibility to respect implied read-write restriction(s).
|
|
1024
|
+
desc["strides"] = tuple(
|
|
1025
|
+
[s * np.dtype(trt.nptype(self._dtype)).itemsize for s in self._strides]
|
|
1026
|
+
)
|
|
1027
|
+
|
|
1028
|
+
return desc
|
|
1029
|
+
|
|
1030
|
+
def __setattr__(self, name, value):
|
|
1031
|
+
if hasattr(self, "_immutable") and self._immutable and name != "_immutable":
|
|
1032
|
+
raise ValueError("Cannot modify immutable Tensor properties")
|
|
1033
|
+
super().__setattr__(name, value)
|
|
1034
|
+
|
|
1035
|
+
def get_aliased(self) -> "Tensor":
|
|
1036
|
+
"""
|
|
1037
|
+
Returns:
|
|
1038
|
+
Returns :class:`Tensor` of the tensor which this tensor is aliased to. Returns None is this tensor is not aliased to any other tensor.
|
|
1039
|
+
"""
|
|
1040
|
+
return self._aliased_to
|
|
1041
|
+
|
|
1042
|
+
@property
|
|
1043
|
+
def is_aliased(self):
|
|
1044
|
+
"""
|
|
1045
|
+
True if this tensor is aliased to another tensor, False otherwise.
|
|
1046
|
+
"""
|
|
1047
|
+
return self._aliased_to is None
|
|
1048
|
+
|
|
1049
|
+
@property
|
|
1050
|
+
def is_data_dependent(self):
|
|
1051
|
+
"""
|
|
1052
|
+
True if this tensor contains at least one data-dependent dimension, False otherwise.
|
|
1053
|
+
"""
|
|
1054
|
+
return self._shape.is_dynamic
|
|
1055
|
+
|
|
1056
|
+
# Return a `Tensor` which has the same `data_ptr` as `self` but has the provided shape.
|
|
1057
|
+
def aliased(self, shape: Union[Shape, Tuple[int], trt.PluginTensorDesc] = None) -> "Tensor":
|
|
1058
|
+
"""
|
|
1059
|
+
Return a :class:`Tensor` which has the same :attr:`data_ptr` as this but has the provided `shape`.
|
|
1060
|
+
|
|
1061
|
+
Args:
|
|
1062
|
+
shape (Union[Shape, Tuple[int], trt.PluginTensorDesc], optional): Required shape of the new tensor (must have the same volume). Defaults to same shape.
|
|
1063
|
+
|
|
1064
|
+
Raises:
|
|
1065
|
+
ValueError: If `shape` is not a supported type or if it does not have the same volume
|
|
1066
|
+
"""
|
|
1067
|
+
cloned = Tensor()
|
|
1068
|
+
cloned.__dict__.update(self.__dict__)
|
|
1069
|
+
cloned._immutable = False
|
|
1070
|
+
if isinstance(shape, trt.PluginTensorDesc):
|
|
1071
|
+
cloned._shape = Shape(shape)
|
|
1072
|
+
elif isinstance(shape, Shape):
|
|
1073
|
+
cloned._shape = shape
|
|
1074
|
+
elif isinstance(shape, tuple):
|
|
1075
|
+
desc = trt.PluginTensorDesc()
|
|
1076
|
+
desc.dims = shape
|
|
1077
|
+
desc.type = self._dtype
|
|
1078
|
+
desc.format = self._format
|
|
1079
|
+
desc.scale = self._scale
|
|
1080
|
+
cloned._shape = Shape(desc)
|
|
1081
|
+
elif shape is None:
|
|
1082
|
+
pass
|
|
1083
|
+
else:
|
|
1084
|
+
raise ValueError("Unsupported type for 'shape'")
|
|
1085
|
+
|
|
1086
|
+
# If either the `shape` or self._shape has a wildcard, we allow aliasing
|
|
1087
|
+
if not self.is_data_dependent and cloned.is_data_dependent:
|
|
1088
|
+
if cloned._shape.numel() > self.numel():
|
|
1089
|
+
raise ValueError("Volume of this tensor is less than the provided 'shape'.")
|
|
1090
|
+
|
|
1091
|
+
cloned._aliased_to = self
|
|
1092
|
+
return cloned
|
|
1093
|
+
|
|
1094
|
+
if IS_AOT_ENABLED:
|
|
1095
|
+
@public_api()
|
|
1096
|
+
class KernelLaunchParams:
|
|
1097
|
+
"""
|
|
1098
|
+
Args:
|
|
1099
|
+
grid_x (Union[int, trt.IDimensionExpr, SymInt32], optional): The grid x dimension. Defaults to 1.
|
|
1100
|
+
grid_y (Union[int, trt.IDimensionExpr, SymInt32], optional): The grid y dimension. Defaults to 1.
|
|
1101
|
+
grid_z (Union[int, trt.IDimensionExpr, SymInt32], optional): The grid z dimension. Defaults to 1.
|
|
1102
|
+
block_x (Union[int, trt.IDimensionExpr, SymInt32], optional): The x dimension of each thread block. Defaults to 1.
|
|
1103
|
+
block_y (Union[int, trt.IDimensionExpr, SymInt32], optional): The y dimension of each thread block. Defaults to 1.
|
|
1104
|
+
block_z (Union[int, trt.IDimensionExpr, SymInt32], optional): The z dimension of each thread block. Defaults to 1.
|
|
1105
|
+
shared_mem (Union[int, trt.IDimensionExpr, SymInt32], optional): Shared-memory per thread block in bytes. Defaults to 0.
|
|
1106
|
+
"""
|
|
1107
|
+
def __init__(self,
|
|
1108
|
+
grid_x: Union[int, trt.IDimensionExpr, SymInt32] = 1,
|
|
1109
|
+
grid_y: Union[int, trt.IDimensionExpr, SymInt32] = 1,
|
|
1110
|
+
grid_z: Union[int, trt.IDimensionExpr, SymInt32] = 1,
|
|
1111
|
+
block_x: Union[int, trt.IDimensionExpr, SymInt32] = 1,
|
|
1112
|
+
block_y: Union[int, trt.IDimensionExpr, SymInt32] = 1,
|
|
1113
|
+
block_z: Union[int, trt.IDimensionExpr, SymInt32] = 1,
|
|
1114
|
+
shared_mem: Union[int, trt.IDimensionExpr, SymInt32] = 0):
|
|
1115
|
+
self.grid_x = SymInt32(grid_x)
|
|
1116
|
+
self.grid_y = SymInt32(grid_y)
|
|
1117
|
+
self.grid_z = SymInt32(grid_z)
|
|
1118
|
+
self.block_x = SymInt32(block_x)
|
|
1119
|
+
self.block_y = SymInt32(block_y)
|
|
1120
|
+
self.block_z = SymInt32(block_z)
|
|
1121
|
+
self.shared_mem = SymInt32(shared_mem)
|
|
1122
|
+
|
|
1123
|
+
|
|
1124
|
+
def __setattr__(self, name, value):
|
|
1125
|
+
if name in ["grid_x", "grid_y", "grid_z", "block_x", "block_y", "block_z", "shared_mem"]:
|
|
1126
|
+
self.__dict__[name] = SymInt32(value)
|
|
1127
|
+
else:
|
|
1128
|
+
raise AttributeError(f"KernelLaunchParams object has no attribute '{name}'")
|