tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. tico/__init__.py +9 -1
  2. tico/config/base.py +1 -1
  3. tico/config/v1.py +5 -0
  4. tico/passes/cast_aten_where_arg_type.py +1 -1
  5. tico/passes/cast_clamp_mixed_type_args.py +169 -0
  6. tico/passes/cast_mixed_type_args.py +4 -2
  7. tico/passes/const_prop_pass.py +1 -1
  8. tico/passes/convert_conv1d_to_conv2d.py +1 -1
  9. tico/passes/convert_expand_to_slice_cat.py +153 -0
  10. tico/passes/convert_matmul_to_linear.py +312 -0
  11. tico/passes/convert_to_relu6.py +1 -1
  12. tico/passes/decompose_addmm.py +0 -3
  13. tico/passes/decompose_batch_norm.py +2 -2
  14. tico/passes/decompose_fake_quantize.py +0 -3
  15. tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
  16. tico/passes/decompose_group_norm.py +0 -3
  17. tico/passes/legalize_predefined_layout_operators.py +2 -11
  18. tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
  19. tico/passes/lower_to_slice.py +1 -1
  20. tico/passes/merge_consecutive_cat.py +1 -1
  21. tico/passes/ops.py +1 -1
  22. tico/passes/remove_redundant_assert_nodes.py +3 -1
  23. tico/passes/remove_redundant_expand.py +3 -6
  24. tico/passes/remove_redundant_reshape.py +5 -5
  25. tico/passes/segment_index_select.py +1 -1
  26. tico/quantization/__init__.py +6 -0
  27. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
  28. tico/quantization/algorithm/gptq/quantizer.py +292 -0
  29. tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
  30. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
  31. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  32. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  33. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  34. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  35. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
  36. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  37. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  38. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  39. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  40. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  41. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  42. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  43. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  44. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
  45. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  46. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  47. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  48. tico/quantization/config/base.py +26 -0
  49. tico/quantization/config/gptq.py +29 -0
  50. tico/quantization/config/pt2e.py +25 -0
  51. tico/quantization/config/ptq.py +119 -0
  52. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
  53. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
  54. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  55. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  56. tico/quantization/evaluation/metric.py +146 -0
  57. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  58. tico/quantization/passes/__init__.py +1 -0
  59. tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
  60. tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
  61. tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
  62. tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
  63. tico/{experimental/quantization → quantization}/public_interface.py +19 -18
  64. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  65. tico/quantization/quantizer_registry.py +73 -0
  66. tico/quantization/wrapq/__init__.py +1 -0
  67. tico/quantization/wrapq/dtypes.py +70 -0
  68. tico/quantization/wrapq/examples/__init__.py +1 -0
  69. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  70. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  71. tico/quantization/wrapq/examples/quantize_linear.py +107 -0
  72. tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
  73. tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
  74. tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
  75. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  76. tico/quantization/wrapq/mode.py +32 -0
  77. tico/quantization/wrapq/observers/__init__.py +1 -0
  78. tico/quantization/wrapq/observers/affine_base.py +128 -0
  79. tico/quantization/wrapq/observers/base.py +98 -0
  80. tico/quantization/wrapq/observers/ema.py +62 -0
  81. tico/quantization/wrapq/observers/identity.py +74 -0
  82. tico/quantization/wrapq/observers/minmax.py +39 -0
  83. tico/quantization/wrapq/observers/mx.py +60 -0
  84. tico/quantization/wrapq/qscheme.py +40 -0
  85. tico/quantization/wrapq/quantizer.py +179 -0
  86. tico/quantization/wrapq/utils/__init__.py +1 -0
  87. tico/quantization/wrapq/utils/introspection.py +167 -0
  88. tico/quantization/wrapq/utils/metrics.py +124 -0
  89. tico/quantization/wrapq/utils/reduce_utils.py +25 -0
  90. tico/quantization/wrapq/wrappers/__init__.py +1 -0
  91. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  92. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  93. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  94. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  95. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  96. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  97. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  98. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  99. tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
  100. tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
  101. tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
  102. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  103. tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
  104. tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
  105. tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
  106. tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
  107. tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
  108. tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
  109. tico/quantization/wrapq/wrappers/registry.py +125 -0
  110. tico/serialize/circle_graph.py +12 -4
  111. tico/serialize/circle_mapping.py +76 -2
  112. tico/serialize/circle_serializer.py +253 -148
  113. tico/serialize/operators/adapters/__init__.py +1 -0
  114. tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
  115. tico/serialize/operators/op_any.py +7 -14
  116. tico/serialize/operators/op_avg_pool2d.py +11 -4
  117. tico/serialize/operators/op_clamp.py +5 -7
  118. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  119. tico/serialize/operators/op_conv2d.py +14 -6
  120. tico/serialize/operators/op_copy.py +26 -3
  121. tico/serialize/operators/op_cumsum.py +3 -1
  122. tico/serialize/operators/op_depthwise_conv2d.py +17 -7
  123. tico/serialize/operators/op_full_like.py +0 -2
  124. tico/serialize/operators/op_index_select.py +8 -1
  125. tico/serialize/operators/op_instance_norm.py +0 -6
  126. tico/serialize/operators/op_le.py +54 -0
  127. tico/serialize/operators/op_log1p.py +3 -2
  128. tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
  129. tico/serialize/operators/op_mm.py +15 -131
  130. tico/serialize/operators/op_mul.py +2 -8
  131. tico/serialize/operators/op_pow.py +3 -1
  132. tico/serialize/operators/op_repeat.py +12 -3
  133. tico/serialize/operators/op_reshape.py +1 -1
  134. tico/serialize/operators/op_rmsnorm.py +65 -0
  135. tico/serialize/operators/op_softmax.py +7 -14
  136. tico/serialize/operators/op_split_with_sizes.py +16 -8
  137. tico/serialize/operators/op_transpose_conv.py +11 -8
  138. tico/serialize/operators/op_view.py +2 -1
  139. tico/serialize/quant_param.py +5 -5
  140. tico/utils/convert.py +30 -17
  141. tico/utils/dtype.py +42 -0
  142. tico/utils/graph.py +1 -1
  143. tico/utils/model.py +2 -1
  144. tico/utils/padding.py +2 -2
  145. tico/utils/pytree_utils.py +134 -0
  146. tico/utils/record_input.py +102 -0
  147. tico/utils/register_custom_op.py +29 -4
  148. tico/utils/serialize.py +16 -3
  149. tico/utils/signature.py +247 -0
  150. tico/utils/torch_compat.py +52 -0
  151. tico/utils/utils.py +50 -58
  152. tico/utils/validate_args_kwargs.py +38 -3
  153. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
  154. tico-0.1.0.dev251102.dist-info/RECORD +271 -0
  155. tico/experimental/quantization/__init__.py +0 -1
  156. tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
  157. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  158. tico/experimental/quantization/evaluation/metric.py +0 -109
  159. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
  160. tico-0.1.0.dev250714.dist-info/RECORD +0 -209
  161. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  162. /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
  163. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  164. /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
  165. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
  166. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  167. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  168. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
  169. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  170. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
  171. /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
  172. /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
  173. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  174. /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
  175. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  176. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  177. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  178. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
  179. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
  180. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
  181. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,102 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+
17
+ import inspect
18
+ from typing import Callable, List, Optional
19
+
20
+ import torch.nn as nn
21
+
22
+
23
+ class RecordingInput:
24
+ r"""Context-manager that records the input values of model::forward()
25
+
26
+ Recording input is useful for preparing example input for torch.export
27
+
28
+ Args:
29
+ condition: lambda to provide the condition whether to record or not
30
+
31
+ For examples, if you want to capture only args["past_key_values"] is not None,
32
+ conditon = lambda args_dict: args_dict["past_key_value"] is not None
33
+
34
+ input_to_remove: list of arg names to remove
35
+
36
+ Sometimes you would like to remove some arg values to make exported graph tidy or correct
37
+ For example, "past_key_values" may be not None, but just an empty cache. Then,
38
+ input_to_remove = [ "past_key_values" ]; makes the life easy
39
+
40
+ Example::
41
+ >>> with RecordingInput(model, input_to_remove=input_to_remove) as rec:
42
+ ... outputs = model.generate(
43
+ ... **inputs,
44
+ ... )
45
+ ... captured_input = rec.captured_input
46
+ >>> circle_model = tico.convert(model, captured_input)
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ module: nn.Module,
52
+ condition: Callable[[dict], bool] = lambda args_dict: True,
53
+ *,
54
+ input_to_remove: Optional[List[str]] = [],
55
+ ):
56
+ self.module = module
57
+ self.forward_org = module.forward
58
+ self.condition = condition
59
+ self.input_to_remove = input_to_remove
60
+ self.sig = inspect.signature(self.forward_org)
61
+
62
+ for param in self.sig.parameters.values():
63
+ if param.kind == inspect.Parameter.KEYWORD_ONLY:
64
+ raise ValueError(f"Keyword-only parameter not supported: {param.name}")
65
+ if param.kind == inspect.Parameter.VAR_POSITIONAL:
66
+ raise ValueError(
67
+ f"Var positional parameter not supported: {param.name}"
68
+ )
69
+
70
+ # NOTE: the name `kwargs` is removed since `kwargs` is a dict, not arg itself.
71
+ # args in kwargs are kept via sig.bind(*args, **kwargs) in capture_and_forward.
72
+ self.args_names = [
73
+ name
74
+ for name, param in self.sig.parameters.items()
75
+ if param.kind != inspect.Parameter.VAR_KEYWORD and name != "self"
76
+ ]
77
+ self.captured_input = None
78
+
79
+ def __enter__(self):
80
+ def capture_and_forward(*args, **kwargs):
81
+ bound = self.sig.bind(*args, **kwargs)
82
+ bound.apply_defaults()
83
+ args_dict = dict(bound.arguments)
84
+
85
+ def populate_args(args_dict, input_to_remove):
86
+ for key in input_to_remove:
87
+ args_dict.pop(key, None)
88
+ args_tuple = tuple(
89
+ args_dict.get(name, None) for name in self.args_names
90
+ )
91
+ return copy.deepcopy(args_tuple)
92
+
93
+ if self.condition(args_dict) and self.captured_input is None:
94
+ self.captured_input = populate_args(args_dict, self.input_to_remove)
95
+
96
+ return self.forward_org(*args, **kwargs)
97
+
98
+ self.module.forward = capture_and_forward
99
+ return self
100
+
101
+ def __exit__(self, exc_type, exc_value, traceback):
102
+ self.module.forward = self.forward_org
@@ -31,9 +31,11 @@ def CircleResizeNearestNeighbor():
31
31
  W_scale_factor = size[2] / W
32
32
  if H_scale_factor != W_scale_factor:
33
33
  raise RuntimeError("Scale factor of H and W should be same.")
34
- return torch.nn.functional.interpolate(
35
- input_, scale_factor=H_scale_factor, mode="nearest"
34
+ permuted = torch.permute(input_, [0, 3, 1, 2])
35
+ resized = torch.nn.functional.interpolate(
36
+ permuted, scale_factor=H_scale_factor, mode="nearest"
36
37
  )
38
+ return torch.permute(resized, [0, 2, 3, 1])
37
39
 
38
40
  @register_fake("circle_custom::resize_nearest_neighbor")
39
41
  def _(input_: torch.Tensor, size: List[int]):
@@ -631,7 +633,7 @@ def CircleInstanceNorm():
631
633
  bias: Optional[torch.Tensor] = None,
632
634
  running_mean: Optional[torch.Tensor] = None,
633
635
  running_var: Optional[torch.Tensor] = None,
634
- use_input_stats: bool = False,
636
+ use_input_stats: bool = True,
635
637
  momentum: float = 0.1,
636
638
  eps: float = 1e-05,
637
639
  cudnn_enabled: bool = False,
@@ -639,7 +641,7 @@ def CircleInstanceNorm():
639
641
  NHWC_to_NCHW = [0, 3, 1, 2]
640
642
  NCHW_input = torch.ops.aten.permute.default(input_, NHWC_to_NCHW)
641
643
 
642
- args = [NCHW_input, weight, bias, None, None, False, momentum, eps, False]
644
+ args = [NCHW_input, weight, bias, None, None, True, momentum, eps, False]
643
645
  NCHW_output = torch.ops.aten.instance_norm.default(*args)
644
646
  NCHW_to_NHWC = [0, 2, 3, 1]
645
647
  NHWC_output = torch.ops.aten.permute.default(NCHW_output, NCHW_to_NHWC)
@@ -703,6 +705,28 @@ def CircleQuantizeMX():
703
705
  return input_
704
706
 
705
707
 
708
+ def CircleRMSNorm():
709
+ @custom_op("circle_custom::rms_norm", mutates_args=())
710
+ def rms_norm(
711
+ hidden_states: torch.Tensor,
712
+ weight: torch.Tensor,
713
+ eps: float = 1e-05,
714
+ ) -> torch.Tensor:
715
+ input_dtype = hidden_states.dtype
716
+ hidden_states = hidden_states.to(torch.float32)
717
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
718
+ hidden_states = hidden_states * torch.rsqrt(variance + eps)
719
+ return weight * hidden_states.to(input_dtype)
720
+
721
+ @register_fake("circle_custom::rms_norm")
722
+ def _(
723
+ hidden_states: torch.Tensor,
724
+ weight: torch.Tensor,
725
+ eps: float = 1e-05,
726
+ ) -> torch.Tensor:
727
+ return hidden_states.new_empty(hidden_states.size())
728
+
729
+
706
730
  # Add custom ops to the torch namespace
707
731
  def RegisterOps():
708
732
  CircleResizeNearestNeighbor()
@@ -715,3 +739,4 @@ def RegisterOps():
715
739
  CircleAvgPool2D()
716
740
  CircleInstanceNorm()
717
741
  CircleQuantizeMX()
742
+ CircleRMSNorm()
tico/utils/serialize.py CHANGED
@@ -12,11 +12,9 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional
16
-
17
- import torch
18
15
 
19
16
  from tico.serialize.circle_graph import CircleSubgraph
17
+ from tico.serialize.circle_mapping import validate_circle_shape
20
18
  from tico.utils.graph import get_module_name_chain
21
19
 
22
20
 
@@ -40,3 +38,18 @@ def finalise_tensor_names(
40
38
  for tensor in graph.tensors:
41
39
  if tensor.name in graph.name_to_node:
42
40
  tensor.name = f"{get_module_name_chain(graph.name_to_node[tensor.name])}::{tensor.name}"
41
+
42
+
43
+ def validate_tensor_shapes(
44
+ graph: CircleSubgraph,
45
+ ) -> None:
46
+ """
47
+ Let's validate all tensors' shapes against their shape signatures.
48
+ """
49
+ for tensor in graph.tensors:
50
+ try:
51
+ validate_circle_shape(tensor.shape, tensor.shapeSignature)
52
+ except Exception as e:
53
+ raise ValueError(
54
+ f"Tensor {tensor.name} has invalid shape ({tensor.shape}), shape_signature ({tensor.shapeSignature})"
55
+ ) from e
@@ -0,0 +1,247 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Sequence
16
+
17
+ import numpy as np
18
+ import torch
19
+ from circle_schema import circle
20
+
21
+ from tico.serialize.circle_mapping import to_circle_shape
22
+ from tico.utils.dtype import circle_dtype_to_torch_dtype
23
+ from tico.utils.installed_packages import is_dynamic_cache_available
24
+
25
+
26
+ def is_dynamic_cache_instance(value):
27
+ if is_dynamic_cache_available():
28
+ from transformers.cache_utils import DynamicCache
29
+
30
+ return isinstance(value, DynamicCache)
31
+ else:
32
+ return False
33
+
34
+
35
+ def flatten_and_convert_kwargs(kwargs: dict) -> dict[str, torch.Tensor]:
36
+ result = {} # type: ignore[var-annotated]
37
+ for k, v in kwargs.items():
38
+ if v is None:
39
+ continue
40
+ elif isinstance(v, (list, tuple)):
41
+ # 1. handle list
42
+ def unpack_recursive(name, value, store=None):
43
+ if store is None:
44
+ store = {}
45
+
46
+ if isinstance(value, (tuple, list)):
47
+ for i, v in enumerate(value):
48
+ # recursive call. Append index to name and explore lower level
49
+ unpack_recursive(f"{name}_{i}", v, store)
50
+ else:
51
+ # base type (scalar etc.) directly stored
52
+ store[name] = value
53
+
54
+ return store
55
+
56
+ unpack_recursive(k, v, result)
57
+ elif is_dynamic_cache_instance(v):
58
+ # 2. handle DynamicCache
59
+ for idx, cache_val in enumerate(v.key_cache):
60
+ result[f"{k}_key_cache_{idx}"] = cache_val
61
+
62
+ for idx, cache_val in enumerate(v.value_cache):
63
+ result[f"{k}_value_cache_{idx}"] = cache_val
64
+ else:
65
+ result[k] = v
66
+
67
+ # 3. Convert to tensors
68
+ for k, v in result.items():
69
+ result[k] = v if isinstance(v, torch.Tensor) else torch.tensor(v)
70
+
71
+ return result
72
+
73
+
74
+ def flatten_and_convert_args(args: Sequence) -> tuple:
75
+ result = [] # type: ignore[var-annotated]
76
+ for item in args:
77
+ if item is None:
78
+ continue
79
+
80
+ # 1. recursion on list and tuple
81
+ if isinstance(item, (list, tuple)):
82
+ result.extend(flatten_and_convert_args(item))
83
+ continue
84
+
85
+ # 2. handle DynamicCache
86
+ if is_dynamic_cache_available():
87
+ from transformers.cache_utils import DynamicCache
88
+
89
+ if isinstance(item, DynamicCache):
90
+ # NOTE The tensor order is: key_in → key_out → value_in → value_out
91
+ #
92
+ # Refer to https://github.com/huggingface/transformers/blob/3457e8e73e4f5532cc69059682b1ba4484d7e7e8/src/transformers/cache_utils.py#L557
93
+ # ```
94
+ # self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
95
+ # self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
96
+ # ```
97
+ result.extend(item.key_cache)
98
+ result.extend(item.value_cache)
99
+ continue
100
+
101
+ # 3. Convert to tensors
102
+ result.append(item if isinstance(item, torch.Tensor) else torch.tensor(item))
103
+
104
+ return tuple(result)
105
+
106
+
107
+ class ModelInputSpec:
108
+ @classmethod
109
+ def load(cls, circle_path):
110
+ def load(circle_path: str) -> bytes:
111
+ with open(circle_path, "rb") as f:
112
+ buf = bytes(f.read())
113
+ return buf
114
+
115
+ circle_binary = load(circle_path)
116
+ return cls(circle_binary)
117
+
118
+ def __init__(self, circle_binary):
119
+ model = circle.Model.Model.GetRootAsModel(circle_binary, 0)
120
+ assert model.SubgraphsLength() == 1, "Only one subgraph is supported"
121
+
122
+ graph = model.Subgraphs(0)
123
+ tensors = [graph.Tensors(graph.Inputs(o)) for o in range(graph.InputsLength())]
124
+
125
+ self.names = [t.Name().decode("utf-8").split("::")[-1] for t in tensors]
126
+ self.shapes = [t.ShapeAsNumpy() for t in tensors]
127
+ self.shape_signatures = list(
128
+ map(
129
+ lambda x: None if (isinstance(x, int) and x == 0) else x,
130
+ (t.ShapeSignatureAsNumpy() for t in tensors),
131
+ )
132
+ )
133
+ self.types: list[torch.dtype] = [
134
+ circle_dtype_to_torch_dtype(t.Type()) for t in tensors
135
+ ]
136
+ self.name_to_idx = {name: idx for idx, name in enumerate(self.names)}
137
+
138
+ def bind(self, args, kwargs, check=True):
139
+ """Convert args and kwargs into an ordered list according to model input order"""
140
+ inputs = []
141
+ args = flatten_and_convert_args(args)
142
+ kwargs = flatten_and_convert_kwargs(kwargs)
143
+
144
+ arg_num = len(args) + len(kwargs)
145
+ m_input_num = len(self.names)
146
+ if arg_num != m_input_num:
147
+ raise ValueError(
148
+ f"Mismatch: number of model inputs and number of passed arguments are not the same: inputs({m_input_num}) != passed({arg_num}), input spec: {self.names}"
149
+ )
150
+
151
+ # 1. positional arguments
152
+ for i, val in enumerate(args):
153
+ name = self.names[i]
154
+ inputs.append(val)
155
+
156
+ # 2. keyword arguments
157
+ for idx in range(len(args), len(self.names)):
158
+ name = self.names[idx]
159
+ inputs.append(kwargs[name])
160
+
161
+ if check:
162
+ self.check_types(inputs)
163
+ self.check_shapes(inputs)
164
+
165
+ return inputs
166
+
167
+ def check_types(self, inputs):
168
+ """Check the types of input values"""
169
+ for i, (inp, ref_type) in enumerate(zip(inputs, self.types)):
170
+ # TODO: Support more data types (np array)
171
+ assert isinstance(
172
+ inp, (torch.Tensor | int | float)
173
+ ), f"Input '{self.names[i]}' type must be a torch tensor or scalar."
174
+
175
+ if isinstance(inp, torch.Tensor):
176
+ if inp.dtype != ref_type:
177
+ raise TypeError(
178
+ f"Input '{self.names[i]}' type {inp.dtype} != expected {ref_type}"
179
+ )
180
+ else:
181
+ # Scalars (int, float)
182
+ if ref_type == torch.float32:
183
+ if not isinstance(inp, (float)):
184
+ raise TypeError(
185
+ f"Input '{self.names[i]}' type {type(inp)} != expected {ref_type}"
186
+ )
187
+ elif ref_type == torch.int64:
188
+ if not isinstance(inp, (int)):
189
+ raise TypeError(
190
+ f"Input '{self.names[i]}' type {type(inp)} != expected {ref_type}"
191
+ )
192
+ else:
193
+ print(f"Unexpected ref_type: {ref_type}")
194
+
195
+ def check_shapes(self, inputs):
196
+ """Check the shapes of input values"""
197
+
198
+ def merge(shape, shape_sig):
199
+ """
200
+ Merge shape signature with shape
201
+ """
202
+ from copy import deepcopy
203
+
204
+ shape_merged = deepcopy(shape)
205
+ if shape_sig is not None:
206
+ for idx, ss in enumerate(shape_sig):
207
+ if ss == -1:
208
+ shape_merged[idx] = -1
209
+
210
+ return shape_merged
211
+
212
+ for i, (inp, ref_shape, ref_shape_sig) in enumerate(
213
+ zip(inputs, self.shapes, self.shape_signatures)
214
+ ):
215
+ # TODO: Support more data types (np array)
216
+ assert isinstance(
217
+ inp, (torch.Tensor | int | float)
218
+ ), f"Input '{self.names[i]}' type must be a torch tensor or scalar."
219
+
220
+ if isinstance(inp, torch.Tensor): # Tensor
221
+ in_shape, in_shape_sig = to_circle_shape(inp.size())
222
+
223
+ if len(in_shape) != len(ref_shape):
224
+ raise ValueError(
225
+ f"Input '{self.names[i]}' has invalid rank {len(in_shape)}!= expected {len(ref_shape)}"
226
+ )
227
+
228
+ in_merged_shape = merge(in_shape, in_shape_sig)
229
+ ref_merged_shape = merge(ref_shape, ref_shape_sig)
230
+ for in_shp, ref_shp in zip(in_merged_shape, ref_merged_shape):
231
+ if ref_shp == -1:
232
+ continue
233
+ if in_shp == -1:
234
+ raise ValueError(
235
+ f"Input '{self.names[i]}' has unknown dimension {inp.size()} != expected shape({ref_shape}) / shape signature({ref_shape_sig}) "
236
+ )
237
+ if in_shp != ref_shp:
238
+ raise ValueError(
239
+ f"Input '{self.names[i]}' has wrong dimension {inp.size()} != expected shape({ref_shape}) / shape signature({ref_shape_sig}) "
240
+ )
241
+ elif isinstance(inp, (int, float)): # Scalar
242
+ if len(ref_shape) > 0:
243
+ raise ValueError(
244
+ f"Input '{self.names[i]}' has invalid rank {len(ref_shape)}"
245
+ )
246
+ else:
247
+ print(f"Unexpected input type: {type(inp)}")
@@ -0,0 +1,52 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Runtime **capability-detection helpers** for the `torch.export` stack.
17
+
18
+ Instead of sprinkling version checks like `torch.__version__ >= "2.9"` throughout
19
+ the codebase, import these helpers once and branch on the feature you need.
20
+
21
+ Each probe executes only **once per process** thanks to `functools.lru_cache`,
22
+ so the overhead is negligible.
23
+ """
24
+
25
+ import functools
26
+
27
+ import torch
28
+
29
+
30
+ @functools.lru_cache(maxsize=None)
31
+ def export_produces_slice() -> bool:
32
+ """
33
+ Compile a minimal model with `torch.export.export` and inspect its FX graph
34
+ to see whether an `aten.slice.Tensor` node appears.
35
+
36
+ Returns
37
+ -------
38
+ bool
39
+ * ``True`` — downstream passes should expect redundant **slice** nodes.
40
+ * ``False`` — downstream passes should expect only a **select** node.
41
+ """
42
+
43
+ class _Probe(torch.nn.Module):
44
+ def forward(self, x): # simple slice: keep all dims except 3rd
45
+ return x[:, :, 1]
46
+
47
+ def get_example_inputs(self):
48
+ return (torch.randn(1, 4, 4),)
49
+
50
+ m = _Probe()
51
+ ep = torch.export.export(m, m.get_example_inputs())
52
+ return any(n.target == torch.ops.aten.slice.Tensor for n in ep.graph.nodes)
tico/utils/utils.py CHANGED
@@ -21,7 +21,6 @@ from typing import List
21
21
 
22
22
  import torch
23
23
  from circle_schema import circle
24
- from packaging.version import Version
25
24
  from torch._guards import detect_fake_mode
26
25
  from torch.export import ExportedProgram
27
26
  from torch.utils import _pytree as pytree
@@ -29,10 +28,6 @@ from torch.utils import _pytree as pytree
29
28
  from tico.serialize.quant_param import QuantParam
30
29
 
31
30
 
32
- HAS_TORCH_OVER_25 = Version(torch.__version__) >= Version("2.5.0")
33
- HAS_TORCH_OVER_28_DEV = Version(torch.__version__) >= Version("2.8.0.dev")
34
-
35
-
36
31
  def get_fake_mode(exported_program: ExportedProgram):
37
32
  fake_mode = detect_fake_mode(
38
33
  tuple(
@@ -84,73 +79,70 @@ def enforce_type(callable):
84
79
  def check_types(*args, **kwargs):
85
80
  parameters = dict(zip(spec.args, args))
86
81
  parameters.update(kwargs)
87
- for name, value in parameters.items():
88
- if name == "self":
89
- # skip 'self' in spec.args
90
- continue
91
82
 
92
- assert (
93
- name in spec.annotations
94
- ), f"All parameter require type hints. {name} needs a type hint"
95
-
96
- type_hint = spec.annotations[name]
97
-
98
- # Return tuple of flattened types.
99
- # Q) What is flatten?
100
- # A) Optional/Union is not included. Below are included.
101
- # collections: List, Set, ...
102
- # primitive types: int, str, ...
103
- def _flatten_type(type_hint) -> tuple:
104
- # `get_origin` maps Union[...] and Optional[...] varieties to Union
105
- if typing.get_origin(type_hint) == typing.Union:
106
- # ex. typing.Union[list, int] -> (list, int)
107
- # ex. typing.Optional[torch.fx.Node] -> (torch.fx.Node, NoneType)
108
- actual_type = tuple(
109
- [_flatten_type(t) for t in typing.get_args(type_hint)]
110
- )
111
- else:
112
- actual_type = (type_hint,)
113
- return actual_type
83
+ # Return tuple of flattened types.
84
+ # Q) What is flatten?
85
+ # A) Optional/Union is not included. Below are included.
86
+ # collections: List, Set, ...
87
+ # primitive types: int, str, ...
88
+ def _flatten_type(type_hint) -> tuple:
89
+ # `get_origin` maps Union[...] and Optional[...] varieties to Union
90
+ if typing.get_origin(type_hint) == typing.Union:
91
+ # ex. typing.Union[list, int] -> (list, int)
92
+ # ex. typing.Optional[torch.fx.Node] -> (torch.fx.Node, NoneType)
93
+ actual_type = tuple(
94
+ _flatten_type(t) for t in typing.get_args(type_hint)
95
+ )
96
+ else:
97
+ actual_type = (type_hint,)
98
+ return actual_type
114
99
 
115
- type_hint = _flatten_type(type_hint)
100
+ # Return true if value matches with type_hint
101
+ # Return false otherwise
102
+ def _check_type(value, type_hint):
103
+ if type_hint == typing.Any:
104
+ return True
116
105
 
117
- # Return true if value matches with type_hint
118
- # Return false otherwise
119
- def _check_type(value, type_hint):
120
- if type_hint == typing.Any:
121
- return True
106
+ if isinstance(type_hint, tuple):
107
+ return any(_check_type(value, t) for t in type_hint)
122
108
 
123
- if isinstance(type_hint, tuple):
124
- return any([_check_type(value, t) for t in type_hint])
109
+ if typing.get_origin(type_hint) in (list, set):
110
+ if not isinstance(value, typing.get_origin(type_hint)):
111
+ return False
125
112
 
126
- if typing.get_origin(type_hint) in (list, set):
127
- if not isinstance(value, typing.get_origin(type_hint)):
113
+ for v in value:
114
+ if not any(_check_type(v, t) for t in typing.get_args(type_hint)):
128
115
  return False
129
116
 
130
- for v in value:
131
- if not any(
132
- [_check_type(v, t) for t in typing.get_args(type_hint)]
133
- ):
134
- return False
117
+ return True
135
118
 
136
- return True
119
+ if typing.get_origin(type_hint) is dict:
120
+ if not isinstance(value, typing.get_origin(type_hint)):
121
+ return False
137
122
 
138
- if typing.get_origin(type_hint) == dict:
139
- if not isinstance(value, typing.get_origin(type_hint)):
123
+ for k, v in value.items():
124
+ k_type, v_type = typing.get_args(type_hint)
125
+ if not _check_type(k, k_type):
126
+ return False
127
+ if not _check_type(v, v_type):
140
128
  return False
141
129
 
142
- for k, v in value.items():
143
- k_type, v_type = typing.get_args(type_hint)
144
- if not _check_type(k, k_type):
145
- return False
146
- if not _check_type(v, v_type):
147
- return False
130
+ return True
148
131
 
149
- return True
132
+ # TODO: Support more type hints
133
+ return isinstance(value, type_hint)
150
134
 
151
- # TODO: Support more type hints
152
- return isinstance(value, type_hint)
135
+ for name, value in parameters.items():
136
+ if name == "self":
137
+ # skip 'self' in spec.args
138
+ continue
153
139
 
140
+ assert (
141
+ name in spec.annotations
142
+ ), f"All parameter require type hints. {name} needs a type hint"
143
+
144
+ type_hint = spec.annotations[name]
145
+ type_hint = _flatten_type(type_hint)
154
146
  type_check_result = _check_type(value, type_hint)
155
147
  if not type_check_result:
156
148
  raise ArgTypeError(