tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +9 -1
- tico/config/base.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_aten_where_arg_type.py +1 -1
- tico/passes/cast_clamp_mixed_type_args.py +169 -0
- tico/passes/cast_mixed_type_args.py +4 -2
- tico/passes/const_prop_pass.py +1 -1
- tico/passes/convert_conv1d_to_conv2d.py +1 -1
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_addmm.py +0 -3
- tico/passes/decompose_batch_norm.py +2 -2
- tico/passes/decompose_fake_quantize.py +0 -3
- tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
- tico/passes/decompose_group_norm.py +0 -3
- tico/passes/legalize_predefined_layout_operators.py +2 -11
- tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
- tico/passes/lower_to_slice.py +1 -1
- tico/passes/merge_consecutive_cat.py +1 -1
- tico/passes/ops.py +1 -1
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- tico/passes/remove_redundant_expand.py +3 -6
- tico/passes/remove_redundant_reshape.py +5 -5
- tico/passes/segment_index_select.py +1 -1
- tico/quantization/__init__.py +6 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
- tico/quantization/algorithm/gptq/quantizer.py +292 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/quantization/config/ptq.py +119 -0
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/quantization/evaluation/metric.py +146 -0
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/quantization/passes/__init__.py +1 -0
- tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
- tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
- tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
- tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
- tico/{experimental/quantization → quantization}/public_interface.py +19 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/__init__.py +1 -0
- tico/quantization/wrapq/dtypes.py +70 -0
- tico/quantization/wrapq/examples/__init__.py +1 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/quantization/wrapq/examples/quantize_linear.py +107 -0
- tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
- tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
- tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/quantization/wrapq/mode.py +32 -0
- tico/quantization/wrapq/observers/__init__.py +1 -0
- tico/quantization/wrapq/observers/affine_base.py +128 -0
- tico/quantization/wrapq/observers/base.py +98 -0
- tico/quantization/wrapq/observers/ema.py +62 -0
- tico/quantization/wrapq/observers/identity.py +74 -0
- tico/quantization/wrapq/observers/minmax.py +39 -0
- tico/quantization/wrapq/observers/mx.py +60 -0
- tico/quantization/wrapq/qscheme.py +40 -0
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/quantization/wrapq/utils/__init__.py +1 -0
- tico/quantization/wrapq/utils/introspection.py +167 -0
- tico/quantization/wrapq/utils/metrics.py +124 -0
- tico/quantization/wrapq/utils/reduce_utils.py +25 -0
- tico/quantization/wrapq/wrappers/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
- tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
- tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
- tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
- tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
- tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
- tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
- tico/quantization/wrapq/wrappers/registry.py +125 -0
- tico/serialize/circle_graph.py +12 -4
- tico/serialize/circle_mapping.py +76 -2
- tico/serialize/circle_serializer.py +253 -148
- tico/serialize/operators/adapters/__init__.py +1 -0
- tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
- tico/serialize/operators/op_any.py +7 -14
- tico/serialize/operators/op_avg_pool2d.py +11 -4
- tico/serialize/operators/op_clamp.py +5 -7
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_conv2d.py +14 -6
- tico/serialize/operators/op_copy.py +26 -3
- tico/serialize/operators/op_cumsum.py +3 -1
- tico/serialize/operators/op_depthwise_conv2d.py +17 -7
- tico/serialize/operators/op_full_like.py +0 -2
- tico/serialize/operators/op_index_select.py +8 -1
- tico/serialize/operators/op_instance_norm.py +0 -6
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_log1p.py +3 -2
- tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
- tico/serialize/operators/op_mm.py +15 -131
- tico/serialize/operators/op_mul.py +2 -8
- tico/serialize/operators/op_pow.py +3 -1
- tico/serialize/operators/op_repeat.py +12 -3
- tico/serialize/operators/op_reshape.py +1 -1
- tico/serialize/operators/op_rmsnorm.py +65 -0
- tico/serialize/operators/op_softmax.py +7 -14
- tico/serialize/operators/op_split_with_sizes.py +16 -8
- tico/serialize/operators/op_transpose_conv.py +11 -8
- tico/serialize/operators/op_view.py +2 -1
- tico/serialize/quant_param.py +5 -5
- tico/utils/convert.py +30 -17
- tico/utils/dtype.py +42 -0
- tico/utils/graph.py +1 -1
- tico/utils/model.py +2 -1
- tico/utils/padding.py +2 -2
- tico/utils/pytree_utils.py +134 -0
- tico/utils/record_input.py +102 -0
- tico/utils/register_custom_op.py +29 -4
- tico/utils/serialize.py +16 -3
- tico/utils/signature.py +247 -0
- tico/utils/torch_compat.py +52 -0
- tico/utils/utils.py +50 -58
- tico/utils/validate_args_kwargs.py +38 -3
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
- tico-0.1.0.dev251102.dist-info/RECORD +271 -0
- tico/experimental/quantization/__init__.py +0 -1
- tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/evaluation/metric.py +0 -109
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
- tico-0.1.0.dev250714.dist-info/RECORD +0 -209
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import List, Optional, TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import torch.fx
|
|
19
|
+
import torch
|
|
20
|
+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
|
|
21
|
+
from torch.export import ExportedProgram
|
|
22
|
+
|
|
23
|
+
from tico.serialize.circle_mapping import extract_shape
|
|
24
|
+
|
|
25
|
+
from tico.utils import logging
|
|
26
|
+
from tico.utils.graph import create_node
|
|
27
|
+
from tico.utils.passes import PassBase, PassResult
|
|
28
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
|
29
|
+
from tico.utils.validate_args_kwargs import BmmArgs, MatmulArgs
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Converter: # type: ignore[empty-body]
|
|
33
|
+
def __init__(self):
|
|
34
|
+
super().__init__()
|
|
35
|
+
|
|
36
|
+
def match(self, exported_program, node) -> bool: # type: ignore[empty-body]
|
|
37
|
+
return False
|
|
38
|
+
|
|
39
|
+
def convert(self, exported_program, node) -> torch.fx.Node: # type: ignore[empty-body]
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class MatmulToLinearConverter(Converter):
|
|
44
|
+
def __init__(self):
|
|
45
|
+
super().__init__()
|
|
46
|
+
|
|
47
|
+
def convert(self, exported_program, node) -> torch.fx.Node:
|
|
48
|
+
graph_module = exported_program.graph_module
|
|
49
|
+
graph = graph_module.graph
|
|
50
|
+
|
|
51
|
+
mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
52
|
+
|
|
53
|
+
lhs = mm_args.input
|
|
54
|
+
rhs = mm_args.other
|
|
55
|
+
|
|
56
|
+
with graph.inserting_before(node):
|
|
57
|
+
transpose_node = create_node(
|
|
58
|
+
graph,
|
|
59
|
+
torch.ops.aten.permute.default,
|
|
60
|
+
args=(rhs, [1, 0]),
|
|
61
|
+
)
|
|
62
|
+
linear_node = create_node(
|
|
63
|
+
graph,
|
|
64
|
+
torch.ops.aten.linear.default,
|
|
65
|
+
args=(lhs, transpose_node),
|
|
66
|
+
)
|
|
67
|
+
node.replace_all_uses_with(linear_node, propagate_meta=True)
|
|
68
|
+
|
|
69
|
+
return linear_node
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class RhsConstMatmulToLinearConverter(MatmulToLinearConverter):
|
|
73
|
+
def __init__(self):
|
|
74
|
+
super().__init__()
|
|
75
|
+
|
|
76
|
+
def match(self, exported_program, node) -> bool:
|
|
77
|
+
if not node.target == torch.ops.aten.mm.default:
|
|
78
|
+
return False
|
|
79
|
+
|
|
80
|
+
mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
81
|
+
|
|
82
|
+
rhs = mm_args.other
|
|
83
|
+
if isinstance(rhs, torch.fx.Node):
|
|
84
|
+
if is_lifted_tensor_constant(exported_program, rhs):
|
|
85
|
+
return True
|
|
86
|
+
elif is_param(exported_program, rhs):
|
|
87
|
+
return True
|
|
88
|
+
elif is_buffer(exported_program, rhs):
|
|
89
|
+
return True
|
|
90
|
+
else:
|
|
91
|
+
return False
|
|
92
|
+
return False
|
|
93
|
+
|
|
94
|
+
def convert(self, exported_program, node) -> torch.fx.Node:
|
|
95
|
+
return super().convert(exported_program, node)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class LhsConstMatmulToLinearConverter(MatmulToLinearConverter):
|
|
99
|
+
def __init__(self):
|
|
100
|
+
super().__init__()
|
|
101
|
+
|
|
102
|
+
def match(self, exported_program, node) -> bool:
|
|
103
|
+
if not node.target == torch.ops.aten.mm.default:
|
|
104
|
+
return False
|
|
105
|
+
|
|
106
|
+
mm_args = MatmulArgs(*node.args, **node.kwargs)
|
|
107
|
+
lhs = mm_args.input
|
|
108
|
+
if isinstance(lhs, torch.fx.Node):
|
|
109
|
+
if is_lifted_tensor_constant(exported_program, lhs):
|
|
110
|
+
return True
|
|
111
|
+
elif is_param(exported_program, lhs):
|
|
112
|
+
return True
|
|
113
|
+
elif is_buffer(exported_program, lhs):
|
|
114
|
+
return True
|
|
115
|
+
return False
|
|
116
|
+
|
|
117
|
+
def convert(self, exported_program, node) -> torch.fx.Node:
|
|
118
|
+
return super().convert(exported_program, node)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class SingleBatchLhsConstBmmToLinearConverter(Converter):
|
|
122
|
+
"""
|
|
123
|
+
Convert `single-batched & lhs-const BatchMatMul` to `linear` operation.
|
|
124
|
+
|
|
125
|
+
[1] exchange lhs and rhs
|
|
126
|
+
[2] transpose rhs
|
|
127
|
+
[3] transpose output
|
|
128
|
+
|
|
129
|
+
**Before**
|
|
130
|
+
|
|
131
|
+
lhs[1,a,b](const) rhs[1,b,c]
|
|
132
|
+
| |
|
|
133
|
+
| |
|
|
134
|
+
---------bmm---------
|
|
135
|
+
|
|
|
136
|
+
output[1,a,c]
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
**After**
|
|
140
|
+
|
|
141
|
+
rhs[1,b,c]
|
|
142
|
+
|
|
|
143
|
+
tr lhs'[a,b](const-folded)
|
|
144
|
+
|[1,c,b] |
|
|
145
|
+
| |
|
|
146
|
+
---------fc--------
|
|
147
|
+
|[1,c,a]
|
|
148
|
+
tr
|
|
149
|
+
|
|
|
150
|
+
output[1,a,c]
|
|
151
|
+
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
def __init__(self):
|
|
155
|
+
super().__init__()
|
|
156
|
+
|
|
157
|
+
def match(self, exported_program, node) -> bool:
|
|
158
|
+
if not node.target == torch.ops.aten.bmm.default:
|
|
159
|
+
return False
|
|
160
|
+
|
|
161
|
+
bmm_args = BmmArgs(*node.args, **node.kwargs)
|
|
162
|
+
lhs = bmm_args.input
|
|
163
|
+
rhs = bmm_args.mat2
|
|
164
|
+
|
|
165
|
+
# [1] Single-batch
|
|
166
|
+
lhs_shape = extract_shape(lhs)
|
|
167
|
+
rhs_shape = extract_shape(rhs)
|
|
168
|
+
|
|
169
|
+
assert len(lhs_shape) == len(
|
|
170
|
+
rhs_shape
|
|
171
|
+
), f"Bmm input's ranks must be the same but got {lhs_shape} and {rhs_shape}"
|
|
172
|
+
|
|
173
|
+
if not (lhs_shape[0] == rhs_shape[0] == 1):
|
|
174
|
+
return False
|
|
175
|
+
|
|
176
|
+
# [2] Lhs is constant
|
|
177
|
+
if not isinstance(lhs, torch.fx.Node):
|
|
178
|
+
return False
|
|
179
|
+
if not (
|
|
180
|
+
is_lifted_tensor_constant(exported_program, lhs)
|
|
181
|
+
or is_param(exported_program, lhs)
|
|
182
|
+
or is_buffer(exported_program, lhs)
|
|
183
|
+
):
|
|
184
|
+
return False
|
|
185
|
+
|
|
186
|
+
return True
|
|
187
|
+
|
|
188
|
+
def convert(self, exported_program, node) -> torch.fx.Node:
|
|
189
|
+
graph_module = exported_program.graph_module
|
|
190
|
+
graph = graph_module.graph
|
|
191
|
+
|
|
192
|
+
bmm_args = BmmArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
193
|
+
|
|
194
|
+
lhs = bmm_args.input # const
|
|
195
|
+
rhs = bmm_args.mat2 # non-const
|
|
196
|
+
lhs_shape = extract_shape(lhs)
|
|
197
|
+
rhs_shape = extract_shape(rhs)
|
|
198
|
+
assert rhs_shape[0] == 1
|
|
199
|
+
assert lhs_shape[0] == 1
|
|
200
|
+
|
|
201
|
+
with graph.inserting_before(node):
|
|
202
|
+
rhs_tr = create_node(
|
|
203
|
+
graph,
|
|
204
|
+
torch.ops.aten.permute.default,
|
|
205
|
+
args=(rhs, [0, 2, 1]),
|
|
206
|
+
)
|
|
207
|
+
lhs_reshape = create_node(
|
|
208
|
+
graph,
|
|
209
|
+
torch.ops.aten.view.default,
|
|
210
|
+
args=(lhs, list(lhs_shape[1:])),
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
linear_node = create_node(
|
|
214
|
+
graph,
|
|
215
|
+
torch.ops.aten.linear.default,
|
|
216
|
+
args=(rhs_tr, lhs_reshape),
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
tr_linear_node = create_node(
|
|
220
|
+
graph,
|
|
221
|
+
torch.ops.aten.permute.default,
|
|
222
|
+
args=(linear_node, [0, 2, 1]),
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
node.replace_all_uses_with(tr_linear_node, propagate_meta=False)
|
|
226
|
+
|
|
227
|
+
return tr_linear_node
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
@trace_graph_diff_on_pass
|
|
231
|
+
class ConvertMatmulToLinear(PassBase):
|
|
232
|
+
"""
|
|
233
|
+
This pass converts matmul(partially includes single-batch bmm) to linear selectively
|
|
234
|
+
|
|
235
|
+
How to select between `matmul` and `linear`?
|
|
236
|
+
|
|
237
|
+
* Linear has better quantization accuracy (NPU backend)
|
|
238
|
+
Due to ONE compiler's quantization policy;
|
|
239
|
+
FullyConnected(=Linear) uses per-channel quantization for weight and per-tensor for input.
|
|
240
|
+
BatchMatmul(=matmul) uses per-tensor quantization for both rhs and lhs.
|
|
241
|
+
|
|
242
|
+
* Matmul to Linear requires Transpose, which may harm latency
|
|
243
|
+
When RHS is constant, addtional transpose can be folded.
|
|
244
|
+
|
|
245
|
+
[RHS non-const case]
|
|
246
|
+
Constant folding cannot be performed.
|
|
247
|
+
|
|
248
|
+
lhs rhs (non-const)
|
|
249
|
+
| |
|
|
250
|
+
| transpose
|
|
251
|
+
| |
|
|
252
|
+
-- linear --
|
|
253
|
+
|
|
|
254
|
+
out
|
|
255
|
+
|
|
256
|
+
[RHS const case]
|
|
257
|
+
Constant folding can be performed to
|
|
258
|
+
|
|
259
|
+
lhs rhs (const) lh rhs (folded const)
|
|
260
|
+
| | | |
|
|
261
|
+
| transpose | |
|
|
262
|
+
| | | |
|
|
263
|
+
-- linear -- --> -- linear --
|
|
264
|
+
| |
|
|
265
|
+
out out
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
enable_lhs_const: If true, convert matmul where LHS is constant tensor. Default is False.
|
|
269
|
+
enable_rhs_const: If true, convert matmul where RHS is constant tensor. Default is True.
|
|
270
|
+
"""
|
|
271
|
+
|
|
272
|
+
def __init__(
|
|
273
|
+
self,
|
|
274
|
+
enable_lhs_const: Optional[bool] = False,
|
|
275
|
+
enable_rhs_const: Optional[bool] = True,
|
|
276
|
+
enable_single_batch_lhs_const_bmm: Optional[bool] = False,
|
|
277
|
+
):
|
|
278
|
+
super().__init__()
|
|
279
|
+
self.converters: List[Converter] = []
|
|
280
|
+
if enable_lhs_const:
|
|
281
|
+
self.converters.append(LhsConstMatmulToLinearConverter())
|
|
282
|
+
if enable_rhs_const:
|
|
283
|
+
self.converters.append(RhsConstMatmulToLinearConverter())
|
|
284
|
+
if enable_single_batch_lhs_const_bmm:
|
|
285
|
+
self.converters.append(SingleBatchLhsConstBmmToLinearConverter())
|
|
286
|
+
|
|
287
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
|
288
|
+
logger = logging.getLogger(__name__)
|
|
289
|
+
|
|
290
|
+
graph_module = exported_program.graph_module
|
|
291
|
+
graph = graph_module.graph
|
|
292
|
+
modified = False
|
|
293
|
+
for node in graph.nodes:
|
|
294
|
+
if not node.op == "call_function":
|
|
295
|
+
continue
|
|
296
|
+
|
|
297
|
+
for converter in self.converters:
|
|
298
|
+
if not converter.match(exported_program, node):
|
|
299
|
+
continue
|
|
300
|
+
|
|
301
|
+
new_node = converter.convert(exported_program, node)
|
|
302
|
+
modified = True
|
|
303
|
+
logger.debug(
|
|
304
|
+
f"{node.name} is replaced with {new_node.name} operator (permute + linear)"
|
|
305
|
+
)
|
|
306
|
+
continue
|
|
307
|
+
|
|
308
|
+
graph.eliminate_dead_code()
|
|
309
|
+
graph.lint()
|
|
310
|
+
graph_module.recompile()
|
|
311
|
+
|
|
312
|
+
return PassResult(modified)
|
tico/passes/convert_to_relu6.py
CHANGED
tico/passes/decompose_addmm.py
CHANGED
|
@@ -20,7 +20,6 @@ import torch
|
|
|
20
20
|
from torch.export import ExportedProgram
|
|
21
21
|
|
|
22
22
|
from tico.serialize.circle_mapping import extract_shape
|
|
23
|
-
from tico.utils import logging
|
|
24
23
|
from tico.utils.graph import add_placeholder, create_node
|
|
25
24
|
from tico.utils.passes import PassBase, PassResult
|
|
26
25
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
|
@@ -59,8 +58,6 @@ class DecomposeAddmm(PassBase):
|
|
|
59
58
|
super().__init__()
|
|
60
59
|
|
|
61
60
|
def call(self, exported_program: ExportedProgram) -> PassResult:
|
|
62
|
-
logger = logging.getLogger(__name__)
|
|
63
|
-
|
|
64
61
|
gm = exported_program.graph_module
|
|
65
62
|
graph: torch.fx.Graph = gm.graph
|
|
66
63
|
modified = False
|
|
@@ -96,9 +96,9 @@ class DecomposeBatchNorm(PassBase):
|
|
|
96
96
|
eps = args.eps
|
|
97
97
|
|
|
98
98
|
if not running_mean:
|
|
99
|
-
raise NotYetSupportedError(
|
|
99
|
+
raise NotYetSupportedError("running_mean=None is not supported yet")
|
|
100
100
|
if not running_var:
|
|
101
|
-
raise NotYetSupportedError(
|
|
101
|
+
raise NotYetSupportedError("running_var=None is not supported yet")
|
|
102
102
|
|
|
103
103
|
"""
|
|
104
104
|
Only support the cases generated from torch.nn.BatchNorm2d module,
|
|
@@ -19,10 +19,8 @@ if TYPE_CHECKING:
|
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
21
|
# To import torch.ops.quantized_decomposed related operator
|
|
22
|
-
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
|
|
23
22
|
from torch.export import ExportedProgram
|
|
24
23
|
|
|
25
|
-
from tico.utils import logging
|
|
26
24
|
from tico.utils.graph import create_node
|
|
27
25
|
from tico.utils.passes import PassBase, PassResult
|
|
28
26
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
|
@@ -66,7 +64,6 @@ class DecomposeFakeQuantize(PassBase):
|
|
|
66
64
|
super().__init__()
|
|
67
65
|
|
|
68
66
|
def call(self, exported_program: ExportedProgram) -> PassResult:
|
|
69
|
-
logger = logging.getLogger(__name__)
|
|
70
67
|
modified = False
|
|
71
68
|
|
|
72
69
|
gm = exported_program.graph_module
|
|
@@ -26,10 +26,8 @@ from torch._export.utils import (
|
|
|
26
26
|
)
|
|
27
27
|
|
|
28
28
|
# To import torch.ops.quantized_decomposed related operator
|
|
29
|
-
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
|
|
30
29
|
from torch.export import ExportedProgram
|
|
31
30
|
|
|
32
|
-
from tico.utils import logging
|
|
33
31
|
from tico.utils.graph import create_node
|
|
34
32
|
from tico.utils.passes import PassBase, PassResult
|
|
35
33
|
from tico.utils.trace_decorators import (
|
|
@@ -246,10 +244,11 @@ class DecomposeFakeQuantizeTensorQParams(PassBase):
|
|
|
246
244
|
# So, let's remove `mask` from the output.args first.
|
|
247
245
|
# mask_user(output).args == (dequantize_per_tensor.tensor, mask)
|
|
248
246
|
if mask:
|
|
249
|
-
len(mask) == 1
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
247
|
+
assert len(mask) == 1
|
|
248
|
+
if len(mask[0].users) > 0:
|
|
249
|
+
mask_user = list(mask[0].users.keys())[0]
|
|
250
|
+
assert len(mask_user.args) == 1
|
|
251
|
+
mask_user.args = ((mask_user.args[0][0],),)
|
|
253
252
|
modified = True
|
|
254
253
|
if (
|
|
255
254
|
node.target
|
|
@@ -22,7 +22,6 @@ import torch
|
|
|
22
22
|
from torch.export import ExportedProgram
|
|
23
23
|
|
|
24
24
|
from tico.serialize.circle_mapping import extract_shape
|
|
25
|
-
from tico.utils import logging
|
|
26
25
|
from tico.utils.graph import create_node
|
|
27
26
|
from tico.utils.passes import PassBase, PassResult
|
|
28
27
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
|
@@ -126,8 +125,6 @@ class DecomposeGroupNorm(PassBase):
|
|
|
126
125
|
)
|
|
127
126
|
|
|
128
127
|
def call(self, exported_program: ExportedProgram) -> PassResult:
|
|
129
|
-
logger = logging.getLogger(__name__)
|
|
130
|
-
|
|
131
128
|
gm = exported_program.graph_module
|
|
132
129
|
graph: torch.fx.Graph = gm.graph
|
|
133
130
|
modified = False
|
|
@@ -20,7 +20,7 @@ if TYPE_CHECKING:
|
|
|
20
20
|
import torch
|
|
21
21
|
from torch.export import ExportedProgram
|
|
22
22
|
|
|
23
|
-
from tico.serialize.
|
|
23
|
+
from tico.serialize.circle_mapping import extract_shape
|
|
24
24
|
from tico.utils import logging
|
|
25
25
|
from tico.utils.errors import NotYetSupportedError
|
|
26
26
|
from tico.utils.graph import create_node
|
|
@@ -206,7 +206,6 @@ class LegalizePreDefinedLayoutOperators(PassBase):
|
|
|
206
206
|
|
|
207
207
|
args = ConvTranspose2DArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
208
208
|
input = args.input
|
|
209
|
-
padding = args.padding
|
|
210
209
|
groups = args.groups
|
|
211
210
|
dilation = args.dilation
|
|
212
211
|
|
|
@@ -288,13 +287,12 @@ class LegalizePreDefinedLayoutOperators(PassBase):
|
|
|
288
287
|
input = args.input
|
|
289
288
|
weight = args.weight
|
|
290
289
|
bias = args.bias
|
|
291
|
-
eps = args.eps
|
|
292
290
|
|
|
293
291
|
running_mean = args.running_mean
|
|
294
292
|
running_var = args.running_var
|
|
295
293
|
use_input_stats = args.use_input_stats
|
|
296
294
|
|
|
297
|
-
if not
|
|
295
|
+
if not use_input_stats:
|
|
298
296
|
raise NotYetSupportedError("Only support use_input_stats is True.")
|
|
299
297
|
if not isinstance(running_mean, NoneType):
|
|
300
298
|
raise NotYetSupportedError("Only support running_mean=None")
|
|
@@ -350,10 +348,6 @@ class LegalizePreDefinedLayoutOperators(PassBase):
|
|
|
350
348
|
# max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
|
|
351
349
|
args = MaxPool2dWithIndicesArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
352
350
|
input_ = args.input
|
|
353
|
-
kernel_size = args.kernel_size
|
|
354
|
-
stride = args.stride
|
|
355
|
-
padding = args.padding
|
|
356
|
-
dilation = args.dilation
|
|
357
351
|
ceil_mode = args.ceil_mode
|
|
358
352
|
if ceil_mode:
|
|
359
353
|
raise NotYetSupportedError("Only support non-ceil model.")
|
|
@@ -402,9 +396,6 @@ class LegalizePreDefinedLayoutOperators(PassBase):
|
|
|
402
396
|
# avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> (Tensor)
|
|
403
397
|
args = AvgPool2dArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
404
398
|
input_ = args.input
|
|
405
|
-
kernel_size = args.kernel_size
|
|
406
|
-
stride = args.stride
|
|
407
|
-
padding = args.padding
|
|
408
399
|
ceil_mode = args.ceil_mode
|
|
409
400
|
if ceil_mode:
|
|
410
401
|
raise NotYetSupportedError("Only support non-ceil model.")
|
|
@@ -67,7 +67,7 @@ class LowerToResizeNearestNeighbor(PassBase):
|
|
|
67
67
|
return None
|
|
68
68
|
# indices = [None, None, H index, W index]
|
|
69
69
|
N, C, H, W = indices
|
|
70
|
-
if N
|
|
70
|
+
if N is not None or C is not None:
|
|
71
71
|
return None
|
|
72
72
|
if not isinstance(H, torch.fx.Node):
|
|
73
73
|
return None
|
tico/passes/lower_to_slice.py
CHANGED
|
@@ -28,7 +28,7 @@ from torch._export.utils import (
|
|
|
28
28
|
from torch.export import ExportedProgram
|
|
29
29
|
|
|
30
30
|
from tico.passes import ops
|
|
31
|
-
from tico.serialize.
|
|
31
|
+
from tico.serialize.circle_mapping import extract_shape
|
|
32
32
|
from tico.utils import logging
|
|
33
33
|
from tico.utils.graph import create_node, is_single_value_tensor
|
|
34
34
|
from tico.utils.passes import PassBase, PassResult
|
|
@@ -51,7 +51,7 @@ class MergeConsecutiveCat(PassBase):
|
|
|
51
51
|
if not prev_cat.op == "call_function":
|
|
52
52
|
continue
|
|
53
53
|
|
|
54
|
-
if
|
|
54
|
+
if prev_cat.target not in ops.aten.cat:
|
|
55
55
|
continue
|
|
56
56
|
|
|
57
57
|
prev_args = CatArgs(*prev_cat.args, **prev_cat.kwargs) # type: ignore[arg-type]
|
tico/passes/ops.py
CHANGED
|
@@ -69,10 +69,10 @@ class AtenOps:
|
|
|
69
69
|
torch.ops.aten.unsqueeze_copy.default,
|
|
70
70
|
]
|
|
71
71
|
self.view = [
|
|
72
|
-
torch.ops.aten.view,
|
|
73
72
|
torch.ops.aten.view.default,
|
|
74
73
|
torch.ops.aten.view_copy.default,
|
|
75
74
|
]
|
|
75
|
+
self._to_copy = [torch.ops.aten._to_copy.default]
|
|
76
76
|
|
|
77
77
|
|
|
78
78
|
aten = AtenOps()
|
|
@@ -21,7 +21,9 @@ from tico.utils.utils import is_target_node
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
assert_node_targets = [
|
|
24
|
+
torch.ops.aten._assert_scalar.default,
|
|
24
25
|
torch.ops.aten._assert_tensor_metadata.default,
|
|
26
|
+
torch.ops.aten.sym_constrain_range_for_size.default, # Related to symbolic shape validation
|
|
25
27
|
]
|
|
26
28
|
|
|
27
29
|
|
|
@@ -29,7 +31,7 @@ assert_node_targets = [
|
|
|
29
31
|
class RemoveRedundantAssertionNodes(PassBase):
|
|
30
32
|
"""
|
|
31
33
|
This removes redundant assertion nodes.
|
|
32
|
-
|
|
34
|
+
When assertion node is erased, related comparison nodes are also removed by graph.eliminate_dead_code().
|
|
33
35
|
"""
|
|
34
36
|
|
|
35
37
|
def __init__(self):
|
|
@@ -12,11 +12,6 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import TYPE_CHECKING
|
|
16
|
-
|
|
17
|
-
if TYPE_CHECKING:
|
|
18
|
-
import torch.fx
|
|
19
|
-
import torch
|
|
20
15
|
from torch.export import ExportedProgram
|
|
21
16
|
|
|
22
17
|
from tico.passes import ops
|
|
@@ -51,7 +46,9 @@ class RemoveRedundantExpand(PassBase):
|
|
|
51
46
|
input, size = args.input, args.size
|
|
52
47
|
|
|
53
48
|
input_shape = extract_shape(input)
|
|
54
|
-
|
|
49
|
+
output_shape = extract_shape(node)
|
|
50
|
+
|
|
51
|
+
if input_shape != output_shape:
|
|
55
52
|
continue
|
|
56
53
|
|
|
57
54
|
node.replace_all_uses_with(input, propagate_meta=False)
|
|
@@ -90,7 +90,7 @@ class RemoveRedundantReshapePattern1(PassBase):
|
|
|
90
90
|
if len(permute.users) != 1:
|
|
91
91
|
continue
|
|
92
92
|
permute_args = PermuteArgs(*permute.args, **permute.kwargs) # type: ignore[arg-type]
|
|
93
|
-
|
|
93
|
+
permute_dims = permute_args.dims
|
|
94
94
|
# (1xAxBxC) - `aten.permute` - (1xAxCxB)
|
|
95
95
|
if permute_dims != [0, 1, 3, 2]:
|
|
96
96
|
continue
|
|
@@ -172,7 +172,7 @@ class RemoveRedundantReshapePattern2(PassBase):
|
|
|
172
172
|
if len(permute.users) != 1:
|
|
173
173
|
continue
|
|
174
174
|
permute_args = PermuteArgs(*permute.args, **permute.kwargs) # type: ignore[arg-type]
|
|
175
|
-
|
|
175
|
+
permute_dims = permute_args.dims
|
|
176
176
|
# (1xAxBxC) - `aten.permute` - (Bx1xAxC)
|
|
177
177
|
if permute_dims != [2, 0, 1, 3]:
|
|
178
178
|
continue
|
|
@@ -262,7 +262,7 @@ class RemoveRedundantReshapePattern3(PassBase):
|
|
|
262
262
|
continue
|
|
263
263
|
|
|
264
264
|
# add
|
|
265
|
-
if
|
|
265
|
+
if add.target not in ops.aten.add:
|
|
266
266
|
continue
|
|
267
267
|
add_args = AddTensorArgs(*add.args, **add.kwargs) # type: ignore[arg-type]
|
|
268
268
|
reshape_2, reshape_3 = add_args.input, add_args.other
|
|
@@ -272,7 +272,7 @@ class RemoveRedundantReshapePattern3(PassBase):
|
|
|
272
272
|
# reshape_2
|
|
273
273
|
if not reshape_2.op == "call_function":
|
|
274
274
|
continue
|
|
275
|
-
if
|
|
275
|
+
if reshape_2.target not in ops.aten.reshape:
|
|
276
276
|
continue
|
|
277
277
|
reshape_2_args = ReshapeArgs(*reshape_2.args, **reshape_2.kwargs) # type: ignore[arg-type]
|
|
278
278
|
reshape_2_input = reshape_2_args.input
|
|
@@ -280,7 +280,7 @@ class RemoveRedundantReshapePattern3(PassBase):
|
|
|
280
280
|
# reshape_3
|
|
281
281
|
if not reshape_3.op == "call_function":
|
|
282
282
|
continue
|
|
283
|
-
if
|
|
283
|
+
if reshape_3.target not in ops.aten.reshape:
|
|
284
284
|
continue
|
|
285
285
|
reshape_3_args = ReshapeArgs(*reshape_3.args, **reshape_3.kwargs) # type: ignore[arg-type]
|
|
286
286
|
reshape_3_input = reshape_3_args.input
|
|
@@ -29,7 +29,7 @@ from torch._export.utils import (
|
|
|
29
29
|
from torch.export import ExportedProgram
|
|
30
30
|
|
|
31
31
|
from tico.passes import ops
|
|
32
|
-
from tico.serialize.
|
|
32
|
+
from tico.serialize.circle_mapping import extract_shape
|
|
33
33
|
from tico.utils import logging
|
|
34
34
|
from tico.utils.graph import add_placeholder, create_node, is_single_value_tensor
|
|
35
35
|
from tico.utils.passes import PassBase, PassResult
|
|
@@ -25,7 +25,7 @@ from typing import Optional
|
|
|
25
25
|
import torch
|
|
26
26
|
import torch.nn as nn
|
|
27
27
|
|
|
28
|
-
from tico.
|
|
28
|
+
from tico.quantization.algorithm.gptq.quant import quantize, Quantizer
|
|
29
29
|
|
|
30
30
|
torch.backends.cuda.matmul.allow_tf32 = False
|
|
31
31
|
torch.backends.cudnn.allow_tf32 = False
|