tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +9 -1
- tico/config/base.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_aten_where_arg_type.py +1 -1
- tico/passes/cast_clamp_mixed_type_args.py +169 -0
- tico/passes/cast_mixed_type_args.py +4 -2
- tico/passes/const_prop_pass.py +1 -1
- tico/passes/convert_conv1d_to_conv2d.py +1 -1
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_addmm.py +0 -3
- tico/passes/decompose_batch_norm.py +2 -2
- tico/passes/decompose_fake_quantize.py +0 -3
- tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
- tico/passes/decompose_group_norm.py +0 -3
- tico/passes/legalize_predefined_layout_operators.py +2 -11
- tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
- tico/passes/lower_to_slice.py +1 -1
- tico/passes/merge_consecutive_cat.py +1 -1
- tico/passes/ops.py +1 -1
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- tico/passes/remove_redundant_expand.py +3 -6
- tico/passes/remove_redundant_reshape.py +5 -5
- tico/passes/segment_index_select.py +1 -1
- tico/quantization/__init__.py +6 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
- tico/quantization/algorithm/gptq/quantizer.py +292 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/quantization/config/ptq.py +119 -0
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/quantization/evaluation/metric.py +146 -0
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/quantization/passes/__init__.py +1 -0
- tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
- tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
- tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
- tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
- tico/{experimental/quantization → quantization}/public_interface.py +19 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/__init__.py +1 -0
- tico/quantization/wrapq/dtypes.py +70 -0
- tico/quantization/wrapq/examples/__init__.py +1 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/quantization/wrapq/examples/quantize_linear.py +107 -0
- tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
- tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
- tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/quantization/wrapq/mode.py +32 -0
- tico/quantization/wrapq/observers/__init__.py +1 -0
- tico/quantization/wrapq/observers/affine_base.py +128 -0
- tico/quantization/wrapq/observers/base.py +98 -0
- tico/quantization/wrapq/observers/ema.py +62 -0
- tico/quantization/wrapq/observers/identity.py +74 -0
- tico/quantization/wrapq/observers/minmax.py +39 -0
- tico/quantization/wrapq/observers/mx.py +60 -0
- tico/quantization/wrapq/qscheme.py +40 -0
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/quantization/wrapq/utils/__init__.py +1 -0
- tico/quantization/wrapq/utils/introspection.py +167 -0
- tico/quantization/wrapq/utils/metrics.py +124 -0
- tico/quantization/wrapq/utils/reduce_utils.py +25 -0
- tico/quantization/wrapq/wrappers/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
- tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
- tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
- tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
- tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
- tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
- tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
- tico/quantization/wrapq/wrappers/registry.py +125 -0
- tico/serialize/circle_graph.py +12 -4
- tico/serialize/circle_mapping.py +76 -2
- tico/serialize/circle_serializer.py +253 -148
- tico/serialize/operators/adapters/__init__.py +1 -0
- tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
- tico/serialize/operators/op_any.py +7 -14
- tico/serialize/operators/op_avg_pool2d.py +11 -4
- tico/serialize/operators/op_clamp.py +5 -7
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_conv2d.py +14 -6
- tico/serialize/operators/op_copy.py +26 -3
- tico/serialize/operators/op_cumsum.py +3 -1
- tico/serialize/operators/op_depthwise_conv2d.py +17 -7
- tico/serialize/operators/op_full_like.py +0 -2
- tico/serialize/operators/op_index_select.py +8 -1
- tico/serialize/operators/op_instance_norm.py +0 -6
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_log1p.py +3 -2
- tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
- tico/serialize/operators/op_mm.py +15 -131
- tico/serialize/operators/op_mul.py +2 -8
- tico/serialize/operators/op_pow.py +3 -1
- tico/serialize/operators/op_repeat.py +12 -3
- tico/serialize/operators/op_reshape.py +1 -1
- tico/serialize/operators/op_rmsnorm.py +65 -0
- tico/serialize/operators/op_softmax.py +7 -14
- tico/serialize/operators/op_split_with_sizes.py +16 -8
- tico/serialize/operators/op_transpose_conv.py +11 -8
- tico/serialize/operators/op_view.py +2 -1
- tico/serialize/quant_param.py +5 -5
- tico/utils/convert.py +30 -17
- tico/utils/dtype.py +42 -0
- tico/utils/graph.py +1 -1
- tico/utils/model.py +2 -1
- tico/utils/padding.py +2 -2
- tico/utils/pytree_utils.py +134 -0
- tico/utils/record_input.py +102 -0
- tico/utils/register_custom_op.py +29 -4
- tico/utils/serialize.py +16 -3
- tico/utils/signature.py +247 -0
- tico/utils/torch_compat.py +52 -0
- tico/utils/utils.py +50 -58
- tico/utils/validate_args_kwargs.py +38 -3
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
- tico-0.1.0.dev251102.dist-info/RECORD +271 -0
- tico/experimental/quantization/__init__.py +0 -1
- tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/evaluation/metric.py +0 -109
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
- tico-0.1.0.dev250714.dist-info/RECORD +0 -209
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import copy
|
|
16
|
+
|
|
17
|
+
import inspect
|
|
18
|
+
from typing import Callable, List, Optional
|
|
19
|
+
|
|
20
|
+
import torch.nn as nn
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class RecordingInput:
|
|
24
|
+
r"""Context-manager that records the input values of model::forward()
|
|
25
|
+
|
|
26
|
+
Recording input is useful for preparing example input for torch.export
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
condition: lambda to provide the condition whether to record or not
|
|
30
|
+
|
|
31
|
+
For examples, if you want to capture only args["past_key_values"] is not None,
|
|
32
|
+
conditon = lambda args_dict: args_dict["past_key_value"] is not None
|
|
33
|
+
|
|
34
|
+
input_to_remove: list of arg names to remove
|
|
35
|
+
|
|
36
|
+
Sometimes you would like to remove some arg values to make exported graph tidy or correct
|
|
37
|
+
For example, "past_key_values" may be not None, but just an empty cache. Then,
|
|
38
|
+
input_to_remove = [ "past_key_values" ]; makes the life easy
|
|
39
|
+
|
|
40
|
+
Example::
|
|
41
|
+
>>> with RecordingInput(model, input_to_remove=input_to_remove) as rec:
|
|
42
|
+
... outputs = model.generate(
|
|
43
|
+
... **inputs,
|
|
44
|
+
... )
|
|
45
|
+
... captured_input = rec.captured_input
|
|
46
|
+
>>> circle_model = tico.convert(model, captured_input)
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
module: nn.Module,
|
|
52
|
+
condition: Callable[[dict], bool] = lambda args_dict: True,
|
|
53
|
+
*,
|
|
54
|
+
input_to_remove: Optional[List[str]] = [],
|
|
55
|
+
):
|
|
56
|
+
self.module = module
|
|
57
|
+
self.forward_org = module.forward
|
|
58
|
+
self.condition = condition
|
|
59
|
+
self.input_to_remove = input_to_remove
|
|
60
|
+
self.sig = inspect.signature(self.forward_org)
|
|
61
|
+
|
|
62
|
+
for param in self.sig.parameters.values():
|
|
63
|
+
if param.kind == inspect.Parameter.KEYWORD_ONLY:
|
|
64
|
+
raise ValueError(f"Keyword-only parameter not supported: {param.name}")
|
|
65
|
+
if param.kind == inspect.Parameter.VAR_POSITIONAL:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"Var positional parameter not supported: {param.name}"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# NOTE: the name `kwargs` is removed since `kwargs` is a dict, not arg itself.
|
|
71
|
+
# args in kwargs are kept via sig.bind(*args, **kwargs) in capture_and_forward.
|
|
72
|
+
self.args_names = [
|
|
73
|
+
name
|
|
74
|
+
for name, param in self.sig.parameters.items()
|
|
75
|
+
if param.kind != inspect.Parameter.VAR_KEYWORD and name != "self"
|
|
76
|
+
]
|
|
77
|
+
self.captured_input = None
|
|
78
|
+
|
|
79
|
+
def __enter__(self):
|
|
80
|
+
def capture_and_forward(*args, **kwargs):
|
|
81
|
+
bound = self.sig.bind(*args, **kwargs)
|
|
82
|
+
bound.apply_defaults()
|
|
83
|
+
args_dict = dict(bound.arguments)
|
|
84
|
+
|
|
85
|
+
def populate_args(args_dict, input_to_remove):
|
|
86
|
+
for key in input_to_remove:
|
|
87
|
+
args_dict.pop(key, None)
|
|
88
|
+
args_tuple = tuple(
|
|
89
|
+
args_dict.get(name, None) for name in self.args_names
|
|
90
|
+
)
|
|
91
|
+
return copy.deepcopy(args_tuple)
|
|
92
|
+
|
|
93
|
+
if self.condition(args_dict) and self.captured_input is None:
|
|
94
|
+
self.captured_input = populate_args(args_dict, self.input_to_remove)
|
|
95
|
+
|
|
96
|
+
return self.forward_org(*args, **kwargs)
|
|
97
|
+
|
|
98
|
+
self.module.forward = capture_and_forward
|
|
99
|
+
return self
|
|
100
|
+
|
|
101
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
102
|
+
self.module.forward = self.forward_org
|
tico/utils/register_custom_op.py
CHANGED
|
@@ -31,9 +31,11 @@ def CircleResizeNearestNeighbor():
|
|
|
31
31
|
W_scale_factor = size[2] / W
|
|
32
32
|
if H_scale_factor != W_scale_factor:
|
|
33
33
|
raise RuntimeError("Scale factor of H and W should be same.")
|
|
34
|
-
|
|
35
|
-
|
|
34
|
+
permuted = torch.permute(input_, [0, 3, 1, 2])
|
|
35
|
+
resized = torch.nn.functional.interpolate(
|
|
36
|
+
permuted, scale_factor=H_scale_factor, mode="nearest"
|
|
36
37
|
)
|
|
38
|
+
return torch.permute(resized, [0, 2, 3, 1])
|
|
37
39
|
|
|
38
40
|
@register_fake("circle_custom::resize_nearest_neighbor")
|
|
39
41
|
def _(input_: torch.Tensor, size: List[int]):
|
|
@@ -631,7 +633,7 @@ def CircleInstanceNorm():
|
|
|
631
633
|
bias: Optional[torch.Tensor] = None,
|
|
632
634
|
running_mean: Optional[torch.Tensor] = None,
|
|
633
635
|
running_var: Optional[torch.Tensor] = None,
|
|
634
|
-
use_input_stats: bool =
|
|
636
|
+
use_input_stats: bool = True,
|
|
635
637
|
momentum: float = 0.1,
|
|
636
638
|
eps: float = 1e-05,
|
|
637
639
|
cudnn_enabled: bool = False,
|
|
@@ -639,7 +641,7 @@ def CircleInstanceNorm():
|
|
|
639
641
|
NHWC_to_NCHW = [0, 3, 1, 2]
|
|
640
642
|
NCHW_input = torch.ops.aten.permute.default(input_, NHWC_to_NCHW)
|
|
641
643
|
|
|
642
|
-
args = [NCHW_input, weight, bias, None, None,
|
|
644
|
+
args = [NCHW_input, weight, bias, None, None, True, momentum, eps, False]
|
|
643
645
|
NCHW_output = torch.ops.aten.instance_norm.default(*args)
|
|
644
646
|
NCHW_to_NHWC = [0, 2, 3, 1]
|
|
645
647
|
NHWC_output = torch.ops.aten.permute.default(NCHW_output, NCHW_to_NHWC)
|
|
@@ -703,6 +705,28 @@ def CircleQuantizeMX():
|
|
|
703
705
|
return input_
|
|
704
706
|
|
|
705
707
|
|
|
708
|
+
def CircleRMSNorm():
|
|
709
|
+
@custom_op("circle_custom::rms_norm", mutates_args=())
|
|
710
|
+
def rms_norm(
|
|
711
|
+
hidden_states: torch.Tensor,
|
|
712
|
+
weight: torch.Tensor,
|
|
713
|
+
eps: float = 1e-05,
|
|
714
|
+
) -> torch.Tensor:
|
|
715
|
+
input_dtype = hidden_states.dtype
|
|
716
|
+
hidden_states = hidden_states.to(torch.float32)
|
|
717
|
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
718
|
+
hidden_states = hidden_states * torch.rsqrt(variance + eps)
|
|
719
|
+
return weight * hidden_states.to(input_dtype)
|
|
720
|
+
|
|
721
|
+
@register_fake("circle_custom::rms_norm")
|
|
722
|
+
def _(
|
|
723
|
+
hidden_states: torch.Tensor,
|
|
724
|
+
weight: torch.Tensor,
|
|
725
|
+
eps: float = 1e-05,
|
|
726
|
+
) -> torch.Tensor:
|
|
727
|
+
return hidden_states.new_empty(hidden_states.size())
|
|
728
|
+
|
|
729
|
+
|
|
706
730
|
# Add custom ops to the torch namespace
|
|
707
731
|
def RegisterOps():
|
|
708
732
|
CircleResizeNearestNeighbor()
|
|
@@ -715,3 +739,4 @@ def RegisterOps():
|
|
|
715
739
|
CircleAvgPool2D()
|
|
716
740
|
CircleInstanceNorm()
|
|
717
741
|
CircleQuantizeMX()
|
|
742
|
+
CircleRMSNorm()
|
tico/utils/serialize.py
CHANGED
|
@@ -12,11 +12,9 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Optional
|
|
16
|
-
|
|
17
|
-
import torch
|
|
18
15
|
|
|
19
16
|
from tico.serialize.circle_graph import CircleSubgraph
|
|
17
|
+
from tico.serialize.circle_mapping import validate_circle_shape
|
|
20
18
|
from tico.utils.graph import get_module_name_chain
|
|
21
19
|
|
|
22
20
|
|
|
@@ -40,3 +38,18 @@ def finalise_tensor_names(
|
|
|
40
38
|
for tensor in graph.tensors:
|
|
41
39
|
if tensor.name in graph.name_to_node:
|
|
42
40
|
tensor.name = f"{get_module_name_chain(graph.name_to_node[tensor.name])}::{tensor.name}"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def validate_tensor_shapes(
|
|
44
|
+
graph: CircleSubgraph,
|
|
45
|
+
) -> None:
|
|
46
|
+
"""
|
|
47
|
+
Let's validate all tensors' shapes against their shape signatures.
|
|
48
|
+
"""
|
|
49
|
+
for tensor in graph.tensors:
|
|
50
|
+
try:
|
|
51
|
+
validate_circle_shape(tensor.shape, tensor.shapeSignature)
|
|
52
|
+
except Exception as e:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
f"Tensor {tensor.name} has invalid shape ({tensor.shape}), shape_signature ({tensor.shapeSignature})"
|
|
55
|
+
) from e
|
tico/utils/signature.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Sequence
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import torch
|
|
19
|
+
from circle_schema import circle
|
|
20
|
+
|
|
21
|
+
from tico.serialize.circle_mapping import to_circle_shape
|
|
22
|
+
from tico.utils.dtype import circle_dtype_to_torch_dtype
|
|
23
|
+
from tico.utils.installed_packages import is_dynamic_cache_available
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def is_dynamic_cache_instance(value):
|
|
27
|
+
if is_dynamic_cache_available():
|
|
28
|
+
from transformers.cache_utils import DynamicCache
|
|
29
|
+
|
|
30
|
+
return isinstance(value, DynamicCache)
|
|
31
|
+
else:
|
|
32
|
+
return False
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def flatten_and_convert_kwargs(kwargs: dict) -> dict[str, torch.Tensor]:
|
|
36
|
+
result = {} # type: ignore[var-annotated]
|
|
37
|
+
for k, v in kwargs.items():
|
|
38
|
+
if v is None:
|
|
39
|
+
continue
|
|
40
|
+
elif isinstance(v, (list, tuple)):
|
|
41
|
+
# 1. handle list
|
|
42
|
+
def unpack_recursive(name, value, store=None):
|
|
43
|
+
if store is None:
|
|
44
|
+
store = {}
|
|
45
|
+
|
|
46
|
+
if isinstance(value, (tuple, list)):
|
|
47
|
+
for i, v in enumerate(value):
|
|
48
|
+
# recursive call. Append index to name and explore lower level
|
|
49
|
+
unpack_recursive(f"{name}_{i}", v, store)
|
|
50
|
+
else:
|
|
51
|
+
# base type (scalar etc.) directly stored
|
|
52
|
+
store[name] = value
|
|
53
|
+
|
|
54
|
+
return store
|
|
55
|
+
|
|
56
|
+
unpack_recursive(k, v, result)
|
|
57
|
+
elif is_dynamic_cache_instance(v):
|
|
58
|
+
# 2. handle DynamicCache
|
|
59
|
+
for idx, cache_val in enumerate(v.key_cache):
|
|
60
|
+
result[f"{k}_key_cache_{idx}"] = cache_val
|
|
61
|
+
|
|
62
|
+
for idx, cache_val in enumerate(v.value_cache):
|
|
63
|
+
result[f"{k}_value_cache_{idx}"] = cache_val
|
|
64
|
+
else:
|
|
65
|
+
result[k] = v
|
|
66
|
+
|
|
67
|
+
# 3. Convert to tensors
|
|
68
|
+
for k, v in result.items():
|
|
69
|
+
result[k] = v if isinstance(v, torch.Tensor) else torch.tensor(v)
|
|
70
|
+
|
|
71
|
+
return result
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def flatten_and_convert_args(args: Sequence) -> tuple:
|
|
75
|
+
result = [] # type: ignore[var-annotated]
|
|
76
|
+
for item in args:
|
|
77
|
+
if item is None:
|
|
78
|
+
continue
|
|
79
|
+
|
|
80
|
+
# 1. recursion on list and tuple
|
|
81
|
+
if isinstance(item, (list, tuple)):
|
|
82
|
+
result.extend(flatten_and_convert_args(item))
|
|
83
|
+
continue
|
|
84
|
+
|
|
85
|
+
# 2. handle DynamicCache
|
|
86
|
+
if is_dynamic_cache_available():
|
|
87
|
+
from transformers.cache_utils import DynamicCache
|
|
88
|
+
|
|
89
|
+
if isinstance(item, DynamicCache):
|
|
90
|
+
# NOTE The tensor order is: key_in → key_out → value_in → value_out
|
|
91
|
+
#
|
|
92
|
+
# Refer to https://github.com/huggingface/transformers/blob/3457e8e73e4f5532cc69059682b1ba4484d7e7e8/src/transformers/cache_utils.py#L557
|
|
93
|
+
# ```
|
|
94
|
+
# self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
|
95
|
+
# self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
|
96
|
+
# ```
|
|
97
|
+
result.extend(item.key_cache)
|
|
98
|
+
result.extend(item.value_cache)
|
|
99
|
+
continue
|
|
100
|
+
|
|
101
|
+
# 3. Convert to tensors
|
|
102
|
+
result.append(item if isinstance(item, torch.Tensor) else torch.tensor(item))
|
|
103
|
+
|
|
104
|
+
return tuple(result)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class ModelInputSpec:
|
|
108
|
+
@classmethod
|
|
109
|
+
def load(cls, circle_path):
|
|
110
|
+
def load(circle_path: str) -> bytes:
|
|
111
|
+
with open(circle_path, "rb") as f:
|
|
112
|
+
buf = bytes(f.read())
|
|
113
|
+
return buf
|
|
114
|
+
|
|
115
|
+
circle_binary = load(circle_path)
|
|
116
|
+
return cls(circle_binary)
|
|
117
|
+
|
|
118
|
+
def __init__(self, circle_binary):
|
|
119
|
+
model = circle.Model.Model.GetRootAsModel(circle_binary, 0)
|
|
120
|
+
assert model.SubgraphsLength() == 1, "Only one subgraph is supported"
|
|
121
|
+
|
|
122
|
+
graph = model.Subgraphs(0)
|
|
123
|
+
tensors = [graph.Tensors(graph.Inputs(o)) for o in range(graph.InputsLength())]
|
|
124
|
+
|
|
125
|
+
self.names = [t.Name().decode("utf-8").split("::")[-1] for t in tensors]
|
|
126
|
+
self.shapes = [t.ShapeAsNumpy() for t in tensors]
|
|
127
|
+
self.shape_signatures = list(
|
|
128
|
+
map(
|
|
129
|
+
lambda x: None if (isinstance(x, int) and x == 0) else x,
|
|
130
|
+
(t.ShapeSignatureAsNumpy() for t in tensors),
|
|
131
|
+
)
|
|
132
|
+
)
|
|
133
|
+
self.types: list[torch.dtype] = [
|
|
134
|
+
circle_dtype_to_torch_dtype(t.Type()) for t in tensors
|
|
135
|
+
]
|
|
136
|
+
self.name_to_idx = {name: idx for idx, name in enumerate(self.names)}
|
|
137
|
+
|
|
138
|
+
def bind(self, args, kwargs, check=True):
|
|
139
|
+
"""Convert args and kwargs into an ordered list according to model input order"""
|
|
140
|
+
inputs = []
|
|
141
|
+
args = flatten_and_convert_args(args)
|
|
142
|
+
kwargs = flatten_and_convert_kwargs(kwargs)
|
|
143
|
+
|
|
144
|
+
arg_num = len(args) + len(kwargs)
|
|
145
|
+
m_input_num = len(self.names)
|
|
146
|
+
if arg_num != m_input_num:
|
|
147
|
+
raise ValueError(
|
|
148
|
+
f"Mismatch: number of model inputs and number of passed arguments are not the same: inputs({m_input_num}) != passed({arg_num}), input spec: {self.names}"
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# 1. positional arguments
|
|
152
|
+
for i, val in enumerate(args):
|
|
153
|
+
name = self.names[i]
|
|
154
|
+
inputs.append(val)
|
|
155
|
+
|
|
156
|
+
# 2. keyword arguments
|
|
157
|
+
for idx in range(len(args), len(self.names)):
|
|
158
|
+
name = self.names[idx]
|
|
159
|
+
inputs.append(kwargs[name])
|
|
160
|
+
|
|
161
|
+
if check:
|
|
162
|
+
self.check_types(inputs)
|
|
163
|
+
self.check_shapes(inputs)
|
|
164
|
+
|
|
165
|
+
return inputs
|
|
166
|
+
|
|
167
|
+
def check_types(self, inputs):
|
|
168
|
+
"""Check the types of input values"""
|
|
169
|
+
for i, (inp, ref_type) in enumerate(zip(inputs, self.types)):
|
|
170
|
+
# TODO: Support more data types (np array)
|
|
171
|
+
assert isinstance(
|
|
172
|
+
inp, (torch.Tensor | int | float)
|
|
173
|
+
), f"Input '{self.names[i]}' type must be a torch tensor or scalar."
|
|
174
|
+
|
|
175
|
+
if isinstance(inp, torch.Tensor):
|
|
176
|
+
if inp.dtype != ref_type:
|
|
177
|
+
raise TypeError(
|
|
178
|
+
f"Input '{self.names[i]}' type {inp.dtype} != expected {ref_type}"
|
|
179
|
+
)
|
|
180
|
+
else:
|
|
181
|
+
# Scalars (int, float)
|
|
182
|
+
if ref_type == torch.float32:
|
|
183
|
+
if not isinstance(inp, (float)):
|
|
184
|
+
raise TypeError(
|
|
185
|
+
f"Input '{self.names[i]}' type {type(inp)} != expected {ref_type}"
|
|
186
|
+
)
|
|
187
|
+
elif ref_type == torch.int64:
|
|
188
|
+
if not isinstance(inp, (int)):
|
|
189
|
+
raise TypeError(
|
|
190
|
+
f"Input '{self.names[i]}' type {type(inp)} != expected {ref_type}"
|
|
191
|
+
)
|
|
192
|
+
else:
|
|
193
|
+
print(f"Unexpected ref_type: {ref_type}")
|
|
194
|
+
|
|
195
|
+
def check_shapes(self, inputs):
|
|
196
|
+
"""Check the shapes of input values"""
|
|
197
|
+
|
|
198
|
+
def merge(shape, shape_sig):
|
|
199
|
+
"""
|
|
200
|
+
Merge shape signature with shape
|
|
201
|
+
"""
|
|
202
|
+
from copy import deepcopy
|
|
203
|
+
|
|
204
|
+
shape_merged = deepcopy(shape)
|
|
205
|
+
if shape_sig is not None:
|
|
206
|
+
for idx, ss in enumerate(shape_sig):
|
|
207
|
+
if ss == -1:
|
|
208
|
+
shape_merged[idx] = -1
|
|
209
|
+
|
|
210
|
+
return shape_merged
|
|
211
|
+
|
|
212
|
+
for i, (inp, ref_shape, ref_shape_sig) in enumerate(
|
|
213
|
+
zip(inputs, self.shapes, self.shape_signatures)
|
|
214
|
+
):
|
|
215
|
+
# TODO: Support more data types (np array)
|
|
216
|
+
assert isinstance(
|
|
217
|
+
inp, (torch.Tensor | int | float)
|
|
218
|
+
), f"Input '{self.names[i]}' type must be a torch tensor or scalar."
|
|
219
|
+
|
|
220
|
+
if isinstance(inp, torch.Tensor): # Tensor
|
|
221
|
+
in_shape, in_shape_sig = to_circle_shape(inp.size())
|
|
222
|
+
|
|
223
|
+
if len(in_shape) != len(ref_shape):
|
|
224
|
+
raise ValueError(
|
|
225
|
+
f"Input '{self.names[i]}' has invalid rank {len(in_shape)}!= expected {len(ref_shape)}"
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
in_merged_shape = merge(in_shape, in_shape_sig)
|
|
229
|
+
ref_merged_shape = merge(ref_shape, ref_shape_sig)
|
|
230
|
+
for in_shp, ref_shp in zip(in_merged_shape, ref_merged_shape):
|
|
231
|
+
if ref_shp == -1:
|
|
232
|
+
continue
|
|
233
|
+
if in_shp == -1:
|
|
234
|
+
raise ValueError(
|
|
235
|
+
f"Input '{self.names[i]}' has unknown dimension {inp.size()} != expected shape({ref_shape}) / shape signature({ref_shape_sig}) "
|
|
236
|
+
)
|
|
237
|
+
if in_shp != ref_shp:
|
|
238
|
+
raise ValueError(
|
|
239
|
+
f"Input '{self.names[i]}' has wrong dimension {inp.size()} != expected shape({ref_shape}) / shape signature({ref_shape_sig}) "
|
|
240
|
+
)
|
|
241
|
+
elif isinstance(inp, (int, float)): # Scalar
|
|
242
|
+
if len(ref_shape) > 0:
|
|
243
|
+
raise ValueError(
|
|
244
|
+
f"Input '{self.names[i]}' has invalid rank {len(ref_shape)}"
|
|
245
|
+
)
|
|
246
|
+
else:
|
|
247
|
+
print(f"Unexpected input type: {type(inp)}")
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
Runtime **capability-detection helpers** for the `torch.export` stack.
|
|
17
|
+
|
|
18
|
+
Instead of sprinkling version checks like `torch.__version__ >= "2.9"` throughout
|
|
19
|
+
the codebase, import these helpers once and branch on the feature you need.
|
|
20
|
+
|
|
21
|
+
Each probe executes only **once per process** thanks to `functools.lru_cache`,
|
|
22
|
+
so the overhead is negligible.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
import functools
|
|
26
|
+
|
|
27
|
+
import torch
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@functools.lru_cache(maxsize=None)
|
|
31
|
+
def export_produces_slice() -> bool:
|
|
32
|
+
"""
|
|
33
|
+
Compile a minimal model with `torch.export.export` and inspect its FX graph
|
|
34
|
+
to see whether an `aten.slice.Tensor` node appears.
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
bool
|
|
39
|
+
* ``True`` — downstream passes should expect redundant **slice** nodes.
|
|
40
|
+
* ``False`` — downstream passes should expect only a **select** node.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
class _Probe(torch.nn.Module):
|
|
44
|
+
def forward(self, x): # simple slice: keep all dims except 3rd
|
|
45
|
+
return x[:, :, 1]
|
|
46
|
+
|
|
47
|
+
def get_example_inputs(self):
|
|
48
|
+
return (torch.randn(1, 4, 4),)
|
|
49
|
+
|
|
50
|
+
m = _Probe()
|
|
51
|
+
ep = torch.export.export(m, m.get_example_inputs())
|
|
52
|
+
return any(n.target == torch.ops.aten.slice.Tensor for n in ep.graph.nodes)
|
tico/utils/utils.py
CHANGED
|
@@ -21,7 +21,6 @@ from typing import List
|
|
|
21
21
|
|
|
22
22
|
import torch
|
|
23
23
|
from circle_schema import circle
|
|
24
|
-
from packaging.version import Version
|
|
25
24
|
from torch._guards import detect_fake_mode
|
|
26
25
|
from torch.export import ExportedProgram
|
|
27
26
|
from torch.utils import _pytree as pytree
|
|
@@ -29,10 +28,6 @@ from torch.utils import _pytree as pytree
|
|
|
29
28
|
from tico.serialize.quant_param import QuantParam
|
|
30
29
|
|
|
31
30
|
|
|
32
|
-
HAS_TORCH_OVER_25 = Version(torch.__version__) >= Version("2.5.0")
|
|
33
|
-
HAS_TORCH_OVER_28_DEV = Version(torch.__version__) >= Version("2.8.0.dev")
|
|
34
|
-
|
|
35
|
-
|
|
36
31
|
def get_fake_mode(exported_program: ExportedProgram):
|
|
37
32
|
fake_mode = detect_fake_mode(
|
|
38
33
|
tuple(
|
|
@@ -84,73 +79,70 @@ def enforce_type(callable):
|
|
|
84
79
|
def check_types(*args, **kwargs):
|
|
85
80
|
parameters = dict(zip(spec.args, args))
|
|
86
81
|
parameters.update(kwargs)
|
|
87
|
-
for name, value in parameters.items():
|
|
88
|
-
if name == "self":
|
|
89
|
-
# skip 'self' in spec.args
|
|
90
|
-
continue
|
|
91
82
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
#
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
actual_type = tuple(
|
|
109
|
-
[_flatten_type(t) for t in typing.get_args(type_hint)]
|
|
110
|
-
)
|
|
111
|
-
else:
|
|
112
|
-
actual_type = (type_hint,)
|
|
113
|
-
return actual_type
|
|
83
|
+
# Return tuple of flattened types.
|
|
84
|
+
# Q) What is flatten?
|
|
85
|
+
# A) Optional/Union is not included. Below are included.
|
|
86
|
+
# collections: List, Set, ...
|
|
87
|
+
# primitive types: int, str, ...
|
|
88
|
+
def _flatten_type(type_hint) -> tuple:
|
|
89
|
+
# `get_origin` maps Union[...] and Optional[...] varieties to Union
|
|
90
|
+
if typing.get_origin(type_hint) == typing.Union:
|
|
91
|
+
# ex. typing.Union[list, int] -> (list, int)
|
|
92
|
+
# ex. typing.Optional[torch.fx.Node] -> (torch.fx.Node, NoneType)
|
|
93
|
+
actual_type = tuple(
|
|
94
|
+
_flatten_type(t) for t in typing.get_args(type_hint)
|
|
95
|
+
)
|
|
96
|
+
else:
|
|
97
|
+
actual_type = (type_hint,)
|
|
98
|
+
return actual_type
|
|
114
99
|
|
|
115
|
-
|
|
100
|
+
# Return true if value matches with type_hint
|
|
101
|
+
# Return false otherwise
|
|
102
|
+
def _check_type(value, type_hint):
|
|
103
|
+
if type_hint == typing.Any:
|
|
104
|
+
return True
|
|
116
105
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
def _check_type(value, type_hint):
|
|
120
|
-
if type_hint == typing.Any:
|
|
121
|
-
return True
|
|
106
|
+
if isinstance(type_hint, tuple):
|
|
107
|
+
return any(_check_type(value, t) for t in type_hint)
|
|
122
108
|
|
|
123
|
-
|
|
124
|
-
|
|
109
|
+
if typing.get_origin(type_hint) in (list, set):
|
|
110
|
+
if not isinstance(value, typing.get_origin(type_hint)):
|
|
111
|
+
return False
|
|
125
112
|
|
|
126
|
-
|
|
127
|
-
if not
|
|
113
|
+
for v in value:
|
|
114
|
+
if not any(_check_type(v, t) for t in typing.get_args(type_hint)):
|
|
128
115
|
return False
|
|
129
116
|
|
|
130
|
-
|
|
131
|
-
if not any(
|
|
132
|
-
[_check_type(v, t) for t in typing.get_args(type_hint)]
|
|
133
|
-
):
|
|
134
|
-
return False
|
|
117
|
+
return True
|
|
135
118
|
|
|
136
|
-
|
|
119
|
+
if typing.get_origin(type_hint) is dict:
|
|
120
|
+
if not isinstance(value, typing.get_origin(type_hint)):
|
|
121
|
+
return False
|
|
137
122
|
|
|
138
|
-
|
|
139
|
-
|
|
123
|
+
for k, v in value.items():
|
|
124
|
+
k_type, v_type = typing.get_args(type_hint)
|
|
125
|
+
if not _check_type(k, k_type):
|
|
126
|
+
return False
|
|
127
|
+
if not _check_type(v, v_type):
|
|
140
128
|
return False
|
|
141
129
|
|
|
142
|
-
|
|
143
|
-
k_type, v_type = typing.get_args(type_hint)
|
|
144
|
-
if not _check_type(k, k_type):
|
|
145
|
-
return False
|
|
146
|
-
if not _check_type(v, v_type):
|
|
147
|
-
return False
|
|
130
|
+
return True
|
|
148
131
|
|
|
149
|
-
|
|
132
|
+
# TODO: Support more type hints
|
|
133
|
+
return isinstance(value, type_hint)
|
|
150
134
|
|
|
151
|
-
|
|
152
|
-
|
|
135
|
+
for name, value in parameters.items():
|
|
136
|
+
if name == "self":
|
|
137
|
+
# skip 'self' in spec.args
|
|
138
|
+
continue
|
|
153
139
|
|
|
140
|
+
assert (
|
|
141
|
+
name in spec.annotations
|
|
142
|
+
), f"All parameter require type hints. {name} needs a type hint"
|
|
143
|
+
|
|
144
|
+
type_hint = spec.annotations[name]
|
|
145
|
+
type_hint = _flatten_type(type_hint)
|
|
154
146
|
type_check_result = _check_type(value, type_hint)
|
|
155
147
|
if not type_check_result:
|
|
156
148
|
raise ArgTypeError(
|