tico 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +42 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/quantize_bias.py +123 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +191 -0
- tico/passes/cast_mixed_type_args.py +187 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +160 -0
- tico/passes/convert_layout_op_to_reshape.py +85 -0
- tico/passes/convert_repeat_to_expand_copy.py +89 -0
- tico/passes/convert_to_relu6.py +181 -0
- tico/passes/decompose_addmm.py +124 -0
- tico/passes/decompose_batch_norm.py +192 -0
- tico/passes/decompose_fake_quantize.py +134 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
- tico/passes/decompose_group_norm.py +275 -0
- tico/passes/decompose_grouped_conv2d.py +209 -0
- tico/passes/decompose_slice_scatter.py +169 -0
- tico/passes/extract_dtype_kwargs.py +122 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +108 -0
- tico/passes/legalize_predefined_layout_operators.py +386 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
- tico/passes/lower_to_slice.py +230 -0
- tico/passes/merge_consecutive_cat.py +80 -0
- tico/passes/ops.py +78 -0
- tico/passes/remove_nop.py +84 -0
- tico/passes/remove_redundant_assert_nodes.py +51 -0
- tico/passes/remove_redundant_expand.py +66 -0
- tico/passes/remove_redundant_permute.py +122 -0
- tico/passes/remove_redundant_reshape.py +436 -0
- tico/passes/remove_redundant_slice.py +62 -0
- tico/passes/remove_redundant_to_copy.py +86 -0
- tico/passes/restore_linear.py +115 -0
- tico/passes/segment_index_select.py +145 -0
- tico/pt2_to_circle.py +105 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +319 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +240 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_abs.py +53 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +150 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +192 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +126 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +186 -0
- tico/serialize/operators/op_copy.py +164 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +95 -0
- tico/serialize/operators/op_depthwise_conv2d.py +199 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_leaky_relu.py +60 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +86 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_dim.py +70 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +177 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +141 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +100 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +99 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +94 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +296 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +282 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/mx/__init__.py +1 -0
- tico/utils/mx/elemwise_ops.py +267 -0
- tico/utils/mx/formats.py +125 -0
- tico/utils/mx/mx_ops.py +270 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +609 -0
- tico/utils/serialize.py +42 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +406 -0
- tico/utils/validate_args_kwargs.py +1149 -0
- tico-0.1.0.dist-info/LICENSE +241 -0
- tico-0.1.0.dist-info/METADATA +354 -0
- tico-0.1.0.dist-info/RECORD +206 -0
- tico-0.1.0.dist-info/WHEEL +5 -0
- tico-0.1.0.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dist-info/top_level.txt +1 -0
tico/utils/graph.py
ADDED
@@ -0,0 +1,282 @@
|
|
1
|
+
# Portions of this file are adapted from code originally authored by
|
2
|
+
# Meta Platforms, Inc. and affiliates, licensed under the BSD-style
|
3
|
+
# license found in the LICENSE file in the root directory of their source tree.
|
4
|
+
|
5
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
6
|
+
#
|
7
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
8
|
+
# you may not use this file except in compliance with the License.
|
9
|
+
# You may obtain a copy of the License at
|
10
|
+
#
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12
|
+
#
|
13
|
+
# Unless required by applicable law or agreed to in writing, software
|
14
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
15
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16
|
+
# See the License for the specific language governing permissions and
|
17
|
+
# limitations under the License.
|
18
|
+
|
19
|
+
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
|
20
|
+
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
import torch.fx
|
23
|
+
import torch
|
24
|
+
from torch.export import ExportedProgram
|
25
|
+
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
|
26
|
+
|
27
|
+
from tico.utils.utils import get_fake_mode, set_new_meta_val
|
28
|
+
|
29
|
+
|
30
|
+
def is_torch_param(node: torch.fx.Node, ep: ExportedProgram):
|
31
|
+
assert node.op == "placeholder"
|
32
|
+
|
33
|
+
return node.name in ep.graph_signature.inputs_to_parameters
|
34
|
+
|
35
|
+
|
36
|
+
def is_torch_buffer(node: torch.fx.Node, ep: ExportedProgram):
|
37
|
+
assert node.op == "placeholder"
|
38
|
+
|
39
|
+
return node.name in ep.graph_signature.inputs_to_buffers
|
40
|
+
|
41
|
+
|
42
|
+
def get_torch_param_value(node: torch.fx.Node, ep: ExportedProgram):
|
43
|
+
assert isinstance(node, torch.fx.Node)
|
44
|
+
assert node.op == "placeholder"
|
45
|
+
assert (
|
46
|
+
node.name in ep.graph_signature.inputs_to_parameters
|
47
|
+
), "Node {node.name} is not in the parameters" # FIX CALLER UNLESS
|
48
|
+
|
49
|
+
param_name = ep.graph_signature.inputs_to_parameters[node.name]
|
50
|
+
named_params = dict(ep.named_parameters())
|
51
|
+
assert param_name in named_params
|
52
|
+
|
53
|
+
return named_params[param_name].data
|
54
|
+
|
55
|
+
|
56
|
+
def get_torch_buffer_value(node: torch.fx.Node, ep: ExportedProgram):
|
57
|
+
assert isinstance(node, torch.fx.Node)
|
58
|
+
assert node.op == "placeholder"
|
59
|
+
assert (
|
60
|
+
node.name in ep.graph_signature.inputs_to_buffers
|
61
|
+
), "Node {node.name} is not in the buffers" # FIX CALLER UNLESS
|
62
|
+
|
63
|
+
buf_name = ep.graph_signature.inputs_to_buffers[node.name]
|
64
|
+
named_buf = dict(ep.named_buffers())
|
65
|
+
assert buf_name in named_buf
|
66
|
+
|
67
|
+
return named_buf[buf_name]
|
68
|
+
|
69
|
+
|
70
|
+
def get_first_user_input(exported_program: ExportedProgram) -> Optional[torch.fx.Node]:
|
71
|
+
"""Returns the first user input node in the graph."""
|
72
|
+
first_user_input: Optional[torch.fx.Node] = None
|
73
|
+
graph_module = exported_program.graph_module
|
74
|
+
graph: torch.fx.Graph = graph_module.graph
|
75
|
+
for node in graph.nodes:
|
76
|
+
if (
|
77
|
+
node.op == "placeholder"
|
78
|
+
and node.name in exported_program.graph_signature.user_inputs
|
79
|
+
):
|
80
|
+
first_user_input = node
|
81
|
+
break
|
82
|
+
|
83
|
+
return first_user_input
|
84
|
+
|
85
|
+
|
86
|
+
def generate_fqn(prefix: str, exported_program: ExportedProgram):
|
87
|
+
"""
|
88
|
+
Generate fully-qualized name for constants.
|
89
|
+
|
90
|
+
This function prevents `exported_program.constants` from having duplicate keys.
|
91
|
+
"""
|
92
|
+
cnt = len(exported_program.constants)
|
93
|
+
while True:
|
94
|
+
if f"{prefix}{cnt}" in exported_program.constants:
|
95
|
+
cnt += 1
|
96
|
+
continue
|
97
|
+
break
|
98
|
+
return f"{prefix}{cnt}"
|
99
|
+
|
100
|
+
|
101
|
+
def create_input_spec(node, input_kind: InputKind):
|
102
|
+
"""
|
103
|
+
@ref https://pytorch.org/docs/stable/export.ir_spec.html#placeholder
|
104
|
+
"""
|
105
|
+
if input_kind == InputKind.CONSTANT_TENSOR:
|
106
|
+
return InputSpec(
|
107
|
+
kind=InputKind.CONSTANT_TENSOR,
|
108
|
+
arg=TensorArgument(name=node.name),
|
109
|
+
target=node.target, # type: ignore[arg-type]
|
110
|
+
persistent=True,
|
111
|
+
)
|
112
|
+
else:
|
113
|
+
raise NotImplementedError("NYI")
|
114
|
+
|
115
|
+
|
116
|
+
def validate_input_specs(exported_program):
|
117
|
+
name_to_spec_dict = {
|
118
|
+
s.arg.name: s for s in exported_program.graph_signature.input_specs
|
119
|
+
}
|
120
|
+
|
121
|
+
for node in exported_program.graph.nodes:
|
122
|
+
if node.op != "placeholder":
|
123
|
+
continue
|
124
|
+
|
125
|
+
if node.name not in name_to_spec_dict:
|
126
|
+
raise RuntimeError(
|
127
|
+
"Placeholder node {node.name} does not have corresponding input spec!"
|
128
|
+
)
|
129
|
+
|
130
|
+
|
131
|
+
def add_placeholder(
|
132
|
+
exported_program: ExportedProgram,
|
133
|
+
tensor: torch.Tensor,
|
134
|
+
prefix: str,
|
135
|
+
) -> torch.fx.Node:
|
136
|
+
"""
|
137
|
+
Add a placeholder to the graph and update the exported program.
|
138
|
+
"""
|
139
|
+
fqn_name = generate_fqn(prefix, exported_program)
|
140
|
+
|
141
|
+
# Get fake mode before adding placeholder
|
142
|
+
fake_mode = get_fake_mode(exported_program)
|
143
|
+
|
144
|
+
first_user_input = get_first_user_input(exported_program)
|
145
|
+
if not first_user_input:
|
146
|
+
# Placeholder nodes must be the first N nodes in the nodes list of a graph.
|
147
|
+
# Therefore, insert the newly created placeholders at the start of the node list.
|
148
|
+
assert exported_program.graph.nodes
|
149
|
+
first_node = list(exported_program.graph.nodes)[0]
|
150
|
+
first_user_input = first_node
|
151
|
+
|
152
|
+
# Add a placeholder to the graph.
|
153
|
+
with exported_program.graph.inserting_before(first_user_input):
|
154
|
+
const_node = exported_program.graph.placeholder(fqn_name)
|
155
|
+
|
156
|
+
const_node.meta["val"] = fake_mode.from_tensor(tensor, static_shapes=True)
|
157
|
+
const_node.meta["val"].constant = tensor
|
158
|
+
|
159
|
+
# Add a new constant to the exported program.
|
160
|
+
exported_program.constants[const_node.name] = tensor
|
161
|
+
|
162
|
+
# Use update (instead of append) if this assert is violated
|
163
|
+
assert const_node.name not in [
|
164
|
+
s.arg.name for s in exported_program.graph_signature.input_specs
|
165
|
+
]
|
166
|
+
|
167
|
+
# Append the new input spec.
|
168
|
+
exported_program.graph_signature.input_specs.append(
|
169
|
+
create_input_spec(const_node, InputKind.CONSTANT_TENSOR)
|
170
|
+
)
|
171
|
+
|
172
|
+
# Get old input specs
|
173
|
+
name_to_spec_dict = {
|
174
|
+
s.arg.name: s for s in exported_program.graph_signature.input_specs
|
175
|
+
}
|
176
|
+
|
177
|
+
# Add the new constants to input specs dict.
|
178
|
+
name_to_spec_dict.update(
|
179
|
+
{const_node.name: create_input_spec(const_node, InputKind.CONSTANT_TENSOR)}
|
180
|
+
)
|
181
|
+
|
182
|
+
# Generate new input spec *in the same order of nodes*
|
183
|
+
# IMPORTANT Input specs and their placeholder nodes must have the same order.
|
184
|
+
new_input_specs = []
|
185
|
+
for node in exported_program.graph.nodes:
|
186
|
+
if node.op != "placeholder":
|
187
|
+
continue
|
188
|
+
new_input_specs.append(name_to_spec_dict[node.name])
|
189
|
+
exported_program.graph_signature.input_specs = new_input_specs
|
190
|
+
|
191
|
+
return const_node
|
192
|
+
|
193
|
+
|
194
|
+
def is_single_value_tensor(t: torch.Tensor):
|
195
|
+
if len(t.size()) == 0:
|
196
|
+
return True
|
197
|
+
if len(t.size()) == 1 and t.size()[0] == 1:
|
198
|
+
return True
|
199
|
+
|
200
|
+
return False
|
201
|
+
|
202
|
+
|
203
|
+
def get_module_name_chain(node: Optional[torch.fx.Node]) -> str:
|
204
|
+
"""
|
205
|
+
Returns a slash-separated string of module names representing the
|
206
|
+
hierarchical path of the FX node within the original model.
|
207
|
+
|
208
|
+
If the node has no `nn_module_stack` metadata, "unknown" is returned.
|
209
|
+
|
210
|
+
Example:
|
211
|
+
"encoder/layer1/linear"
|
212
|
+
|
213
|
+
Parameters
|
214
|
+
----------
|
215
|
+
node: torch.fx.Node
|
216
|
+
A node from an ExportedProgram graph.
|
217
|
+
|
218
|
+
Returns
|
219
|
+
-------
|
220
|
+
str
|
221
|
+
A human-readable string that describes the full module path.
|
222
|
+
"""
|
223
|
+
if node is None:
|
224
|
+
return "unknown"
|
225
|
+
# Let's prefix "tico" for graph inputs
|
226
|
+
if node.op == "placeholder" and "nn_module_stack" not in node.meta:
|
227
|
+
return "tico"
|
228
|
+
|
229
|
+
assert isinstance(node, torch.fx.Node)
|
230
|
+
stack = node.meta.get("nn_module_stack")
|
231
|
+
if stack:
|
232
|
+
assert isinstance(stack, dict)
|
233
|
+
# Retrieving the last element is enough.
|
234
|
+
return next(reversed(stack.values()))[1]
|
235
|
+
else:
|
236
|
+
return "unknown"
|
237
|
+
|
238
|
+
|
239
|
+
def create_node(
|
240
|
+
graph: torch.fx.Graph,
|
241
|
+
target: torch._ops.OpOverload,
|
242
|
+
args: Optional[Tuple[Any, ...]] = None,
|
243
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
244
|
+
*,
|
245
|
+
origin: Optional[torch.fx.Node] = None,
|
246
|
+
) -> torch.fx.Node:
|
247
|
+
"""
|
248
|
+
Insert a new node into graph and propagate metadata from *origin*.
|
249
|
+
|
250
|
+
Parameters
|
251
|
+
----------
|
252
|
+
graph : torch.fx.Graph
|
253
|
+
The graph that will own the newly-created node.
|
254
|
+
|
255
|
+
target : torch._ops.OpOverload
|
256
|
+
The op to call (e.g. `torch.add` or "call_function" target).
|
257
|
+
|
258
|
+
args : Tuple[Any, ...], optional
|
259
|
+
Positional arguments for the new node.
|
260
|
+
|
261
|
+
kwargs : Dict[str, Any], optional
|
262
|
+
Keyword arguments for the new node.
|
263
|
+
|
264
|
+
origin : torch.fx.Node, optional
|
265
|
+
If given, every key in `origin.meta` **except** "val" is copied
|
266
|
+
onto the new node. "val" is recomputed from *args* /*kwargs* using
|
267
|
+
the internal meta-inference helper.
|
268
|
+
|
269
|
+
Returns
|
270
|
+
-------
|
271
|
+
torch.fx.Node
|
272
|
+
The freshly inserted node with fully-populated `.meta`.
|
273
|
+
"""
|
274
|
+
new_node = graph.call_function(target, args=args, kwargs=kwargs)
|
275
|
+
if origin:
|
276
|
+
assert isinstance(origin, torch.fx.Node), type(origin)
|
277
|
+
# Propagate "nn_module_stack" to retain the originating module context
|
278
|
+
# for meaningful node names.
|
279
|
+
if "nn_module_stack" in origin.meta:
|
280
|
+
new_node.meta["nn_module_stack"] = origin.meta["nn_module_stack"]
|
281
|
+
|
282
|
+
return new_node
|
tico/utils/logging.py
ADDED
@@ -0,0 +1,45 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import logging
|
16
|
+
import os
|
17
|
+
|
18
|
+
|
19
|
+
def _loggerLevel():
|
20
|
+
TICO_LOG = os.environ.get("TICO_LOG")
|
21
|
+
if TICO_LOG == "1":
|
22
|
+
log_level = logging.FATAL
|
23
|
+
elif TICO_LOG == "2":
|
24
|
+
log_level = logging.WARNING
|
25
|
+
elif TICO_LOG == "3":
|
26
|
+
log_level = logging.INFO
|
27
|
+
elif TICO_LOG == "4":
|
28
|
+
log_level = logging.DEBUG
|
29
|
+
else:
|
30
|
+
log_level = logging.WARNING
|
31
|
+
return log_level
|
32
|
+
|
33
|
+
|
34
|
+
LOG_LEVEL = _loggerLevel()
|
35
|
+
|
36
|
+
|
37
|
+
def getLogger(name: str):
|
38
|
+
"""
|
39
|
+
Get logger with setting log level according to the `TICO_LOG` environment variable.
|
40
|
+
"""
|
41
|
+
logging.basicConfig()
|
42
|
+
logger = logging.getLogger(name)
|
43
|
+
logger.setLevel(LOG_LEVEL)
|
44
|
+
|
45
|
+
return logger
|
tico/utils/model.py
ADDED
@@ -0,0 +1,37 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from __future__ import annotations
|
16
|
+
|
17
|
+
from typing import Any
|
18
|
+
|
19
|
+
from tico.interpreter import infer
|
20
|
+
|
21
|
+
|
22
|
+
class CircleModel:
|
23
|
+
def __init__(self, circle_binary: bytes):
|
24
|
+
self.circle_binary = circle_binary
|
25
|
+
|
26
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
27
|
+
return infer.infer(self.circle_binary, *args, **kwargs)
|
28
|
+
|
29
|
+
@staticmethod
|
30
|
+
def load(circle_path: str) -> CircleModel:
|
31
|
+
with open(circle_path, "rb") as f:
|
32
|
+
buf = bytes(f.read())
|
33
|
+
return CircleModel(buf)
|
34
|
+
|
35
|
+
def save(self, circle_path: str) -> None:
|
36
|
+
with open(circle_path, "wb") as f:
|
37
|
+
f.write(self.circle_binary)
|
@@ -0,0 +1 @@
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
@@ -0,0 +1,267 @@
|
|
1
|
+
"""
|
2
|
+
Copyright (c) Microsoft Corporation.
|
3
|
+
Licensed under the MIT License.
|
4
|
+
|
5
|
+
Name: elemwise_ops.py
|
6
|
+
|
7
|
+
Pytorch functions for elementwise (i.e. bfloat) quantization.
|
8
|
+
|
9
|
+
Usage Notes:
|
10
|
+
- Use the "Exposed Methods" below to implement autograd functions
|
11
|
+
- Use autograd functions to then implement torch.nn.Module(s)
|
12
|
+
- Do *not* use methods in this file in Modules, they have no defined
|
13
|
+
backwards pass and will block gradient computation.
|
14
|
+
- Avoid importing internal function if at all possible.
|
15
|
+
|
16
|
+
Exposed Methods:
|
17
|
+
quantize_elemwise_op - quantizes a tensor to bfloat or other
|
18
|
+
custom float format
|
19
|
+
"""
|
20
|
+
import torch
|
21
|
+
|
22
|
+
from .formats import RoundingMode, _get_format_params
|
23
|
+
from .formats import _get_min_norm, _get_max_norm
|
24
|
+
|
25
|
+
|
26
|
+
# -------------------------------------------------------------------------
|
27
|
+
# Helper funcs
|
28
|
+
# -------------------------------------------------------------------------
|
29
|
+
# Never explicitly compute 2**(-exp) since subnorm numbers have
|
30
|
+
# exponents smaller than -126
|
31
|
+
def _safe_lshift(x, bits, exp):
|
32
|
+
if exp is None:
|
33
|
+
return x * (2**bits)
|
34
|
+
else:
|
35
|
+
return x / (2 ** exp) * (2**bits)
|
36
|
+
|
37
|
+
|
38
|
+
def _safe_rshift(x, bits, exp):
|
39
|
+
if exp is None:
|
40
|
+
return x / (2**bits)
|
41
|
+
else:
|
42
|
+
return x / (2**bits) * (2 ** exp)
|
43
|
+
|
44
|
+
|
45
|
+
def _round_mantissa(A, bits, round, clamp=False):
|
46
|
+
"""
|
47
|
+
Rounds mantissa to nearest bits depending on the rounding method 'round'
|
48
|
+
Args:
|
49
|
+
A {PyTorch tensor} -- Input tensor
|
50
|
+
round {str} -- Rounding method
|
51
|
+
"floor" rounds to the floor
|
52
|
+
"nearest" rounds to ceil or floor, whichever is nearest
|
53
|
+
Returns:
|
54
|
+
A {PyTorch tensor} -- Tensor with mantissas rounded
|
55
|
+
"""
|
56
|
+
|
57
|
+
if round == "dither":
|
58
|
+
rand_A = torch.rand_like(A, requires_grad=False)
|
59
|
+
A = torch.sign(A) * torch.floor(torch.abs(A) + rand_A)
|
60
|
+
elif round == "floor":
|
61
|
+
A = torch.sign(A) * torch.floor(torch.abs(A))
|
62
|
+
elif round == "nearest":
|
63
|
+
A = torch.sign(A) * torch.floor(torch.abs(A) + 0.5)
|
64
|
+
elif round == "even":
|
65
|
+
absA = torch.abs(A)
|
66
|
+
# find 0.5, 2.5, 4.5 ...
|
67
|
+
maskA = ((absA - 0.5) % 2 == torch.zeros_like(A)).type(A.dtype)
|
68
|
+
A = torch.sign(A) * (torch.floor(absA + 0.5) - maskA)
|
69
|
+
else:
|
70
|
+
raise Exception("Unrecognized round method %s" % (round))
|
71
|
+
|
72
|
+
# Clip values that cannot be expressed by the specified number of bits
|
73
|
+
if clamp:
|
74
|
+
max_mantissa = 2 ** (bits - 1) - 1
|
75
|
+
A = torch.clamp(A, -max_mantissa, max_mantissa)
|
76
|
+
return A
|
77
|
+
|
78
|
+
|
79
|
+
# -------------------------------------------------------------------------
|
80
|
+
# Main funcs
|
81
|
+
# -------------------------------------------------------------------------
|
82
|
+
def _quantize_elemwise_core(A, bits, exp_bits, max_norm, round='nearest',
|
83
|
+
saturate_normals=False, allow_denorm=True,
|
84
|
+
custom_cuda=False):
|
85
|
+
""" Core function used for element-wise quantization
|
86
|
+
Arguments:
|
87
|
+
A {PyTorch tensor} -- A tensor to be quantized
|
88
|
+
bits {int} -- Number of mantissa bits. Includes
|
89
|
+
sign bit and implicit one for floats
|
90
|
+
exp_bits {int} -- Number of exponent bits, 0 for ints
|
91
|
+
max_norm {float} -- Largest representable normal number
|
92
|
+
round {str} -- Rounding mode: (floor, nearest, even)
|
93
|
+
saturate_normals {bool} -- If True, normal numbers (i.e., not NaN/Inf)
|
94
|
+
that exceed max norm are clamped.
|
95
|
+
Must be True for correct MX conversion.
|
96
|
+
allow_denorm {bool} -- If False, flush denorm numbers in the
|
97
|
+
elem_format to zero.
|
98
|
+
custom_cuda {str} -- If True, use custom CUDA kernels
|
99
|
+
Returns:
|
100
|
+
quantized tensor {PyTorch tensor} -- A tensor that has been quantized
|
101
|
+
"""
|
102
|
+
A_is_sparse = A.is_sparse
|
103
|
+
if A_is_sparse:
|
104
|
+
if A.layout != torch.sparse_coo:
|
105
|
+
raise NotImplementedError("Only COO layout sparse tensors are currently supported.")
|
106
|
+
|
107
|
+
sparse_A = A.coalesce()
|
108
|
+
A = sparse_A.values().clone()
|
109
|
+
|
110
|
+
# custom cuda only support floor and nearest rounding modes
|
111
|
+
custom_cuda = custom_cuda and round in RoundingMode.string_enums()
|
112
|
+
|
113
|
+
if custom_cuda:
|
114
|
+
A = A.contiguous()
|
115
|
+
|
116
|
+
from . import custom_extensions
|
117
|
+
if A.device.type == "cuda":
|
118
|
+
A = custom_extensions.funcs.quantize_elemwise_func_cuda(
|
119
|
+
A, bits, exp_bits, max_norm, RoundingMode[round],
|
120
|
+
saturate_normals, allow_denorm)
|
121
|
+
elif A.device.type == "cpu":
|
122
|
+
A = custom_extensions.funcs.quantize_elemwise_func_cpp(
|
123
|
+
A, bits, exp_bits, max_norm, RoundingMode[round],
|
124
|
+
saturate_normals, allow_denorm)
|
125
|
+
return A
|
126
|
+
|
127
|
+
# Flush values < min_norm to zero if denorms are not allowed
|
128
|
+
if not allow_denorm and exp_bits > 0:
|
129
|
+
min_norm = _get_min_norm(exp_bits)
|
130
|
+
out = (torch.abs(A) >= min_norm).type(A.dtype) * A
|
131
|
+
else:
|
132
|
+
out = A
|
133
|
+
|
134
|
+
if exp_bits != 0:
|
135
|
+
private_exp = torch.floor(torch.log2(
|
136
|
+
torch.abs(A) + (A == 0).type(A.dtype)))
|
137
|
+
|
138
|
+
# The minimum representable exponent for 8 exp bits is -126
|
139
|
+
min_exp = -(2**(exp_bits-1)) + 2
|
140
|
+
private_exp = private_exp.clip(min=min_exp)
|
141
|
+
else:
|
142
|
+
private_exp = None
|
143
|
+
|
144
|
+
# Scale up so appropriate number of bits are in the integer portion of the number
|
145
|
+
out = _safe_lshift(out, bits - 2, private_exp)
|
146
|
+
|
147
|
+
out = _round_mantissa(out, bits, round, clamp=False)
|
148
|
+
|
149
|
+
# Undo scaling
|
150
|
+
out = _safe_rshift(out, bits - 2, private_exp)
|
151
|
+
|
152
|
+
# Set values > max_norm to Inf if desired, else clamp them
|
153
|
+
if saturate_normals or exp_bits == 0:
|
154
|
+
out = torch.clamp(out, min=-max_norm, max=max_norm)
|
155
|
+
else:
|
156
|
+
out = torch.where((torch.abs(out) > max_norm),
|
157
|
+
torch.sign(out) * float("Inf"), out)
|
158
|
+
|
159
|
+
# handle Inf/NaN
|
160
|
+
if not custom_cuda:
|
161
|
+
out[A == float("Inf")] = float("Inf")
|
162
|
+
out[A == -float("Inf")] = -float("Inf")
|
163
|
+
out[A == float("NaN")] = float("NaN")
|
164
|
+
|
165
|
+
if A_is_sparse:
|
166
|
+
output = torch.sparse_coo_tensor(sparse_A.indices(), output,
|
167
|
+
sparse_A.size(), dtype=sparse_A.dtype, device=sparse_A.device,
|
168
|
+
requires_grad=sparse_A.requires_grad)
|
169
|
+
|
170
|
+
return out
|
171
|
+
|
172
|
+
|
173
|
+
def _quantize_elemwise(A, elem_format, round='nearest', custom_cuda=False,
|
174
|
+
saturate_normals=False, allow_denorm=True):
|
175
|
+
""" Quantize values to a defined format. See _quantize_elemwise_core()
|
176
|
+
"""
|
177
|
+
if elem_format == None:
|
178
|
+
return A
|
179
|
+
|
180
|
+
ebits, mbits, _, max_norm, _ = _get_format_params(elem_format)
|
181
|
+
|
182
|
+
output = _quantize_elemwise_core(
|
183
|
+
A, mbits, ebits, max_norm,
|
184
|
+
round=round, allow_denorm=allow_denorm,
|
185
|
+
saturate_normals=saturate_normals,
|
186
|
+
custom_cuda=custom_cuda)
|
187
|
+
|
188
|
+
return output
|
189
|
+
|
190
|
+
|
191
|
+
def _quantize_bfloat(A, bfloat, round='nearest', custom_cuda=False, allow_denorm=True):
|
192
|
+
""" Quantize values to bfloatX format
|
193
|
+
Arguments:
|
194
|
+
bfloat {int} -- Total number of bits for bfloatX format,
|
195
|
+
Includes 1 sign, 8 exp bits, and variable
|
196
|
+
mantissa bits. Must be >= 9.
|
197
|
+
"""
|
198
|
+
# Shortcut for no quantization
|
199
|
+
if bfloat == 0 or bfloat == 32:
|
200
|
+
return A
|
201
|
+
|
202
|
+
max_norm = _get_max_norm(8, bfloat-7)
|
203
|
+
|
204
|
+
return _quantize_elemwise_core(
|
205
|
+
A, bits=bfloat-7, exp_bits=8, max_norm=max_norm, round=round,
|
206
|
+
allow_denorm=allow_denorm, custom_cuda=custom_cuda)
|
207
|
+
|
208
|
+
|
209
|
+
def _quantize_fp(A, exp_bits=None, mantissa_bits=None,
|
210
|
+
round='nearest', custom_cuda=False, allow_denorm=True):
|
211
|
+
""" Quantize values to IEEE fpX format. The format defines NaN/Inf
|
212
|
+
and subnorm numbers in the same way as FP32 and FP16.
|
213
|
+
Arguments:
|
214
|
+
exp_bits {int} -- number of bits used to store exponent
|
215
|
+
mantissa_bits {int} -- number of bits used to store mantissa, not
|
216
|
+
including sign or implicit 1
|
217
|
+
round {str} -- Rounding mode, (floor, nearest, even)
|
218
|
+
"""
|
219
|
+
# Shortcut for no quantization
|
220
|
+
if exp_bits is None or mantissa_bits is None:
|
221
|
+
return A
|
222
|
+
|
223
|
+
max_norm = _get_max_norm(exp_bits, mantissa_bits+2)
|
224
|
+
|
225
|
+
output = _quantize_elemwise_core(
|
226
|
+
A, bits=mantissa_bits + 2, exp_bits=exp_bits,
|
227
|
+
max_norm=max_norm, round=round, allow_denorm=allow_denorm,
|
228
|
+
custom_cuda=custom_cuda)
|
229
|
+
|
230
|
+
return output
|
231
|
+
|
232
|
+
|
233
|
+
def quantize_elemwise_op(A, mx_specs, round=None):
|
234
|
+
"""A function used for element-wise quantization with mx_specs
|
235
|
+
Arguments:
|
236
|
+
A {PyTorch tensor} -- a tensor that needs to be quantized
|
237
|
+
mx_specs {dictionary} -- dictionary to specify mx_specs
|
238
|
+
round {str} -- Rounding mode, choose from (floor, nearest, even)
|
239
|
+
(default: "nearest")
|
240
|
+
Returns:
|
241
|
+
quantized value {PyTorch tensor} -- a tensor that has been quantized
|
242
|
+
"""
|
243
|
+
if mx_specs is None:
|
244
|
+
return A
|
245
|
+
elif round is None:
|
246
|
+
round = mx_specs['round']
|
247
|
+
|
248
|
+
if mx_specs['bfloat'] == 16 and round == 'even'\
|
249
|
+
and torch.cuda.is_bf16_supported() \
|
250
|
+
and mx_specs['bfloat_subnorms'] == True:
|
251
|
+
return A.to(torch.bfloat16)
|
252
|
+
|
253
|
+
if mx_specs['bfloat'] > 0 and mx_specs['fp'] > 0:
|
254
|
+
raise ValueError("Cannot set both [bfloat] and [fp] in mx_specs.")
|
255
|
+
elif mx_specs['bfloat'] > 9:
|
256
|
+
A = _quantize_bfloat(A, bfloat=mx_specs['bfloat'], round=round,
|
257
|
+
custom_cuda=mx_specs['custom_cuda'],
|
258
|
+
allow_denorm=mx_specs['bfloat_subnorms'])
|
259
|
+
elif mx_specs['bfloat'] > 0 and mx_specs['bfloat'] <= 9:
|
260
|
+
raise ValueError("Cannot set [bfloat] <= 9 in mx_specs.")
|
261
|
+
elif mx_specs['fp'] > 6:
|
262
|
+
A = _quantize_fp(A, exp_bits=5, mantissa_bits=mx_specs['fp'] - 6,
|
263
|
+
round=round, custom_cuda=mx_specs['custom_cuda'],
|
264
|
+
allow_denorm=mx_specs['bfloat_subnorms'])
|
265
|
+
elif mx_specs['fp'] > 0 and mx_specs['fp'] <= 6:
|
266
|
+
raise ValueError("Cannot set [fp] <= 6 in mx_specs.")
|
267
|
+
return A
|