tico 0.1.0.dev250803__py3-none-any.whl → 0.1.0.dev251106__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_mixed_type_args.py +2 -0
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -4
- tico/passes/ops.py +0 -1
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- tico/passes/remove_redundant_expand.py +3 -1
- tico/quantization/__init__.py +6 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
- tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +30 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/quantization/config/ptq.py +119 -0
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +7 -16
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/quantization/evaluation/metric.py +146 -0
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/quantization/passes/__init__.py +1 -0
- tico/{experimental/quantization → quantization}/public_interface.py +11 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/__init__.py +1 -0
- tico/quantization/wrapq/dtypes.py +70 -0
- tico/quantization/wrapq/examples/__init__.py +1 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/quantization/wrapq/examples/quantize_linear.py +107 -0
- tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
- tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
- tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/quantization/wrapq/mode.py +32 -0
- tico/quantization/wrapq/observers/__init__.py +1 -0
- tico/quantization/wrapq/observers/affine_base.py +128 -0
- tico/quantization/wrapq/observers/base.py +98 -0
- tico/quantization/wrapq/observers/ema.py +62 -0
- tico/quantization/wrapq/observers/identity.py +74 -0
- tico/quantization/wrapq/observers/minmax.py +39 -0
- tico/quantization/wrapq/observers/mx.py +60 -0
- tico/quantization/wrapq/qscheme.py +40 -0
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/quantization/wrapq/utils/__init__.py +1 -0
- tico/quantization/wrapq/utils/introspection.py +167 -0
- tico/quantization/wrapq/utils/metrics.py +124 -0
- tico/quantization/wrapq/utils/reduce_utils.py +25 -0
- tico/quantization/wrapq/wrappers/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
- tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
- tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
- tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
- tico/quantization/wrapq/wrappers/nn/quant_silu.py +60 -0
- tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
- tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
- tico/quantization/wrapq/wrappers/registry.py +128 -0
- tico/serialize/circle_serializer.py +11 -4
- tico/serialize/operators/adapters/__init__.py +1 -0
- tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_mm.py +15 -132
- tico/serialize/operators/op_rmsnorm.py +65 -0
- tico/utils/convert.py +20 -15
- tico/utils/dtype.py +22 -0
- tico/utils/register_custom_op.py +29 -4
- tico/utils/signature.py +247 -0
- tico/utils/utils.py +50 -53
- tico/utils/validate_args_kwargs.py +37 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/METADATA +49 -2
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/RECORD +130 -73
- tico/experimental/quantization/__init__.py +0 -6
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/evaluation/metric.py +0 -109
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
tico/config/v1.py
CHANGED
|
@@ -20,6 +20,11 @@ from tico.config.base import CompileConfigBase
|
|
|
20
20
|
@dataclass
|
|
21
21
|
class CompileConfigV1(CompileConfigBase):
|
|
22
22
|
legalize_causal_mask_value: bool = False
|
|
23
|
+
remove_constant_input: bool = False
|
|
24
|
+
convert_lhs_const_mm_to_fc: bool = False
|
|
25
|
+
convert_rhs_const_mm_to_fc: bool = True
|
|
26
|
+
convert_single_batch_lhs_const_bmm_to_fc: bool = False
|
|
27
|
+
convert_expand_to_slice_cat: bool = False
|
|
23
28
|
|
|
24
29
|
def get(self, name: str):
|
|
25
30
|
return super().get(name)
|
|
@@ -41,6 +41,8 @@ ops_to_promote = {
|
|
|
41
41
|
torch.ops.aten.ge.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
|
42
42
|
torch.ops.aten.gt.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
|
43
43
|
torch.ops.aten.gt.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
|
44
|
+
torch.ops.aten.le.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
|
45
|
+
torch.ops.aten.le.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
|
44
46
|
torch.ops.aten.mul.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
|
45
47
|
torch.ops.aten.minimum.default: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
|
46
48
|
torch.ops.aten.ne.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import torch.fx
|
|
19
|
+
import torch
|
|
20
|
+
from torch.export import ExportedProgram
|
|
21
|
+
|
|
22
|
+
from tico.passes import ops
|
|
23
|
+
from tico.serialize.circle_mapping import extract_shape
|
|
24
|
+
from tico.utils import logging
|
|
25
|
+
from tico.utils.graph import create_node
|
|
26
|
+
from tico.utils.passes import PassBase, PassResult
|
|
27
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
|
28
|
+
from tico.utils.utils import is_target_node
|
|
29
|
+
from tico.utils.validate_args_kwargs import ExpandArgs, ReshapeArgs
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@trace_graph_diff_on_pass
|
|
33
|
+
class ConvertExpandToSliceCat(PassBase):
|
|
34
|
+
"""
|
|
35
|
+
This pass replaces `aten.reshape` + `aten.expand` pattern by rewriting it using
|
|
36
|
+
a series of `aten.slice` and `aten.cat` operations.
|
|
37
|
+
|
|
38
|
+
This pass is specialized for expand of KVCache.
|
|
39
|
+
- Expects (batch, num_key_value_heads, seq_len, head_dim) as input shape of reshape
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, enabled: bool = False):
|
|
43
|
+
super().__init__()
|
|
44
|
+
self.enabled = enabled
|
|
45
|
+
|
|
46
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
|
47
|
+
if not self.enabled:
|
|
48
|
+
return PassResult(False)
|
|
49
|
+
|
|
50
|
+
logger = logging.getLogger(__name__)
|
|
51
|
+
|
|
52
|
+
graph_module = exported_program.graph_module
|
|
53
|
+
graph = graph_module.graph
|
|
54
|
+
modified = False
|
|
55
|
+
|
|
56
|
+
# This pass handles expand on EXPAND_DIM only
|
|
57
|
+
CAT_DIM = 1
|
|
58
|
+
EXPAND_DIM = 2
|
|
59
|
+
|
|
60
|
+
for node in graph.nodes:
|
|
61
|
+
if not isinstance(node, torch.fx.Node) or not is_target_node(
|
|
62
|
+
node, ops.aten.reshape
|
|
63
|
+
):
|
|
64
|
+
continue
|
|
65
|
+
|
|
66
|
+
post_reshape = node
|
|
67
|
+
post_reshape_args = ReshapeArgs(*post_reshape.args, **post_reshape.kwargs)
|
|
68
|
+
post_reshape_input = post_reshape_args.input
|
|
69
|
+
|
|
70
|
+
if not isinstance(post_reshape_input, torch.fx.Node) or not is_target_node(
|
|
71
|
+
post_reshape_input, ops.aten.expand
|
|
72
|
+
):
|
|
73
|
+
continue
|
|
74
|
+
|
|
75
|
+
expand = post_reshape_input
|
|
76
|
+
expand_args = ExpandArgs(*expand.args, **expand.kwargs)
|
|
77
|
+
expand_input = expand_args.input
|
|
78
|
+
expand_shape = extract_shape(expand)
|
|
79
|
+
|
|
80
|
+
if not isinstance(expand_input, torch.fx.Node) or not is_target_node(
|
|
81
|
+
expand_input, ops.aten.reshape
|
|
82
|
+
):
|
|
83
|
+
continue
|
|
84
|
+
|
|
85
|
+
pre_reshape = expand_input
|
|
86
|
+
pre_reshape_args = ReshapeArgs(*pre_reshape.args, **pre_reshape.kwargs)
|
|
87
|
+
pre_reshape_input = pre_reshape_args.input
|
|
88
|
+
pre_reshape_shape = extract_shape(pre_reshape)
|
|
89
|
+
|
|
90
|
+
if pre_reshape_shape[EXPAND_DIM] != 1:
|
|
91
|
+
continue
|
|
92
|
+
|
|
93
|
+
reshape_input_shape = extract_shape(pre_reshape_input)
|
|
94
|
+
|
|
95
|
+
if len(expand_shape) != len(pre_reshape_shape):
|
|
96
|
+
continue
|
|
97
|
+
|
|
98
|
+
# Ensure all dimensions *except* at EXPAND_DIM are identical.
|
|
99
|
+
if not (
|
|
100
|
+
expand_shape[:EXPAND_DIM] == pre_reshape_shape[:EXPAND_DIM]
|
|
101
|
+
and expand_shape[EXPAND_DIM + 1 :]
|
|
102
|
+
== pre_reshape_shape[EXPAND_DIM + 1 :]
|
|
103
|
+
):
|
|
104
|
+
continue
|
|
105
|
+
|
|
106
|
+
# Ensure the expansion dimension is a clean multiple.
|
|
107
|
+
if expand_shape[EXPAND_DIM] % pre_reshape_shape[EXPAND_DIM] != 0:
|
|
108
|
+
continue
|
|
109
|
+
|
|
110
|
+
expand_ratio = expand_shape[EXPAND_DIM] // pre_reshape_shape[EXPAND_DIM]
|
|
111
|
+
|
|
112
|
+
if expand_ratio <= 1:
|
|
113
|
+
continue
|
|
114
|
+
|
|
115
|
+
cat_nodes = []
|
|
116
|
+
|
|
117
|
+
for i in range(reshape_input_shape[CAT_DIM]):
|
|
118
|
+
with graph.inserting_before(expand):
|
|
119
|
+
slice_copy_args = (pre_reshape_input, CAT_DIM, i, i + 1, 1)
|
|
120
|
+
slice_node = create_node(
|
|
121
|
+
graph,
|
|
122
|
+
torch.ops.aten.slice.Tensor,
|
|
123
|
+
args=slice_copy_args,
|
|
124
|
+
origin=expand,
|
|
125
|
+
)
|
|
126
|
+
with graph.inserting_after(slice_node):
|
|
127
|
+
cat_args = ([slice_node] * expand_ratio, CAT_DIM)
|
|
128
|
+
cat_node = create_node(
|
|
129
|
+
graph,
|
|
130
|
+
torch.ops.aten.cat.default,
|
|
131
|
+
args=cat_args,
|
|
132
|
+
origin=expand,
|
|
133
|
+
)
|
|
134
|
+
cat_nodes.append(cat_node)
|
|
135
|
+
|
|
136
|
+
with graph.inserting_after(expand):
|
|
137
|
+
cat_args = (cat_nodes, CAT_DIM)
|
|
138
|
+
cat_node = create_node(
|
|
139
|
+
graph,
|
|
140
|
+
torch.ops.aten.cat.default,
|
|
141
|
+
args=cat_args,
|
|
142
|
+
origin=expand,
|
|
143
|
+
)
|
|
144
|
+
expand.replace_all_uses_with(cat_node)
|
|
145
|
+
|
|
146
|
+
modified = True
|
|
147
|
+
logger.debug(f"{expand.name} is replaced with {cat_node.name} operators")
|
|
148
|
+
|
|
149
|
+
graph.eliminate_dead_code()
|
|
150
|
+
graph.lint()
|
|
151
|
+
graph_module.recompile()
|
|
152
|
+
|
|
153
|
+
return PassResult(modified)
|
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import List, Optional, TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import torch.fx
|
|
19
|
+
import torch
|
|
20
|
+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
|
|
21
|
+
from torch.export import ExportedProgram
|
|
22
|
+
|
|
23
|
+
from tico.serialize.circle_mapping import extract_shape
|
|
24
|
+
|
|
25
|
+
from tico.utils import logging
|
|
26
|
+
from tico.utils.graph import create_node
|
|
27
|
+
from tico.utils.passes import PassBase, PassResult
|
|
28
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
|
29
|
+
from tico.utils.validate_args_kwargs import BmmArgs, MatmulArgs
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Converter: # type: ignore[empty-body]
|
|
33
|
+
def __init__(self):
|
|
34
|
+
super().__init__()
|
|
35
|
+
|
|
36
|
+
def match(self, exported_program, node) -> bool: # type: ignore[empty-body]
|
|
37
|
+
return False
|
|
38
|
+
|
|
39
|
+
def convert(self, exported_program, node) -> torch.fx.Node: # type: ignore[empty-body]
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class MatmulToLinearConverter(Converter):
|
|
44
|
+
def __init__(self):
|
|
45
|
+
super().__init__()
|
|
46
|
+
|
|
47
|
+
def convert(self, exported_program, node) -> torch.fx.Node:
|
|
48
|
+
graph_module = exported_program.graph_module
|
|
49
|
+
graph = graph_module.graph
|
|
50
|
+
|
|
51
|
+
mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
52
|
+
|
|
53
|
+
lhs = mm_args.input
|
|
54
|
+
rhs = mm_args.other
|
|
55
|
+
|
|
56
|
+
with graph.inserting_before(node):
|
|
57
|
+
transpose_node = create_node(
|
|
58
|
+
graph,
|
|
59
|
+
torch.ops.aten.permute.default,
|
|
60
|
+
args=(rhs, [1, 0]),
|
|
61
|
+
)
|
|
62
|
+
linear_node = create_node(
|
|
63
|
+
graph,
|
|
64
|
+
torch.ops.aten.linear.default,
|
|
65
|
+
args=(lhs, transpose_node),
|
|
66
|
+
)
|
|
67
|
+
node.replace_all_uses_with(linear_node, propagate_meta=True)
|
|
68
|
+
|
|
69
|
+
return linear_node
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class RhsConstMatmulToLinearConverter(MatmulToLinearConverter):
|
|
73
|
+
def __init__(self):
|
|
74
|
+
super().__init__()
|
|
75
|
+
|
|
76
|
+
def match(self, exported_program, node) -> bool:
|
|
77
|
+
if not node.target == torch.ops.aten.mm.default:
|
|
78
|
+
return False
|
|
79
|
+
|
|
80
|
+
mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
81
|
+
|
|
82
|
+
rhs = mm_args.other
|
|
83
|
+
if isinstance(rhs, torch.fx.Node):
|
|
84
|
+
if is_lifted_tensor_constant(exported_program, rhs):
|
|
85
|
+
return True
|
|
86
|
+
elif is_param(exported_program, rhs):
|
|
87
|
+
return True
|
|
88
|
+
elif is_buffer(exported_program, rhs):
|
|
89
|
+
return True
|
|
90
|
+
else:
|
|
91
|
+
return False
|
|
92
|
+
return False
|
|
93
|
+
|
|
94
|
+
def convert(self, exported_program, node) -> torch.fx.Node:
|
|
95
|
+
return super().convert(exported_program, node)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class LhsConstMatmulToLinearConverter(MatmulToLinearConverter):
|
|
99
|
+
def __init__(self):
|
|
100
|
+
super().__init__()
|
|
101
|
+
|
|
102
|
+
def match(self, exported_program, node) -> bool:
|
|
103
|
+
if not node.target == torch.ops.aten.mm.default:
|
|
104
|
+
return False
|
|
105
|
+
|
|
106
|
+
mm_args = MatmulArgs(*node.args, **node.kwargs)
|
|
107
|
+
lhs = mm_args.input
|
|
108
|
+
if isinstance(lhs, torch.fx.Node):
|
|
109
|
+
if is_lifted_tensor_constant(exported_program, lhs):
|
|
110
|
+
return True
|
|
111
|
+
elif is_param(exported_program, lhs):
|
|
112
|
+
return True
|
|
113
|
+
elif is_buffer(exported_program, lhs):
|
|
114
|
+
return True
|
|
115
|
+
return False
|
|
116
|
+
|
|
117
|
+
def convert(self, exported_program, node) -> torch.fx.Node:
|
|
118
|
+
return super().convert(exported_program, node)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class SingleBatchLhsConstBmmToLinearConverter(Converter):
|
|
122
|
+
"""
|
|
123
|
+
Convert `single-batched & lhs-const BatchMatMul` to `linear` operation.
|
|
124
|
+
|
|
125
|
+
[1] exchange lhs and rhs
|
|
126
|
+
[2] transpose rhs
|
|
127
|
+
[3] transpose output
|
|
128
|
+
|
|
129
|
+
**Before**
|
|
130
|
+
|
|
131
|
+
lhs[1,a,b](const) rhs[1,b,c]
|
|
132
|
+
| |
|
|
133
|
+
| |
|
|
134
|
+
---------bmm---------
|
|
135
|
+
|
|
|
136
|
+
output[1,a,c]
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
**After**
|
|
140
|
+
|
|
141
|
+
rhs[1,b,c]
|
|
142
|
+
|
|
|
143
|
+
tr lhs'[a,b](const-folded)
|
|
144
|
+
|[1,c,b] |
|
|
145
|
+
| |
|
|
146
|
+
---------fc--------
|
|
147
|
+
|[1,c,a]
|
|
148
|
+
tr
|
|
149
|
+
|
|
|
150
|
+
output[1,a,c]
|
|
151
|
+
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
def __init__(self):
|
|
155
|
+
super().__init__()
|
|
156
|
+
|
|
157
|
+
def match(self, exported_program, node) -> bool:
|
|
158
|
+
if not node.target == torch.ops.aten.bmm.default:
|
|
159
|
+
return False
|
|
160
|
+
|
|
161
|
+
bmm_args = BmmArgs(*node.args, **node.kwargs)
|
|
162
|
+
lhs = bmm_args.input
|
|
163
|
+
rhs = bmm_args.mat2
|
|
164
|
+
|
|
165
|
+
# [1] Single-batch
|
|
166
|
+
lhs_shape = extract_shape(lhs)
|
|
167
|
+
rhs_shape = extract_shape(rhs)
|
|
168
|
+
|
|
169
|
+
assert len(lhs_shape) == len(
|
|
170
|
+
rhs_shape
|
|
171
|
+
), f"Bmm input's ranks must be the same but got {lhs_shape} and {rhs_shape}"
|
|
172
|
+
|
|
173
|
+
if not (lhs_shape[0] == rhs_shape[0] == 1):
|
|
174
|
+
return False
|
|
175
|
+
|
|
176
|
+
# [2] Lhs is constant
|
|
177
|
+
if not isinstance(lhs, torch.fx.Node):
|
|
178
|
+
return False
|
|
179
|
+
if not (
|
|
180
|
+
is_lifted_tensor_constant(exported_program, lhs)
|
|
181
|
+
or is_param(exported_program, lhs)
|
|
182
|
+
or is_buffer(exported_program, lhs)
|
|
183
|
+
):
|
|
184
|
+
return False
|
|
185
|
+
|
|
186
|
+
return True
|
|
187
|
+
|
|
188
|
+
def convert(self, exported_program, node) -> torch.fx.Node:
|
|
189
|
+
graph_module = exported_program.graph_module
|
|
190
|
+
graph = graph_module.graph
|
|
191
|
+
|
|
192
|
+
bmm_args = BmmArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
193
|
+
|
|
194
|
+
lhs = bmm_args.input # const
|
|
195
|
+
rhs = bmm_args.mat2 # non-const
|
|
196
|
+
lhs_shape = extract_shape(lhs)
|
|
197
|
+
rhs_shape = extract_shape(rhs)
|
|
198
|
+
assert rhs_shape[0] == 1
|
|
199
|
+
assert lhs_shape[0] == 1
|
|
200
|
+
|
|
201
|
+
with graph.inserting_before(node):
|
|
202
|
+
rhs_tr = create_node(
|
|
203
|
+
graph,
|
|
204
|
+
torch.ops.aten.permute.default,
|
|
205
|
+
args=(rhs, [0, 2, 1]),
|
|
206
|
+
)
|
|
207
|
+
lhs_reshape = create_node(
|
|
208
|
+
graph,
|
|
209
|
+
torch.ops.aten.view.default,
|
|
210
|
+
args=(lhs, list(lhs_shape[1:])),
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
linear_node = create_node(
|
|
214
|
+
graph,
|
|
215
|
+
torch.ops.aten.linear.default,
|
|
216
|
+
args=(rhs_tr, lhs_reshape),
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
tr_linear_node = create_node(
|
|
220
|
+
graph,
|
|
221
|
+
torch.ops.aten.permute.default,
|
|
222
|
+
args=(linear_node, [0, 2, 1]),
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
node.replace_all_uses_with(tr_linear_node, propagate_meta=False)
|
|
226
|
+
|
|
227
|
+
return tr_linear_node
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
@trace_graph_diff_on_pass
|
|
231
|
+
class ConvertMatmulToLinear(PassBase):
|
|
232
|
+
"""
|
|
233
|
+
This pass converts matmul(partially includes single-batch bmm) to linear selectively
|
|
234
|
+
|
|
235
|
+
How to select between `matmul` and `linear`?
|
|
236
|
+
|
|
237
|
+
* Linear has better quantization accuracy (NPU backend)
|
|
238
|
+
Due to ONE compiler's quantization policy;
|
|
239
|
+
FullyConnected(=Linear) uses per-channel quantization for weight and per-tensor for input.
|
|
240
|
+
BatchMatmul(=matmul) uses per-tensor quantization for both rhs and lhs.
|
|
241
|
+
|
|
242
|
+
* Matmul to Linear requires Transpose, which may harm latency
|
|
243
|
+
When RHS is constant, addtional transpose can be folded.
|
|
244
|
+
|
|
245
|
+
[RHS non-const case]
|
|
246
|
+
Constant folding cannot be performed.
|
|
247
|
+
|
|
248
|
+
lhs rhs (non-const)
|
|
249
|
+
| |
|
|
250
|
+
| transpose
|
|
251
|
+
| |
|
|
252
|
+
-- linear --
|
|
253
|
+
|
|
|
254
|
+
out
|
|
255
|
+
|
|
256
|
+
[RHS const case]
|
|
257
|
+
Constant folding can be performed to
|
|
258
|
+
|
|
259
|
+
lhs rhs (const) lh rhs (folded const)
|
|
260
|
+
| | | |
|
|
261
|
+
| transpose | |
|
|
262
|
+
| | | |
|
|
263
|
+
-- linear -- --> -- linear --
|
|
264
|
+
| |
|
|
265
|
+
out out
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
enable_lhs_const: If true, convert matmul where LHS is constant tensor. Default is False.
|
|
269
|
+
enable_rhs_const: If true, convert matmul where RHS is constant tensor. Default is True.
|
|
270
|
+
"""
|
|
271
|
+
|
|
272
|
+
def __init__(
|
|
273
|
+
self,
|
|
274
|
+
enable_lhs_const: Optional[bool] = False,
|
|
275
|
+
enable_rhs_const: Optional[bool] = True,
|
|
276
|
+
enable_single_batch_lhs_const_bmm: Optional[bool] = False,
|
|
277
|
+
):
|
|
278
|
+
super().__init__()
|
|
279
|
+
self.converters: List[Converter] = []
|
|
280
|
+
if enable_lhs_const:
|
|
281
|
+
self.converters.append(LhsConstMatmulToLinearConverter())
|
|
282
|
+
if enable_rhs_const:
|
|
283
|
+
self.converters.append(RhsConstMatmulToLinearConverter())
|
|
284
|
+
if enable_single_batch_lhs_const_bmm:
|
|
285
|
+
self.converters.append(SingleBatchLhsConstBmmToLinearConverter())
|
|
286
|
+
|
|
287
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
|
288
|
+
logger = logging.getLogger(__name__)
|
|
289
|
+
|
|
290
|
+
graph_module = exported_program.graph_module
|
|
291
|
+
graph = graph_module.graph
|
|
292
|
+
modified = False
|
|
293
|
+
for node in graph.nodes:
|
|
294
|
+
if not node.op == "call_function":
|
|
295
|
+
continue
|
|
296
|
+
|
|
297
|
+
for converter in self.converters:
|
|
298
|
+
if not converter.match(exported_program, node):
|
|
299
|
+
continue
|
|
300
|
+
|
|
301
|
+
new_node = converter.convert(exported_program, node)
|
|
302
|
+
modified = True
|
|
303
|
+
logger.debug(
|
|
304
|
+
f"{node.name} is replaced with {new_node.name} operator (permute + linear)"
|
|
305
|
+
)
|
|
306
|
+
continue
|
|
307
|
+
|
|
308
|
+
graph.eliminate_dead_code()
|
|
309
|
+
graph.lint()
|
|
310
|
+
graph_module.recompile()
|
|
311
|
+
|
|
312
|
+
return PassResult(modified)
|
tico/passes/convert_to_relu6.py
CHANGED
|
@@ -244,10 +244,11 @@ class DecomposeFakeQuantizeTensorQParams(PassBase):
|
|
|
244
244
|
# So, let's remove `mask` from the output.args first.
|
|
245
245
|
# mask_user(output).args == (dequantize_per_tensor.tensor, mask)
|
|
246
246
|
if mask:
|
|
247
|
-
len(mask) == 1
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
247
|
+
assert len(mask) == 1
|
|
248
|
+
if len(mask[0].users) > 0:
|
|
249
|
+
mask_user = list(mask[0].users.keys())[0]
|
|
250
|
+
assert len(mask_user.args) == 1
|
|
251
|
+
mask_user.args = ((mask_user.args[0][0],),)
|
|
251
252
|
modified = True
|
|
252
253
|
if (
|
|
253
254
|
node.target
|
tico/passes/ops.py
CHANGED
|
@@ -21,7 +21,9 @@ from tico.utils.utils import is_target_node
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
assert_node_targets = [
|
|
24
|
+
torch.ops.aten._assert_scalar.default,
|
|
24
25
|
torch.ops.aten._assert_tensor_metadata.default,
|
|
26
|
+
torch.ops.aten.sym_constrain_range_for_size.default, # Related to symbolic shape validation
|
|
25
27
|
]
|
|
26
28
|
|
|
27
29
|
|
|
@@ -29,7 +31,7 @@ assert_node_targets = [
|
|
|
29
31
|
class RemoveRedundantAssertionNodes(PassBase):
|
|
30
32
|
"""
|
|
31
33
|
This removes redundant assertion nodes.
|
|
32
|
-
|
|
34
|
+
When assertion node is erased, related comparison nodes are also removed by graph.eliminate_dead_code().
|
|
33
35
|
"""
|
|
34
36
|
|
|
35
37
|
def __init__(self):
|
|
@@ -46,7 +46,9 @@ class RemoveRedundantExpand(PassBase):
|
|
|
46
46
|
input, size = args.input, args.size
|
|
47
47
|
|
|
48
48
|
input_shape = extract_shape(input)
|
|
49
|
-
|
|
49
|
+
output_shape = extract_shape(node)
|
|
50
|
+
|
|
51
|
+
if input_shape != output_shape:
|
|
50
52
|
continue
|
|
51
53
|
|
|
52
54
|
node.replace_all_uses_with(input, propagate_meta=False)
|
|
@@ -25,7 +25,7 @@ from typing import Optional
|
|
|
25
25
|
import torch
|
|
26
26
|
import torch.nn as nn
|
|
27
27
|
|
|
28
|
-
from tico.
|
|
28
|
+
from tico.quantization.algorithm.gptq.quant import quantize, Quantizer
|
|
29
29
|
|
|
30
30
|
torch.backends.cuda.matmul.allow_tf32 = False
|
|
31
31
|
torch.backends.cudnn.allow_tf32 = False
|
|
@@ -36,6 +36,11 @@ class GPTQ:
|
|
|
36
36
|
self.layer = layer
|
|
37
37
|
self.dev = self.layer.weight.device
|
|
38
38
|
W = layer.weight.data.clone()
|
|
39
|
+
if isinstance(self.layer, nn.Conv2d):
|
|
40
|
+
W = W.flatten(1)
|
|
41
|
+
|
|
42
|
+
if isinstance(self.layer, nn.Conv1d):
|
|
43
|
+
W = W.t()
|
|
39
44
|
self.rows = W.shape[0]
|
|
40
45
|
self.columns = W.shape[1]
|
|
41
46
|
self.H: Optional[torch.Tensor] = torch.zeros(
|
|
@@ -48,10 +53,22 @@ class GPTQ:
|
|
|
48
53
|
if len(inp.shape) == 2:
|
|
49
54
|
inp = inp.unsqueeze(0)
|
|
50
55
|
tmp = inp.shape[0]
|
|
51
|
-
if isinstance(self.layer, nn.Linear):
|
|
52
|
-
if len(inp.shape)
|
|
56
|
+
if isinstance(self.layer, nn.Linear) or isinstance(self.layer, nn.Conv1d):
|
|
57
|
+
if len(inp.shape) > 2:
|
|
53
58
|
inp = inp.reshape((-1, inp.shape[-1]))
|
|
54
59
|
inp = inp.t()
|
|
60
|
+
if isinstance(self.layer, nn.Conv2d):
|
|
61
|
+
unfold = nn.Unfold(
|
|
62
|
+
self.layer.kernel_size,
|
|
63
|
+
dilation=self.layer.dilation,
|
|
64
|
+
padding=self.layer.padding,
|
|
65
|
+
stride=self.layer.stride,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
inp = unfold(inp)
|
|
69
|
+
inp = inp.permute([1, 0, 2])
|
|
70
|
+
inp = inp.flatten(1)
|
|
71
|
+
|
|
55
72
|
self.H *= self.nsamples / (self.nsamples + tmp)
|
|
56
73
|
self.nsamples += tmp
|
|
57
74
|
inp = math.sqrt(2 / self.nsamples) * inp.float()
|
|
@@ -67,6 +84,10 @@ class GPTQ:
|
|
|
67
84
|
verbose=False,
|
|
68
85
|
):
|
|
69
86
|
W = self.layer.weight.data.clone()
|
|
87
|
+
if isinstance(self.layer, nn.Conv2d):
|
|
88
|
+
W = W.flatten(1)
|
|
89
|
+
if isinstance(self.layer, nn.Conv1d):
|
|
90
|
+
W = W.t()
|
|
70
91
|
W = W.float()
|
|
71
92
|
tick = time.time()
|
|
72
93
|
if not self.quantizer.ready():
|