tico 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +42 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/quantize_bias.py +123 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +191 -0
- tico/passes/cast_mixed_type_args.py +187 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +160 -0
- tico/passes/convert_layout_op_to_reshape.py +85 -0
- tico/passes/convert_repeat_to_expand_copy.py +89 -0
- tico/passes/convert_to_relu6.py +181 -0
- tico/passes/decompose_addmm.py +124 -0
- tico/passes/decompose_batch_norm.py +192 -0
- tico/passes/decompose_fake_quantize.py +134 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
- tico/passes/decompose_group_norm.py +275 -0
- tico/passes/decompose_grouped_conv2d.py +209 -0
- tico/passes/decompose_slice_scatter.py +169 -0
- tico/passes/extract_dtype_kwargs.py +122 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +108 -0
- tico/passes/legalize_predefined_layout_operators.py +386 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
- tico/passes/lower_to_slice.py +230 -0
- tico/passes/merge_consecutive_cat.py +80 -0
- tico/passes/ops.py +78 -0
- tico/passes/remove_nop.py +84 -0
- tico/passes/remove_redundant_assert_nodes.py +51 -0
- tico/passes/remove_redundant_expand.py +66 -0
- tico/passes/remove_redundant_permute.py +122 -0
- tico/passes/remove_redundant_reshape.py +436 -0
- tico/passes/remove_redundant_slice.py +62 -0
- tico/passes/remove_redundant_to_copy.py +86 -0
- tico/passes/restore_linear.py +115 -0
- tico/passes/segment_index_select.py +145 -0
- tico/pt2_to_circle.py +105 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +319 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +240 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_abs.py +53 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +150 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +192 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +126 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +186 -0
- tico/serialize/operators/op_copy.py +164 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +95 -0
- tico/serialize/operators/op_depthwise_conv2d.py +199 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_leaky_relu.py +60 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +86 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_dim.py +70 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +177 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +141 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +100 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +99 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +94 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +296 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +282 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/mx/__init__.py +1 -0
- tico/utils/mx/elemwise_ops.py +267 -0
- tico/utils/mx/formats.py +125 -0
- tico/utils/mx/mx_ops.py +270 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +609 -0
- tico/utils/serialize.py +42 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +406 -0
- tico/utils/validate_args_kwargs.py +1149 -0
- tico-0.1.0.dist-info/LICENSE +241 -0
- tico-0.1.0.dist-info/METADATA +354 -0
- tico-0.1.0.dist-info/RECORD +206 -0
- tico-0.1.0.dist-info/WHEEL +5 -0
- tico-0.1.0.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,101 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from functools import wraps
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from torch.export import ExportedProgram
|
19
|
+
|
20
|
+
from tico.utils.diff_graph import capture, capture_const, log, log_const
|
21
|
+
from tico.utils.passes import PassBase
|
22
|
+
|
23
|
+
|
24
|
+
def trace_const_diff_on_pass(cls):
|
25
|
+
"""Decorator for PassBase to trace const diff"""
|
26
|
+
|
27
|
+
assert issubclass(cls, PassBase), type(cls)
|
28
|
+
|
29
|
+
def _call_traced(fn):
|
30
|
+
@wraps(fn)
|
31
|
+
def wrapped(*args):
|
32
|
+
_, exported_program = args
|
33
|
+
assert isinstance(exported_program, ExportedProgram)
|
34
|
+
graph_module = exported_program.graph_module
|
35
|
+
assert isinstance(graph_module, torch.fx.GraphModule), type(graph_module)
|
36
|
+
capture_const(exported_program)
|
37
|
+
ret = fn(*args)
|
38
|
+
log_const(exported_program, title=str(cls.__name__), recapture=False)
|
39
|
+
return ret
|
40
|
+
|
41
|
+
return wrapped
|
42
|
+
|
43
|
+
# replace call function it with traced version
|
44
|
+
for key, val in vars(cls).items():
|
45
|
+
if key == "call":
|
46
|
+
setattr(cls, key, _call_traced(val))
|
47
|
+
return cls
|
48
|
+
|
49
|
+
|
50
|
+
def trace_graph_diff_on_pass(cls):
|
51
|
+
"""Decorator for PassBase to trace graph diff"""
|
52
|
+
|
53
|
+
assert issubclass(cls, PassBase), type(cls)
|
54
|
+
|
55
|
+
def _call_traced(fn):
|
56
|
+
@wraps(fn)
|
57
|
+
def wrapped(*args):
|
58
|
+
_, exported_program = args
|
59
|
+
assert isinstance(exported_program, ExportedProgram)
|
60
|
+
graph_module = exported_program.graph_module
|
61
|
+
assert isinstance(graph_module, torch.fx.GraphModule), type(graph_module)
|
62
|
+
capture(graph_module.graph)
|
63
|
+
ret = fn(*args)
|
64
|
+
log(graph_module.graph, title=str(cls.__name__), recapture=False)
|
65
|
+
return ret
|
66
|
+
|
67
|
+
return wrapped
|
68
|
+
|
69
|
+
# replace call function it with traced version
|
70
|
+
for key, val in vars(cls).items():
|
71
|
+
if key == "call":
|
72
|
+
setattr(cls, key, _call_traced(val))
|
73
|
+
return cls
|
74
|
+
|
75
|
+
|
76
|
+
def trace_const_diff_on_func(fn):
|
77
|
+
"""Decorator for function to trace const diff"""
|
78
|
+
|
79
|
+
@wraps(fn)
|
80
|
+
def wrapped(ep: torch.export.ExportedProgram):
|
81
|
+
assert isinstance(ep, torch.export.ExportedProgram)
|
82
|
+
capture_const(ep)
|
83
|
+
ret = fn(ep)
|
84
|
+
log_const(ret, title=str(fn.__name__), recapture=False)
|
85
|
+
return ret
|
86
|
+
|
87
|
+
return wrapped
|
88
|
+
|
89
|
+
|
90
|
+
def trace_graph_diff_on_func(fn):
|
91
|
+
"""Decorator for function to trace graph diff"""
|
92
|
+
|
93
|
+
@wraps(fn)
|
94
|
+
def wrapped(ep: torch.export.ExportedProgram):
|
95
|
+
assert isinstance(ep, torch.export.ExportedProgram)
|
96
|
+
capture(ep.graph)
|
97
|
+
ret = fn(ep)
|
98
|
+
log(ret.graph, title=str(fn.__name__), recapture=False)
|
99
|
+
return ret
|
100
|
+
|
101
|
+
return wrapped
|
tico/utils/utils.py
ADDED
@@ -0,0 +1,406 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import inspect
|
16
|
+
import subprocess
|
17
|
+
import typing
|
18
|
+
import warnings
|
19
|
+
from functools import wraps
|
20
|
+
from typing import List
|
21
|
+
|
22
|
+
import torch
|
23
|
+
from circle_schema import circle
|
24
|
+
from packaging.version import Version
|
25
|
+
from torch._guards import detect_fake_mode
|
26
|
+
from torch.export import ExportedProgram
|
27
|
+
from torch.utils import _pytree as pytree
|
28
|
+
|
29
|
+
from tico.serialize.quant_param import QuantParam
|
30
|
+
|
31
|
+
|
32
|
+
HAS_TORCH_OVER_25 = Version(torch.__version__) >= Version("2.5.0")
|
33
|
+
HAS_TORCH_OVER_28_DEV = Version(torch.__version__) >= Version("2.8.0.dev")
|
34
|
+
|
35
|
+
|
36
|
+
def get_fake_mode(exported_program: ExportedProgram):
|
37
|
+
fake_mode = detect_fake_mode(
|
38
|
+
tuple(
|
39
|
+
node.meta["val"]
|
40
|
+
for node in exported_program.graph.nodes
|
41
|
+
if node.op == "placeholder"
|
42
|
+
)
|
43
|
+
)
|
44
|
+
assert fake_mode is not None
|
45
|
+
return fake_mode
|
46
|
+
|
47
|
+
|
48
|
+
class SuppressWarning:
|
49
|
+
def __init__(self, warning_category: type[Warning], regex):
|
50
|
+
self.warning_category = warning_category
|
51
|
+
self.regex = regex
|
52
|
+
|
53
|
+
def __enter__(self):
|
54
|
+
warnings.filterwarnings(
|
55
|
+
"ignore", category=self.warning_category, message=self.regex
|
56
|
+
)
|
57
|
+
|
58
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
59
|
+
warnings.filterwarnings(
|
60
|
+
"default", category=self.warning_category, message=self.regex
|
61
|
+
)
|
62
|
+
|
63
|
+
|
64
|
+
class ArgTypeError(Exception):
|
65
|
+
"""
|
66
|
+
Invalid argument type
|
67
|
+
"""
|
68
|
+
|
69
|
+
pass
|
70
|
+
|
71
|
+
|
72
|
+
def enforce_type(callable):
|
73
|
+
"""Check types for your callable's signature
|
74
|
+
|
75
|
+
NOTE Place this one above @dataclass decorator if you want to use it with dataclass initializer.
|
76
|
+
Ex.
|
77
|
+
@enforce_type
|
78
|
+
@dataclass
|
79
|
+
class Args:
|
80
|
+
...
|
81
|
+
"""
|
82
|
+
spec = inspect.getfullargspec(callable)
|
83
|
+
|
84
|
+
def check_types(*args, **kwargs):
|
85
|
+
parameters = dict(zip(spec.args, args))
|
86
|
+
parameters.update(kwargs)
|
87
|
+
for name, value in parameters.items():
|
88
|
+
if name == "self":
|
89
|
+
# skip 'self' in spec.args
|
90
|
+
continue
|
91
|
+
|
92
|
+
assert (
|
93
|
+
name in spec.annotations
|
94
|
+
), f"All parameter require type hints. {name} needs a type hint"
|
95
|
+
|
96
|
+
type_hint = spec.annotations[name]
|
97
|
+
|
98
|
+
# Return tuple of flattened types.
|
99
|
+
# Q) What is flatten?
|
100
|
+
# A) Optional/Union is not included. Below are included.
|
101
|
+
# collections: List, Set, ...
|
102
|
+
# primitive types: int, str, ...
|
103
|
+
def _flatten_type(type_hint) -> tuple:
|
104
|
+
# `get_origin` maps Union[...] and Optional[...] varieties to Union
|
105
|
+
if typing.get_origin(type_hint) == typing.Union:
|
106
|
+
# ex. typing.Union[list, int] -> (list, int)
|
107
|
+
# ex. typing.Optional[torch.fx.Node] -> (torch.fx.Node, NoneType)
|
108
|
+
actual_type = tuple(
|
109
|
+
[_flatten_type(t) for t in typing.get_args(type_hint)]
|
110
|
+
)
|
111
|
+
else:
|
112
|
+
actual_type = (type_hint,)
|
113
|
+
return actual_type
|
114
|
+
|
115
|
+
type_hint = _flatten_type(type_hint)
|
116
|
+
|
117
|
+
# Return true if value matches with type_hint
|
118
|
+
# Return false otherwise
|
119
|
+
def _check_type(value, type_hint):
|
120
|
+
if type_hint == typing.Any:
|
121
|
+
return True
|
122
|
+
|
123
|
+
if isinstance(type_hint, tuple):
|
124
|
+
return any([_check_type(value, t) for t in type_hint])
|
125
|
+
|
126
|
+
if typing.get_origin(type_hint) in (list, set):
|
127
|
+
if not isinstance(value, typing.get_origin(type_hint)):
|
128
|
+
return False
|
129
|
+
|
130
|
+
for v in value:
|
131
|
+
if not any(
|
132
|
+
[_check_type(v, t) for t in typing.get_args(type_hint)]
|
133
|
+
):
|
134
|
+
return False
|
135
|
+
|
136
|
+
return True
|
137
|
+
|
138
|
+
if typing.get_origin(type_hint) == dict:
|
139
|
+
if not isinstance(value, typing.get_origin(type_hint)):
|
140
|
+
return False
|
141
|
+
|
142
|
+
for k, v in value.items():
|
143
|
+
k_type, v_type = typing.get_args(type_hint)
|
144
|
+
if not _check_type(k, k_type):
|
145
|
+
return False
|
146
|
+
if not _check_type(v, v_type):
|
147
|
+
return False
|
148
|
+
|
149
|
+
return True
|
150
|
+
|
151
|
+
# TODO: Support more type hints
|
152
|
+
return isinstance(value, type_hint)
|
153
|
+
|
154
|
+
type_check_result = _check_type(value, type_hint)
|
155
|
+
if not type_check_result:
|
156
|
+
raise ArgTypeError(
|
157
|
+
"Unexpected type for '{}' (expected {} but found {})".format(
|
158
|
+
name, type_hint, type(value)
|
159
|
+
)
|
160
|
+
)
|
161
|
+
|
162
|
+
def decorate(func):
|
163
|
+
@wraps(func)
|
164
|
+
def wrapper(*args, **kwargs):
|
165
|
+
check_types(*args, **kwargs)
|
166
|
+
return func(*args, **kwargs)
|
167
|
+
|
168
|
+
return wrapper
|
169
|
+
|
170
|
+
if inspect.isclass(callable):
|
171
|
+
callable.__init__ = decorate(callable.__init__)
|
172
|
+
return callable
|
173
|
+
|
174
|
+
return decorate(callable)
|
175
|
+
|
176
|
+
|
177
|
+
def fill_meta_val(exported_program: ExportedProgram):
|
178
|
+
for node in exported_program.graph.nodes:
|
179
|
+
assert hasattr(node, "meta"), f"{node.name} does not have meta attribute"
|
180
|
+
|
181
|
+
if node.meta.get("val", None) is None:
|
182
|
+
if node.op == "call_function":
|
183
|
+
set_new_meta_val(node)
|
184
|
+
|
185
|
+
|
186
|
+
def set_new_meta_val(node: torch.fx.node.Node):
|
187
|
+
"""
|
188
|
+
Set node.meta["val"].
|
189
|
+
|
190
|
+
There are some cases when node.meta["val"] should be updated.
|
191
|
+
- After creating new node
|
192
|
+
- After updating node's args or kwargs
|
193
|
+
"""
|
194
|
+
assert isinstance(node, torch.fx.node.Node)
|
195
|
+
|
196
|
+
# `node.target()` needs only `Tensor` for its arguments.
|
197
|
+
# Therefore, let's retrieve `FakeTensor` if it is `torch.fx.Node`.
|
198
|
+
args, kwargs = pytree.tree_map_only(
|
199
|
+
torch.fx.Node,
|
200
|
+
lambda n: n.meta["val"],
|
201
|
+
(node.args, node.kwargs),
|
202
|
+
)
|
203
|
+
new_val = node.target(*args, **kwargs) # type: ignore[operator]
|
204
|
+
node.meta["val"] = new_val
|
205
|
+
|
206
|
+
|
207
|
+
def unset_meta_val(node: torch.fx.node.Node):
|
208
|
+
"""
|
209
|
+
Unset node.meta["val"].
|
210
|
+
|
211
|
+
- When to use it?
|
212
|
+
When we need to update a node's meta val
|
213
|
+
but some precedent's meta value are not decided yet, (eg. newly created args)
|
214
|
+
let's simply unset meta val and expect `FillMetaVal` do it.
|
215
|
+
"""
|
216
|
+
assert isinstance(node, torch.fx.node.Node)
|
217
|
+
|
218
|
+
if "val" in node.meta:
|
219
|
+
del node.meta["val"]
|
220
|
+
|
221
|
+
|
222
|
+
def run_bash_cmd(command: typing.List[str]) -> subprocess.CompletedProcess[str]:
|
223
|
+
"""
|
224
|
+
Executes a given bash command represented as a sequence of program arguments
|
225
|
+
using subprocess and returns output.
|
226
|
+
|
227
|
+
Args:
|
228
|
+
command (List[str]): A sequence of program arguments.
|
229
|
+
|
230
|
+
Returns:
|
231
|
+
str: The standard output of the executed command.
|
232
|
+
|
233
|
+
Example:
|
234
|
+
>>> completed_process = run_bash_cmd(["echo", "Hello, World!"])
|
235
|
+
print (completed_process.stdout)
|
236
|
+
'Hello, World!\\n'
|
237
|
+
|
238
|
+
>>> cp = run_bash_cmd(["ls", "-l"])
|
239
|
+
print (cp.stdout)
|
240
|
+
'drwxrwxr-x 8 user group 4096 12월 3 17:16 tico\\n'
|
241
|
+
"""
|
242
|
+
if not isinstance(command, list) or not all(isinstance(c, str) for c in command):
|
243
|
+
raise ValueError("Command must be a list of strings.")
|
244
|
+
try:
|
245
|
+
return subprocess.run(command, check=True, text=True, capture_output=True)
|
246
|
+
except subprocess.CalledProcessError as err:
|
247
|
+
cmd_str = " ".join(err.cmd)
|
248
|
+
msg = f"Error while running command:\n\n $ {cmd_str}"
|
249
|
+
msg += "\n"
|
250
|
+
msg += "[EXIT CODE]\n"
|
251
|
+
msg += f"{err.returncode}\n"
|
252
|
+
msg += "[STDOUT]\n"
|
253
|
+
msg += err.stdout
|
254
|
+
msg += "[STDERR]\n"
|
255
|
+
msg += err.stderr
|
256
|
+
raise RuntimeError(f"Failed.\n\n {msg}")
|
257
|
+
|
258
|
+
|
259
|
+
def has_quantization_ops(graph: torch.fx.Graph):
|
260
|
+
"""
|
261
|
+
Checks whether the given fx graph contains any quantization-related operations.
|
262
|
+
|
263
|
+
This function inspects the provided graph to determine if it includes operations associated
|
264
|
+
with quantization (e.g., quantize, dequantize, fake quantize, etc.). The presence of such operations
|
265
|
+
can be used to decide whether to run subsequent quantization-specific passes on the graph.
|
266
|
+
|
267
|
+
Parameters:
|
268
|
+
graph: The fx graph to be examined. It is expected that the graph supports
|
269
|
+
iteration or traversal over its constituent operations.
|
270
|
+
|
271
|
+
Returns:
|
272
|
+
bool: True if the graph contains one or more quantization-related operations, False otherwise.
|
273
|
+
"""
|
274
|
+
quantized_ops = [
|
275
|
+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
276
|
+
torch.ops.quantized_decomposed.quantize_per_channel.default,
|
277
|
+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
278
|
+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
279
|
+
]
|
280
|
+
for node in graph.nodes:
|
281
|
+
if node.op != "call_function":
|
282
|
+
continue
|
283
|
+
if node.target in quantized_ops:
|
284
|
+
return True
|
285
|
+
|
286
|
+
return False
|
287
|
+
|
288
|
+
|
289
|
+
def to_circle_qparam(qparam: QuantParam):
|
290
|
+
circle_qparam = circle.QuantizationParameters.QuantizationParametersT()
|
291
|
+
if qparam.scale is not None:
|
292
|
+
circle_qparam.scale = qparam.scale
|
293
|
+
|
294
|
+
if qparam.zero_point is not None:
|
295
|
+
circle_qparam.zeroPoint = qparam.zero_point
|
296
|
+
|
297
|
+
if qparam.quantized_dimension is not None:
|
298
|
+
circle_qparam.quantizedDimension = qparam.quantized_dimension
|
299
|
+
|
300
|
+
if qparam.min is not None:
|
301
|
+
circle_qparam.min = qparam.min
|
302
|
+
|
303
|
+
if qparam.max is not None:
|
304
|
+
circle_qparam.max = qparam.max
|
305
|
+
|
306
|
+
return circle_qparam
|
307
|
+
|
308
|
+
|
309
|
+
def quant_min_max(dtype: str):
|
310
|
+
if dtype == "uint8":
|
311
|
+
return (0, 255)
|
312
|
+
elif dtype == "int16":
|
313
|
+
return (-32768, 32767)
|
314
|
+
else:
|
315
|
+
raise NotImplementedError(f"NYI dtype: {dtype}")
|
316
|
+
|
317
|
+
|
318
|
+
def get_quant_dtype(qmin: int, qmax: int):
|
319
|
+
"""
|
320
|
+
Returns the string representation of the quantized data type based on qmin and qmax.
|
321
|
+
|
322
|
+
Args:
|
323
|
+
qmin (int): Minimum quantized value.
|
324
|
+
qmax (int): Maximum quantized value.
|
325
|
+
|
326
|
+
Returns:
|
327
|
+
str: A string representing the quantized data type, such as "int8", "uint4", etc.
|
328
|
+
|
329
|
+
Raises:
|
330
|
+
ValueError: If the (qmin, qmax) pair is not supported.
|
331
|
+
"""
|
332
|
+
known_ranges = {
|
333
|
+
(-32768, 32767): "int16",
|
334
|
+
(-32767, 32767): "int16",
|
335
|
+
(0, 65535): "uint16",
|
336
|
+
(-128, 127): "int8",
|
337
|
+
(0, 255): "uint8",
|
338
|
+
(-8, 7): "int4",
|
339
|
+
(0, 15): "uint4",
|
340
|
+
}
|
341
|
+
|
342
|
+
if (qmin, qmax) in known_ranges:
|
343
|
+
return known_ranges[(qmin, qmax)]
|
344
|
+
else:
|
345
|
+
raise ValueError(f"Unsupported quantization range: ({qmin}, {qmax})")
|
346
|
+
|
347
|
+
|
348
|
+
def broadcastable(
|
349
|
+
shape_a: List[int] | torch.Size, shape_b: List[int] | torch.Size
|
350
|
+
) -> bool:
|
351
|
+
"""
|
352
|
+
Return **True** if two shapes are broadcast-compatible under the standard
|
353
|
+
NumPy/PyTorch rules.
|
354
|
+
|
355
|
+
Broadcasting rule
|
356
|
+
--------------------------------
|
357
|
+
- Align the shapes **right-to-left**.
|
358
|
+
- For each aligned dimension `(a, b)` one of the following must hold
|
359
|
+
- `a == b` (sizes match)
|
360
|
+
- `a == 1` (shape-A can repeat along that dim)
|
361
|
+
- `b == 1` (shape-B can repeat along that dim)
|
362
|
+
- When one shape is shorter, treat its missing leading dims as `1`.
|
363
|
+
|
364
|
+
Examples
|
365
|
+
--------
|
366
|
+
>>> _broadcastable([8, 16, 32], [16, 32])
|
367
|
+
True
|
368
|
+
>>> _broadcastable([8, 16, 32], [1, 32])
|
369
|
+
True
|
370
|
+
>>> _broadcastable([8, 16, 32], [8, 32, 16])
|
371
|
+
False
|
372
|
+
"""
|
373
|
+
# Walk from the last dim to the front
|
374
|
+
len_a, len_b = len(shape_a), len(shape_b)
|
375
|
+
max_len = max(len_a, len_b)
|
376
|
+
for i in range(1, max_len + 1):
|
377
|
+
dim_a = shape_a[-i] if i <= len_a else 1
|
378
|
+
dim_b = shape_b[-i] if i <= len_b else 1
|
379
|
+
if dim_a != 1 and dim_b != 1 and dim_a != dim_b:
|
380
|
+
return False
|
381
|
+
return True
|
382
|
+
|
383
|
+
|
384
|
+
def is_target_node(
|
385
|
+
node: torch.fx.Node, target_ops: list[torch._ops.OpOverload] | torch._ops.OpOverload
|
386
|
+
):
|
387
|
+
"""
|
388
|
+
Check whether a given node is a `call_function` node that matches one of the specified targets.
|
389
|
+
|
390
|
+
Args:
|
391
|
+
node (torch.fx.Node): The node to check.
|
392
|
+
target_ops (Iterable[Callable]): A list or set of target operations to match (e.g., ops.aten.reshape).
|
393
|
+
|
394
|
+
Returns:
|
395
|
+
bool: True if the node is a call_function, its target is in `target_ops`.
|
396
|
+
"""
|
397
|
+
if not isinstance(target_ops, list):
|
398
|
+
target_ops = [target_ops]
|
399
|
+
assert all(isinstance(t, torch._ops.OpOverload) for t in target_ops), target_ops
|
400
|
+
|
401
|
+
if node.op != "call_function":
|
402
|
+
return False
|
403
|
+
if node.target not in target_ops:
|
404
|
+
return False
|
405
|
+
|
406
|
+
return True
|