tico 0.1.0.dev250411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +31 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +185 -0
- tico/passes/cast_mixed_type_args.py +186 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +151 -0
- tico/passes/convert_layout_op_to_reshape.py +84 -0
- tico/passes/convert_repeat_to_expand_copy.py +90 -0
- tico/passes/convert_to_relu6.py +180 -0
- tico/passes/decompose_addmm.py +127 -0
- tico/passes/decompose_batch_norm.py +198 -0
- tico/passes/decompose_fake_quantize.py +126 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
- tico/passes/decompose_group_norm.py +258 -0
- tico/passes/decompose_grouped_conv2d.py +202 -0
- tico/passes/decompose_slice_scatter.py +167 -0
- tico/passes/extract_dtype_kwargs.py +121 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +113 -0
- tico/passes/legalize_predefined_layout_operators.py +383 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
- tico/passes/lower_to_slice.py +112 -0
- tico/passes/merge_consecutive_cat.py +82 -0
- tico/passes/ops.py +75 -0
- tico/passes/remove_nop.py +85 -0
- tico/passes/remove_redundant_assert_nodes.py +50 -0
- tico/passes/remove_redundant_expand.py +70 -0
- tico/passes/remove_redundant_permute.py +102 -0
- tico/passes/remove_redundant_reshape.py +431 -0
- tico/passes/remove_redundant_slice.py +64 -0
- tico/passes/remove_redundant_to_copy.py +84 -0
- tico/passes/restore_linear.py +113 -0
- tico/passes/segment_index_select.py +143 -0
- tico/pt2_to_circle.py +101 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +264 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +232 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +142 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +112 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +123 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +181 -0
- tico/serialize/operators/op_copy.py +162 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +92 -0
- tico/serialize/operators/op_depthwise_conv2d.py +198 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +83 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +174 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +138 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +99 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +96 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +51 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +292 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +200 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +562 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +314 -0
- tico/utils/validate_args_kwargs.py +1114 -0
- tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
- tico-0.1.0.dev250411.dist-info/METADATA +17 -0
- tico-0.1.0.dev250411.dist-info/RECORD +196 -0
- tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
- tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,141 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import copy
|
20
|
+
|
21
|
+
import torch
|
22
|
+
from torch.export import ExportedProgram
|
23
|
+
|
24
|
+
from tico.serialize.quant_param import QPARAM_KEY
|
25
|
+
from tico.utils import logging
|
26
|
+
from tico.utils.passes import PassBase, PassResult
|
27
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
|
+
from tico.utils.validate_args_kwargs import (
|
29
|
+
CatArgs,
|
30
|
+
NegArgs,
|
31
|
+
PermuteArgs,
|
32
|
+
ReshapeArgs,
|
33
|
+
SliceArgs,
|
34
|
+
)
|
35
|
+
|
36
|
+
|
37
|
+
@trace_graph_diff_on_pass
|
38
|
+
class PropagateQParamForward(PassBase):
|
39
|
+
"""
|
40
|
+
A pass propagates quantization parameters through operations that do not alter them.
|
41
|
+
|
42
|
+
This pass identifies and propagates quantization parameters through operations that
|
43
|
+
do not change their values, such as `permute`, `reshape`, `transpose`, `view` and
|
44
|
+
similar tensor transformations.
|
45
|
+
|
46
|
+
By ensuring that quantization parameters remain consistent across such operations,
|
47
|
+
this pass helps maintain correctness in quantization-aware representations.
|
48
|
+
"""
|
49
|
+
|
50
|
+
def __init__(self):
|
51
|
+
super().__init__()
|
52
|
+
|
53
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
54
|
+
logger = logging.getLogger(__name__)
|
55
|
+
|
56
|
+
def _propagate_qparam_if_possible(src: torch.fx.Node, dst: torch.fx.Node):
|
57
|
+
if QPARAM_KEY not in src.meta:
|
58
|
+
return
|
59
|
+
|
60
|
+
if (
|
61
|
+
QPARAM_KEY in dst.meta
|
62
|
+
and src.meta[QPARAM_KEY].dtype != dst.meta[QPARAM_KEY].dtype
|
63
|
+
):
|
64
|
+
return
|
65
|
+
|
66
|
+
dst.meta[QPARAM_KEY] = copy.deepcopy(src.meta[QPARAM_KEY])
|
67
|
+
|
68
|
+
logger.debug(f"{src.name}'s quantparam is propagated to {dst.name}.")
|
69
|
+
|
70
|
+
graph_module = exported_program.graph_module
|
71
|
+
graph: torch.fx.Graph = graph_module.graph
|
72
|
+
for node in graph.nodes:
|
73
|
+
if node.op != "call_function":
|
74
|
+
continue
|
75
|
+
if node.target == torch.ops.aten.permute.default:
|
76
|
+
permute_args = PermuteArgs(*node.args, **node.kwargs)
|
77
|
+
_propagate_qparam_if_possible(permute_args.input, node)
|
78
|
+
elif node.target == torch.ops.aten.reshape.default:
|
79
|
+
reshape_args = ReshapeArgs(*node.args, **node.kwargs)
|
80
|
+
_propagate_qparam_if_possible(reshape_args.input, node)
|
81
|
+
elif node.target == torch.ops.aten.slice.Tensor:
|
82
|
+
slice_args = SliceArgs(*node.args, **node.kwargs)
|
83
|
+
_propagate_qparam_if_possible(slice_args.input, node)
|
84
|
+
elif node.target == torch.ops.aten.neg.default:
|
85
|
+
neg_args = NegArgs(*node.args, **node.kwargs)
|
86
|
+
|
87
|
+
if QPARAM_KEY not in neg_args.input.meta:
|
88
|
+
continue
|
89
|
+
# Only support int16 for now
|
90
|
+
if neg_args.input.meta[QPARAM_KEY].dtype != "int16":
|
91
|
+
continue
|
92
|
+
|
93
|
+
_propagate_qparam_if_possible(neg_args.input, node)
|
94
|
+
|
95
|
+
elif node.target == torch.ops.aten.cat.default:
|
96
|
+
concat_args = CatArgs(*node.args, **node.kwargs)
|
97
|
+
concat_inputs = concat_args.tensors
|
98
|
+
|
99
|
+
cond = True
|
100
|
+
for concat_input in concat_inputs:
|
101
|
+
# Check all inputs have qparam
|
102
|
+
if QPARAM_KEY not in concat_input.meta:
|
103
|
+
cond = False
|
104
|
+
break
|
105
|
+
|
106
|
+
# Only support int16 for now
|
107
|
+
if concat_input.meta[QPARAM_KEY].dtype != "int16":
|
108
|
+
cond = False
|
109
|
+
break
|
110
|
+
|
111
|
+
if concat_input.meta[QPARAM_KEY].scale is None:
|
112
|
+
cond = False
|
113
|
+
break
|
114
|
+
|
115
|
+
if len(concat_input.meta[QPARAM_KEY].scale) != 1:
|
116
|
+
cond = False
|
117
|
+
break
|
118
|
+
|
119
|
+
if not cond:
|
120
|
+
continue
|
121
|
+
|
122
|
+
# Find max scale node
|
123
|
+
max_scale = 0.0
|
124
|
+
max_scale_node = None
|
125
|
+
for concat_input in concat_inputs:
|
126
|
+
scale = concat_input.meta[QPARAM_KEY].scale[0]
|
127
|
+
if max_scale < scale:
|
128
|
+
max_scale = scale
|
129
|
+
max_scale_node = concat_input
|
130
|
+
|
131
|
+
assert max_scale_node is not None
|
132
|
+
_propagate_qparam_if_possible(max_scale_node, node)
|
133
|
+
|
134
|
+
# TODO Support more ops.
|
135
|
+
|
136
|
+
graph.eliminate_dead_code()
|
137
|
+
graph.lint()
|
138
|
+
graph_module.recompile()
|
139
|
+
|
140
|
+
# Run only once.
|
141
|
+
return PassResult(False)
|
@@ -0,0 +1,168 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import List, Optional, TYPE_CHECKING, Union
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch._subclasses.fake_tensor import FakeTensor
|
21
|
+
from torch.export import ExportedProgram
|
22
|
+
|
23
|
+
from tico.serialize.quant_param import QPARAM_KEY, QuantParam, to_qparam_dtype
|
24
|
+
from tico.utils import logging
|
25
|
+
from tico.utils.passes import PassBase, PassResult
|
26
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
|
+
from tico.utils.validate_args_kwargs import (
|
28
|
+
DequantizePerChannelArgs,
|
29
|
+
DequantizePerTensorArgs,
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
def get_constant(exported_program: ExportedProgram, node: torch.fx.Node):
|
34
|
+
assert isinstance(node, torch.fx.Node)
|
35
|
+
if node.name in exported_program.constants:
|
36
|
+
return exported_program.constants[node.name]
|
37
|
+
elif node.name in exported_program.graph_signature.inputs_to_buffers:
|
38
|
+
buffer_name = exported_program.graph_signature.inputs_to_buffers[node.name]
|
39
|
+
named_buffer = dict(exported_program.named_buffers())
|
40
|
+
return named_buffer[buffer_name]
|
41
|
+
else:
|
42
|
+
raise RuntimeError("NYI constant")
|
43
|
+
|
44
|
+
|
45
|
+
class ValRange:
|
46
|
+
def __init__(self, val: Union[torch.Tensor, List[int]]):
|
47
|
+
if isinstance(val, torch.Tensor):
|
48
|
+
self.max = torch.max(val).item()
|
49
|
+
self.min = torch.min(val).item()
|
50
|
+
elif type(val) == list:
|
51
|
+
self.max = max(val)
|
52
|
+
self.min = min(val)
|
53
|
+
else:
|
54
|
+
raise RuntimeError("Wrong dtype (val)")
|
55
|
+
|
56
|
+
def within(self, min_val, max_val):
|
57
|
+
return self.min >= min_val and self.max <= max_val
|
58
|
+
|
59
|
+
|
60
|
+
# Infer dtype using weight, zero point, and dtype
|
61
|
+
def infer_dtype(weight: torch.Tensor, zerop: List[int], dtype: torch.dtype) -> str:
|
62
|
+
weight_val = ValRange(weight)
|
63
|
+
zp_val = ValRange(zerop)
|
64
|
+
|
65
|
+
if weight_val.within(0, 15) and zp_val.within(0, 15) and dtype == torch.uint8:
|
66
|
+
return "uint4"
|
67
|
+
else:
|
68
|
+
return to_qparam_dtype(dtype)
|
69
|
+
|
70
|
+
|
71
|
+
@trace_graph_diff_on_pass
|
72
|
+
class RemoveWeightDequantOp(PassBase):
|
73
|
+
"""
|
74
|
+
This pass identifies and removes any remaining Dequantize ops associated with
|
75
|
+
quantized weights.
|
76
|
+
|
77
|
+
Since weights already quantized earlier (and possibly kept in float by
|
78
|
+
attaching a DQ), the final stage of the quantization pipeline typically
|
79
|
+
does not require those DQ ops anymore.
|
80
|
+
|
81
|
+
NOTE Removing 'DQ' causes a sementic change: f32 -> quantized
|
82
|
+
|
83
|
+
[BEFORE]
|
84
|
+
W (quantized) - Dequantize (float)
|
85
|
+
|
86
|
+
[AFTER]
|
87
|
+
W (quantized)
|
88
|
+
"""
|
89
|
+
|
90
|
+
def __init__(self):
|
91
|
+
super().__init__()
|
92
|
+
|
93
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
94
|
+
logger = logging.getLogger(__name__)
|
95
|
+
|
96
|
+
graph_module = exported_program.graph_module
|
97
|
+
graph: torch.fx.Graph = graph_module.graph
|
98
|
+
for dq in graph.nodes:
|
99
|
+
if not dq.op == "call_function":
|
100
|
+
continue
|
101
|
+
|
102
|
+
if dq.target not in [
|
103
|
+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
104
|
+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
105
|
+
]:
|
106
|
+
continue
|
107
|
+
dq_args: Optional[DequantizePerChannelArgs | DequantizePerTensorArgs] = None
|
108
|
+
|
109
|
+
if (
|
110
|
+
dq.target
|
111
|
+
== torch.ops.quantized_decomposed.dequantize_per_channel.default
|
112
|
+
):
|
113
|
+
dq_args = DequantizePerChannelArgs(*dq.args, *dq.kwargs)
|
114
|
+
elif (
|
115
|
+
dq.target
|
116
|
+
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
117
|
+
):
|
118
|
+
dq_args = DequantizePerTensorArgs(*dq.args, *dq.kwargs)
|
119
|
+
else:
|
120
|
+
raise RuntimeError(f"Invalid DQ target: {dq.target}")
|
121
|
+
|
122
|
+
q_weight = dq_args.input
|
123
|
+
# All weights are placehoders.
|
124
|
+
if q_weight.op != "placeholder":
|
125
|
+
continue
|
126
|
+
# Check if DQ already has quant param because DQ can be shared.
|
127
|
+
if QPARAM_KEY in q_weight.meta:
|
128
|
+
continue
|
129
|
+
|
130
|
+
q_weight_meta = q_weight.meta["val"]
|
131
|
+
assert isinstance(q_weight_meta, FakeTensor)
|
132
|
+
# Weight should have quantized values.
|
133
|
+
assert q_weight_meta.dtype != torch.float
|
134
|
+
|
135
|
+
q_weight_val = get_constant(exported_program, q_weight)
|
136
|
+
assert isinstance(q_weight_val, torch.Tensor)
|
137
|
+
|
138
|
+
quant_param = QuantParam()
|
139
|
+
if isinstance(dq_args, DequantizePerChannelArgs):
|
140
|
+
scales = get_constant(exported_program, dq_args.scales)
|
141
|
+
zero_ps = get_constant(exported_program, dq_args.zero_points)
|
142
|
+
quant_param.scale = scales.tolist()
|
143
|
+
quant_param.zero_point = zero_ps.tolist()
|
144
|
+
assert quant_param.zero_point is not None # To avoid mypy error
|
145
|
+
quant_param.quantized_dimension = dq_args.axis
|
146
|
+
quant_param.dtype = infer_dtype(
|
147
|
+
q_weight_val, quant_param.zero_point, q_weight_meta.dtype
|
148
|
+
)
|
149
|
+
elif isinstance(dq_args, DequantizePerTensorArgs):
|
150
|
+
quant_param.scale = [dq_args.scale]
|
151
|
+
quant_param.zero_point = [dq_args.zero_point]
|
152
|
+
assert quant_param.zero_point is not None # To avoid mypy error
|
153
|
+
quant_param.dtype = infer_dtype(
|
154
|
+
q_weight_val, quant_param.zero_point, q_weight_meta.dtype
|
155
|
+
)
|
156
|
+
else:
|
157
|
+
raise RuntimeError(f"Invalid DQ target: {dq.target}")
|
158
|
+
|
159
|
+
q_weight.meta[QPARAM_KEY] = quant_param
|
160
|
+
dq.replace_all_uses_with(q_weight, propagate_meta=False)
|
161
|
+
logger.debug(f"{dq.name} is removed.")
|
162
|
+
|
163
|
+
graph.eliminate_dead_code()
|
164
|
+
graph.lint()
|
165
|
+
graph_module.recompile()
|
166
|
+
|
167
|
+
# Run only once.
|
168
|
+
return PassResult(False)
|
@@ -0,0 +1,108 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import copy
|
16
|
+
from typing import Any, Dict, Optional, Type
|
17
|
+
|
18
|
+
import torch
|
19
|
+
|
20
|
+
from tico.experimental.quantization.algorithm.gptq.quantizer import GPTQQuantizer
|
21
|
+
from tico.experimental.quantization.algorithm.pt2e.quantizer import PT2EQuantizer
|
22
|
+
from tico.experimental.quantization.algorithm.smoothquant.quantizer import (
|
23
|
+
SmoothQuantQuantizer,
|
24
|
+
)
|
25
|
+
from tico.experimental.quantization.config import BaseConfig
|
26
|
+
from tico.experimental.quantization.quantizer import BaseQuantizer
|
27
|
+
|
28
|
+
|
29
|
+
config_to_quantizer: Dict[str, Type[BaseQuantizer]] = {
|
30
|
+
"pt2e": PT2EQuantizer,
|
31
|
+
"gptq": GPTQQuantizer,
|
32
|
+
"smooth_quant": SmoothQuantQuantizer,
|
33
|
+
}
|
34
|
+
|
35
|
+
QUANTIZER_ATTRIBUTE_NAME = "tico_quantizer"
|
36
|
+
|
37
|
+
|
38
|
+
def prepare(
|
39
|
+
model: torch.nn.Module,
|
40
|
+
quant_config: BaseConfig,
|
41
|
+
args: Optional[Any] = None,
|
42
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
43
|
+
inplace: Optional[bool] = False,
|
44
|
+
):
|
45
|
+
"""
|
46
|
+
Prepare the model for quantization using the provided configuration.
|
47
|
+
|
48
|
+
Determines the appropriate quantizer based on the type of `quant_config` and
|
49
|
+
prepares the model accordingly.
|
50
|
+
|
51
|
+
Parameters:
|
52
|
+
model: The PyTorch model to be quantized.
|
53
|
+
quant_config (BaseConfig): The quantization configuration.
|
54
|
+
args (Any, optional): Positional example inputs required for activation quantization.
|
55
|
+
kwargs (Dict[str, Any], optional): Keyword example inputs required for activation quantization.
|
56
|
+
inplace (bool, optional): If true, the model will be modified in place;
|
57
|
+
otherwise, a new prepared model is returned.
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
The model prepared for quantization.
|
61
|
+
"""
|
62
|
+
if quant_config.name == "pt2e" and inplace:
|
63
|
+
raise RuntimeError(
|
64
|
+
"In-place is not supported for PT2E quantization due to limitation in the underlying Torch APIs. Please set 'inplace=False' to proceed."
|
65
|
+
)
|
66
|
+
|
67
|
+
model = model if inplace else copy.deepcopy(model)
|
68
|
+
|
69
|
+
quantizer = config_to_quantizer[quant_config.name](quant_config)
|
70
|
+
model = quantizer.prepare(model, args, kwargs)
|
71
|
+
setattr(model, QUANTIZER_ATTRIBUTE_NAME, quantizer)
|
72
|
+
|
73
|
+
return model
|
74
|
+
|
75
|
+
|
76
|
+
def convert(model, inplace: Optional[bool] = False):
|
77
|
+
"""
|
78
|
+
Convert the prepared model to a quantized model using the provided configuration.
|
79
|
+
|
80
|
+
Determines the appropriate quantizer based on the type of quant_config and
|
81
|
+
converts the model accordingly.
|
82
|
+
|
83
|
+
Parameters:
|
84
|
+
model: The prepared PyTorch model.
|
85
|
+
inplace (bool, optional): If true, the model will be modified in place;
|
86
|
+
otherwise, a new prepared model is returned.
|
87
|
+
|
88
|
+
Returns:
|
89
|
+
The quantized model.
|
90
|
+
"""
|
91
|
+
# Get quantizer first before calling deepcopy that does not copy attributes properly.
|
92
|
+
if hasattr(model, QUANTIZER_ATTRIBUTE_NAME):
|
93
|
+
quantizer = getattr(model, QUANTIZER_ATTRIBUTE_NAME)
|
94
|
+
delattr(model, QUANTIZER_ATTRIBUTE_NAME)
|
95
|
+
else:
|
96
|
+
raise RuntimeError("Call prepare() function first.")
|
97
|
+
|
98
|
+
if isinstance(quantizer, PT2EQuantizer) and inplace:
|
99
|
+
raise RuntimeError(
|
100
|
+
"In-place is not supported for PT2E quantization due to limitation in the underlying Torch APIs. Please set 'inplace=False' to proceed."
|
101
|
+
)
|
102
|
+
|
103
|
+
model = model if inplace else copy.deepcopy(model)
|
104
|
+
|
105
|
+
assert isinstance(quantizer, BaseQuantizer)
|
106
|
+
model = quantizer.convert(model)
|
107
|
+
|
108
|
+
return model
|
@@ -0,0 +1,71 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from abc import ABC, abstractmethod
|
16
|
+
from typing import Any, Dict, Optional
|
17
|
+
|
18
|
+
import torch
|
19
|
+
|
20
|
+
from tico.experimental.quantization.config import BaseConfig
|
21
|
+
|
22
|
+
|
23
|
+
class BaseQuantizer(ABC):
|
24
|
+
"""
|
25
|
+
Abstract base class for quantizers that apply a quantization algorithm to a target model.
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(self, config: BaseConfig):
|
29
|
+
"""
|
30
|
+
Initialize the quantizer with the given configuration.
|
31
|
+
|
32
|
+
Parameters:
|
33
|
+
config (BaseConfig): Quantization configuration parameters.
|
34
|
+
"""
|
35
|
+
self.config = config
|
36
|
+
|
37
|
+
@abstractmethod
|
38
|
+
def prepare(
|
39
|
+
self,
|
40
|
+
model: torch.nn.Module,
|
41
|
+
args: Optional[Any] = None,
|
42
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
43
|
+
):
|
44
|
+
"""
|
45
|
+
Prepare the given model for quantization based on the provided algorithm-specific
|
46
|
+
configuration. This involves setting up necessary observers or hooks, and may
|
47
|
+
optionally use example inputs—particularly useful for activation quantization.
|
48
|
+
|
49
|
+
Parameters:
|
50
|
+
model: The target PyTorch model.
|
51
|
+
args (Any, optional): Positional example inputs required for activation quantization.
|
52
|
+
kwargs (Dict[str, Any], optional): Keyword example inputs required for activation quantization.
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
The prepared model.
|
56
|
+
"""
|
57
|
+
pass
|
58
|
+
|
59
|
+
@abstractmethod
|
60
|
+
def convert(self, model):
|
61
|
+
"""
|
62
|
+
Convert the prepared (or calibrated) model into its quantized form. This function leverages
|
63
|
+
the statistics collected during calibration to perform the quantization transformation.
|
64
|
+
|
65
|
+
Parameters:
|
66
|
+
model: The prepared PyTorch model.
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
The quantized model.
|
70
|
+
"""
|
71
|
+
pass
|
@@ -0,0 +1 @@
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
@@ -0,0 +1,116 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Any
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
import torch
|
19
|
+
from circle_schema import circle
|
20
|
+
|
21
|
+
from tico.interpreter.interpreter import Interpreter
|
22
|
+
from tico.serialize.circle_mapping import np_dtype_from_circle_dtype, to_circle_dtype
|
23
|
+
|
24
|
+
|
25
|
+
def preprocess_inputs(inputs: Any):
|
26
|
+
"""
|
27
|
+
Preprocess user inputs for circle inference.
|
28
|
+
|
29
|
+
1. None inputs are ignored.
|
30
|
+
2. A list/tuple input is flatten when a torch module is exported.
|
31
|
+
e.g. inputs = (torch.Tensor, [2,3,4]) -> inputs = (torch.Tensor, 2, 3, 4)
|
32
|
+
"""
|
33
|
+
l = []
|
34
|
+
for value in inputs:
|
35
|
+
if value == None:
|
36
|
+
continue
|
37
|
+
if isinstance(value, (tuple, list)):
|
38
|
+
for val in value:
|
39
|
+
l.append(val)
|
40
|
+
else:
|
41
|
+
l.append(value)
|
42
|
+
# Check if it is a list of a list.
|
43
|
+
if any(isinstance(item, (tuple, list)) for item in l):
|
44
|
+
l = preprocess_inputs(l)
|
45
|
+
return tuple(l)
|
46
|
+
|
47
|
+
|
48
|
+
def infer(circle_binary: bytes, *args: Any, **kwargs: Any) -> Any:
|
49
|
+
# When converting a model, it is assumed that the order of keyword arguments is maintained.
|
50
|
+
user_inputs = args + tuple(kwargs.values())
|
51
|
+
user_inputs = preprocess_inputs(user_inputs)
|
52
|
+
# Cast them to torch.Tensor to make it simple.
|
53
|
+
user_inputs = tuple(
|
54
|
+
torch.tensor(user_input) if type(user_input) != torch.Tensor else user_input
|
55
|
+
for user_input in user_inputs
|
56
|
+
)
|
57
|
+
|
58
|
+
# Get input spec from circle binary.
|
59
|
+
model = circle.Model.Model.GetRootAsModel(circle_binary, 0)
|
60
|
+
assert model.SubgraphsLength() == 1
|
61
|
+
graph = model.Subgraphs(0)
|
62
|
+
model_input_tensors = [
|
63
|
+
graph.Tensors(graph.Inputs(o)) for o in range(graph.InputsLength())
|
64
|
+
]
|
65
|
+
model_input_shapes_np = [t.ShapeAsNumpy() for t in model_input_tensors]
|
66
|
+
model_input_types_cm = [t.Type() for t in model_input_tensors]
|
67
|
+
|
68
|
+
# Check if given inputs' dtype and shape from users match the inputs' from model binary.
|
69
|
+
if len(model_input_shapes_np) != len(user_inputs):
|
70
|
+
raise RuntimeError(
|
71
|
+
f"Mismatch input length: input({len(user_inputs)}) != circle model({len(model_input_shapes_np)})"
|
72
|
+
)
|
73
|
+
for input_idx, user_input in enumerate(user_inputs):
|
74
|
+
# Shape
|
75
|
+
if list(user_input.shape) != list(model_input_shapes_np[input_idx]):
|
76
|
+
raise RuntimeError(
|
77
|
+
f"Mismatch input {input_idx} shape : input({user_input.shape}) != circle model({model_input_shapes_np[input_idx]})"
|
78
|
+
)
|
79
|
+
# Data type
|
80
|
+
user_input_type_cm = to_circle_dtype(user_input.dtype)
|
81
|
+
if user_input_type_cm != model_input_types_cm[input_idx]:
|
82
|
+
raise RuntimeError(
|
83
|
+
f"Mismatch input {input_idx} data type : input({user_input_type_cm}) != circle model({model_input_types_cm[input_idx]})"
|
84
|
+
)
|
85
|
+
|
86
|
+
# Initialize interpreter
|
87
|
+
intp = Interpreter(circle_binary)
|
88
|
+
|
89
|
+
# Set input
|
90
|
+
for input_idx, user_input in enumerate(user_inputs):
|
91
|
+
intp.writeInputTensor(input_idx, user_input)
|
92
|
+
|
93
|
+
# Interpret
|
94
|
+
intp.interpret()
|
95
|
+
|
96
|
+
# Retrieve outputs' dtype and shape from circle model
|
97
|
+
model_output_tensors = [
|
98
|
+
graph.Tensors(graph.Outputs(o)) for o in range(graph.OutputsLength())
|
99
|
+
]
|
100
|
+
model_output_shapes_np = [t.ShapeAsNumpy() for t in model_output_tensors]
|
101
|
+
model_output_types_cm = [t.Type() for t in model_output_tensors]
|
102
|
+
|
103
|
+
output = []
|
104
|
+
# Get output
|
105
|
+
for output_idx in range(len(model_output_tensors)):
|
106
|
+
result: np.ndarray = np.empty(
|
107
|
+
model_output_shapes_np[output_idx],
|
108
|
+
dtype=np_dtype_from_circle_dtype(model_output_types_cm[output_idx]),
|
109
|
+
)
|
110
|
+
intp.readOutputTensor(output_idx, result)
|
111
|
+
output.append(result)
|
112
|
+
|
113
|
+
if len(output) == 1:
|
114
|
+
return output[0]
|
115
|
+
else:
|
116
|
+
return output
|