tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. tico/__init__.py +9 -1
  2. tico/config/base.py +1 -1
  3. tico/config/v1.py +5 -0
  4. tico/passes/cast_aten_where_arg_type.py +1 -1
  5. tico/passes/cast_clamp_mixed_type_args.py +169 -0
  6. tico/passes/cast_mixed_type_args.py +4 -2
  7. tico/passes/const_prop_pass.py +1 -1
  8. tico/passes/convert_conv1d_to_conv2d.py +1 -1
  9. tico/passes/convert_expand_to_slice_cat.py +153 -0
  10. tico/passes/convert_matmul_to_linear.py +312 -0
  11. tico/passes/convert_to_relu6.py +1 -1
  12. tico/passes/decompose_addmm.py +0 -3
  13. tico/passes/decompose_batch_norm.py +2 -2
  14. tico/passes/decompose_fake_quantize.py +0 -3
  15. tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
  16. tico/passes/decompose_group_norm.py +0 -3
  17. tico/passes/legalize_predefined_layout_operators.py +2 -11
  18. tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
  19. tico/passes/lower_to_slice.py +1 -1
  20. tico/passes/merge_consecutive_cat.py +1 -1
  21. tico/passes/ops.py +1 -1
  22. tico/passes/remove_redundant_assert_nodes.py +3 -1
  23. tico/passes/remove_redundant_expand.py +3 -6
  24. tico/passes/remove_redundant_reshape.py +5 -5
  25. tico/passes/segment_index_select.py +1 -1
  26. tico/quantization/__init__.py +6 -0
  27. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
  28. tico/quantization/algorithm/gptq/quantizer.py +292 -0
  29. tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
  30. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
  31. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  32. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  33. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  34. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  35. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
  36. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  37. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  38. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  39. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  40. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  41. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  42. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  43. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  44. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
  45. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  46. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  47. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  48. tico/quantization/config/base.py +26 -0
  49. tico/quantization/config/gptq.py +29 -0
  50. tico/quantization/config/pt2e.py +25 -0
  51. tico/quantization/config/ptq.py +119 -0
  52. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
  53. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
  54. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  55. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  56. tico/quantization/evaluation/metric.py +146 -0
  57. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  58. tico/quantization/passes/__init__.py +1 -0
  59. tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
  60. tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
  61. tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
  62. tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
  63. tico/{experimental/quantization → quantization}/public_interface.py +19 -18
  64. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  65. tico/quantization/quantizer_registry.py +73 -0
  66. tico/quantization/wrapq/__init__.py +1 -0
  67. tico/quantization/wrapq/dtypes.py +70 -0
  68. tico/quantization/wrapq/examples/__init__.py +1 -0
  69. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  70. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  71. tico/quantization/wrapq/examples/quantize_linear.py +107 -0
  72. tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
  73. tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
  74. tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
  75. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  76. tico/quantization/wrapq/mode.py +32 -0
  77. tico/quantization/wrapq/observers/__init__.py +1 -0
  78. tico/quantization/wrapq/observers/affine_base.py +128 -0
  79. tico/quantization/wrapq/observers/base.py +98 -0
  80. tico/quantization/wrapq/observers/ema.py +62 -0
  81. tico/quantization/wrapq/observers/identity.py +74 -0
  82. tico/quantization/wrapq/observers/minmax.py +39 -0
  83. tico/quantization/wrapq/observers/mx.py +60 -0
  84. tico/quantization/wrapq/qscheme.py +40 -0
  85. tico/quantization/wrapq/quantizer.py +179 -0
  86. tico/quantization/wrapq/utils/__init__.py +1 -0
  87. tico/quantization/wrapq/utils/introspection.py +167 -0
  88. tico/quantization/wrapq/utils/metrics.py +124 -0
  89. tico/quantization/wrapq/utils/reduce_utils.py +25 -0
  90. tico/quantization/wrapq/wrappers/__init__.py +1 -0
  91. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  92. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  93. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  94. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  95. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  96. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  97. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  98. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  99. tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
  100. tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
  101. tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
  102. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  103. tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
  104. tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
  105. tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
  106. tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
  107. tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
  108. tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
  109. tico/quantization/wrapq/wrappers/registry.py +125 -0
  110. tico/serialize/circle_graph.py +12 -4
  111. tico/serialize/circle_mapping.py +76 -2
  112. tico/serialize/circle_serializer.py +253 -148
  113. tico/serialize/operators/adapters/__init__.py +1 -0
  114. tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
  115. tico/serialize/operators/op_any.py +7 -14
  116. tico/serialize/operators/op_avg_pool2d.py +11 -4
  117. tico/serialize/operators/op_clamp.py +5 -7
  118. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  119. tico/serialize/operators/op_conv2d.py +14 -6
  120. tico/serialize/operators/op_copy.py +26 -3
  121. tico/serialize/operators/op_cumsum.py +3 -1
  122. tico/serialize/operators/op_depthwise_conv2d.py +17 -7
  123. tico/serialize/operators/op_full_like.py +0 -2
  124. tico/serialize/operators/op_index_select.py +8 -1
  125. tico/serialize/operators/op_instance_norm.py +0 -6
  126. tico/serialize/operators/op_le.py +54 -0
  127. tico/serialize/operators/op_log1p.py +3 -2
  128. tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
  129. tico/serialize/operators/op_mm.py +15 -131
  130. tico/serialize/operators/op_mul.py +2 -8
  131. tico/serialize/operators/op_pow.py +3 -1
  132. tico/serialize/operators/op_repeat.py +12 -3
  133. tico/serialize/operators/op_reshape.py +1 -1
  134. tico/serialize/operators/op_rmsnorm.py +65 -0
  135. tico/serialize/operators/op_softmax.py +7 -14
  136. tico/serialize/operators/op_split_with_sizes.py +16 -8
  137. tico/serialize/operators/op_transpose_conv.py +11 -8
  138. tico/serialize/operators/op_view.py +2 -1
  139. tico/serialize/quant_param.py +5 -5
  140. tico/utils/convert.py +30 -17
  141. tico/utils/dtype.py +42 -0
  142. tico/utils/graph.py +1 -1
  143. tico/utils/model.py +2 -1
  144. tico/utils/padding.py +2 -2
  145. tico/utils/pytree_utils.py +134 -0
  146. tico/utils/record_input.py +102 -0
  147. tico/utils/register_custom_op.py +29 -4
  148. tico/utils/serialize.py +16 -3
  149. tico/utils/signature.py +247 -0
  150. tico/utils/torch_compat.py +52 -0
  151. tico/utils/utils.py +50 -58
  152. tico/utils/validate_args_kwargs.py +38 -3
  153. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
  154. tico-0.1.0.dev251102.dist-info/RECORD +271 -0
  155. tico/experimental/quantization/__init__.py +0 -1
  156. tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
  157. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  158. tico/experimental/quantization/evaluation/metric.py +0 -109
  159. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
  160. tico-0.1.0.dev250714.dist-info/RECORD +0 -209
  161. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  162. /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
  163. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  164. /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
  165. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
  166. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  167. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  168. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
  169. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  170. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
  171. /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
  172. /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
  173. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  174. /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
  175. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  176. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  177. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  178. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
  179. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
  180. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
  181. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
tico/__init__.py CHANGED
@@ -20,8 +20,16 @@ from packaging.version import Version
20
20
  from tico.config import CompileConfigV1, get_default_config
21
21
  from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
22
22
 
23
+ __all__ = [
24
+ "CompileConfigV1",
25
+ "get_default_config",
26
+ "convert",
27
+ "convert_from_exported_program",
28
+ "convert_from_pt2",
29
+ ]
30
+
23
31
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
24
- __version__ = "0.1.0.dev250714"
32
+ __version__ = "0.1.0.dev251102"
25
33
 
26
34
  MINIMUM_SUPPORTED_VERSION = "2.5.0"
27
35
  SECURE_TORCH_VERSION = "2.6.0"
tico/config/base.py CHANGED
@@ -31,7 +31,7 @@ class CompileConfigBase:
31
31
  config = cls()
32
32
  for key in config_dict:
33
33
  if key in config.to_dict():
34
- assert type(config.get(key)) == bool
34
+ assert isinstance(config.get(key), bool)
35
35
  config.set(key, config_dict[key])
36
36
 
37
37
  return config
tico/config/v1.py CHANGED
@@ -20,6 +20,11 @@ from tico.config.base import CompileConfigBase
20
20
  @dataclass
21
21
  class CompileConfigV1(CompileConfigBase):
22
22
  legalize_causal_mask_value: bool = False
23
+ remove_constant_input: bool = False
24
+ convert_lhs_const_mm_to_fc: bool = False
25
+ convert_rhs_const_mm_to_fc: bool = True
26
+ convert_single_batch_lhs_const_bmm_to_fc: bool = False
27
+ convert_expand_to_slice_cat: bool = False
23
28
 
24
29
  def get(self, name: str):
25
30
  return super().get(name)
@@ -176,7 +176,7 @@ class CastATenWhereArgType(PassBase):
176
176
  node_dtype = extract_torch_dtype(node)
177
177
  assert (
178
178
  node_dtype == node_dtype_ori
179
- ), f"Type casting doesn't change node's dtype."
179
+ ), "Type casting doesn't change node's dtype."
180
180
 
181
181
  logger.debug(
182
182
  f"{to_cast.name}'s dtype was casted from {buf_data.dtype} to {dtype_to_cast}"
@@ -0,0 +1,169 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.passes import ops
23
+
24
+ from tico.serialize.circle_mapping import extract_torch_dtype
25
+ from tico.utils import logging
26
+ from tico.utils.graph import create_node
27
+ from tico.utils.passes import PassBase, PassResult
28
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
29
+ from tico.utils.utils import is_target_node, set_new_meta_val
30
+ from tico.utils.validate_args_kwargs import ClampArgs
31
+
32
+
33
+ @trace_graph_diff_on_pass
34
+ class CastClampMixedTypeArgs(PassBase):
35
+ """
36
+ This pass ensures consistent dtypes for clamp operations by:
37
+ 1. Converting min/max arguments to match output dtype when provided
38
+ 2. Inserting cast operations when input dtype differs from output dtype
39
+
40
+ Behavior Examples:
41
+ - When input dtype differs from output:
42
+ Inserts _to_copy operation to convert input
43
+ - When min/max dtype differs from output:
44
+ Converts min/max values to output dtype
45
+
46
+ (Case 1, if input dtype is different from output dtype)
47
+ [before]
48
+
49
+ input min(or max)
50
+ (dtype=int) (dtype=float)
51
+ | |
52
+ clamp <----------------+
53
+ |
54
+ output
55
+ (dtype=float)
56
+
57
+ [after]
58
+
59
+ input min(or max)
60
+ (dtype=int) (dtype=float)
61
+ | |
62
+ cast |
63
+ (in=int, out=float) |
64
+ | |
65
+ clamp <--------------+
66
+ |
67
+ output
68
+ (dtype=float)
69
+
70
+ (Case 2, if min(or max) dtype is different from output dtype)
71
+ [before]
72
+
73
+ input min(or max)
74
+ (dtype=float) (dtype=int)
75
+ | |
76
+ clamp <----------------+
77
+ |
78
+ output
79
+ (dtype=float)
80
+
81
+ [after]
82
+
83
+ input min(or max)
84
+ (dtype=float) (dtype=float)
85
+ | |
86
+ clamp <--------------+
87
+ |
88
+ output
89
+ (dtype=float)
90
+ """
91
+
92
+ def __init__(self):
93
+ super().__init__()
94
+
95
+ def convert(self, exported_program: ExportedProgram, node: torch.fx.Node) -> bool:
96
+ logger = logging.getLogger(__name__)
97
+ modified = False
98
+
99
+ graph_module = exported_program.graph_module
100
+ graph = graph_module.graph
101
+
102
+ # clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
103
+ args = ClampArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
104
+
105
+ input = args.input
106
+ min = args.min
107
+ max = args.max
108
+
109
+ input_dtype = extract_torch_dtype(input)
110
+ output_dtype = extract_torch_dtype(node)
111
+
112
+ def _convert_arg(arg, arg_name: str):
113
+ if arg is None:
114
+ return False
115
+
116
+ arg_dtype = torch.tensor(arg).dtype
117
+ arg_idx = node.args.index(arg)
118
+ if arg_dtype != output_dtype:
119
+ assert output_dtype in [torch.float, torch.int]
120
+ if output_dtype == torch.float:
121
+ arg = float(arg)
122
+ else:
123
+ arg = int(arg)
124
+ node.update_arg(arg_idx, arg)
125
+ logger.debug(
126
+ f"Casting {arg_name} value from {arg_dtype} to {output_dtype} for clamp operation at {node.name}"
127
+ )
128
+ return True
129
+ return False
130
+
131
+ modified |= _convert_arg(min, "min")
132
+ modified |= _convert_arg(max, "max")
133
+
134
+ if input_dtype != output_dtype:
135
+ logger.debug(
136
+ f"Inserting cast from {input_dtype} to {output_dtype} for input {input.name}"
137
+ )
138
+ with graph.inserting_after(input):
139
+ to_copy = create_node(
140
+ graph,
141
+ torch.ops.aten._to_copy.default,
142
+ (input,),
143
+ {"dtype": output_dtype},
144
+ origin=input,
145
+ )
146
+ set_new_meta_val(to_copy)
147
+ node.update_arg(node.args.index(input), to_copy)
148
+
149
+ modified = True
150
+
151
+ return modified
152
+
153
+ def call(self, exported_program: ExportedProgram) -> PassResult:
154
+ target_op = ops.aten.clamp
155
+
156
+ graph_module = exported_program.graph_module
157
+ graph = graph_module.graph
158
+ modified = False
159
+ for node in graph.nodes:
160
+ if not is_target_node(node, target_op):
161
+ continue
162
+
163
+ modified |= self.convert(exported_program, node)
164
+
165
+ graph.eliminate_dead_code()
166
+ graph.lint()
167
+ graph_module.recompile()
168
+
169
+ return PassResult(modified)
@@ -41,6 +41,8 @@ ops_to_promote = {
41
41
  torch.ops.aten.ge.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
42
42
  torch.ops.aten.gt.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
43
43
  torch.ops.aten.gt.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
44
+ torch.ops.aten.le.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
45
+ torch.ops.aten.le.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
44
46
  torch.ops.aten.mul.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
45
47
  torch.ops.aten.minimum.default: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
46
48
  torch.ops.aten.ne.Scalar: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
@@ -124,7 +126,7 @@ class CastMixedTypeArgs(PassBase):
124
126
  if rhs_val.dtype == type_to_promote:
125
127
  ori_type = lhs_val.dtype
126
128
  arg_to_promote = lhs
127
- assert arg_to_promote != None
129
+ assert arg_to_promote is not None
128
130
 
129
131
  if isinstance(arg_to_promote, torch.fx.Node):
130
132
  with graph.inserting_after(arg_to_promote):
@@ -178,7 +180,7 @@ class CastMixedTypeArgs(PassBase):
178
180
  node_dtype = extract_torch_dtype(node)
179
181
  assert (
180
182
  node_dtype == node_dtype_ori
181
- ), f"Type casting doesn't change node's dtype."
183
+ ), "Type casting doesn't change node's dtype."
182
184
 
183
185
  graph.eliminate_dead_code()
184
186
  graph.lint()
@@ -301,7 +301,7 @@ class ConstPropPass(PassBase):
301
301
  graph.eliminate_dead_code()
302
302
  graph_module.recompile()
303
303
 
304
- logger.debug(f"Constant nodes are propagated")
304
+ logger.debug("Constant nodes are propagated")
305
305
  # Constant folding can be done with only one time run. Let's set `modified` to False.
306
306
  modified = False
307
307
  return PassResult(modified)
@@ -19,7 +19,7 @@ if TYPE_CHECKING:
19
19
  import torch
20
20
  from torch.export import ExportedProgram
21
21
 
22
- from tico.serialize.circle_graph import extract_shape
22
+ from tico.serialize.circle_mapping import extract_shape
23
23
  from tico.utils import logging
24
24
  from tico.utils.errors import NotYetSupportedError
25
25
  from tico.utils.graph import create_node
@@ -0,0 +1,153 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.passes import ops
23
+ from tico.serialize.circle_mapping import extract_shape
24
+ from tico.utils import logging
25
+ from tico.utils.graph import create_node
26
+ from tico.utils.passes import PassBase, PassResult
27
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.utils import is_target_node
29
+ from tico.utils.validate_args_kwargs import ExpandArgs, ReshapeArgs
30
+
31
+
32
+ @trace_graph_diff_on_pass
33
+ class ConvertExpandToSliceCat(PassBase):
34
+ """
35
+ This pass replaces `aten.reshape` + `aten.expand` pattern by rewriting it using
36
+ a series of `aten.slice` and `aten.cat` operations.
37
+
38
+ This pass is specialized for expand of KVCache.
39
+ - Expects (batch, num_key_value_heads, seq_len, head_dim) as input shape of reshape
40
+ """
41
+
42
+ def __init__(self, enabled: bool = False):
43
+ super().__init__()
44
+ self.enabled = enabled
45
+
46
+ def call(self, exported_program: ExportedProgram) -> PassResult:
47
+ if not self.enabled:
48
+ return PassResult(False)
49
+
50
+ logger = logging.getLogger(__name__)
51
+
52
+ graph_module = exported_program.graph_module
53
+ graph = graph_module.graph
54
+ modified = False
55
+
56
+ # This pass handles expand on EXPAND_DIM only
57
+ CAT_DIM = 1
58
+ EXPAND_DIM = 2
59
+
60
+ for node in graph.nodes:
61
+ if not isinstance(node, torch.fx.Node) or not is_target_node(
62
+ node, ops.aten.reshape
63
+ ):
64
+ continue
65
+
66
+ post_reshape = node
67
+ post_reshape_args = ReshapeArgs(*post_reshape.args, **post_reshape.kwargs)
68
+ post_reshape_input = post_reshape_args.input
69
+
70
+ if not isinstance(post_reshape_input, torch.fx.Node) or not is_target_node(
71
+ post_reshape_input, ops.aten.expand
72
+ ):
73
+ continue
74
+
75
+ expand = post_reshape_input
76
+ expand_args = ExpandArgs(*expand.args, **expand.kwargs)
77
+ expand_input = expand_args.input
78
+ expand_shape = extract_shape(expand)
79
+
80
+ if not isinstance(expand_input, torch.fx.Node) or not is_target_node(
81
+ expand_input, ops.aten.reshape
82
+ ):
83
+ continue
84
+
85
+ pre_reshape = expand_input
86
+ pre_reshape_args = ReshapeArgs(*pre_reshape.args, **pre_reshape.kwargs)
87
+ pre_reshape_input = pre_reshape_args.input
88
+ pre_reshape_shape = extract_shape(pre_reshape)
89
+
90
+ if pre_reshape_shape[EXPAND_DIM] != 1:
91
+ continue
92
+
93
+ reshape_input_shape = extract_shape(pre_reshape_input)
94
+
95
+ if len(expand_shape) != len(pre_reshape_shape):
96
+ continue
97
+
98
+ # Ensure all dimensions *except* at EXPAND_DIM are identical.
99
+ if not (
100
+ expand_shape[:EXPAND_DIM] == pre_reshape_shape[:EXPAND_DIM]
101
+ and expand_shape[EXPAND_DIM + 1 :]
102
+ == pre_reshape_shape[EXPAND_DIM + 1 :]
103
+ ):
104
+ continue
105
+
106
+ # Ensure the expansion dimension is a clean multiple.
107
+ if expand_shape[EXPAND_DIM] % pre_reshape_shape[EXPAND_DIM] != 0:
108
+ continue
109
+
110
+ expand_ratio = expand_shape[EXPAND_DIM] // pre_reshape_shape[EXPAND_DIM]
111
+
112
+ if expand_ratio <= 1:
113
+ continue
114
+
115
+ cat_nodes = []
116
+
117
+ for i in range(reshape_input_shape[CAT_DIM]):
118
+ with graph.inserting_before(expand):
119
+ slice_copy_args = (pre_reshape_input, CAT_DIM, i, i + 1, 1)
120
+ slice_node = create_node(
121
+ graph,
122
+ torch.ops.aten.slice.Tensor,
123
+ args=slice_copy_args,
124
+ origin=expand,
125
+ )
126
+ with graph.inserting_after(slice_node):
127
+ cat_args = ([slice_node] * expand_ratio, CAT_DIM)
128
+ cat_node = create_node(
129
+ graph,
130
+ torch.ops.aten.cat.default,
131
+ args=cat_args,
132
+ origin=expand,
133
+ )
134
+ cat_nodes.append(cat_node)
135
+
136
+ with graph.inserting_after(expand):
137
+ cat_args = (cat_nodes, CAT_DIM)
138
+ cat_node = create_node(
139
+ graph,
140
+ torch.ops.aten.cat.default,
141
+ args=cat_args,
142
+ origin=expand,
143
+ )
144
+ expand.replace_all_uses_with(cat_node)
145
+
146
+ modified = True
147
+ logger.debug(f"{expand.name} is replaced with {cat_node.name} operators")
148
+
149
+ graph.eliminate_dead_code()
150
+ graph.lint()
151
+ graph_module.recompile()
152
+
153
+ return PassResult(modified)