tico 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +42 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/quantize_bias.py +123 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +191 -0
- tico/passes/cast_mixed_type_args.py +187 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +160 -0
- tico/passes/convert_layout_op_to_reshape.py +85 -0
- tico/passes/convert_repeat_to_expand_copy.py +89 -0
- tico/passes/convert_to_relu6.py +181 -0
- tico/passes/decompose_addmm.py +124 -0
- tico/passes/decompose_batch_norm.py +192 -0
- tico/passes/decompose_fake_quantize.py +134 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
- tico/passes/decompose_group_norm.py +275 -0
- tico/passes/decompose_grouped_conv2d.py +209 -0
- tico/passes/decompose_slice_scatter.py +169 -0
- tico/passes/extract_dtype_kwargs.py +122 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +108 -0
- tico/passes/legalize_predefined_layout_operators.py +386 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
- tico/passes/lower_to_slice.py +230 -0
- tico/passes/merge_consecutive_cat.py +80 -0
- tico/passes/ops.py +78 -0
- tico/passes/remove_nop.py +84 -0
- tico/passes/remove_redundant_assert_nodes.py +51 -0
- tico/passes/remove_redundant_expand.py +66 -0
- tico/passes/remove_redundant_permute.py +122 -0
- tico/passes/remove_redundant_reshape.py +436 -0
- tico/passes/remove_redundant_slice.py +62 -0
- tico/passes/remove_redundant_to_copy.py +86 -0
- tico/passes/restore_linear.py +115 -0
- tico/passes/segment_index_select.py +145 -0
- tico/pt2_to_circle.py +105 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +319 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +240 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_abs.py +53 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +150 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +192 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +126 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +186 -0
- tico/serialize/operators/op_copy.py +164 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +95 -0
- tico/serialize/operators/op_depthwise_conv2d.py +199 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_leaky_relu.py +60 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +86 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_dim.py +70 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +177 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +141 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +100 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +99 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +94 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +296 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +282 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/mx/__init__.py +1 -0
- tico/utils/mx/elemwise_ops.py +267 -0
- tico/utils/mx/formats.py +125 -0
- tico/utils/mx/mx_ops.py +270 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +609 -0
- tico/utils/serialize.py +42 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +406 -0
- tico/utils/validate_args_kwargs.py +1149 -0
- tico-0.1.0.dist-info/LICENSE +241 -0
- tico-0.1.0.dist-info/METADATA +354 -0
- tico-0.1.0.dist-info/RECORD +206 -0
- tico-0.1.0.dist-info/WHEEL +5 -0
- tico-0.1.0.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,91 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import copy
|
20
|
+
|
21
|
+
import torch
|
22
|
+
from torch.export import ExportedProgram
|
23
|
+
|
24
|
+
from tico.serialize.quant_param import QPARAM_KEY
|
25
|
+
from tico.utils import logging
|
26
|
+
from tico.utils.passes import PassBase, PassResult
|
27
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
|
+
from tico.utils.validate_args_kwargs import CatArgs, PermuteArgs, ReshapeArgs
|
29
|
+
|
30
|
+
|
31
|
+
@trace_graph_diff_on_pass
|
32
|
+
class PropagateQParamBackward(PassBase):
|
33
|
+
"""
|
34
|
+
This pass propagates quantization parameters backward.
|
35
|
+
|
36
|
+
BEFORE)
|
37
|
+
|
38
|
+
node -> reshape (with meta[QPARAM_KEY])
|
39
|
+
|
40
|
+
AFTER)
|
41
|
+
|
42
|
+
node (with meta[QPARAM_KEY]) -> reshape (with meta[QPARAM_KEY])
|
43
|
+
"""
|
44
|
+
|
45
|
+
def __init__(self):
|
46
|
+
super().__init__()
|
47
|
+
|
48
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
49
|
+
logger = logging.getLogger(__name__)
|
50
|
+
|
51
|
+
graph_module = exported_program.graph_module
|
52
|
+
graph: torch.fx.Graph = graph_module.graph
|
53
|
+
|
54
|
+
def _propagate_qparam_if_possible(src: torch.fx.Node, dst: torch.fx.Node):
|
55
|
+
if QPARAM_KEY not in src.meta:
|
56
|
+
return
|
57
|
+
|
58
|
+
if (
|
59
|
+
QPARAM_KEY in dst.meta
|
60
|
+
and src.meta[QPARAM_KEY].dtype != dst.meta[QPARAM_KEY].dtype
|
61
|
+
):
|
62
|
+
return
|
63
|
+
|
64
|
+
dst.meta[QPARAM_KEY] = copy.deepcopy(src.meta[QPARAM_KEY])
|
65
|
+
|
66
|
+
logger.debug(f"{src.name}'s quantparam is propagated to {dst.name}.")
|
67
|
+
|
68
|
+
# Do reverse-order traversal for backward propagation
|
69
|
+
for node in reversed(graph.nodes):
|
70
|
+
if node.op != "call_function":
|
71
|
+
continue
|
72
|
+
if node.target == torch.ops.aten.cat.default:
|
73
|
+
concat_args = CatArgs(*node.args, **node.kwargs)
|
74
|
+
concat_inputs = concat_args.tensors
|
75
|
+
|
76
|
+
for concat_input in concat_inputs:
|
77
|
+
_propagate_qparam_if_possible(node, concat_input)
|
78
|
+
elif node.target == torch.ops.aten.reshape.default:
|
79
|
+
args = ReshapeArgs(*node.args, **node.kwargs)
|
80
|
+
_propagate_qparam_if_possible(node, args.input)
|
81
|
+
elif node.target == torch.ops.aten.permute.default:
|
82
|
+
permute_args = PermuteArgs(*node.args, **node.kwargs)
|
83
|
+
_propagate_qparam_if_possible(node, permute_args.input)
|
84
|
+
# TODO Support more ops.
|
85
|
+
|
86
|
+
graph.eliminate_dead_code()
|
87
|
+
graph.lint()
|
88
|
+
graph_module.recompile()
|
89
|
+
|
90
|
+
# Run only once.
|
91
|
+
return PassResult(False)
|
@@ -0,0 +1,141 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import copy
|
20
|
+
|
21
|
+
import torch
|
22
|
+
from torch.export import ExportedProgram
|
23
|
+
|
24
|
+
from tico.serialize.quant_param import QPARAM_KEY
|
25
|
+
from tico.utils import logging
|
26
|
+
from tico.utils.passes import PassBase, PassResult
|
27
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
|
+
from tico.utils.validate_args_kwargs import (
|
29
|
+
CatArgs,
|
30
|
+
NegArgs,
|
31
|
+
PermuteArgs,
|
32
|
+
ReshapeArgs,
|
33
|
+
SliceArgs,
|
34
|
+
)
|
35
|
+
|
36
|
+
|
37
|
+
@trace_graph_diff_on_pass
|
38
|
+
class PropagateQParamForward(PassBase):
|
39
|
+
"""
|
40
|
+
A pass propagates quantization parameters through operations that do not alter them.
|
41
|
+
|
42
|
+
This pass identifies and propagates quantization parameters through operations that
|
43
|
+
do not change their values, such as `permute`, `reshape`, `transpose`, `view` and
|
44
|
+
similar tensor transformations.
|
45
|
+
|
46
|
+
By ensuring that quantization parameters remain consistent across such operations,
|
47
|
+
this pass helps maintain correctness in quantization-aware representations.
|
48
|
+
"""
|
49
|
+
|
50
|
+
def __init__(self):
|
51
|
+
super().__init__()
|
52
|
+
|
53
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
54
|
+
logger = logging.getLogger(__name__)
|
55
|
+
|
56
|
+
def _propagate_qparam_if_possible(src: torch.fx.Node, dst: torch.fx.Node):
|
57
|
+
if QPARAM_KEY not in src.meta:
|
58
|
+
return
|
59
|
+
|
60
|
+
if (
|
61
|
+
QPARAM_KEY in dst.meta
|
62
|
+
and src.meta[QPARAM_KEY].dtype != dst.meta[QPARAM_KEY].dtype
|
63
|
+
):
|
64
|
+
return
|
65
|
+
|
66
|
+
dst.meta[QPARAM_KEY] = copy.deepcopy(src.meta[QPARAM_KEY])
|
67
|
+
|
68
|
+
logger.debug(f"{src.name}'s quantparam is propagated to {dst.name}.")
|
69
|
+
|
70
|
+
graph_module = exported_program.graph_module
|
71
|
+
graph: torch.fx.Graph = graph_module.graph
|
72
|
+
for node in graph.nodes:
|
73
|
+
if node.op != "call_function":
|
74
|
+
continue
|
75
|
+
if node.target == torch.ops.aten.permute.default:
|
76
|
+
permute_args = PermuteArgs(*node.args, **node.kwargs)
|
77
|
+
_propagate_qparam_if_possible(permute_args.input, node)
|
78
|
+
elif node.target == torch.ops.aten.reshape.default:
|
79
|
+
reshape_args = ReshapeArgs(*node.args, **node.kwargs)
|
80
|
+
_propagate_qparam_if_possible(reshape_args.input, node)
|
81
|
+
elif node.target == torch.ops.aten.slice.Tensor:
|
82
|
+
slice_args = SliceArgs(*node.args, **node.kwargs)
|
83
|
+
_propagate_qparam_if_possible(slice_args.input, node)
|
84
|
+
elif node.target == torch.ops.aten.neg.default:
|
85
|
+
neg_args = NegArgs(*node.args, **node.kwargs)
|
86
|
+
|
87
|
+
if QPARAM_KEY not in neg_args.input.meta:
|
88
|
+
continue
|
89
|
+
# Only support int16 for now
|
90
|
+
if neg_args.input.meta[QPARAM_KEY].dtype != "int16":
|
91
|
+
continue
|
92
|
+
|
93
|
+
_propagate_qparam_if_possible(neg_args.input, node)
|
94
|
+
|
95
|
+
elif node.target == torch.ops.aten.cat.default:
|
96
|
+
concat_args = CatArgs(*node.args, **node.kwargs)
|
97
|
+
concat_inputs = concat_args.tensors
|
98
|
+
|
99
|
+
cond = True
|
100
|
+
for concat_input in concat_inputs:
|
101
|
+
# Check all inputs have qparam
|
102
|
+
if QPARAM_KEY not in concat_input.meta:
|
103
|
+
cond = False
|
104
|
+
break
|
105
|
+
|
106
|
+
# Only support int16 for now
|
107
|
+
if concat_input.meta[QPARAM_KEY].dtype != "int16":
|
108
|
+
cond = False
|
109
|
+
break
|
110
|
+
|
111
|
+
if concat_input.meta[QPARAM_KEY].scale is None:
|
112
|
+
cond = False
|
113
|
+
break
|
114
|
+
|
115
|
+
if len(concat_input.meta[QPARAM_KEY].scale) != 1:
|
116
|
+
cond = False
|
117
|
+
break
|
118
|
+
|
119
|
+
if not cond:
|
120
|
+
continue
|
121
|
+
|
122
|
+
# Find max scale node
|
123
|
+
max_scale = 0.0
|
124
|
+
max_scale_node = None
|
125
|
+
for concat_input in concat_inputs:
|
126
|
+
scale = concat_input.meta[QPARAM_KEY].scale[0]
|
127
|
+
if max_scale < scale:
|
128
|
+
max_scale = scale
|
129
|
+
max_scale_node = concat_input
|
130
|
+
|
131
|
+
assert max_scale_node is not None
|
132
|
+
_propagate_qparam_if_possible(max_scale_node, node)
|
133
|
+
|
134
|
+
# TODO Support more ops.
|
135
|
+
|
136
|
+
graph.eliminate_dead_code()
|
137
|
+
graph.lint()
|
138
|
+
graph_module.recompile()
|
139
|
+
|
140
|
+
# Run only once.
|
141
|
+
return PassResult(False)
|
@@ -0,0 +1,123 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import copy
|
20
|
+
|
21
|
+
import torch
|
22
|
+
from torch.export import ExportedProgram
|
23
|
+
|
24
|
+
from tico.serialize.quant_param import QPARAM_KEY, QuantParam, to_qparam_dtype
|
25
|
+
from tico.utils import logging
|
26
|
+
from tico.utils.graph import add_placeholder, get_torch_param_value, is_torch_param
|
27
|
+
from tico.utils.passes import PassBase, PassResult
|
28
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
29
|
+
from tico.utils.validate_args_kwargs import LinearArgs
|
30
|
+
|
31
|
+
|
32
|
+
@trace_graph_diff_on_pass
|
33
|
+
class QuantizeBias(PassBase):
|
34
|
+
"""
|
35
|
+
Quantize bias.
|
36
|
+
|
37
|
+
This pass identifies fp32 biases, quantizes them using scales of input and weights.
|
38
|
+
|
39
|
+
This pass assumes that if bias is fp32, input and weights must have been quantized.
|
40
|
+
"""
|
41
|
+
|
42
|
+
def __init__(self):
|
43
|
+
super().__init__()
|
44
|
+
|
45
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
46
|
+
logger = logging.getLogger(__name__)
|
47
|
+
|
48
|
+
graph_module = exported_program.graph_module
|
49
|
+
graph: torch.fx.Graph = graph_module.graph
|
50
|
+
for node in graph.nodes:
|
51
|
+
if node.op != "call_function":
|
52
|
+
continue
|
53
|
+
if node.target == torch.ops.aten.linear.default:
|
54
|
+
lin_args = LinearArgs(*node.args, **node.kwargs)
|
55
|
+
inp = lin_args.input
|
56
|
+
weights = lin_args.weight
|
57
|
+
bias = lin_args.bias
|
58
|
+
|
59
|
+
if bias is None:
|
60
|
+
continue
|
61
|
+
|
62
|
+
# Only support bias is Parameter
|
63
|
+
# TODO Is it possible that bias is not Parameter?
|
64
|
+
if not is_torch_param(bias, exported_program):
|
65
|
+
continue
|
66
|
+
|
67
|
+
bias_val: torch.Tensor = get_torch_param_value(bias, exported_program)
|
68
|
+
if bias_val.dtype != torch.float32:
|
69
|
+
continue
|
70
|
+
|
71
|
+
if QPARAM_KEY not in inp.meta:
|
72
|
+
continue
|
73
|
+
|
74
|
+
if QPARAM_KEY not in weights.meta:
|
75
|
+
continue
|
76
|
+
|
77
|
+
quant_dtype = None
|
78
|
+
if inp.meta[QPARAM_KEY].dtype == "int16":
|
79
|
+
quant_dtype = torch.int64
|
80
|
+
elif inp.meta[QPARAM_KEY].dtype == "uint8":
|
81
|
+
quant_dtype = torch.int32
|
82
|
+
else:
|
83
|
+
continue
|
84
|
+
|
85
|
+
type_info = torch.iinfo(quant_dtype)
|
86
|
+
|
87
|
+
assert quant_dtype is not None
|
88
|
+
|
89
|
+
i_scale = inp.meta[QPARAM_KEY].scale
|
90
|
+
w_scale = weights.meta[QPARAM_KEY].scale
|
91
|
+
|
92
|
+
assert i_scale is not None
|
93
|
+
assert w_scale is not None
|
94
|
+
assert len(i_scale) == 1
|
95
|
+
assert len(w_scale) == bias_val.shape[0]
|
96
|
+
|
97
|
+
bias_scale = torch.tensor(i_scale) * torch.tensor(w_scale)
|
98
|
+
q_bias = torch.round(bias_val / bias_scale)
|
99
|
+
q_bias = torch.clamp(q_bias, min=type_info.min, max=type_info.max)
|
100
|
+
q_bias = q_bias.to(quant_dtype)
|
101
|
+
|
102
|
+
q_bias_node = add_placeholder(exported_program, q_bias, bias.name)
|
103
|
+
|
104
|
+
qparam = QuantParam()
|
105
|
+
qparam.scale = bias_scale.tolist()
|
106
|
+
assert qparam.scale is not None
|
107
|
+
qparam.zero_point = [0] * len(qparam.scale)
|
108
|
+
qparam.dtype = to_qparam_dtype(quant_dtype)
|
109
|
+
qparam.quantized_dimension = 0
|
110
|
+
q_bias_node.meta[QPARAM_KEY] = qparam
|
111
|
+
|
112
|
+
node.update_arg(2, q_bias_node)
|
113
|
+
|
114
|
+
logger.debug(f"Bias ({bias.name}) is quantized to {q_bias_node.name}.")
|
115
|
+
|
116
|
+
# TODO Support more ops.
|
117
|
+
|
118
|
+
graph.eliminate_dead_code()
|
119
|
+
graph.lint()
|
120
|
+
graph_module.recompile()
|
121
|
+
|
122
|
+
# Run only once.
|
123
|
+
return PassResult(False)
|
@@ -0,0 +1,177 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import List, Optional, TYPE_CHECKING, Union
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch._export.utils import (
|
21
|
+
get_buffer,
|
22
|
+
get_lifted_tensor_constant,
|
23
|
+
is_buffer,
|
24
|
+
is_lifted_tensor_constant,
|
25
|
+
)
|
26
|
+
from torch._subclasses.fake_tensor import FakeTensor
|
27
|
+
from torch.export import ExportedProgram
|
28
|
+
|
29
|
+
from tico.serialize.quant_param import QPARAM_KEY, QuantParam, to_qparam_dtype
|
30
|
+
from tico.utils import logging
|
31
|
+
from tico.utils.passes import PassBase, PassResult
|
32
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
33
|
+
from tico.utils.validate_args_kwargs import (
|
34
|
+
DequantizePerChannelArgs,
|
35
|
+
DequantizePerTensorArgs,
|
36
|
+
)
|
37
|
+
|
38
|
+
|
39
|
+
def get_constant(exported_program: ExportedProgram, node: torch.fx.Node):
|
40
|
+
assert isinstance(node, torch.fx.Node)
|
41
|
+
if node.name in exported_program.constants:
|
42
|
+
return exported_program.constants[node.name]
|
43
|
+
elif is_buffer(exported_program, node):
|
44
|
+
return get_buffer(exported_program, node)
|
45
|
+
elif is_lifted_tensor_constant(exported_program, node):
|
46
|
+
return get_lifted_tensor_constant(exported_program, node)
|
47
|
+
else:
|
48
|
+
raise RuntimeError("NYI constant")
|
49
|
+
|
50
|
+
|
51
|
+
class ValRange:
|
52
|
+
def __init__(self, val: Union[torch.Tensor, List[int]]):
|
53
|
+
if isinstance(val, torch.Tensor):
|
54
|
+
self.max = torch.max(val).item()
|
55
|
+
self.min = torch.min(val).item()
|
56
|
+
elif type(val) == list:
|
57
|
+
self.max = max(val)
|
58
|
+
self.min = min(val)
|
59
|
+
else:
|
60
|
+
raise RuntimeError("Wrong dtype (val)")
|
61
|
+
|
62
|
+
def within(self, min_val, max_val):
|
63
|
+
return self.min >= min_val and self.max <= max_val
|
64
|
+
|
65
|
+
|
66
|
+
# Infer dtype using weight, zero point, and dtype
|
67
|
+
def infer_dtype(weight: torch.Tensor, zerop: List[int], dtype: torch.dtype) -> str:
|
68
|
+
weight_val = ValRange(weight)
|
69
|
+
zp_val = ValRange(zerop)
|
70
|
+
|
71
|
+
if weight_val.within(0, 15) and zp_val.within(0, 15) and dtype == torch.uint8:
|
72
|
+
return "uint4"
|
73
|
+
else:
|
74
|
+
return to_qparam_dtype(dtype)
|
75
|
+
|
76
|
+
|
77
|
+
@trace_graph_diff_on_pass
|
78
|
+
class RemoveWeightDequantOp(PassBase):
|
79
|
+
"""
|
80
|
+
This pass identifies and removes any remaining Dequantize ops associated with
|
81
|
+
quantized weights.
|
82
|
+
|
83
|
+
Since weights already quantized earlier (and possibly kept in float by
|
84
|
+
attaching a DQ), the final stage of the quantization pipeline typically
|
85
|
+
does not require those DQ ops anymore.
|
86
|
+
|
87
|
+
NOTE Removing 'DQ' causes a sementic change: f32 -> quantized
|
88
|
+
|
89
|
+
[BEFORE]
|
90
|
+
W (quantized) - Dequantize (float)
|
91
|
+
|
92
|
+
[AFTER]
|
93
|
+
W (quantized)
|
94
|
+
"""
|
95
|
+
|
96
|
+
def __init__(self):
|
97
|
+
super().__init__()
|
98
|
+
|
99
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
100
|
+
logger = logging.getLogger(__name__)
|
101
|
+
|
102
|
+
graph_module = exported_program.graph_module
|
103
|
+
graph: torch.fx.Graph = graph_module.graph
|
104
|
+
for dq in graph.nodes:
|
105
|
+
if not dq.op == "call_function":
|
106
|
+
continue
|
107
|
+
|
108
|
+
if dq.target not in [
|
109
|
+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
110
|
+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
111
|
+
]:
|
112
|
+
continue
|
113
|
+
dq_args: Optional[DequantizePerChannelArgs | DequantizePerTensorArgs] = None
|
114
|
+
|
115
|
+
if (
|
116
|
+
dq.target
|
117
|
+
== torch.ops.quantized_decomposed.dequantize_per_channel.default
|
118
|
+
):
|
119
|
+
dq_args = DequantizePerChannelArgs(*dq.args, **dq.kwargs)
|
120
|
+
elif (
|
121
|
+
dq.target
|
122
|
+
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
123
|
+
):
|
124
|
+
dq_args = DequantizePerTensorArgs(*dq.args, **dq.kwargs)
|
125
|
+
else:
|
126
|
+
raise RuntimeError(f"Invalid DQ target: {dq.target}")
|
127
|
+
|
128
|
+
q_weight = dq_args.input
|
129
|
+
# All weights are placehoders.
|
130
|
+
if q_weight.op != "placeholder":
|
131
|
+
continue
|
132
|
+
# Check if DQ already has quant param because DQ can be shared.
|
133
|
+
if QPARAM_KEY in q_weight.meta:
|
134
|
+
continue
|
135
|
+
|
136
|
+
q_weight_meta = q_weight.meta["val"]
|
137
|
+
assert isinstance(q_weight_meta, FakeTensor)
|
138
|
+
# Weight should have quantized values.
|
139
|
+
assert q_weight_meta.dtype != torch.float
|
140
|
+
|
141
|
+
q_weight_val = get_constant(exported_program, q_weight)
|
142
|
+
assert isinstance(q_weight_val, torch.Tensor)
|
143
|
+
|
144
|
+
quant_param = QuantParam()
|
145
|
+
if isinstance(dq_args, DequantizePerChannelArgs):
|
146
|
+
scales = get_constant(exported_program, dq_args.scales)
|
147
|
+
zero_ps = get_constant(exported_program, dq_args.zero_points)
|
148
|
+
|
149
|
+
# Sometimes users can give fp32 zero point. Let's update dtype here.
|
150
|
+
zero_ps = zero_ps.to(torch.int64)
|
151
|
+
quant_param.scale = scales.tolist()
|
152
|
+
quant_param.zero_point = zero_ps.tolist()
|
153
|
+
assert quant_param.zero_point is not None # To avoid mypy error
|
154
|
+
quant_param.quantized_dimension = dq_args.axis
|
155
|
+
quant_param.dtype = infer_dtype(
|
156
|
+
q_weight_val, quant_param.zero_point, q_weight_meta.dtype
|
157
|
+
)
|
158
|
+
elif isinstance(dq_args, DequantizePerTensorArgs):
|
159
|
+
quant_param.scale = [dq_args.scale]
|
160
|
+
quant_param.zero_point = [dq_args.zero_point]
|
161
|
+
assert quant_param.zero_point is not None # To avoid mypy error
|
162
|
+
quant_param.dtype = infer_dtype(
|
163
|
+
q_weight_val, quant_param.zero_point, q_weight_meta.dtype
|
164
|
+
)
|
165
|
+
else:
|
166
|
+
raise RuntimeError(f"Invalid DQ target: {dq.target}")
|
167
|
+
|
168
|
+
q_weight.meta[QPARAM_KEY] = quant_param
|
169
|
+
dq.replace_all_uses_with(q_weight, propagate_meta=False)
|
170
|
+
logger.debug(f"{dq.name} is removed.")
|
171
|
+
|
172
|
+
graph.eliminate_dead_code()
|
173
|
+
graph.lint()
|
174
|
+
graph_module.recompile()
|
175
|
+
|
176
|
+
# Run only once.
|
177
|
+
return PassResult(False)
|
@@ -0,0 +1,108 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import copy
|
16
|
+
from typing import Any, Dict, Optional, Type
|
17
|
+
|
18
|
+
import torch
|
19
|
+
|
20
|
+
from tico.experimental.quantization.algorithm.gptq.quantizer import GPTQQuantizer
|
21
|
+
from tico.experimental.quantization.algorithm.pt2e.quantizer import PT2EQuantizer
|
22
|
+
from tico.experimental.quantization.algorithm.smoothquant.quantizer import (
|
23
|
+
SmoothQuantQuantizer,
|
24
|
+
)
|
25
|
+
from tico.experimental.quantization.config import BaseConfig
|
26
|
+
from tico.experimental.quantization.quantizer import BaseQuantizer
|
27
|
+
|
28
|
+
|
29
|
+
config_to_quantizer: Dict[str, Type[BaseQuantizer]] = {
|
30
|
+
"pt2e": PT2EQuantizer,
|
31
|
+
"gptq": GPTQQuantizer,
|
32
|
+
"smooth_quant": SmoothQuantQuantizer,
|
33
|
+
}
|
34
|
+
|
35
|
+
QUANTIZER_ATTRIBUTE_NAME = "tico_quantizer"
|
36
|
+
|
37
|
+
|
38
|
+
def prepare(
|
39
|
+
model: torch.nn.Module,
|
40
|
+
quant_config: BaseConfig,
|
41
|
+
args: Optional[Any] = None,
|
42
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
43
|
+
inplace: Optional[bool] = False,
|
44
|
+
):
|
45
|
+
"""
|
46
|
+
Prepare the model for quantization using the provided configuration.
|
47
|
+
|
48
|
+
Determines the appropriate quantizer based on the type of `quant_config` and
|
49
|
+
prepares the model accordingly.
|
50
|
+
|
51
|
+
Parameters:
|
52
|
+
model: The PyTorch model to be quantized.
|
53
|
+
quant_config (BaseConfig): The quantization configuration.
|
54
|
+
args (Any, optional): Positional example inputs required for activation quantization.
|
55
|
+
kwargs (Dict[str, Any], optional): Keyword example inputs required for activation quantization.
|
56
|
+
inplace (bool, optional): If true, the model will be modified in place;
|
57
|
+
otherwise, a new prepared model is returned.
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
The model prepared for quantization.
|
61
|
+
"""
|
62
|
+
if quant_config.name == "pt2e" and inplace:
|
63
|
+
raise RuntimeError(
|
64
|
+
"In-place is not supported for PT2E quantization due to limitation in the underlying Torch APIs. Please set 'inplace=False' to proceed."
|
65
|
+
)
|
66
|
+
|
67
|
+
model = model if inplace else copy.deepcopy(model)
|
68
|
+
|
69
|
+
quantizer = config_to_quantizer[quant_config.name](quant_config)
|
70
|
+
model = quantizer.prepare(model, args, kwargs)
|
71
|
+
setattr(model, QUANTIZER_ATTRIBUTE_NAME, quantizer)
|
72
|
+
|
73
|
+
return model
|
74
|
+
|
75
|
+
|
76
|
+
def convert(model, inplace: Optional[bool] = False):
|
77
|
+
"""
|
78
|
+
Convert the prepared model to a quantized model using the provided configuration.
|
79
|
+
|
80
|
+
Determines the appropriate quantizer based on the type of quant_config and
|
81
|
+
converts the model accordingly.
|
82
|
+
|
83
|
+
Parameters:
|
84
|
+
model: The prepared PyTorch model.
|
85
|
+
inplace (bool, optional): If true, the model will be modified in place;
|
86
|
+
otherwise, a new prepared model is returned.
|
87
|
+
|
88
|
+
Returns:
|
89
|
+
The quantized model.
|
90
|
+
"""
|
91
|
+
# Get quantizer first before calling deepcopy that does not copy attributes properly.
|
92
|
+
if hasattr(model, QUANTIZER_ATTRIBUTE_NAME):
|
93
|
+
quantizer = getattr(model, QUANTIZER_ATTRIBUTE_NAME)
|
94
|
+
delattr(model, QUANTIZER_ATTRIBUTE_NAME)
|
95
|
+
else:
|
96
|
+
raise RuntimeError("Call prepare() function first.")
|
97
|
+
|
98
|
+
if isinstance(quantizer, PT2EQuantizer) and inplace:
|
99
|
+
raise RuntimeError(
|
100
|
+
"In-place is not supported for PT2E quantization due to limitation in the underlying Torch APIs. Please set 'inplace=False' to proceed."
|
101
|
+
)
|
102
|
+
|
103
|
+
model = model if inplace else copy.deepcopy(model)
|
104
|
+
|
105
|
+
assert isinstance(quantizer, BaseQuantizer)
|
106
|
+
model = quantizer.convert(model)
|
107
|
+
|
108
|
+
return model
|