tico 0.1.0.dev250616__py3-none-any.whl → 0.1.0.dev250617__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +6 -2
- tico/passes/cast_aten_where_arg_type.py +4 -1
- tico/passes/cast_mixed_type_args.py +4 -1
- tico/passes/convert_conv1d_to_conv2d.py +12 -4
- tico/passes/convert_layout_op_to_reshape.py +3 -2
- tico/passes/convert_repeat_to_expand_copy.py +5 -2
- tico/passes/convert_to_relu6.py +4 -3
- tico/passes/decompose_addmm.py +11 -7
- tico/passes/decompose_batch_norm.py +7 -11
- tico/passes/decompose_fake_quantize.py +12 -6
- tico/passes/decompose_fake_quantize_tensor_qparams.py +12 -6
- tico/passes/decompose_group_norm.py +50 -21
- tico/passes/decompose_grouped_conv2d.py +15 -7
- tico/passes/decompose_slice_scatter.py +9 -5
- tico/passes/fuse_leading_unsqueeze_reshape.py +8 -3
- tico/passes/legalize_predefined_layout_operators.py +33 -25
- tico/passes/lower_pow2_to_mul.py +3 -1
- tico/passes/lower_to_resize_nearest_neighbor.py +21 -10
- tico/passes/lower_to_slice.py +21 -11
- tico/passes/remove_redundant_permute.py +5 -3
- tico/passes/remove_redundant_reshape.py +5 -2
- tico/passes/remove_redundant_to_copy.py +4 -0
- tico/passes/restore_linear.py +7 -5
- tico/passes/segment_index_select.py +9 -5
- tico/utils/graph.py +48 -2
- {tico-0.1.0.dev250616.dist-info → tico-0.1.0.dev250617.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250616.dist-info → tico-0.1.0.dev250617.dist-info}/RECORD +32 -32
- {tico-0.1.0.dev250616.dist-info → tico-0.1.0.dev250617.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250616.dist-info → tico-0.1.0.dev250617.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250616.dist-info → tico-0.1.0.dev250617.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250616.dist-info → tico-0.1.0.dev250617.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
@@ -21,7 +21,7 @@ from tico.config import CompileConfigV1, get_default_config
|
|
21
21
|
from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
|
22
22
|
|
23
23
|
# THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
|
24
|
-
__version__ = "0.1.0.
|
24
|
+
__version__ = "0.1.0.dev250617"
|
25
25
|
|
26
26
|
MINIMUM_SUPPORTED_VERSION = "2.5.0"
|
27
27
|
SECURE_TORCH_VERSION = "2.6.0"
|
@@ -24,6 +24,7 @@ from torch.export import ExportedProgram
|
|
24
24
|
from tico.serialize.quant_param import QPARAM_KEY, QuantParam
|
25
25
|
from tico.utils import logging
|
26
26
|
from tico.utils.errors import NotYetSupportedError
|
27
|
+
from tico.utils.graph import create_node
|
27
28
|
from tico.utils.passes import PassBase, PassResult
|
28
29
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
29
30
|
from tico.utils.utils import quant_min_max, set_new_meta_val
|
@@ -145,9 +146,11 @@ class InsertQuantizeOnDtypeMismatch(PassBase):
|
|
145
146
|
|
146
147
|
with graph.inserting_before(node):
|
147
148
|
q_args = (inp, scale, zerop, min_, max_, dtype)
|
148
|
-
quantize =
|
149
|
+
quantize = create_node(
|
150
|
+
graph,
|
149
151
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
150
152
|
args=q_args,
|
153
|
+
origin=node,
|
151
154
|
)
|
152
155
|
quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
|
153
156
|
set_new_meta_val(quantize)
|
@@ -166,7 +169,8 @@ class InsertQuantizeOnDtypeMismatch(PassBase):
|
|
166
169
|
dtype = getattr(torch, qparam.dtype)
|
167
170
|
with graph.inserting_after(node):
|
168
171
|
q_args = (node, scale, zerop, min_, max_, dtype)
|
169
|
-
quantize =
|
172
|
+
quantize = create_node(
|
173
|
+
graph,
|
170
174
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
171
175
|
args=q_args,
|
172
176
|
)
|
@@ -21,6 +21,7 @@ from torch.export import ExportedProgram
|
|
21
21
|
|
22
22
|
from tico.serialize.circle_mapping import extract_torch_dtype
|
23
23
|
from tico.utils import logging
|
24
|
+
from tico.utils.graph import create_node
|
24
25
|
from tico.utils.passes import PassBase, PassResult
|
25
26
|
from tico.utils.trace_decorators import (
|
26
27
|
trace_const_diff_on_pass,
|
@@ -158,10 +159,12 @@ class CastATenWhereArgType(PassBase):
|
|
158
159
|
f"{to_cast.name}({buf_data.dtype}) data range is out of {dtype_to_cast} range"
|
159
160
|
)
|
160
161
|
with graph_module.graph.inserting_after(to_cast):
|
161
|
-
cast =
|
162
|
+
cast = create_node(
|
163
|
+
graph,
|
162
164
|
torch.ops.aten._to_copy.default,
|
163
165
|
args=(to_cast,),
|
164
166
|
kwargs={"dtype": dtype_to_cast},
|
167
|
+
origin=to_cast,
|
165
168
|
)
|
166
169
|
# set new meta["val"] in advance because we will use it below for checking if type promotion is valid.
|
167
170
|
set_new_meta_val(cast)
|
@@ -26,6 +26,7 @@ from torch.export import ExportedProgram
|
|
26
26
|
|
27
27
|
from tico.serialize.circle_mapping import extract_torch_dtype
|
28
28
|
from tico.utils import logging
|
29
|
+
from tico.utils.graph import create_node
|
29
30
|
from tico.utils.passes import PassBase, PassResult
|
30
31
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
31
32
|
from tico.utils.utils import is_target_node, set_new_meta_val
|
@@ -126,10 +127,12 @@ class CastMixedTypeArgs(PassBase):
|
|
126
127
|
|
127
128
|
if isinstance(arg_to_promote, torch.fx.Node):
|
128
129
|
with graph.inserting_after(arg_to_promote):
|
129
|
-
to_copy =
|
130
|
+
to_copy = create_node(
|
131
|
+
graph,
|
130
132
|
torch.ops.aten._to_copy.default,
|
131
133
|
(arg_to_promote,),
|
132
134
|
{"dtype": type_to_promote},
|
135
|
+
origin=arg_to_promote,
|
133
136
|
)
|
134
137
|
# set new meta["val"] in advance because we will use it below for checking if type promotion is valid.
|
135
138
|
set_new_meta_val(to_copy)
|
@@ -22,6 +22,7 @@ from torch.export import ExportedProgram
|
|
22
22
|
from tico.serialize.circle_graph import extract_shape
|
23
23
|
from tico.utils import logging
|
24
24
|
from tico.utils.errors import NotYetSupportedError
|
25
|
+
from tico.utils.graph import create_node
|
25
26
|
from tico.utils.passes import PassBase, PassResult
|
26
27
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
28
|
from tico.utils.utils import is_target_node
|
@@ -89,15 +90,19 @@ class ConvertConv1dToConv2d(PassBase):
|
|
89
90
|
)
|
90
91
|
|
91
92
|
with graph.inserting_after(input):
|
92
|
-
input_unsqueeze =
|
93
|
+
input_unsqueeze = create_node(
|
94
|
+
graph,
|
93
95
|
torch.ops.aten.unsqueeze.default,
|
94
96
|
args=(input, 3),
|
97
|
+
origin=input,
|
95
98
|
)
|
96
99
|
|
97
100
|
with graph.inserting_after(weight):
|
98
|
-
weight_unsqueeze =
|
101
|
+
weight_unsqueeze = create_node(
|
102
|
+
graph,
|
99
103
|
torch.ops.aten.unsqueeze.default,
|
100
104
|
args=(weight, 3),
|
105
|
+
origin=weight,
|
101
106
|
)
|
102
107
|
|
103
108
|
with graph.inserting_before(node):
|
@@ -106,7 +111,8 @@ class ConvertConv1dToConv2d(PassBase):
|
|
106
111
|
elif isinstance(padding, str):
|
107
112
|
conv2d_op = torch.ops.aten.conv2d.padding
|
108
113
|
|
109
|
-
conv2d =
|
114
|
+
conv2d = create_node(
|
115
|
+
graph,
|
110
116
|
conv2d_op,
|
111
117
|
args=(
|
112
118
|
input_unsqueeze,
|
@@ -118,9 +124,11 @@ class ConvertConv1dToConv2d(PassBase):
|
|
118
124
|
groups,
|
119
125
|
),
|
120
126
|
kwargs=node.kwargs,
|
127
|
+
origin=node,
|
121
128
|
)
|
122
129
|
|
123
|
-
conv_out_squeeze =
|
130
|
+
conv_out_squeeze = create_node(
|
131
|
+
graph,
|
124
132
|
torch.ops.aten.squeeze.dims,
|
125
133
|
args=(conv2d, [3]),
|
126
134
|
)
|
@@ -22,6 +22,7 @@ from torch.export import ExportedProgram
|
|
22
22
|
from tico.passes import ops
|
23
23
|
from tico.serialize.circle_mapping import extract_shape
|
24
24
|
from tico.utils import logging
|
25
|
+
from tico.utils.graph import create_node
|
25
26
|
from tico.utils.passes import PassBase, PassResult
|
26
27
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
28
|
from tico.utils.validate_args_kwargs import SqueezeArgs, UnSqueezeArgs, ViewArgs
|
@@ -48,11 +49,11 @@ class ConvertLayoutOpToReshape(PassBase):
|
|
48
49
|
out_shape = list(extract_shape(node))
|
49
50
|
|
50
51
|
with graph.inserting_after(node):
|
51
|
-
reshape_node =
|
52
|
+
reshape_node = create_node(
|
53
|
+
graph,
|
52
54
|
torch.ops.aten.reshape.default,
|
53
55
|
args=(input, out_shape),
|
54
56
|
)
|
55
|
-
|
56
57
|
node.replace_all_uses_with(reshape_node, propagate_meta=True)
|
57
58
|
|
58
59
|
logger.debug(f"{node.name} is replaced with {reshape_node.name}")
|
@@ -20,6 +20,7 @@ import torch
|
|
20
20
|
from torch.export import ExportedProgram
|
21
21
|
|
22
22
|
from tico.utils import logging
|
23
|
+
from tico.utils.graph import create_node
|
23
24
|
from tico.utils.passes import PassBase, PassResult
|
24
25
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
25
26
|
from tico.utils.utils import is_target_node
|
@@ -71,8 +72,10 @@ class ConvertRepeatToExpandCopy(PassBase):
|
|
71
72
|
expand_copy_args = (tensor, size)
|
72
73
|
|
73
74
|
with graph.inserting_after(node):
|
74
|
-
expand_copy_node =
|
75
|
-
|
75
|
+
expand_copy_node = create_node(
|
76
|
+
graph,
|
77
|
+
torch.ops.aten.expand_copy.default,
|
78
|
+
args=expand_copy_args,
|
76
79
|
)
|
77
80
|
node.replace_all_uses_with(expand_copy_node, propagate_meta=True)
|
78
81
|
|
tico/passes/convert_to_relu6.py
CHANGED
@@ -20,6 +20,7 @@ import torch
|
|
20
20
|
from torch.export import ExportedProgram
|
21
21
|
|
22
22
|
from tico.utils import logging
|
23
|
+
from tico.utils.graph import create_node
|
23
24
|
from tico.utils.passes import PassBase, PassResult
|
24
25
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
25
26
|
from tico.utils.validate_args_kwargs import ClampArgs, HardTanhArgs
|
@@ -58,7 +59,7 @@ class ConvertHardTanhToReLU6(Converter):
|
|
58
59
|
input = args.input
|
59
60
|
|
60
61
|
with graph.inserting_after(node):
|
61
|
-
relu_node = graph
|
62
|
+
relu_node = create_node(graph, torch.ops.aten.relu6.default, args=(input,))
|
62
63
|
node.replace_all_uses_with(relu_node, propagate_meta=True)
|
63
64
|
|
64
65
|
|
@@ -84,7 +85,7 @@ class ConvertClampToReLU6(Converter):
|
|
84
85
|
input = args.input
|
85
86
|
|
86
87
|
with graph.inserting_after(node):
|
87
|
-
relu_node = graph
|
88
|
+
relu_node = create_node(graph, torch.ops.aten.relu6.default, args=(input,))
|
88
89
|
node.replace_all_uses_with(relu_node, propagate_meta=True)
|
89
90
|
|
90
91
|
|
@@ -140,7 +141,7 @@ class ConvertDoubleClampsToReLU6(Converter):
|
|
140
141
|
input = prev_args.input
|
141
142
|
|
142
143
|
with graph.inserting_after(node):
|
143
|
-
relu_node = graph
|
144
|
+
relu_node = create_node(graph, torch.ops.aten.relu6.default, args=(input,))
|
144
145
|
node.replace_all_uses_with(relu_node, propagate_meta=True)
|
145
146
|
|
146
147
|
|
tico/passes/decompose_addmm.py
CHANGED
@@ -21,7 +21,7 @@ from torch.export import ExportedProgram
|
|
21
21
|
|
22
22
|
from tico.serialize.circle_mapping import extract_shape
|
23
23
|
from tico.utils import logging
|
24
|
-
from tico.utils.graph import add_placeholder
|
24
|
+
from tico.utils.graph import add_placeholder, create_node
|
25
25
|
from tico.utils.passes import PassBase, PassResult
|
26
26
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
27
|
from tico.utils.utils import is_target_node, set_new_meta_val
|
@@ -78,7 +78,9 @@ class DecomposeAddmm(PassBase):
|
|
78
78
|
|
79
79
|
with graph.inserting_before(node):
|
80
80
|
# out = beta * input + alpha * (mat1 @ mat2)
|
81
|
-
matmul =
|
81
|
+
matmul = create_node(
|
82
|
+
graph, torch.ops.aten.mm.default, (mat1, mat2), origin=node
|
83
|
+
)
|
82
84
|
set_new_meta_val(matmul)
|
83
85
|
|
84
86
|
if beta == 1:
|
@@ -90,7 +92,9 @@ class DecomposeAddmm(PassBase):
|
|
90
92
|
f"{node.name}_beta_zeros",
|
91
93
|
)
|
92
94
|
else:
|
93
|
-
bias =
|
95
|
+
bias = create_node(
|
96
|
+
graph, torch.ops.aten.mul.Tensor, (input, beta), origin=node
|
97
|
+
)
|
94
98
|
|
95
99
|
if alpha == 1:
|
96
100
|
scaled_matmul: torch.fx.Node | torch.Tensor = matmul
|
@@ -101,12 +105,12 @@ class DecomposeAddmm(PassBase):
|
|
101
105
|
f"{node.name}_alpha_zeros",
|
102
106
|
)
|
103
107
|
else:
|
104
|
-
scaled_matmul =
|
105
|
-
torch.ops.aten.mul.Tensor, (matmul, alpha)
|
108
|
+
scaled_matmul = create_node(
|
109
|
+
graph, torch.ops.aten.mul.Tensor, (matmul, alpha), origin=node
|
106
110
|
)
|
107
111
|
|
108
|
-
result =
|
109
|
-
torch.ops.aten.add.Tensor, (bias, scaled_matmul)
|
112
|
+
result = create_node(
|
113
|
+
graph, torch.ops.aten.add.Tensor, (bias, scaled_matmul)
|
110
114
|
)
|
111
115
|
|
112
116
|
node.replace_all_uses_with(result, propagate_meta=True)
|
@@ -24,6 +24,7 @@ from tico.utils import logging
|
|
24
24
|
from tico.utils.errors import NotYetSupportedError
|
25
25
|
from tico.utils.graph import (
|
26
26
|
add_placeholder,
|
27
|
+
create_node,
|
27
28
|
get_first_user_input,
|
28
29
|
get_torch_buffer_value,
|
29
30
|
get_torch_param_value,
|
@@ -32,16 +33,10 @@ from tico.utils.graph import (
|
|
32
33
|
)
|
33
34
|
from tico.utils.passes import PassBase, PassResult
|
34
35
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
35
|
-
from tico.utils.utils import
|
36
|
+
from tico.utils.utils import is_target_node
|
36
37
|
from tico.utils.validate_args_kwargs import NativeBatchNormLegitNoTrainingArgs
|
37
38
|
|
38
39
|
|
39
|
-
def insert_node(graph: torch.fx.Graph, operation, args):
|
40
|
-
new_node = graph.call_function(operation, args)
|
41
|
-
|
42
|
-
return new_node
|
43
|
-
|
44
|
-
|
45
40
|
@trace_graph_diff_on_pass
|
46
41
|
class DecomposeBatchNorm(PassBase):
|
47
42
|
"""
|
@@ -173,19 +168,20 @@ class DecomposeBatchNorm(PassBase):
|
|
173
168
|
)
|
174
169
|
|
175
170
|
with gm.graph.inserting_before(node):
|
176
|
-
mul =
|
171
|
+
mul = create_node(
|
172
|
+
graph,
|
177
173
|
torch.ops.aten.mul.Tensor,
|
178
174
|
args=(input_, mul_const_node),
|
175
|
+
origin=node,
|
179
176
|
)
|
180
|
-
add =
|
177
|
+
add = create_node(
|
178
|
+
graph,
|
181
179
|
torch.ops.aten.add.Tensor,
|
182
180
|
args=(mul, add_const_node),
|
183
181
|
)
|
184
|
-
# Not set meta for propagating replacing get_item's meta.
|
185
182
|
get_item, *_ = node.users.keys()
|
186
183
|
get_item.replace_all_uses_with(add, propagate_meta=True)
|
187
184
|
|
188
|
-
fill_meta_val(exported_program)
|
189
185
|
logger.debug(f"{node.name} is decomposed to {mul.name} and {add.name}")
|
190
186
|
modified = True
|
191
187
|
|
@@ -23,6 +23,7 @@ from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
|
|
23
23
|
from torch.export import ExportedProgram
|
24
24
|
|
25
25
|
from tico.utils import logging
|
26
|
+
from tico.utils.graph import create_node
|
26
27
|
from tico.utils.passes import PassBase, PassResult
|
27
28
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
29
|
from tico.utils.validate_args_kwargs import FakeQuantizePerChannelArgs
|
@@ -69,6 +70,7 @@ class DecomposeFakeQuantize(PassBase):
|
|
69
70
|
modified = False
|
70
71
|
|
71
72
|
gm = exported_program.graph_module
|
73
|
+
g = gm.graph
|
72
74
|
qd = torch.ops.quantized_decomposed # type: ignore[return]
|
73
75
|
for node in gm.graph.nodes:
|
74
76
|
if node.op != "call_function":
|
@@ -83,17 +85,19 @@ class DecomposeFakeQuantize(PassBase):
|
|
83
85
|
**{"dtype": get_quant_type(quant_min, quant_max)},
|
84
86
|
}
|
85
87
|
with gm.graph.inserting_before(node):
|
86
|
-
quant =
|
88
|
+
quant = create_node(
|
89
|
+
g,
|
87
90
|
qd.quantize_per_tensor.default,
|
88
91
|
args=node.args,
|
89
92
|
kwargs=quant_kwargs,
|
93
|
+
origin=node,
|
90
94
|
)
|
91
|
-
dequnt =
|
95
|
+
dequnt = create_node(
|
96
|
+
g,
|
92
97
|
qd.dequantize_per_tensor.default,
|
93
98
|
args=(quant, *quant.args[1:]),
|
94
99
|
kwargs=quant.kwargs,
|
95
100
|
)
|
96
|
-
# Not set meta for propagating replacing node's meta.
|
97
101
|
node.replace_all_uses_with(dequnt, propagate_meta=True)
|
98
102
|
modified = True
|
99
103
|
|
@@ -107,17 +111,19 @@ class DecomposeFakeQuantize(PassBase):
|
|
107
111
|
**{"dtype": get_quant_type(quant_min, quant_max)},
|
108
112
|
}
|
109
113
|
with gm.graph.inserting_before(node):
|
110
|
-
quant =
|
114
|
+
quant = create_node(
|
115
|
+
g,
|
111
116
|
qd.quantize_per_channel.default,
|
112
117
|
args=node.args,
|
113
118
|
kwargs=quant_kwargs,
|
119
|
+
origin=node,
|
114
120
|
)
|
115
|
-
dequnt =
|
121
|
+
dequnt = create_node(
|
122
|
+
g,
|
116
123
|
qd.dequantize_per_channel.default,
|
117
124
|
args=(quant, *quant.args[1:]),
|
118
125
|
kwargs=quant.kwargs,
|
119
126
|
)
|
120
|
-
# Not set meta for propagating replacing node's meta.
|
121
127
|
node.replace_all_uses_with(dequnt, propagate_meta=True)
|
122
128
|
modified = True
|
123
129
|
|
@@ -30,6 +30,7 @@ from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
|
|
30
30
|
from torch.export import ExportedProgram
|
31
31
|
|
32
32
|
from tico.utils import logging
|
33
|
+
from tico.utils.graph import create_node
|
33
34
|
from tico.utils.passes import PassBase, PassResult
|
34
35
|
from tico.utils.trace_decorators import (
|
35
36
|
trace_const_diff_on_pass,
|
@@ -200,6 +201,7 @@ class DecomposeFakeQuantizeTensorQParams(PassBase):
|
|
200
201
|
modified = False
|
201
202
|
|
202
203
|
gm = exported_program.graph_module
|
204
|
+
g = gm.graph
|
203
205
|
qd = torch.ops.quantized_decomposed # type: ignore[return]
|
204
206
|
for node in gm.graph.nodes:
|
205
207
|
if node.op != "call_function":
|
@@ -226,17 +228,19 @@ class DecomposeFakeQuantizeTensorQParams(PassBase):
|
|
226
228
|
**{"dtype": get_quant_type(quant_min, quant_max)},
|
227
229
|
}
|
228
230
|
with gm.graph.inserting_before(node):
|
229
|
-
quant =
|
231
|
+
quant = create_node(
|
232
|
+
g,
|
230
233
|
qd.quantize_per_tensor.default,
|
231
234
|
args=(tensor, s_value, zp_value, quant_min, quant_max),
|
232
235
|
kwargs=quant_kwargs,
|
236
|
+
origin=node,
|
233
237
|
)
|
234
|
-
dequant =
|
238
|
+
dequant = create_node(
|
239
|
+
g,
|
235
240
|
qd.dequantize_per_tensor.default,
|
236
241
|
args=(quant, *quant.args[1:]),
|
237
242
|
kwargs=quant.kwargs,
|
238
243
|
)
|
239
|
-
# Not set meta for propagating replacing get_item's meta.
|
240
244
|
get_item.replace_all_uses_with(dequant, propagate_meta=True)
|
241
245
|
# If `mask` can be graph output, which prevents `eliminate_dead_code()` from eliminating `mask`.
|
242
246
|
# So, let's remove `mask` from the output.args first.
|
@@ -267,17 +271,19 @@ class DecomposeFakeQuantizeTensorQParams(PassBase):
|
|
267
271
|
**{"dtype": get_quant_type(quant_min, quant_max)},
|
268
272
|
}
|
269
273
|
with gm.graph.inserting_before(node):
|
270
|
-
quant =
|
274
|
+
quant = create_node(
|
275
|
+
g,
|
271
276
|
qd.quantize_per_tensor.default,
|
272
277
|
args=(tensor, s_value, zp_value, quant_min, quant_max),
|
273
278
|
kwargs=quant_kwargs,
|
279
|
+
origin=node,
|
274
280
|
)
|
275
|
-
dequant =
|
281
|
+
dequant = create_node(
|
282
|
+
g,
|
276
283
|
qd.dequantize_per_tensor.default,
|
277
284
|
args=(quant, *quant.args[1:]),
|
278
285
|
kwargs=quant.kwargs,
|
279
286
|
)
|
280
|
-
# Not set meta for propagating replacing get_item's meta.
|
281
287
|
node.replace_all_uses_with(dequant, propagate_meta=True)
|
282
288
|
modified = True
|
283
289
|
|
@@ -23,6 +23,7 @@ from torch.export import ExportedProgram
|
|
23
23
|
|
24
24
|
from tico.serialize.circle_mapping import extract_shape
|
25
25
|
from tico.utils import logging
|
26
|
+
from tico.utils.graph import create_node
|
26
27
|
from tico.utils.passes import PassBase, PassResult
|
27
28
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
29
|
from tico.utils.utils import is_target_node
|
@@ -89,24 +90,40 @@ class DecomposeGroupNorm(PassBase):
|
|
89
90
|
def __init__(self):
|
90
91
|
super().__init__()
|
91
92
|
|
92
|
-
def _insert_norm(self, graph, tensor, eps):
|
93
|
+
def _insert_norm(self, graph, tensor, eps, origin):
|
93
94
|
"""
|
94
95
|
Insert (tensor - mean) / sqrt(var + eps)) into the graph
|
95
96
|
and return the normalized tensor node.
|
96
97
|
"""
|
97
|
-
mean =
|
98
|
-
|
98
|
+
mean = create_node(
|
99
|
+
graph,
|
100
|
+
torch.ops.aten.mean.dim,
|
101
|
+
(tensor, [-1]),
|
102
|
+
{"keepdim": True},
|
103
|
+
origin=origin,
|
99
104
|
)
|
100
|
-
deviation =
|
101
|
-
|
102
|
-
var = graph.call_function(
|
103
|
-
torch.ops.aten.mean.dim, (squared, [-1]), {"keepdim": True}
|
105
|
+
deviation = create_node(
|
106
|
+
graph, torch.ops.aten.sub.Tensor, (tensor, mean), origin=origin
|
104
107
|
)
|
105
|
-
|
108
|
+
squared = create_node(
|
109
|
+
graph, torch.ops.aten.pow.Tensor_Scalar, (deviation, 2), origin=origin
|
110
|
+
)
|
111
|
+
var = create_node(
|
112
|
+
graph,
|
113
|
+
torch.ops.aten.mean.dim,
|
114
|
+
(squared, [-1]),
|
115
|
+
{"keepdim": True},
|
116
|
+
origin=origin,
|
117
|
+
)
|
118
|
+
inverse_std = create_node(
|
119
|
+
graph,
|
106
120
|
torch.ops.aten.rsqrt.default,
|
107
|
-
(graph
|
121
|
+
(create_node(graph, torch.ops.aten.add.Tensor, (var, eps), origin=origin),),
|
122
|
+
origin=origin,
|
123
|
+
)
|
124
|
+
return create_node(
|
125
|
+
graph, torch.ops.aten.mul.Tensor, (deviation, inverse_std), origin=origin
|
108
126
|
)
|
109
|
-
return graph.call_function(torch.ops.aten.mul.Tensor, (deviation, inverse_std))
|
110
127
|
|
111
128
|
def call(self, exported_program: ExportedProgram) -> PassResult:
|
112
129
|
logger = logging.getLogger(__name__)
|
@@ -178,17 +195,23 @@ class DecomposeGroupNorm(PassBase):
|
|
178
195
|
# Branch only on whether a reshape is needed; the normalization is shared.
|
179
196
|
if norm_size != x_shape[-1]:
|
180
197
|
# Pack groups so that the last dimension equals norm_size.
|
181
|
-
packed =
|
182
|
-
|
198
|
+
packed = create_node(
|
199
|
+
graph,
|
200
|
+
torch.ops.aten.reshape.default,
|
201
|
+
(x, pack_shape),
|
202
|
+
origin=node,
|
183
203
|
)
|
184
|
-
normed = self._insert_norm(graph, packed, eps)
|
204
|
+
normed = self._insert_norm(graph, packed, eps, origin=node)
|
185
205
|
# Restore the original shape after normalization.
|
186
|
-
layer_norm =
|
187
|
-
|
206
|
+
layer_norm = create_node(
|
207
|
+
graph,
|
208
|
+
torch.ops.aten.reshape.default,
|
209
|
+
(normed, x_shape),
|
210
|
+
origin=node,
|
188
211
|
)
|
189
212
|
else:
|
190
213
|
# The input already has norm_size in the last dimension.
|
191
|
-
layer_norm = self._insert_norm(graph, x, eps)
|
214
|
+
layer_norm = self._insert_norm(graph, x, eps, origin=node)
|
192
215
|
|
193
216
|
# weight
|
194
217
|
if weight:
|
@@ -197,13 +220,17 @@ class DecomposeGroupNorm(PassBase):
|
|
197
220
|
assert weight_shape[0] == C
|
198
221
|
reshape_size = [1] * len(x_shape)
|
199
222
|
reshape_size[1] = C
|
200
|
-
weight =
|
223
|
+
weight = create_node(
|
224
|
+
graph,
|
201
225
|
torch.ops.aten.view.default,
|
202
226
|
(weight, reshape_size),
|
227
|
+
origin=node,
|
203
228
|
)
|
204
|
-
layer_norm =
|
229
|
+
layer_norm = create_node(
|
230
|
+
graph,
|
205
231
|
torch.ops.aten.mul.Tensor,
|
206
232
|
(layer_norm, weight),
|
233
|
+
origin=node,
|
207
234
|
)
|
208
235
|
|
209
236
|
# bias
|
@@ -213,15 +240,17 @@ class DecomposeGroupNorm(PassBase):
|
|
213
240
|
assert bias_shape[0] == C
|
214
241
|
reshape_size = [1] * len(x_shape)
|
215
242
|
reshape_size[1] = C
|
216
|
-
bias =
|
243
|
+
bias = create_node(
|
244
|
+
graph,
|
217
245
|
torch.ops.aten.view.default,
|
218
246
|
(bias, reshape_size),
|
247
|
+
origin=node,
|
219
248
|
)
|
220
|
-
layer_norm =
|
249
|
+
layer_norm = create_node(
|
250
|
+
graph,
|
221
251
|
torch.ops.aten.add.Tensor,
|
222
252
|
(layer_norm, bias),
|
223
253
|
)
|
224
|
-
|
225
254
|
# Reset last node's meta for propagating replacing node's meta.
|
226
255
|
layer_norm.meta = {}
|
227
256
|
|
@@ -23,7 +23,7 @@ from tico.passes import ops
|
|
23
23
|
from tico.serialize.circle_mapping import extract_shape
|
24
24
|
from tico.utils import logging
|
25
25
|
from tico.utils.errors import InvalidArgumentError, NotYetSupportedError
|
26
|
-
from tico.utils.graph import add_placeholder
|
26
|
+
from tico.utils.graph import add_placeholder, create_node
|
27
27
|
from tico.utils.passes import PassBase, PassResult
|
28
28
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
29
29
|
from tico.utils.utils import is_target_node
|
@@ -159,19 +159,26 @@ class DecomposeGroupedConv2d(PassBase):
|
|
159
159
|
|
160
160
|
conv2d_tensors = []
|
161
161
|
for i in range(groups):
|
162
|
-
sliced_input =
|
162
|
+
sliced_input = create_node(
|
163
|
+
graph,
|
163
164
|
torch.ops.aten.slice.Tensor,
|
164
165
|
(input_, 1, group_size * i, group_size * (i + 1), 1),
|
166
|
+
origin=node,
|
165
167
|
)
|
166
|
-
sliced_weight =
|
168
|
+
sliced_weight = create_node(
|
169
|
+
graph,
|
167
170
|
torch.ops.aten.slice.Tensor,
|
168
171
|
(weight, 0, out_group_size * i, out_group_size * (i + 1), 1),
|
172
|
+
origin=node,
|
169
173
|
)
|
170
|
-
sliced_bias =
|
174
|
+
sliced_bias = create_node(
|
175
|
+
graph,
|
171
176
|
torch.ops.aten.slice.Tensor,
|
172
177
|
(bias, 0, out_group_size * i, out_group_size * (i + 1), 1),
|
178
|
+
origin=node,
|
173
179
|
)
|
174
|
-
conv2d_tensor =
|
180
|
+
conv2d_tensor = create_node(
|
181
|
+
graph,
|
175
182
|
conv2d_op,
|
176
183
|
(
|
177
184
|
sliced_input,
|
@@ -182,11 +189,12 @@ class DecomposeGroupedConv2d(PassBase):
|
|
182
189
|
dilation,
|
183
190
|
1,
|
184
191
|
),
|
192
|
+
origin=node,
|
185
193
|
)
|
186
194
|
conv2d_tensors.append(conv2d_tensor)
|
187
195
|
|
188
|
-
concat_output =
|
189
|
-
torch.ops.aten.cat.default, (conv2d_tensors, 1)
|
196
|
+
concat_output = create_node(
|
197
|
+
graph, torch.ops.aten.cat.default, (conv2d_tensors, 1)
|
190
198
|
)
|
191
199
|
|
192
200
|
node.replace_all_uses_with(concat_output, propagate_meta=True)
|