tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1149 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+ from typing import List, Optional, TYPE_CHECKING, Union
17
+
18
+ if TYPE_CHECKING:
19
+ import torch._ops
20
+ import torch.fx
21
+ import torch
22
+ import torch.fx.node
23
+
24
+ from tico.utils.utils import enforce_type
25
+
26
+ """
27
+ This file includes OpArgs classes that provide arguments with type annotations.
28
+ - Each class provides type-checked arguments for the aten Op in the comment.
29
+ - Class name is determined by the follwoing priority.
30
+ 1. Torch spec (aten/src/ATen/native/native_functions.yaml in pytorch repo)
31
+ 2. pytorch doc (https://pytorch.org/docs/stable/index.html)
32
+ """
33
+
34
+
35
+ @enforce_type
36
+ @dataclass
37
+ class AbsArgs:
38
+ """
39
+ abs(Tensor self) -> Tensor
40
+ """
41
+
42
+ input: torch.fx.Node
43
+
44
+
45
+ @enforce_type
46
+ @dataclass
47
+ class AddTensorArgs:
48
+ """
49
+ add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
50
+ """
51
+
52
+ input: Union[torch.fx.Node, float, int, torch.Tensor]
53
+ other: Union[torch.fx.Node, float, int, torch.Tensor]
54
+
55
+
56
+ @enforce_type
57
+ @dataclass
58
+ class AddmmArgs:
59
+ """
60
+ addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
61
+ """
62
+
63
+ input: torch.fx.Node
64
+ mat1: torch.fx.Node
65
+ mat2: torch.fx.Node
66
+ beta: Union[int, float] = 1
67
+ alpha: Union[int, float] = 1
68
+
69
+
70
+ @enforce_type
71
+ @dataclass
72
+ class AliasCopyArgs:
73
+ """
74
+ alias_copy(Tensor self) -> Tensor
75
+ """
76
+
77
+ input: torch.fx.Node
78
+
79
+
80
+ @enforce_type
81
+ @dataclass
82
+ class AnyArgs:
83
+ """
84
+ any(Tensor self) -> Tensor
85
+ any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor
86
+ any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor
87
+ """
88
+
89
+ input: torch.fx.Node
90
+ dim: Union[int, tuple, None] = None
91
+ keepdim: bool = False
92
+
93
+
94
+ @enforce_type
95
+ @dataclass
96
+ class ArangeStartStepArgs:
97
+ """
98
+ arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
99
+ """
100
+
101
+ start: Union[int, float]
102
+ end: Union[int, float]
103
+ step: Union[int, float] = 1
104
+
105
+
106
+ @enforce_type
107
+ @dataclass
108
+ class ArgMaxArgs:
109
+ """
110
+ argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor
111
+ """
112
+
113
+ tensor: Union[torch.fx.Node, torch.Tensor]
114
+ dim: Union[int, None] = None
115
+
116
+
117
+ @enforce_type
118
+ @dataclass
119
+ class AvgPool2dArgs:
120
+ """
121
+ avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> (Tensor)
122
+ """
123
+
124
+ input: torch.fx.Node
125
+ kernel_size: List[int]
126
+ stride: Optional[List[int]] = None
127
+ padding: List[int] = field(default_factory=lambda: [0, 0])
128
+ ceil_mode: bool = field(default=False)
129
+ count_include_pad: bool = field(default=True)
130
+ divisor_override: Optional[Union[int, None]] = None
131
+
132
+ def __post_init__(self):
133
+ assert len(self.kernel_size) == 2, len(self.kernel_size)
134
+ self.stride = self.kernel_size if not self.stride else self.stride
135
+ assert len(self.stride) == 2, len(self.stride)
136
+ if self.padding is not None:
137
+ assert len(self.padding) == 2, len(self.padding)
138
+ if self.divisor_override is not None:
139
+ assert isinstance(self.divisor_override, int), type(self.divisor_override)
140
+ assert self.divisor_override != 0, f"Divisor must be not zero."
141
+
142
+
143
+ @enforce_type
144
+ @dataclass
145
+ class AdaptiveAvgPool2dArgs:
146
+ """
147
+ adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor
148
+ """
149
+
150
+ input: torch.fx.Node
151
+ output_size: List[int]
152
+
153
+
154
+ @enforce_type
155
+ @dataclass
156
+ class BmmArgs:
157
+ """
158
+ bmm(Tensor self, Tensor mat2) -> Tensor
159
+ """
160
+
161
+ input: torch.fx.Node
162
+ mat2: torch.fx.Node
163
+
164
+
165
+ @enforce_type
166
+ @dataclass
167
+ class CatArgs:
168
+ """
169
+ cat(Tensor[] tensors, int dim=0) -> Tensor
170
+ """
171
+
172
+ tensors: List[torch.fx.Node]
173
+ dim: int = 0
174
+
175
+
176
+ @enforce_type
177
+ @dataclass
178
+ class ClampArgs:
179
+ """
180
+ clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
181
+ """
182
+
183
+ input: torch.fx.Node
184
+ min: Optional[Union[int, float]] = None
185
+ max: Optional[Union[int, float]] = None
186
+
187
+
188
+ @enforce_type
189
+ @dataclass
190
+ class CloneArgs:
191
+ """
192
+ clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor
193
+ """
194
+
195
+ input: torch.fx.Node
196
+ memory_format: Optional[torch.memory_format] = None
197
+
198
+
199
+ @enforce_type
200
+ @dataclass
201
+ class ConstantPadNdArgs:
202
+ """
203
+ constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor
204
+ """
205
+
206
+ input: torch.fx.Node
207
+ pad: List[int]
208
+ value: int | float
209
+
210
+
211
+ @enforce_type
212
+ @dataclass
213
+ class Conv2DArgs:
214
+ """
215
+ conv2d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, SymInt groups=1) -> Tensor
216
+ conv2d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, str padding="valid", SymInt[2] dilation=1, SymInt groups=1) -> Tensor
217
+ """
218
+
219
+ input: torch.fx.Node
220
+ weight: torch.fx.Node
221
+ bias: Union[torch.fx.Node, None] = None
222
+ stride: List[int] = field(default_factory=lambda: [1, 1])
223
+ padding: Union[List[int], str] = field(default_factory=lambda: [0, 0])
224
+ dilation: List[int] = field(default_factory=lambda: [1, 1])
225
+ groups: int = 1
226
+
227
+ def __post_init__(self):
228
+ assert len(self.stride) == 2, len(self.stride)
229
+ assert len(self.dilation) == 2, len(self.dilation)
230
+
231
+
232
+ @enforce_type
233
+ @dataclass
234
+ class Conv1DArgs:
235
+ """
236
+ conv1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] dilation=1, SymInt groups=1) -> Tensor
237
+ conv1d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, str padding="valid", SymInt[1] dilation=1, SymInt groups=1) -> Tensor
238
+ """
239
+
240
+ input: torch.fx.Node
241
+ weight: torch.fx.Node
242
+ bias: Union[torch.fx.Node, None] = None
243
+ stride: List[int] = field(default_factory=lambda: [1])
244
+ padding: Union[List[int], str] = field(default_factory=lambda: [0])
245
+ dilation: List[int] = field(default_factory=lambda: [1])
246
+ groups: int = 1
247
+
248
+ def __post_init__(self):
249
+ assert len(self.stride) == 1, len(self.stride)
250
+ assert len(self.dilation) == 1, len(self.dilation)
251
+
252
+
253
+ @enforce_type
254
+ @dataclass
255
+ class CopyArgs:
256
+ """
257
+ copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor
258
+ """
259
+
260
+ dst: torch.fx.Node
261
+ src: torch.fx.Node
262
+
263
+
264
+ @enforce_type
265
+ @dataclass
266
+ class CosArgs:
267
+ """
268
+ cos(Tensor self) -> Tensor
269
+ """
270
+
271
+ input: torch.fx.Node
272
+
273
+
274
+ @enforce_type
275
+ @dataclass
276
+ class CumsumArgs:
277
+ """
278
+ cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor
279
+ """
280
+
281
+ input: torch.fx.Node
282
+ dim: int
283
+
284
+
285
+ @enforce_type
286
+ @dataclass
287
+ class DequantizePerChannelArgs:
288
+ """
289
+ quantized_decomposed.dequantize_per_channel(Tensor input, Tensor scales, Tensor? zero_points, int axis, int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor
290
+ """
291
+
292
+ input: torch.fx.Node
293
+ scales: torch.fx.Node
294
+ zero_points: torch.fx.Node
295
+ axis: int
296
+ quant_min: int
297
+ quant_max: int
298
+ dtype: torch.dtype
299
+
300
+
301
+ @enforce_type
302
+ @dataclass
303
+ class DequantizePerTensorArgs:
304
+ """
305
+ quantized_decomposed.dequantize_per_tensor(input: TensorBox, scale: float, zero_point: int, quant_min: int, quant_max: int, dtype: torch.dtype) -> TensorBox
306
+ """
307
+
308
+ input: torch.fx.Node
309
+ scale: float
310
+ zero_point: int
311
+ quant_min: int
312
+ quant_max: int
313
+ dtype: torch.dtype
314
+
315
+
316
+ @enforce_type
317
+ @dataclass
318
+ class DivTensorArgs:
319
+ """
320
+ div.Tensor(Tensor self, Tensor other) -> Tensor
321
+ """
322
+
323
+ input: Union[torch.fx.Node, float, int, torch.Tensor]
324
+ other: Union[torch.fx.Node, float, int, torch.Tensor]
325
+
326
+
327
+ @enforce_type
328
+ @dataclass
329
+ class EmbeddingArgs:
330
+ """
331
+ embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor
332
+ """
333
+
334
+ weight: torch.fx.Node
335
+ indices: torch.fx.Node
336
+ padding_idx: int = 1
337
+ scale_grad_by_freq: bool = False
338
+ sparse: bool = False
339
+
340
+
341
+ @enforce_type
342
+ @dataclass
343
+ class EqArgs:
344
+ """
345
+ eq.Scalar(Tensor self, Scalar other) -> Tensor
346
+ eq.Tensor(Tensor self, Tensor other) -> Tensor
347
+ """
348
+
349
+ input: Union[torch.fx.Node, torch.Tensor, float, int]
350
+ other: Union[torch.fx.Node, torch.Tensor, float, int]
351
+
352
+
353
+ @enforce_type
354
+ @dataclass
355
+ class ExpArgs:
356
+ """
357
+ exp(Tensor self) -> Tensor
358
+ """
359
+
360
+ input: torch.fx.Node
361
+
362
+
363
+ @enforce_type
364
+ @dataclass
365
+ class ExpandArgs:
366
+ """
367
+ expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
368
+ expand_copy(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor
369
+ """
370
+
371
+ input: torch.fx.Node
372
+ size: List[int]
373
+
374
+
375
+ @enforce_type
376
+ @dataclass
377
+ class FakeQuantizePerChannelArgs:
378
+ """
379
+ fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor
380
+ fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask)
381
+ """
382
+
383
+ input: torch.fx.Node
384
+ scale: torch.fx.Node
385
+ zero_point: torch.fx.Node
386
+ axis: int
387
+ quant_min: int
388
+ quant_max: int
389
+
390
+
391
+ @enforce_type
392
+ @dataclass
393
+ class FakeQuantizePerTensorTQParamArgs:
394
+ """
395
+ fake_quantize_per_tensor_affine.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor
396
+ """
397
+
398
+ input: torch.fx.Node
399
+ scale: torch.fx.Node
400
+ zero_point: torch.fx.Node
401
+ quant_min: int
402
+ quant_max: int
403
+
404
+
405
+ @enforce_type
406
+ @dataclass
407
+ class FullLikeArgs:
408
+ """
409
+ full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
410
+ """
411
+
412
+ input: torch.fx.Node
413
+ fill_value: Union[int, float, bool]
414
+ pin_memory: Optional[bool] = None
415
+ memory_format: Optional[torch.memory_format] = None
416
+
417
+
418
+ @enforce_type
419
+ @dataclass
420
+ class FullArgs:
421
+ """
422
+ full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
423
+ """
424
+
425
+ size: Union[list, tuple, torch.Size]
426
+ fill_value: Union[int, float, bool]
427
+
428
+
429
+ @enforce_type
430
+ @dataclass
431
+ class GeArgs:
432
+ """
433
+ ge.Scalar(Tensor self, Scalar other) -> Tensor
434
+ ge.Tensor(Tensor self, Tensor other) -> Tensor
435
+ """
436
+
437
+ input: Union[torch.fx.Node, torch.Tensor, float, int]
438
+ other: Union[torch.fx.Node, torch.Tensor, float, int]
439
+
440
+
441
+ @enforce_type
442
+ @dataclass
443
+ class GeluArgs:
444
+ """
445
+ gelu(Tensor self, *, str approximate='none') -> Tensor
446
+ """
447
+
448
+ input: torch.fx.Node
449
+ approximate: Optional[str] = "none"
450
+
451
+
452
+ @enforce_type
453
+ @dataclass
454
+ class GtArgs:
455
+ """
456
+ gt.Scalar(Tensor self, Scalar other) -> Tensor
457
+ gt.Tensor(Tensor self, Tensor other) -> Tensor
458
+ """
459
+
460
+ input: Union[torch.fx.Node, torch.Tensor, float, int]
461
+ other: Union[torch.fx.Node, torch.Tensor, float, int]
462
+
463
+
464
+ @enforce_type
465
+ @dataclass
466
+ class HardTanhArgs:
467
+ """
468
+ hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor
469
+ """
470
+
471
+ input: Union[torch.fx.Node, torch.Tensor]
472
+ min_val: Union[float, int] = -1
473
+ max_val: Union[float, int] = 1
474
+
475
+
476
+ @enforce_type
477
+ @dataclass
478
+ class IndexSelectArgs:
479
+ """
480
+ index_select(Tensor self, int dim, Tensor index) -> Tensor
481
+ """
482
+
483
+ input: torch.fx.Node
484
+ dim: int
485
+ index: Union[torch.fx.Node, torch.Tensor]
486
+
487
+
488
+ @enforce_type
489
+ @dataclass
490
+ class IndexArgs:
491
+ """
492
+ index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
493
+ """
494
+
495
+ input: torch.fx.Node
496
+ indices: List[Union[torch.fx.Node, torch.Tensor, int, None]]
497
+
498
+
499
+ @enforce_type
500
+ @dataclass
501
+ class InstanceNormArgs:
502
+ """
503
+ instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor
504
+ """
505
+
506
+ input: torch.fx.Node
507
+ weight: Optional[torch.fx.Node]
508
+ bias: Optional[torch.fx.Node]
509
+ running_mean: Optional[torch.fx.Node]
510
+ running_var: Optional[torch.fx.Node]
511
+ use_input_stats: bool
512
+ momentum: float
513
+ eps: float
514
+ cudnn_enabled: bool
515
+
516
+
517
+ @enforce_type
518
+ @dataclass
519
+ class LeakyReluArgs:
520
+ """
521
+ leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor
522
+ """
523
+
524
+ input: torch.fx.Node
525
+ negative_slope: float = 0.01
526
+
527
+
528
+ @enforce_type
529
+ @dataclass
530
+ class LinearArgs:
531
+ """
532
+ linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
533
+ """
534
+
535
+ input: torch.fx.Node
536
+ weight: torch.fx.Node
537
+ bias: Optional[torch.fx.Node] = None
538
+
539
+
540
+ @enforce_type
541
+ @dataclass
542
+ class LogArgs:
543
+ """
544
+ log(Tensor self) -> Tensor
545
+ """
546
+
547
+ input: torch.fx.Node
548
+
549
+
550
+ @enforce_type
551
+ @dataclass
552
+ class Log1pArgs:
553
+ """
554
+ log1p(Tensor self) -> Tensor
555
+ """
556
+
557
+ input: torch.fx.Node
558
+
559
+
560
+ @enforce_type
561
+ @dataclass
562
+ class LogicalAndArgs:
563
+ """
564
+ logical_and(Tensor self, Tensor other) -> Tensor
565
+ """
566
+
567
+ input: torch.fx.Node
568
+ other: torch.fx.Node
569
+
570
+
571
+ @enforce_type
572
+ @dataclass
573
+ class LogicalNotArgs:
574
+ """
575
+ logical_not(Tensor self) -> Tensor
576
+ """
577
+
578
+ input: torch.fx.Node
579
+
580
+
581
+ @enforce_type
582
+ @dataclass
583
+ class LtArgs:
584
+ """
585
+ lt.Tensor(Tensor self, Tensor other) -> Tensor
586
+ """
587
+
588
+ input: torch.fx.Node
589
+ other: torch.fx.Node
590
+
591
+
592
+ @enforce_type
593
+ @dataclass
594
+ class MaxDimArgs:
595
+ """
596
+ max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
597
+ """
598
+
599
+ input: torch.fx.Node
600
+ dim: int
601
+ keepdim: bool = False
602
+
603
+
604
+ @enforce_type
605
+ @dataclass
606
+ class MaxPool2dWithIndicesArgs:
607
+ """
608
+ max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
609
+ """
610
+
611
+ input: torch.fx.Node
612
+ kernel_size: List[int]
613
+ stride: Optional[List[int]] = None
614
+ padding: List[int] = field(default_factory=lambda: [0, 0])
615
+ dilation: List[int] = field(default_factory=lambda: [1, 1])
616
+ ceil_mode: bool = field(default=False)
617
+
618
+ def __post_init__(self):
619
+ assert len(self.kernel_size) == 2, len(self.kernel_size)
620
+ self.stride = self.kernel_size if not self.stride else self.stride
621
+ assert len(self.stride) == 2, len(self.stride)
622
+ if self.padding is not None:
623
+ assert len(self.padding) == 2, len(self.padding)
624
+ if self.dilation is not None:
625
+ assert len(self.dilation) == 2, len(self.dilation)
626
+
627
+
628
+ @enforce_type
629
+ @dataclass
630
+ class MeanDimArgs:
631
+ """
632
+ mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
633
+ """
634
+
635
+ input: torch.fx.Node
636
+ dim: List[int]
637
+ keepdim: bool = False
638
+ dtype: Optional[torch.dtype] = None
639
+
640
+
641
+ @enforce_type
642
+ @dataclass
643
+ class MatmulArgs:
644
+ """
645
+ mm(Tensor self, Tensor mat2) -> Tensor
646
+ """
647
+
648
+ input: torch.fx.Node
649
+ other: torch.fx.Node
650
+
651
+
652
+ @enforce_type
653
+ @dataclass
654
+ class MaximumArgs:
655
+ """
656
+ maximum(Tensor self, Tensor other) -> Tensor
657
+ """
658
+
659
+ input: Union[torch.fx.Node, torch.Tensor]
660
+ other: Union[torch.fx.Node, torch.Tensor]
661
+
662
+
663
+ @enforce_type
664
+ @dataclass
665
+ class MinimumArgs:
666
+ """
667
+ minimum(Tensor self, Tensor other) -> Tensor
668
+ """
669
+
670
+ input: Union[torch.fx.Node, torch.Tensor]
671
+ other: Union[torch.fx.Node, torch.Tensor]
672
+
673
+
674
+ @enforce_type
675
+ @dataclass
676
+ class MulTensorArgs:
677
+ """
678
+ mul.Tensor(Tensor self, Tensor other) -> Tensor
679
+ """
680
+
681
+ input: Union[torch.fx.Node, torch.Tensor, int, float]
682
+ other: Union[torch.fx.Node, torch.Tensor, int, float]
683
+
684
+
685
+ @enforce_type
686
+ @dataclass
687
+ class MulScalarArgs:
688
+ """
689
+ mul.Scalar(Tensor self, Scalar other) -> Tensor
690
+ """
691
+
692
+ input: torch.fx.Node
693
+ other: Union[int, float]
694
+
695
+
696
+ @enforce_type
697
+ @dataclass
698
+ class NeScalarArgs:
699
+ """
700
+ ne.Scalar(Tensor self, Scalar other) -> Tensor
701
+ """
702
+
703
+ input: Union[torch.fx.Node, torch.Tensor, float, int, bool]
704
+ other: Union[torch.fx.Node, torch.Tensor, float, int, bool]
705
+
706
+
707
+ @enforce_type
708
+ @dataclass
709
+ class NativeBatchNormLegitNoTrainingArgs:
710
+ """
711
+ _native_batch_norm_legit_no_training (Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor)
712
+ """
713
+
714
+ input: torch.fx.Node
715
+ weight: Optional[torch.fx.Node]
716
+ bias: Optional[torch.fx.Node]
717
+ running_mean: Optional[torch.fx.Node]
718
+ running_var: Optional[torch.fx.Node]
719
+ momentum: float
720
+ eps: float
721
+
722
+
723
+ @enforce_type
724
+ @dataclass
725
+ class NativeGroupNormArgs:
726
+ """
727
+ native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)
728
+ """
729
+
730
+ input: torch.fx.Node
731
+ weight: Optional[torch.fx.Node]
732
+ bias: Optional[torch.fx.Node]
733
+ N: int
734
+ C: int
735
+ HxW: int
736
+ group: int
737
+ eps: float
738
+
739
+
740
+ @enforce_type
741
+ @dataclass
742
+ class NativeLayerNormArgs:
743
+ """
744
+ native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
745
+ """
746
+
747
+ input: torch.fx.Node
748
+ normalized_shape: Union[tuple, list]
749
+ weight: Optional[torch.fx.Node]
750
+ bias: Optional[torch.fx.Node]
751
+ eps: float
752
+
753
+
754
+ @enforce_type
755
+ @dataclass
756
+ class NeTensorArgs:
757
+ """
758
+ ne.Tensor(Tensor self, Tensor other) -> Tensor
759
+ """
760
+
761
+ input: Union[torch.fx.Node, torch.Tensor, float, int, bool]
762
+ other: Union[torch.fx.Node, torch.Tensor, float, int, bool]
763
+
764
+
765
+ @enforce_type
766
+ @dataclass
767
+ class NegArgs:
768
+ """
769
+ neg(Tensor self) -> Tensor
770
+ """
771
+
772
+ input: torch.fx.Node
773
+
774
+
775
+ @enforce_type
776
+ @dataclass
777
+ class PermuteArgs:
778
+ """
779
+ permute(Tensor(a) self, int[] dims) -> Tensor(a)
780
+ """
781
+
782
+ input: torch.fx.Node
783
+ dims: List[int]
784
+
785
+
786
+ @enforce_type
787
+ @dataclass
788
+ class PowTensorTensorArgs:
789
+ """
790
+ pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor
791
+ """
792
+
793
+ input: torch.fx.Node
794
+ exponent: Union[torch.fx.Node]
795
+
796
+
797
+ @enforce_type
798
+ @dataclass
799
+ class PowTensorScalarArgs:
800
+ """
801
+ pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor
802
+ """
803
+
804
+ input: torch.fx.Node
805
+ exponent: Union[float, int]
806
+
807
+
808
+ @enforce_type
809
+ @dataclass
810
+ class PReLUArgs:
811
+ """
812
+ prelu(Tensor self, Tensor weight) -> Tensor
813
+ """
814
+
815
+ input: torch.fx.Node
816
+ weight: torch.fx.Node
817
+
818
+
819
+ @enforce_type
820
+ @dataclass
821
+ class QuantizePerTensorArgs:
822
+ """
823
+ quantized_decomposed.quantize_per_tensor(input: TensorBox, scale: float, zero_point: int, quant_min: int, quant_max: int, dtype: torch.dtype) -> TensorBox
824
+ """
825
+
826
+ tensor: torch.fx.Node
827
+ scale: float
828
+ zero_p: int
829
+ quant_min: int
830
+ quant_max: int
831
+ dtype: torch.dtype
832
+
833
+
834
+ @enforce_type
835
+ @dataclass
836
+ class ReciprocalArgs:
837
+ """
838
+ reciprocal(Tensor self) -> Tensor
839
+ """
840
+
841
+ input: torch.fx.Node
842
+
843
+
844
+ @enforce_type
845
+ @dataclass
846
+ class ReluArgs:
847
+ """
848
+ relu(Tensor self) -> Tensor
849
+ """
850
+
851
+ input: torch.fx.Node
852
+
853
+
854
+ @enforce_type
855
+ @dataclass
856
+ class Relu6Args:
857
+ """
858
+ relu6(Tensor self) -> Tensor
859
+ """
860
+
861
+ input: torch.fx.Node
862
+
863
+
864
+ @enforce_type
865
+ @dataclass
866
+ class RepeatArgs:
867
+ """
868
+ repeat(Tensor self, SymInt[] repeats) -> Tensor
869
+ """
870
+
871
+ input: torch.fx.Node
872
+ repeats: List[int]
873
+
874
+
875
+ @enforce_type
876
+ @dataclass
877
+ class ReshapeArgs:
878
+ """
879
+ reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)
880
+ """
881
+
882
+ input: torch.fx.Node
883
+ shape: List[int]
884
+
885
+
886
+ @enforce_type
887
+ @dataclass
888
+ class ResizeNearestNeighborArgs:
889
+ """
890
+ # Mapped from `torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode='nearest')` case.
891
+ """
892
+
893
+ input: torch.fx.Node
894
+ size: List[int]
895
+
896
+
897
+ @enforce_type
898
+ @dataclass
899
+ class RsqrtArgs:
900
+ """
901
+ rsqrt(Tensor self) -> Tensor
902
+ """
903
+
904
+ input: torch.fx.Node
905
+
906
+
907
+ @enforce_type
908
+ @dataclass
909
+ class ScalarTensorArgs:
910
+ """
911
+ scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
912
+ """
913
+
914
+ scalar: Union[int, float]
915
+
916
+
917
+ @enforce_type
918
+ @dataclass
919
+ class SelectCopyIntArgs:
920
+ """
921
+ select_copy.int(Tensor self, int dim, SymInt index) -> Tensor
922
+ """
923
+
924
+ input: torch.fx.Node
925
+ dim: int
926
+ index: int
927
+
928
+
929
+ @enforce_type
930
+ @dataclass
931
+ class SigmoidArgs:
932
+ """
933
+ sigmoid(Tensor self) -> Tensor
934
+ """
935
+
936
+ input: torch.fx.Node
937
+
938
+
939
+ @enforce_type
940
+ @dataclass
941
+ class SinArgs:
942
+ """
943
+ sin(Tensor self) -> Tensor
944
+ """
945
+
946
+ input: torch.fx.Node
947
+
948
+
949
+ @enforce_type
950
+ @dataclass
951
+ class SliceArgs:
952
+ """
953
+ slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
954
+ slice_copy.Tensor(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor
955
+ """
956
+
957
+ input: torch.fx.Node
958
+ dim: int = 0
959
+ start: Optional[int] = None
960
+ end: Optional[int] = None
961
+ step: Optional[int] = 1
962
+
963
+
964
+ @enforce_type
965
+ @dataclass
966
+ class SafeSoftmaxArgs:
967
+ """
968
+ _safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
969
+ """
970
+
971
+ input: torch.fx.Node
972
+ dim: int
973
+ dtype: Optional[torch.dtype] = None
974
+
975
+
976
+ @enforce_type
977
+ @dataclass
978
+ class SoftmaxArgs:
979
+ """
980
+ _softmax(Tensor self, int dim, bool half_to_float) -> Tensor
981
+ """
982
+
983
+ input: torch.fx.Node
984
+ dim: int
985
+ half_to_float: bool
986
+
987
+
988
+ @enforce_type
989
+ @dataclass
990
+ class SplitWithSizesArgs:
991
+ """
992
+ split_with_sizes(Tensor(a->*) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[]
993
+ """
994
+
995
+ input: torch.fx.Node
996
+ split_sizes: List[int]
997
+ dim: int = 0
998
+
999
+
1000
+ @enforce_type
1001
+ @dataclass
1002
+ class SqrtArgs:
1003
+ """
1004
+ sqrt(Tensor self) -> Tensor
1005
+ """
1006
+
1007
+ input: torch.fx.Node
1008
+
1009
+
1010
+ @enforce_type
1011
+ @dataclass
1012
+ class SqueezeArgs:
1013
+ """
1014
+ squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)
1015
+ squeeze_copy.dims(Tensor self, int[] dim) -> Tensor
1016
+ """
1017
+
1018
+ input: torch.fx.Node
1019
+ dims: List[int] = field(default_factory=lambda: [])
1020
+
1021
+
1022
+ @enforce_type
1023
+ @dataclass
1024
+ class SubTensorArgs:
1025
+ """
1026
+ sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
1027
+ """
1028
+
1029
+ input: Union[torch.fx.Node, torch.Tensor, float, int]
1030
+ other: Union[torch.fx.Node, torch.Tensor, float, int]
1031
+ alpha: Optional[int] = None
1032
+
1033
+
1034
+ @enforce_type
1035
+ @dataclass
1036
+ class SumDimIntListArgs:
1037
+ """
1038
+ sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
1039
+ """
1040
+
1041
+ input: Union[torch.fx.Node, torch.Tensor, float, int]
1042
+ dim: List[int] = field(default_factory=list)
1043
+ keepdim: bool = False
1044
+ dtype: Optional[torch.dtype] = None
1045
+
1046
+
1047
+ @enforce_type
1048
+ @dataclass
1049
+ class TanhArgs:
1050
+ """
1051
+ tanh(Tensor self) -> Tensor
1052
+ """
1053
+
1054
+ input: torch.fx.Node
1055
+
1056
+
1057
+ @enforce_type
1058
+ @dataclass
1059
+ class ToCopyArgs:
1060
+ """
1061
+ _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
1062
+ """
1063
+
1064
+ input: torch.fx.Node
1065
+ dtype: Optional[torch.dtype] = None
1066
+ layout: Optional[torch.layout] = None
1067
+ device: Optional[torch.device] = None
1068
+ pin_memory: Optional[bool] = None
1069
+ non_blocking: Optional[bool] = False
1070
+ memory_format: Optional[torch.memory_format] = None
1071
+
1072
+
1073
+ @enforce_type
1074
+ @dataclass
1075
+ class ToDtypeArgs:
1076
+ """
1077
+ to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
1078
+ """
1079
+
1080
+ input: torch.fx.Node
1081
+ dtype: Optional[torch.dtype] = None
1082
+ non_blocking: Optional[bool] = False
1083
+ copy: Optional[bool] = False
1084
+ memory_format: Optional[torch.memory_format] = None
1085
+
1086
+
1087
+ @enforce_type
1088
+ @dataclass
1089
+ class ToDtypeLayoutArgs:
1090
+ """
1091
+ to.dtype_layout(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
1092
+ """
1093
+
1094
+ input: torch.fx.Node
1095
+ dtype: Optional[torch.dtype] = None
1096
+ layout: Optional[torch.layout] = None
1097
+ device: Optional[torch.device] = None
1098
+ pin_memory: Optional[bool] = None
1099
+ non_blocking: Optional[bool] = False
1100
+ copy: Optional[bool] = False
1101
+ memory_format: Optional[torch.memory_format] = None
1102
+
1103
+
1104
+ @enforce_type
1105
+ @dataclass
1106
+ class UnSqueezeArgs:
1107
+ """
1108
+ unsqueeze(Tensor(a) self, int dim) -> Tensor(a)
1109
+ unsqueeze_copy(Tensor self, int dim) -> Tensor
1110
+ """
1111
+
1112
+ input: torch.fx.Node
1113
+ dim: int
1114
+
1115
+
1116
+ @enforce_type
1117
+ @dataclass
1118
+ class UpsampleNearest2DVecArgs:
1119
+ """
1120
+ upsample_nearest2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
1121
+ """
1122
+
1123
+ input: torch.fx.Node
1124
+ output_size: Optional[List[int]]
1125
+ scale_factors: Optional[List[float]]
1126
+
1127
+
1128
+ @enforce_type
1129
+ @dataclass
1130
+ class ViewArgs:
1131
+ """
1132
+ view(Tensor(a) self, SymInt[] size) -> Tensor(a)
1133
+ view_copy(Tensor self, SymInt[] size) -> Tensor
1134
+ """
1135
+
1136
+ input: torch.fx.Node
1137
+ size: List[int]
1138
+
1139
+
1140
+ @enforce_type
1141
+ @dataclass
1142
+ class WhereSelfArgs:
1143
+ """
1144
+ where.self(Tensor condition, Tensor self, Tensor other) -> Tensor
1145
+ """
1146
+
1147
+ condition: torch.fx.Node
1148
+ input: Union[torch.fx.Node, torch.Tensor]
1149
+ other: Union[torch.fx.Node, torch.Tensor]