tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,101 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from functools import wraps
16
+
17
+ import torch
18
+ from torch.export import ExportedProgram
19
+
20
+ from tico.utils.diff_graph import capture, capture_const, log, log_const
21
+ from tico.utils.passes import PassBase
22
+
23
+
24
+ def trace_const_diff_on_pass(cls):
25
+ """Decorator for PassBase to trace const diff"""
26
+
27
+ assert issubclass(cls, PassBase), type(cls)
28
+
29
+ def _call_traced(fn):
30
+ @wraps(fn)
31
+ def wrapped(*args):
32
+ _, exported_program = args
33
+ assert isinstance(exported_program, ExportedProgram)
34
+ graph_module = exported_program.graph_module
35
+ assert isinstance(graph_module, torch.fx.GraphModule), type(graph_module)
36
+ capture_const(exported_program)
37
+ ret = fn(*args)
38
+ log_const(exported_program, title=str(cls.__name__), recapture=False)
39
+ return ret
40
+
41
+ return wrapped
42
+
43
+ # replace call function it with traced version
44
+ for key, val in vars(cls).items():
45
+ if key == "call":
46
+ setattr(cls, key, _call_traced(val))
47
+ return cls
48
+
49
+
50
+ def trace_graph_diff_on_pass(cls):
51
+ """Decorator for PassBase to trace graph diff"""
52
+
53
+ assert issubclass(cls, PassBase), type(cls)
54
+
55
+ def _call_traced(fn):
56
+ @wraps(fn)
57
+ def wrapped(*args):
58
+ _, exported_program = args
59
+ assert isinstance(exported_program, ExportedProgram)
60
+ graph_module = exported_program.graph_module
61
+ assert isinstance(graph_module, torch.fx.GraphModule), type(graph_module)
62
+ capture(graph_module.graph)
63
+ ret = fn(*args)
64
+ log(graph_module.graph, title=str(cls.__name__), recapture=False)
65
+ return ret
66
+
67
+ return wrapped
68
+
69
+ # replace call function it with traced version
70
+ for key, val in vars(cls).items():
71
+ if key == "call":
72
+ setattr(cls, key, _call_traced(val))
73
+ return cls
74
+
75
+
76
+ def trace_const_diff_on_func(fn):
77
+ """Decorator for function to trace const diff"""
78
+
79
+ @wraps(fn)
80
+ def wrapped(ep: torch.export.ExportedProgram):
81
+ assert isinstance(ep, torch.export.ExportedProgram)
82
+ capture_const(ep)
83
+ ret = fn(ep)
84
+ log_const(ret, title=str(fn.__name__), recapture=False)
85
+ return ret
86
+
87
+ return wrapped
88
+
89
+
90
+ def trace_graph_diff_on_func(fn):
91
+ """Decorator for function to trace graph diff"""
92
+
93
+ @wraps(fn)
94
+ def wrapped(ep: torch.export.ExportedProgram):
95
+ assert isinstance(ep, torch.export.ExportedProgram)
96
+ capture(ep.graph)
97
+ ret = fn(ep)
98
+ log(ret.graph, title=str(fn.__name__), recapture=False)
99
+ return ret
100
+
101
+ return wrapped
tico/utils/utils.py ADDED
@@ -0,0 +1,406 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import subprocess
17
+ import typing
18
+ import warnings
19
+ from functools import wraps
20
+ from typing import List
21
+
22
+ import torch
23
+ from circle_schema import circle
24
+ from packaging.version import Version
25
+ from torch._guards import detect_fake_mode
26
+ from torch.export import ExportedProgram
27
+ from torch.utils import _pytree as pytree
28
+
29
+ from tico.serialize.quant_param import QuantParam
30
+
31
+
32
+ HAS_TORCH_OVER_25 = Version(torch.__version__) >= Version("2.5.0")
33
+ HAS_TORCH_OVER_28_DEV = Version(torch.__version__) >= Version("2.8.0.dev")
34
+
35
+
36
+ def get_fake_mode(exported_program: ExportedProgram):
37
+ fake_mode = detect_fake_mode(
38
+ tuple(
39
+ node.meta["val"]
40
+ for node in exported_program.graph.nodes
41
+ if node.op == "placeholder"
42
+ )
43
+ )
44
+ assert fake_mode is not None
45
+ return fake_mode
46
+
47
+
48
+ class SuppressWarning:
49
+ def __init__(self, warning_category: type[Warning], regex):
50
+ self.warning_category = warning_category
51
+ self.regex = regex
52
+
53
+ def __enter__(self):
54
+ warnings.filterwarnings(
55
+ "ignore", category=self.warning_category, message=self.regex
56
+ )
57
+
58
+ def __exit__(self, exc_type, exc_val, exc_tb):
59
+ warnings.filterwarnings(
60
+ "default", category=self.warning_category, message=self.regex
61
+ )
62
+
63
+
64
+ class ArgTypeError(Exception):
65
+ """
66
+ Invalid argument type
67
+ """
68
+
69
+ pass
70
+
71
+
72
+ def enforce_type(callable):
73
+ """Check types for your callable's signature
74
+
75
+ NOTE Place this one above @dataclass decorator if you want to use it with dataclass initializer.
76
+ Ex.
77
+ @enforce_type
78
+ @dataclass
79
+ class Args:
80
+ ...
81
+ """
82
+ spec = inspect.getfullargspec(callable)
83
+
84
+ def check_types(*args, **kwargs):
85
+ parameters = dict(zip(spec.args, args))
86
+ parameters.update(kwargs)
87
+ for name, value in parameters.items():
88
+ if name == "self":
89
+ # skip 'self' in spec.args
90
+ continue
91
+
92
+ assert (
93
+ name in spec.annotations
94
+ ), f"All parameter require type hints. {name} needs a type hint"
95
+
96
+ type_hint = spec.annotations[name]
97
+
98
+ # Return tuple of flattened types.
99
+ # Q) What is flatten?
100
+ # A) Optional/Union is not included. Below are included.
101
+ # collections: List, Set, ...
102
+ # primitive types: int, str, ...
103
+ def _flatten_type(type_hint) -> tuple:
104
+ # `get_origin` maps Union[...] and Optional[...] varieties to Union
105
+ if typing.get_origin(type_hint) == typing.Union:
106
+ # ex. typing.Union[list, int] -> (list, int)
107
+ # ex. typing.Optional[torch.fx.Node] -> (torch.fx.Node, NoneType)
108
+ actual_type = tuple(
109
+ [_flatten_type(t) for t in typing.get_args(type_hint)]
110
+ )
111
+ else:
112
+ actual_type = (type_hint,)
113
+ return actual_type
114
+
115
+ type_hint = _flatten_type(type_hint)
116
+
117
+ # Return true if value matches with type_hint
118
+ # Return false otherwise
119
+ def _check_type(value, type_hint):
120
+ if type_hint == typing.Any:
121
+ return True
122
+
123
+ if isinstance(type_hint, tuple):
124
+ return any([_check_type(value, t) for t in type_hint])
125
+
126
+ if typing.get_origin(type_hint) in (list, set):
127
+ if not isinstance(value, typing.get_origin(type_hint)):
128
+ return False
129
+
130
+ for v in value:
131
+ if not any(
132
+ [_check_type(v, t) for t in typing.get_args(type_hint)]
133
+ ):
134
+ return False
135
+
136
+ return True
137
+
138
+ if typing.get_origin(type_hint) == dict:
139
+ if not isinstance(value, typing.get_origin(type_hint)):
140
+ return False
141
+
142
+ for k, v in value.items():
143
+ k_type, v_type = typing.get_args(type_hint)
144
+ if not _check_type(k, k_type):
145
+ return False
146
+ if not _check_type(v, v_type):
147
+ return False
148
+
149
+ return True
150
+
151
+ # TODO: Support more type hints
152
+ return isinstance(value, type_hint)
153
+
154
+ type_check_result = _check_type(value, type_hint)
155
+ if not type_check_result:
156
+ raise ArgTypeError(
157
+ "Unexpected type for '{}' (expected {} but found {})".format(
158
+ name, type_hint, type(value)
159
+ )
160
+ )
161
+
162
+ def decorate(func):
163
+ @wraps(func)
164
+ def wrapper(*args, **kwargs):
165
+ check_types(*args, **kwargs)
166
+ return func(*args, **kwargs)
167
+
168
+ return wrapper
169
+
170
+ if inspect.isclass(callable):
171
+ callable.__init__ = decorate(callable.__init__)
172
+ return callable
173
+
174
+ return decorate(callable)
175
+
176
+
177
+ def fill_meta_val(exported_program: ExportedProgram):
178
+ for node in exported_program.graph.nodes:
179
+ assert hasattr(node, "meta"), f"{node.name} does not have meta attribute"
180
+
181
+ if node.meta.get("val", None) is None:
182
+ if node.op == "call_function":
183
+ set_new_meta_val(node)
184
+
185
+
186
+ def set_new_meta_val(node: torch.fx.node.Node):
187
+ """
188
+ Set node.meta["val"].
189
+
190
+ There are some cases when node.meta["val"] should be updated.
191
+ - After creating new node
192
+ - After updating node's args or kwargs
193
+ """
194
+ assert isinstance(node, torch.fx.node.Node)
195
+
196
+ # `node.target()` needs only `Tensor` for its arguments.
197
+ # Therefore, let's retrieve `FakeTensor` if it is `torch.fx.Node`.
198
+ args, kwargs = pytree.tree_map_only(
199
+ torch.fx.Node,
200
+ lambda n: n.meta["val"],
201
+ (node.args, node.kwargs),
202
+ )
203
+ new_val = node.target(*args, **kwargs) # type: ignore[operator]
204
+ node.meta["val"] = new_val
205
+
206
+
207
+ def unset_meta_val(node: torch.fx.node.Node):
208
+ """
209
+ Unset node.meta["val"].
210
+
211
+ - When to use it?
212
+ When we need to update a node's meta val
213
+ but some precedent's meta value are not decided yet, (eg. newly created args)
214
+ let's simply unset meta val and expect `FillMetaVal` do it.
215
+ """
216
+ assert isinstance(node, torch.fx.node.Node)
217
+
218
+ if "val" in node.meta:
219
+ del node.meta["val"]
220
+
221
+
222
+ def run_bash_cmd(command: typing.List[str]) -> subprocess.CompletedProcess[str]:
223
+ """
224
+ Executes a given bash command represented as a sequence of program arguments
225
+ using subprocess and returns output.
226
+
227
+ Args:
228
+ command (List[str]): A sequence of program arguments.
229
+
230
+ Returns:
231
+ str: The standard output of the executed command.
232
+
233
+ Example:
234
+ >>> completed_process = run_bash_cmd(["echo", "Hello, World!"])
235
+ print (completed_process.stdout)
236
+ 'Hello, World!\\n'
237
+
238
+ >>> cp = run_bash_cmd(["ls", "-l"])
239
+ print (cp.stdout)
240
+ 'drwxrwxr-x 8 user group 4096 12월 3 17:16 tico\\n'
241
+ """
242
+ if not isinstance(command, list) or not all(isinstance(c, str) for c in command):
243
+ raise ValueError("Command must be a list of strings.")
244
+ try:
245
+ return subprocess.run(command, check=True, text=True, capture_output=True)
246
+ except subprocess.CalledProcessError as err:
247
+ cmd_str = " ".join(err.cmd)
248
+ msg = f"Error while running command:\n\n $ {cmd_str}"
249
+ msg += "\n"
250
+ msg += "[EXIT CODE]\n"
251
+ msg += f"{err.returncode}\n"
252
+ msg += "[STDOUT]\n"
253
+ msg += err.stdout
254
+ msg += "[STDERR]\n"
255
+ msg += err.stderr
256
+ raise RuntimeError(f"Failed.\n\n {msg}")
257
+
258
+
259
+ def has_quantization_ops(graph: torch.fx.Graph):
260
+ """
261
+ Checks whether the given fx graph contains any quantization-related operations.
262
+
263
+ This function inspects the provided graph to determine if it includes operations associated
264
+ with quantization (e.g., quantize, dequantize, fake quantize, etc.). The presence of such operations
265
+ can be used to decide whether to run subsequent quantization-specific passes on the graph.
266
+
267
+ Parameters:
268
+ graph: The fx graph to be examined. It is expected that the graph supports
269
+ iteration or traversal over its constituent operations.
270
+
271
+ Returns:
272
+ bool: True if the graph contains one or more quantization-related operations, False otherwise.
273
+ """
274
+ quantized_ops = [
275
+ torch.ops.quantized_decomposed.quantize_per_tensor.default,
276
+ torch.ops.quantized_decomposed.quantize_per_channel.default,
277
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
278
+ torch.ops.quantized_decomposed.dequantize_per_channel.default,
279
+ ]
280
+ for node in graph.nodes:
281
+ if node.op != "call_function":
282
+ continue
283
+ if node.target in quantized_ops:
284
+ return True
285
+
286
+ return False
287
+
288
+
289
+ def to_circle_qparam(qparam: QuantParam):
290
+ circle_qparam = circle.QuantizationParameters.QuantizationParametersT()
291
+ if qparam.scale is not None:
292
+ circle_qparam.scale = qparam.scale
293
+
294
+ if qparam.zero_point is not None:
295
+ circle_qparam.zeroPoint = qparam.zero_point
296
+
297
+ if qparam.quantized_dimension is not None:
298
+ circle_qparam.quantizedDimension = qparam.quantized_dimension
299
+
300
+ if qparam.min is not None:
301
+ circle_qparam.min = qparam.min
302
+
303
+ if qparam.max is not None:
304
+ circle_qparam.max = qparam.max
305
+
306
+ return circle_qparam
307
+
308
+
309
+ def quant_min_max(dtype: str):
310
+ if dtype == "uint8":
311
+ return (0, 255)
312
+ elif dtype == "int16":
313
+ return (-32768, 32767)
314
+ else:
315
+ raise NotImplementedError(f"NYI dtype: {dtype}")
316
+
317
+
318
+ def get_quant_dtype(qmin: int, qmax: int):
319
+ """
320
+ Returns the string representation of the quantized data type based on qmin and qmax.
321
+
322
+ Args:
323
+ qmin (int): Minimum quantized value.
324
+ qmax (int): Maximum quantized value.
325
+
326
+ Returns:
327
+ str: A string representing the quantized data type, such as "int8", "uint4", etc.
328
+
329
+ Raises:
330
+ ValueError: If the (qmin, qmax) pair is not supported.
331
+ """
332
+ known_ranges = {
333
+ (-32768, 32767): "int16",
334
+ (-32767, 32767): "int16",
335
+ (0, 65535): "uint16",
336
+ (-128, 127): "int8",
337
+ (0, 255): "uint8",
338
+ (-8, 7): "int4",
339
+ (0, 15): "uint4",
340
+ }
341
+
342
+ if (qmin, qmax) in known_ranges:
343
+ return known_ranges[(qmin, qmax)]
344
+ else:
345
+ raise ValueError(f"Unsupported quantization range: ({qmin}, {qmax})")
346
+
347
+
348
+ def broadcastable(
349
+ shape_a: List[int] | torch.Size, shape_b: List[int] | torch.Size
350
+ ) -> bool:
351
+ """
352
+ Return **True** if two shapes are broadcast-compatible under the standard
353
+ NumPy/PyTorch rules.
354
+
355
+ Broadcasting rule
356
+ --------------------------------
357
+ - Align the shapes **right-to-left**.
358
+ - For each aligned dimension `(a, b)` one of the following must hold
359
+ - `a == b` (sizes match)
360
+ - `a == 1` (shape-A can repeat along that dim)
361
+ - `b == 1` (shape-B can repeat along that dim)
362
+ - When one shape is shorter, treat its missing leading dims as `1`.
363
+
364
+ Examples
365
+ --------
366
+ >>> _broadcastable([8, 16, 32], [16, 32])
367
+ True
368
+ >>> _broadcastable([8, 16, 32], [1, 32])
369
+ True
370
+ >>> _broadcastable([8, 16, 32], [8, 32, 16])
371
+ False
372
+ """
373
+ # Walk from the last dim to the front
374
+ len_a, len_b = len(shape_a), len(shape_b)
375
+ max_len = max(len_a, len_b)
376
+ for i in range(1, max_len + 1):
377
+ dim_a = shape_a[-i] if i <= len_a else 1
378
+ dim_b = shape_b[-i] if i <= len_b else 1
379
+ if dim_a != 1 and dim_b != 1 and dim_a != dim_b:
380
+ return False
381
+ return True
382
+
383
+
384
+ def is_target_node(
385
+ node: torch.fx.Node, target_ops: list[torch._ops.OpOverload] | torch._ops.OpOverload
386
+ ):
387
+ """
388
+ Check whether a given node is a `call_function` node that matches one of the specified targets.
389
+
390
+ Args:
391
+ node (torch.fx.Node): The node to check.
392
+ target_ops (Iterable[Callable]): A list or set of target operations to match (e.g., ops.aten.reshape).
393
+
394
+ Returns:
395
+ bool: True if the node is a call_function, its target is in `target_ops`.
396
+ """
397
+ if not isinstance(target_ops, list):
398
+ target_ops = [target_ops]
399
+ assert all(isinstance(t, torch._ops.OpOverload) for t in target_ops), target_ops
400
+
401
+ if node.op != "call_function":
402
+ return False
403
+ if node.target not in target_ops:
404
+ return False
405
+
406
+ return True