tico 0.1.0.dev250803__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_mixed_type_args.py +2 -0
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -4
- tico/passes/ops.py +0 -1
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- tico/passes/remove_redundant_expand.py +3 -1
- tico/quantization/__init__.py +6 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +30 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/quantization/config/ptq.py +119 -0
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +7 -16
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/quantization/evaluation/metric.py +146 -0
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/quantization/passes/__init__.py +1 -0
- tico/{experimental/quantization → quantization}/public_interface.py +11 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/__init__.py +1 -0
- tico/quantization/wrapq/dtypes.py +70 -0
- tico/quantization/wrapq/examples/__init__.py +1 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/quantization/wrapq/examples/quantize_linear.py +107 -0
- tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
- tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
- tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/quantization/wrapq/mode.py +32 -0
- tico/quantization/wrapq/observers/__init__.py +1 -0
- tico/quantization/wrapq/observers/affine_base.py +128 -0
- tico/quantization/wrapq/observers/base.py +98 -0
- tico/quantization/wrapq/observers/ema.py +62 -0
- tico/quantization/wrapq/observers/identity.py +74 -0
- tico/quantization/wrapq/observers/minmax.py +39 -0
- tico/quantization/wrapq/observers/mx.py +60 -0
- tico/quantization/wrapq/qscheme.py +40 -0
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/quantization/wrapq/utils/__init__.py +1 -0
- tico/quantization/wrapq/utils/introspection.py +167 -0
- tico/quantization/wrapq/utils/metrics.py +124 -0
- tico/quantization/wrapq/utils/reduce_utils.py +25 -0
- tico/quantization/wrapq/wrappers/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
- tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
- tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
- tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
- tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
- tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
- tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
- tico/quantization/wrapq/wrappers/registry.py +125 -0
- tico/serialize/circle_serializer.py +11 -4
- tico/serialize/operators/adapters/__init__.py +1 -0
- tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_mm.py +15 -132
- tico/serialize/operators/op_rmsnorm.py +65 -0
- tico/utils/convert.py +20 -15
- tico/utils/dtype.py +22 -0
- tico/utils/register_custom_op.py +29 -4
- tico/utils/signature.py +247 -0
- tico/utils/utils.py +50 -53
- tico/utils/validate_args_kwargs.py +37 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/RECORD +130 -73
- tico/experimental/quantization/__init__.py +0 -6
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/evaluation/metric.py +0 -109
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
tico/config/v1.py
CHANGED
|
@@ -20,6 +20,11 @@ from tico.config.base import CompileConfigBase
|
|
|
20
20
|
@dataclass
|
|
21
21
|
class CompileConfigV1(CompileConfigBase):
|
|
22
22
|
legalize_causal_mask_value: bool = False
|
|
23
|
+
remove_constant_input: bool = False
|
|
24
|
+
convert_lhs_const_mm_to_fc: bool = False
|
|
25
|
+
convert_rhs_const_mm_to_fc: bool = True
|
|
26
|
+
convert_single_batch_lhs_const_bmm_to_fc: bool = False
|
|
27
|
+
convert_expand_to_slice_cat: bool = False
|
|
23
28
|
|
|
24
29
|
def get(self, name: str):
|
|
25
30
|
return super().get(name)
|
|
@@ -41,6 +41,8 @@ ops_to_promote = {
|
|
|
41
41
|
torch.ops.aten.ge.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
|
42
42
|
torch.ops.aten.gt.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
|
43
43
|
torch.ops.aten.gt.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
|
44
|
+
torch.ops.aten.le.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
|
45
|
+
torch.ops.aten.le.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
|
44
46
|
torch.ops.aten.mul.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
|
45
47
|
torch.ops.aten.minimum.default: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
|
46
48
|
torch.ops.aten.ne.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import torch.fx
|
|
19
|
+
import torch
|
|
20
|
+
from torch.export import ExportedProgram
|
|
21
|
+
|
|
22
|
+
from tico.passes import ops
|
|
23
|
+
from tico.serialize.circle_mapping import extract_shape
|
|
24
|
+
from tico.utils import logging
|
|
25
|
+
from tico.utils.graph import create_node
|
|
26
|
+
from tico.utils.passes import PassBase, PassResult
|
|
27
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
|
28
|
+
from tico.utils.utils import is_target_node
|
|
29
|
+
from tico.utils.validate_args_kwargs import ExpandArgs, ReshapeArgs
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@trace_graph_diff_on_pass
|
|
33
|
+
class ConvertExpandToSliceCat(PassBase):
|
|
34
|
+
"""
|
|
35
|
+
This pass replaces `aten.reshape` + `aten.expand` pattern by rewriting it using
|
|
36
|
+
a series of `aten.slice` and `aten.cat` operations.
|
|
37
|
+
|
|
38
|
+
This pass is specialized for expand of KVCache.
|
|
39
|
+
- Expects (batch, num_key_value_heads, seq_len, head_dim) as input shape of reshape
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, enabled: bool = False):
|
|
43
|
+
super().__init__()
|
|
44
|
+
self.enabled = enabled
|
|
45
|
+
|
|
46
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
|
47
|
+
if not self.enabled:
|
|
48
|
+
return PassResult(False)
|
|
49
|
+
|
|
50
|
+
logger = logging.getLogger(__name__)
|
|
51
|
+
|
|
52
|
+
graph_module = exported_program.graph_module
|
|
53
|
+
graph = graph_module.graph
|
|
54
|
+
modified = False
|
|
55
|
+
|
|
56
|
+
# This pass handles expand on EXPAND_DIM only
|
|
57
|
+
CAT_DIM = 1
|
|
58
|
+
EXPAND_DIM = 2
|
|
59
|
+
|
|
60
|
+
for node in graph.nodes:
|
|
61
|
+
if not isinstance(node, torch.fx.Node) or not is_target_node(
|
|
62
|
+
node, ops.aten.reshape
|
|
63
|
+
):
|
|
64
|
+
continue
|
|
65
|
+
|
|
66
|
+
post_reshape = node
|
|
67
|
+
post_reshape_args = ReshapeArgs(*post_reshape.args, **post_reshape.kwargs)
|
|
68
|
+
post_reshape_input = post_reshape_args.input
|
|
69
|
+
|
|
70
|
+
if not isinstance(post_reshape_input, torch.fx.Node) or not is_target_node(
|
|
71
|
+
post_reshape_input, ops.aten.expand
|
|
72
|
+
):
|
|
73
|
+
continue
|
|
74
|
+
|
|
75
|
+
expand = post_reshape_input
|
|
76
|
+
expand_args = ExpandArgs(*expand.args, **expand.kwargs)
|
|
77
|
+
expand_input = expand_args.input
|
|
78
|
+
expand_shape = extract_shape(expand)
|
|
79
|
+
|
|
80
|
+
if not isinstance(expand_input, torch.fx.Node) or not is_target_node(
|
|
81
|
+
expand_input, ops.aten.reshape
|
|
82
|
+
):
|
|
83
|
+
continue
|
|
84
|
+
|
|
85
|
+
pre_reshape = expand_input
|
|
86
|
+
pre_reshape_args = ReshapeArgs(*pre_reshape.args, **pre_reshape.kwargs)
|
|
87
|
+
pre_reshape_input = pre_reshape_args.input
|
|
88
|
+
pre_reshape_shape = extract_shape(pre_reshape)
|
|
89
|
+
|
|
90
|
+
if pre_reshape_shape[EXPAND_DIM] != 1:
|
|
91
|
+
continue
|
|
92
|
+
|
|
93
|
+
reshape_input_shape = extract_shape(pre_reshape_input)
|
|
94
|
+
|
|
95
|
+
if len(expand_shape) != len(pre_reshape_shape):
|
|
96
|
+
continue
|
|
97
|
+
|
|
98
|
+
# Ensure all dimensions *except* at EXPAND_DIM are identical.
|
|
99
|
+
if not (
|
|
100
|
+
expand_shape[:EXPAND_DIM] == pre_reshape_shape[:EXPAND_DIM]
|
|
101
|
+
and expand_shape[EXPAND_DIM + 1 :]
|
|
102
|
+
== pre_reshape_shape[EXPAND_DIM + 1 :]
|
|
103
|
+
):
|
|
104
|
+
continue
|
|
105
|
+
|
|
106
|
+
# Ensure the expansion dimension is a clean multiple.
|
|
107
|
+
if expand_shape[EXPAND_DIM] % pre_reshape_shape[EXPAND_DIM] != 0:
|
|
108
|
+
continue
|
|
109
|
+
|
|
110
|
+
expand_ratio = expand_shape[EXPAND_DIM] // pre_reshape_shape[EXPAND_DIM]
|
|
111
|
+
|
|
112
|
+
if expand_ratio <= 1:
|
|
113
|
+
continue
|
|
114
|
+
|
|
115
|
+
cat_nodes = []
|
|
116
|
+
|
|
117
|
+
for i in range(reshape_input_shape[CAT_DIM]):
|
|
118
|
+
with graph.inserting_before(expand):
|
|
119
|
+
slice_copy_args = (pre_reshape_input, CAT_DIM, i, i + 1, 1)
|
|
120
|
+
slice_node = create_node(
|
|
121
|
+
graph,
|
|
122
|
+
torch.ops.aten.slice.Tensor,
|
|
123
|
+
args=slice_copy_args,
|
|
124
|
+
origin=expand,
|
|
125
|
+
)
|
|
126
|
+
with graph.inserting_after(slice_node):
|
|
127
|
+
cat_args = ([slice_node] * expand_ratio, CAT_DIM)
|
|
128
|
+
cat_node = create_node(
|
|
129
|
+
graph,
|
|
130
|
+
torch.ops.aten.cat.default,
|
|
131
|
+
args=cat_args,
|
|
132
|
+
origin=expand,
|
|
133
|
+
)
|
|
134
|
+
cat_nodes.append(cat_node)
|
|
135
|
+
|
|
136
|
+
with graph.inserting_after(expand):
|
|
137
|
+
cat_args = (cat_nodes, CAT_DIM)
|
|
138
|
+
cat_node = create_node(
|
|
139
|
+
graph,
|
|
140
|
+
torch.ops.aten.cat.default,
|
|
141
|
+
args=cat_args,
|
|
142
|
+
origin=expand,
|
|
143
|
+
)
|
|
144
|
+
expand.replace_all_uses_with(cat_node)
|
|
145
|
+
|
|
146
|
+
modified = True
|
|
147
|
+
logger.debug(f"{expand.name} is replaced with {cat_node.name} operators")
|
|
148
|
+
|
|
149
|
+
graph.eliminate_dead_code()
|
|
150
|
+
graph.lint()
|
|
151
|
+
graph_module.recompile()
|
|
152
|
+
|
|
153
|
+
return PassResult(modified)
|
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import List, Optional, TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import torch.fx
|
|
19
|
+
import torch
|
|
20
|
+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
|
|
21
|
+
from torch.export import ExportedProgram
|
|
22
|
+
|
|
23
|
+
from tico.serialize.circle_mapping import extract_shape
|
|
24
|
+
|
|
25
|
+
from tico.utils import logging
|
|
26
|
+
from tico.utils.graph import create_node
|
|
27
|
+
from tico.utils.passes import PassBase, PassResult
|
|
28
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
|
29
|
+
from tico.utils.validate_args_kwargs import BmmArgs, MatmulArgs
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Converter: # type: ignore[empty-body]
|
|
33
|
+
def __init__(self):
|
|
34
|
+
super().__init__()
|
|
35
|
+
|
|
36
|
+
def match(self, exported_program, node) -> bool: # type: ignore[empty-body]
|
|
37
|
+
return False
|
|
38
|
+
|
|
39
|
+
def convert(self, exported_program, node) -> torch.fx.Node: # type: ignore[empty-body]
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class MatmulToLinearConverter(Converter):
|
|
44
|
+
def __init__(self):
|
|
45
|
+
super().__init__()
|
|
46
|
+
|
|
47
|
+
def convert(self, exported_program, node) -> torch.fx.Node:
|
|
48
|
+
graph_module = exported_program.graph_module
|
|
49
|
+
graph = graph_module.graph
|
|
50
|
+
|
|
51
|
+
mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
52
|
+
|
|
53
|
+
lhs = mm_args.input
|
|
54
|
+
rhs = mm_args.other
|
|
55
|
+
|
|
56
|
+
with graph.inserting_before(node):
|
|
57
|
+
transpose_node = create_node(
|
|
58
|
+
graph,
|
|
59
|
+
torch.ops.aten.permute.default,
|
|
60
|
+
args=(rhs, [1, 0]),
|
|
61
|
+
)
|
|
62
|
+
linear_node = create_node(
|
|
63
|
+
graph,
|
|
64
|
+
torch.ops.aten.linear.default,
|
|
65
|
+
args=(lhs, transpose_node),
|
|
66
|
+
)
|
|
67
|
+
node.replace_all_uses_with(linear_node, propagate_meta=True)
|
|
68
|
+
|
|
69
|
+
return linear_node
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class RhsConstMatmulToLinearConverter(MatmulToLinearConverter):
|
|
73
|
+
def __init__(self):
|
|
74
|
+
super().__init__()
|
|
75
|
+
|
|
76
|
+
def match(self, exported_program, node) -> bool:
|
|
77
|
+
if not node.target == torch.ops.aten.mm.default:
|
|
78
|
+
return False
|
|
79
|
+
|
|
80
|
+
mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
81
|
+
|
|
82
|
+
rhs = mm_args.other
|
|
83
|
+
if isinstance(rhs, torch.fx.Node):
|
|
84
|
+
if is_lifted_tensor_constant(exported_program, rhs):
|
|
85
|
+
return True
|
|
86
|
+
elif is_param(exported_program, rhs):
|
|
87
|
+
return True
|
|
88
|
+
elif is_buffer(exported_program, rhs):
|
|
89
|
+
return True
|
|
90
|
+
else:
|
|
91
|
+
return False
|
|
92
|
+
return False
|
|
93
|
+
|
|
94
|
+
def convert(self, exported_program, node) -> torch.fx.Node:
|
|
95
|
+
return super().convert(exported_program, node)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class LhsConstMatmulToLinearConverter(MatmulToLinearConverter):
|
|
99
|
+
def __init__(self):
|
|
100
|
+
super().__init__()
|
|
101
|
+
|
|
102
|
+
def match(self, exported_program, node) -> bool:
|
|
103
|
+
if not node.target == torch.ops.aten.mm.default:
|
|
104
|
+
return False
|
|
105
|
+
|
|
106
|
+
mm_args = MatmulArgs(*node.args, **node.kwargs)
|
|
107
|
+
lhs = mm_args.input
|
|
108
|
+
if isinstance(lhs, torch.fx.Node):
|
|
109
|
+
if is_lifted_tensor_constant(exported_program, lhs):
|
|
110
|
+
return True
|
|
111
|
+
elif is_param(exported_program, lhs):
|
|
112
|
+
return True
|
|
113
|
+
elif is_buffer(exported_program, lhs):
|
|
114
|
+
return True
|
|
115
|
+
return False
|
|
116
|
+
|
|
117
|
+
def convert(self, exported_program, node) -> torch.fx.Node:
|
|
118
|
+
return super().convert(exported_program, node)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class SingleBatchLhsConstBmmToLinearConverter(Converter):
|
|
122
|
+
"""
|
|
123
|
+
Convert `single-batched & lhs-const BatchMatMul` to `linear` operation.
|
|
124
|
+
|
|
125
|
+
[1] exchange lhs and rhs
|
|
126
|
+
[2] transpose rhs
|
|
127
|
+
[3] transpose output
|
|
128
|
+
|
|
129
|
+
**Before**
|
|
130
|
+
|
|
131
|
+
lhs[1,a,b](const) rhs[1,b,c]
|
|
132
|
+
| |
|
|
133
|
+
| |
|
|
134
|
+
---------bmm---------
|
|
135
|
+
|
|
|
136
|
+
output[1,a,c]
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
**After**
|
|
140
|
+
|
|
141
|
+
rhs[1,b,c]
|
|
142
|
+
|
|
|
143
|
+
tr lhs'[a,b](const-folded)
|
|
144
|
+
|[1,c,b] |
|
|
145
|
+
| |
|
|
146
|
+
---------fc--------
|
|
147
|
+
|[1,c,a]
|
|
148
|
+
tr
|
|
149
|
+
|
|
|
150
|
+
output[1,a,c]
|
|
151
|
+
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
def __init__(self):
|
|
155
|
+
super().__init__()
|
|
156
|
+
|
|
157
|
+
def match(self, exported_program, node) -> bool:
|
|
158
|
+
if not node.target == torch.ops.aten.bmm.default:
|
|
159
|
+
return False
|
|
160
|
+
|
|
161
|
+
bmm_args = BmmArgs(*node.args, **node.kwargs)
|
|
162
|
+
lhs = bmm_args.input
|
|
163
|
+
rhs = bmm_args.mat2
|
|
164
|
+
|
|
165
|
+
# [1] Single-batch
|
|
166
|
+
lhs_shape = extract_shape(lhs)
|
|
167
|
+
rhs_shape = extract_shape(rhs)
|
|
168
|
+
|
|
169
|
+
assert len(lhs_shape) == len(
|
|
170
|
+
rhs_shape
|
|
171
|
+
), f"Bmm input's ranks must be the same but got {lhs_shape} and {rhs_shape}"
|
|
172
|
+
|
|
173
|
+
if not (lhs_shape[0] == rhs_shape[0] == 1):
|
|
174
|
+
return False
|
|
175
|
+
|
|
176
|
+
# [2] Lhs is constant
|
|
177
|
+
if not isinstance(lhs, torch.fx.Node):
|
|
178
|
+
return False
|
|
179
|
+
if not (
|
|
180
|
+
is_lifted_tensor_constant(exported_program, lhs)
|
|
181
|
+
or is_param(exported_program, lhs)
|
|
182
|
+
or is_buffer(exported_program, lhs)
|
|
183
|
+
):
|
|
184
|
+
return False
|
|
185
|
+
|
|
186
|
+
return True
|
|
187
|
+
|
|
188
|
+
def convert(self, exported_program, node) -> torch.fx.Node:
|
|
189
|
+
graph_module = exported_program.graph_module
|
|
190
|
+
graph = graph_module.graph
|
|
191
|
+
|
|
192
|
+
bmm_args = BmmArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
193
|
+
|
|
194
|
+
lhs = bmm_args.input # const
|
|
195
|
+
rhs = bmm_args.mat2 # non-const
|
|
196
|
+
lhs_shape = extract_shape(lhs)
|
|
197
|
+
rhs_shape = extract_shape(rhs)
|
|
198
|
+
assert rhs_shape[0] == 1
|
|
199
|
+
assert lhs_shape[0] == 1
|
|
200
|
+
|
|
201
|
+
with graph.inserting_before(node):
|
|
202
|
+
rhs_tr = create_node(
|
|
203
|
+
graph,
|
|
204
|
+
torch.ops.aten.permute.default,
|
|
205
|
+
args=(rhs, [0, 2, 1]),
|
|
206
|
+
)
|
|
207
|
+
lhs_reshape = create_node(
|
|
208
|
+
graph,
|
|
209
|
+
torch.ops.aten.view.default,
|
|
210
|
+
args=(lhs, list(lhs_shape[1:])),
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
linear_node = create_node(
|
|
214
|
+
graph,
|
|
215
|
+
torch.ops.aten.linear.default,
|
|
216
|
+
args=(rhs_tr, lhs_reshape),
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
tr_linear_node = create_node(
|
|
220
|
+
graph,
|
|
221
|
+
torch.ops.aten.permute.default,
|
|
222
|
+
args=(linear_node, [0, 2, 1]),
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
node.replace_all_uses_with(tr_linear_node, propagate_meta=False)
|
|
226
|
+
|
|
227
|
+
return tr_linear_node
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
@trace_graph_diff_on_pass
|
|
231
|
+
class ConvertMatmulToLinear(PassBase):
|
|
232
|
+
"""
|
|
233
|
+
This pass converts matmul(partially includes single-batch bmm) to linear selectively
|
|
234
|
+
|
|
235
|
+
How to select between `matmul` and `linear`?
|
|
236
|
+
|
|
237
|
+
* Linear has better quantization accuracy (NPU backend)
|
|
238
|
+
Due to ONE compiler's quantization policy;
|
|
239
|
+
FullyConnected(=Linear) uses per-channel quantization for weight and per-tensor for input.
|
|
240
|
+
BatchMatmul(=matmul) uses per-tensor quantization for both rhs and lhs.
|
|
241
|
+
|
|
242
|
+
* Matmul to Linear requires Transpose, which may harm latency
|
|
243
|
+
When RHS is constant, addtional transpose can be folded.
|
|
244
|
+
|
|
245
|
+
[RHS non-const case]
|
|
246
|
+
Constant folding cannot be performed.
|
|
247
|
+
|
|
248
|
+
lhs rhs (non-const)
|
|
249
|
+
| |
|
|
250
|
+
| transpose
|
|
251
|
+
| |
|
|
252
|
+
-- linear --
|
|
253
|
+
|
|
|
254
|
+
out
|
|
255
|
+
|
|
256
|
+
[RHS const case]
|
|
257
|
+
Constant folding can be performed to
|
|
258
|
+
|
|
259
|
+
lhs rhs (const) lh rhs (folded const)
|
|
260
|
+
| | | |
|
|
261
|
+
| transpose | |
|
|
262
|
+
| | | |
|
|
263
|
+
-- linear -- --> -- linear --
|
|
264
|
+
| |
|
|
265
|
+
out out
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
enable_lhs_const: If true, convert matmul where LHS is constant tensor. Default is False.
|
|
269
|
+
enable_rhs_const: If true, convert matmul where RHS is constant tensor. Default is True.
|
|
270
|
+
"""
|
|
271
|
+
|
|
272
|
+
def __init__(
|
|
273
|
+
self,
|
|
274
|
+
enable_lhs_const: Optional[bool] = False,
|
|
275
|
+
enable_rhs_const: Optional[bool] = True,
|
|
276
|
+
enable_single_batch_lhs_const_bmm: Optional[bool] = False,
|
|
277
|
+
):
|
|
278
|
+
super().__init__()
|
|
279
|
+
self.converters: List[Converter] = []
|
|
280
|
+
if enable_lhs_const:
|
|
281
|
+
self.converters.append(LhsConstMatmulToLinearConverter())
|
|
282
|
+
if enable_rhs_const:
|
|
283
|
+
self.converters.append(RhsConstMatmulToLinearConverter())
|
|
284
|
+
if enable_single_batch_lhs_const_bmm:
|
|
285
|
+
self.converters.append(SingleBatchLhsConstBmmToLinearConverter())
|
|
286
|
+
|
|
287
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
|
288
|
+
logger = logging.getLogger(__name__)
|
|
289
|
+
|
|
290
|
+
graph_module = exported_program.graph_module
|
|
291
|
+
graph = graph_module.graph
|
|
292
|
+
modified = False
|
|
293
|
+
for node in graph.nodes:
|
|
294
|
+
if not node.op == "call_function":
|
|
295
|
+
continue
|
|
296
|
+
|
|
297
|
+
for converter in self.converters:
|
|
298
|
+
if not converter.match(exported_program, node):
|
|
299
|
+
continue
|
|
300
|
+
|
|
301
|
+
new_node = converter.convert(exported_program, node)
|
|
302
|
+
modified = True
|
|
303
|
+
logger.debug(
|
|
304
|
+
f"{node.name} is replaced with {new_node.name} operator (permute + linear)"
|
|
305
|
+
)
|
|
306
|
+
continue
|
|
307
|
+
|
|
308
|
+
graph.eliminate_dead_code()
|
|
309
|
+
graph.lint()
|
|
310
|
+
graph_module.recompile()
|
|
311
|
+
|
|
312
|
+
return PassResult(modified)
|
tico/passes/convert_to_relu6.py
CHANGED
|
@@ -244,10 +244,11 @@ class DecomposeFakeQuantizeTensorQParams(PassBase):
|
|
|
244
244
|
# So, let's remove `mask` from the output.args first.
|
|
245
245
|
# mask_user(output).args == (dequantize_per_tensor.tensor, mask)
|
|
246
246
|
if mask:
|
|
247
|
-
len(mask) == 1
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
247
|
+
assert len(mask) == 1
|
|
248
|
+
if len(mask[0].users) > 0:
|
|
249
|
+
mask_user = list(mask[0].users.keys())[0]
|
|
250
|
+
assert len(mask_user.args) == 1
|
|
251
|
+
mask_user.args = ((mask_user.args[0][0],),)
|
|
251
252
|
modified = True
|
|
252
253
|
if (
|
|
253
254
|
node.target
|
tico/passes/ops.py
CHANGED
|
@@ -21,7 +21,9 @@ from tico.utils.utils import is_target_node
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
assert_node_targets = [
|
|
24
|
+
torch.ops.aten._assert_scalar.default,
|
|
24
25
|
torch.ops.aten._assert_tensor_metadata.default,
|
|
26
|
+
torch.ops.aten.sym_constrain_range_for_size.default, # Related to symbolic shape validation
|
|
25
27
|
]
|
|
26
28
|
|
|
27
29
|
|
|
@@ -29,7 +31,7 @@ assert_node_targets = [
|
|
|
29
31
|
class RemoveRedundantAssertionNodes(PassBase):
|
|
30
32
|
"""
|
|
31
33
|
This removes redundant assertion nodes.
|
|
32
|
-
|
|
34
|
+
When assertion node is erased, related comparison nodes are also removed by graph.eliminate_dead_code().
|
|
33
35
|
"""
|
|
34
36
|
|
|
35
37
|
def __init__(self):
|
|
@@ -46,7 +46,9 @@ class RemoveRedundantExpand(PassBase):
|
|
|
46
46
|
input, size = args.input, args.size
|
|
47
47
|
|
|
48
48
|
input_shape = extract_shape(input)
|
|
49
|
-
|
|
49
|
+
output_shape = extract_shape(node)
|
|
50
|
+
|
|
51
|
+
if input_shape != output_shape:
|
|
50
52
|
continue
|
|
51
53
|
|
|
52
54
|
node.replace_all_uses_with(input, propagate_meta=False)
|
|
@@ -25,7 +25,7 @@ from typing import Optional
|
|
|
25
25
|
import torch
|
|
26
26
|
import torch.nn as nn
|
|
27
27
|
|
|
28
|
-
from tico.
|
|
28
|
+
from tico.quantization.algorithm.gptq.quant import quantize, Quantizer
|
|
29
29
|
|
|
30
30
|
torch.backends.cuda.matmul.allow_tf32 = False
|
|
31
31
|
torch.backends.cudnn.allow_tf32 = False
|
|
@@ -17,15 +17,17 @@ import types
|
|
|
17
17
|
from typing import Any, Callable, Dict, List, Optional
|
|
18
18
|
|
|
19
19
|
import torch
|
|
20
|
+
from tqdm.auto import tqdm
|
|
20
21
|
|
|
21
|
-
from tico.
|
|
22
|
-
from tico.
|
|
22
|
+
from tico.quantization.algorithm.gptq.gptq import GPTQ
|
|
23
|
+
from tico.quantization.algorithm.gptq.utils import (
|
|
23
24
|
find_layers,
|
|
24
25
|
gather_single_batch_from_dict,
|
|
25
26
|
gather_single_batch_from_list,
|
|
26
27
|
)
|
|
27
|
-
from tico.
|
|
28
|
-
from tico.
|
|
28
|
+
from tico.quantization.config.gptq import GPTQConfig
|
|
29
|
+
from tico.quantization.quantizer import BaseQuantizer
|
|
30
|
+
from tico.quantization.quantizer_registry import register_quantizer
|
|
29
31
|
|
|
30
32
|
|
|
31
33
|
class StopForward(Exception):
|
|
@@ -34,6 +36,7 @@ class StopForward(Exception):
|
|
|
34
36
|
pass
|
|
35
37
|
|
|
36
38
|
|
|
39
|
+
@register_quantizer(GPTQConfig)
|
|
37
40
|
class GPTQQuantizer(BaseQuantizer):
|
|
38
41
|
"""
|
|
39
42
|
Quantizer for applying the GPTQ algorithm (typically for weight quantization).
|
|
@@ -43,7 +46,7 @@ class GPTQQuantizer(BaseQuantizer):
|
|
|
43
46
|
3) convert(model) to consume the collected data and apply GPTQ.
|
|
44
47
|
"""
|
|
45
48
|
|
|
46
|
-
def __init__(self, config:
|
|
49
|
+
def __init__(self, config: GPTQConfig):
|
|
47
50
|
super().__init__(config)
|
|
48
51
|
|
|
49
52
|
# cache_args[i] -> list of the i-th positional argument for each batch
|
|
@@ -181,7 +184,14 @@ class GPTQQuantizer(BaseQuantizer):
|
|
|
181
184
|
target_layers = [model]
|
|
182
185
|
|
|
183
186
|
quantizers: Dict[str, Any] = {}
|
|
184
|
-
for l_idx, layer in enumerate(
|
|
187
|
+
for l_idx, layer in enumerate(
|
|
188
|
+
tqdm(
|
|
189
|
+
target_layers,
|
|
190
|
+
desc="Quantizing layers",
|
|
191
|
+
unit="layer",
|
|
192
|
+
disable=not gptq_conf.show_progress,
|
|
193
|
+
)
|
|
194
|
+
):
|
|
185
195
|
# 1) Identify quantizable submodules within the layer
|
|
186
196
|
full = find_layers(layer)
|
|
187
197
|
sequential = [list(full.keys())]
|
|
@@ -210,7 +220,13 @@ class GPTQQuantizer(BaseQuantizer):
|
|
|
210
220
|
|
|
211
221
|
# Run layer forward over all cached batches to build Hessian/statistics
|
|
212
222
|
batch_num = self.num_batches
|
|
213
|
-
for batch_idx in
|
|
223
|
+
for batch_idx in tqdm(
|
|
224
|
+
range(batch_num),
|
|
225
|
+
desc=f"[L{l_idx}] collecting",
|
|
226
|
+
leave=False,
|
|
227
|
+
unit="batch",
|
|
228
|
+
disable=not gptq_conf.show_progress,
|
|
229
|
+
):
|
|
214
230
|
cache_args_batch = gather_single_batch_from_list(
|
|
215
231
|
self.cache_args, batch_idx
|
|
216
232
|
)
|
|
@@ -238,7 +254,13 @@ class GPTQQuantizer(BaseQuantizer):
|
|
|
238
254
|
gptq[name].free()
|
|
239
255
|
|
|
240
256
|
# 4) After quantization, re-run the layer to produce outputs for the next layer
|
|
241
|
-
for batch_idx in
|
|
257
|
+
for batch_idx in tqdm(
|
|
258
|
+
range(batch_num),
|
|
259
|
+
desc=f"[L{l_idx}] re-forward",
|
|
260
|
+
leave=False,
|
|
261
|
+
unit="batch",
|
|
262
|
+
disable=not gptq_conf.show_progress,
|
|
263
|
+
):
|
|
242
264
|
cache_args_batch = gather_single_batch_from_list(
|
|
243
265
|
self.cache_args, batch_idx
|
|
244
266
|
)
|
|
@@ -25,14 +25,12 @@ from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObser
|
|
|
25
25
|
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
|
|
26
26
|
from torch.ao.quantization.quantizer.utils import _get_module_name_filter
|
|
27
27
|
|
|
28
|
-
from tico.
|
|
29
|
-
import tico.
|
|
30
|
-
import tico.
|
|
31
|
-
import tico.
|
|
32
|
-
from tico.
|
|
33
|
-
|
|
34
|
-
)
|
|
35
|
-
from tico.experimental.quantization.algorithm.pt2e.transformation.convert_scalars_to_attrs import (
|
|
28
|
+
from tico.quantization.algorithm.pt2e.annotation.op import *
|
|
29
|
+
import tico.quantization.algorithm.pt2e.annotation.spec as annot_spec
|
|
30
|
+
import tico.quantization.algorithm.pt2e.annotation.utils as annot_utils
|
|
31
|
+
import tico.quantization.algorithm.pt2e.utils as quant_utils
|
|
32
|
+
from tico.quantization.algorithm.pt2e.annotation.config import QuantizationConfig
|
|
33
|
+
from tico.quantization.algorithm.pt2e.transformation.convert_scalars_to_attrs import (
|
|
36
34
|
convert_scalars_to_attrs,
|
|
37
35
|
)
|
|
38
36
|
|