tico 0.1.0.dev250411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. tico/__init__.py +31 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
  55. tico/experimental/quantization/public_interface.py +108 -0
  56. tico/experimental/quantization/quantizer.py +71 -0
  57. tico/interpreter/__init__.py +1 -0
  58. tico/interpreter/infer.py +116 -0
  59. tico/interpreter/interpreter.py +93 -0
  60. tico/passes/__init__.py +1 -0
  61. tico/passes/cast_aten_where_arg_type.py +185 -0
  62. tico/passes/cast_mixed_type_args.py +186 -0
  63. tico/passes/const_prop_pass.py +307 -0
  64. tico/passes/convert_conv1d_to_conv2d.py +151 -0
  65. tico/passes/convert_layout_op_to_reshape.py +84 -0
  66. tico/passes/convert_repeat_to_expand_copy.py +90 -0
  67. tico/passes/convert_to_relu6.py +180 -0
  68. tico/passes/decompose_addmm.py +127 -0
  69. tico/passes/decompose_batch_norm.py +198 -0
  70. tico/passes/decompose_fake_quantize.py +126 -0
  71. tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
  72. tico/passes/decompose_group_norm.py +258 -0
  73. tico/passes/decompose_grouped_conv2d.py +202 -0
  74. tico/passes/decompose_slice_scatter.py +167 -0
  75. tico/passes/extract_dtype_kwargs.py +121 -0
  76. tico/passes/fill_meta_val.py +57 -0
  77. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  78. tico/passes/legalize_causal_mask_value.py +113 -0
  79. tico/passes/legalize_predefined_layout_operators.py +383 -0
  80. tico/passes/lower_pow2_to_mul.py +75 -0
  81. tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
  82. tico/passes/lower_to_slice.py +112 -0
  83. tico/passes/merge_consecutive_cat.py +82 -0
  84. tico/passes/ops.py +75 -0
  85. tico/passes/remove_nop.py +85 -0
  86. tico/passes/remove_redundant_assert_nodes.py +50 -0
  87. tico/passes/remove_redundant_expand.py +70 -0
  88. tico/passes/remove_redundant_permute.py +102 -0
  89. tico/passes/remove_redundant_reshape.py +431 -0
  90. tico/passes/remove_redundant_slice.py +64 -0
  91. tico/passes/remove_redundant_to_copy.py +84 -0
  92. tico/passes/restore_linear.py +113 -0
  93. tico/passes/segment_index_select.py +143 -0
  94. tico/pt2_to_circle.py +101 -0
  95. tico/serialize/__init__.py +1 -0
  96. tico/serialize/circle_graph.py +264 -0
  97. tico/serialize/circle_mapping.py +177 -0
  98. tico/serialize/circle_serializer.py +232 -0
  99. tico/serialize/operators/__init__.py +28 -0
  100. tico/serialize/operators/hashable_opcode.py +43 -0
  101. tico/serialize/operators/node_visitor.py +80 -0
  102. tico/serialize/operators/op_add.py +69 -0
  103. tico/serialize/operators/op_alias_copy.py +64 -0
  104. tico/serialize/operators/op_any.py +142 -0
  105. tico/serialize/operators/op_arange_start_step.py +61 -0
  106. tico/serialize/operators/op_argmax.py +62 -0
  107. tico/serialize/operators/op_avg_pool2d.py +112 -0
  108. tico/serialize/operators/op_bmm.py +62 -0
  109. tico/serialize/operators/op_cat.py +66 -0
  110. tico/serialize/operators/op_clamp.py +123 -0
  111. tico/serialize/operators/op_clone.py +71 -0
  112. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  113. tico/serialize/operators/op_conv2d.py +181 -0
  114. tico/serialize/operators/op_copy.py +162 -0
  115. tico/serialize/operators/op_cos.py +59 -0
  116. tico/serialize/operators/op_cumsum.py +92 -0
  117. tico/serialize/operators/op_depthwise_conv2d.py +198 -0
  118. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  119. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  120. tico/serialize/operators/op_div.py +62 -0
  121. tico/serialize/operators/op_embedding.py +60 -0
  122. tico/serialize/operators/op_eq.py +64 -0
  123. tico/serialize/operators/op_exp.py +60 -0
  124. tico/serialize/operators/op_expand.py +91 -0
  125. tico/serialize/operators/op_full.py +48 -0
  126. tico/serialize/operators/op_full_like.py +55 -0
  127. tico/serialize/operators/op_ge.py +54 -0
  128. tico/serialize/operators/op_gelu.py +59 -0
  129. tico/serialize/operators/op_gt.py +54 -0
  130. tico/serialize/operators/op_index.py +82 -0
  131. tico/serialize/operators/op_index_select.py +64 -0
  132. tico/serialize/operators/op_instance_norm.py +91 -0
  133. tico/serialize/operators/op_linear.py +70 -0
  134. tico/serialize/operators/op_log.py +53 -0
  135. tico/serialize/operators/op_log1p.py +83 -0
  136. tico/serialize/operators/op_logical_and.py +63 -0
  137. tico/serialize/operators/op_logical_not.py +62 -0
  138. tico/serialize/operators/op_lt.py +61 -0
  139. tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
  140. tico/serialize/operators/op_maximum.py +53 -0
  141. tico/serialize/operators/op_mean.py +66 -0
  142. tico/serialize/operators/op_minimum.py +53 -0
  143. tico/serialize/operators/op_mm.py +174 -0
  144. tico/serialize/operators/op_mul.py +99 -0
  145. tico/serialize/operators/op_ne.py +54 -0
  146. tico/serialize/operators/op_neg.py +59 -0
  147. tico/serialize/operators/op_permute.py +65 -0
  148. tico/serialize/operators/op_pow.py +138 -0
  149. tico/serialize/operators/op_prelu.py +54 -0
  150. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  151. tico/serialize/operators/op_reciprocal.py +64 -0
  152. tico/serialize/operators/op_relu.py +53 -0
  153. tico/serialize/operators/op_relu6.py +52 -0
  154. tico/serialize/operators/op_repeat.py +99 -0
  155. tico/serialize/operators/op_reshape.py +73 -0
  156. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  157. tico/serialize/operators/op_rsqrt.py +53 -0
  158. tico/serialize/operators/op_scalar_tensor.py +51 -0
  159. tico/serialize/operators/op_select_copy.py +65 -0
  160. tico/serialize/operators/op_sigmoid.py +56 -0
  161. tico/serialize/operators/op_sin.py +53 -0
  162. tico/serialize/operators/op_slice.py +155 -0
  163. tico/serialize/operators/op_softmax.py +100 -0
  164. tico/serialize/operators/op_split_with_sizes.py +96 -0
  165. tico/serialize/operators/op_sqrt.py +55 -0
  166. tico/serialize/operators/op_squeeze.py +73 -0
  167. tico/serialize/operators/op_sub.py +71 -0
  168. tico/serialize/operators/op_sum.py +63 -0
  169. tico/serialize/operators/op_tanh.py +54 -0
  170. tico/serialize/operators/op_to_copy.py +105 -0
  171. tico/serialize/operators/op_unsqueeze.py +66 -0
  172. tico/serialize/operators/op_view.py +74 -0
  173. tico/serialize/operators/op_where.py +82 -0
  174. tico/serialize/operators/utils.py +51 -0
  175. tico/serialize/pack.py +35 -0
  176. tico/serialize/quant_param.py +42 -0
  177. tico/utils/__init__.py +1 -0
  178. tico/utils/convert.py +292 -0
  179. tico/utils/define.py +35 -0
  180. tico/utils/diff_graph.py +181 -0
  181. tico/utils/errors.py +35 -0
  182. tico/utils/graph.py +200 -0
  183. tico/utils/logging.py +45 -0
  184. tico/utils/model.py +37 -0
  185. tico/utils/padding.py +47 -0
  186. tico/utils/passes.py +76 -0
  187. tico/utils/register_custom_op.py +562 -0
  188. tico/utils/trace_decorators.py +101 -0
  189. tico/utils/utils.py +314 -0
  190. tico/utils/validate_args_kwargs.py +1114 -0
  191. tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
  192. tico-0.1.0.dev250411.dist-info/METADATA +17 -0
  193. tico-0.1.0.dev250411.dist-info/RECORD +196 -0
  194. tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
  195. tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
  196. tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1114 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+ from typing import List, Optional, TYPE_CHECKING, Union
17
+
18
+ if TYPE_CHECKING:
19
+ import torch._ops
20
+ import torch.fx
21
+ import torch
22
+ import torch.fx.node
23
+
24
+ from tico.utils.utils import enforce_type
25
+
26
+ """
27
+ This file includes OpArgs classes that provide arguments with type annotations.
28
+ - Each class provides type-checked arguments for the aten Op in the comment.
29
+ - Class name is determined by the follwoing priority.
30
+ 1. Torch spec (aten/src/ATen/native/native_functions.yaml in pytorch repo)
31
+ 2. pytorch doc (https://pytorch.org/docs/stable/index.html)
32
+ """
33
+
34
+
35
+ @enforce_type
36
+ @dataclass
37
+ class AddTensorArgs:
38
+ """
39
+ add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
40
+ """
41
+
42
+ input: Union[torch.fx.Node, float, int, torch.Tensor]
43
+ other: Union[torch.fx.Node, float, int, torch.Tensor]
44
+
45
+
46
+ @enforce_type
47
+ @dataclass
48
+ class AddmmArgs:
49
+ """
50
+ addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
51
+ """
52
+
53
+ input: torch.fx.Node
54
+ mat1: torch.fx.Node
55
+ mat2: torch.fx.Node
56
+ beta: Union[int, float] = 1
57
+ alpha: Union[int, float] = 1
58
+
59
+
60
+ @enforce_type
61
+ @dataclass
62
+ class AliasCopyArgs:
63
+ """
64
+ alias_copy(Tensor self) -> Tensor
65
+ """
66
+
67
+ input: torch.fx.Node
68
+
69
+
70
+ @enforce_type
71
+ @dataclass
72
+ class AnyArgs:
73
+ """
74
+ any(Tensor self) -> Tensor
75
+ any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor
76
+ any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor
77
+ """
78
+
79
+ input: torch.fx.Node
80
+ dim: Union[int, tuple, None] = None
81
+ keepdim: bool = False
82
+
83
+
84
+ @enforce_type
85
+ @dataclass
86
+ class ArangeStartStepArgs:
87
+ """
88
+ arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
89
+ """
90
+
91
+ start: Union[int, float]
92
+ end: Union[int, float]
93
+ step: Union[int, float] = 1
94
+
95
+
96
+ @enforce_type
97
+ @dataclass
98
+ class ArgMaxArgs:
99
+ """
100
+ argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor
101
+ """
102
+
103
+ tensor: Union[torch.fx.Node, torch.Tensor]
104
+ dim: Union[int, None] = None
105
+
106
+
107
+ @enforce_type
108
+ @dataclass
109
+ class AvgPool2dArgs:
110
+ """
111
+ avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> (Tensor)
112
+ """
113
+
114
+ input: torch.fx.Node
115
+ kernel_size: List[int]
116
+ stride: List[int] = field(default_factory=list)
117
+ padding: List[int] = field(default_factory=lambda: [0, 0])
118
+ ceil_mode: bool = field(default=False)
119
+ count_include_pad: bool = field(default=True)
120
+ divisor_override: Optional[Union[int, None]] = None
121
+
122
+ def __post_init__(self):
123
+ assert len(self.kernel_size) == 2, len(self.kernel_size)
124
+ assert len(self.stride) == 2, len(self.stride)
125
+ if self.padding is not None:
126
+ assert len(self.padding) == 2, len(self.padding)
127
+ if self.divisor_override is not None:
128
+ assert isinstance(self.divisor_override, int), type(self.divisor_override)
129
+ assert self.divisor_override != 0, f"Divisor must be not zero."
130
+
131
+
132
+ @enforce_type
133
+ @dataclass
134
+ class AdaptiveAvgPool2dArgs:
135
+ """
136
+ adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor
137
+ """
138
+
139
+ input: torch.fx.Node
140
+ output_size: List[int]
141
+
142
+
143
+ @enforce_type
144
+ @dataclass
145
+ class BmmArgs:
146
+ """
147
+ bmm(Tensor self, Tensor mat2) -> Tensor
148
+ """
149
+
150
+ input: torch.fx.Node
151
+ mat2: torch.fx.Node
152
+
153
+
154
+ @enforce_type
155
+ @dataclass
156
+ class CatArgs:
157
+ """
158
+ cat(Tensor[] tensors, int dim=0) -> Tensor
159
+ """
160
+
161
+ tensors: List[torch.fx.Node]
162
+ dim: int = 0
163
+
164
+
165
+ @enforce_type
166
+ @dataclass
167
+ class ClampArgs:
168
+ """
169
+ clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
170
+ """
171
+
172
+ input: torch.fx.Node
173
+ min: Optional[Union[int, float]] = None
174
+ max: Optional[Union[int, float]] = None
175
+
176
+
177
+ @enforce_type
178
+ @dataclass
179
+ class CloneArgs:
180
+ """
181
+ clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor
182
+ """
183
+
184
+ input: torch.fx.Node
185
+ memory_format: Optional[torch.memory_format] = None
186
+
187
+
188
+ @enforce_type
189
+ @dataclass
190
+ class ConstantPadNdArgs:
191
+ """
192
+ constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor
193
+ """
194
+
195
+ input: torch.fx.Node
196
+ pad: List[int]
197
+ value: int | float
198
+
199
+
200
+ @enforce_type
201
+ @dataclass
202
+ class Conv2DArgs:
203
+ """
204
+ conv2d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, SymInt groups=1) -> Tensor
205
+ conv2d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, str padding="valid", SymInt[2] dilation=1, SymInt groups=1) -> Tensor
206
+ """
207
+
208
+ input: torch.fx.Node
209
+ weight: torch.fx.Node
210
+ bias: Union[torch.fx.Node, None] = None
211
+ stride: List[int] = field(default_factory=lambda: [1, 1])
212
+ padding: Union[List[int], str] = field(default_factory=lambda: [0, 0])
213
+ dilation: List[int] = field(default_factory=lambda: [1, 1])
214
+ groups: int = 1
215
+
216
+ def __post_init__(self):
217
+ assert len(self.stride) == 2, len(self.stride)
218
+ assert len(self.dilation) == 2, len(self.dilation)
219
+
220
+
221
+ @enforce_type
222
+ @dataclass
223
+ class Conv1DArgs:
224
+ """
225
+ conv1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] dilation=1, SymInt groups=1) -> Tensor
226
+ conv1d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, str padding="valid", SymInt[1] dilation=1, SymInt groups=1) -> Tensor
227
+ """
228
+
229
+ input: torch.fx.Node
230
+ weight: torch.fx.Node
231
+ bias: Union[torch.fx.Node, None] = None
232
+ stride: List[int] = field(default_factory=lambda: [1])
233
+ padding: Union[List[int], str] = field(default_factory=lambda: [0])
234
+ dilation: List[int] = field(default_factory=lambda: [1])
235
+ groups: int = 1
236
+
237
+ def __post_init__(self):
238
+ assert len(self.stride) == 1, len(self.stride)
239
+ assert len(self.dilation) == 1, len(self.dilation)
240
+
241
+
242
+ @enforce_type
243
+ @dataclass
244
+ class CopyArgs:
245
+ """
246
+ copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor
247
+ """
248
+
249
+ dst: torch.fx.Node
250
+ src: torch.fx.Node
251
+
252
+
253
+ @enforce_type
254
+ @dataclass
255
+ class CosArgs:
256
+ """
257
+ cos(Tensor self) -> Tensor
258
+ """
259
+
260
+ input: torch.fx.Node
261
+
262
+
263
+ @enforce_type
264
+ @dataclass
265
+ class CumsumArgs:
266
+ """
267
+ cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor
268
+ """
269
+
270
+ input: torch.fx.Node
271
+ dim: int
272
+
273
+
274
+ @enforce_type
275
+ @dataclass
276
+ class DequantizePerChannelArgs:
277
+ """
278
+ quantized_decomposed.dequantize_per_channel(Tensor input, Tensor scales, Tensor? zero_points, int axis, int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor
279
+ """
280
+
281
+ input: torch.fx.Node
282
+ scales: torch.fx.Node
283
+ zero_points: torch.fx.Node
284
+ axis: int
285
+ quant_min: int
286
+ quant_max: int
287
+ dtype: torch.dtype
288
+
289
+
290
+ @enforce_type
291
+ @dataclass
292
+ class DequantizePerTensorArgs:
293
+ """
294
+ quantized_decomposed.dequantize_per_tensor(input: TensorBox, scale: float, zero_point: int, quant_min: int, quant_max: int, dtype: torch.dtype) -> TensorBox
295
+ """
296
+
297
+ input: torch.fx.Node
298
+ scale: float
299
+ zero_point: int
300
+ quant_min: int
301
+ quant_max: int
302
+ dtype: torch.dtype
303
+
304
+
305
+ @enforce_type
306
+ @dataclass
307
+ class DivTensorArgs:
308
+ """
309
+ div.Tensor(Tensor self, Tensor other) -> Tensor
310
+ """
311
+
312
+ input: Union[torch.fx.Node, float, int, torch.Tensor]
313
+ other: Union[torch.fx.Node, float, int, torch.Tensor]
314
+
315
+
316
+ @enforce_type
317
+ @dataclass
318
+ class EmbeddingArgs:
319
+ """
320
+ embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor
321
+ """
322
+
323
+ weight: torch.fx.Node
324
+ indices: torch.fx.Node
325
+ padding_idx: int = 1
326
+ scale_grad_by_freq: bool = False
327
+ sparse: bool = False
328
+
329
+
330
+ @enforce_type
331
+ @dataclass
332
+ class EqArgs:
333
+ """
334
+ eq.Scalar(Tensor self, Scalar other) -> Tensor
335
+ eq.Tensor(Tensor self, Tensor other) -> Tensor
336
+ """
337
+
338
+ input: Union[torch.fx.Node, torch.Tensor, float, int]
339
+ other: Union[torch.fx.Node, torch.Tensor, float, int]
340
+
341
+
342
+ @enforce_type
343
+ @dataclass
344
+ class ExpArgs:
345
+ """
346
+ exp(Tensor self) -> Tensor
347
+ """
348
+
349
+ input: torch.fx.Node
350
+
351
+
352
+ @enforce_type
353
+ @dataclass
354
+ class ExpandArgs:
355
+ """
356
+ expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
357
+ expand_copy(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor
358
+ """
359
+
360
+ input: torch.fx.Node
361
+ size: List[int]
362
+
363
+
364
+ @enforce_type
365
+ @dataclass
366
+ class FakeQuantizePerChannelArgs:
367
+ """
368
+ fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor
369
+ fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask)
370
+ """
371
+
372
+ input: torch.fx.Node
373
+ scale: torch.fx.Node
374
+ zero_point: torch.fx.Node
375
+ axis: int
376
+ quant_min: int
377
+ quant_max: int
378
+
379
+
380
+ @enforce_type
381
+ @dataclass
382
+ class FakeQuantizePerTensorTQParamArgs:
383
+ """
384
+ fake_quantize_per_tensor_affine.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor
385
+ """
386
+
387
+ input: torch.fx.Node
388
+ scale: torch.fx.Node
389
+ zero_point: torch.fx.Node
390
+ quant_min: int
391
+ quant_max: int
392
+
393
+
394
+ @enforce_type
395
+ @dataclass
396
+ class FullLikeArgs:
397
+ """
398
+ full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
399
+ """
400
+
401
+ input: torch.fx.Node
402
+ fill_value: Union[int, float, bool]
403
+ pin_memory: Optional[bool] = None
404
+ memory_format: Optional[torch.memory_format] = None
405
+
406
+
407
+ @enforce_type
408
+ @dataclass
409
+ class FullArgs:
410
+ """
411
+ full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
412
+ """
413
+
414
+ size: Union[list, tuple, torch.Size]
415
+ fill_value: Union[int, float, bool]
416
+
417
+
418
+ @enforce_type
419
+ @dataclass
420
+ class GeArgs:
421
+ """
422
+ ge.Scalar(Tensor self, Scalar other) -> Tensor
423
+ ge.Tensor(Tensor self, Tensor other) -> Tensor
424
+ """
425
+
426
+ input: Union[torch.fx.Node, torch.Tensor, float, int]
427
+ other: Union[torch.fx.Node, torch.Tensor, float, int]
428
+
429
+
430
+ @enforce_type
431
+ @dataclass
432
+ class GeluArgs:
433
+ """
434
+ gelu(Tensor self, *, str approximate='none') -> Tensor
435
+ """
436
+
437
+ input: torch.fx.Node
438
+ approximate: Optional[str] = "none"
439
+
440
+
441
+ @enforce_type
442
+ @dataclass
443
+ class GtArgs:
444
+ """
445
+ gt.Scalar(Tensor self, Scalar other) -> Tensor
446
+ gt.Tensor(Tensor self, Tensor other) -> Tensor
447
+ """
448
+
449
+ input: Union[torch.fx.Node, torch.Tensor, float, int]
450
+ other: Union[torch.fx.Node, torch.Tensor, float, int]
451
+
452
+
453
+ @enforce_type
454
+ @dataclass
455
+ class HardTanhArgs:
456
+ """
457
+ hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor
458
+ """
459
+
460
+ input: Union[torch.fx.Node, torch.Tensor]
461
+ min_val: Union[float, int] = -1
462
+ max_val: Union[float, int] = 1
463
+
464
+
465
+ @enforce_type
466
+ @dataclass
467
+ class IndexSelectArgs:
468
+ """
469
+ index_select(Tensor self, int dim, Tensor index) -> Tensor
470
+ """
471
+
472
+ input: torch.fx.Node
473
+ dim: int
474
+ index: Union[torch.fx.Node, torch.Tensor]
475
+
476
+
477
+ @enforce_type
478
+ @dataclass
479
+ class IndexArgs:
480
+ """
481
+ index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
482
+ """
483
+
484
+ input: torch.fx.Node
485
+ indices: List[Union[torch.fx.Node, torch.Tensor, int, None]]
486
+
487
+
488
+ @enforce_type
489
+ @dataclass
490
+ class InstanceNormArgs:
491
+ """
492
+ instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor
493
+ """
494
+
495
+ input: torch.fx.Node
496
+ weight: Optional[torch.fx.Node]
497
+ bias: Optional[torch.fx.Node]
498
+ running_mean: Optional[torch.fx.Node]
499
+ running_var: Optional[torch.fx.Node]
500
+ use_input_stats: bool
501
+ momentum: float
502
+ eps: float
503
+ cudnn_enabled: bool
504
+
505
+
506
+ @enforce_type
507
+ @dataclass
508
+ class LinearArgs:
509
+ """
510
+ linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
511
+ """
512
+
513
+ input: torch.fx.Node
514
+ weight: torch.fx.Node
515
+ bias: Optional[torch.fx.Node] = None
516
+
517
+
518
+ @enforce_type
519
+ @dataclass
520
+ class LogArgs:
521
+ """
522
+ log(Tensor self) -> Tensor
523
+ """
524
+
525
+ input: torch.fx.Node
526
+
527
+
528
+ @enforce_type
529
+ @dataclass
530
+ class Log1pArgs:
531
+ """
532
+ log1p(Tensor self) -> Tensor
533
+ """
534
+
535
+ input: torch.fx.Node
536
+
537
+
538
+ @enforce_type
539
+ @dataclass
540
+ class LogicalAndArgs:
541
+ """
542
+ logical_and(Tensor self, Tensor other) -> Tensor
543
+ """
544
+
545
+ input: torch.fx.Node
546
+ other: torch.fx.Node
547
+
548
+
549
+ @enforce_type
550
+ @dataclass
551
+ class LogicalNotArgs:
552
+ """
553
+ logical_not(Tensor self) -> Tensor
554
+ """
555
+
556
+ input: torch.fx.Node
557
+
558
+
559
+ @enforce_type
560
+ @dataclass
561
+ class LtArgs:
562
+ """
563
+ lt.Tensor(Tensor self, Tensor other) -> Tensor
564
+ """
565
+
566
+ input: torch.fx.Node
567
+ other: torch.fx.Node
568
+
569
+
570
+ @enforce_type
571
+ @dataclass
572
+ class MaxPool2dWithIndicesArgs:
573
+ """
574
+ max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
575
+ """
576
+
577
+ input: torch.fx.Node
578
+ kernel_size: List[int]
579
+ stride: List[int] = field(default_factory=list)
580
+ padding: List[int] = field(default_factory=lambda: [0, 0])
581
+ dilation: List[int] = field(default_factory=lambda: [1, 1])
582
+ ceil_mode: bool = field(default=False)
583
+
584
+ def __post_init__(self):
585
+ assert len(self.kernel_size) == 2, len(self.kernel_size)
586
+ assert len(self.stride) == 2, len(self.stride)
587
+ if self.padding is not None:
588
+ assert len(self.padding) == 2, len(self.padding)
589
+ if self.dilation is not None:
590
+ assert len(self.dilation) == 2, len(self.dilation)
591
+
592
+
593
+ @enforce_type
594
+ @dataclass
595
+ class MeanDimArgs:
596
+ """
597
+ mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
598
+ """
599
+
600
+ input: torch.fx.Node
601
+ dim: List[int]
602
+ keep_dims: bool = False
603
+ dtype: Optional[torch.dtype] = None
604
+
605
+
606
+ @enforce_type
607
+ @dataclass
608
+ class MatmulArgs:
609
+ """
610
+ mm(Tensor self, Tensor mat2) -> Tensor
611
+ """
612
+
613
+ input: torch.fx.Node
614
+ other: torch.fx.Node
615
+
616
+
617
+ @enforce_type
618
+ @dataclass
619
+ class MaximumArgs:
620
+ """
621
+ maximum(Tensor self, Tensor other) -> Tensor
622
+ """
623
+
624
+ input: Union[torch.fx.Node, torch.Tensor]
625
+ other: Union[torch.fx.Node, torch.Tensor]
626
+
627
+
628
+ @enforce_type
629
+ @dataclass
630
+ class MinimumArgs:
631
+ """
632
+ minimum(Tensor self, Tensor other) -> Tensor
633
+ """
634
+
635
+ input: Union[torch.fx.Node, torch.Tensor]
636
+ other: Union[torch.fx.Node, torch.Tensor]
637
+
638
+
639
+ @enforce_type
640
+ @dataclass
641
+ class MulTensorArgs:
642
+ """
643
+ mul.Tensor(Tensor self, Tensor other) -> Tensor
644
+ """
645
+
646
+ input: Union[torch.fx.Node, torch.Tensor, int, float]
647
+ other: Union[torch.fx.Node, torch.Tensor, int, float]
648
+
649
+
650
+ @enforce_type
651
+ @dataclass
652
+ class MulScalarArgs:
653
+ """
654
+ mul.Scalar(Tensor self, Scalar other) -> Tensor
655
+ """
656
+
657
+ input: torch.fx.Node
658
+ other: Union[int, float]
659
+
660
+
661
+ @enforce_type
662
+ @dataclass
663
+ class NeScalarArgs:
664
+ """
665
+ ne.Scalar(Tensor self, Scalar other) -> Tensor
666
+ """
667
+
668
+ input: Union[torch.fx.Node, torch.Tensor, float, int, bool]
669
+ other: Union[torch.fx.Node, torch.Tensor, float, int, bool]
670
+
671
+
672
+ @enforce_type
673
+ @dataclass
674
+ class NativeBatchNormLegitNoTrainingArgs:
675
+ """
676
+ _native_batch_norm_legit_no_training (Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor)
677
+ """
678
+
679
+ input: torch.fx.Node
680
+ weight: Optional[torch.fx.Node]
681
+ bias: Optional[torch.fx.Node]
682
+ running_mean: Optional[torch.fx.Node]
683
+ running_var: Optional[torch.fx.Node]
684
+ momentum: float
685
+ eps: float
686
+
687
+
688
+ @enforce_type
689
+ @dataclass
690
+ class NativeGroupNormArgs:
691
+ """
692
+ native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)
693
+ """
694
+
695
+ input: torch.fx.Node
696
+ weight: Optional[torch.fx.Node]
697
+ bias: Optional[torch.fx.Node]
698
+ N: int
699
+ C: int
700
+ HxW: int
701
+ group: int
702
+ eps: float
703
+
704
+
705
+ @enforce_type
706
+ @dataclass
707
+ class NativeLayerNormArgs:
708
+ """
709
+ native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
710
+ """
711
+
712
+ input: torch.fx.Node
713
+ normalized_shape: Union[tuple, list]
714
+ weight: Optional[torch.fx.Node]
715
+ bias: Optional[torch.fx.Node]
716
+ eps: float
717
+
718
+
719
+ @enforce_type
720
+ @dataclass
721
+ class NeTensorArgs:
722
+ """
723
+ ne.Tensor(Tensor self, Tensor other) -> Tensor
724
+ """
725
+
726
+ input: Union[torch.fx.Node, torch.Tensor, float, int, bool]
727
+ other: Union[torch.fx.Node, torch.Tensor, float, int, bool]
728
+
729
+
730
+ @enforce_type
731
+ @dataclass
732
+ class NegArgs:
733
+ """
734
+ neg(Tensor self) -> Tensor
735
+ """
736
+
737
+ input: torch.fx.Node
738
+
739
+
740
+ @enforce_type
741
+ @dataclass
742
+ class PermuteArgs:
743
+ """
744
+ permute(Tensor(a) self, int[] dims) -> Tensor(a)
745
+ """
746
+
747
+ input: torch.fx.Node
748
+ dims: List[int]
749
+
750
+
751
+ @enforce_type
752
+ @dataclass
753
+ class PowTensorTensorArgs:
754
+ """
755
+ pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor
756
+ """
757
+
758
+ input: torch.fx.Node
759
+ exponent: Union[torch.fx.Node]
760
+
761
+
762
+ @enforce_type
763
+ @dataclass
764
+ class PowTensorScalarArgs:
765
+ """
766
+ pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor
767
+ """
768
+
769
+ input: torch.fx.Node
770
+ exponent: Union[float, int]
771
+
772
+
773
+ @enforce_type
774
+ @dataclass
775
+ class PReLUArgs:
776
+ """
777
+ prelu(Tensor self, Tensor weight) -> Tensor
778
+ """
779
+
780
+ input: torch.fx.Node
781
+ weight: torch.fx.Node
782
+
783
+
784
+ @enforce_type
785
+ @dataclass
786
+ class QuantizePerTensorArgs:
787
+ """
788
+ quantized_decomposed.quantize_per_tensor(input: TensorBox, scale: float, zero_point: int, quant_min: int, quant_max: int, dtype: torch.dtype) -> TensorBox
789
+ """
790
+
791
+ tensor: torch.fx.Node
792
+ scale: float
793
+ zero_p: int
794
+ quant_min: int
795
+ quant_max: int
796
+ dtype: torch.dtype
797
+
798
+
799
+ @enforce_type
800
+ @dataclass
801
+ class ReciprocalArgs:
802
+ """
803
+ reciprocal(Tensor self) -> Tensor
804
+ """
805
+
806
+ input: torch.fx.Node
807
+
808
+
809
+ @enforce_type
810
+ @dataclass
811
+ class ReluArgs:
812
+ """
813
+ relu(Tensor self) -> Tensor
814
+ """
815
+
816
+ input: torch.fx.Node
817
+
818
+
819
+ @enforce_type
820
+ @dataclass
821
+ class Relu6Args:
822
+ """
823
+ relu6(Tensor self) -> Tensor
824
+ """
825
+
826
+ input: torch.fx.Node
827
+
828
+
829
+ @enforce_type
830
+ @dataclass
831
+ class RepeatArgs:
832
+ """
833
+ repeat(Tensor self, SymInt[] repeats) -> Tensor
834
+ """
835
+
836
+ input: torch.fx.Node
837
+ repeats: List[int]
838
+
839
+
840
+ @enforce_type
841
+ @dataclass
842
+ class ReshapeArgs:
843
+ """
844
+ reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)
845
+ """
846
+
847
+ input: torch.fx.Node
848
+ size: List[int]
849
+
850
+
851
+ @enforce_type
852
+ @dataclass
853
+ class ResizeNearestNeighborArgs:
854
+ """
855
+ # Maps from `torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode='nearest')` case.
856
+ """
857
+
858
+ input: torch.fx.Node
859
+ size: List[int]
860
+
861
+
862
+ @enforce_type
863
+ @dataclass
864
+ class RsqrtArgs:
865
+ """
866
+ rsqrt(Tensor self) -> Tensor
867
+ """
868
+
869
+ input: torch.fx.Node
870
+
871
+
872
+ @enforce_type
873
+ @dataclass
874
+ class ScalarTensorArgs:
875
+ """
876
+ scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
877
+ """
878
+
879
+ scalar: Union[int, float]
880
+
881
+
882
+ @enforce_type
883
+ @dataclass
884
+ class SelectCopyIntArgs:
885
+ """
886
+ select_copy.int(Tensor self, int dim, SymInt index) -> Tensor
887
+ """
888
+
889
+ input: torch.fx.Node
890
+ dim: int
891
+ index: int
892
+
893
+
894
+ @enforce_type
895
+ @dataclass
896
+ class SigmoidArgs:
897
+ """
898
+ sigmoid(Tensor self) -> Tensor
899
+ """
900
+
901
+ input: torch.fx.Node
902
+
903
+
904
+ @enforce_type
905
+ @dataclass
906
+ class SinArgs:
907
+ """
908
+ sin(Tensor self) -> Tensor
909
+ """
910
+
911
+ input: torch.fx.Node
912
+
913
+
914
+ @enforce_type
915
+ @dataclass
916
+ class SliceArgs:
917
+ """
918
+ slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
919
+ slice_copy.Tensor(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor
920
+ """
921
+
922
+ input: torch.fx.Node
923
+ dim: int = 0
924
+ start: Optional[int] = None
925
+ end: Optional[int] = None
926
+ step: Optional[int] = 1
927
+
928
+
929
+ @enforce_type
930
+ @dataclass
931
+ class SafeSoftmaxArgs:
932
+ """
933
+ _safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
934
+ """
935
+
936
+ input: torch.fx.Node
937
+ dim: int
938
+ dtype: Optional[torch.dtype] = None
939
+
940
+
941
+ @enforce_type
942
+ @dataclass
943
+ class SoftmaxArgs:
944
+ """
945
+ _softmax(Tensor self, int dim, bool half_to_float) -> Tensor
946
+ """
947
+
948
+ input: torch.fx.Node
949
+ dim: int
950
+ half_to_float: bool
951
+
952
+
953
+ @enforce_type
954
+ @dataclass
955
+ class SplitWithSizesArgs:
956
+ """
957
+ split_with_sizes(Tensor(a->*) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[]
958
+ """
959
+
960
+ input: torch.fx.Node
961
+ split_sizes: List[int]
962
+ dim: int = 0
963
+
964
+
965
+ @enforce_type
966
+ @dataclass
967
+ class SqrtArgs:
968
+ """
969
+ sqrt(Tensor self) -> Tensor
970
+ """
971
+
972
+ input: torch.fx.Node
973
+
974
+
975
+ @enforce_type
976
+ @dataclass
977
+ class SqueezeArgs:
978
+ """
979
+ squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)
980
+ squeeze_copy.dims(Tensor self, int[] dim) -> Tensor
981
+ """
982
+
983
+ input: torch.fx.Node
984
+ dims: List[int] = field(default_factory=lambda: [])
985
+
986
+
987
+ @enforce_type
988
+ @dataclass
989
+ class SubTensorArgs:
990
+ """
991
+ sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
992
+ """
993
+
994
+ input: Union[torch.fx.Node, torch.Tensor, float, int]
995
+ other: Union[torch.fx.Node, torch.Tensor, float, int]
996
+ alpha: Optional[int] = None
997
+
998
+
999
+ @enforce_type
1000
+ @dataclass
1001
+ class SumDimIntListArgs:
1002
+ """
1003
+ sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
1004
+ """
1005
+
1006
+ input: Union[torch.fx.Node, torch.Tensor, float, int]
1007
+ dim: List[int] = field(default_factory=list)
1008
+ keepdim: bool = False
1009
+ dtype: Optional[torch.dtype] = None
1010
+
1011
+
1012
+ @enforce_type
1013
+ @dataclass
1014
+ class TanhArgs:
1015
+ """
1016
+ tanh(Tensor self) -> Tensor
1017
+ """
1018
+
1019
+ input: torch.fx.Node
1020
+
1021
+
1022
+ @enforce_type
1023
+ @dataclass
1024
+ class ToCopyArgs:
1025
+ """
1026
+ _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
1027
+ """
1028
+
1029
+ input: torch.fx.Node
1030
+ dtype: Optional[torch.dtype] = None
1031
+ layout: Optional[torch.layout] = None
1032
+ device: Optional[torch.device] = None
1033
+ pin_memory: Optional[bool] = None
1034
+ non_blocking: Optional[bool] = False
1035
+ memory_format: Optional[torch.memory_format] = None
1036
+
1037
+
1038
+ @enforce_type
1039
+ @dataclass
1040
+ class ToDtypeArgs:
1041
+ """
1042
+ to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
1043
+ """
1044
+
1045
+ input: torch.fx.Node
1046
+ dtype: Optional[torch.dtype] = None
1047
+ non_blocking: Optional[bool] = False
1048
+ copy: Optional[bool] = False
1049
+ memory_format: Optional[torch.memory_format] = None
1050
+
1051
+
1052
+ @enforce_type
1053
+ @dataclass
1054
+ class ToDtypeLayoutArgs:
1055
+ """
1056
+ to.dtype_layout(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
1057
+ """
1058
+
1059
+ input: torch.fx.Node
1060
+ dtype: Optional[torch.dtype] = None
1061
+ layout: Optional[torch.layout] = None
1062
+ device: Optional[torch.device] = None
1063
+ pin_memory: Optional[bool] = None
1064
+ non_blocking: Optional[bool] = False
1065
+ copy: Optional[bool] = False
1066
+ memory_format: Optional[torch.memory_format] = None
1067
+
1068
+
1069
+ @enforce_type
1070
+ @dataclass
1071
+ class UnSqueezeArgs:
1072
+ """
1073
+ unsqueeze(Tensor(a) self, int dim) -> Tensor(a)
1074
+ unsqueeze_copy(Tensor self, int dim) -> Tensor
1075
+ """
1076
+
1077
+ input: torch.fx.Node
1078
+ dim: int
1079
+
1080
+
1081
+ @enforce_type
1082
+ @dataclass
1083
+ class UpsampleNearest2DVecArgs:
1084
+ """
1085
+ upsample_nearest2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
1086
+ """
1087
+
1088
+ input: torch.fx.Node
1089
+ output_size: Optional[List[int]]
1090
+ scale_factors: Optional[List[float]]
1091
+
1092
+
1093
+ @enforce_type
1094
+ @dataclass
1095
+ class ViewArgs:
1096
+ """
1097
+ view(Tensor(a) self, SymInt[] size) -> Tensor(a)
1098
+ view_copy(Tensor self, SymInt[] size) -> Tensor
1099
+ """
1100
+
1101
+ input: torch.fx.Node
1102
+ size: List[int]
1103
+
1104
+
1105
+ @enforce_type
1106
+ @dataclass
1107
+ class WhereSelfArgs:
1108
+ """
1109
+ where.self(Tensor condition, Tensor self, Tensor other) -> Tensor
1110
+ """
1111
+
1112
+ condition: torch.fx.Node
1113
+ input: Union[torch.fx.Node, torch.Tensor]
1114
+ other: Union[torch.fx.Node, torch.Tensor]