tico 0.1.0.dev250411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +31 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +185 -0
- tico/passes/cast_mixed_type_args.py +186 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +151 -0
- tico/passes/convert_layout_op_to_reshape.py +84 -0
- tico/passes/convert_repeat_to_expand_copy.py +90 -0
- tico/passes/convert_to_relu6.py +180 -0
- tico/passes/decompose_addmm.py +127 -0
- tico/passes/decompose_batch_norm.py +198 -0
- tico/passes/decompose_fake_quantize.py +126 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
- tico/passes/decompose_group_norm.py +258 -0
- tico/passes/decompose_grouped_conv2d.py +202 -0
- tico/passes/decompose_slice_scatter.py +167 -0
- tico/passes/extract_dtype_kwargs.py +121 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +113 -0
- tico/passes/legalize_predefined_layout_operators.py +383 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
- tico/passes/lower_to_slice.py +112 -0
- tico/passes/merge_consecutive_cat.py +82 -0
- tico/passes/ops.py +75 -0
- tico/passes/remove_nop.py +85 -0
- tico/passes/remove_redundant_assert_nodes.py +50 -0
- tico/passes/remove_redundant_expand.py +70 -0
- tico/passes/remove_redundant_permute.py +102 -0
- tico/passes/remove_redundant_reshape.py +431 -0
- tico/passes/remove_redundant_slice.py +64 -0
- tico/passes/remove_redundant_to_copy.py +84 -0
- tico/passes/restore_linear.py +113 -0
- tico/passes/segment_index_select.py +143 -0
- tico/pt2_to_circle.py +101 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +264 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +232 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +142 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +112 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +123 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +181 -0
- tico/serialize/operators/op_copy.py +162 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +92 -0
- tico/serialize/operators/op_depthwise_conv2d.py +198 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +83 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +174 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +138 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +99 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +96 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +51 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +292 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +200 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +562 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +314 -0
- tico/utils/validate_args_kwargs.py +1114 -0
- tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
- tico-0.1.0.dev250411.dist-info/METADATA +17 -0
- tico-0.1.0.dev250411.dist-info/RECORD +196 -0
- tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
- tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1114 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from dataclasses import dataclass, field
|
16
|
+
from typing import List, Optional, TYPE_CHECKING, Union
|
17
|
+
|
18
|
+
if TYPE_CHECKING:
|
19
|
+
import torch._ops
|
20
|
+
import torch.fx
|
21
|
+
import torch
|
22
|
+
import torch.fx.node
|
23
|
+
|
24
|
+
from tico.utils.utils import enforce_type
|
25
|
+
|
26
|
+
"""
|
27
|
+
This file includes OpArgs classes that provide arguments with type annotations.
|
28
|
+
- Each class provides type-checked arguments for the aten Op in the comment.
|
29
|
+
- Class name is determined by the follwoing priority.
|
30
|
+
1. Torch spec (aten/src/ATen/native/native_functions.yaml in pytorch repo)
|
31
|
+
2. pytorch doc (https://pytorch.org/docs/stable/index.html)
|
32
|
+
"""
|
33
|
+
|
34
|
+
|
35
|
+
@enforce_type
|
36
|
+
@dataclass
|
37
|
+
class AddTensorArgs:
|
38
|
+
"""
|
39
|
+
add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
|
40
|
+
"""
|
41
|
+
|
42
|
+
input: Union[torch.fx.Node, float, int, torch.Tensor]
|
43
|
+
other: Union[torch.fx.Node, float, int, torch.Tensor]
|
44
|
+
|
45
|
+
|
46
|
+
@enforce_type
|
47
|
+
@dataclass
|
48
|
+
class AddmmArgs:
|
49
|
+
"""
|
50
|
+
addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
|
51
|
+
"""
|
52
|
+
|
53
|
+
input: torch.fx.Node
|
54
|
+
mat1: torch.fx.Node
|
55
|
+
mat2: torch.fx.Node
|
56
|
+
beta: Union[int, float] = 1
|
57
|
+
alpha: Union[int, float] = 1
|
58
|
+
|
59
|
+
|
60
|
+
@enforce_type
|
61
|
+
@dataclass
|
62
|
+
class AliasCopyArgs:
|
63
|
+
"""
|
64
|
+
alias_copy(Tensor self) -> Tensor
|
65
|
+
"""
|
66
|
+
|
67
|
+
input: torch.fx.Node
|
68
|
+
|
69
|
+
|
70
|
+
@enforce_type
|
71
|
+
@dataclass
|
72
|
+
class AnyArgs:
|
73
|
+
"""
|
74
|
+
any(Tensor self) -> Tensor
|
75
|
+
any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor
|
76
|
+
any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor
|
77
|
+
"""
|
78
|
+
|
79
|
+
input: torch.fx.Node
|
80
|
+
dim: Union[int, tuple, None] = None
|
81
|
+
keepdim: bool = False
|
82
|
+
|
83
|
+
|
84
|
+
@enforce_type
|
85
|
+
@dataclass
|
86
|
+
class ArangeStartStepArgs:
|
87
|
+
"""
|
88
|
+
arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
89
|
+
"""
|
90
|
+
|
91
|
+
start: Union[int, float]
|
92
|
+
end: Union[int, float]
|
93
|
+
step: Union[int, float] = 1
|
94
|
+
|
95
|
+
|
96
|
+
@enforce_type
|
97
|
+
@dataclass
|
98
|
+
class ArgMaxArgs:
|
99
|
+
"""
|
100
|
+
argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor
|
101
|
+
"""
|
102
|
+
|
103
|
+
tensor: Union[torch.fx.Node, torch.Tensor]
|
104
|
+
dim: Union[int, None] = None
|
105
|
+
|
106
|
+
|
107
|
+
@enforce_type
|
108
|
+
@dataclass
|
109
|
+
class AvgPool2dArgs:
|
110
|
+
"""
|
111
|
+
avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> (Tensor)
|
112
|
+
"""
|
113
|
+
|
114
|
+
input: torch.fx.Node
|
115
|
+
kernel_size: List[int]
|
116
|
+
stride: List[int] = field(default_factory=list)
|
117
|
+
padding: List[int] = field(default_factory=lambda: [0, 0])
|
118
|
+
ceil_mode: bool = field(default=False)
|
119
|
+
count_include_pad: bool = field(default=True)
|
120
|
+
divisor_override: Optional[Union[int, None]] = None
|
121
|
+
|
122
|
+
def __post_init__(self):
|
123
|
+
assert len(self.kernel_size) == 2, len(self.kernel_size)
|
124
|
+
assert len(self.stride) == 2, len(self.stride)
|
125
|
+
if self.padding is not None:
|
126
|
+
assert len(self.padding) == 2, len(self.padding)
|
127
|
+
if self.divisor_override is not None:
|
128
|
+
assert isinstance(self.divisor_override, int), type(self.divisor_override)
|
129
|
+
assert self.divisor_override != 0, f"Divisor must be not zero."
|
130
|
+
|
131
|
+
|
132
|
+
@enforce_type
|
133
|
+
@dataclass
|
134
|
+
class AdaptiveAvgPool2dArgs:
|
135
|
+
"""
|
136
|
+
adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor
|
137
|
+
"""
|
138
|
+
|
139
|
+
input: torch.fx.Node
|
140
|
+
output_size: List[int]
|
141
|
+
|
142
|
+
|
143
|
+
@enforce_type
|
144
|
+
@dataclass
|
145
|
+
class BmmArgs:
|
146
|
+
"""
|
147
|
+
bmm(Tensor self, Tensor mat2) -> Tensor
|
148
|
+
"""
|
149
|
+
|
150
|
+
input: torch.fx.Node
|
151
|
+
mat2: torch.fx.Node
|
152
|
+
|
153
|
+
|
154
|
+
@enforce_type
|
155
|
+
@dataclass
|
156
|
+
class CatArgs:
|
157
|
+
"""
|
158
|
+
cat(Tensor[] tensors, int dim=0) -> Tensor
|
159
|
+
"""
|
160
|
+
|
161
|
+
tensors: List[torch.fx.Node]
|
162
|
+
dim: int = 0
|
163
|
+
|
164
|
+
|
165
|
+
@enforce_type
|
166
|
+
@dataclass
|
167
|
+
class ClampArgs:
|
168
|
+
"""
|
169
|
+
clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
|
170
|
+
"""
|
171
|
+
|
172
|
+
input: torch.fx.Node
|
173
|
+
min: Optional[Union[int, float]] = None
|
174
|
+
max: Optional[Union[int, float]] = None
|
175
|
+
|
176
|
+
|
177
|
+
@enforce_type
|
178
|
+
@dataclass
|
179
|
+
class CloneArgs:
|
180
|
+
"""
|
181
|
+
clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor
|
182
|
+
"""
|
183
|
+
|
184
|
+
input: torch.fx.Node
|
185
|
+
memory_format: Optional[torch.memory_format] = None
|
186
|
+
|
187
|
+
|
188
|
+
@enforce_type
|
189
|
+
@dataclass
|
190
|
+
class ConstantPadNdArgs:
|
191
|
+
"""
|
192
|
+
constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor
|
193
|
+
"""
|
194
|
+
|
195
|
+
input: torch.fx.Node
|
196
|
+
pad: List[int]
|
197
|
+
value: int | float
|
198
|
+
|
199
|
+
|
200
|
+
@enforce_type
|
201
|
+
@dataclass
|
202
|
+
class Conv2DArgs:
|
203
|
+
"""
|
204
|
+
conv2d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, SymInt groups=1) -> Tensor
|
205
|
+
conv2d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, str padding="valid", SymInt[2] dilation=1, SymInt groups=1) -> Tensor
|
206
|
+
"""
|
207
|
+
|
208
|
+
input: torch.fx.Node
|
209
|
+
weight: torch.fx.Node
|
210
|
+
bias: Union[torch.fx.Node, None] = None
|
211
|
+
stride: List[int] = field(default_factory=lambda: [1, 1])
|
212
|
+
padding: Union[List[int], str] = field(default_factory=lambda: [0, 0])
|
213
|
+
dilation: List[int] = field(default_factory=lambda: [1, 1])
|
214
|
+
groups: int = 1
|
215
|
+
|
216
|
+
def __post_init__(self):
|
217
|
+
assert len(self.stride) == 2, len(self.stride)
|
218
|
+
assert len(self.dilation) == 2, len(self.dilation)
|
219
|
+
|
220
|
+
|
221
|
+
@enforce_type
|
222
|
+
@dataclass
|
223
|
+
class Conv1DArgs:
|
224
|
+
"""
|
225
|
+
conv1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] dilation=1, SymInt groups=1) -> Tensor
|
226
|
+
conv1d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, str padding="valid", SymInt[1] dilation=1, SymInt groups=1) -> Tensor
|
227
|
+
"""
|
228
|
+
|
229
|
+
input: torch.fx.Node
|
230
|
+
weight: torch.fx.Node
|
231
|
+
bias: Union[torch.fx.Node, None] = None
|
232
|
+
stride: List[int] = field(default_factory=lambda: [1])
|
233
|
+
padding: Union[List[int], str] = field(default_factory=lambda: [0])
|
234
|
+
dilation: List[int] = field(default_factory=lambda: [1])
|
235
|
+
groups: int = 1
|
236
|
+
|
237
|
+
def __post_init__(self):
|
238
|
+
assert len(self.stride) == 1, len(self.stride)
|
239
|
+
assert len(self.dilation) == 1, len(self.dilation)
|
240
|
+
|
241
|
+
|
242
|
+
@enforce_type
|
243
|
+
@dataclass
|
244
|
+
class CopyArgs:
|
245
|
+
"""
|
246
|
+
copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor
|
247
|
+
"""
|
248
|
+
|
249
|
+
dst: torch.fx.Node
|
250
|
+
src: torch.fx.Node
|
251
|
+
|
252
|
+
|
253
|
+
@enforce_type
|
254
|
+
@dataclass
|
255
|
+
class CosArgs:
|
256
|
+
"""
|
257
|
+
cos(Tensor self) -> Tensor
|
258
|
+
"""
|
259
|
+
|
260
|
+
input: torch.fx.Node
|
261
|
+
|
262
|
+
|
263
|
+
@enforce_type
|
264
|
+
@dataclass
|
265
|
+
class CumsumArgs:
|
266
|
+
"""
|
267
|
+
cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor
|
268
|
+
"""
|
269
|
+
|
270
|
+
input: torch.fx.Node
|
271
|
+
dim: int
|
272
|
+
|
273
|
+
|
274
|
+
@enforce_type
|
275
|
+
@dataclass
|
276
|
+
class DequantizePerChannelArgs:
|
277
|
+
"""
|
278
|
+
quantized_decomposed.dequantize_per_channel(Tensor input, Tensor scales, Tensor? zero_points, int axis, int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor
|
279
|
+
"""
|
280
|
+
|
281
|
+
input: torch.fx.Node
|
282
|
+
scales: torch.fx.Node
|
283
|
+
zero_points: torch.fx.Node
|
284
|
+
axis: int
|
285
|
+
quant_min: int
|
286
|
+
quant_max: int
|
287
|
+
dtype: torch.dtype
|
288
|
+
|
289
|
+
|
290
|
+
@enforce_type
|
291
|
+
@dataclass
|
292
|
+
class DequantizePerTensorArgs:
|
293
|
+
"""
|
294
|
+
quantized_decomposed.dequantize_per_tensor(input: TensorBox, scale: float, zero_point: int, quant_min: int, quant_max: int, dtype: torch.dtype) -> TensorBox
|
295
|
+
"""
|
296
|
+
|
297
|
+
input: torch.fx.Node
|
298
|
+
scale: float
|
299
|
+
zero_point: int
|
300
|
+
quant_min: int
|
301
|
+
quant_max: int
|
302
|
+
dtype: torch.dtype
|
303
|
+
|
304
|
+
|
305
|
+
@enforce_type
|
306
|
+
@dataclass
|
307
|
+
class DivTensorArgs:
|
308
|
+
"""
|
309
|
+
div.Tensor(Tensor self, Tensor other) -> Tensor
|
310
|
+
"""
|
311
|
+
|
312
|
+
input: Union[torch.fx.Node, float, int, torch.Tensor]
|
313
|
+
other: Union[torch.fx.Node, float, int, torch.Tensor]
|
314
|
+
|
315
|
+
|
316
|
+
@enforce_type
|
317
|
+
@dataclass
|
318
|
+
class EmbeddingArgs:
|
319
|
+
"""
|
320
|
+
embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor
|
321
|
+
"""
|
322
|
+
|
323
|
+
weight: torch.fx.Node
|
324
|
+
indices: torch.fx.Node
|
325
|
+
padding_idx: int = 1
|
326
|
+
scale_grad_by_freq: bool = False
|
327
|
+
sparse: bool = False
|
328
|
+
|
329
|
+
|
330
|
+
@enforce_type
|
331
|
+
@dataclass
|
332
|
+
class EqArgs:
|
333
|
+
"""
|
334
|
+
eq.Scalar(Tensor self, Scalar other) -> Tensor
|
335
|
+
eq.Tensor(Tensor self, Tensor other) -> Tensor
|
336
|
+
"""
|
337
|
+
|
338
|
+
input: Union[torch.fx.Node, torch.Tensor, float, int]
|
339
|
+
other: Union[torch.fx.Node, torch.Tensor, float, int]
|
340
|
+
|
341
|
+
|
342
|
+
@enforce_type
|
343
|
+
@dataclass
|
344
|
+
class ExpArgs:
|
345
|
+
"""
|
346
|
+
exp(Tensor self) -> Tensor
|
347
|
+
"""
|
348
|
+
|
349
|
+
input: torch.fx.Node
|
350
|
+
|
351
|
+
|
352
|
+
@enforce_type
|
353
|
+
@dataclass
|
354
|
+
class ExpandArgs:
|
355
|
+
"""
|
356
|
+
expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
|
357
|
+
expand_copy(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor
|
358
|
+
"""
|
359
|
+
|
360
|
+
input: torch.fx.Node
|
361
|
+
size: List[int]
|
362
|
+
|
363
|
+
|
364
|
+
@enforce_type
|
365
|
+
@dataclass
|
366
|
+
class FakeQuantizePerChannelArgs:
|
367
|
+
"""
|
368
|
+
fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor
|
369
|
+
fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask)
|
370
|
+
"""
|
371
|
+
|
372
|
+
input: torch.fx.Node
|
373
|
+
scale: torch.fx.Node
|
374
|
+
zero_point: torch.fx.Node
|
375
|
+
axis: int
|
376
|
+
quant_min: int
|
377
|
+
quant_max: int
|
378
|
+
|
379
|
+
|
380
|
+
@enforce_type
|
381
|
+
@dataclass
|
382
|
+
class FakeQuantizePerTensorTQParamArgs:
|
383
|
+
"""
|
384
|
+
fake_quantize_per_tensor_affine.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor
|
385
|
+
"""
|
386
|
+
|
387
|
+
input: torch.fx.Node
|
388
|
+
scale: torch.fx.Node
|
389
|
+
zero_point: torch.fx.Node
|
390
|
+
quant_min: int
|
391
|
+
quant_max: int
|
392
|
+
|
393
|
+
|
394
|
+
@enforce_type
|
395
|
+
@dataclass
|
396
|
+
class FullLikeArgs:
|
397
|
+
"""
|
398
|
+
full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
|
399
|
+
"""
|
400
|
+
|
401
|
+
input: torch.fx.Node
|
402
|
+
fill_value: Union[int, float, bool]
|
403
|
+
pin_memory: Optional[bool] = None
|
404
|
+
memory_format: Optional[torch.memory_format] = None
|
405
|
+
|
406
|
+
|
407
|
+
@enforce_type
|
408
|
+
@dataclass
|
409
|
+
class FullArgs:
|
410
|
+
"""
|
411
|
+
full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
412
|
+
"""
|
413
|
+
|
414
|
+
size: Union[list, tuple, torch.Size]
|
415
|
+
fill_value: Union[int, float, bool]
|
416
|
+
|
417
|
+
|
418
|
+
@enforce_type
|
419
|
+
@dataclass
|
420
|
+
class GeArgs:
|
421
|
+
"""
|
422
|
+
ge.Scalar(Tensor self, Scalar other) -> Tensor
|
423
|
+
ge.Tensor(Tensor self, Tensor other) -> Tensor
|
424
|
+
"""
|
425
|
+
|
426
|
+
input: Union[torch.fx.Node, torch.Tensor, float, int]
|
427
|
+
other: Union[torch.fx.Node, torch.Tensor, float, int]
|
428
|
+
|
429
|
+
|
430
|
+
@enforce_type
|
431
|
+
@dataclass
|
432
|
+
class GeluArgs:
|
433
|
+
"""
|
434
|
+
gelu(Tensor self, *, str approximate='none') -> Tensor
|
435
|
+
"""
|
436
|
+
|
437
|
+
input: torch.fx.Node
|
438
|
+
approximate: Optional[str] = "none"
|
439
|
+
|
440
|
+
|
441
|
+
@enforce_type
|
442
|
+
@dataclass
|
443
|
+
class GtArgs:
|
444
|
+
"""
|
445
|
+
gt.Scalar(Tensor self, Scalar other) -> Tensor
|
446
|
+
gt.Tensor(Tensor self, Tensor other) -> Tensor
|
447
|
+
"""
|
448
|
+
|
449
|
+
input: Union[torch.fx.Node, torch.Tensor, float, int]
|
450
|
+
other: Union[torch.fx.Node, torch.Tensor, float, int]
|
451
|
+
|
452
|
+
|
453
|
+
@enforce_type
|
454
|
+
@dataclass
|
455
|
+
class HardTanhArgs:
|
456
|
+
"""
|
457
|
+
hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor
|
458
|
+
"""
|
459
|
+
|
460
|
+
input: Union[torch.fx.Node, torch.Tensor]
|
461
|
+
min_val: Union[float, int] = -1
|
462
|
+
max_val: Union[float, int] = 1
|
463
|
+
|
464
|
+
|
465
|
+
@enforce_type
|
466
|
+
@dataclass
|
467
|
+
class IndexSelectArgs:
|
468
|
+
"""
|
469
|
+
index_select(Tensor self, int dim, Tensor index) -> Tensor
|
470
|
+
"""
|
471
|
+
|
472
|
+
input: torch.fx.Node
|
473
|
+
dim: int
|
474
|
+
index: Union[torch.fx.Node, torch.Tensor]
|
475
|
+
|
476
|
+
|
477
|
+
@enforce_type
|
478
|
+
@dataclass
|
479
|
+
class IndexArgs:
|
480
|
+
"""
|
481
|
+
index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
|
482
|
+
"""
|
483
|
+
|
484
|
+
input: torch.fx.Node
|
485
|
+
indices: List[Union[torch.fx.Node, torch.Tensor, int, None]]
|
486
|
+
|
487
|
+
|
488
|
+
@enforce_type
|
489
|
+
@dataclass
|
490
|
+
class InstanceNormArgs:
|
491
|
+
"""
|
492
|
+
instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor
|
493
|
+
"""
|
494
|
+
|
495
|
+
input: torch.fx.Node
|
496
|
+
weight: Optional[torch.fx.Node]
|
497
|
+
bias: Optional[torch.fx.Node]
|
498
|
+
running_mean: Optional[torch.fx.Node]
|
499
|
+
running_var: Optional[torch.fx.Node]
|
500
|
+
use_input_stats: bool
|
501
|
+
momentum: float
|
502
|
+
eps: float
|
503
|
+
cudnn_enabled: bool
|
504
|
+
|
505
|
+
|
506
|
+
@enforce_type
|
507
|
+
@dataclass
|
508
|
+
class LinearArgs:
|
509
|
+
"""
|
510
|
+
linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
|
511
|
+
"""
|
512
|
+
|
513
|
+
input: torch.fx.Node
|
514
|
+
weight: torch.fx.Node
|
515
|
+
bias: Optional[torch.fx.Node] = None
|
516
|
+
|
517
|
+
|
518
|
+
@enforce_type
|
519
|
+
@dataclass
|
520
|
+
class LogArgs:
|
521
|
+
"""
|
522
|
+
log(Tensor self) -> Tensor
|
523
|
+
"""
|
524
|
+
|
525
|
+
input: torch.fx.Node
|
526
|
+
|
527
|
+
|
528
|
+
@enforce_type
|
529
|
+
@dataclass
|
530
|
+
class Log1pArgs:
|
531
|
+
"""
|
532
|
+
log1p(Tensor self) -> Tensor
|
533
|
+
"""
|
534
|
+
|
535
|
+
input: torch.fx.Node
|
536
|
+
|
537
|
+
|
538
|
+
@enforce_type
|
539
|
+
@dataclass
|
540
|
+
class LogicalAndArgs:
|
541
|
+
"""
|
542
|
+
logical_and(Tensor self, Tensor other) -> Tensor
|
543
|
+
"""
|
544
|
+
|
545
|
+
input: torch.fx.Node
|
546
|
+
other: torch.fx.Node
|
547
|
+
|
548
|
+
|
549
|
+
@enforce_type
|
550
|
+
@dataclass
|
551
|
+
class LogicalNotArgs:
|
552
|
+
"""
|
553
|
+
logical_not(Tensor self) -> Tensor
|
554
|
+
"""
|
555
|
+
|
556
|
+
input: torch.fx.Node
|
557
|
+
|
558
|
+
|
559
|
+
@enforce_type
|
560
|
+
@dataclass
|
561
|
+
class LtArgs:
|
562
|
+
"""
|
563
|
+
lt.Tensor(Tensor self, Tensor other) -> Tensor
|
564
|
+
"""
|
565
|
+
|
566
|
+
input: torch.fx.Node
|
567
|
+
other: torch.fx.Node
|
568
|
+
|
569
|
+
|
570
|
+
@enforce_type
|
571
|
+
@dataclass
|
572
|
+
class MaxPool2dWithIndicesArgs:
|
573
|
+
"""
|
574
|
+
max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
|
575
|
+
"""
|
576
|
+
|
577
|
+
input: torch.fx.Node
|
578
|
+
kernel_size: List[int]
|
579
|
+
stride: List[int] = field(default_factory=list)
|
580
|
+
padding: List[int] = field(default_factory=lambda: [0, 0])
|
581
|
+
dilation: List[int] = field(default_factory=lambda: [1, 1])
|
582
|
+
ceil_mode: bool = field(default=False)
|
583
|
+
|
584
|
+
def __post_init__(self):
|
585
|
+
assert len(self.kernel_size) == 2, len(self.kernel_size)
|
586
|
+
assert len(self.stride) == 2, len(self.stride)
|
587
|
+
if self.padding is not None:
|
588
|
+
assert len(self.padding) == 2, len(self.padding)
|
589
|
+
if self.dilation is not None:
|
590
|
+
assert len(self.dilation) == 2, len(self.dilation)
|
591
|
+
|
592
|
+
|
593
|
+
@enforce_type
|
594
|
+
@dataclass
|
595
|
+
class MeanDimArgs:
|
596
|
+
"""
|
597
|
+
mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
|
598
|
+
"""
|
599
|
+
|
600
|
+
input: torch.fx.Node
|
601
|
+
dim: List[int]
|
602
|
+
keep_dims: bool = False
|
603
|
+
dtype: Optional[torch.dtype] = None
|
604
|
+
|
605
|
+
|
606
|
+
@enforce_type
|
607
|
+
@dataclass
|
608
|
+
class MatmulArgs:
|
609
|
+
"""
|
610
|
+
mm(Tensor self, Tensor mat2) -> Tensor
|
611
|
+
"""
|
612
|
+
|
613
|
+
input: torch.fx.Node
|
614
|
+
other: torch.fx.Node
|
615
|
+
|
616
|
+
|
617
|
+
@enforce_type
|
618
|
+
@dataclass
|
619
|
+
class MaximumArgs:
|
620
|
+
"""
|
621
|
+
maximum(Tensor self, Tensor other) -> Tensor
|
622
|
+
"""
|
623
|
+
|
624
|
+
input: Union[torch.fx.Node, torch.Tensor]
|
625
|
+
other: Union[torch.fx.Node, torch.Tensor]
|
626
|
+
|
627
|
+
|
628
|
+
@enforce_type
|
629
|
+
@dataclass
|
630
|
+
class MinimumArgs:
|
631
|
+
"""
|
632
|
+
minimum(Tensor self, Tensor other) -> Tensor
|
633
|
+
"""
|
634
|
+
|
635
|
+
input: Union[torch.fx.Node, torch.Tensor]
|
636
|
+
other: Union[torch.fx.Node, torch.Tensor]
|
637
|
+
|
638
|
+
|
639
|
+
@enforce_type
|
640
|
+
@dataclass
|
641
|
+
class MulTensorArgs:
|
642
|
+
"""
|
643
|
+
mul.Tensor(Tensor self, Tensor other) -> Tensor
|
644
|
+
"""
|
645
|
+
|
646
|
+
input: Union[torch.fx.Node, torch.Tensor, int, float]
|
647
|
+
other: Union[torch.fx.Node, torch.Tensor, int, float]
|
648
|
+
|
649
|
+
|
650
|
+
@enforce_type
|
651
|
+
@dataclass
|
652
|
+
class MulScalarArgs:
|
653
|
+
"""
|
654
|
+
mul.Scalar(Tensor self, Scalar other) -> Tensor
|
655
|
+
"""
|
656
|
+
|
657
|
+
input: torch.fx.Node
|
658
|
+
other: Union[int, float]
|
659
|
+
|
660
|
+
|
661
|
+
@enforce_type
|
662
|
+
@dataclass
|
663
|
+
class NeScalarArgs:
|
664
|
+
"""
|
665
|
+
ne.Scalar(Tensor self, Scalar other) -> Tensor
|
666
|
+
"""
|
667
|
+
|
668
|
+
input: Union[torch.fx.Node, torch.Tensor, float, int, bool]
|
669
|
+
other: Union[torch.fx.Node, torch.Tensor, float, int, bool]
|
670
|
+
|
671
|
+
|
672
|
+
@enforce_type
|
673
|
+
@dataclass
|
674
|
+
class NativeBatchNormLegitNoTrainingArgs:
|
675
|
+
"""
|
676
|
+
_native_batch_norm_legit_no_training (Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor)
|
677
|
+
"""
|
678
|
+
|
679
|
+
input: torch.fx.Node
|
680
|
+
weight: Optional[torch.fx.Node]
|
681
|
+
bias: Optional[torch.fx.Node]
|
682
|
+
running_mean: Optional[torch.fx.Node]
|
683
|
+
running_var: Optional[torch.fx.Node]
|
684
|
+
momentum: float
|
685
|
+
eps: float
|
686
|
+
|
687
|
+
|
688
|
+
@enforce_type
|
689
|
+
@dataclass
|
690
|
+
class NativeGroupNormArgs:
|
691
|
+
"""
|
692
|
+
native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)
|
693
|
+
"""
|
694
|
+
|
695
|
+
input: torch.fx.Node
|
696
|
+
weight: Optional[torch.fx.Node]
|
697
|
+
bias: Optional[torch.fx.Node]
|
698
|
+
N: int
|
699
|
+
C: int
|
700
|
+
HxW: int
|
701
|
+
group: int
|
702
|
+
eps: float
|
703
|
+
|
704
|
+
|
705
|
+
@enforce_type
|
706
|
+
@dataclass
|
707
|
+
class NativeLayerNormArgs:
|
708
|
+
"""
|
709
|
+
native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
|
710
|
+
"""
|
711
|
+
|
712
|
+
input: torch.fx.Node
|
713
|
+
normalized_shape: Union[tuple, list]
|
714
|
+
weight: Optional[torch.fx.Node]
|
715
|
+
bias: Optional[torch.fx.Node]
|
716
|
+
eps: float
|
717
|
+
|
718
|
+
|
719
|
+
@enforce_type
|
720
|
+
@dataclass
|
721
|
+
class NeTensorArgs:
|
722
|
+
"""
|
723
|
+
ne.Tensor(Tensor self, Tensor other) -> Tensor
|
724
|
+
"""
|
725
|
+
|
726
|
+
input: Union[torch.fx.Node, torch.Tensor, float, int, bool]
|
727
|
+
other: Union[torch.fx.Node, torch.Tensor, float, int, bool]
|
728
|
+
|
729
|
+
|
730
|
+
@enforce_type
|
731
|
+
@dataclass
|
732
|
+
class NegArgs:
|
733
|
+
"""
|
734
|
+
neg(Tensor self) -> Tensor
|
735
|
+
"""
|
736
|
+
|
737
|
+
input: torch.fx.Node
|
738
|
+
|
739
|
+
|
740
|
+
@enforce_type
|
741
|
+
@dataclass
|
742
|
+
class PermuteArgs:
|
743
|
+
"""
|
744
|
+
permute(Tensor(a) self, int[] dims) -> Tensor(a)
|
745
|
+
"""
|
746
|
+
|
747
|
+
input: torch.fx.Node
|
748
|
+
dims: List[int]
|
749
|
+
|
750
|
+
|
751
|
+
@enforce_type
|
752
|
+
@dataclass
|
753
|
+
class PowTensorTensorArgs:
|
754
|
+
"""
|
755
|
+
pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor
|
756
|
+
"""
|
757
|
+
|
758
|
+
input: torch.fx.Node
|
759
|
+
exponent: Union[torch.fx.Node]
|
760
|
+
|
761
|
+
|
762
|
+
@enforce_type
|
763
|
+
@dataclass
|
764
|
+
class PowTensorScalarArgs:
|
765
|
+
"""
|
766
|
+
pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor
|
767
|
+
"""
|
768
|
+
|
769
|
+
input: torch.fx.Node
|
770
|
+
exponent: Union[float, int]
|
771
|
+
|
772
|
+
|
773
|
+
@enforce_type
|
774
|
+
@dataclass
|
775
|
+
class PReLUArgs:
|
776
|
+
"""
|
777
|
+
prelu(Tensor self, Tensor weight) -> Tensor
|
778
|
+
"""
|
779
|
+
|
780
|
+
input: torch.fx.Node
|
781
|
+
weight: torch.fx.Node
|
782
|
+
|
783
|
+
|
784
|
+
@enforce_type
|
785
|
+
@dataclass
|
786
|
+
class QuantizePerTensorArgs:
|
787
|
+
"""
|
788
|
+
quantized_decomposed.quantize_per_tensor(input: TensorBox, scale: float, zero_point: int, quant_min: int, quant_max: int, dtype: torch.dtype) -> TensorBox
|
789
|
+
"""
|
790
|
+
|
791
|
+
tensor: torch.fx.Node
|
792
|
+
scale: float
|
793
|
+
zero_p: int
|
794
|
+
quant_min: int
|
795
|
+
quant_max: int
|
796
|
+
dtype: torch.dtype
|
797
|
+
|
798
|
+
|
799
|
+
@enforce_type
|
800
|
+
@dataclass
|
801
|
+
class ReciprocalArgs:
|
802
|
+
"""
|
803
|
+
reciprocal(Tensor self) -> Tensor
|
804
|
+
"""
|
805
|
+
|
806
|
+
input: torch.fx.Node
|
807
|
+
|
808
|
+
|
809
|
+
@enforce_type
|
810
|
+
@dataclass
|
811
|
+
class ReluArgs:
|
812
|
+
"""
|
813
|
+
relu(Tensor self) -> Tensor
|
814
|
+
"""
|
815
|
+
|
816
|
+
input: torch.fx.Node
|
817
|
+
|
818
|
+
|
819
|
+
@enforce_type
|
820
|
+
@dataclass
|
821
|
+
class Relu6Args:
|
822
|
+
"""
|
823
|
+
relu6(Tensor self) -> Tensor
|
824
|
+
"""
|
825
|
+
|
826
|
+
input: torch.fx.Node
|
827
|
+
|
828
|
+
|
829
|
+
@enforce_type
|
830
|
+
@dataclass
|
831
|
+
class RepeatArgs:
|
832
|
+
"""
|
833
|
+
repeat(Tensor self, SymInt[] repeats) -> Tensor
|
834
|
+
"""
|
835
|
+
|
836
|
+
input: torch.fx.Node
|
837
|
+
repeats: List[int]
|
838
|
+
|
839
|
+
|
840
|
+
@enforce_type
|
841
|
+
@dataclass
|
842
|
+
class ReshapeArgs:
|
843
|
+
"""
|
844
|
+
reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)
|
845
|
+
"""
|
846
|
+
|
847
|
+
input: torch.fx.Node
|
848
|
+
size: List[int]
|
849
|
+
|
850
|
+
|
851
|
+
@enforce_type
|
852
|
+
@dataclass
|
853
|
+
class ResizeNearestNeighborArgs:
|
854
|
+
"""
|
855
|
+
# Maps from `torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode='nearest')` case.
|
856
|
+
"""
|
857
|
+
|
858
|
+
input: torch.fx.Node
|
859
|
+
size: List[int]
|
860
|
+
|
861
|
+
|
862
|
+
@enforce_type
|
863
|
+
@dataclass
|
864
|
+
class RsqrtArgs:
|
865
|
+
"""
|
866
|
+
rsqrt(Tensor self) -> Tensor
|
867
|
+
"""
|
868
|
+
|
869
|
+
input: torch.fx.Node
|
870
|
+
|
871
|
+
|
872
|
+
@enforce_type
|
873
|
+
@dataclass
|
874
|
+
class ScalarTensorArgs:
|
875
|
+
"""
|
876
|
+
scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
877
|
+
"""
|
878
|
+
|
879
|
+
scalar: Union[int, float]
|
880
|
+
|
881
|
+
|
882
|
+
@enforce_type
|
883
|
+
@dataclass
|
884
|
+
class SelectCopyIntArgs:
|
885
|
+
"""
|
886
|
+
select_copy.int(Tensor self, int dim, SymInt index) -> Tensor
|
887
|
+
"""
|
888
|
+
|
889
|
+
input: torch.fx.Node
|
890
|
+
dim: int
|
891
|
+
index: int
|
892
|
+
|
893
|
+
|
894
|
+
@enforce_type
|
895
|
+
@dataclass
|
896
|
+
class SigmoidArgs:
|
897
|
+
"""
|
898
|
+
sigmoid(Tensor self) -> Tensor
|
899
|
+
"""
|
900
|
+
|
901
|
+
input: torch.fx.Node
|
902
|
+
|
903
|
+
|
904
|
+
@enforce_type
|
905
|
+
@dataclass
|
906
|
+
class SinArgs:
|
907
|
+
"""
|
908
|
+
sin(Tensor self) -> Tensor
|
909
|
+
"""
|
910
|
+
|
911
|
+
input: torch.fx.Node
|
912
|
+
|
913
|
+
|
914
|
+
@enforce_type
|
915
|
+
@dataclass
|
916
|
+
class SliceArgs:
|
917
|
+
"""
|
918
|
+
slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
|
919
|
+
slice_copy.Tensor(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor
|
920
|
+
"""
|
921
|
+
|
922
|
+
input: torch.fx.Node
|
923
|
+
dim: int = 0
|
924
|
+
start: Optional[int] = None
|
925
|
+
end: Optional[int] = None
|
926
|
+
step: Optional[int] = 1
|
927
|
+
|
928
|
+
|
929
|
+
@enforce_type
|
930
|
+
@dataclass
|
931
|
+
class SafeSoftmaxArgs:
|
932
|
+
"""
|
933
|
+
_safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
|
934
|
+
"""
|
935
|
+
|
936
|
+
input: torch.fx.Node
|
937
|
+
dim: int
|
938
|
+
dtype: Optional[torch.dtype] = None
|
939
|
+
|
940
|
+
|
941
|
+
@enforce_type
|
942
|
+
@dataclass
|
943
|
+
class SoftmaxArgs:
|
944
|
+
"""
|
945
|
+
_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
|
946
|
+
"""
|
947
|
+
|
948
|
+
input: torch.fx.Node
|
949
|
+
dim: int
|
950
|
+
half_to_float: bool
|
951
|
+
|
952
|
+
|
953
|
+
@enforce_type
|
954
|
+
@dataclass
|
955
|
+
class SplitWithSizesArgs:
|
956
|
+
"""
|
957
|
+
split_with_sizes(Tensor(a->*) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[]
|
958
|
+
"""
|
959
|
+
|
960
|
+
input: torch.fx.Node
|
961
|
+
split_sizes: List[int]
|
962
|
+
dim: int = 0
|
963
|
+
|
964
|
+
|
965
|
+
@enforce_type
|
966
|
+
@dataclass
|
967
|
+
class SqrtArgs:
|
968
|
+
"""
|
969
|
+
sqrt(Tensor self) -> Tensor
|
970
|
+
"""
|
971
|
+
|
972
|
+
input: torch.fx.Node
|
973
|
+
|
974
|
+
|
975
|
+
@enforce_type
|
976
|
+
@dataclass
|
977
|
+
class SqueezeArgs:
|
978
|
+
"""
|
979
|
+
squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)
|
980
|
+
squeeze_copy.dims(Tensor self, int[] dim) -> Tensor
|
981
|
+
"""
|
982
|
+
|
983
|
+
input: torch.fx.Node
|
984
|
+
dims: List[int] = field(default_factory=lambda: [])
|
985
|
+
|
986
|
+
|
987
|
+
@enforce_type
|
988
|
+
@dataclass
|
989
|
+
class SubTensorArgs:
|
990
|
+
"""
|
991
|
+
sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
|
992
|
+
"""
|
993
|
+
|
994
|
+
input: Union[torch.fx.Node, torch.Tensor, float, int]
|
995
|
+
other: Union[torch.fx.Node, torch.Tensor, float, int]
|
996
|
+
alpha: Optional[int] = None
|
997
|
+
|
998
|
+
|
999
|
+
@enforce_type
|
1000
|
+
@dataclass
|
1001
|
+
class SumDimIntListArgs:
|
1002
|
+
"""
|
1003
|
+
sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
|
1004
|
+
"""
|
1005
|
+
|
1006
|
+
input: Union[torch.fx.Node, torch.Tensor, float, int]
|
1007
|
+
dim: List[int] = field(default_factory=list)
|
1008
|
+
keepdim: bool = False
|
1009
|
+
dtype: Optional[torch.dtype] = None
|
1010
|
+
|
1011
|
+
|
1012
|
+
@enforce_type
|
1013
|
+
@dataclass
|
1014
|
+
class TanhArgs:
|
1015
|
+
"""
|
1016
|
+
tanh(Tensor self) -> Tensor
|
1017
|
+
"""
|
1018
|
+
|
1019
|
+
input: torch.fx.Node
|
1020
|
+
|
1021
|
+
|
1022
|
+
@enforce_type
|
1023
|
+
@dataclass
|
1024
|
+
class ToCopyArgs:
|
1025
|
+
"""
|
1026
|
+
_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
|
1027
|
+
"""
|
1028
|
+
|
1029
|
+
input: torch.fx.Node
|
1030
|
+
dtype: Optional[torch.dtype] = None
|
1031
|
+
layout: Optional[torch.layout] = None
|
1032
|
+
device: Optional[torch.device] = None
|
1033
|
+
pin_memory: Optional[bool] = None
|
1034
|
+
non_blocking: Optional[bool] = False
|
1035
|
+
memory_format: Optional[torch.memory_format] = None
|
1036
|
+
|
1037
|
+
|
1038
|
+
@enforce_type
|
1039
|
+
@dataclass
|
1040
|
+
class ToDtypeArgs:
|
1041
|
+
"""
|
1042
|
+
to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
|
1043
|
+
"""
|
1044
|
+
|
1045
|
+
input: torch.fx.Node
|
1046
|
+
dtype: Optional[torch.dtype] = None
|
1047
|
+
non_blocking: Optional[bool] = False
|
1048
|
+
copy: Optional[bool] = False
|
1049
|
+
memory_format: Optional[torch.memory_format] = None
|
1050
|
+
|
1051
|
+
|
1052
|
+
@enforce_type
|
1053
|
+
@dataclass
|
1054
|
+
class ToDtypeLayoutArgs:
|
1055
|
+
"""
|
1056
|
+
to.dtype_layout(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
|
1057
|
+
"""
|
1058
|
+
|
1059
|
+
input: torch.fx.Node
|
1060
|
+
dtype: Optional[torch.dtype] = None
|
1061
|
+
layout: Optional[torch.layout] = None
|
1062
|
+
device: Optional[torch.device] = None
|
1063
|
+
pin_memory: Optional[bool] = None
|
1064
|
+
non_blocking: Optional[bool] = False
|
1065
|
+
copy: Optional[bool] = False
|
1066
|
+
memory_format: Optional[torch.memory_format] = None
|
1067
|
+
|
1068
|
+
|
1069
|
+
@enforce_type
|
1070
|
+
@dataclass
|
1071
|
+
class UnSqueezeArgs:
|
1072
|
+
"""
|
1073
|
+
unsqueeze(Tensor(a) self, int dim) -> Tensor(a)
|
1074
|
+
unsqueeze_copy(Tensor self, int dim) -> Tensor
|
1075
|
+
"""
|
1076
|
+
|
1077
|
+
input: torch.fx.Node
|
1078
|
+
dim: int
|
1079
|
+
|
1080
|
+
|
1081
|
+
@enforce_type
|
1082
|
+
@dataclass
|
1083
|
+
class UpsampleNearest2DVecArgs:
|
1084
|
+
"""
|
1085
|
+
upsample_nearest2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
|
1086
|
+
"""
|
1087
|
+
|
1088
|
+
input: torch.fx.Node
|
1089
|
+
output_size: Optional[List[int]]
|
1090
|
+
scale_factors: Optional[List[float]]
|
1091
|
+
|
1092
|
+
|
1093
|
+
@enforce_type
|
1094
|
+
@dataclass
|
1095
|
+
class ViewArgs:
|
1096
|
+
"""
|
1097
|
+
view(Tensor(a) self, SymInt[] size) -> Tensor(a)
|
1098
|
+
view_copy(Tensor self, SymInt[] size) -> Tensor
|
1099
|
+
"""
|
1100
|
+
|
1101
|
+
input: torch.fx.Node
|
1102
|
+
size: List[int]
|
1103
|
+
|
1104
|
+
|
1105
|
+
@enforce_type
|
1106
|
+
@dataclass
|
1107
|
+
class WhereSelfArgs:
|
1108
|
+
"""
|
1109
|
+
where.self(Tensor condition, Tensor self, Tensor other) -> Tensor
|
1110
|
+
"""
|
1111
|
+
|
1112
|
+
condition: torch.fx.Node
|
1113
|
+
input: Union[torch.fx.Node, torch.Tensor]
|
1114
|
+
other: Union[torch.fx.Node, torch.Tensor]
|