tico 0.1.0.dev250609__py3-none-any.whl → 0.1.0.dev250611__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/passes/cast_aten_where_arg_type.py +65 -62
- tico/passes/cast_mixed_type_args.py +2 -5
- tico/passes/convert_conv1d_to_conv2d.py +3 -4
- tico/passes/convert_repeat_to_expand_copy.py +5 -9
- tico/passes/decompose_addmm.py +41 -48
- tico/passes/decompose_batch_norm.py +97 -99
- tico/passes/decompose_fake_quantize.py +4 -6
- tico/passes/decompose_fake_quantize_tensor_qparams.py +4 -4
- tico/passes/decompose_group_norm.py +8 -7
- tico/passes/decompose_grouped_conv2d.py +2 -3
- tico/passes/decompose_slice_scatter.py +2 -4
- tico/passes/extract_dtype_kwargs.py +2 -1
- tico/passes/fuse_leading_unsqueeze_reshape.py +107 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +3 -5
- tico/passes/legalize_causal_mask_value.py +2 -7
- tico/passes/legalize_predefined_layout_operators.py +2 -3
- tico/passes/lower_pow2_to_mul.py +5 -7
- tico/passes/lower_to_resize_nearest_neighbor.py +6 -10
- tico/passes/lower_to_slice.py +3 -9
- tico/passes/merge_consecutive_cat.py +2 -4
- tico/passes/remove_nop.py +2 -3
- tico/passes/remove_redundant_assert_nodes.py +2 -1
- tico/passes/remove_redundant_expand.py +5 -9
- tico/passes/remove_redundant_permute.py +6 -5
- tico/passes/remove_redundant_reshape.py +17 -34
- tico/passes/remove_redundant_slice.py +2 -4
- tico/passes/remove_redundant_to_copy.py +2 -4
- tico/passes/segment_index_select.py +2 -4
- tico/serialize/operators/op_where.py +2 -2
- tico/utils/convert.py +2 -0
- tico/utils/utils.py +26 -0
- {tico-0.1.0.dev250609.dist-info → tico-0.1.0.dev250611.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250609.dist-info → tico-0.1.0.dev250611.dist-info}/RECORD +38 -37
- {tico-0.1.0.dev250609.dist-info → tico-0.1.0.dev250611.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250609.dist-info → tico-0.1.0.dev250611.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250609.dist-info → tico-0.1.0.dev250611.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250609.dist-info → tico-0.1.0.dev250611.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
@@ -21,7 +21,7 @@ from tico.config import CompileConfigV1, get_default_config
|
|
21
21
|
from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
|
22
22
|
|
23
23
|
# THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
|
24
|
-
__version__ = "0.1.0.
|
24
|
+
__version__ = "0.1.0.dev250611"
|
25
25
|
|
26
26
|
MINIMUM_SUPPORTED_VERSION = "2.5.0"
|
27
27
|
SECURE_TORCH_VERSION = "2.6.0"
|
@@ -26,7 +26,8 @@ from tico.utils.trace_decorators import (
|
|
26
26
|
trace_const_diff_on_pass,
|
27
27
|
trace_graph_diff_on_pass,
|
28
28
|
)
|
29
|
-
from tico.utils.utils import set_new_meta_val
|
29
|
+
from tico.utils.utils import is_target_node, set_new_meta_val
|
30
|
+
from tico.utils.validate_args_kwargs import WhereSelfArgs
|
30
31
|
|
31
32
|
|
32
33
|
dtype_ranking = {
|
@@ -114,69 +115,71 @@ class CastATenWhereArgType(PassBase):
|
|
114
115
|
modified = False
|
115
116
|
|
116
117
|
for node in graph.nodes:
|
117
|
-
if
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
cast = graph_module.graph.call_function(
|
159
|
-
torch.ops.aten._to_copy.default,
|
160
|
-
args=(to_cast,),
|
161
|
-
kwargs={"dtype": dtype_to_cast},
|
118
|
+
if not is_target_node(node, torch.ops.aten.where.self):
|
119
|
+
continue
|
120
|
+
|
121
|
+
where_args = WhereSelfArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
122
|
+
result_true, result_false = where_args.input, where_args.other
|
123
|
+
if not isinstance(result_true, torch.fx.Node) or not isinstance(
|
124
|
+
result_false, torch.fx.Node
|
125
|
+
):
|
126
|
+
continue
|
127
|
+
|
128
|
+
ep = exported_program
|
129
|
+
assert isinstance(result_true, torch.fx.Node)
|
130
|
+
assert isinstance(result_false, torch.fx.Node)
|
131
|
+
if not (
|
132
|
+
result_true.name in ep.graph_signature.inputs_to_buffers
|
133
|
+
and result_false.name in ep.graph_signature.inputs_to_buffers
|
134
|
+
):
|
135
|
+
continue
|
136
|
+
|
137
|
+
# Check if they have different data types
|
138
|
+
true_dtype = extract_torch_dtype(result_true)
|
139
|
+
false_dtype = extract_torch_dtype(result_false)
|
140
|
+
if true_dtype == false_dtype:
|
141
|
+
continue
|
142
|
+
|
143
|
+
node_to_dtype = {result_true: true_dtype, result_false: false_dtype}
|
144
|
+
|
145
|
+
not_to_cast, to_cast = sort_by_dtype(result_true, result_false)
|
146
|
+
|
147
|
+
buf_name_to_data = {name: buf for name, buf in ep.named_buffers()}
|
148
|
+
buf_name = ep.graph_signature.inputs_to_buffers[to_cast.name]
|
149
|
+
buf_data = buf_name_to_data[buf_name]
|
150
|
+
|
151
|
+
assert isinstance(buf_data, torch.Tensor)
|
152
|
+
|
153
|
+
dtype_to_cast = node_to_dtype[not_to_cast]
|
154
|
+
|
155
|
+
if dtype_to_cast == torch.float32:
|
156
|
+
if not check_if_covered_by_float(buf_data):
|
157
|
+
raise RuntimeError(
|
158
|
+
f"{to_cast.name}({buf_data.dtype}) data range is out of {dtype_to_cast} range"
|
162
159
|
)
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
node_dtype_ori = extract_torch_dtype(node)
|
169
|
-
set_new_meta_val(node)
|
170
|
-
node_dtype = extract_torch_dtype(node)
|
171
|
-
assert (
|
172
|
-
node_dtype == node_dtype_ori
|
173
|
-
), f"Type casting doesn't change node's dtype."
|
174
|
-
|
175
|
-
logger.debug(
|
176
|
-
f"{to_cast.name}'s dtype was casted from {buf_data.dtype} to {dtype_to_cast}"
|
160
|
+
with graph_module.graph.inserting_after(to_cast):
|
161
|
+
cast = graph_module.graph.call_function(
|
162
|
+
torch.ops.aten._to_copy.default,
|
163
|
+
args=(to_cast,),
|
164
|
+
kwargs={"dtype": dtype_to_cast},
|
177
165
|
)
|
178
|
-
|
179
|
-
|
166
|
+
# set new meta["val"] in advance because we will use it below for checking if type promotion is valid.
|
167
|
+
set_new_meta_val(cast)
|
168
|
+
node.update_arg(node.args.index(to_cast), cast)
|
169
|
+
|
170
|
+
# check if type promotion is valid.
|
171
|
+
node_dtype_ori = extract_torch_dtype(node)
|
172
|
+
set_new_meta_val(node)
|
173
|
+
node_dtype = extract_torch_dtype(node)
|
174
|
+
assert (
|
175
|
+
node_dtype == node_dtype_ori
|
176
|
+
), f"Type casting doesn't change node's dtype."
|
177
|
+
|
178
|
+
logger.debug(
|
179
|
+
f"{to_cast.name}'s dtype was casted from {buf_data.dtype} to {dtype_to_cast}"
|
180
|
+
)
|
181
|
+
|
182
|
+
modified = True
|
180
183
|
|
181
184
|
graph.eliminate_dead_code()
|
182
185
|
graph.lint()
|
@@ -28,7 +28,7 @@ from tico.serialize.circle_mapping import extract_torch_dtype
|
|
28
28
|
from tico.utils import logging
|
29
29
|
from tico.utils.passes import PassBase, PassResult
|
30
30
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
31
|
-
from tico.utils.utils import set_new_meta_val
|
31
|
+
from tico.utils.utils import is_target_node, set_new_meta_val
|
32
32
|
|
33
33
|
|
34
34
|
ops_to_promote = {
|
@@ -96,10 +96,7 @@ class CastMixedTypeArgs(PassBase):
|
|
96
96
|
graph = graph_module.graph
|
97
97
|
modified = False
|
98
98
|
for node in graph.nodes:
|
99
|
-
if not node.
|
100
|
-
continue
|
101
|
-
|
102
|
-
if node.target not in ops_to_promote:
|
99
|
+
if not is_target_node(node, list(ops_to_promote.keys())):
|
103
100
|
continue
|
104
101
|
|
105
102
|
assert len(node.args) == 2
|
@@ -24,6 +24,7 @@ from tico.utils import logging
|
|
24
24
|
from tico.utils.errors import NotYetSupportedError
|
25
25
|
from tico.utils.passes import PassBase, PassResult
|
26
26
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
|
+
from tico.utils.utils import is_target_node
|
27
28
|
from tico.utils.validate_args_kwargs import Conv1DArgs
|
28
29
|
|
29
30
|
|
@@ -131,17 +132,15 @@ class ConvertConv1dToConv2d(PassBase):
|
|
131
132
|
return modified
|
132
133
|
|
133
134
|
def call(self, exported_program: ExportedProgram) -> PassResult:
|
134
|
-
target_conv_op =
|
135
|
+
target_conv_op = [torch.ops.aten.conv1d.default, torch.ops.aten.conv1d.padding]
|
135
136
|
|
136
137
|
graph_module = exported_program.graph_module
|
137
138
|
graph = graph_module.graph
|
138
139
|
modified = False
|
139
140
|
for node in graph.nodes:
|
140
|
-
if not node
|
141
|
+
if not is_target_node(node, target_conv_op):
|
141
142
|
continue
|
142
143
|
|
143
|
-
if node.target not in target_conv_op:
|
144
|
-
continue
|
145
144
|
modified |= self.convert(exported_program, node)
|
146
145
|
|
147
146
|
graph.eliminate_dead_code()
|
@@ -22,6 +22,8 @@ from torch.export import ExportedProgram
|
|
22
22
|
from tico.utils import logging
|
23
23
|
from tico.utils.passes import PassBase, PassResult
|
24
24
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
25
|
+
from tico.utils.utils import is_target_node
|
26
|
+
from tico.utils.validate_args_kwargs import RepeatArgs
|
25
27
|
|
26
28
|
|
27
29
|
@trace_graph_diff_on_pass
|
@@ -42,17 +44,11 @@ class ConvertRepeatToExpandCopy(PassBase):
|
|
42
44
|
graph = graph_module.graph
|
43
45
|
modified = False
|
44
46
|
for node in graph.nodes:
|
45
|
-
if not node.
|
47
|
+
if not is_target_node(node, torch.ops.aten.repeat.default):
|
46
48
|
continue
|
47
49
|
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
assert len(node.args) == 2
|
52
|
-
|
53
|
-
tensor, repeats = node.args
|
54
|
-
assert isinstance(tensor, torch.fx.Node)
|
55
|
-
assert isinstance(repeats, list)
|
50
|
+
reshape_args = RepeatArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
51
|
+
tensor, repeats = reshape_args.input, reshape_args.repeats
|
56
52
|
|
57
53
|
tensor_shape: List[int] = [int(dim) for dim in tensor.meta["val"].shape]
|
58
54
|
|
tico/passes/decompose_addmm.py
CHANGED
@@ -24,7 +24,7 @@ from tico.utils import logging
|
|
24
24
|
from tico.utils.graph import add_placeholder
|
25
25
|
from tico.utils.passes import PassBase, PassResult
|
26
26
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
|
-
from tico.utils.utils import set_new_meta_val
|
27
|
+
from tico.utils.utils import is_target_node, set_new_meta_val
|
28
28
|
from tico.utils.validate_args_kwargs import AddmmArgs
|
29
29
|
|
30
30
|
|
@@ -66,59 +66,52 @@ class DecomposeAddmm(PassBase):
|
|
66
66
|
modified = False
|
67
67
|
|
68
68
|
for node in graph.nodes:
|
69
|
-
if node.
|
69
|
+
if not is_target_node(node, torch.ops.aten.addmm.default):
|
70
70
|
continue
|
71
71
|
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
72
|
+
args = AddmmArgs(*node.args, **node.kwargs)
|
73
|
+
input = args.input
|
74
|
+
mat1 = args.mat1
|
75
|
+
mat2 = args.mat2
|
76
|
+
beta = args.beta
|
77
|
+
alpha = args.alpha
|
78
|
+
|
79
|
+
with graph.inserting_before(node):
|
80
|
+
# out = beta * input + alpha * (mat1 @ mat2)
|
81
|
+
matmul = graph.call_function(torch.ops.aten.mm.default, (mat1, mat2))
|
82
|
+
set_new_meta_val(matmul)
|
83
|
+
|
84
|
+
if beta == 1:
|
85
|
+
bias: torch.fx.Node | torch.Tensor = input
|
86
|
+
elif beta == 0:
|
87
|
+
bias = add_placeholder(
|
88
|
+
exported_program,
|
89
|
+
torch.zeros(extract_shape(input)),
|
90
|
+
f"{node.name}_beta_zeros",
|
86
91
|
)
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
else:
|
98
|
-
bias = graph.call_function(
|
99
|
-
torch.ops.aten.mul.Tensor, (input, beta)
|
100
|
-
)
|
101
|
-
|
102
|
-
if alpha == 1:
|
103
|
-
scaled_matmul: torch.fx.Node | torch.Tensor = matmul
|
104
|
-
elif alpha == 0:
|
105
|
-
scaled_matmul = add_placeholder(
|
106
|
-
exported_program,
|
107
|
-
torch.zeros(extract_shape(matmul)),
|
108
|
-
f"{node.name}_alpha_zeros",
|
109
|
-
)
|
110
|
-
else:
|
111
|
-
scaled_matmul = graph.call_function(
|
112
|
-
torch.ops.aten.mul.Tensor, (matmul, alpha)
|
113
|
-
)
|
114
|
-
|
115
|
-
result = graph.call_function(
|
116
|
-
torch.ops.aten.add.Tensor, (bias, scaled_matmul)
|
92
|
+
else:
|
93
|
+
bias = graph.call_function(torch.ops.aten.mul.Tensor, (input, beta))
|
94
|
+
|
95
|
+
if alpha == 1:
|
96
|
+
scaled_matmul: torch.fx.Node | torch.Tensor = matmul
|
97
|
+
elif alpha == 0:
|
98
|
+
scaled_matmul = add_placeholder(
|
99
|
+
exported_program,
|
100
|
+
torch.zeros(extract_shape(matmul)),
|
101
|
+
f"{node.name}_alpha_zeros",
|
117
102
|
)
|
103
|
+
else:
|
104
|
+
scaled_matmul = graph.call_function(
|
105
|
+
torch.ops.aten.mul.Tensor, (matmul, alpha)
|
106
|
+
)
|
107
|
+
|
108
|
+
result = graph.call_function(
|
109
|
+
torch.ops.aten.add.Tensor, (bias, scaled_matmul)
|
110
|
+
)
|
118
111
|
|
119
|
-
|
112
|
+
node.replace_all_uses_with(result, propagate_meta=True)
|
120
113
|
|
121
|
-
|
114
|
+
modified = True
|
122
115
|
|
123
116
|
gm.graph.eliminate_dead_code()
|
124
117
|
gm.graph.lint()
|
@@ -32,7 +32,7 @@ from tico.utils.graph import (
|
|
32
32
|
)
|
33
33
|
from tico.utils.passes import PassBase, PassResult
|
34
34
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
35
|
-
from tico.utils.utils import fill_meta_val
|
35
|
+
from tico.utils.utils import fill_meta_val, is_target_node
|
36
36
|
from tico.utils.validate_args_kwargs import NativeBatchNormLegitNoTrainingArgs
|
37
37
|
|
38
38
|
|
@@ -87,109 +87,107 @@ class DecomposeBatchNorm(PassBase):
|
|
87
87
|
modified = False
|
88
88
|
|
89
89
|
for node in graph.nodes:
|
90
|
-
if
|
90
|
+
if not is_target_node(
|
91
|
+
node, torch.ops.aten._native_batch_norm_legit_no_training.default
|
92
|
+
):
|
91
93
|
continue
|
92
94
|
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
95
|
+
args = NativeBatchNormLegitNoTrainingArgs(*node.args)
|
96
|
+
input_ = args.input
|
97
|
+
weight = args.weight
|
98
|
+
bias = args.bias
|
99
|
+
running_mean = args.running_mean
|
100
|
+
running_var = args.running_var
|
101
|
+
eps = args.eps
|
102
|
+
|
103
|
+
if not running_mean:
|
104
|
+
raise NotYetSupportedError(f"running_mean=None is not supported yet")
|
105
|
+
if not running_var:
|
106
|
+
raise NotYetSupportedError(f"running_var=None is not supported yet")
|
107
|
+
|
108
|
+
"""
|
109
|
+
Only support the cases generated from torch.nn.BatchNorm2d module,
|
110
|
+
for which, let's checks if weight and bias are parameters and
|
111
|
+
running_mean and running_var are buffers.
|
112
|
+
"""
|
113
|
+
if weight and not is_torch_param(weight, exported_program):
|
114
|
+
continue
|
115
|
+
if bias and not is_torch_param(bias, exported_program):
|
116
|
+
continue
|
117
|
+
if not is_torch_buffer(running_mean, exported_program):
|
118
|
+
continue
|
119
|
+
if not is_torch_buffer(running_var, exported_program):
|
120
|
+
continue
|
121
|
+
|
122
|
+
input_shape = extract_shape(input_)
|
123
|
+
assert len(input_shape) == 4
|
124
|
+
C = input_shape[1]
|
125
|
+
|
126
|
+
weight_value = (
|
127
|
+
get_torch_param_value(weight, exported_program)
|
128
|
+
if weight
|
129
|
+
else torch.tensor([1] * C)
|
130
|
+
)
|
131
|
+
bias_value = (
|
132
|
+
get_torch_param_value(bias, exported_program)
|
133
|
+
if bias
|
134
|
+
else torch.tensor([0] * C)
|
135
|
+
)
|
136
|
+
mean_value = get_torch_buffer_value(running_mean, exported_program)
|
137
|
+
var_value = get_torch_buffer_value(running_var, exported_program)
|
138
|
+
|
139
|
+
assert isinstance(weight_value, torch.Tensor)
|
140
|
+
assert isinstance(bias_value, torch.Tensor)
|
141
|
+
assert isinstance(mean_value, torch.Tensor)
|
142
|
+
assert isinstance(var_value, torch.Tensor)
|
143
|
+
|
144
|
+
assert (
|
145
|
+
weight_value.shape
|
146
|
+
== bias_value.shape
|
147
|
+
== mean_value.shape
|
148
|
+
== var_value.shape
|
149
|
+
)
|
150
|
+
# Calculate constants for mul and add
|
151
|
+
mul_const = weight_value / torch.sqrt(var_value + eps)
|
152
|
+
add_const = bias_value - (mul_const * mean_value)
|
153
|
+
# N, C, H, W
|
154
|
+
assert len(mul_const) == len(add_const) == C
|
155
|
+
# reshape along with channel dimension
|
156
|
+
mul_const = mul_const.view(1, mul_const.shape[0], 1, 1)
|
157
|
+
add_const = add_const.view(1, add_const.shape[0], 1, 1)
|
158
|
+
|
159
|
+
# Placeholder nodes must be the first N nodes in the nodes list of a graph.
|
160
|
+
# Therefore, insert the newly created placeholders at the start of the node list.
|
161
|
+
with exported_program.graph.inserting_before(
|
162
|
+
get_first_user_input(exported_program)
|
163
|
+
):
|
164
|
+
mul_const_node = add_placeholder(
|
165
|
+
exported_program,
|
166
|
+
mul_const,
|
167
|
+
prefix=f"{node.name}_mul_const",
|
168
|
+
)
|
169
|
+
add_const_node = add_placeholder(
|
170
|
+
exported_program,
|
171
|
+
add_const,
|
172
|
+
prefix=f"{node.name}_add_const",
|
133
173
|
)
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
174
|
+
|
175
|
+
with gm.graph.inserting_before(node):
|
176
|
+
mul = graph.call_function(
|
177
|
+
torch.ops.aten.mul.Tensor,
|
178
|
+
args=(input_, mul_const_node),
|
138
179
|
)
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
assert isinstance(weight_value, torch.Tensor)
|
143
|
-
assert isinstance(bias_value, torch.Tensor)
|
144
|
-
assert isinstance(mean_value, torch.Tensor)
|
145
|
-
assert isinstance(var_value, torch.Tensor)
|
146
|
-
|
147
|
-
assert (
|
148
|
-
weight_value.shape
|
149
|
-
== bias_value.shape
|
150
|
-
== mean_value.shape
|
151
|
-
== var_value.shape
|
180
|
+
add = graph.call_function(
|
181
|
+
torch.ops.aten.add.Tensor,
|
182
|
+
args=(mul, add_const_node),
|
152
183
|
)
|
153
|
-
#
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
add_const = add_const.view(1, add_const.shape[0], 1, 1)
|
161
|
-
|
162
|
-
# Placeholder nodes must be the first N nodes in the nodes list of a graph.
|
163
|
-
# Therefore, insert the newly created placeholders at the start of the node list.
|
164
|
-
with exported_program.graph.inserting_before(
|
165
|
-
get_first_user_input(exported_program)
|
166
|
-
):
|
167
|
-
mul_const_node = add_placeholder(
|
168
|
-
exported_program,
|
169
|
-
mul_const,
|
170
|
-
prefix=f"{node.name}_mul_const",
|
171
|
-
)
|
172
|
-
add_const_node = add_placeholder(
|
173
|
-
exported_program,
|
174
|
-
add_const,
|
175
|
-
prefix=f"{node.name}_add_const",
|
176
|
-
)
|
177
|
-
|
178
|
-
with gm.graph.inserting_before(node):
|
179
|
-
mul = graph.call_function(
|
180
|
-
torch.ops.aten.mul.Tensor,
|
181
|
-
args=(input_, mul_const_node),
|
182
|
-
)
|
183
|
-
add = graph.call_function(
|
184
|
-
torch.ops.aten.add.Tensor,
|
185
|
-
args=(mul, add_const_node),
|
186
|
-
)
|
187
|
-
# Not set meta for propagating replacing get_item's meta.
|
188
|
-
get_item, *_ = node.users.keys()
|
189
|
-
get_item.replace_all_uses_with(add, propagate_meta=True)
|
190
|
-
|
191
|
-
fill_meta_val(exported_program)
|
192
|
-
modified = True
|
184
|
+
# Not set meta for propagating replacing get_item's meta.
|
185
|
+
get_item, *_ = node.users.keys()
|
186
|
+
get_item.replace_all_uses_with(add, propagate_meta=True)
|
187
|
+
|
188
|
+
fill_meta_val(exported_program)
|
189
|
+
logger.debug(f"{node.name} is decomposed to {mul.name} and {add.name}")
|
190
|
+
modified = True
|
193
191
|
|
194
192
|
gm.graph.eliminate_dead_code()
|
195
193
|
gm.graph.lint()
|
@@ -71,9 +71,9 @@ class DecomposeFakeQuantize(PassBase):
|
|
71
71
|
gm = exported_program.graph_module
|
72
72
|
qd = torch.ops.quantized_decomposed # type: ignore[return]
|
73
73
|
for node in gm.graph.nodes:
|
74
|
-
if node.op
|
75
|
-
|
76
|
-
]:
|
74
|
+
if node.op != "call_function":
|
75
|
+
continue
|
76
|
+
if node.target in [torch.ops.aten.fake_quantize_per_tensor_affine.default]:
|
77
77
|
# tensor, scale, zero_p, quant_min, quant_max
|
78
78
|
assert len(node.args) == 5
|
79
79
|
_, _, _, quant_min, quant_max = node.args
|
@@ -97,9 +97,7 @@ class DecomposeFakeQuantize(PassBase):
|
|
97
97
|
node.replace_all_uses_with(dequnt, propagate_meta=True)
|
98
98
|
modified = True
|
99
99
|
|
100
|
-
if node.
|
101
|
-
torch.ops.aten.fake_quantize_per_channel_affine.default
|
102
|
-
]:
|
100
|
+
if node.target in [torch.ops.aten.fake_quantize_per_channel_affine.default]:
|
103
101
|
fq_args = FakeQuantizePerChannelArgs(*node.args, **node.kwargs)
|
104
102
|
quant_min = fq_args.quant_min
|
105
103
|
quant_max = fq_args.quant_max
|
@@ -202,9 +202,10 @@ class DecomposeFakeQuantizeTensorQParams(PassBase):
|
|
202
202
|
gm = exported_program.graph_module
|
203
203
|
qd = torch.ops.quantized_decomposed # type: ignore[return]
|
204
204
|
for node in gm.graph.nodes:
|
205
|
+
if node.op != "call_function":
|
206
|
+
continue
|
205
207
|
if (
|
206
|
-
node.
|
207
|
-
and node.target
|
208
|
+
node.target
|
208
209
|
== torch.ops.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams.default
|
209
210
|
):
|
210
211
|
# tensor, scale, zero_p, fake_quant_enabled, quant_min, quant_max
|
@@ -247,8 +248,7 @@ class DecomposeFakeQuantizeTensorQParams(PassBase):
|
|
247
248
|
mask_user.args = ((mask_user.args[0][0],),)
|
248
249
|
modified = True
|
249
250
|
if (
|
250
|
-
node.
|
251
|
-
and node.target
|
251
|
+
node.target
|
252
252
|
== torch.ops.aten.fake_quantize_per_tensor_affine.tensor_qparams
|
253
253
|
):
|
254
254
|
fq_args = FakeQuantizePerTensorTQParamArgs(*node.args, **node.kwargs)
|
@@ -25,6 +25,7 @@ from tico.serialize.circle_mapping import extract_shape
|
|
25
25
|
from tico.utils import logging
|
26
26
|
from tico.utils.passes import PassBase, PassResult
|
27
27
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
|
+
from tico.utils.utils import is_target_node
|
28
29
|
from tico.utils.validate_args_kwargs import NativeGroupNormArgs, NativeLayerNormArgs
|
29
30
|
|
30
31
|
|
@@ -115,13 +116,13 @@ class DecomposeGroupNorm(PassBase):
|
|
115
116
|
modified = False
|
116
117
|
|
117
118
|
for node in graph.nodes:
|
118
|
-
if
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
119
|
+
if not is_target_node(
|
120
|
+
node,
|
121
|
+
[
|
122
|
+
torch.ops.aten.native_layer_norm.default,
|
123
|
+
torch.ops.aten.native_group_norm.default,
|
124
|
+
],
|
125
|
+
):
|
125
126
|
continue
|
126
127
|
|
127
128
|
if node.target == torch.ops.aten.native_layer_norm.default:
|
@@ -26,6 +26,7 @@ from tico.utils.errors import InvalidArgumentError, NotYetSupportedError
|
|
26
26
|
from tico.utils.graph import add_placeholder
|
27
27
|
from tico.utils.passes import PassBase, PassResult
|
28
28
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
29
|
+
from tico.utils.utils import is_target_node
|
29
30
|
from tico.utils.validate_args_kwargs import Conv2DArgs
|
30
31
|
|
31
32
|
|
@@ -88,9 +89,7 @@ class DecomposeGroupedConv2d(PassBase):
|
|
88
89
|
modified = False
|
89
90
|
|
90
91
|
for node in graph.nodes:
|
91
|
-
if node.
|
92
|
-
continue
|
93
|
-
if not node.target in ops.aten.conv2d:
|
92
|
+
if not is_target_node(node, ops.aten.conv2d):
|
94
93
|
continue
|
95
94
|
|
96
95
|
args = Conv2DArgs(*node.args)
|