tensorrt-cu12-bindings 10.7.0.post1__cp39-none-win_amd64.whl → 10.9.0.34__cp39-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorrt-cu12-bindings might be problematic. Click here for more details.

@@ -1,5 +1,5 @@
1
1
  #
2
- # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,68 +18,224 @@
18
18
  import tensorrt as trt
19
19
  from typing import Tuple, Union
20
20
  import numpy as np
21
- from ._export import public_api
21
+ from ._export import public_api, IS_AOT_ENABLED
22
+ from abc import ABC, abstractmethod
22
23
 
23
- # Symbolic expression for a given dimension of a tensor
24
- @public_api()
25
- class ShapeExpr:
24
+ class SymExpr(ABC):
25
+
26
+ @abstractmethod
27
+ def _op(self, op, other):
28
+ pass
29
+
30
+ @abstractmethod
31
+ def __add__(self, other):
32
+ pass
33
+
34
+ @abstractmethod
35
+ def __sub__(self, other):
36
+ pass
37
+
38
+ @abstractmethod
39
+ def __mul__(self, other):
40
+ pass
41
+
42
+ @abstractmethod
43
+ def __floordiv__(self, other):
44
+ pass
45
+
46
+ @abstractmethod
47
+ def __eq__(self, other):
48
+ pass
49
+
50
+ @abstractmethod
51
+ def __lt__(self, other):
52
+ pass
53
+
54
+ @abstractmethod
55
+ def __repr__(self):
56
+ pass
57
+
58
+ @property
59
+ @abstractmethod
60
+ def is_constant(self) -> bool:
61
+ pass
62
+
63
+ @property
64
+ @abstractmethod
65
+ def constant_value(self) -> int:
66
+ pass
67
+
68
+ # Evaluate the underlying trt.IDimensionExpr, if so done lazily
69
+ @property
70
+ @abstractmethod
71
+ def _expr(self):
72
+ pass
73
+
74
+ class SymIntExprMeta(type(SymExpr)):
75
+ pass
76
+
77
+ class SymIntExpr(SymExpr, metaclass=SymIntExprMeta):
26
78
  """
27
- Symbolic expression for single dimension of a tensor
79
+ Symbolic integer (scalar) expression
28
80
  """
29
81
  _exprBuilder = None # trt.IExprBuilder instance. Populated when a shape-calculation context is entered.
30
82
 
31
- def __init__(self, value: Union[int, trt.IDimensionExpr, "ShapeExpr"] = None):
83
+ def __init__(self, value: Union[int, trt.IDimensionExpr, "SymIntExpr"] = None):
32
84
  """
33
85
  Args:
34
- value (Union[int, trt.IDimensionExpr, ShapeExpr], optional): Constant or another symbolic expression. Defaults to creating a fake shape expression.
86
+ value (Union[int, trt.IDimensionExpr, SymIntExpr], optional): Constant or another symbolic expression. Defaults to creating a fake shape expression.
35
87
  """
36
- self._is_dummy = False
37
- self._dim_expr = None
38
- self._is_size_tensor = False
39
- if value is None:
40
- self._is_dummy = True
41
- elif isinstance(value, int):
42
- if self._exprBuilder is None:
43
- self._dim_expr = None
44
- self._is_dummy = True
88
+ self._int_expr = None
89
+ if isinstance(value, int):
90
+ if SymIntExpr._exprBuilder is None:
91
+ self._int_expr = None
45
92
  else:
46
- self._dim_expr = ShapeExpr._exprBuilder.constant(value)
93
+ self._int_expr = SymIntExpr._exprBuilder.constant(value)
47
94
  elif isinstance(value, trt.IDimensionExpr):
48
- self._dim_expr = value
49
- elif isinstance(value, ShapeExpr):
50
- self._dim_expr = value._dim_expr
51
- self._is_dummy = value._is_dummy
52
- self._is_size_tensor = value._is_size_tensor
95
+ self._int_expr = value
96
+ elif isinstance(value, SymIntExpr):
97
+ self._int_expr = value._int_expr
53
98
 
54
- def _op(self, op: trt.DimensionOperation, other: Union[int, "ShapeExpr"]):
55
- if self._is_size_tensor:
56
- raise ValueError("It is not permitted to perform binary operations on size tensor expressions") # trt limitation
57
- if self._is_dummy:
58
- return ShapeExpr()
99
+ def _op(self, op: trt.DimensionOperation, other: Union[int, "SymIntExpr"]):
59
100
  if isinstance(other, int):
60
- other = ShapeExpr(other)
61
- return ShapeExpr(ShapeExpr._exprBuilder.operation(op, self._expr, other._expr))
101
+ other = SymIntExpr(other)
102
+ return SymIntExpr(SymIntExpr._exprBuilder.operation(op, self._expr, other._expr))
62
103
 
63
104
  # Binary operations for +, -, *, //, ==. <
64
105
  # Those for ceil_div, max and min are provided as top-level functions of tensorrt.plugin
65
- def __add__(self, other: Union[int, "ShapeExpr"]):
106
+ def __add__(self, other: Union[int, "SymIntExpr"]):
66
107
  return self._op(trt.DimensionOperation.SUM, other)
67
108
 
68
- def __sub__(self, other: Union[int, "ShapeExpr"]):
109
+ def __sub__(self, other: Union[int, "SymIntExpr"]):
69
110
  return self._op(trt.DimensionOperation.SUB, other)
70
111
 
71
- def __mul__(self, other: Union[int, "ShapeExpr"]):
112
+ def __mul__(self, other: Union[int, "SymIntExpr"]):
72
113
  return self._op(trt.DimensionOperation.PROD, other)
73
114
 
74
- def __floordiv__(self, other: Union[int, "ShapeExpr"]):
115
+ def __floordiv__(self, other: Union[int, "SymIntExpr"]):
75
116
  return self._op(trt.DimensionOperation.FLOOR_DIV, other)
76
117
 
77
- def __eq__(self, other: Union[int, "ShapeExpr"]):
118
+ def __eq__(self, other: Union[int, "SymIntExpr"]):
78
119
  return self._op(trt.DimensionOperation.EQUAL, other)
79
120
 
80
- def __lt__(self, other: Union[int, "ShapeExpr"]):
121
+ def __lt__(self, other: Union[int, "SymIntExpr"]):
81
122
  return self._op(trt.DimensionOperation.LESS, other)
82
123
 
124
+ def __repr__(self):
125
+ if self._is_dummy:
126
+ return f"FakeSymIntExpr[id={id(self)}]"
127
+ elif not self.is_constant:
128
+ return f"SymIntExpr[id={id(self)}]"
129
+ return f"SymIntExpr[{self._expr.get_constant_value()}]"
130
+
131
+ @property
132
+ def is_constant(self) -> bool:
133
+ """
134
+ `True` if this integer expression is a build-time constant, `False` otherwise.
135
+
136
+ Raises:
137
+ RuntimeError: For fake :class:`SymIntExpr`\s. Check :attr:`is_fake` to determine accessibility.
138
+ """
139
+ if self._is_dummy:
140
+ raise RuntimeError(
141
+ "Not accessible for fake 'SymIntExpr's. Check is_fake to determine accessibility."
142
+ )
143
+ return self._expr.is_constant()
144
+
145
+ def constant_value(self) -> int:
146
+ """
147
+ Return value of the constant integer expression.
148
+
149
+ Raises:
150
+ RuntimeError: For non-constant integer expressions. Check :attr:`is_constant` to determine accessibility.
151
+ """
152
+ if not self.is_constant:
153
+ raise RuntimeError(
154
+ "Not accessible for non-constant integer expressions. Check is_constant to determine accessibility."
155
+ )
156
+ return self._expr.get_constant_value()
157
+
158
+ # Evaluate the underlying trt.IDimensionExpr, if so done lazily
159
+ @property
160
+ def _expr(self):
161
+ return self._int_expr
162
+
163
+ def _clone(self):
164
+ return SymIntExpr(self + 0)
165
+
166
+ class SymExprImpl(trt.ISymExpr):
167
+ def __init__(self, expr: SymIntExpr):
168
+ trt.ISymExpr.__init__(self)
169
+ self.type = trt.PluginArgType.INT
170
+ if isinstance(expr, SymInt32):
171
+ self.dtype = trt.PluginArgDataType.INT32
172
+ elif isinstance(expr, SymInt16):
173
+ self.dtype = trt.PluginArgDataType.INT16
174
+ elif isinstance(expr, SymInt8):
175
+ self.dtype = trt.PluginArgDataType.INT8
176
+ else:
177
+ raise ValueError(f"Unknown SymIntExpr type {type(expr)}")
178
+
179
+ self.expr = expr._expr
180
+ @public_api()
181
+ class SymInt32(SymIntExpr):
182
+ """
183
+ Symbolic expression for a 32-bit integer
184
+ """
185
+ def __init__(self, value: Union[int, trt.IDimensionExpr, SymIntExpr] = None):
186
+ super().__init__(value)
187
+
188
+ def __call__(self):
189
+ return SymExprImpl(self)
190
+ @public_api()
191
+ class SymInt8(SymIntExpr):
192
+ """
193
+ Symbolic expression for an 8-bit integer
194
+ """
195
+ def __init__(self, value: Union[int, trt.IDimensionExpr, "SymIntExpr"] = None):
196
+ super().__init__(value)
197
+
198
+ @public_api()
199
+ class SymInt16(SymIntExpr):
200
+ """
201
+ Symbolic expression for a 16-bit integer
202
+ """
203
+ def __init__(self, value: Union[int, trt.IDimensionExpr, "SymIntExpr"] = None):
204
+ super().__init__(value)
205
+
206
+ # Symbolic expression for a given dimension of a tensor
207
+ @public_api()
208
+ class ShapeExpr(SymInt32):
209
+ """
210
+ Symbolic expression for single dimension of a tensor
211
+ """
212
+ def __init__(self, value: Union[int, trt.IDimensionExpr, "ShapeExpr", SymIntExpr] = None):
213
+ """
214
+ Args:
215
+ value (Union[int, trt.IDimensionExpr, ShapeExpr, SymIntExpr], optional): Constant or another symbolic expression. Defaults to creating a fake shape expression.
216
+ """
217
+ super().__init__(value)
218
+ self._exprBuilder = SymIntExpr._exprBuilder
219
+ self._is_dummy = False
220
+ self._is_size_tensor = False
221
+ if value is None:
222
+ self._is_dummy = True
223
+ elif isinstance(value, int):
224
+ if self._exprBuilder is None:
225
+ self._is_dummy = True
226
+ elif isinstance(value, ShapeExpr):
227
+ self._is_dummy = value._is_dummy
228
+ self._is_size_tensor = value._is_size_tensor
229
+ elif isinstance(value, SymIntExpr):
230
+ pass
231
+
232
+ def _op(self, op: trt.DimensionOperation, other: Union[int, "ShapeExpr"]):
233
+ if self._is_size_tensor:
234
+ raise ValueError("It is not permitted to perform binary operations on size tensor expressions") # trt limitation
235
+ if self._is_dummy:
236
+ return ShapeExpr()
237
+ return ShapeExpr(super()._op(op, other))
238
+
83
239
  def __repr__(self):
84
240
  if self._is_dummy:
85
241
  return f"FakeShapeExpr[id={id(self)}]"
@@ -116,7 +272,7 @@ class ShapeExpr:
116
272
  raise RuntimeError(
117
273
  "Not accessible for fake 'ShapeExpr's. Check is_fake to determine accessibility."
118
274
  )
119
- return self._expr.is_constant()
275
+ return super().is_constant
120
276
 
121
277
  def constant_value(self) -> int:
122
278
  """
@@ -129,12 +285,13 @@ class ShapeExpr:
129
285
  raise RuntimeError(
130
286
  "Not accessible for non-constant shape expressions. Check is_constant to determine accessibility."
131
287
  )
132
- return self._expr.get_constant_value()
288
+ return super().constant_value()
133
289
 
134
- # Evaluate the underlying trt.IDimensionExpr, if so done lazily
135
- @property
136
- def _expr(self):
137
- return self._dim_expr
290
+ def _clone(self):
291
+ ret = ShapeExpr(self + 0)
292
+ ret._is_dummy = self._is_dummy
293
+ ret._is_size_tensor = self._is_size_tensor
294
+ return ret
138
295
 
139
296
  @public_api()
140
297
  class SizeTensorShapeExpr(ShapeExpr):
@@ -166,69 +323,163 @@ class SizeTensorShapeExpr(ShapeExpr):
166
323
 
167
324
  @property
168
325
  def _expr(self):
169
- if self._dim_expr is not None:
170
- return self._dim_expr
171
-
172
- self._dim_expr = super()._exprBuilder.declare_size_tensor(self._size_tensor_desc.index, self._size_tensor_desc.opt._expr, self._size_tensor_desc.upper_bound._expr)
173
- return self._dim_expr
326
+ if self._int_expr is not None:
327
+ return self._int_expr
174
328
 
329
+ self._int_expr = super()._exprBuilder.declare_size_tensor(self._size_tensor_desc.index, self._size_tensor_desc.opt._expr, self._size_tensor_desc.upper_bound._expr)
330
+ return self._int_expr
331
+
175
332
  def __repr__(self):
176
333
  return f"ShapeExpr[is_size_tensor = True, id={id(self)}]"
177
334
 
335
+ def _from_scalar(s):
336
+ if isinstance(s, int):
337
+ return SymInt32(s)
338
+ elif isinstance(s, float):
339
+ raise ValueError("Float symbolic expressions are not supported")
340
+ else:
341
+ raise ValueError(f"Unsupported type: '{type(s)}'")
342
+
178
343
  # Iterable holding `ShapeExpr`s
179
344
  @public_api()
180
- class ShapeExprs:
181
- def __init__(self, length: int, _is_dummy: bool = False):
345
+ class SymExprs:
346
+ def __init__(self, length: int):
182
347
  """
183
- Iterable holding :class:`ShapeExpr`\s
348
+ Iterable holding symbolic expressions
184
349
 
185
350
  Args:
186
351
  length (int): Number of dimensions of the tensor
187
352
  """
188
353
  self._length = length
354
+ self._exprs = [None] * length
355
+
356
+ @classmethod
357
+ def from_tuple(cls, shape_exprs: Tuple[Union[SymExpr, int]]) -> "SymExprs":
358
+ """
359
+ Args:
360
+ shape_exprs (Tuple[Union[SymExpr, int]]): Tuple to construct :class:`SymExprs` from
361
+ """
362
+
363
+ shape_exprs_ = tuple([e if isinstance(e, SymExpr) else _from_scalar(e) for e in shape_exprs])
364
+ inst = cls(len(shape_exprs_))
365
+ inst._exprs = list(shape_exprs_)
366
+ return inst
367
+
368
+ def __iter__(self):
369
+ return iter(self._exprs)
370
+
371
+ def __getitem__(self, index):
372
+ return self._exprs[index]
373
+
374
+ def __len__(self):
375
+ return self._length
376
+
377
+ def __setitem__(self, index, expr):
378
+ if index >= self._length:
379
+ raise IndexError("Index out of range")
380
+
381
+ if not isinstance(expr, SymExpr):
382
+ expr = _from_scalar(expr)
383
+
384
+ self._exprs[index] = expr
385
+
386
+ def __repr__(self):
387
+ return f"SymExprs[{', '.join([s.__repr__() for s in self._exprs])}]"
388
+
389
+ @public_api()
390
+ class ShapeExprs(SymExprs):
391
+ def __init__(self, length, _is_dummy = False):
392
+ """
393
+ Iterable holding :class:`ShapeExpr`\s, representing a tensor shape
394
+
395
+ Args:
396
+ length (int): Number of dimensions of the tensor
397
+ """
398
+ if length > trt.Dims.MAX_DIMS:
399
+ raise ValueError(f"ShapeExprs can only support up to trt.Dims.MAX_DIMS = {trt.Dims.MAX_DIMS} dimensions. {length} given.")
400
+
401
+ super().__init__(length)
402
+
189
403
  self._is_dummy = _is_dummy
190
404
  if _is_dummy:
191
- self._shapes = [ShapeExpr()] * length
192
- else:
193
- self._shapes = [None] * length
405
+ self._exprs = [ShapeExpr()] * length
194
406
 
195
407
  @classmethod
196
- def from_tuple(cls, shape_exprs: Tuple[Union[ShapeExpr, int]]) -> "ShapeExprs":
408
+ def from_tuple(cls, shape_exprs: Tuple[Union[ShapeExpr, int]]) -> "ShapeExpr":
197
409
  """
198
410
  Args:
199
411
  shape_exprs (Tuple[Union[ShapeExpr, int]]): Tuple to construct :class:`ShapeExprs` from
200
412
  """
413
+
201
414
  shape_exprs_ = tuple([e if isinstance(e, ShapeExpr) else ShapeExpr(e) for e in shape_exprs])
202
415
  inst = cls(len(shape_exprs_))
203
- inst._shapes = list(shape_exprs_)
416
+ inst._exprs = list(shape_exprs_)
204
417
  return inst
205
-
418
+
206
419
  def numel(self) -> ShapeExpr:
207
420
  """
208
421
  Returns a symbolic expression for the number of elements
209
422
  """
210
423
  ret = ShapeExpr(1)
211
- for s in self._shapes:
424
+ for s in self._exprs:
212
425
  ret *= s
213
426
  return ret
214
427
 
215
- def __iter__(self):
216
- return iter(self._shapes)
428
+ def __setitem__(self, index, value):
429
+ if index >= self._length:
430
+ raise IndexError("Index out of range")
431
+
432
+ if not isinstance(value, ShapeExpr):
433
+ if not isinstance(value, int):
434
+ raise ValueError(f"Value should be int or ShapeExpr. Got '{type(value)}'")
435
+ value = ShapeExpr(value)
436
+
437
+ self._exprs[index] = value
438
+
439
+ def __repr__(self):
440
+ return f"ShapeExprs[{', '.join([s.__repr__() for s in self._exprs])}]"
217
441
 
218
- def __getitem__(self, index):
219
- return self._shapes[index]
442
+ def _clone(self):
443
+ ret = ShapeExprs.from_tuple((e._clone() for e in self._exprs))
444
+ ret._is_dummy = self._is_dummy
445
+ return ret
446
+
447
+ @public_api()
448
+ class SymIntExprs(SymExprs):
449
+ def __init__(self, length):
450
+ """
451
+ Iterable holding :class:`SymIntExpr`\s
220
452
 
221
- def __len__(self):
222
- return self._length
453
+ Args:
454
+ length (int): Number of symbolic expressions in the iterable
455
+ """
456
+ super().__init__(length)
457
+
458
+ @classmethod
459
+ def from_tuple(cls, shape_exprs: Tuple[Union[SymIntExpr, int]]) -> "SymIntExpr":
460
+ """
461
+ Args:
462
+ shape_exprs (Tuple[Union[SymIntExpr, int]]): Tuple to construct :class:`SymIntExprs` from
463
+ """
464
+
465
+ shape_exprs_ = tuple([e if isinstance(e, SymIntExpr) else SymIntExpr(e) for e in shape_exprs])
466
+ inst = cls(len(shape_exprs_))
467
+ inst._exprs = list(shape_exprs_)
468
+ return inst
223
469
 
224
- def __setitem__(self, index, shape):
470
+ def __setitem__(self, index, value):
225
471
  if index >= self._length:
226
472
  raise IndexError("Index out of range")
227
- self._shapes[index] = shape
228
-
473
+
474
+ if not isinstance(value, SymIntExpr):
475
+ if not isinstance(value, int):
476
+ raise ValueError(f"Value should be int or SymIntExpr. Got '{type(value)}'")
477
+ value = SymIntExpr(value)
478
+
479
+ self._exprs[index] = value
480
+
229
481
  def __repr__(self):
230
- return f"ShapeExprs[{', '.join([s.__repr__() for s in self._shapes])}]"
231
-
482
+ return f"SymIntExprs[{', '.join([s.__repr__() for s in self._exprs])}]"
232
483
 
233
484
  # Numerical representation of a tensor shape
234
485
  @public_api()
@@ -237,7 +488,7 @@ class Shape:
237
488
  Numerical representation of a tensor shape
238
489
  """
239
490
  def __init__(
240
- self, tensor_desc: Union[Tuple[int], trt.DynamicPluginTensorDesc, trt.PluginTensorDesc]
491
+ self, tensor_desc: Union[Tuple[int], trt.DynamicPluginTensorDesc, trt.PluginTensorDesc] = None
241
492
  ):
242
493
  self._is_dynamic = None # set lazily
243
494
  if isinstance(tensor_desc, trt.DynamicPluginTensorDesc):
@@ -250,6 +501,9 @@ class Shape:
250
501
  elif isinstance(tensor_desc, tuple):
251
502
  self._shapes = trt.Dims(tensor_desc)
252
503
  self._length = len(self._shapes)
504
+ elif tensor_desc is None:
505
+ self._length = 0
506
+ self._shapes = trt.Dims(0)
253
507
  else:
254
508
  raise ValueError("Unsupported type used for constructing trt.plugin.Shape! tensor_desc must be a Tuple[int], trt.DynamicPluginTensorDesc, or trt.PluginTensorDesc")
255
509
 
@@ -335,6 +589,11 @@ class Shape:
335
589
  raise IndexError("Index out of range")
336
590
  self._shapes[index] = val
337
591
 
592
+ def _clone(self):
593
+ ret = Shape()
594
+ ret.__dict__.update(self.__dict__)
595
+ return ret
596
+
338
597
 
339
598
  # Descriptor for a tensor
340
599
  # A `TensorDesc` never contains nor refers to any tensor data.
@@ -416,8 +675,7 @@ class TensorDesc:
416
675
  def _(inp: tensorrt.plugin.TensorDesc) -> tensorrt.plugin.TensorDesc:
417
676
  return inp.like()
418
677
  """
419
- cloned = TensorDesc()
420
- cloned.__dict__.update(self.__dict__)
678
+ cloned = self._clone()
421
679
  cloned._immutable = False
422
680
  return cloned
423
681
 
@@ -436,10 +694,19 @@ class TensorDesc:
436
694
  def _(inp: tensorrt.plugin.TensorDesc) -> tensorrt.plugin.TensorDesc:
437
695
  return inp.aliased()
438
696
  """
697
+ cloned = self._clone()
698
+ cloned._immutable = False
699
+ cloned._aliased_to = self
700
+ cloned._immutable = True
701
+ return cloned
702
+
703
+ def _clone(self) -> "TensorDesc":
439
704
  cloned = TensorDesc()
440
705
  cloned.__dict__.update(self.__dict__)
441
706
  cloned._immutable = False
442
- cloned._aliased_to = self
707
+ cloned._shape_expr = self._shape_expr._clone()
708
+ if self._shape is not None:
709
+ cloned._shape = self._shape._clone()
443
710
  cloned._immutable = True
444
711
  return cloned
445
712
 
@@ -823,3 +1090,39 @@ class Tensor:
823
1090
 
824
1091
  cloned._aliased_to = self
825
1092
  return cloned
1093
+
1094
+ if IS_AOT_ENABLED:
1095
+ @public_api()
1096
+ class KernelLaunchParams:
1097
+ """
1098
+ Args:
1099
+ grid_x (Union[int, trt.IDimensionExpr, SymInt32], optional): The grid x dimension. Defaults to 1.
1100
+ grid_y (Union[int, trt.IDimensionExpr, SymInt32], optional): The grid y dimension. Defaults to 1.
1101
+ grid_z (Union[int, trt.IDimensionExpr, SymInt32], optional): The grid z dimension. Defaults to 1.
1102
+ block_x (Union[int, trt.IDimensionExpr, SymInt32], optional): The x dimension of each thread block. Defaults to 1.
1103
+ block_y (Union[int, trt.IDimensionExpr, SymInt32], optional): The y dimension of each thread block. Defaults to 1.
1104
+ block_z (Union[int, trt.IDimensionExpr, SymInt32], optional): The z dimension of each thread block. Defaults to 1.
1105
+ shared_mem (Union[int, trt.IDimensionExpr, SymInt32], optional): Shared-memory per thread block in bytes. Defaults to 0.
1106
+ """
1107
+ def __init__(self,
1108
+ grid_x: Union[int, trt.IDimensionExpr, SymInt32] = 1,
1109
+ grid_y: Union[int, trt.IDimensionExpr, SymInt32] = 1,
1110
+ grid_z: Union[int, trt.IDimensionExpr, SymInt32] = 1,
1111
+ block_x: Union[int, trt.IDimensionExpr, SymInt32] = 1,
1112
+ block_y: Union[int, trt.IDimensionExpr, SymInt32] = 1,
1113
+ block_z: Union[int, trt.IDimensionExpr, SymInt32] = 1,
1114
+ shared_mem: Union[int, trt.IDimensionExpr, SymInt32] = 0):
1115
+ self.grid_x = SymInt32(grid_x)
1116
+ self.grid_y = SymInt32(grid_y)
1117
+ self.grid_z = SymInt32(grid_z)
1118
+ self.block_x = SymInt32(block_x)
1119
+ self.block_y = SymInt32(block_y)
1120
+ self.block_z = SymInt32(block_z)
1121
+ self.shared_mem = SymInt32(shared_mem)
1122
+
1123
+
1124
+ def __setattr__(self, name, value):
1125
+ if name in ["grid_x", "grid_y", "grid_z", "block_x", "block_y", "block_z", "shared_mem"]:
1126
+ self.__dict__[name] = SymInt32(value)
1127
+ else:
1128
+ raise AttributeError(f"KernelLaunchParams object has no attribute '{name}'")