tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. tico/__init__.py +9 -1
  2. tico/config/base.py +1 -1
  3. tico/config/v1.py +5 -0
  4. tico/passes/cast_aten_where_arg_type.py +1 -1
  5. tico/passes/cast_clamp_mixed_type_args.py +169 -0
  6. tico/passes/cast_mixed_type_args.py +4 -2
  7. tico/passes/const_prop_pass.py +1 -1
  8. tico/passes/convert_conv1d_to_conv2d.py +1 -1
  9. tico/passes/convert_expand_to_slice_cat.py +153 -0
  10. tico/passes/convert_matmul_to_linear.py +312 -0
  11. tico/passes/convert_to_relu6.py +1 -1
  12. tico/passes/decompose_addmm.py +0 -3
  13. tico/passes/decompose_batch_norm.py +2 -2
  14. tico/passes/decompose_fake_quantize.py +0 -3
  15. tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
  16. tico/passes/decompose_group_norm.py +0 -3
  17. tico/passes/legalize_predefined_layout_operators.py +2 -11
  18. tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
  19. tico/passes/lower_to_slice.py +1 -1
  20. tico/passes/merge_consecutive_cat.py +1 -1
  21. tico/passes/ops.py +1 -1
  22. tico/passes/remove_redundant_assert_nodes.py +3 -1
  23. tico/passes/remove_redundant_expand.py +3 -6
  24. tico/passes/remove_redundant_reshape.py +5 -5
  25. tico/passes/segment_index_select.py +1 -1
  26. tico/quantization/__init__.py +6 -0
  27. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
  28. tico/quantization/algorithm/gptq/quantizer.py +292 -0
  29. tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
  30. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
  31. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  32. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  33. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  34. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  35. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
  36. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  37. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  38. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  39. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  40. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  41. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  42. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  43. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  44. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
  45. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  46. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  47. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  48. tico/quantization/config/base.py +26 -0
  49. tico/quantization/config/gptq.py +29 -0
  50. tico/quantization/config/pt2e.py +25 -0
  51. tico/quantization/config/ptq.py +119 -0
  52. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
  53. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
  54. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  55. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  56. tico/quantization/evaluation/metric.py +146 -0
  57. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  58. tico/quantization/passes/__init__.py +1 -0
  59. tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
  60. tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
  61. tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
  62. tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
  63. tico/{experimental/quantization → quantization}/public_interface.py +19 -18
  64. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  65. tico/quantization/quantizer_registry.py +73 -0
  66. tico/quantization/wrapq/__init__.py +1 -0
  67. tico/quantization/wrapq/dtypes.py +70 -0
  68. tico/quantization/wrapq/examples/__init__.py +1 -0
  69. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  70. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  71. tico/quantization/wrapq/examples/quantize_linear.py +107 -0
  72. tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
  73. tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
  74. tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
  75. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  76. tico/quantization/wrapq/mode.py +32 -0
  77. tico/quantization/wrapq/observers/__init__.py +1 -0
  78. tico/quantization/wrapq/observers/affine_base.py +128 -0
  79. tico/quantization/wrapq/observers/base.py +98 -0
  80. tico/quantization/wrapq/observers/ema.py +62 -0
  81. tico/quantization/wrapq/observers/identity.py +74 -0
  82. tico/quantization/wrapq/observers/minmax.py +39 -0
  83. tico/quantization/wrapq/observers/mx.py +60 -0
  84. tico/quantization/wrapq/qscheme.py +40 -0
  85. tico/quantization/wrapq/quantizer.py +179 -0
  86. tico/quantization/wrapq/utils/__init__.py +1 -0
  87. tico/quantization/wrapq/utils/introspection.py +167 -0
  88. tico/quantization/wrapq/utils/metrics.py +124 -0
  89. tico/quantization/wrapq/utils/reduce_utils.py +25 -0
  90. tico/quantization/wrapq/wrappers/__init__.py +1 -0
  91. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  92. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  93. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  94. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  95. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  96. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  97. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  98. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  99. tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
  100. tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
  101. tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
  102. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  103. tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
  104. tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
  105. tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
  106. tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
  107. tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
  108. tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
  109. tico/quantization/wrapq/wrappers/registry.py +125 -0
  110. tico/serialize/circle_graph.py +12 -4
  111. tico/serialize/circle_mapping.py +76 -2
  112. tico/serialize/circle_serializer.py +253 -148
  113. tico/serialize/operators/adapters/__init__.py +1 -0
  114. tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
  115. tico/serialize/operators/op_any.py +7 -14
  116. tico/serialize/operators/op_avg_pool2d.py +11 -4
  117. tico/serialize/operators/op_clamp.py +5 -7
  118. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  119. tico/serialize/operators/op_conv2d.py +14 -6
  120. tico/serialize/operators/op_copy.py +26 -3
  121. tico/serialize/operators/op_cumsum.py +3 -1
  122. tico/serialize/operators/op_depthwise_conv2d.py +17 -7
  123. tico/serialize/operators/op_full_like.py +0 -2
  124. tico/serialize/operators/op_index_select.py +8 -1
  125. tico/serialize/operators/op_instance_norm.py +0 -6
  126. tico/serialize/operators/op_le.py +54 -0
  127. tico/serialize/operators/op_log1p.py +3 -2
  128. tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
  129. tico/serialize/operators/op_mm.py +15 -131
  130. tico/serialize/operators/op_mul.py +2 -8
  131. tico/serialize/operators/op_pow.py +3 -1
  132. tico/serialize/operators/op_repeat.py +12 -3
  133. tico/serialize/operators/op_reshape.py +1 -1
  134. tico/serialize/operators/op_rmsnorm.py +65 -0
  135. tico/serialize/operators/op_softmax.py +7 -14
  136. tico/serialize/operators/op_split_with_sizes.py +16 -8
  137. tico/serialize/operators/op_transpose_conv.py +11 -8
  138. tico/serialize/operators/op_view.py +2 -1
  139. tico/serialize/quant_param.py +5 -5
  140. tico/utils/convert.py +30 -17
  141. tico/utils/dtype.py +42 -0
  142. tico/utils/graph.py +1 -1
  143. tico/utils/model.py +2 -1
  144. tico/utils/padding.py +2 -2
  145. tico/utils/pytree_utils.py +134 -0
  146. tico/utils/record_input.py +102 -0
  147. tico/utils/register_custom_op.py +29 -4
  148. tico/utils/serialize.py +16 -3
  149. tico/utils/signature.py +247 -0
  150. tico/utils/torch_compat.py +52 -0
  151. tico/utils/utils.py +50 -58
  152. tico/utils/validate_args_kwargs.py +38 -3
  153. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
  154. tico-0.1.0.dev251102.dist-info/RECORD +271 -0
  155. tico/experimental/quantization/__init__.py +0 -1
  156. tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
  157. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  158. tico/experimental/quantization/evaluation/metric.py +0 -109
  159. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
  160. tico-0.1.0.dev250714.dist-info/RECORD +0 -209
  161. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  162. /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
  163. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  164. /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
  165. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
  166. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  167. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  168. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
  169. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  170. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
  171. /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
  172. /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
  173. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  174. /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
  175. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  176. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  177. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  178. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
  179. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
  180. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
  181. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
@@ -18,20 +18,16 @@ from typing import Dict
18
18
  import flatbuffers
19
19
  import torch
20
20
  from circle_schema import circle
21
- from torch.export.exported_program import (
22
- ConstantArgument,
23
- ExportedProgram,
24
- InputKind,
25
- TensorArgument,
26
- )
27
-
28
- from tico.serialize.circle_mapping import to_circle_dtype
21
+ from torch.export.exported_program import ConstantArgument, ExportedProgram, InputKind
22
+
23
+ from tico.config import CompileConfigBase, get_default_config
24
+ from tico.serialize.circle_mapping import to_circle_dtype, to_circle_shape
29
25
  from tico.serialize.operators import *
30
26
  from tico.serialize.circle_graph import CircleModel, CircleSubgraph
31
27
  from tico.serialize.operators.hashable_opcode import OpCode
32
28
  from tico.serialize.operators.node_visitor import get_node_visitors
33
29
  from tico.utils import logging
34
- from tico.utils.serialize import finalise_tensor_names
30
+ from tico.utils.serialize import finalise_tensor_names, validate_tensor_shapes
35
31
 
36
32
 
37
33
  multiple_output_ops = [
@@ -39,161 +35,58 @@ multiple_output_ops = [
39
35
  torch.ops.aten.max.dim,
40
36
  ]
41
37
 
42
- # Build circle model from ExportedProgram
43
- # Return raw bytes of circle model
44
- def build_circle(edge_program: ExportedProgram) -> bytes:
45
- logger = logging.getLogger(__name__)
46
38
 
47
- builder = flatbuffers.Builder()
39
+ def _initialize_model() -> tuple[CircleModel, CircleSubgraph]:
40
+ """Initialize a new Circle model and subgraph.
48
41
 
49
- # Init Model
42
+ Returns:
43
+ Tuple containing the model and subgraph
44
+ """
50
45
  model = CircleModel()
51
-
52
- # Add empty buffer at the front (convention)
53
- model.add_buffer(circle.Buffer.BufferT())
54
-
55
- # Create an empty subgraph (assume a single subgraph)
46
+ model.add_buffer(circle.Buffer.BufferT()) # Add empty buffer at the front
56
47
  graph = CircleSubgraph(model)
48
+ return model, graph
57
49
 
58
- # Export tensors
59
- logger.debug("---------------Export tensors--------------")
60
- buf_name_to_data = {name: buf for name, buf in edge_program.named_buffers()}
61
- for node in edge_program.graph.nodes:
62
- if node.op == "call_function":
63
- if node.target in multiple_output_ops:
64
- continue
65
- node_val = node.meta["val"]
66
- if node_val.layout != torch.strided:
67
- raise RuntimeError(
68
- f"Only support dense tensors (node layout: {node_val.layout})"
69
- )
70
- graph.add_tensor_from_node(node)
71
- logger.debug(f"call_function: {node.name} tensor exported.")
72
-
73
- # placeholder: function input (including parameters, buffers, constant tensors)
74
- elif node.op == "placeholder":
75
- # placeholder invariants
76
- assert node.args is None or len(node.args) == 0 # Not support default param
77
-
78
- # parameters
79
- if node.name in edge_program.graph_signature.inputs_to_parameters:
80
- param_name = edge_program.graph_signature.inputs_to_parameters[
81
- node.name
82
- ]
83
- param_data = edge_program.state_dict[param_name]
84
-
85
- assert isinstance(
86
- param_data, torch.Tensor
87
- ), "Expect parameters to be a tensor"
88
- param_value = param_data.cpu().detach().numpy()
89
-
90
- graph.add_tensor_from_node(node, param_value)
91
- logger.debug(f"placeholder(param): {node.name} tensor exported.")
92
- elif node.name in edge_program.graph_signature.inputs_to_buffers:
93
- buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name]
94
- assert buffer_name in buf_name_to_data
95
- buffer_data = buf_name_to_data[buffer_name]
96
- assert isinstance(
97
- buffer_data, torch.Tensor
98
- ), "Expect buffers to be a tensor"
99
- buffer_value = buffer_data.cpu().detach().numpy()
100
-
101
- graph.add_tensor_from_node(node, buffer_value)
102
- logger.debug(f"placeholder(buffer): {node.name} tensor exported.")
103
- elif (
104
- node.name
105
- in edge_program.graph_signature.inputs_to_lifted_tensor_constants
106
- ):
107
- ctensor_name = (
108
- edge_program.graph_signature.inputs_to_lifted_tensor_constants[
109
- node.name
110
- ]
111
- )
112
- ctensor_data = edge_program.constants[ctensor_name]
113
-
114
- assert isinstance(
115
- ctensor_data, torch.Tensor
116
- ), "Expect constant tensor to be a tensor"
117
- ctensor_value = ctensor_data.cpu().detach().numpy()
118
50
 
119
- graph.add_tensor_from_node(node, ctensor_value)
120
- logger.debug(
121
- f"placeholder(constant tensor): {node.name} tensor exported."
122
- )
123
- else:
124
- user_inputs = [
125
- specs
126
- for specs in edge_program.graph_signature.input_specs
127
- if specs.kind == InputKind.USER_INPUT
128
- ]
129
- constant_inputs = [
130
- specs
131
- for specs in user_inputs
132
- if isinstance(specs.arg, ConstantArgument)
133
- ]
134
- name_to_value = {
135
- specs.arg.name: specs.arg.value for specs in constant_inputs
136
- }
137
- # NoneType ConstantArgument is ignored.
138
- if node.name in name_to_value and name_to_value[node.name] == None:
139
- continue
140
- graph.add_tensor_from_node(node)
141
- logger.debug(f"placeholder: {node.name} tensor exported.")
142
-
143
- # get_attr: retrieve parameter
144
- elif node.op == "get_attr":
145
- # node.name: Place where fetched attribute is saved
146
- # node.target: Attribute in the module
147
- attr_tensor = getattr(node.graph.owning_module, node.target)
148
- assert isinstance(attr_tensor, torch.Tensor)
51
+ def build_circle(
52
+ ep: ExportedProgram, config: CompileConfigBase = get_default_config()
53
+ ) -> bytes:
54
+ """Convert ExportedProgram to Circle format.
149
55
 
150
- graph.add_tensor_from_scratch(
151
- prefix=node.name,
152
- shape=list(attr_tensor.shape),
153
- dtype=to_circle_dtype(attr_tensor.dtype),
154
- source_node=node,
155
- )
156
-
157
- logger.debug(f"get_attr: {node.name} tensor exported.")
158
-
159
- # output: function output
160
- elif node.op == "output":
161
- # output node itself does not need a buffer
162
- # argument of output node is assumed to be exported beforehand
163
- for output in node.args[0]:
164
- if isinstance(output, torch.fx.Node):
165
- assert graph.has_tensor(output.name)
166
- continue
56
+ Args:
57
+ ep: The exported PyTorch program to convert
167
58
 
168
- # call_method: call method
169
- elif node.op == "call_method":
170
- raise AssertionError("Not yet implemented")
171
-
172
- # call_module: call 'forward' of module
173
- elif node.op == "call_module":
174
- raise AssertionError("Not yet implemented")
59
+ Returns:
60
+ bytes: Raw bytes of the Circle model
61
+ """
62
+ logger = logging.getLogger(__name__)
63
+ builder = flatbuffers.Builder()
64
+ model, graph = _initialize_model()
175
65
 
176
- else:
177
- # Add more if fx.Node is extended
178
- raise AssertionError(f"Unknown fx.Node op {node.op}")
66
+ # Export tensors
67
+ _export_tensors(graph, ep)
179
68
 
180
69
  # Register inputs
181
70
  logger.debug("---------------Register inputs--------------")
182
- for in_spec in edge_program.graph_signature.input_specs:
71
+ for in_spec in ep.graph_signature.input_specs:
183
72
  if in_spec.kind != InputKind.USER_INPUT:
184
73
  continue
185
- # NoneType ConstantArgument is ignored.
186
- if isinstance(in_spec.arg, ConstantArgument) and in_spec.arg.value == None:
187
- continue
74
+ if isinstance(in_spec.arg, ConstantArgument):
75
+ # ConstantArgument is ignored when option is given
76
+ if config.get("remove_constant_input"):
77
+ continue
78
+ # NoneType ConstantArgument is ignored.
79
+ if in_spec.arg.value == None:
80
+ continue
188
81
  arg_name = in_spec.arg.name
189
82
  graph.add_input(arg_name)
190
83
  logger.debug(f"Registered input: {arg_name}")
191
84
 
192
85
  # Register outputs
193
86
  logger.debug("---------------Register outputs--------------")
194
- for user_output in edge_program.graph_signature.user_outputs:
87
+ for user_output in ep.graph_signature.user_outputs:
195
88
  if user_output == None:
196
- logger.debug(f"Ignore 'None' output")
89
+ logger.debug("Ignore 'None' output")
197
90
  continue
198
91
 
199
92
  graph.add_output(user_output)
@@ -203,7 +96,7 @@ def build_circle(edge_program: ExportedProgram) -> bytes:
203
96
  logger.debug("---------------Export operators--------------")
204
97
  op_codes: Dict[OpCode, int] = {}
205
98
  visitors = get_node_visitors(op_codes, graph)
206
- for node in edge_program.graph.nodes:
99
+ for node in ep.graph.nodes:
207
100
  if node.op != "call_function":
208
101
  continue
209
102
 
@@ -218,8 +111,10 @@ def build_circle(edge_program: ExportedProgram) -> bytes:
218
111
  graph.add_operator(circle_op)
219
112
  logger.debug(f"call_function: {node.name} ({opcode}) Op exported.")
220
113
 
221
- # Register subgraph
222
114
  finalise_tensor_names(graph)
115
+ validate_tensor_shapes(graph)
116
+
117
+ # Register subgraph
223
118
  model.subgraphs.append(graph)
224
119
 
225
120
  # Encode operator codes
@@ -227,10 +122,8 @@ def build_circle(edge_program: ExportedProgram) -> bytes:
227
122
  code for code, _ in sorted(op_codes.items(), key=lambda x: x[1])
228
123
  ]
229
124
 
230
- # Description
125
+ # Final model settings
231
126
  model.description = "circle"
232
-
233
- # Set version
234
127
  model.version = 0
235
128
 
236
129
  # Finish model
@@ -238,3 +131,215 @@ def build_circle(edge_program: ExportedProgram) -> bytes:
238
131
  buf = builder.Output()
239
132
 
240
133
  return bytes(buf)
134
+
135
+
136
+ def _export_tensors(graph: CircleSubgraph, ep: ExportedProgram) -> None:
137
+ """Export all tensors from the exported program to the circle graph.
138
+
139
+ Args:
140
+ graph: The CircleSubgraph to add tensors to
141
+ ep: The exported PyTorch program
142
+ """
143
+ logger = logging.getLogger(__name__)
144
+ logger.debug("---------------Export tensors--------------")
145
+ buf_name_to_data = {name: buf for name, buf in ep.named_buffers()}
146
+
147
+ for node in ep.graph.nodes:
148
+ if node.op == "call_function":
149
+ if node.target in multiple_output_ops:
150
+ continue
151
+ node_val = node.meta["val"]
152
+ if node_val.layout != torch.strided:
153
+ raise RuntimeError(
154
+ f"Only support dense tensors (node layout: {node_val.layout})"
155
+ )
156
+ graph.add_tensor_from_node(node)
157
+ logger.debug(f"call_function: {node.name} tensor exported.")
158
+
159
+ elif node.op == "placeholder":
160
+ _handle_placeholder_node(graph, node, ep, buf_name_to_data)
161
+
162
+ elif node.op == "get_attr":
163
+ _handle_get_attr_node(graph, node)
164
+
165
+ elif node.op == "output":
166
+ for output in node.args[0]:
167
+ if isinstance(output, torch.fx.Node):
168
+ assert graph.has_tensor(output.name)
169
+ continue
170
+
171
+ elif node.op == "call_method":
172
+ raise AssertionError("Not yet implemented")
173
+
174
+ elif node.op == "call_module":
175
+ raise AssertionError("Not yet implemented")
176
+
177
+ else:
178
+ raise AssertionError(f"Unknown fx.Node op {node.op}")
179
+
180
+
181
+ def _handle_placeholder_node(
182
+ graph: CircleSubgraph,
183
+ node: torch.fx.Node,
184
+ ep: ExportedProgram,
185
+ buf_name_to_data: dict,
186
+ ) -> None:
187
+ """Handle a placeholder node during tensor export."""
188
+ # placeholder invariants
189
+ assert node.args is None or len(node.args) == 0 # Not support default param
190
+
191
+ if node.name in ep.graph_signature.inputs_to_parameters:
192
+ _handle_parameter_node(graph, node, ep)
193
+ elif node.name in ep.graph_signature.inputs_to_buffers:
194
+ _handle_buffer_node(graph, node, ep, buf_name_to_data)
195
+ elif node.name in ep.graph_signature.inputs_to_lifted_tensor_constants:
196
+ _handle_constant_tensor_node(graph, node, ep)
197
+ else:
198
+ _handle_user_input_node(graph, node, ep)
199
+
200
+
201
+ def _handle_parameter_node(
202
+ graph: CircleSubgraph,
203
+ node: torch.fx.Node,
204
+ ep: ExportedProgram,
205
+ ) -> None:
206
+ """Handle a parameter placeholder node by exporting its tensor data.
207
+
208
+ Args:
209
+ graph: CircleSubgraph to add tensor to
210
+ node: The parameter node to process
211
+ ep: ExportedProgram containing parameter data
212
+ """
213
+ param_name = ep.graph_signature.inputs_to_parameters[node.name]
214
+ param_data = ep.state_dict[param_name]
215
+
216
+ if not isinstance(param_data, torch.Tensor):
217
+ raise ValueError(f"Parameter {param_name} is not a tensor")
218
+
219
+ tensor_value = param_data.cpu().detach().numpy()
220
+ graph.add_tensor_from_node(node, tensor_value)
221
+
222
+ logger = logging.getLogger(__name__)
223
+ logger.debug(f"Exported parameter tensor: {node.name}")
224
+
225
+
226
+ def _handle_buffer_node(
227
+ graph: CircleSubgraph,
228
+ node: torch.fx.Node,
229
+ ep: ExportedProgram,
230
+ buf_name_to_data: dict,
231
+ ) -> None:
232
+ """Handle a buffer placeholder node by exporting its tensor data.
233
+
234
+ Args:
235
+ graph: CircleSubgraph to add tensor to
236
+ node: The buffer node to process
237
+ ep: ExportedProgram containing buffer info
238
+ buf_name_to_data: Mapping of buffer names to data
239
+ """
240
+ buffer_name = ep.graph_signature.inputs_to_buffers[node.name]
241
+
242
+ if buffer_name not in buf_name_to_data:
243
+ raise ValueError(f"Buffer {buffer_name} not found in buffer data")
244
+
245
+ buffer_data = buf_name_to_data[buffer_name]
246
+
247
+ if not isinstance(buffer_data, torch.Tensor):
248
+ raise ValueError(f"Buffer {buffer_name} is not a tensor")
249
+
250
+ tensor_value = buffer_data.cpu().detach().numpy()
251
+ graph.add_tensor_from_node(node, tensor_value)
252
+
253
+ logger = logging.getLogger(__name__)
254
+ logger.debug(f"Exported buffer tensor: {node.name}")
255
+
256
+
257
+ def _handle_constant_tensor_node(
258
+ graph: CircleSubgraph,
259
+ node: torch.fx.Node,
260
+ ep: ExportedProgram,
261
+ ) -> None:
262
+ """Handle a constant tensor placeholder node by exporting its tensor data.
263
+
264
+ Args:
265
+ graph: CircleSubgraph to add tensor to
266
+ node: The constant tensor node to process
267
+ ep: ExportedProgram containing constant data
268
+ """
269
+ ctensor_name = ep.graph_signature.inputs_to_lifted_tensor_constants[node.name]
270
+
271
+ if ctensor_name not in ep.constants:
272
+ raise ValueError(f"Constant tensor {ctensor_name} not found")
273
+
274
+ ctensor_data = ep.constants[ctensor_name]
275
+
276
+ if not isinstance(ctensor_data, torch.Tensor):
277
+ raise ValueError(f"Constant tensor {ctensor_name} is not a tensor")
278
+
279
+ tensor_value = ctensor_data.cpu().detach().numpy()
280
+ graph.add_tensor_from_node(node, tensor_value)
281
+
282
+ logger = logging.getLogger(__name__)
283
+ logger.debug(f"Exported constant tensor: {node.name}")
284
+
285
+
286
+ def _handle_user_input_node(
287
+ graph: CircleSubgraph,
288
+ node: torch.fx.Node,
289
+ ep: ExportedProgram,
290
+ ) -> None:
291
+ """Handle a user input placeholder node by exporting its tensor data.
292
+
293
+ Args:
294
+ graph: CircleSubgraph to add tensor to
295
+ node: The user input node to process
296
+ ep: ExportedProgram containing input specs
297
+ """
298
+ user_inputs = [
299
+ specs
300
+ for specs in ep.graph_signature.input_specs
301
+ if specs.kind == InputKind.USER_INPUT
302
+ ]
303
+ constant_inputs = [
304
+ specs for specs in user_inputs if isinstance(specs.arg, ConstantArgument)
305
+ ]
306
+ name_to_value = {specs.arg.name: specs.arg.value for specs in constant_inputs}
307
+
308
+ # Skip NoneType ConstantArgument
309
+ if node.name in name_to_value and name_to_value[node.name] is None:
310
+ return
311
+
312
+ graph.add_tensor_from_node(node)
313
+
314
+ logger = logging.getLogger(__name__)
315
+ logger.debug(f"Exported user input tensor: {node.name}")
316
+
317
+
318
+ def _handle_get_attr_node(
319
+ graph: CircleSubgraph,
320
+ node: torch.fx.Node,
321
+ ) -> None:
322
+ """Handle a get_attr node by exporting its tensor data.
323
+
324
+ Args:
325
+ graph: CircleSubgraph to add tensor to
326
+ node: The get_attr node to process
327
+ """
328
+ assert isinstance(node.target, str)
329
+ attr_tensor = getattr(node.graph.owning_module, node.target)
330
+
331
+ if not isinstance(attr_tensor, torch.Tensor):
332
+ raise ValueError(f"Attribute {node.target} is not a tensor")
333
+
334
+ attr_shape, attr_shape_signature = to_circle_shape(attr_tensor.shape)
335
+
336
+ graph.add_tensor_from_scratch(
337
+ prefix=node.name,
338
+ shape=attr_shape,
339
+ shape_signature=attr_shape_signature,
340
+ dtype=to_circle_dtype(attr_tensor.dtype),
341
+ source_node=node,
342
+ )
343
+
344
+ logger = logging.getLogger(__name__)
345
+ logger.debug(f"Exported attribute tensor: {node.name}")
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -0,0 +1,35 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from contextlib import contextmanager
16
+
17
+ import torch
18
+
19
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
20
+
21
+
22
+ def llama_rmsnorm_forward_adapter(self: LlamaRMSNorm, hidden_states: torch.Tensor):
23
+ return torch.ops.circle_custom.rms_norm(
24
+ hidden_states, self.weight, self.variance_epsilon
25
+ )
26
+
27
+
28
+ @contextmanager
29
+ def patched_llama_rmsnorm():
30
+ orig = LlamaRMSNorm.forward
31
+ LlamaRMSNorm.forward = llama_rmsnorm_forward_adapter
32
+ try:
33
+ yield
34
+ finally:
35
+ LlamaRMSNorm.forward = orig
@@ -22,7 +22,7 @@ from circle_schema import circle
22
22
  from tico.serialize.circle_graph import CircleSubgraph
23
23
  from tico.serialize.circle_mapping import (
24
24
  circle_legalize_dtype_to,
25
- extract_circle_dtype,
25
+ extract_circle_shape,
26
26
  extract_shape,
27
27
  extract_torch_dtype,
28
28
  )
@@ -100,19 +100,10 @@ class AnyVisitor(NodeVisitor):
100
100
  keepdim = args.keepdim
101
101
 
102
102
  input_shape = list(extract_shape(input))
103
- output_shape = list(extract_shape(node))
104
-
105
- dim_i32 = None
106
103
  if dim is None:
107
- dims = tuple(i for i in range(0, len(input_shape)))
108
- dim_i32 = tuple(
109
- circle_legalize_dtype_to(dim, dtype=torch.int32) for dim in dims
110
- )
111
- if isinstance(dim, int):
112
- dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
113
- if isinstance(dim, tuple):
114
- dim_i32 = tuple(circle_legalize_dtype_to(d, dtype=torch.int32) for d in dim)
115
- assert dim_i32 is not None
104
+ dim = tuple(i for i in range(0, len(input_shape)))
105
+
106
+ dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
116
107
 
117
108
  inputs = [
118
109
  input,
@@ -126,9 +117,11 @@ class AnyVisitor(NodeVisitor):
126
117
  if dtype_torch in [torch.int32, torch.int64, torch.float32, torch.float64]:
127
118
  dst_dtype_circle = circle.TensorType.TensorType.BOOL
128
119
  dst_dtype_torch = torch.bool
120
+ dst_shape, dst_shape_signature = extract_circle_shape(input)
129
121
  ne_tensor: circle.Tensor.TensorT = self.graph.add_tensor_from_scratch(
130
122
  prefix=f"{input.name}_ne",
131
- shape=input_shape,
123
+ shape=dst_shape,
124
+ shape_signature=dst_shape_signature,
132
125
  dtype=dst_dtype_circle,
133
126
  source_node=input,
134
127
  )
@@ -22,7 +22,11 @@ import torch
22
22
  from circle_schema import circle
23
23
 
24
24
  from tico.serialize.circle_graph import CircleSubgraph
25
- from tico.serialize.circle_mapping import extract_circle_dtype, extract_shape
25
+ from tico.serialize.circle_mapping import (
26
+ extract_circle_dtype,
27
+ extract_shape,
28
+ to_circle_shape,
29
+ )
26
30
  from tico.serialize.operators.hashable_opcode import OpCode
27
31
  from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
28
32
  from tico.serialize.operators.utils import create_builtin_operator, get_op_index
@@ -57,7 +61,7 @@ class AvgPool2DVisitor(NodeVisitor):
57
61
  return True
58
62
 
59
63
  def has_same_padding(self, args: AvgPool2dArgs) -> bool:
60
- input_shape = list(extract_shape(args.input))
64
+ input_shape: torch.Size = extract_shape(args.input)
61
65
  kernel_size = args.kernel_size
62
66
  stride = args.stride
63
67
  assert stride
@@ -137,7 +141,7 @@ class AvgPool2DVisitor(NodeVisitor):
137
141
  ],
138
142
  dtype=torch.int32,
139
143
  )
140
- input_shape = list(extract_shape(input))
144
+ input_shape = extract_shape(input)
141
145
  input_dtype: int = extract_circle_dtype(input)
142
146
  padded_input_shape = [
143
147
  input_shape[0],
@@ -147,10 +151,13 @@ class AvgPool2DVisitor(NodeVisitor):
147
151
  ]
148
152
  padded_input_shape[1] += padding[0] * 2
149
153
  padded_input_shape[2] += padding[1] * 2
154
+
150
155
  # create padded input tensor
156
+ padded_cshape, padded_cshape_signature = to_circle_shape(padded_input_shape)
151
157
  padded_input_tensor = self.graph.add_tensor_from_scratch(
152
158
  prefix=f"{input.name}_pad_output",
153
- shape=padded_input_shape,
159
+ shape=padded_cshape,
160
+ shape_signature=padded_cshape_signature,
154
161
  dtype=input_dtype,
155
162
  source_node=node,
156
163
  )
@@ -21,12 +21,9 @@ import torch
21
21
  from circle_schema import circle
22
22
 
23
23
  from tico.passes import ops
24
+ from tico.serialize.circle_graph import CircleSubgraph
24
25
 
25
- from tico.serialize.circle_graph import (
26
- CircleSubgraph,
27
- extract_circle_dtype,
28
- extract_shape,
29
- )
26
+ from tico.serialize.circle_mapping import extract_circle_dtype, extract_circle_shape
30
27
  from tico.serialize.operators.hashable_opcode import OpCode
31
28
  from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
32
29
  from tico.serialize.operators.utils import create_builtin_operator, get_op_index
@@ -104,12 +101,13 @@ class ClampVisitor(NodeVisitor):
104
101
  return self.define_minimum_node([input, max_val], [node])
105
102
 
106
103
  elif min_val is not None and max_val is not None:
107
- input_shape = extract_shape(input)
104
+ input_shape, input_shape_signature = extract_circle_shape(input)
108
105
  input_dtype = extract_circle_dtype(input)
109
106
  minimum_tensor = self.graph.add_tensor_from_scratch(
110
107
  prefix=f"{input.name}_min",
111
108
  dtype=input_dtype,
112
- shape=list(input_shape),
109
+ shape=input_shape,
110
+ shape_signature=input_shape_signature,
113
111
  source_node=node,
114
112
  )
115
113
  minimum_opertor = self.define_minimum_node(