tico 0.1.0.dev250616__py3-none-any.whl → 0.1.0.dev250618__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +6 -2
- tico/experimental/quantization/passes/quantize_bias.py +123 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +3 -0
- tico/passes/cast_aten_where_arg_type.py +4 -1
- tico/passes/cast_mixed_type_args.py +4 -1
- tico/passes/convert_conv1d_to_conv2d.py +12 -4
- tico/passes/convert_layout_op_to_reshape.py +3 -2
- tico/passes/convert_repeat_to_expand_copy.py +5 -2
- tico/passes/convert_to_relu6.py +4 -3
- tico/passes/decompose_addmm.py +11 -7
- tico/passes/decompose_batch_norm.py +7 -11
- tico/passes/decompose_fake_quantize.py +12 -6
- tico/passes/decompose_fake_quantize_tensor_qparams.py +12 -6
- tico/passes/decompose_group_norm.py +50 -21
- tico/passes/decompose_grouped_conv2d.py +15 -7
- tico/passes/decompose_slice_scatter.py +9 -5
- tico/passes/fuse_leading_unsqueeze_reshape.py +8 -3
- tico/passes/legalize_predefined_layout_operators.py +33 -25
- tico/passes/lower_pow2_to_mul.py +3 -1
- tico/passes/lower_to_resize_nearest_neighbor.py +21 -10
- tico/passes/lower_to_slice.py +21 -11
- tico/passes/remove_redundant_permute.py +5 -3
- tico/passes/remove_redundant_reshape.py +5 -2
- tico/passes/remove_redundant_to_copy.py +4 -0
- tico/passes/restore_linear.py +7 -5
- tico/passes/segment_index_select.py +9 -5
- tico/utils/convert.py +2 -0
- tico/utils/graph.py +48 -2
- {tico-0.1.0.dev250616.dist-info → tico-0.1.0.dev250618.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250616.dist-info → tico-0.1.0.dev250618.dist-info}/RECORD +35 -34
- {tico-0.1.0.dev250616.dist-info → tico-0.1.0.dev250618.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250616.dist-info → tico-0.1.0.dev250618.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250616.dist-info → tico-0.1.0.dev250618.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250616.dist-info → tico-0.1.0.dev250618.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
@@ -21,7 +21,7 @@ from tico.config import CompileConfigV1, get_default_config
|
|
21
21
|
from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
|
22
22
|
|
23
23
|
# THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
|
24
|
-
__version__ = "0.1.0.
|
24
|
+
__version__ = "0.1.0.dev250618"
|
25
25
|
|
26
26
|
MINIMUM_SUPPORTED_VERSION = "2.5.0"
|
27
27
|
SECURE_TORCH_VERSION = "2.6.0"
|
@@ -24,6 +24,7 @@ from torch.export import ExportedProgram
|
|
24
24
|
from tico.serialize.quant_param import QPARAM_KEY, QuantParam
|
25
25
|
from tico.utils import logging
|
26
26
|
from tico.utils.errors import NotYetSupportedError
|
27
|
+
from tico.utils.graph import create_node
|
27
28
|
from tico.utils.passes import PassBase, PassResult
|
28
29
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
29
30
|
from tico.utils.utils import quant_min_max, set_new_meta_val
|
@@ -145,9 +146,11 @@ class InsertQuantizeOnDtypeMismatch(PassBase):
|
|
145
146
|
|
146
147
|
with graph.inserting_before(node):
|
147
148
|
q_args = (inp, scale, zerop, min_, max_, dtype)
|
148
|
-
quantize =
|
149
|
+
quantize = create_node(
|
150
|
+
graph,
|
149
151
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
150
152
|
args=q_args,
|
153
|
+
origin=node,
|
151
154
|
)
|
152
155
|
quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
|
153
156
|
set_new_meta_val(quantize)
|
@@ -166,7 +169,8 @@ class InsertQuantizeOnDtypeMismatch(PassBase):
|
|
166
169
|
dtype = getattr(torch, qparam.dtype)
|
167
170
|
with graph.inserting_after(node):
|
168
171
|
q_args = (node, scale, zerop, min_, max_, dtype)
|
169
|
-
quantize =
|
172
|
+
quantize = create_node(
|
173
|
+
graph,
|
170
174
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
171
175
|
args=q_args,
|
172
176
|
)
|
@@ -0,0 +1,123 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import copy
|
20
|
+
|
21
|
+
import torch
|
22
|
+
from torch.export import ExportedProgram
|
23
|
+
|
24
|
+
from tico.serialize.quant_param import QPARAM_KEY, QuantParam, to_qparam_dtype
|
25
|
+
from tico.utils import logging
|
26
|
+
from tico.utils.graph import add_placeholder, get_torch_param_value, is_torch_param
|
27
|
+
from tico.utils.passes import PassBase, PassResult
|
28
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
29
|
+
from tico.utils.validate_args_kwargs import LinearArgs
|
30
|
+
|
31
|
+
|
32
|
+
@trace_graph_diff_on_pass
|
33
|
+
class QuantizeBias(PassBase):
|
34
|
+
"""
|
35
|
+
Quantize bias.
|
36
|
+
|
37
|
+
This pass identifies fp32 biases, quantizes them using scales of input and weights.
|
38
|
+
|
39
|
+
This pass assumes that if bias is fp32, input and weights must have been quantized.
|
40
|
+
"""
|
41
|
+
|
42
|
+
def __init__(self):
|
43
|
+
super().__init__()
|
44
|
+
|
45
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
46
|
+
logger = logging.getLogger(__name__)
|
47
|
+
|
48
|
+
graph_module = exported_program.graph_module
|
49
|
+
graph: torch.fx.Graph = graph_module.graph
|
50
|
+
for node in graph.nodes:
|
51
|
+
if node.op != "call_function":
|
52
|
+
continue
|
53
|
+
if node.target == torch.ops.aten.linear.default:
|
54
|
+
lin_args = LinearArgs(*node.args, **node.kwargs)
|
55
|
+
inp = lin_args.input
|
56
|
+
weights = lin_args.weight
|
57
|
+
bias = lin_args.bias
|
58
|
+
|
59
|
+
if bias is None:
|
60
|
+
continue
|
61
|
+
|
62
|
+
# Only support bias is Parameter
|
63
|
+
# TODO Is it possible that bias is not Parameter?
|
64
|
+
if not is_torch_param(bias, exported_program):
|
65
|
+
continue
|
66
|
+
|
67
|
+
bias_val: torch.Tensor = get_torch_param_value(bias, exported_program)
|
68
|
+
if bias_val.dtype != torch.float32:
|
69
|
+
continue
|
70
|
+
|
71
|
+
if QPARAM_KEY not in inp.meta:
|
72
|
+
continue
|
73
|
+
|
74
|
+
if QPARAM_KEY not in weights.meta:
|
75
|
+
continue
|
76
|
+
|
77
|
+
quant_dtype = None
|
78
|
+
if inp.meta[QPARAM_KEY].dtype == "int16":
|
79
|
+
quant_dtype = torch.int64
|
80
|
+
elif inp.meta[QPARAM_KEY].dtype == "uint8":
|
81
|
+
quant_dtype = torch.int32
|
82
|
+
else:
|
83
|
+
continue
|
84
|
+
|
85
|
+
type_info = torch.iinfo(quant_dtype)
|
86
|
+
|
87
|
+
assert quant_dtype is not None
|
88
|
+
|
89
|
+
i_scale = inp.meta[QPARAM_KEY].scale
|
90
|
+
w_scale = weights.meta[QPARAM_KEY].scale
|
91
|
+
|
92
|
+
assert i_scale is not None
|
93
|
+
assert w_scale is not None
|
94
|
+
assert len(i_scale) == 1
|
95
|
+
assert len(w_scale) == bias_val.shape[0]
|
96
|
+
|
97
|
+
bias_scale = torch.tensor(i_scale) * torch.tensor(w_scale)
|
98
|
+
q_bias = torch.round(bias_val / bias_scale)
|
99
|
+
q_bias = torch.clamp(q_bias, min=type_info.min, max=type_info.max)
|
100
|
+
q_bias = q_bias.to(quant_dtype)
|
101
|
+
|
102
|
+
q_bias_node = add_placeholder(exported_program, q_bias, bias.name)
|
103
|
+
|
104
|
+
qparam = QuantParam()
|
105
|
+
qparam.scale = bias_scale.tolist()
|
106
|
+
assert qparam.scale is not None
|
107
|
+
qparam.zero_point = [0] * len(qparam.scale)
|
108
|
+
qparam.dtype = to_qparam_dtype(quant_dtype)
|
109
|
+
qparam.quantized_dimension = 0
|
110
|
+
q_bias_node.meta[QPARAM_KEY] = qparam
|
111
|
+
|
112
|
+
node.update_arg(2, q_bias_node)
|
113
|
+
|
114
|
+
logger.debug(f"Bias ({bias.name}) is quantized to {q_bias_node.name}.")
|
115
|
+
|
116
|
+
# TODO Support more ops.
|
117
|
+
|
118
|
+
graph.eliminate_dead_code()
|
119
|
+
graph.lint()
|
120
|
+
graph_module.recompile()
|
121
|
+
|
122
|
+
# Run only once.
|
123
|
+
return PassResult(False)
|
@@ -145,6 +145,9 @@ class RemoveWeightDequantOp(PassBase):
|
|
145
145
|
if isinstance(dq_args, DequantizePerChannelArgs):
|
146
146
|
scales = get_constant(exported_program, dq_args.scales)
|
147
147
|
zero_ps = get_constant(exported_program, dq_args.zero_points)
|
148
|
+
|
149
|
+
# Sometimes users can give fp32 zero point. Let's update dtype here.
|
150
|
+
zero_ps = zero_ps.to(torch.int64)
|
148
151
|
quant_param.scale = scales.tolist()
|
149
152
|
quant_param.zero_point = zero_ps.tolist()
|
150
153
|
assert quant_param.zero_point is not None # To avoid mypy error
|
@@ -21,6 +21,7 @@ from torch.export import ExportedProgram
|
|
21
21
|
|
22
22
|
from tico.serialize.circle_mapping import extract_torch_dtype
|
23
23
|
from tico.utils import logging
|
24
|
+
from tico.utils.graph import create_node
|
24
25
|
from tico.utils.passes import PassBase, PassResult
|
25
26
|
from tico.utils.trace_decorators import (
|
26
27
|
trace_const_diff_on_pass,
|
@@ -158,10 +159,12 @@ class CastATenWhereArgType(PassBase):
|
|
158
159
|
f"{to_cast.name}({buf_data.dtype}) data range is out of {dtype_to_cast} range"
|
159
160
|
)
|
160
161
|
with graph_module.graph.inserting_after(to_cast):
|
161
|
-
cast =
|
162
|
+
cast = create_node(
|
163
|
+
graph,
|
162
164
|
torch.ops.aten._to_copy.default,
|
163
165
|
args=(to_cast,),
|
164
166
|
kwargs={"dtype": dtype_to_cast},
|
167
|
+
origin=to_cast,
|
165
168
|
)
|
166
169
|
# set new meta["val"] in advance because we will use it below for checking if type promotion is valid.
|
167
170
|
set_new_meta_val(cast)
|
@@ -26,6 +26,7 @@ from torch.export import ExportedProgram
|
|
26
26
|
|
27
27
|
from tico.serialize.circle_mapping import extract_torch_dtype
|
28
28
|
from tico.utils import logging
|
29
|
+
from tico.utils.graph import create_node
|
29
30
|
from tico.utils.passes import PassBase, PassResult
|
30
31
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
31
32
|
from tico.utils.utils import is_target_node, set_new_meta_val
|
@@ -126,10 +127,12 @@ class CastMixedTypeArgs(PassBase):
|
|
126
127
|
|
127
128
|
if isinstance(arg_to_promote, torch.fx.Node):
|
128
129
|
with graph.inserting_after(arg_to_promote):
|
129
|
-
to_copy =
|
130
|
+
to_copy = create_node(
|
131
|
+
graph,
|
130
132
|
torch.ops.aten._to_copy.default,
|
131
133
|
(arg_to_promote,),
|
132
134
|
{"dtype": type_to_promote},
|
135
|
+
origin=arg_to_promote,
|
133
136
|
)
|
134
137
|
# set new meta["val"] in advance because we will use it below for checking if type promotion is valid.
|
135
138
|
set_new_meta_val(to_copy)
|
@@ -22,6 +22,7 @@ from torch.export import ExportedProgram
|
|
22
22
|
from tico.serialize.circle_graph import extract_shape
|
23
23
|
from tico.utils import logging
|
24
24
|
from tico.utils.errors import NotYetSupportedError
|
25
|
+
from tico.utils.graph import create_node
|
25
26
|
from tico.utils.passes import PassBase, PassResult
|
26
27
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
28
|
from tico.utils.utils import is_target_node
|
@@ -89,15 +90,19 @@ class ConvertConv1dToConv2d(PassBase):
|
|
89
90
|
)
|
90
91
|
|
91
92
|
with graph.inserting_after(input):
|
92
|
-
input_unsqueeze =
|
93
|
+
input_unsqueeze = create_node(
|
94
|
+
graph,
|
93
95
|
torch.ops.aten.unsqueeze.default,
|
94
96
|
args=(input, 3),
|
97
|
+
origin=input,
|
95
98
|
)
|
96
99
|
|
97
100
|
with graph.inserting_after(weight):
|
98
|
-
weight_unsqueeze =
|
101
|
+
weight_unsqueeze = create_node(
|
102
|
+
graph,
|
99
103
|
torch.ops.aten.unsqueeze.default,
|
100
104
|
args=(weight, 3),
|
105
|
+
origin=weight,
|
101
106
|
)
|
102
107
|
|
103
108
|
with graph.inserting_before(node):
|
@@ -106,7 +111,8 @@ class ConvertConv1dToConv2d(PassBase):
|
|
106
111
|
elif isinstance(padding, str):
|
107
112
|
conv2d_op = torch.ops.aten.conv2d.padding
|
108
113
|
|
109
|
-
conv2d =
|
114
|
+
conv2d = create_node(
|
115
|
+
graph,
|
110
116
|
conv2d_op,
|
111
117
|
args=(
|
112
118
|
input_unsqueeze,
|
@@ -118,9 +124,11 @@ class ConvertConv1dToConv2d(PassBase):
|
|
118
124
|
groups,
|
119
125
|
),
|
120
126
|
kwargs=node.kwargs,
|
127
|
+
origin=node,
|
121
128
|
)
|
122
129
|
|
123
|
-
conv_out_squeeze =
|
130
|
+
conv_out_squeeze = create_node(
|
131
|
+
graph,
|
124
132
|
torch.ops.aten.squeeze.dims,
|
125
133
|
args=(conv2d, [3]),
|
126
134
|
)
|
@@ -22,6 +22,7 @@ from torch.export import ExportedProgram
|
|
22
22
|
from tico.passes import ops
|
23
23
|
from tico.serialize.circle_mapping import extract_shape
|
24
24
|
from tico.utils import logging
|
25
|
+
from tico.utils.graph import create_node
|
25
26
|
from tico.utils.passes import PassBase, PassResult
|
26
27
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
28
|
from tico.utils.validate_args_kwargs import SqueezeArgs, UnSqueezeArgs, ViewArgs
|
@@ -48,11 +49,11 @@ class ConvertLayoutOpToReshape(PassBase):
|
|
48
49
|
out_shape = list(extract_shape(node))
|
49
50
|
|
50
51
|
with graph.inserting_after(node):
|
51
|
-
reshape_node =
|
52
|
+
reshape_node = create_node(
|
53
|
+
graph,
|
52
54
|
torch.ops.aten.reshape.default,
|
53
55
|
args=(input, out_shape),
|
54
56
|
)
|
55
|
-
|
56
57
|
node.replace_all_uses_with(reshape_node, propagate_meta=True)
|
57
58
|
|
58
59
|
logger.debug(f"{node.name} is replaced with {reshape_node.name}")
|
@@ -20,6 +20,7 @@ import torch
|
|
20
20
|
from torch.export import ExportedProgram
|
21
21
|
|
22
22
|
from tico.utils import logging
|
23
|
+
from tico.utils.graph import create_node
|
23
24
|
from tico.utils.passes import PassBase, PassResult
|
24
25
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
25
26
|
from tico.utils.utils import is_target_node
|
@@ -71,8 +72,10 @@ class ConvertRepeatToExpandCopy(PassBase):
|
|
71
72
|
expand_copy_args = (tensor, size)
|
72
73
|
|
73
74
|
with graph.inserting_after(node):
|
74
|
-
expand_copy_node =
|
75
|
-
|
75
|
+
expand_copy_node = create_node(
|
76
|
+
graph,
|
77
|
+
torch.ops.aten.expand_copy.default,
|
78
|
+
args=expand_copy_args,
|
76
79
|
)
|
77
80
|
node.replace_all_uses_with(expand_copy_node, propagate_meta=True)
|
78
81
|
|
tico/passes/convert_to_relu6.py
CHANGED
@@ -20,6 +20,7 @@ import torch
|
|
20
20
|
from torch.export import ExportedProgram
|
21
21
|
|
22
22
|
from tico.utils import logging
|
23
|
+
from tico.utils.graph import create_node
|
23
24
|
from tico.utils.passes import PassBase, PassResult
|
24
25
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
25
26
|
from tico.utils.validate_args_kwargs import ClampArgs, HardTanhArgs
|
@@ -58,7 +59,7 @@ class ConvertHardTanhToReLU6(Converter):
|
|
58
59
|
input = args.input
|
59
60
|
|
60
61
|
with graph.inserting_after(node):
|
61
|
-
relu_node = graph
|
62
|
+
relu_node = create_node(graph, torch.ops.aten.relu6.default, args=(input,))
|
62
63
|
node.replace_all_uses_with(relu_node, propagate_meta=True)
|
63
64
|
|
64
65
|
|
@@ -84,7 +85,7 @@ class ConvertClampToReLU6(Converter):
|
|
84
85
|
input = args.input
|
85
86
|
|
86
87
|
with graph.inserting_after(node):
|
87
|
-
relu_node = graph
|
88
|
+
relu_node = create_node(graph, torch.ops.aten.relu6.default, args=(input,))
|
88
89
|
node.replace_all_uses_with(relu_node, propagate_meta=True)
|
89
90
|
|
90
91
|
|
@@ -140,7 +141,7 @@ class ConvertDoubleClampsToReLU6(Converter):
|
|
140
141
|
input = prev_args.input
|
141
142
|
|
142
143
|
with graph.inserting_after(node):
|
143
|
-
relu_node = graph
|
144
|
+
relu_node = create_node(graph, torch.ops.aten.relu6.default, args=(input,))
|
144
145
|
node.replace_all_uses_with(relu_node, propagate_meta=True)
|
145
146
|
|
146
147
|
|
tico/passes/decompose_addmm.py
CHANGED
@@ -21,7 +21,7 @@ from torch.export import ExportedProgram
|
|
21
21
|
|
22
22
|
from tico.serialize.circle_mapping import extract_shape
|
23
23
|
from tico.utils import logging
|
24
|
-
from tico.utils.graph import add_placeholder
|
24
|
+
from tico.utils.graph import add_placeholder, create_node
|
25
25
|
from tico.utils.passes import PassBase, PassResult
|
26
26
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
27
|
from tico.utils.utils import is_target_node, set_new_meta_val
|
@@ -78,7 +78,9 @@ class DecomposeAddmm(PassBase):
|
|
78
78
|
|
79
79
|
with graph.inserting_before(node):
|
80
80
|
# out = beta * input + alpha * (mat1 @ mat2)
|
81
|
-
matmul =
|
81
|
+
matmul = create_node(
|
82
|
+
graph, torch.ops.aten.mm.default, (mat1, mat2), origin=node
|
83
|
+
)
|
82
84
|
set_new_meta_val(matmul)
|
83
85
|
|
84
86
|
if beta == 1:
|
@@ -90,7 +92,9 @@ class DecomposeAddmm(PassBase):
|
|
90
92
|
f"{node.name}_beta_zeros",
|
91
93
|
)
|
92
94
|
else:
|
93
|
-
bias =
|
95
|
+
bias = create_node(
|
96
|
+
graph, torch.ops.aten.mul.Tensor, (input, beta), origin=node
|
97
|
+
)
|
94
98
|
|
95
99
|
if alpha == 1:
|
96
100
|
scaled_matmul: torch.fx.Node | torch.Tensor = matmul
|
@@ -101,12 +105,12 @@ class DecomposeAddmm(PassBase):
|
|
101
105
|
f"{node.name}_alpha_zeros",
|
102
106
|
)
|
103
107
|
else:
|
104
|
-
scaled_matmul =
|
105
|
-
torch.ops.aten.mul.Tensor, (matmul, alpha)
|
108
|
+
scaled_matmul = create_node(
|
109
|
+
graph, torch.ops.aten.mul.Tensor, (matmul, alpha), origin=node
|
106
110
|
)
|
107
111
|
|
108
|
-
result =
|
109
|
-
torch.ops.aten.add.Tensor, (bias, scaled_matmul)
|
112
|
+
result = create_node(
|
113
|
+
graph, torch.ops.aten.add.Tensor, (bias, scaled_matmul)
|
110
114
|
)
|
111
115
|
|
112
116
|
node.replace_all_uses_with(result, propagate_meta=True)
|
@@ -24,6 +24,7 @@ from tico.utils import logging
|
|
24
24
|
from tico.utils.errors import NotYetSupportedError
|
25
25
|
from tico.utils.graph import (
|
26
26
|
add_placeholder,
|
27
|
+
create_node,
|
27
28
|
get_first_user_input,
|
28
29
|
get_torch_buffer_value,
|
29
30
|
get_torch_param_value,
|
@@ -32,16 +33,10 @@ from tico.utils.graph import (
|
|
32
33
|
)
|
33
34
|
from tico.utils.passes import PassBase, PassResult
|
34
35
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
35
|
-
from tico.utils.utils import
|
36
|
+
from tico.utils.utils import is_target_node
|
36
37
|
from tico.utils.validate_args_kwargs import NativeBatchNormLegitNoTrainingArgs
|
37
38
|
|
38
39
|
|
39
|
-
def insert_node(graph: torch.fx.Graph, operation, args):
|
40
|
-
new_node = graph.call_function(operation, args)
|
41
|
-
|
42
|
-
return new_node
|
43
|
-
|
44
|
-
|
45
40
|
@trace_graph_diff_on_pass
|
46
41
|
class DecomposeBatchNorm(PassBase):
|
47
42
|
"""
|
@@ -173,19 +168,20 @@ class DecomposeBatchNorm(PassBase):
|
|
173
168
|
)
|
174
169
|
|
175
170
|
with gm.graph.inserting_before(node):
|
176
|
-
mul =
|
171
|
+
mul = create_node(
|
172
|
+
graph,
|
177
173
|
torch.ops.aten.mul.Tensor,
|
178
174
|
args=(input_, mul_const_node),
|
175
|
+
origin=node,
|
179
176
|
)
|
180
|
-
add =
|
177
|
+
add = create_node(
|
178
|
+
graph,
|
181
179
|
torch.ops.aten.add.Tensor,
|
182
180
|
args=(mul, add_const_node),
|
183
181
|
)
|
184
|
-
# Not set meta for propagating replacing get_item's meta.
|
185
182
|
get_item, *_ = node.users.keys()
|
186
183
|
get_item.replace_all_uses_with(add, propagate_meta=True)
|
187
184
|
|
188
|
-
fill_meta_val(exported_program)
|
189
185
|
logger.debug(f"{node.name} is decomposed to {mul.name} and {add.name}")
|
190
186
|
modified = True
|
191
187
|
|
@@ -23,6 +23,7 @@ from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
|
|
23
23
|
from torch.export import ExportedProgram
|
24
24
|
|
25
25
|
from tico.utils import logging
|
26
|
+
from tico.utils.graph import create_node
|
26
27
|
from tico.utils.passes import PassBase, PassResult
|
27
28
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
29
|
from tico.utils.validate_args_kwargs import FakeQuantizePerChannelArgs
|
@@ -69,6 +70,7 @@ class DecomposeFakeQuantize(PassBase):
|
|
69
70
|
modified = False
|
70
71
|
|
71
72
|
gm = exported_program.graph_module
|
73
|
+
g = gm.graph
|
72
74
|
qd = torch.ops.quantized_decomposed # type: ignore[return]
|
73
75
|
for node in gm.graph.nodes:
|
74
76
|
if node.op != "call_function":
|
@@ -83,17 +85,19 @@ class DecomposeFakeQuantize(PassBase):
|
|
83
85
|
**{"dtype": get_quant_type(quant_min, quant_max)},
|
84
86
|
}
|
85
87
|
with gm.graph.inserting_before(node):
|
86
|
-
quant =
|
88
|
+
quant = create_node(
|
89
|
+
g,
|
87
90
|
qd.quantize_per_tensor.default,
|
88
91
|
args=node.args,
|
89
92
|
kwargs=quant_kwargs,
|
93
|
+
origin=node,
|
90
94
|
)
|
91
|
-
dequnt =
|
95
|
+
dequnt = create_node(
|
96
|
+
g,
|
92
97
|
qd.dequantize_per_tensor.default,
|
93
98
|
args=(quant, *quant.args[1:]),
|
94
99
|
kwargs=quant.kwargs,
|
95
100
|
)
|
96
|
-
# Not set meta for propagating replacing node's meta.
|
97
101
|
node.replace_all_uses_with(dequnt, propagate_meta=True)
|
98
102
|
modified = True
|
99
103
|
|
@@ -107,17 +111,19 @@ class DecomposeFakeQuantize(PassBase):
|
|
107
111
|
**{"dtype": get_quant_type(quant_min, quant_max)},
|
108
112
|
}
|
109
113
|
with gm.graph.inserting_before(node):
|
110
|
-
quant =
|
114
|
+
quant = create_node(
|
115
|
+
g,
|
111
116
|
qd.quantize_per_channel.default,
|
112
117
|
args=node.args,
|
113
118
|
kwargs=quant_kwargs,
|
119
|
+
origin=node,
|
114
120
|
)
|
115
|
-
dequnt =
|
121
|
+
dequnt = create_node(
|
122
|
+
g,
|
116
123
|
qd.dequantize_per_channel.default,
|
117
124
|
args=(quant, *quant.args[1:]),
|
118
125
|
kwargs=quant.kwargs,
|
119
126
|
)
|
120
|
-
# Not set meta for propagating replacing node's meta.
|
121
127
|
node.replace_all_uses_with(dequnt, propagate_meta=True)
|
122
128
|
modified = True
|
123
129
|
|
@@ -30,6 +30,7 @@ from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
|
|
30
30
|
from torch.export import ExportedProgram
|
31
31
|
|
32
32
|
from tico.utils import logging
|
33
|
+
from tico.utils.graph import create_node
|
33
34
|
from tico.utils.passes import PassBase, PassResult
|
34
35
|
from tico.utils.trace_decorators import (
|
35
36
|
trace_const_diff_on_pass,
|
@@ -200,6 +201,7 @@ class DecomposeFakeQuantizeTensorQParams(PassBase):
|
|
200
201
|
modified = False
|
201
202
|
|
202
203
|
gm = exported_program.graph_module
|
204
|
+
g = gm.graph
|
203
205
|
qd = torch.ops.quantized_decomposed # type: ignore[return]
|
204
206
|
for node in gm.graph.nodes:
|
205
207
|
if node.op != "call_function":
|
@@ -226,17 +228,19 @@ class DecomposeFakeQuantizeTensorQParams(PassBase):
|
|
226
228
|
**{"dtype": get_quant_type(quant_min, quant_max)},
|
227
229
|
}
|
228
230
|
with gm.graph.inserting_before(node):
|
229
|
-
quant =
|
231
|
+
quant = create_node(
|
232
|
+
g,
|
230
233
|
qd.quantize_per_tensor.default,
|
231
234
|
args=(tensor, s_value, zp_value, quant_min, quant_max),
|
232
235
|
kwargs=quant_kwargs,
|
236
|
+
origin=node,
|
233
237
|
)
|
234
|
-
dequant =
|
238
|
+
dequant = create_node(
|
239
|
+
g,
|
235
240
|
qd.dequantize_per_tensor.default,
|
236
241
|
args=(quant, *quant.args[1:]),
|
237
242
|
kwargs=quant.kwargs,
|
238
243
|
)
|
239
|
-
# Not set meta for propagating replacing get_item's meta.
|
240
244
|
get_item.replace_all_uses_with(dequant, propagate_meta=True)
|
241
245
|
# If `mask` can be graph output, which prevents `eliminate_dead_code()` from eliminating `mask`.
|
242
246
|
# So, let's remove `mask` from the output.args first.
|
@@ -267,17 +271,19 @@ class DecomposeFakeQuantizeTensorQParams(PassBase):
|
|
267
271
|
**{"dtype": get_quant_type(quant_min, quant_max)},
|
268
272
|
}
|
269
273
|
with gm.graph.inserting_before(node):
|
270
|
-
quant =
|
274
|
+
quant = create_node(
|
275
|
+
g,
|
271
276
|
qd.quantize_per_tensor.default,
|
272
277
|
args=(tensor, s_value, zp_value, quant_min, quant_max),
|
273
278
|
kwargs=quant_kwargs,
|
279
|
+
origin=node,
|
274
280
|
)
|
275
|
-
dequant =
|
281
|
+
dequant = create_node(
|
282
|
+
g,
|
276
283
|
qd.dequantize_per_tensor.default,
|
277
284
|
args=(quant, *quant.args[1:]),
|
278
285
|
kwargs=quant.kwargs,
|
279
286
|
)
|
280
|
-
# Not set meta for propagating replacing get_item's meta.
|
281
287
|
node.replace_all_uses_with(dequant, propagate_meta=True)
|
282
288
|
modified = True
|
283
289
|
|