tensorrt-cu12-bindings 10.13.3.9.post1__cp312-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1128 @@
1
+ #
2
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ import tensorrt as trt
19
+ from typing import Tuple, Union
20
+ import numpy as np
21
+ from ._export import public_api, IS_AOT_ENABLED
22
+ from abc import ABC, abstractmethod
23
+
24
+ class SymExpr(ABC):
25
+
26
+ @abstractmethod
27
+ def _op(self, op, other):
28
+ pass
29
+
30
+ @abstractmethod
31
+ def __add__(self, other):
32
+ pass
33
+
34
+ @abstractmethod
35
+ def __sub__(self, other):
36
+ pass
37
+
38
+ @abstractmethod
39
+ def __mul__(self, other):
40
+ pass
41
+
42
+ @abstractmethod
43
+ def __floordiv__(self, other):
44
+ pass
45
+
46
+ @abstractmethod
47
+ def __eq__(self, other):
48
+ pass
49
+
50
+ @abstractmethod
51
+ def __lt__(self, other):
52
+ pass
53
+
54
+ @abstractmethod
55
+ def __repr__(self):
56
+ pass
57
+
58
+ @property
59
+ @abstractmethod
60
+ def is_constant(self) -> bool:
61
+ pass
62
+
63
+ @property
64
+ @abstractmethod
65
+ def constant_value(self) -> int:
66
+ pass
67
+
68
+ # Evaluate the underlying trt.IDimensionExpr, if so done lazily
69
+ @property
70
+ @abstractmethod
71
+ def _expr(self):
72
+ pass
73
+
74
+ class SymIntExprMeta(type(SymExpr)):
75
+ pass
76
+
77
+ class SymIntExpr(SymExpr, metaclass=SymIntExprMeta):
78
+ """
79
+ Symbolic integer (scalar) expression
80
+ """
81
+ _exprBuilder = None # trt.IExprBuilder instance. Populated when a shape-calculation context is entered.
82
+
83
+ def __init__(self, value: Union[int, trt.IDimensionExpr, "SymIntExpr"] = None):
84
+ """
85
+ Args:
86
+ value (Union[int, trt.IDimensionExpr, SymIntExpr], optional): Constant or another symbolic expression. Defaults to creating a fake shape expression.
87
+ """
88
+ self._int_expr = None
89
+ if isinstance(value, int):
90
+ if SymIntExpr._exprBuilder is None:
91
+ self._int_expr = None
92
+ else:
93
+ self._int_expr = SymIntExpr._exprBuilder.constant(value)
94
+ elif isinstance(value, trt.IDimensionExpr):
95
+ self._int_expr = value
96
+ elif isinstance(value, SymIntExpr):
97
+ self._int_expr = value._int_expr
98
+
99
+ def _op(self, op: trt.DimensionOperation, other: Union[int, "SymIntExpr"]):
100
+ if isinstance(other, int):
101
+ other = SymIntExpr(other)
102
+ return SymIntExpr(SymIntExpr._exprBuilder.operation(op, self._expr, other._expr))
103
+
104
+ # Binary operations for +, -, *, //, ==. <
105
+ # Those for ceil_div, max and min are provided as top-level functions of tensorrt.plugin
106
+ def __add__(self, other: Union[int, "SymIntExpr"]):
107
+ return self._op(trt.DimensionOperation.SUM, other)
108
+
109
+ def __sub__(self, other: Union[int, "SymIntExpr"]):
110
+ return self._op(trt.DimensionOperation.SUB, other)
111
+
112
+ def __mul__(self, other: Union[int, "SymIntExpr"]):
113
+ return self._op(trt.DimensionOperation.PROD, other)
114
+
115
+ def __floordiv__(self, other: Union[int, "SymIntExpr"]):
116
+ return self._op(trt.DimensionOperation.FLOOR_DIV, other)
117
+
118
+ def __eq__(self, other: Union[int, "SymIntExpr"]):
119
+ return self._op(trt.DimensionOperation.EQUAL, other)
120
+
121
+ def __lt__(self, other: Union[int, "SymIntExpr"]):
122
+ return self._op(trt.DimensionOperation.LESS, other)
123
+
124
+ def __repr__(self):
125
+ if self._is_dummy:
126
+ return f"FakeSymIntExpr[id={id(self)}]"
127
+ elif not self.is_constant:
128
+ return f"SymIntExpr[id={id(self)}]"
129
+ return f"SymIntExpr[{self._expr.get_constant_value()}]"
130
+
131
+ @property
132
+ def is_constant(self) -> bool:
133
+ """
134
+ `True` if this integer expression is a build-time constant, `False` otherwise.
135
+
136
+ Raises:
137
+ RuntimeError: For fake :class:`SymIntExpr`\s. Check :attr:`is_fake` to determine accessibility.
138
+ """
139
+ if self._is_dummy:
140
+ raise RuntimeError(
141
+ "Not accessible for fake 'SymIntExpr's. Check is_fake to determine accessibility."
142
+ )
143
+ return self._expr.is_constant()
144
+
145
+ def constant_value(self) -> int:
146
+ """
147
+ Return value of the constant integer expression.
148
+
149
+ Raises:
150
+ RuntimeError: For non-constant integer expressions. Check :attr:`is_constant` to determine accessibility.
151
+ """
152
+ if not self.is_constant:
153
+ raise RuntimeError(
154
+ "Not accessible for non-constant integer expressions. Check is_constant to determine accessibility."
155
+ )
156
+ return self._expr.get_constant_value()
157
+
158
+ # Evaluate the underlying trt.IDimensionExpr, if so done lazily
159
+ @property
160
+ def _expr(self):
161
+ return self._int_expr
162
+
163
+ def _clone(self):
164
+ return SymIntExpr(self + 0)
165
+
166
+ class SymExprImpl(trt.ISymExpr):
167
+ def __init__(self, expr: SymIntExpr):
168
+ trt.ISymExpr.__init__(self)
169
+ self.type = trt.PluginArgType.INT
170
+ if isinstance(expr, SymInt32):
171
+ self.dtype = trt.PluginArgDataType.INT32
172
+ elif isinstance(expr, SymInt16):
173
+ self.dtype = trt.PluginArgDataType.INT16
174
+ elif isinstance(expr, SymInt8):
175
+ self.dtype = trt.PluginArgDataType.INT8
176
+ else:
177
+ raise ValueError(f"Unknown SymIntExpr type {type(expr)}")
178
+
179
+ self.expr = expr._expr
180
+ @public_api()
181
+ class SymInt32(SymIntExpr):
182
+ """
183
+ Symbolic expression for a 32-bit integer
184
+ """
185
+ def __init__(self, value: Union[int, trt.IDimensionExpr, SymIntExpr] = None):
186
+ super().__init__(value)
187
+
188
+ def __call__(self):
189
+ return SymExprImpl(self)
190
+ @public_api()
191
+ class SymInt8(SymIntExpr):
192
+ """
193
+ Symbolic expression for an 8-bit integer
194
+ """
195
+ def __init__(self, value: Union[int, trt.IDimensionExpr, "SymIntExpr"] = None):
196
+ super().__init__(value)
197
+
198
+ @public_api()
199
+ class SymInt16(SymIntExpr):
200
+ """
201
+ Symbolic expression for a 16-bit integer
202
+ """
203
+ def __init__(self, value: Union[int, trt.IDimensionExpr, "SymIntExpr"] = None):
204
+ super().__init__(value)
205
+
206
+ # Symbolic expression for a given dimension of a tensor
207
+ @public_api()
208
+ class ShapeExpr(SymInt32):
209
+ """
210
+ Symbolic expression for single dimension of a tensor
211
+ """
212
+ def __init__(self, value: Union[int, trt.IDimensionExpr, "ShapeExpr", SymIntExpr] = None):
213
+ """
214
+ Args:
215
+ value (Union[int, trt.IDimensionExpr, ShapeExpr, SymIntExpr], optional): Constant or another symbolic expression. Defaults to creating a fake shape expression.
216
+ """
217
+ super().__init__(value)
218
+ self._exprBuilder = SymIntExpr._exprBuilder
219
+ self._is_dummy = False
220
+ self._is_size_tensor = False
221
+ if value is None:
222
+ self._is_dummy = True
223
+ elif isinstance(value, int):
224
+ if self._exprBuilder is None:
225
+ self._is_dummy = True
226
+ elif isinstance(value, ShapeExpr):
227
+ self._is_dummy = value._is_dummy
228
+ self._is_size_tensor = value._is_size_tensor
229
+ elif isinstance(value, SymIntExpr):
230
+ pass
231
+
232
+ def _op(self, op: trt.DimensionOperation, other: Union[int, "ShapeExpr"]):
233
+ if self._is_size_tensor:
234
+ raise ValueError("It is not permitted to perform binary operations on size tensor expressions") # trt limitation
235
+ if self._is_dummy:
236
+ return ShapeExpr()
237
+ return ShapeExpr(super()._op(op, other))
238
+
239
+ def __repr__(self):
240
+ if self._is_dummy:
241
+ return f"FakeShapeExpr[id={id(self)}]"
242
+ elif not self.is_constant:
243
+ return f"ShapeExpr[id={id(self)}]"
244
+ return f"ShapeExpr[{self._expr.get_constant_value()}]"
245
+
246
+ # A ShapeExpr may be "fake" when it is accessed in a non-shape calculation context. Fake `ShapeExpr`s are externally indistinguishable unless `is_constant` or `constant_value` is required.
247
+ # Therefore, constant checks/access must occur conditionally after evaluating `is_fake`.
248
+ @property
249
+ def is_fake(self) -> bool:
250
+ """
251
+ A ShapeExpr may be "fake" when it is accessed in a non-shape calculation context.
252
+ Fake `ShapeExpr`s are externally indistinguishable unless `is_constant` or `constant_value` is required.
253
+ """
254
+ return self._is_dummy
255
+
256
+ @property
257
+ def is_size_tensor(self) -> bool:
258
+ """
259
+ `True` if this represents a size tensor, `False` otherwise.
260
+ """
261
+ return self._is_size_tensor
262
+
263
+ @property
264
+ def is_constant(self) -> bool:
265
+ """
266
+ `True` if this shape expression is a build-time constant, `False` otherwise.
267
+
268
+ Raises:
269
+ RuntimeError: For fake :class:`ShapeExpr`\s. Check :attr:`is_fake` to determine accessibility.
270
+ """
271
+ if self._is_dummy:
272
+ raise RuntimeError(
273
+ "Not accessible for fake 'ShapeExpr's. Check is_fake to determine accessibility."
274
+ )
275
+ return super().is_constant
276
+
277
+ def constant_value(self) -> int:
278
+ """
279
+ Return value of the constant shape expression.
280
+
281
+ Raises:
282
+ RuntimeError: For non-constant shape expressions. Check :attr:`is_constant` to determine accessibility.
283
+ """
284
+ if not self.is_constant:
285
+ raise RuntimeError(
286
+ "Not accessible for non-constant shape expressions. Check is_constant to determine accessibility."
287
+ )
288
+ return super().constant_value()
289
+
290
+ def _clone(self):
291
+ ret = ShapeExpr(self + 0)
292
+ ret._is_dummy = self._is_dummy
293
+ ret._is_size_tensor = self._is_size_tensor
294
+ return ret
295
+
296
+ @public_api()
297
+ class SizeTensorShapeExpr(ShapeExpr):
298
+ """
299
+ Extends :class:`ShapeExpr`
300
+
301
+ A shape expression that represent a size tensor
302
+
303
+ """
304
+ def __init__(self, size_tensor_desc: "SizeTensorDesc"):
305
+ """
306
+ .. note:: It is recommended to use :attr:`SizeTensorDesc.expr` to get a :class:`SizeTensorShapeExpr` representing a size tensor
307
+ """
308
+ super().__init__()
309
+ self._is_size_tensor = True
310
+ self._is_dummy = size_tensor_desc.opt.is_fake
311
+ self._size_tensor_desc = size_tensor_desc
312
+
313
+ def _op(self, op: trt.DimensionOperation, other: Union[int, "ShapeExpr"]):
314
+ raise ValueError("It is not permitted to perform binary operations on size tensor expressions") # TRT limitation
315
+
316
+ @property
317
+ def is_constant(self):
318
+ if self._is_dummy:
319
+ raise RuntimeError(
320
+ "Not accessible for fake 'ShapeExpr's. Check is_fake to determine accessibility."
321
+ )
322
+ return False
323
+
324
+ @property
325
+ def _expr(self):
326
+ if self._int_expr is not None:
327
+ return self._int_expr
328
+
329
+ self._int_expr = super()._exprBuilder.declare_size_tensor(self._size_tensor_desc.index, self._size_tensor_desc.opt._expr, self._size_tensor_desc.upper_bound._expr)
330
+ return self._int_expr
331
+
332
+ def __repr__(self):
333
+ return f"ShapeExpr[is_size_tensor = True, id={id(self)}]"
334
+
335
+ def _from_scalar(s):
336
+ if isinstance(s, int):
337
+ return SymInt32(s)
338
+ elif isinstance(s, float):
339
+ raise ValueError("Float symbolic expressions are not supported")
340
+ else:
341
+ raise ValueError(f"Unsupported type: '{type(s)}'")
342
+
343
+ # Iterable holding `ShapeExpr`s
344
+ @public_api()
345
+ class SymExprs:
346
+ def __init__(self, length: int):
347
+ """
348
+ Iterable holding symbolic expressions
349
+
350
+ Args:
351
+ length (int): Number of dimensions of the tensor
352
+ """
353
+ self._length = length
354
+ self._exprs = [None] * length
355
+
356
+ @classmethod
357
+ def from_tuple(cls, shape_exprs: Tuple[Union[SymExpr, int]]) -> "SymExprs":
358
+ """
359
+ Args:
360
+ shape_exprs (Tuple[Union[SymExpr, int]]): Tuple to construct :class:`SymExprs` from
361
+ """
362
+
363
+ shape_exprs_ = tuple([e if isinstance(e, SymExpr) else _from_scalar(e) for e in shape_exprs])
364
+ inst = cls(len(shape_exprs_))
365
+ inst._exprs = list(shape_exprs_)
366
+ return inst
367
+
368
+ def __iter__(self):
369
+ return iter(self._exprs)
370
+
371
+ def __getitem__(self, index):
372
+ return self._exprs[index]
373
+
374
+ def __len__(self):
375
+ return self._length
376
+
377
+ def __setitem__(self, index, expr):
378
+ if index >= self._length:
379
+ raise IndexError("Index out of range")
380
+
381
+ if not isinstance(expr, SymExpr):
382
+ expr = _from_scalar(expr)
383
+
384
+ self._exprs[index] = expr
385
+
386
+ def __repr__(self):
387
+ return f"SymExprs[{', '.join([s.__repr__() for s in self._exprs])}]"
388
+
389
+ @public_api()
390
+ class ShapeExprs(SymExprs):
391
+ def __init__(self, length, _is_dummy = False):
392
+ """
393
+ Iterable holding :class:`ShapeExpr`\s, representing a tensor shape
394
+
395
+ Args:
396
+ length (int): Number of dimensions of the tensor
397
+ """
398
+ if length > trt.Dims.MAX_DIMS:
399
+ raise ValueError(f"ShapeExprs can only support up to trt.Dims.MAX_DIMS = {trt.Dims.MAX_DIMS} dimensions. {length} given.")
400
+
401
+ super().__init__(length)
402
+
403
+ self._is_dummy = _is_dummy
404
+ if _is_dummy:
405
+ self._exprs = [ShapeExpr()] * length
406
+
407
+ @classmethod
408
+ def from_tuple(cls, shape_exprs: Tuple[Union[ShapeExpr, int]]) -> "ShapeExpr":
409
+ """
410
+ Args:
411
+ shape_exprs (Tuple[Union[ShapeExpr, int]]): Tuple to construct :class:`ShapeExprs` from
412
+ """
413
+
414
+ shape_exprs_ = tuple([e if isinstance(e, ShapeExpr) else ShapeExpr(e) for e in shape_exprs])
415
+ inst = cls(len(shape_exprs_))
416
+ inst._exprs = list(shape_exprs_)
417
+ return inst
418
+
419
+ def numel(self) -> ShapeExpr:
420
+ """
421
+ Returns a symbolic expression for the number of elements
422
+ """
423
+ ret = ShapeExpr(1)
424
+ for s in self._exprs:
425
+ ret *= s
426
+ return ret
427
+
428
+ def __setitem__(self, index, value):
429
+ if index >= self._length:
430
+ raise IndexError("Index out of range")
431
+
432
+ if not isinstance(value, ShapeExpr):
433
+ if not isinstance(value, int):
434
+ raise ValueError(f"Value should be int or ShapeExpr. Got '{type(value)}'")
435
+ value = ShapeExpr(value)
436
+
437
+ self._exprs[index] = value
438
+
439
+ def __repr__(self):
440
+ return f"ShapeExprs[{', '.join([s.__repr__() for s in self._exprs])}]"
441
+
442
+ def _clone(self):
443
+ ret = ShapeExprs.from_tuple((e._clone() for e in self._exprs))
444
+ ret._is_dummy = self._is_dummy
445
+ return ret
446
+
447
+ @public_api()
448
+ class SymIntExprs(SymExprs):
449
+ def __init__(self, length):
450
+ """
451
+ Iterable holding :class:`SymIntExpr`\s
452
+
453
+ Args:
454
+ length (int): Number of symbolic expressions in the iterable
455
+ """
456
+ super().__init__(length)
457
+
458
+ @classmethod
459
+ def from_tuple(cls, shape_exprs: Tuple[Union[SymIntExpr, int]]) -> "SymIntExpr":
460
+ """
461
+ Args:
462
+ shape_exprs (Tuple[Union[SymIntExpr, int]]): Tuple to construct :class:`SymIntExprs` from
463
+ """
464
+
465
+ shape_exprs_ = tuple([e if isinstance(e, SymIntExpr) else SymIntExpr(e) for e in shape_exprs])
466
+ inst = cls(len(shape_exprs_))
467
+ inst._exprs = list(shape_exprs_)
468
+ return inst
469
+
470
+ def __setitem__(self, index, value):
471
+ if index >= self._length:
472
+ raise IndexError("Index out of range")
473
+
474
+ if not isinstance(value, SymIntExpr):
475
+ if not isinstance(value, int):
476
+ raise ValueError(f"Value should be int or SymIntExpr. Got '{type(value)}'")
477
+ value = SymIntExpr(value)
478
+
479
+ self._exprs[index] = value
480
+
481
+ def __repr__(self):
482
+ return f"SymIntExprs[{', '.join([s.__repr__() for s in self._exprs])}]"
483
+
484
+ # Numerical representation of a tensor shape
485
+ @public_api()
486
+ class Shape:
487
+ """
488
+ Numerical representation of a tensor shape
489
+ """
490
+ def __init__(
491
+ self, tensor_desc: Union[Tuple[int], trt.DynamicPluginTensorDesc, trt.PluginTensorDesc] = None
492
+ ):
493
+ self._is_dynamic = None # set lazily
494
+ if isinstance(tensor_desc, trt.DynamicPluginTensorDesc):
495
+ self._length = len(tensor_desc.desc.dims)
496
+ self._shapes = tensor_desc.desc.dims
497
+ self._desc = tensor_desc
498
+ elif isinstance(tensor_desc, trt.PluginTensorDesc):
499
+ self._length = len(tensor_desc.dims)
500
+ self._shapes = tensor_desc.dims
501
+ elif isinstance(tensor_desc, tuple):
502
+ self._shapes = trt.Dims(tensor_desc)
503
+ self._length = len(self._shapes)
504
+ elif tensor_desc is None:
505
+ self._length = 0
506
+ self._shapes = trt.Dims(0)
507
+ else:
508
+ raise ValueError("Unsupported type used for constructing trt.plugin.Shape! tensor_desc must be a Tuple[int], trt.DynamicPluginTensorDesc, or trt.PluginTensorDesc")
509
+
510
+ def numel(self) -> int:
511
+ """
512
+ Number of elements contained
513
+
514
+ Raises:
515
+ ValueError: When :attr:`is_dynamic` is `True`
516
+ """
517
+ if self.is_dynamic:
518
+ raise ValueError("Shape has at least one dynamic dimension.")
519
+ return int(np.prod(self._shapes))
520
+
521
+ def __iter__(self):
522
+ yield from self._shapes
523
+
524
+ def __getitem__(self, index):
525
+ return self._shapes[index]
526
+
527
+ def __len__(self):
528
+ return self._length
529
+
530
+ def __str__(self):
531
+ return "Shape" + str(tuple(self))
532
+
533
+ @property
534
+ def is_dynamic(self) -> bool:
535
+ """
536
+ `True` if this tensor has at least one dynamic dimension, `False` otherwise.
537
+ """
538
+ if self._is_dynamic is not None:
539
+ return self._is_dynamic
540
+
541
+ self._is_dynamic = False
542
+ for d in self._shapes:
543
+ if d == -1:
544
+ self._is_dynamic = True
545
+
546
+ return self._is_dynamic
547
+
548
+ @property
549
+ def opt(self) -> Tuple[int]:
550
+ """
551
+ Optimum value of dimensions specified for auto-tuning.
552
+ """
553
+ if not self.is_dynamic:
554
+ raise ValueError("opt property is only accessible if is_dynamic is true")
555
+ if not hasattr(self, "_desc"):
556
+ raise AttributeError(
557
+ "Shape object has at least one dynamic dimension, but no information is available on 'opt' property."
558
+ )
559
+ return tuple(self._desc.opt)
560
+
561
+ @property
562
+ def min(self) -> Tuple[int]:
563
+ """
564
+ Lower bounds on tensor's dimensions.
565
+ """
566
+ if not self.is_dynamic:
567
+ raise ValueError("min property is only accessible if is_dynamic is true")
568
+ if not hasattr(self, "_desc"):
569
+ raise AttributeError(
570
+ "Shape object has at least one dynamic dimension, but no information is available on 'min' property."
571
+ )
572
+ return tuple(self._desc.min)
573
+
574
+ @property
575
+ def max(self) -> Tuple[int]:
576
+ """
577
+ Upper bounds on tensor's dimensions.
578
+ """
579
+ if not self.is_dynamic:
580
+ raise ValueError("max property is only accessible if is_dynamic is true")
581
+ if not hasattr(self, "_desc"):
582
+ raise AttributeError(
583
+ "Shape object has at least one dynamic dimension, but no information is available on 'max' property."
584
+ )
585
+ return tuple(self._desc.max)
586
+
587
+ def __setitem__(self, index, val):
588
+ if index >= self._length:
589
+ raise IndexError("Index out of range")
590
+ self._shapes[index] = val
591
+
592
+ def _clone(self):
593
+ ret = Shape()
594
+ ret.__dict__.update(self.__dict__)
595
+ return ret
596
+
597
+
598
+ # Descriptor for a tensor
599
+ # A `TensorDesc` never contains nor refers to any tensor data.
600
+ @public_api()
601
+ class TensorDesc:
602
+ """
603
+ Descriptor for a tensor
604
+ A `TensorDesc` never contains nor refers to any tensor data.
605
+ """
606
+ def __init__(self, shape_expr: ShapeExprs = None, dtype: trt.DataType = None, format: trt.TensorFormat = None, scale: float = None):
607
+ """
608
+ Args:
609
+ shape_expr (ShapeExprs): The data with which to initialize the tensor.
610
+ dtype (trt.DataType): The data type of the tensor.
611
+ format (trt.TensorFormat): Format (layout) of the tensor.
612
+ scale (float): Scale for INT8 data type.
613
+
614
+ .. code-block:: python
615
+ :linenos:
616
+ :caption: Creates a TensorDesc with constant shape expressions
617
+
618
+ tensor = trt.TensorDesc((10, 2, 32, 32), dtype=trt.float32)
619
+
620
+ .. code-block:: python
621
+ :linenos:
622
+ :caption: Creates a TensorDesc from shape expression of another TensorDesc
623
+
624
+ tensor = trt.from_shape_expr(other.shape_expr, dtype=trt.float32)
625
+ """
626
+
627
+ # `TensorDesc` may or may not have `Shape` information but always has symbolic shape expressions and dtype
628
+ self._shape_expr = shape_expr
629
+ self._dtype = dtype
630
+
631
+ # `shape`, `format`, and `scale` are only accessible if `has_shape`. Presently, this would be inside autotune.
632
+ self._shape = None
633
+ self._format = format
634
+ self._scale = scale
635
+
636
+ self._aliased_to = None
637
+ self._immutable = False
638
+
639
+ def numel(self) -> int:
640
+ """
641
+ Returns:
642
+ Returns an int with the number of elements of the tensor.
643
+
644
+ .. warning::
645
+ Should only be called when TensorDesc.has_shape is true. If a symbolic expression for the number of elements is required, query TensorDesc.shape_expr.numel().
646
+ """
647
+ if not self.has_shape:
648
+ raise ValueError(
649
+ "TensorDesc has no shape information available at this stage. Inspect TensorDesc.has_shape to determine availability."
650
+ )
651
+ return int(np.prod(self.shape))
652
+
653
+ @property
654
+ def ndim(self) -> int:
655
+ """
656
+ Number of dimensions
657
+ """
658
+ return len(self._shape_expr)
659
+
660
+ @property
661
+ def is_size_tensor(self):
662
+ return False
663
+
664
+ # Return a `TensorDesc` that has identical properties to `self` but is mutable
665
+ def like(self) -> "TensorDesc":
666
+ """
667
+ Returns:
668
+ Returns a TensorDesc which has identical properties to this tensor, and is mutable.
669
+
670
+ .. code-block:: python
671
+ :linenos:
672
+ :caption: Communicate that output tensor has identical properties to the input tensor
673
+
674
+ @tensorrt.plugin.register("my::plugin")
675
+ def _(inp: tensorrt.plugin.TensorDesc) -> tensorrt.plugin.TensorDesc:
676
+ return inp.like()
677
+ """
678
+ cloned = self._clone()
679
+ cloned._immutable = False
680
+ return cloned
681
+
682
+ # Return a `TensorDesc` that has identical properties to `self` AND is aliased to `self` (would result in a `Tensor` during enqueue sharing the same data buffer)
683
+ def aliased(self) -> "TensorDesc":
684
+ """
685
+ Returns:
686
+ Returns a TensorDesc which has identical properties and is aliased to this tensor (would result in a `Tensor` during enqueue sharing the same data buffer).
687
+ Returned TensorDesc is immutable.
688
+
689
+ .. code-block:: python
690
+ :linenos:
691
+ :caption: Communicate that output tensor has identical properties to the input tensor
692
+
693
+ @tensorrt.plugin.register("my::plugin")
694
+ def _(inp: tensorrt.plugin.TensorDesc) -> tensorrt.plugin.TensorDesc:
695
+ return inp.aliased()
696
+ """
697
+ cloned = self._clone()
698
+ cloned._immutable = False
699
+ cloned._aliased_to = self
700
+ cloned._immutable = True
701
+ return cloned
702
+
703
+ def _clone(self) -> "TensorDesc":
704
+ cloned = TensorDesc()
705
+ cloned.__dict__.update(self.__dict__)
706
+ cloned._immutable = False
707
+ cloned._shape_expr = self._shape_expr._clone()
708
+ if self._shape is not None:
709
+ cloned._shape = self._shape._clone()
710
+ cloned._immutable = True
711
+ return cloned
712
+
713
+ def get_aliased(self) -> "TensorDesc":
714
+ """
715
+ Returns:
716
+ Returns a TensorDesc for the tensor which this tensor is aliased to. Returns None is this tensor is not aliased to any other tensor.
717
+ """
718
+ return self._aliased_to
719
+
720
+ def _validate_has_shape(self) -> None:
721
+ if not self.has_shape:
722
+ raise ValueError(
723
+ "TensorDesc has no shape information available at this stage. Inspect TensorDesc.has_shape to determine availability."
724
+ )
725
+
726
+ def _validate_not_immutable(self):
727
+ if hasattr(self, "_immutable") and self._immutable:
728
+ raise ValueError("Cannot modify immutable TensorDesc")
729
+
730
+ @property
731
+ def shape_expr(self) -> ShapeExprs:
732
+ """
733
+ Symbolic expressions for the tensor shape.
734
+ """
735
+ return self._shape_expr
736
+
737
+ @property
738
+ def dtype(self) -> trt.DataType:
739
+ """
740
+ Data type of the tensor.
741
+ """
742
+ return self._dtype
743
+
744
+ @property
745
+ def shape(self) -> Shape:
746
+ """
747
+ The (concrete) shape of the tensor.
748
+
749
+ .. warning::
750
+ Only accessible when TensorDesc.has_shape is true.
751
+ """
752
+ self._validate_has_shape()
753
+ return self._shape
754
+
755
+ @property
756
+ def format(self) -> trt.TensorFormat:
757
+ """
758
+ The format of the tensor.
759
+
760
+ .. warning::
761
+ Only accessible when TensorDesc.has_shape is true.
762
+ """
763
+ self._validate_has_shape()
764
+ return self._format
765
+
766
+ @property
767
+ def scale(self) -> float:
768
+ """
769
+ Scale for INT8 data type.
770
+
771
+ .. warning::
772
+ Only accessible when TensorDesc.has_shape is true.
773
+ """
774
+ self._validate_has_shape()
775
+ return self._scale
776
+
777
+
778
+ @shape_expr.setter
779
+ def shape_expr(self, value):
780
+ self._shape_expr = value
781
+
782
+ @dtype.setter
783
+ def dtype(self, value):
784
+ self._dtype = value
785
+
786
+ @shape.setter
787
+ def shape(self, value):
788
+ self._validate_not_immutable()
789
+ self._shape = value
790
+
791
+ @format.setter
792
+ def format(self, value):
793
+ self._validate_not_immutable()
794
+ self._format = value
795
+
796
+ @scale.setter
797
+ def scale(self, value):
798
+ self._validate_not_immutable()
799
+ self._scale = value
800
+
801
+ @property
802
+ def is_aliased(self) -> bool:
803
+ """
804
+ True if this tensor is aliased to another tensor, False otherwise.
805
+ """
806
+ return self._aliased_to is not None
807
+
808
+ @property
809
+ def has_shape(self) -> bool:
810
+ """
811
+ True if this tensor has concrete shape information, False otherwise.
812
+ """
813
+ return self._shape is not None
814
+
815
+ @property
816
+ def is_dynamic(self) -> bool:
817
+ """
818
+ `True` if this tensor has at least one dynamic dimension, `False` otherwise.
819
+ """
820
+ if not self.has_shape:
821
+ raise ValueError(
822
+ "TensorDesc has no shape information available at this stage. Inspect TensorDesc.has_shape to determine availability."
823
+ )
824
+ return self.shape.is_dynamic
825
+
826
+ @property
827
+ def has_shape_expr(self) -> bool:
828
+ """
829
+ True if this tensor has symbolic shape expressions, False otherwise.
830
+ """
831
+ return self.shape_expr is not None
832
+
833
+ def __setattr__(self, name, value):
834
+ if hasattr(self, "_immutable") and self._immutable and name != "_immutable":
835
+ raise ValueError("Cannot modify immutable TensorDesc properties")
836
+ super().__setattr__(name, value)
837
+
838
+ @public_api()
839
+ class SizeTensorDesc(TensorDesc):
840
+ """
841
+ Extends :class:`TensorDesc`
842
+
843
+ Descriptor for a size tensor: a scalar of either INT32 or INT64 data type used to express the extent of a data-dependent dimension.
844
+ """
845
+ def __init__(self, opt: ShapeExpr, upper_bound: ShapeExpr):
846
+ """
847
+ Args:
848
+ opt (ShapeExpr): Symbolic expression for the extent of this size tensor to use in the autotune process of the engine build
849
+ upper_bound (ShapeExpr): Symbolic expression for the upper-bound of this size tensor
850
+
851
+ .. note:: It is recommended to construct a size tensor using :func:`size_tensor` instead of using this constructor directly
852
+ """
853
+ super().__init__(ShapeExprs(0), trt.int32)
854
+ self._opt = opt
855
+ self._upper_bound = upper_bound
856
+ self._index = None
857
+ self._expr = SizeTensorShapeExpr(self)
858
+
859
+ @property
860
+ def is_size_tensor(self):
861
+ return True
862
+
863
+ @property
864
+ def opt(self) -> ShapeExpr:
865
+ """
866
+ Symbolic expression for the extent of this size tensor to use in the autotune process of the engine build
867
+ """
868
+ return self._opt
869
+
870
+ @property
871
+ def upper_bound(self) -> ShapeExpr:
872
+ """
873
+ Symbolic expression for the upper-bound of this size tensor
874
+ """
875
+ return self._upper_bound
876
+
877
+ @property
878
+ def index(self) -> int:
879
+ """
880
+ Output index at which this size tensor resides
881
+ """
882
+ return self._index
883
+
884
+ def _set_index(self, idx: int):
885
+ self._index = idx
886
+
887
+ def expr(self) -> SizeTensorShapeExpr:
888
+ """
889
+ Symbolic expression for this size tensor
890
+ """
891
+ return self._expr
892
+
893
+
894
+ # A tensor representation that carries data
895
+ @public_api()
896
+ class Tensor:
897
+ """
898
+ Representation of a tensor that carries data
899
+
900
+ :class:`Tensor` objects are strictly *descriptors* of a tensor with an underlying data buffer. `tensorrt.plugin` does not provide any APIs that perform standard data-altering operations on :class:`Tensor`\s.
901
+
902
+ Supports `__cuda_array_interface__` for interoperability with other frameworks.
903
+
904
+ """
905
+ def __init__(self):
906
+ self._data_ptr = None
907
+ self._shape = None
908
+ self._format = None
909
+ self._dtype = None
910
+ self._scale = None
911
+ self._strides = None
912
+
913
+ self._aliased_to = None
914
+ self._stream = None
915
+ self._read_only = None
916
+ self._immutable = False
917
+
918
+ @property
919
+ def ndim(self) -> int:
920
+ """
921
+ Number of dimensions
922
+ """
923
+ return len(self._shape)
924
+
925
+ @property
926
+ def data_ptr(self) -> int:
927
+ """
928
+ Pointer to the data buffer of this tensor
929
+ """
930
+ return self._data_ptr
931
+
932
+ @property
933
+ def dtype(self) -> trt.DataType:
934
+ """
935
+ Data type of the tensor.
936
+ """
937
+ return self._dtype
938
+
939
+ @property
940
+ def shape(self) -> Shape:
941
+ """
942
+ The (concrete) shape of the tensor.
943
+ """
944
+ return self._shape
945
+
946
+ @property
947
+ def format(self) -> trt.TensorFormat:
948
+ """
949
+ The format of the tensor.
950
+ """
951
+ return self._format
952
+
953
+ @property
954
+ def scale(self) -> float:
955
+ """
956
+ Scale for INT8 data type.
957
+ """
958
+ return self._scale
959
+
960
+ @property
961
+ def strides(self) -> Tuple[int]:
962
+ """
963
+ Strides of this tensor.
964
+ """
965
+ return self._strides
966
+
967
+ @data_ptr.setter
968
+ def data_ptr(self, value):
969
+ self._data_ptr = value
970
+
971
+ @dtype.setter
972
+ def dtype(self, value):
973
+ self._dtype = value
974
+
975
+ @shape.setter
976
+ def shape(self, value):
977
+ self._shape = value
978
+
979
+ @format.setter
980
+ def format(self, value):
981
+ self._format = value
982
+
983
+ @scale.setter
984
+ def scale(self, value):
985
+ self._scale = value
986
+
987
+ @strides.setter
988
+ def strides(self, value):
989
+ self._strides = value
990
+
991
+ def numel(self) -> int:
992
+ """
993
+ Returns the number of elements of the tensor
994
+
995
+ Raises:
996
+ ValueError: If the tensor has a data-dependent dimension. Examine :attr:`is_data_dependent` to determine whether the tensor is data-dependent.
997
+
998
+ Returns:
999
+ int: Number of elements of the tensor
1000
+ """
1001
+ if self.is_data_dependent:
1002
+ raise ValueError(
1003
+ "Tensor has a data-dependent dimension. Examine Tensor.shape to determine wildcards (representing data-dependent dimensions)."
1004
+ )
1005
+ return int(np.prod(self._shape))
1006
+
1007
+ @property
1008
+ def __cuda_array_interface__(self):
1009
+ if self._dtype in [trt.DataType.BF16, trt.DataType.FP8, trt.DataType.INT4]:
1010
+ raise ValueError(
1011
+ f"Handling {self._dtype} via '__cuda_array_interface__' is not supported"
1012
+ )
1013
+
1014
+ desc = {
1015
+ "shape": tuple(self._shape),
1016
+ "typestr": np.dtype(trt.nptype(self._dtype)).str,
1017
+ }
1018
+ desc["stream"] = self._stream
1019
+ desc["version"] = 3
1020
+ desc["data"] = (
1021
+ self._data_ptr,
1022
+ False,
1023
+ ) # torch does not support read_only flag. Always set to False -- it is user's responsibility to respect implied read-write restriction(s).
1024
+ desc["strides"] = tuple(
1025
+ [s * np.dtype(trt.nptype(self._dtype)).itemsize for s in self._strides]
1026
+ )
1027
+
1028
+ return desc
1029
+
1030
+ def __setattr__(self, name, value):
1031
+ if hasattr(self, "_immutable") and self._immutable and name != "_immutable":
1032
+ raise ValueError("Cannot modify immutable Tensor properties")
1033
+ super().__setattr__(name, value)
1034
+
1035
+ def get_aliased(self) -> "Tensor":
1036
+ """
1037
+ Returns:
1038
+ Returns :class:`Tensor` of the tensor which this tensor is aliased to. Returns None is this tensor is not aliased to any other tensor.
1039
+ """
1040
+ return self._aliased_to
1041
+
1042
+ @property
1043
+ def is_aliased(self):
1044
+ """
1045
+ True if this tensor is aliased to another tensor, False otherwise.
1046
+ """
1047
+ return self._aliased_to is None
1048
+
1049
+ @property
1050
+ def is_data_dependent(self):
1051
+ """
1052
+ True if this tensor contains at least one data-dependent dimension, False otherwise.
1053
+ """
1054
+ return self._shape.is_dynamic
1055
+
1056
+ # Return a `Tensor` which has the same `data_ptr` as `self` but has the provided shape.
1057
+ def aliased(self, shape: Union[Shape, Tuple[int], trt.PluginTensorDesc] = None) -> "Tensor":
1058
+ """
1059
+ Return a :class:`Tensor` which has the same :attr:`data_ptr` as this but has the provided `shape`.
1060
+
1061
+ Args:
1062
+ shape (Union[Shape, Tuple[int], trt.PluginTensorDesc], optional): Required shape of the new tensor (must have the same volume). Defaults to same shape.
1063
+
1064
+ Raises:
1065
+ ValueError: If `shape` is not a supported type or if it does not have the same volume
1066
+ """
1067
+ cloned = Tensor()
1068
+ cloned.__dict__.update(self.__dict__)
1069
+ cloned._immutable = False
1070
+ if isinstance(shape, trt.PluginTensorDesc):
1071
+ cloned._shape = Shape(shape)
1072
+ elif isinstance(shape, Shape):
1073
+ cloned._shape = shape
1074
+ elif isinstance(shape, tuple):
1075
+ desc = trt.PluginTensorDesc()
1076
+ desc.dims = shape
1077
+ desc.type = self._dtype
1078
+ desc.format = self._format
1079
+ desc.scale = self._scale
1080
+ cloned._shape = Shape(desc)
1081
+ elif shape is None:
1082
+ pass
1083
+ else:
1084
+ raise ValueError("Unsupported type for 'shape'")
1085
+
1086
+ # If either the `shape` or self._shape has a wildcard, we allow aliasing
1087
+ if not self.is_data_dependent and cloned.is_data_dependent:
1088
+ if cloned._shape.numel() > self.numel():
1089
+ raise ValueError("Volume of this tensor is less than the provided 'shape'.")
1090
+
1091
+ cloned._aliased_to = self
1092
+ return cloned
1093
+
1094
+ if IS_AOT_ENABLED:
1095
+ @public_api()
1096
+ class KernelLaunchParams:
1097
+ """
1098
+ Args:
1099
+ grid_x (Union[int, trt.IDimensionExpr, SymInt32], optional): The grid x dimension. Defaults to 1.
1100
+ grid_y (Union[int, trt.IDimensionExpr, SymInt32], optional): The grid y dimension. Defaults to 1.
1101
+ grid_z (Union[int, trt.IDimensionExpr, SymInt32], optional): The grid z dimension. Defaults to 1.
1102
+ block_x (Union[int, trt.IDimensionExpr, SymInt32], optional): The x dimension of each thread block. Defaults to 1.
1103
+ block_y (Union[int, trt.IDimensionExpr, SymInt32], optional): The y dimension of each thread block. Defaults to 1.
1104
+ block_z (Union[int, trt.IDimensionExpr, SymInt32], optional): The z dimension of each thread block. Defaults to 1.
1105
+ shared_mem (Union[int, trt.IDimensionExpr, SymInt32], optional): Shared-memory per thread block in bytes. Defaults to 0.
1106
+ """
1107
+ def __init__(self,
1108
+ grid_x: Union[int, trt.IDimensionExpr, SymInt32] = 1,
1109
+ grid_y: Union[int, trt.IDimensionExpr, SymInt32] = 1,
1110
+ grid_z: Union[int, trt.IDimensionExpr, SymInt32] = 1,
1111
+ block_x: Union[int, trt.IDimensionExpr, SymInt32] = 1,
1112
+ block_y: Union[int, trt.IDimensionExpr, SymInt32] = 1,
1113
+ block_z: Union[int, trt.IDimensionExpr, SymInt32] = 1,
1114
+ shared_mem: Union[int, trt.IDimensionExpr, SymInt32] = 0):
1115
+ self.grid_x = SymInt32(grid_x)
1116
+ self.grid_y = SymInt32(grid_y)
1117
+ self.grid_z = SymInt32(grid_z)
1118
+ self.block_x = SymInt32(block_x)
1119
+ self.block_y = SymInt32(block_y)
1120
+ self.block_z = SymInt32(block_z)
1121
+ self.shared_mem = SymInt32(shared_mem)
1122
+
1123
+
1124
+ def __setattr__(self, name, value):
1125
+ if name in ["grid_x", "grid_y", "grid_z", "block_x", "block_y", "block_z", "shared_mem"]:
1126
+ self.__dict__[name] = SymInt32(value)
1127
+ else:
1128
+ raise AttributeError(f"KernelLaunchParams object has no attribute '{name}'")