tico 0.1.0.dev250803__py3-none-any.whl → 0.1.0.dev251106__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. tico/__init__.py +1 -1
  2. tico/config/v1.py +5 -0
  3. tico/passes/cast_mixed_type_args.py +2 -0
  4. tico/passes/convert_expand_to_slice_cat.py +153 -0
  5. tico/passes/convert_matmul_to_linear.py +312 -0
  6. tico/passes/convert_to_relu6.py +1 -1
  7. tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -4
  8. tico/passes/ops.py +0 -1
  9. tico/passes/remove_redundant_assert_nodes.py +3 -1
  10. tico/passes/remove_redundant_expand.py +3 -1
  11. tico/quantization/__init__.py +6 -0
  12. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
  13. tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +30 -8
  14. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
  15. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  16. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  17. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  18. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  19. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
  20. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  21. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  22. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  23. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  24. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  25. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  26. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  27. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  28. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
  29. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  30. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  31. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  32. tico/quantization/config/base.py +26 -0
  33. tico/quantization/config/gptq.py +29 -0
  34. tico/quantization/config/pt2e.py +25 -0
  35. tico/quantization/config/ptq.py +119 -0
  36. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
  37. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +7 -16
  38. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  39. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  40. tico/quantization/evaluation/metric.py +146 -0
  41. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  42. tico/quantization/passes/__init__.py +1 -0
  43. tico/{experimental/quantization → quantization}/public_interface.py +11 -18
  44. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  45. tico/quantization/quantizer_registry.py +73 -0
  46. tico/quantization/wrapq/__init__.py +1 -0
  47. tico/quantization/wrapq/dtypes.py +70 -0
  48. tico/quantization/wrapq/examples/__init__.py +1 -0
  49. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  50. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  51. tico/quantization/wrapq/examples/quantize_linear.py +107 -0
  52. tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
  53. tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
  54. tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
  55. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  56. tico/quantization/wrapq/mode.py +32 -0
  57. tico/quantization/wrapq/observers/__init__.py +1 -0
  58. tico/quantization/wrapq/observers/affine_base.py +128 -0
  59. tico/quantization/wrapq/observers/base.py +98 -0
  60. tico/quantization/wrapq/observers/ema.py +62 -0
  61. tico/quantization/wrapq/observers/identity.py +74 -0
  62. tico/quantization/wrapq/observers/minmax.py +39 -0
  63. tico/quantization/wrapq/observers/mx.py +60 -0
  64. tico/quantization/wrapq/qscheme.py +40 -0
  65. tico/quantization/wrapq/quantizer.py +179 -0
  66. tico/quantization/wrapq/utils/__init__.py +1 -0
  67. tico/quantization/wrapq/utils/introspection.py +167 -0
  68. tico/quantization/wrapq/utils/metrics.py +124 -0
  69. tico/quantization/wrapq/utils/reduce_utils.py +25 -0
  70. tico/quantization/wrapq/wrappers/__init__.py +1 -0
  71. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  72. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  73. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  74. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  75. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  76. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  77. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  78. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  79. tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
  80. tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
  81. tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
  82. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  83. tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
  84. tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
  85. tico/quantization/wrapq/wrappers/nn/quant_silu.py +60 -0
  86. tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
  87. tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
  88. tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
  89. tico/quantization/wrapq/wrappers/registry.py +128 -0
  90. tico/serialize/circle_serializer.py +11 -4
  91. tico/serialize/operators/adapters/__init__.py +1 -0
  92. tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
  93. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  94. tico/serialize/operators/op_le.py +54 -0
  95. tico/serialize/operators/op_mm.py +15 -132
  96. tico/serialize/operators/op_rmsnorm.py +65 -0
  97. tico/utils/convert.py +20 -15
  98. tico/utils/dtype.py +22 -0
  99. tico/utils/register_custom_op.py +29 -4
  100. tico/utils/signature.py +247 -0
  101. tico/utils/utils.py +50 -53
  102. tico/utils/validate_args_kwargs.py +37 -0
  103. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/METADATA +49 -2
  104. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/RECORD +130 -73
  105. tico/experimental/quantization/__init__.py +0 -6
  106. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  107. tico/experimental/quantization/evaluation/metric.py +0 -109
  108. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  109. /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
  110. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  111. /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
  112. /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
  113. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
  114. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  115. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  116. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
  117. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  118. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
  119. /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
  120. /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
  121. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  122. /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
  123. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  124. /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
  125. /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
  126. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  127. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  128. /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
  129. /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
  130. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/LICENSE +0 -0
  131. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/WHEEL +0 -0
  132. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/entry_points.txt +0 -0
  133. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@ if TYPE_CHECKING:
20
20
  import torch
21
21
  from circle_schema import circle
22
22
 
23
- from tico.serialize.circle_graph import CircleSubgraph, is_const
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
24
  from tico.serialize.operators.hashable_opcode import OpCode
25
25
  from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
26
  from tico.serialize.operators.utils import create_builtin_operator, get_op_index
@@ -28,9 +28,9 @@ from tico.utils.validate_args_kwargs import MatmulArgs
28
28
 
29
29
 
30
30
  @register_node_visitor
31
- class MatmulDefaultVisitor(NodeVisitor):
31
+ class MatmulVisitor(NodeVisitor):
32
32
  """
33
- Convert matmul to equavalent BatchMatMul or FullyConnected with Transpose.
33
+ Convert matmul to Circle BatchMatMul
34
34
  """
35
35
 
36
36
  target: List[torch._ops.OpOverload] = [torch.ops.aten.mm.default]
@@ -38,131 +38,7 @@ class MatmulDefaultVisitor(NodeVisitor):
38
38
  def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
39
39
  super().__init__(op_codes, graph)
40
40
 
41
- # NOTE: Matmul is equivalent to Batch MatMul (batch=1)
42
- def define_bmm_node(self, inputs, outputs) -> circle.Operator.OperatorT:
43
- def set_bmm_option(operator):
44
- operator.builtinOptionsType = (
45
- circle.BuiltinOptions.BuiltinOptions.BatchMatMulOptions
46
- )
47
- option = circle.BatchMatMulOptions.BatchMatMulOptionsT()
48
- option.adjointLhs, option.adjointRhs = False, False
49
- option.asymmetricQuantizeInputs = False
50
- operator.builtinOptions = option
51
-
52
- op_index = get_op_index(
53
- circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, self._op_codes
54
- )
55
- operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
56
- set_bmm_option(operator)
57
-
58
- return operator
59
-
60
- def define_transpose_node(self, inputs, outputs) -> circle.Operator.OperatorT:
61
- def set_transpose_option(operator):
62
- operator.builtinOptionsType = (
63
- circle.BuiltinOptions.BuiltinOptions.TransposeOptions
64
- )
65
- option = circle.TransposeOptions.TransposeOptionsT()
66
- operator.builtinOptions = option
67
-
68
- transpose_op_index = get_op_index(
69
- circle.BuiltinOperator.BuiltinOperator.TRANSPOSE, self._op_codes
70
- )
71
- operator = create_builtin_operator(
72
- self.graph, transpose_op_index, inputs, outputs
73
- )
74
- set_transpose_option(operator)
75
- return operator
76
-
77
- def define_fc_node(self, inputs, outputs) -> circle.Operator.OperatorT:
78
- def set_fc_option(operator):
79
- operator.builtinOptionsType = (
80
- circle.BuiltinOptions.BuiltinOptions.FullyConnectedOptions
81
- )
82
- option = circle.FullyConnectedOptions.FullyConnectedOptionsT()
83
-
84
- option.fusedActivationFunction = (
85
- circle.ActivationFunctionType.ActivationFunctionType.NONE
86
- )
87
- option.weightsFormat = (
88
- circle.FullyConnectedOptionsWeightsFormat.FullyConnectedOptionsWeightsFormat.DEFAULT
89
- )
90
- option.keepNumDims = False
91
- option.asymmetricQuantizeInputs = False
92
- option.quantizedBiasType = circle.TensorType.TensorType.FLOAT32
93
-
94
- operator.builtinOptions = option
95
-
96
- fc_op_index = get_op_index(
97
- circle.BuiltinOperator.BuiltinOperator.FULLY_CONNECTED, self._op_codes
98
- )
99
- operator = create_builtin_operator(self.graph, fc_op_index, inputs, outputs)
100
- set_fc_option(operator)
101
- return operator
102
-
103
- """
104
- Define FullyConnnected with Tranpose operator.
105
- Note that those sets of operators are equivalent.
106
- (1) Matmul
107
- matmul( lhs[H, K], rhs[K, W'] ) -> output(H, W')
108
-
109
- (2) Transpose + FullyConneccted
110
- transpose( rhs[K, W'] ) -> trs_output[W', K]
111
- fullyconnected( lhs[H, K], trs_output[W', K] ) -> output(H, W')
112
- """
113
-
114
- def define_fc_with_transpose(
115
- self, node, inputs, outputs
116
- ) -> circle.Operator.OperatorT:
117
- lhs, rhs = inputs
118
-
119
- # get transpose shape
120
- rhs_tid: int = self.graph.get_tid_registered(rhs)
121
- rhs_tensor: circle.Tensor.TensorT = self.graph.tensors[rhs_tid]
122
- rhs_name: str = rhs.name
123
- rhs_type: int = rhs_tensor.type
124
- rhs_shape: List[int] = rhs_tensor.shape
125
- assert len(rhs_shape) == 2, len(rhs_shape)
126
- rhs_shape_transpose = [rhs_shape[1], rhs_shape[0]]
127
-
128
- # create transpose output tensor
129
- trs_output = self.graph.add_tensor_from_scratch(
130
- prefix=f"{rhs_name}_transposed_output",
131
- shape=rhs_shape_transpose,
132
- shape_signature=None,
133
- dtype=rhs_type,
134
- source_node=node,
135
- )
136
- trs_perm = self.graph.add_const_tensor(data=[1, 0], source_node=node)
137
- trs_operator = self.define_transpose_node([rhs, trs_perm], [trs_output])
138
- self.graph.add_operator(trs_operator)
139
-
140
- # define fc node
141
- fc_input = lhs
142
- fc_weight = trs_output
143
- fc_shape = [fc_weight.shape[0]]
144
- fc_bias = self.graph.add_const_tensor(
145
- data=[0.0] * fc_shape[0], source_node=node
146
- )
147
-
148
- operator = self.define_fc_node([fc_input, fc_weight, fc_bias], outputs)
149
-
150
- return operator
151
-
152
- def define_node(
153
- self, node: torch.fx.Node, prior_latency=True
154
- ) -> circle.Operator.OperatorT:
155
- """
156
- NOTE: Possibility of accuracy-latency trade-off
157
- From ONE compiler's perspective:
158
- - BMM uses per-tensor quantization for both rhs and lhs.
159
- - FC uses per-channel quantization for weight and per-tensor for input.
160
- Thus, FC is better in terms of accuracy.
161
- FC necessarily involves an additional transpose operation to be identical with mm.
162
- If transposed operand is const, it can be optimized by constant folding.
163
- Thus, convert FC only if tranpose can be folded.
164
- TODO set prior_latency outside
165
- """
41
+ def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
166
42
  args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
167
43
  input = args.input
168
44
  other = args.other
@@ -170,9 +46,16 @@ class MatmulDefaultVisitor(NodeVisitor):
170
46
  inputs = [input, other]
171
47
  outputs = [node]
172
48
 
173
- if not is_const(other) and prior_latency:
174
- operator = self.define_bmm_node(inputs, outputs)
175
- else:
176
- operator = self.define_fc_with_transpose(node, inputs, outputs)
49
+ op_index = get_op_index(
50
+ circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, self._op_codes
51
+ )
52
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
53
+ operator.builtinOptionsType = (
54
+ circle.BuiltinOptions.BuiltinOptions.BatchMatMulOptions
55
+ )
56
+ option = circle.BatchMatMulOptions.BatchMatMulOptionsT()
57
+ option.adjointLhs, option.adjointRhs = False, False
58
+ option.asymmetricQuantizeInputs = False
59
+ operator.builtinOptions = option
177
60
 
178
61
  return operator
@@ -0,0 +1,65 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import CircleRMSNormArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class RMSNormVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.circle_custom.rms_norm.default,
34
+ ]
35
+
36
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
37
+ super().__init__(op_codes, graph)
38
+
39
+ def define_node(
40
+ self,
41
+ node: torch.fx.Node,
42
+ ) -> circle.Operator.OperatorT:
43
+ args = CircleRMSNormArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
44
+ input = args.input
45
+ weight = args.weight
46
+ eps = args.eps
47
+
48
+ op_index = get_op_index(
49
+ circle.BuiltinOperator.BuiltinOperator.RMS_NORM, self._op_codes
50
+ )
51
+
52
+ inputs = [input, weight]
53
+ outputs = [node]
54
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
55
+
56
+ # Op-specific option
57
+ operator.builtinOptionsType = (
58
+ circle.BuiltinOptions.BuiltinOptions.RmsNormOptions
59
+ )
60
+ option = circle.RmsNormOptions.RmsNormOptionsT()
61
+ option.epsilon = eps
62
+
63
+ operator.builtinOptions = option
64
+
65
+ return operator
tico/utils/convert.py CHANGED
@@ -20,26 +20,14 @@ import torch
20
20
  from torch.export import export, ExportedProgram
21
21
 
22
22
  from tico.config import CompileConfigBase, get_default_config
23
- from tico.experimental.quantization.passes.fold_quant_ops import FoldQuantOps
24
- from tico.experimental.quantization.passes.insert_quantize_on_dtype_mismatch import (
25
- InsertQuantizeOnDtypeMismatch,
26
- )
27
- from tico.experimental.quantization.passes.propagate_qparam_backward import (
28
- PropagateQParamBackward,
29
- )
30
- from tico.experimental.quantization.passes.propagate_qparam_forward import (
31
- PropagateQParamForward,
32
- )
33
- from tico.experimental.quantization.passes.quantize_bias import QuantizeBias
34
- from tico.experimental.quantization.passes.remove_weight_dequant_op import (
35
- RemoveWeightDequantOp,
36
- )
37
23
  from tico.passes.cast_aten_where_arg_type import CastATenWhereArgType
38
24
  from tico.passes.cast_clamp_mixed_type_args import CastClampMixedTypeArgs
39
25
  from tico.passes.cast_mixed_type_args import CastMixedTypeArgs
40
26
  from tico.passes.const_prop_pass import ConstPropPass
41
27
  from tico.passes.convert_conv1d_to_conv2d import ConvertConv1dToConv2d
28
+ from tico.passes.convert_expand_to_slice_cat import ConvertExpandToSliceCat
42
29
  from tico.passes.convert_layout_op_to_reshape import ConvertLayoutOpToReshape
30
+ from tico.passes.convert_matmul_to_linear import ConvertMatmulToLinear
43
31
  from tico.passes.convert_repeat_to_expand_copy import ConvertRepeatToExpandCopy
44
32
  from tico.passes.convert_to_relu6 import ConvertToReLU6
45
33
  from tico.passes.decompose_addmm import DecomposeAddmm
@@ -72,6 +60,14 @@ from tico.passes.remove_redundant_slice import RemoveRedundantSlice
72
60
  from tico.passes.remove_redundant_to_copy import RemoveRedundantToCopy
73
61
  from tico.passes.restore_linear import RestoreLinear
74
62
  from tico.passes.segment_index_select import SegmentIndexSelectConst
63
+ from tico.quantization.passes.fold_quant_ops import FoldQuantOps
64
+ from tico.quantization.passes.insert_quantize_on_dtype_mismatch import (
65
+ InsertQuantizeOnDtypeMismatch,
66
+ )
67
+ from tico.quantization.passes.propagate_qparam_backward import PropagateQParamBackward
68
+ from tico.quantization.passes.propagate_qparam_forward import PropagateQParamForward
69
+ from tico.quantization.passes.quantize_bias import QuantizeBias
70
+ from tico.quantization.passes.remove_weight_dequant_op import RemoveWeightDequantOp
75
71
  from tico.serialize.circle_serializer import build_circle
76
72
  from tico.serialize.operators.node_visitor import get_support_targets
77
73
  from tico.utils import logging
@@ -141,6 +137,7 @@ def traced_run_decompositions(exported_program: ExportedProgram):
141
137
  or torch.__version__.startswith("2.7")
142
138
  or torch.__version__.startswith("2.8")
143
139
  or torch.__version__.startswith("2.9")
140
+ or torch.__version__.startswith("2.10")
144
141
  ):
145
142
  return run_decompositions(exported_program)
146
143
  else:
@@ -249,6 +246,14 @@ def convert_exported_module_to_circle(
249
246
  ConstPropPass(),
250
247
  SegmentIndexSelectConst(),
251
248
  LegalizeCausalMaskValue(enabled=config.get("legalize_causal_mask_value")),
249
+ ConvertExpandToSliceCat(enabled=config.get("convert_expand_to_slice_cat")),
250
+ ConvertMatmulToLinear(
251
+ enable_lhs_const=config.get("convert_lhs_const_mm_to_fc"),
252
+ enable_rhs_const=config.get("convert_rhs_const_mm_to_fc"),
253
+ enable_single_batch_lhs_const_bmm=config.get(
254
+ "convert_single_batch_lhs_const_bmm_to_fc"
255
+ ),
256
+ ),
252
257
  LowerToResizeNearestNeighbor(),
253
258
  LegalizePreDefinedLayoutOperators(),
254
259
  LowerPow2ToMul(),
@@ -287,7 +292,7 @@ def convert_exported_module_to_circle(
287
292
 
288
293
  check_unsupported_target(exported_program)
289
294
  check_training_ops(exported_program)
290
- circle_program = build_circle(exported_program)
295
+ circle_program = build_circle(exported_program, config)
291
296
 
292
297
  return circle_program
293
298
 
tico/utils/dtype.py CHANGED
@@ -1,6 +1,8 @@
1
1
  import numpy as np
2
2
  import torch
3
3
 
4
+ from circle_schema import circle
5
+
4
6
  NUMPY_TO_TORCH_DTYPE_DICT = {
5
7
  np.dtype("float32"): torch.float32,
6
8
  np.dtype("float64"): torch.float64,
@@ -15,6 +17,26 @@ NUMPY_TO_TORCH_DTYPE_DICT = {
15
17
  np.dtype("bool"): torch.bool,
16
18
  }
17
19
 
20
+ CIRCLE_TO_TORCH_DTYPE_DICT = {
21
+ circle.TensorType.TensorType.FLOAT32: torch.float32,
22
+ circle.TensorType.TensorType.UINT8: torch.uint8,
23
+ circle.TensorType.TensorType.INT8: torch.int8,
24
+ circle.TensorType.TensorType.INT16: torch.int16,
25
+ circle.TensorType.TensorType.INT32: torch.int32,
26
+ circle.TensorType.TensorType.INT64: torch.int64,
27
+ circle.TensorType.TensorType.BOOL: torch.bool,
28
+ }
29
+
18
30
 
19
31
  def numpy_dtype_to_torch_dtype(np_dtype: np.dtype) -> torch.dtype:
20
32
  return NUMPY_TO_TORCH_DTYPE_DICT[np_dtype]
33
+
34
+
35
+ def circle_dtype_to_torch_dtype(circle_dtype: int) -> torch.dtype:
36
+ assert isinstance(circle_dtype, int)
37
+ if circle_dtype not in CIRCLE_TO_TORCH_DTYPE_DICT:
38
+ raise RuntimeError(f"Unsupported dtype {circle_dtype}")
39
+
40
+ torch_dtype = CIRCLE_TO_TORCH_DTYPE_DICT[circle_dtype]
41
+ assert torch_dtype is not None
42
+ return torch_dtype
@@ -31,9 +31,11 @@ def CircleResizeNearestNeighbor():
31
31
  W_scale_factor = size[2] / W
32
32
  if H_scale_factor != W_scale_factor:
33
33
  raise RuntimeError("Scale factor of H and W should be same.")
34
- return torch.nn.functional.interpolate(
35
- input_, scale_factor=H_scale_factor, mode="nearest"
34
+ permuted = torch.permute(input_, [0, 3, 1, 2])
35
+ resized = torch.nn.functional.interpolate(
36
+ permuted, scale_factor=H_scale_factor, mode="nearest"
36
37
  )
38
+ return torch.permute(resized, [0, 2, 3, 1])
37
39
 
38
40
  @register_fake("circle_custom::resize_nearest_neighbor")
39
41
  def _(input_: torch.Tensor, size: List[int]):
@@ -631,7 +633,7 @@ def CircleInstanceNorm():
631
633
  bias: Optional[torch.Tensor] = None,
632
634
  running_mean: Optional[torch.Tensor] = None,
633
635
  running_var: Optional[torch.Tensor] = None,
634
- use_input_stats: bool = False,
636
+ use_input_stats: bool = True,
635
637
  momentum: float = 0.1,
636
638
  eps: float = 1e-05,
637
639
  cudnn_enabled: bool = False,
@@ -639,7 +641,7 @@ def CircleInstanceNorm():
639
641
  NHWC_to_NCHW = [0, 3, 1, 2]
640
642
  NCHW_input = torch.ops.aten.permute.default(input_, NHWC_to_NCHW)
641
643
 
642
- args = [NCHW_input, weight, bias, None, None, False, momentum, eps, False]
644
+ args = [NCHW_input, weight, bias, None, None, True, momentum, eps, False]
643
645
  NCHW_output = torch.ops.aten.instance_norm.default(*args)
644
646
  NCHW_to_NHWC = [0, 2, 3, 1]
645
647
  NHWC_output = torch.ops.aten.permute.default(NCHW_output, NCHW_to_NHWC)
@@ -703,6 +705,28 @@ def CircleQuantizeMX():
703
705
  return input_
704
706
 
705
707
 
708
+ def CircleRMSNorm():
709
+ @custom_op("circle_custom::rms_norm", mutates_args=())
710
+ def rms_norm(
711
+ hidden_states: torch.Tensor,
712
+ weight: torch.Tensor,
713
+ eps: float = 1e-05,
714
+ ) -> torch.Tensor:
715
+ input_dtype = hidden_states.dtype
716
+ hidden_states = hidden_states.to(torch.float32)
717
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
718
+ hidden_states = hidden_states * torch.rsqrt(variance + eps)
719
+ return weight * hidden_states.to(input_dtype)
720
+
721
+ @register_fake("circle_custom::rms_norm")
722
+ def _(
723
+ hidden_states: torch.Tensor,
724
+ weight: torch.Tensor,
725
+ eps: float = 1e-05,
726
+ ) -> torch.Tensor:
727
+ return hidden_states.new_empty(hidden_states.size())
728
+
729
+
706
730
  # Add custom ops to the torch namespace
707
731
  def RegisterOps():
708
732
  CircleResizeNearestNeighbor()
@@ -715,3 +739,4 @@ def RegisterOps():
715
739
  CircleAvgPool2D()
716
740
  CircleInstanceNorm()
717
741
  CircleQuantizeMX()
742
+ CircleRMSNorm()
@@ -0,0 +1,247 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Sequence
16
+
17
+ import numpy as np
18
+ import torch
19
+ from circle_schema import circle
20
+
21
+ from tico.serialize.circle_mapping import to_circle_shape
22
+ from tico.utils.dtype import circle_dtype_to_torch_dtype
23
+ from tico.utils.installed_packages import is_dynamic_cache_available
24
+
25
+
26
+ def is_dynamic_cache_instance(value):
27
+ if is_dynamic_cache_available():
28
+ from transformers.cache_utils import DynamicCache
29
+
30
+ return isinstance(value, DynamicCache)
31
+ else:
32
+ return False
33
+
34
+
35
+ def flatten_and_convert_kwargs(kwargs: dict) -> dict[str, torch.Tensor]:
36
+ result = {} # type: ignore[var-annotated]
37
+ for k, v in kwargs.items():
38
+ if v is None:
39
+ continue
40
+ elif isinstance(v, (list, tuple)):
41
+ # 1. handle list
42
+ def unpack_recursive(name, value, store=None):
43
+ if store is None:
44
+ store = {}
45
+
46
+ if isinstance(value, (tuple, list)):
47
+ for i, v in enumerate(value):
48
+ # recursive call. Append index to name and explore lower level
49
+ unpack_recursive(f"{name}_{i}", v, store)
50
+ else:
51
+ # base type (scalar etc.) directly stored
52
+ store[name] = value
53
+
54
+ return store
55
+
56
+ unpack_recursive(k, v, result)
57
+ elif is_dynamic_cache_instance(v):
58
+ # 2. handle DynamicCache
59
+ for idx, cache_val in enumerate(v.key_cache):
60
+ result[f"{k}_key_cache_{idx}"] = cache_val
61
+
62
+ for idx, cache_val in enumerate(v.value_cache):
63
+ result[f"{k}_value_cache_{idx}"] = cache_val
64
+ else:
65
+ result[k] = v
66
+
67
+ # 3. Convert to tensors
68
+ for k, v in result.items():
69
+ result[k] = v if isinstance(v, torch.Tensor) else torch.tensor(v)
70
+
71
+ return result
72
+
73
+
74
+ def flatten_and_convert_args(args: Sequence) -> tuple:
75
+ result = [] # type: ignore[var-annotated]
76
+ for item in args:
77
+ if item is None:
78
+ continue
79
+
80
+ # 1. recursion on list and tuple
81
+ if isinstance(item, (list, tuple)):
82
+ result.extend(flatten_and_convert_args(item))
83
+ continue
84
+
85
+ # 2. handle DynamicCache
86
+ if is_dynamic_cache_available():
87
+ from transformers.cache_utils import DynamicCache
88
+
89
+ if isinstance(item, DynamicCache):
90
+ # NOTE The tensor order is: key_in → key_out → value_in → value_out
91
+ #
92
+ # Refer to https://github.com/huggingface/transformers/blob/3457e8e73e4f5532cc69059682b1ba4484d7e7e8/src/transformers/cache_utils.py#L557
93
+ # ```
94
+ # self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
95
+ # self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
96
+ # ```
97
+ result.extend(item.key_cache)
98
+ result.extend(item.value_cache)
99
+ continue
100
+
101
+ # 3. Convert to tensors
102
+ result.append(item if isinstance(item, torch.Tensor) else torch.tensor(item))
103
+
104
+ return tuple(result)
105
+
106
+
107
+ class ModelInputSpec:
108
+ @classmethod
109
+ def load(cls, circle_path):
110
+ def load(circle_path: str) -> bytes:
111
+ with open(circle_path, "rb") as f:
112
+ buf = bytes(f.read())
113
+ return buf
114
+
115
+ circle_binary = load(circle_path)
116
+ return cls(circle_binary)
117
+
118
+ def __init__(self, circle_binary):
119
+ model = circle.Model.Model.GetRootAsModel(circle_binary, 0)
120
+ assert model.SubgraphsLength() == 1, "Only one subgraph is supported"
121
+
122
+ graph = model.Subgraphs(0)
123
+ tensors = [graph.Tensors(graph.Inputs(o)) for o in range(graph.InputsLength())]
124
+
125
+ self.names = [t.Name().decode("utf-8").split("::")[-1] for t in tensors]
126
+ self.shapes = [t.ShapeAsNumpy() for t in tensors]
127
+ self.shape_signatures = list(
128
+ map(
129
+ lambda x: None if (isinstance(x, int) and x == 0) else x,
130
+ (t.ShapeSignatureAsNumpy() for t in tensors),
131
+ )
132
+ )
133
+ self.types: list[torch.dtype] = [
134
+ circle_dtype_to_torch_dtype(t.Type()) for t in tensors
135
+ ]
136
+ self.name_to_idx = {name: idx for idx, name in enumerate(self.names)}
137
+
138
+ def bind(self, args, kwargs, check=True):
139
+ """Convert args and kwargs into an ordered list according to model input order"""
140
+ inputs = []
141
+ args = flatten_and_convert_args(args)
142
+ kwargs = flatten_and_convert_kwargs(kwargs)
143
+
144
+ arg_num = len(args) + len(kwargs)
145
+ m_input_num = len(self.names)
146
+ if arg_num != m_input_num:
147
+ raise ValueError(
148
+ f"Mismatch: number of model inputs and number of passed arguments are not the same: inputs({m_input_num}) != passed({arg_num}), input spec: {self.names}"
149
+ )
150
+
151
+ # 1. positional arguments
152
+ for i, val in enumerate(args):
153
+ name = self.names[i]
154
+ inputs.append(val)
155
+
156
+ # 2. keyword arguments
157
+ for idx in range(len(args), len(self.names)):
158
+ name = self.names[idx]
159
+ inputs.append(kwargs[name])
160
+
161
+ if check:
162
+ self.check_types(inputs)
163
+ self.check_shapes(inputs)
164
+
165
+ return inputs
166
+
167
+ def check_types(self, inputs):
168
+ """Check the types of input values"""
169
+ for i, (inp, ref_type) in enumerate(zip(inputs, self.types)):
170
+ # TODO: Support more data types (np array)
171
+ assert isinstance(
172
+ inp, (torch.Tensor | int | float)
173
+ ), f"Input '{self.names[i]}' type must be a torch tensor or scalar."
174
+
175
+ if isinstance(inp, torch.Tensor):
176
+ if inp.dtype != ref_type:
177
+ raise TypeError(
178
+ f"Input '{self.names[i]}' type {inp.dtype} != expected {ref_type}"
179
+ )
180
+ else:
181
+ # Scalars (int, float)
182
+ if ref_type == torch.float32:
183
+ if not isinstance(inp, (float)):
184
+ raise TypeError(
185
+ f"Input '{self.names[i]}' type {type(inp)} != expected {ref_type}"
186
+ )
187
+ elif ref_type == torch.int64:
188
+ if not isinstance(inp, (int)):
189
+ raise TypeError(
190
+ f"Input '{self.names[i]}' type {type(inp)} != expected {ref_type}"
191
+ )
192
+ else:
193
+ print(f"Unexpected ref_type: {ref_type}")
194
+
195
+ def check_shapes(self, inputs):
196
+ """Check the shapes of input values"""
197
+
198
+ def merge(shape, shape_sig):
199
+ """
200
+ Merge shape signature with shape
201
+ """
202
+ from copy import deepcopy
203
+
204
+ shape_merged = deepcopy(shape)
205
+ if shape_sig is not None:
206
+ for idx, ss in enumerate(shape_sig):
207
+ if ss == -1:
208
+ shape_merged[idx] = -1
209
+
210
+ return shape_merged
211
+
212
+ for i, (inp, ref_shape, ref_shape_sig) in enumerate(
213
+ zip(inputs, self.shapes, self.shape_signatures)
214
+ ):
215
+ # TODO: Support more data types (np array)
216
+ assert isinstance(
217
+ inp, (torch.Tensor | int | float)
218
+ ), f"Input '{self.names[i]}' type must be a torch tensor or scalar."
219
+
220
+ if isinstance(inp, torch.Tensor): # Tensor
221
+ in_shape, in_shape_sig = to_circle_shape(inp.size())
222
+
223
+ if len(in_shape) != len(ref_shape):
224
+ raise ValueError(
225
+ f"Input '{self.names[i]}' has invalid rank {len(in_shape)}!= expected {len(ref_shape)}"
226
+ )
227
+
228
+ in_merged_shape = merge(in_shape, in_shape_sig)
229
+ ref_merged_shape = merge(ref_shape, ref_shape_sig)
230
+ for in_shp, ref_shp in zip(in_merged_shape, ref_merged_shape):
231
+ if ref_shp == -1:
232
+ continue
233
+ if in_shp == -1:
234
+ raise ValueError(
235
+ f"Input '{self.names[i]}' has unknown dimension {inp.size()} != expected shape({ref_shape}) / shape signature({ref_shape_sig}) "
236
+ )
237
+ if in_shp != ref_shp:
238
+ raise ValueError(
239
+ f"Input '{self.names[i]}' has wrong dimension {inp.size()} != expected shape({ref_shape}) / shape signature({ref_shape_sig}) "
240
+ )
241
+ elif isinstance(inp, (int, float)): # Scalar
242
+ if len(ref_shape) > 0:
243
+ raise ValueError(
244
+ f"Input '{self.names[i]}' has invalid rank {len(ref_shape)}"
245
+ )
246
+ else:
247
+ print(f"Unexpected input type: {type(inp)}")