tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
tico/utils/graph.py ADDED
@@ -0,0 +1,282 @@
1
+ # Portions of this file are adapted from code originally authored by
2
+ # Meta Platforms, Inc. and affiliates, licensed under the BSD-style
3
+ # license found in the LICENSE file in the root directory of their source tree.
4
+
5
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
20
+
21
+ if TYPE_CHECKING:
22
+ import torch.fx
23
+ import torch
24
+ from torch.export import ExportedProgram
25
+ from torch.export.exported_program import InputKind, InputSpec, TensorArgument
26
+
27
+ from tico.utils.utils import get_fake_mode, set_new_meta_val
28
+
29
+
30
+ def is_torch_param(node: torch.fx.Node, ep: ExportedProgram):
31
+ assert node.op == "placeholder"
32
+
33
+ return node.name in ep.graph_signature.inputs_to_parameters
34
+
35
+
36
+ def is_torch_buffer(node: torch.fx.Node, ep: ExportedProgram):
37
+ assert node.op == "placeholder"
38
+
39
+ return node.name in ep.graph_signature.inputs_to_buffers
40
+
41
+
42
+ def get_torch_param_value(node: torch.fx.Node, ep: ExportedProgram):
43
+ assert isinstance(node, torch.fx.Node)
44
+ assert node.op == "placeholder"
45
+ assert (
46
+ node.name in ep.graph_signature.inputs_to_parameters
47
+ ), "Node {node.name} is not in the parameters" # FIX CALLER UNLESS
48
+
49
+ param_name = ep.graph_signature.inputs_to_parameters[node.name]
50
+ named_params = dict(ep.named_parameters())
51
+ assert param_name in named_params
52
+
53
+ return named_params[param_name].data
54
+
55
+
56
+ def get_torch_buffer_value(node: torch.fx.Node, ep: ExportedProgram):
57
+ assert isinstance(node, torch.fx.Node)
58
+ assert node.op == "placeholder"
59
+ assert (
60
+ node.name in ep.graph_signature.inputs_to_buffers
61
+ ), "Node {node.name} is not in the buffers" # FIX CALLER UNLESS
62
+
63
+ buf_name = ep.graph_signature.inputs_to_buffers[node.name]
64
+ named_buf = dict(ep.named_buffers())
65
+ assert buf_name in named_buf
66
+
67
+ return named_buf[buf_name]
68
+
69
+
70
+ def get_first_user_input(exported_program: ExportedProgram) -> Optional[torch.fx.Node]:
71
+ """Returns the first user input node in the graph."""
72
+ first_user_input: Optional[torch.fx.Node] = None
73
+ graph_module = exported_program.graph_module
74
+ graph: torch.fx.Graph = graph_module.graph
75
+ for node in graph.nodes:
76
+ if (
77
+ node.op == "placeholder"
78
+ and node.name in exported_program.graph_signature.user_inputs
79
+ ):
80
+ first_user_input = node
81
+ break
82
+
83
+ return first_user_input
84
+
85
+
86
+ def generate_fqn(prefix: str, exported_program: ExportedProgram):
87
+ """
88
+ Generate fully-qualized name for constants.
89
+
90
+ This function prevents `exported_program.constants` from having duplicate keys.
91
+ """
92
+ cnt = len(exported_program.constants)
93
+ while True:
94
+ if f"{prefix}{cnt}" in exported_program.constants:
95
+ cnt += 1
96
+ continue
97
+ break
98
+ return f"{prefix}{cnt}"
99
+
100
+
101
+ def create_input_spec(node, input_kind: InputKind):
102
+ """
103
+ @ref https://pytorch.org/docs/stable/export.ir_spec.html#placeholder
104
+ """
105
+ if input_kind == InputKind.CONSTANT_TENSOR:
106
+ return InputSpec(
107
+ kind=InputKind.CONSTANT_TENSOR,
108
+ arg=TensorArgument(name=node.name),
109
+ target=node.target, # type: ignore[arg-type]
110
+ persistent=True,
111
+ )
112
+ else:
113
+ raise NotImplementedError("NYI")
114
+
115
+
116
+ def validate_input_specs(exported_program):
117
+ name_to_spec_dict = {
118
+ s.arg.name: s for s in exported_program.graph_signature.input_specs
119
+ }
120
+
121
+ for node in exported_program.graph.nodes:
122
+ if node.op != "placeholder":
123
+ continue
124
+
125
+ if node.name not in name_to_spec_dict:
126
+ raise RuntimeError(
127
+ "Placeholder node {node.name} does not have corresponding input spec!"
128
+ )
129
+
130
+
131
+ def add_placeholder(
132
+ exported_program: ExportedProgram,
133
+ tensor: torch.Tensor,
134
+ prefix: str,
135
+ ) -> torch.fx.Node:
136
+ """
137
+ Add a placeholder to the graph and update the exported program.
138
+ """
139
+ fqn_name = generate_fqn(prefix, exported_program)
140
+
141
+ # Get fake mode before adding placeholder
142
+ fake_mode = get_fake_mode(exported_program)
143
+
144
+ first_user_input = get_first_user_input(exported_program)
145
+ if not first_user_input:
146
+ # Placeholder nodes must be the first N nodes in the nodes list of a graph.
147
+ # Therefore, insert the newly created placeholders at the start of the node list.
148
+ assert exported_program.graph.nodes
149
+ first_node = list(exported_program.graph.nodes)[0]
150
+ first_user_input = first_node
151
+
152
+ # Add a placeholder to the graph.
153
+ with exported_program.graph.inserting_before(first_user_input):
154
+ const_node = exported_program.graph.placeholder(fqn_name)
155
+
156
+ const_node.meta["val"] = fake_mode.from_tensor(tensor, static_shapes=True)
157
+ const_node.meta["val"].constant = tensor
158
+
159
+ # Add a new constant to the exported program.
160
+ exported_program.constants[const_node.name] = tensor
161
+
162
+ # Use update (instead of append) if this assert is violated
163
+ assert const_node.name not in [
164
+ s.arg.name for s in exported_program.graph_signature.input_specs
165
+ ]
166
+
167
+ # Append the new input spec.
168
+ exported_program.graph_signature.input_specs.append(
169
+ create_input_spec(const_node, InputKind.CONSTANT_TENSOR)
170
+ )
171
+
172
+ # Get old input specs
173
+ name_to_spec_dict = {
174
+ s.arg.name: s for s in exported_program.graph_signature.input_specs
175
+ }
176
+
177
+ # Add the new constants to input specs dict.
178
+ name_to_spec_dict.update(
179
+ {const_node.name: create_input_spec(const_node, InputKind.CONSTANT_TENSOR)}
180
+ )
181
+
182
+ # Generate new input spec *in the same order of nodes*
183
+ # IMPORTANT Input specs and their placeholder nodes must have the same order.
184
+ new_input_specs = []
185
+ for node in exported_program.graph.nodes:
186
+ if node.op != "placeholder":
187
+ continue
188
+ new_input_specs.append(name_to_spec_dict[node.name])
189
+ exported_program.graph_signature.input_specs = new_input_specs
190
+
191
+ return const_node
192
+
193
+
194
+ def is_single_value_tensor(t: torch.Tensor):
195
+ if len(t.size()) == 0:
196
+ return True
197
+ if len(t.size()) == 1 and t.size()[0] == 1:
198
+ return True
199
+
200
+ return False
201
+
202
+
203
+ def get_module_name_chain(node: Optional[torch.fx.Node]) -> str:
204
+ """
205
+ Returns a slash-separated string of module names representing the
206
+ hierarchical path of the FX node within the original model.
207
+
208
+ If the node has no `nn_module_stack` metadata, "unknown" is returned.
209
+
210
+ Example:
211
+ "encoder/layer1/linear"
212
+
213
+ Parameters
214
+ ----------
215
+ node: torch.fx.Node
216
+ A node from an ExportedProgram graph.
217
+
218
+ Returns
219
+ -------
220
+ str
221
+ A human-readable string that describes the full module path.
222
+ """
223
+ if node is None:
224
+ return "unknown"
225
+ # Let's prefix "tico" for graph inputs
226
+ if node.op == "placeholder" and "nn_module_stack" not in node.meta:
227
+ return "tico"
228
+
229
+ assert isinstance(node, torch.fx.Node)
230
+ stack = node.meta.get("nn_module_stack")
231
+ if stack:
232
+ assert isinstance(stack, dict)
233
+ # Retrieving the last element is enough.
234
+ return next(reversed(stack.values()))[1]
235
+ else:
236
+ return "unknown"
237
+
238
+
239
+ def create_node(
240
+ graph: torch.fx.Graph,
241
+ target: torch._ops.OpOverload,
242
+ args: Optional[Tuple[Any, ...]] = None,
243
+ kwargs: Optional[Dict[str, Any]] = None,
244
+ *,
245
+ origin: Optional[torch.fx.Node] = None,
246
+ ) -> torch.fx.Node:
247
+ """
248
+ Insert a new node into graph and propagate metadata from *origin*.
249
+
250
+ Parameters
251
+ ----------
252
+ graph : torch.fx.Graph
253
+ The graph that will own the newly-created node.
254
+
255
+ target : torch._ops.OpOverload
256
+ The op to call (e.g. `torch.add` or "call_function" target).
257
+
258
+ args : Tuple[Any, ...], optional
259
+ Positional arguments for the new node.
260
+
261
+ kwargs : Dict[str, Any], optional
262
+ Keyword arguments for the new node.
263
+
264
+ origin : torch.fx.Node, optional
265
+ If given, every key in `origin.meta` **except** "val" is copied
266
+ onto the new node. "val" is recomputed from *args* /*kwargs* using
267
+ the internal meta-inference helper.
268
+
269
+ Returns
270
+ -------
271
+ torch.fx.Node
272
+ The freshly inserted node with fully-populated `.meta`.
273
+ """
274
+ new_node = graph.call_function(target, args=args, kwargs=kwargs)
275
+ if origin:
276
+ assert isinstance(origin, torch.fx.Node), type(origin)
277
+ # Propagate "nn_module_stack" to retain the originating module context
278
+ # for meaningful node names.
279
+ if "nn_module_stack" in origin.meta:
280
+ new_node.meta["nn_module_stack"] = origin.meta["nn_module_stack"]
281
+
282
+ return new_node
tico/utils/logging.py ADDED
@@ -0,0 +1,45 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import os
17
+
18
+
19
+ def _loggerLevel():
20
+ TICO_LOG = os.environ.get("TICO_LOG")
21
+ if TICO_LOG == "1":
22
+ log_level = logging.FATAL
23
+ elif TICO_LOG == "2":
24
+ log_level = logging.WARNING
25
+ elif TICO_LOG == "3":
26
+ log_level = logging.INFO
27
+ elif TICO_LOG == "4":
28
+ log_level = logging.DEBUG
29
+ else:
30
+ log_level = logging.WARNING
31
+ return log_level
32
+
33
+
34
+ LOG_LEVEL = _loggerLevel()
35
+
36
+
37
+ def getLogger(name: str):
38
+ """
39
+ Get logger with setting log level according to the `TICO_LOG` environment variable.
40
+ """
41
+ logging.basicConfig()
42
+ logger = logging.getLogger(name)
43
+ logger.setLevel(LOG_LEVEL)
44
+
45
+ return logger
tico/utils/model.py ADDED
@@ -0,0 +1,37 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing import Any
18
+
19
+ from tico.interpreter import infer
20
+
21
+
22
+ class CircleModel:
23
+ def __init__(self, circle_binary: bytes):
24
+ self.circle_binary = circle_binary
25
+
26
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
27
+ return infer.infer(self.circle_binary, *args, **kwargs)
28
+
29
+ @staticmethod
30
+ def load(circle_path: str) -> CircleModel:
31
+ with open(circle_path, "rb") as f:
32
+ buf = bytes(f.read())
33
+ return CircleModel(buf)
34
+
35
+ def save(self, circle_path: str) -> None:
36
+ with open(circle_path, "wb") as f:
37
+ f.write(self.circle_binary)
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -0,0 +1,267 @@
1
+ """
2
+ Copyright (c) Microsoft Corporation.
3
+ Licensed under the MIT License.
4
+
5
+ Name: elemwise_ops.py
6
+
7
+ Pytorch functions for elementwise (i.e. bfloat) quantization.
8
+
9
+ Usage Notes:
10
+ - Use the "Exposed Methods" below to implement autograd functions
11
+ - Use autograd functions to then implement torch.nn.Module(s)
12
+ - Do *not* use methods in this file in Modules, they have no defined
13
+ backwards pass and will block gradient computation.
14
+ - Avoid importing internal function if at all possible.
15
+
16
+ Exposed Methods:
17
+ quantize_elemwise_op - quantizes a tensor to bfloat or other
18
+ custom float format
19
+ """
20
+ import torch
21
+
22
+ from .formats import RoundingMode, _get_format_params
23
+ from .formats import _get_min_norm, _get_max_norm
24
+
25
+
26
+ # -------------------------------------------------------------------------
27
+ # Helper funcs
28
+ # -------------------------------------------------------------------------
29
+ # Never explicitly compute 2**(-exp) since subnorm numbers have
30
+ # exponents smaller than -126
31
+ def _safe_lshift(x, bits, exp):
32
+ if exp is None:
33
+ return x * (2**bits)
34
+ else:
35
+ return x / (2 ** exp) * (2**bits)
36
+
37
+
38
+ def _safe_rshift(x, bits, exp):
39
+ if exp is None:
40
+ return x / (2**bits)
41
+ else:
42
+ return x / (2**bits) * (2 ** exp)
43
+
44
+
45
+ def _round_mantissa(A, bits, round, clamp=False):
46
+ """
47
+ Rounds mantissa to nearest bits depending on the rounding method 'round'
48
+ Args:
49
+ A {PyTorch tensor} -- Input tensor
50
+ round {str} -- Rounding method
51
+ "floor" rounds to the floor
52
+ "nearest" rounds to ceil or floor, whichever is nearest
53
+ Returns:
54
+ A {PyTorch tensor} -- Tensor with mantissas rounded
55
+ """
56
+
57
+ if round == "dither":
58
+ rand_A = torch.rand_like(A, requires_grad=False)
59
+ A = torch.sign(A) * torch.floor(torch.abs(A) + rand_A)
60
+ elif round == "floor":
61
+ A = torch.sign(A) * torch.floor(torch.abs(A))
62
+ elif round == "nearest":
63
+ A = torch.sign(A) * torch.floor(torch.abs(A) + 0.5)
64
+ elif round == "even":
65
+ absA = torch.abs(A)
66
+ # find 0.5, 2.5, 4.5 ...
67
+ maskA = ((absA - 0.5) % 2 == torch.zeros_like(A)).type(A.dtype)
68
+ A = torch.sign(A) * (torch.floor(absA + 0.5) - maskA)
69
+ else:
70
+ raise Exception("Unrecognized round method %s" % (round))
71
+
72
+ # Clip values that cannot be expressed by the specified number of bits
73
+ if clamp:
74
+ max_mantissa = 2 ** (bits - 1) - 1
75
+ A = torch.clamp(A, -max_mantissa, max_mantissa)
76
+ return A
77
+
78
+
79
+ # -------------------------------------------------------------------------
80
+ # Main funcs
81
+ # -------------------------------------------------------------------------
82
+ def _quantize_elemwise_core(A, bits, exp_bits, max_norm, round='nearest',
83
+ saturate_normals=False, allow_denorm=True,
84
+ custom_cuda=False):
85
+ """ Core function used for element-wise quantization
86
+ Arguments:
87
+ A {PyTorch tensor} -- A tensor to be quantized
88
+ bits {int} -- Number of mantissa bits. Includes
89
+ sign bit and implicit one for floats
90
+ exp_bits {int} -- Number of exponent bits, 0 for ints
91
+ max_norm {float} -- Largest representable normal number
92
+ round {str} -- Rounding mode: (floor, nearest, even)
93
+ saturate_normals {bool} -- If True, normal numbers (i.e., not NaN/Inf)
94
+ that exceed max norm are clamped.
95
+ Must be True for correct MX conversion.
96
+ allow_denorm {bool} -- If False, flush denorm numbers in the
97
+ elem_format to zero.
98
+ custom_cuda {str} -- If True, use custom CUDA kernels
99
+ Returns:
100
+ quantized tensor {PyTorch tensor} -- A tensor that has been quantized
101
+ """
102
+ A_is_sparse = A.is_sparse
103
+ if A_is_sparse:
104
+ if A.layout != torch.sparse_coo:
105
+ raise NotImplementedError("Only COO layout sparse tensors are currently supported.")
106
+
107
+ sparse_A = A.coalesce()
108
+ A = sparse_A.values().clone()
109
+
110
+ # custom cuda only support floor and nearest rounding modes
111
+ custom_cuda = custom_cuda and round in RoundingMode.string_enums()
112
+
113
+ if custom_cuda:
114
+ A = A.contiguous()
115
+
116
+ from . import custom_extensions
117
+ if A.device.type == "cuda":
118
+ A = custom_extensions.funcs.quantize_elemwise_func_cuda(
119
+ A, bits, exp_bits, max_norm, RoundingMode[round],
120
+ saturate_normals, allow_denorm)
121
+ elif A.device.type == "cpu":
122
+ A = custom_extensions.funcs.quantize_elemwise_func_cpp(
123
+ A, bits, exp_bits, max_norm, RoundingMode[round],
124
+ saturate_normals, allow_denorm)
125
+ return A
126
+
127
+ # Flush values < min_norm to zero if denorms are not allowed
128
+ if not allow_denorm and exp_bits > 0:
129
+ min_norm = _get_min_norm(exp_bits)
130
+ out = (torch.abs(A) >= min_norm).type(A.dtype) * A
131
+ else:
132
+ out = A
133
+
134
+ if exp_bits != 0:
135
+ private_exp = torch.floor(torch.log2(
136
+ torch.abs(A) + (A == 0).type(A.dtype)))
137
+
138
+ # The minimum representable exponent for 8 exp bits is -126
139
+ min_exp = -(2**(exp_bits-1)) + 2
140
+ private_exp = private_exp.clip(min=min_exp)
141
+ else:
142
+ private_exp = None
143
+
144
+ # Scale up so appropriate number of bits are in the integer portion of the number
145
+ out = _safe_lshift(out, bits - 2, private_exp)
146
+
147
+ out = _round_mantissa(out, bits, round, clamp=False)
148
+
149
+ # Undo scaling
150
+ out = _safe_rshift(out, bits - 2, private_exp)
151
+
152
+ # Set values > max_norm to Inf if desired, else clamp them
153
+ if saturate_normals or exp_bits == 0:
154
+ out = torch.clamp(out, min=-max_norm, max=max_norm)
155
+ else:
156
+ out = torch.where((torch.abs(out) > max_norm),
157
+ torch.sign(out) * float("Inf"), out)
158
+
159
+ # handle Inf/NaN
160
+ if not custom_cuda:
161
+ out[A == float("Inf")] = float("Inf")
162
+ out[A == -float("Inf")] = -float("Inf")
163
+ out[A == float("NaN")] = float("NaN")
164
+
165
+ if A_is_sparse:
166
+ output = torch.sparse_coo_tensor(sparse_A.indices(), output,
167
+ sparse_A.size(), dtype=sparse_A.dtype, device=sparse_A.device,
168
+ requires_grad=sparse_A.requires_grad)
169
+
170
+ return out
171
+
172
+
173
+ def _quantize_elemwise(A, elem_format, round='nearest', custom_cuda=False,
174
+ saturate_normals=False, allow_denorm=True):
175
+ """ Quantize values to a defined format. See _quantize_elemwise_core()
176
+ """
177
+ if elem_format == None:
178
+ return A
179
+
180
+ ebits, mbits, _, max_norm, _ = _get_format_params(elem_format)
181
+
182
+ output = _quantize_elemwise_core(
183
+ A, mbits, ebits, max_norm,
184
+ round=round, allow_denorm=allow_denorm,
185
+ saturate_normals=saturate_normals,
186
+ custom_cuda=custom_cuda)
187
+
188
+ return output
189
+
190
+
191
+ def _quantize_bfloat(A, bfloat, round='nearest', custom_cuda=False, allow_denorm=True):
192
+ """ Quantize values to bfloatX format
193
+ Arguments:
194
+ bfloat {int} -- Total number of bits for bfloatX format,
195
+ Includes 1 sign, 8 exp bits, and variable
196
+ mantissa bits. Must be >= 9.
197
+ """
198
+ # Shortcut for no quantization
199
+ if bfloat == 0 or bfloat == 32:
200
+ return A
201
+
202
+ max_norm = _get_max_norm(8, bfloat-7)
203
+
204
+ return _quantize_elemwise_core(
205
+ A, bits=bfloat-7, exp_bits=8, max_norm=max_norm, round=round,
206
+ allow_denorm=allow_denorm, custom_cuda=custom_cuda)
207
+
208
+
209
+ def _quantize_fp(A, exp_bits=None, mantissa_bits=None,
210
+ round='nearest', custom_cuda=False, allow_denorm=True):
211
+ """ Quantize values to IEEE fpX format. The format defines NaN/Inf
212
+ and subnorm numbers in the same way as FP32 and FP16.
213
+ Arguments:
214
+ exp_bits {int} -- number of bits used to store exponent
215
+ mantissa_bits {int} -- number of bits used to store mantissa, not
216
+ including sign or implicit 1
217
+ round {str} -- Rounding mode, (floor, nearest, even)
218
+ """
219
+ # Shortcut for no quantization
220
+ if exp_bits is None or mantissa_bits is None:
221
+ return A
222
+
223
+ max_norm = _get_max_norm(exp_bits, mantissa_bits+2)
224
+
225
+ output = _quantize_elemwise_core(
226
+ A, bits=mantissa_bits + 2, exp_bits=exp_bits,
227
+ max_norm=max_norm, round=round, allow_denorm=allow_denorm,
228
+ custom_cuda=custom_cuda)
229
+
230
+ return output
231
+
232
+
233
+ def quantize_elemwise_op(A, mx_specs, round=None):
234
+ """A function used for element-wise quantization with mx_specs
235
+ Arguments:
236
+ A {PyTorch tensor} -- a tensor that needs to be quantized
237
+ mx_specs {dictionary} -- dictionary to specify mx_specs
238
+ round {str} -- Rounding mode, choose from (floor, nearest, even)
239
+ (default: "nearest")
240
+ Returns:
241
+ quantized value {PyTorch tensor} -- a tensor that has been quantized
242
+ """
243
+ if mx_specs is None:
244
+ return A
245
+ elif round is None:
246
+ round = mx_specs['round']
247
+
248
+ if mx_specs['bfloat'] == 16 and round == 'even'\
249
+ and torch.cuda.is_bf16_supported() \
250
+ and mx_specs['bfloat_subnorms'] == True:
251
+ return A.to(torch.bfloat16)
252
+
253
+ if mx_specs['bfloat'] > 0 and mx_specs['fp'] > 0:
254
+ raise ValueError("Cannot set both [bfloat] and [fp] in mx_specs.")
255
+ elif mx_specs['bfloat'] > 9:
256
+ A = _quantize_bfloat(A, bfloat=mx_specs['bfloat'], round=round,
257
+ custom_cuda=mx_specs['custom_cuda'],
258
+ allow_denorm=mx_specs['bfloat_subnorms'])
259
+ elif mx_specs['bfloat'] > 0 and mx_specs['bfloat'] <= 9:
260
+ raise ValueError("Cannot set [bfloat] <= 9 in mx_specs.")
261
+ elif mx_specs['fp'] > 6:
262
+ A = _quantize_fp(A, exp_bits=5, mantissa_bits=mx_specs['fp'] - 6,
263
+ round=round, custom_cuda=mx_specs['custom_cuda'],
264
+ allow_denorm=mx_specs['bfloat_subnorms'])
265
+ elif mx_specs['fp'] > 0 and mx_specs['fp'] <= 6:
266
+ raise ValueError("Cannot set [fp] <= 6 in mx_specs.")
267
+ return A