tico 0.1.0.dev250411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +31 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +185 -0
- tico/passes/cast_mixed_type_args.py +186 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +151 -0
- tico/passes/convert_layout_op_to_reshape.py +84 -0
- tico/passes/convert_repeat_to_expand_copy.py +90 -0
- tico/passes/convert_to_relu6.py +180 -0
- tico/passes/decompose_addmm.py +127 -0
- tico/passes/decompose_batch_norm.py +198 -0
- tico/passes/decompose_fake_quantize.py +126 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
- tico/passes/decompose_group_norm.py +258 -0
- tico/passes/decompose_grouped_conv2d.py +202 -0
- tico/passes/decompose_slice_scatter.py +167 -0
- tico/passes/extract_dtype_kwargs.py +121 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +113 -0
- tico/passes/legalize_predefined_layout_operators.py +383 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
- tico/passes/lower_to_slice.py +112 -0
- tico/passes/merge_consecutive_cat.py +82 -0
- tico/passes/ops.py +75 -0
- tico/passes/remove_nop.py +85 -0
- tico/passes/remove_redundant_assert_nodes.py +50 -0
- tico/passes/remove_redundant_expand.py +70 -0
- tico/passes/remove_redundant_permute.py +102 -0
- tico/passes/remove_redundant_reshape.py +431 -0
- tico/passes/remove_redundant_slice.py +64 -0
- tico/passes/remove_redundant_to_copy.py +84 -0
- tico/passes/restore_linear.py +113 -0
- tico/passes/segment_index_select.py +143 -0
- tico/pt2_to_circle.py +101 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +264 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +232 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +142 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +112 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +123 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +181 -0
- tico/serialize/operators/op_copy.py +162 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +92 -0
- tico/serialize/operators/op_depthwise_conv2d.py +198 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +83 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +174 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +138 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +99 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +96 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +51 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +292 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +200 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +562 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +314 -0
- tico/utils/validate_args_kwargs.py +1114 -0
- tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
- tico-0.1.0.dev250411.dist-info/METADATA +17 -0
- tico-0.1.0.dev250411.dist-info/RECORD +196 -0
- tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
- tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,264 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from collections import defaultdict
|
16
|
+
from typing import Any, cast, Dict, final, List, Optional, TYPE_CHECKING, Union
|
17
|
+
|
18
|
+
if TYPE_CHECKING:
|
19
|
+
import torch.fx
|
20
|
+
import numpy as np
|
21
|
+
import torch
|
22
|
+
from circle_schema import circle
|
23
|
+
from torch._subclasses.fake_tensor import FakeTensor
|
24
|
+
|
25
|
+
from tico.serialize.circle_mapping import (
|
26
|
+
extract_circle_dtype,
|
27
|
+
extract_shape,
|
28
|
+
str_to_circle_dtype,
|
29
|
+
to_circle_dtype,
|
30
|
+
)
|
31
|
+
from tico.serialize.pack import pack_buffer
|
32
|
+
from tico.serialize.quant_param import QPARAM_KEY
|
33
|
+
from tico.utils.utils import to_circle_qparam
|
34
|
+
|
35
|
+
"""
|
36
|
+
Type alias for const
|
37
|
+
"""
|
38
|
+
_PRIMITIVE_TYPES = (
|
39
|
+
float,
|
40
|
+
int,
|
41
|
+
bool,
|
42
|
+
str,
|
43
|
+
torch.Tensor,
|
44
|
+
torch.device,
|
45
|
+
torch.dtype,
|
46
|
+
torch.layout,
|
47
|
+
)
|
48
|
+
ConstDataElement = Union[
|
49
|
+
int, float, bool, str, torch.Tensor, torch.device, torch.dtype, torch.layout
|
50
|
+
]
|
51
|
+
ConstData = Union[ConstDataElement, List[ConstDataElement]]
|
52
|
+
|
53
|
+
|
54
|
+
def is_const(arg) -> bool:
|
55
|
+
if isinstance(arg, FakeTensor):
|
56
|
+
return False
|
57
|
+
if isinstance(arg, _PRIMITIVE_TYPES):
|
58
|
+
return True
|
59
|
+
if isinstance(arg, (tuple, list)):
|
60
|
+
return all(map(is_const, arg))
|
61
|
+
if isinstance(arg, dict):
|
62
|
+
return all(map(is_const, arg.values()))
|
63
|
+
return False
|
64
|
+
|
65
|
+
|
66
|
+
@final
|
67
|
+
class CircleModel(circle.Model.ModelT):
|
68
|
+
def __init__(self):
|
69
|
+
super().__init__()
|
70
|
+
self.subgraphs: List[circle.SubGraph.SubGraphT] = []
|
71
|
+
self.buffers: List[circle.Buffer.BufferT] = []
|
72
|
+
|
73
|
+
def add_subgraph(self, graph: circle.SubGraph.SubGraphT) -> None:
|
74
|
+
self.subgraphs.append(graph)
|
75
|
+
|
76
|
+
def add_buffer(self, buffer: circle.Buffer.BufferT) -> int:
|
77
|
+
"""Return buffer id"""
|
78
|
+
self.buffers.append(buffer)
|
79
|
+
buf_id = len(self.buffers) - 1 # last index
|
80
|
+
return buf_id
|
81
|
+
|
82
|
+
|
83
|
+
@final
|
84
|
+
class CircleSubgraph(circle.SubGraph.SubGraphT):
|
85
|
+
def __init__(self, model: CircleModel):
|
86
|
+
super().__init__()
|
87
|
+
self.model: CircleModel = model
|
88
|
+
self.name: str = "subgraph"
|
89
|
+
self.inputs: List[int] = []
|
90
|
+
self.outputs: List[int] = []
|
91
|
+
self.tensors: List[circle.Tensor.TensorT] = []
|
92
|
+
self.operators: List[circle.Operator.OperatorT] = []
|
93
|
+
self.name_to_tid: Dict[str, int] = {}
|
94
|
+
self.counter: defaultdict = defaultdict(int)
|
95
|
+
|
96
|
+
# Generate a unique name with prefix.
|
97
|
+
# Naming rule
|
98
|
+
# - If no tensor has the same name with prefix, return prefix
|
99
|
+
# - Otherwise, add postfix f"_{idx}" where idx increases by 1 from 0
|
100
|
+
# Example
|
101
|
+
# If prefix = "add", this function will find a unique name in the following order.
|
102
|
+
# "add", "add_0", "add_1", ...
|
103
|
+
def _gen_unique_name_with_prefix(self, prefix: str):
|
104
|
+
name = prefix
|
105
|
+
while self.has_tensor(name):
|
106
|
+
index = self.counter[prefix]
|
107
|
+
name = f"{prefix}_{index}"
|
108
|
+
self.counter[prefix] += 1
|
109
|
+
|
110
|
+
return name
|
111
|
+
|
112
|
+
def _add_tensor(self, tensor: circle.Tensor.TensorT) -> None:
|
113
|
+
self.tensors.append(tensor)
|
114
|
+
self.name_to_tid[tensor.name] = len(self.tensors) - 1
|
115
|
+
|
116
|
+
def add_operator(self, op: circle.Operator.OperatorT) -> None:
|
117
|
+
self.operators.append(op)
|
118
|
+
|
119
|
+
def add_input(self, input_name: str) -> None:
|
120
|
+
assert input_name in self.name_to_tid, f"{input_name}"
|
121
|
+
tid = self.name_to_tid[input_name]
|
122
|
+
self.inputs.append(tid)
|
123
|
+
|
124
|
+
def add_output(self, output: Any) -> None:
|
125
|
+
if isinstance(output, str):
|
126
|
+
assert output in self.name_to_tid
|
127
|
+
output_name = output
|
128
|
+
elif isinstance(output, int | float):
|
129
|
+
# output is built-in type.
|
130
|
+
circle_tensor = self.add_const_tensor(output)
|
131
|
+
output_name = circle_tensor.name
|
132
|
+
else:
|
133
|
+
raise NotImplementedError(f"Unsupported output dtype: {type(output)}")
|
134
|
+
tid = self.name_to_tid[output_name]
|
135
|
+
self.outputs.append(tid)
|
136
|
+
|
137
|
+
def has_tensor(self, name: str):
|
138
|
+
return name in self.name_to_tid
|
139
|
+
|
140
|
+
def add_tensor_from_node(
|
141
|
+
self, node: torch.fx.node.Node, data: Optional[np.ndarray] = None
|
142
|
+
) -> None:
|
143
|
+
tensor = circle.Tensor.TensorT()
|
144
|
+
tensor.name = self._gen_unique_name_with_prefix(node.name)
|
145
|
+
assert node.meta.get("val") is not None
|
146
|
+
tensor.type = extract_circle_dtype(node)
|
147
|
+
tensor.shape = list(extract_shape(node))
|
148
|
+
if QPARAM_KEY in node.meta:
|
149
|
+
tensor.quantization = to_circle_qparam(node.meta[QPARAM_KEY])
|
150
|
+
tensor.type = str_to_circle_dtype(node.meta[QPARAM_KEY].dtype)
|
151
|
+
|
152
|
+
buffer = circle.Buffer.BufferT()
|
153
|
+
if data is not None and isinstance(data, np.ndarray):
|
154
|
+
data = data.flatten()
|
155
|
+
|
156
|
+
if QPARAM_KEY in node.meta:
|
157
|
+
if node.meta[QPARAM_KEY].dtype == "uint4":
|
158
|
+
data = pack_buffer(data, "uint4")
|
159
|
+
|
160
|
+
# Packing np.ndarray is faster than packing bytes
|
161
|
+
buffer.data = data.view(np.uint8) # type: ignore[assignment]
|
162
|
+
else:
|
163
|
+
assert data is None
|
164
|
+
bid = self.model.add_buffer(buffer)
|
165
|
+
tensor.buffer = bid
|
166
|
+
self._add_tensor(tensor)
|
167
|
+
|
168
|
+
def add_const_tensor(self, data: ConstData) -> circle.Tensor.TensorT:
|
169
|
+
assert is_const(data)
|
170
|
+
tensor = circle.Tensor.TensorT()
|
171
|
+
tensor.name = self._gen_unique_name_with_prefix("const_tensor")
|
172
|
+
assert not self.has_tensor(tensor.name)
|
173
|
+
torch_t = torch.as_tensor(data=data)
|
174
|
+
torch_t_shape = list(torch_t.size())
|
175
|
+
tensor.type = to_circle_dtype(torch_dtype=torch_t.dtype)
|
176
|
+
tensor.shape = torch_t_shape
|
177
|
+
|
178
|
+
buffer = circle.Buffer.BufferT()
|
179
|
+
buffer.data = torch_t.flatten().cpu().numpy().view(np.uint8) # type: ignore[assignment]
|
180
|
+
bid = self.model.add_buffer(buffer)
|
181
|
+
tensor.buffer = bid
|
182
|
+
self._add_tensor(tensor)
|
183
|
+
|
184
|
+
return tensor
|
185
|
+
|
186
|
+
def add_tensor_from_scratch(
|
187
|
+
self, prefix: str, shape: List[int], dtype: int
|
188
|
+
) -> circle.Tensor.TensorT:
|
189
|
+
assert isinstance(dtype, int), f"{dtype} must be integer. Use to_circle_dtype."
|
190
|
+
tensor = circle.Tensor.TensorT()
|
191
|
+
tensor.name = self._gen_unique_name_with_prefix(prefix)
|
192
|
+
tensor.type = dtype
|
193
|
+
tensor.shape = shape
|
194
|
+
|
195
|
+
buffer = circle.Buffer.BufferT()
|
196
|
+
bid = self.model.add_buffer(buffer)
|
197
|
+
tensor.buffer = bid
|
198
|
+
self._add_tensor(tensor)
|
199
|
+
|
200
|
+
return tensor
|
201
|
+
|
202
|
+
# Some operators like `full`, `arange_start_step` or `scalar_tensor` needs buffers to be in-place updated.
|
203
|
+
# TODO remove this function
|
204
|
+
def update_tensor_buffer(
|
205
|
+
self, data: ConstData, tensor_name: str = str()
|
206
|
+
) -> circle.Tensor.TensorT:
|
207
|
+
assert is_const(data)
|
208
|
+
assert self.has_tensor(tensor_name)
|
209
|
+
data_tensor = torch.as_tensor(data=data)
|
210
|
+
data_shape = list(data_tensor.size())
|
211
|
+
op_tensor = self.tensors[self.name_to_tid[tensor_name]]
|
212
|
+
assert op_tensor.type == to_circle_dtype(
|
213
|
+
data_tensor.dtype
|
214
|
+
), f"{op_tensor.type}, {data_tensor.dtype}"
|
215
|
+
assert op_tensor.shape == data_shape
|
216
|
+
|
217
|
+
buffer = circle.Buffer.BufferT()
|
218
|
+
# Packing np.ndarray is faster than packing bytes
|
219
|
+
buffer.data = data_tensor.flatten().cpu().numpy().view(np.uint8) # type: ignore[assignment]
|
220
|
+
bid = self.model.add_buffer(buffer)
|
221
|
+
op_tensor.buffer = bid
|
222
|
+
|
223
|
+
return op_tensor
|
224
|
+
|
225
|
+
def get_tid_registered(
|
226
|
+
self, node: Union[torch.fx.node.Node, circle.Tensor.TensorT]
|
227
|
+
) -> int:
|
228
|
+
assert hasattr(node, "name"), "FIX CALLER UNLESS"
|
229
|
+
|
230
|
+
tid = self.name_to_tid.get(node.name, None)
|
231
|
+
|
232
|
+
if tid is None:
|
233
|
+
raise KeyError(f"{node}({node.name}) is not registered.")
|
234
|
+
|
235
|
+
assert tid < len(self.tensors)
|
236
|
+
|
237
|
+
return tid
|
238
|
+
|
239
|
+
def get_tensor(self, node: torch.fx.node.Node) -> circle.Tensor.TensorT:
|
240
|
+
tid = self.get_tid_registered(node)
|
241
|
+
|
242
|
+
return self.tensors[tid]
|
243
|
+
|
244
|
+
def get_buffer(self, node: torch.fx.Node) -> circle.Buffer.BufferT:
|
245
|
+
buf_id = self.get_tensor(node).buffer
|
246
|
+
return self.model.buffers[buf_id]
|
247
|
+
|
248
|
+
# TODO Rename, it doesn't only get_tid but also possibly add a new const tensor
|
249
|
+
def get_tid(
|
250
|
+
self, node: Union[torch.fx.node.Node, circle.Tensor.TensorT, ConstData]
|
251
|
+
) -> int:
|
252
|
+
# return -1 if node is None. This is for generating CircleOutputExclude
|
253
|
+
if node == None:
|
254
|
+
return -1
|
255
|
+
|
256
|
+
if hasattr(node, "name") and node.name in self.name_to_tid:
|
257
|
+
return self.name_to_tid[node.name]
|
258
|
+
|
259
|
+
if is_const(node):
|
260
|
+
node_name = self.add_const_tensor(cast(ConstData, node)).name
|
261
|
+
return self.name_to_tid[node_name]
|
262
|
+
|
263
|
+
# Unreachable
|
264
|
+
raise RuntimeError("fx Node was not converted to tensor.")
|
@@ -0,0 +1,177 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Tuple, TYPE_CHECKING, Union
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import numpy as np
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
|
24
|
+
# Convert torch dtype to circle dtype
|
25
|
+
def to_circle_dtype(
|
26
|
+
torch_dtype: torch.dtype,
|
27
|
+
) -> int:
|
28
|
+
assert isinstance(torch_dtype, torch.dtype)
|
29
|
+
dmap = {
|
30
|
+
torch.float32: circle.TensorType.TensorType.FLOAT32,
|
31
|
+
torch.float: circle.TensorType.TensorType.FLOAT32,
|
32
|
+
torch.uint8: circle.TensorType.TensorType.UINT8,
|
33
|
+
torch.int8: circle.TensorType.TensorType.INT8,
|
34
|
+
torch.int16: circle.TensorType.TensorType.INT16,
|
35
|
+
torch.short: circle.TensorType.TensorType.INT16,
|
36
|
+
torch.int32: circle.TensorType.TensorType.INT32,
|
37
|
+
torch.int: circle.TensorType.TensorType.INT32,
|
38
|
+
torch.int64: circle.TensorType.TensorType.INT64,
|
39
|
+
torch.bool: circle.TensorType.TensorType.BOOL,
|
40
|
+
}
|
41
|
+
|
42
|
+
if torch_dtype not in dmap:
|
43
|
+
raise RuntimeError(f"Unsupported dtype {torch_dtype}")
|
44
|
+
|
45
|
+
circle_type = dmap[torch_dtype]
|
46
|
+
assert circle_type is not None
|
47
|
+
return circle_type
|
48
|
+
|
49
|
+
|
50
|
+
# Convert str dtype used in QuantParam to circle dtype
|
51
|
+
def str_to_circle_dtype(
|
52
|
+
str_dtype: str,
|
53
|
+
) -> int:
|
54
|
+
dmap = {
|
55
|
+
"float32": circle.TensorType.TensorType.FLOAT32,
|
56
|
+
"float": circle.TensorType.TensorType.FLOAT32,
|
57
|
+
"uint8": circle.TensorType.TensorType.UINT8,
|
58
|
+
"int8": circle.TensorType.TensorType.INT8,
|
59
|
+
"int16": circle.TensorType.TensorType.INT16,
|
60
|
+
"short": circle.TensorType.TensorType.INT16,
|
61
|
+
"int32": circle.TensorType.TensorType.INT32,
|
62
|
+
"int": circle.TensorType.TensorType.INT32,
|
63
|
+
"int64": circle.TensorType.TensorType.INT64,
|
64
|
+
"bool": circle.TensorType.TensorType.BOOL,
|
65
|
+
"uint4": circle.TensorType.TensorType.UINT4,
|
66
|
+
# TODO Add more dtypes
|
67
|
+
}
|
68
|
+
|
69
|
+
if str_dtype not in dmap:
|
70
|
+
raise RuntimeError(f"Unsupported dtype {str_dtype}")
|
71
|
+
|
72
|
+
circle_type = dmap[str_dtype]
|
73
|
+
assert circle_type is not None
|
74
|
+
return circle_type
|
75
|
+
|
76
|
+
|
77
|
+
# Convert circle dtype to numpy dtype
|
78
|
+
def np_dtype_from_circle_dtype(circle_dtype: int):
|
79
|
+
dmap = {
|
80
|
+
circle.TensorType.TensorType.FLOAT32: np.float32,
|
81
|
+
circle.TensorType.TensorType.UINT8: np.uint8,
|
82
|
+
circle.TensorType.TensorType.INT8: np.int8,
|
83
|
+
circle.TensorType.TensorType.INT16: np.int16,
|
84
|
+
circle.TensorType.TensorType.INT32: np.int32,
|
85
|
+
circle.TensorType.TensorType.INT64: np.int64,
|
86
|
+
circle.TensorType.TensorType.BOOL: np.bool_,
|
87
|
+
}
|
88
|
+
|
89
|
+
if circle_dtype not in dmap:
|
90
|
+
raise RuntimeError(f"Unsupported dtype {circle_dtype}")
|
91
|
+
|
92
|
+
np_dtype = dmap[circle_dtype]
|
93
|
+
assert np_dtype is not None
|
94
|
+
return np_dtype
|
95
|
+
|
96
|
+
|
97
|
+
# Return dtype of node
|
98
|
+
def extract_torch_dtype(node: torch.fx.Node) -> torch.dtype:
|
99
|
+
assert node.meta is not None
|
100
|
+
assert node.meta.get("val") is not None
|
101
|
+
|
102
|
+
val = node.meta.get("val")
|
103
|
+
val_dtype = None
|
104
|
+
if isinstance(val, torch.Tensor):
|
105
|
+
assert isinstance(val.dtype, torch.dtype)
|
106
|
+
val_dtype = val.dtype
|
107
|
+
else:
|
108
|
+
val_dtype = torch.tensor(val).dtype
|
109
|
+
return val_dtype
|
110
|
+
|
111
|
+
|
112
|
+
def extract_circle_dtype(node: torch.fx.Node) -> int:
|
113
|
+
return to_circle_dtype(extract_torch_dtype(node))
|
114
|
+
|
115
|
+
|
116
|
+
# Return shape of node
|
117
|
+
def extract_shape(node: torch.fx.Node) -> torch.Size:
|
118
|
+
assert node.meta is not None
|
119
|
+
assert node.meta.get("val") is not None
|
120
|
+
|
121
|
+
val = node.meta.get("val")
|
122
|
+
val_shape = None
|
123
|
+
if isinstance(val, torch.Tensor):
|
124
|
+
val_shape = val.size()
|
125
|
+
else:
|
126
|
+
val_shape = torch.tensor(val).shape
|
127
|
+
|
128
|
+
return val_shape
|
129
|
+
|
130
|
+
|
131
|
+
# Return stride of node
|
132
|
+
def extract_stride(node: torch.fx.Node) -> Tuple[int, ...]:
|
133
|
+
assert node.meta is not None
|
134
|
+
assert node.meta.get("val") is not None
|
135
|
+
|
136
|
+
val = node.meta.get("val")
|
137
|
+
val_stride = None
|
138
|
+
assert isinstance(val, torch.Tensor)
|
139
|
+
val_stride = val.stride()
|
140
|
+
|
141
|
+
return val_stride
|
142
|
+
|
143
|
+
|
144
|
+
def traverse_elements(iter, container_types=(list, tuple)):
|
145
|
+
if isinstance(iter, container_types):
|
146
|
+
for e in iter:
|
147
|
+
for sub_e in traverse_elements(e, container_types):
|
148
|
+
yield sub_e
|
149
|
+
else:
|
150
|
+
yield iter
|
151
|
+
|
152
|
+
|
153
|
+
def check_if_i32_range(axis: Union[list, int]):
|
154
|
+
INT32_MAX = 2**31 - 1
|
155
|
+
INT32_MIN = -(2**31)
|
156
|
+
values = list(traverse_elements(axis))
|
157
|
+
return all(INT32_MIN <= val <= INT32_MAX for val in values)
|
158
|
+
|
159
|
+
|
160
|
+
def circle_legalize_dtype_to(values, *, dtype: torch.dtype):
|
161
|
+
"""
|
162
|
+
Legalize data types from `torch.int64` to `torch.int32`.
|
163
|
+
|
164
|
+
Pytorch assumes python's built-in integer type is `torch.int64`.
|
165
|
+
But, many of the circle infrastructures support only int32 type. E.g. circle-interpreter.
|
166
|
+
|
167
|
+
So, if constants has values whose range is inside [INT32_MIN <= val <= INT32_MAX], we will legalize the data type to int32.
|
168
|
+
|
169
|
+
TODO support more types
|
170
|
+
|
171
|
+
NOTE. This function must be applied only to constant values.
|
172
|
+
"""
|
173
|
+
if dtype != torch.int32:
|
174
|
+
raise RuntimeError("Not supported data types.")
|
175
|
+
if not check_if_i32_range(values):
|
176
|
+
raise RuntimeError("'size' cannot be converted from int64 to int32.")
|
177
|
+
return torch.as_tensor(values, dtype=dtype)
|
@@ -0,0 +1,232 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import operator
|
16
|
+
from typing import Dict
|
17
|
+
|
18
|
+
import flatbuffers
|
19
|
+
import torch
|
20
|
+
from circle_schema import circle
|
21
|
+
from torch.export.exported_program import (
|
22
|
+
ConstantArgument,
|
23
|
+
ExportedProgram,
|
24
|
+
InputKind,
|
25
|
+
TensorArgument,
|
26
|
+
)
|
27
|
+
|
28
|
+
from tico.serialize.circle_mapping import to_circle_dtype
|
29
|
+
from tico.serialize.operators import *
|
30
|
+
from tico.serialize.circle_graph import CircleModel, CircleSubgraph
|
31
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
32
|
+
from tico.serialize.operators.node_visitor import get_node_visitors
|
33
|
+
from tico.utils import logging
|
34
|
+
|
35
|
+
|
36
|
+
multiple_output_ops = [
|
37
|
+
torch.ops.aten.split_with_sizes.default,
|
38
|
+
]
|
39
|
+
|
40
|
+
# Build circle model from ExportedProgram
|
41
|
+
# Return raw bytes of circle model
|
42
|
+
def build_circle(edge_program: ExportedProgram) -> bytes:
|
43
|
+
logger = logging.getLogger(__name__)
|
44
|
+
|
45
|
+
builder = flatbuffers.Builder()
|
46
|
+
|
47
|
+
# Init Model
|
48
|
+
model = CircleModel()
|
49
|
+
|
50
|
+
# Add empty buffer at the front (convention)
|
51
|
+
model.add_buffer(circle.Buffer.BufferT())
|
52
|
+
|
53
|
+
# Create an empty subgraph (assume a single subgraph)
|
54
|
+
graph = CircleSubgraph(model)
|
55
|
+
|
56
|
+
# Export tensors
|
57
|
+
logger.debug("---------------Export tensors--------------")
|
58
|
+
buf_name_to_data = {name: buf for name, buf in edge_program.named_buffers()}
|
59
|
+
for node in edge_program.graph.nodes:
|
60
|
+
if node.op == "call_function":
|
61
|
+
if node.target in multiple_output_ops:
|
62
|
+
continue
|
63
|
+
node_val = node.meta["val"]
|
64
|
+
if node_val.layout != torch.strided:
|
65
|
+
raise RuntimeError(
|
66
|
+
f"Only support dense tensors (node layout: {node_val.layout})"
|
67
|
+
)
|
68
|
+
graph.add_tensor_from_node(node)
|
69
|
+
logger.debug(f"call_function: {node.name} tensor exported.")
|
70
|
+
|
71
|
+
# placeholder: function input (including parameters, buffers, constant tensors)
|
72
|
+
elif node.op == "placeholder":
|
73
|
+
# placeholder invariants
|
74
|
+
assert node.args is None or len(node.args) == 0 # Not support default param
|
75
|
+
|
76
|
+
# parameters
|
77
|
+
if node.name in edge_program.graph_signature.inputs_to_parameters:
|
78
|
+
param_name = edge_program.graph_signature.inputs_to_parameters[
|
79
|
+
node.name
|
80
|
+
]
|
81
|
+
param_data = edge_program.state_dict[param_name]
|
82
|
+
|
83
|
+
assert isinstance(
|
84
|
+
param_data, torch.Tensor
|
85
|
+
), "Expect parameters to be a tensor"
|
86
|
+
param_value = param_data.cpu().detach().numpy()
|
87
|
+
|
88
|
+
graph.add_tensor_from_node(node, param_value)
|
89
|
+
logger.debug(f"placeholder(param): {node.name} tensor exported.")
|
90
|
+
elif node.name in edge_program.graph_signature.inputs_to_buffers:
|
91
|
+
buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name]
|
92
|
+
assert buffer_name in buf_name_to_data
|
93
|
+
buffer_data = buf_name_to_data[buffer_name]
|
94
|
+
assert isinstance(
|
95
|
+
buffer_data, torch.Tensor
|
96
|
+
), "Expect buffers to be a tensor"
|
97
|
+
buffer_value = buffer_data.cpu().detach().numpy()
|
98
|
+
|
99
|
+
graph.add_tensor_from_node(node, buffer_value)
|
100
|
+
logger.debug(f"placeholder(buffer): {node.name} tensor exported.")
|
101
|
+
elif (
|
102
|
+
node.name
|
103
|
+
in edge_program.graph_signature.inputs_to_lifted_tensor_constants
|
104
|
+
):
|
105
|
+
ctensor_name = (
|
106
|
+
edge_program.graph_signature.inputs_to_lifted_tensor_constants[
|
107
|
+
node.name
|
108
|
+
]
|
109
|
+
)
|
110
|
+
ctensor_data = edge_program.constants[ctensor_name]
|
111
|
+
|
112
|
+
assert isinstance(
|
113
|
+
ctensor_data, torch.Tensor
|
114
|
+
), "Expect constant tensor to be a tensor"
|
115
|
+
ctensor_value = ctensor_data.cpu().detach().numpy()
|
116
|
+
|
117
|
+
graph.add_tensor_from_node(node, ctensor_value)
|
118
|
+
logger.debug(
|
119
|
+
f"placeholder(constant tensor): {node.name} tensor exported."
|
120
|
+
)
|
121
|
+
else:
|
122
|
+
user_inputs = [
|
123
|
+
specs
|
124
|
+
for specs in edge_program.graph_signature.input_specs
|
125
|
+
if specs.kind == InputKind.USER_INPUT
|
126
|
+
]
|
127
|
+
constant_inputs = [
|
128
|
+
specs
|
129
|
+
for specs in user_inputs
|
130
|
+
if isinstance(specs.arg, ConstantArgument)
|
131
|
+
]
|
132
|
+
name_to_value = {
|
133
|
+
specs.arg.name: specs.arg.value for specs in constant_inputs
|
134
|
+
}
|
135
|
+
# NoneType ConstantArgument is ignored.
|
136
|
+
if node.name in name_to_value and name_to_value[node.name] == None:
|
137
|
+
continue
|
138
|
+
graph.add_tensor_from_node(node)
|
139
|
+
logger.debug(f"placeholder: {node.name} tensor exported.")
|
140
|
+
|
141
|
+
# get_attr: retrieve parameter
|
142
|
+
elif node.op == "get_attr":
|
143
|
+
# node.name: Place where fetched attribute is saved
|
144
|
+
# node.target: Attribute in the module
|
145
|
+
attr_tensor = getattr(node.graph.owning_module, node.target)
|
146
|
+
assert isinstance(attr_tensor, torch.Tensor)
|
147
|
+
|
148
|
+
graph.add_tensor_from_scratch(
|
149
|
+
prefix=node.name,
|
150
|
+
shape=list(attr_tensor.shape),
|
151
|
+
dtype=to_circle_dtype(attr_tensor.dtype),
|
152
|
+
)
|
153
|
+
|
154
|
+
logger.debug(f"get_attr: {node.name} tensor exported.")
|
155
|
+
|
156
|
+
# output: function output
|
157
|
+
elif node.op == "output":
|
158
|
+
# output node itself does not need a buffer
|
159
|
+
# argument of output node is assumed to be exported beforehand
|
160
|
+
for output in node.args[0]:
|
161
|
+
if isinstance(output, torch.fx.Node):
|
162
|
+
assert graph.has_tensor(output.name)
|
163
|
+
continue
|
164
|
+
|
165
|
+
# call_method: call method
|
166
|
+
elif node.op == "call_method":
|
167
|
+
raise AssertionError("Not yet implemented")
|
168
|
+
|
169
|
+
# call_module: call 'forward' of module
|
170
|
+
elif node.op == "call_module":
|
171
|
+
raise AssertionError("Not yet implemented")
|
172
|
+
|
173
|
+
else:
|
174
|
+
# Add more if fx.Node is extended
|
175
|
+
raise AssertionError(f"Unknown fx.Node op {node.op}")
|
176
|
+
|
177
|
+
# Register inputs
|
178
|
+
logger.debug("---------------Register inputs--------------")
|
179
|
+
for in_spec in edge_program.graph_signature.input_specs:
|
180
|
+
if in_spec.kind != InputKind.USER_INPUT:
|
181
|
+
continue
|
182
|
+
# NoneType ConstantArgument is ignored.
|
183
|
+
if isinstance(in_spec.arg, ConstantArgument) and in_spec.arg.value == None:
|
184
|
+
continue
|
185
|
+
arg_name = in_spec.arg.name
|
186
|
+
graph.add_input(arg_name)
|
187
|
+
logger.debug(f"Registered input: {arg_name}")
|
188
|
+
|
189
|
+
# Register outputs
|
190
|
+
logger.debug("---------------Register outputs--------------")
|
191
|
+
for user_output in edge_program.graph_signature.user_outputs:
|
192
|
+
graph.add_output(user_output)
|
193
|
+
logger.debug(f"Registered output: {user_output}")
|
194
|
+
|
195
|
+
# Export operators
|
196
|
+
logger.debug("---------------Export operators--------------")
|
197
|
+
op_codes: Dict[OpCode, int] = {}
|
198
|
+
visitors = get_node_visitors(op_codes, graph)
|
199
|
+
for node in edge_program.graph.nodes:
|
200
|
+
if node.op != "call_function":
|
201
|
+
continue
|
202
|
+
|
203
|
+
opcode = node.target
|
204
|
+
if opcode == operator.getitem:
|
205
|
+
continue
|
206
|
+
if opcode not in visitors:
|
207
|
+
raise RuntimeError(f"{opcode} is not yet supported")
|
208
|
+
circle_op = visitors[opcode].define_node(node)
|
209
|
+
|
210
|
+
if circle_op:
|
211
|
+
graph.add_operator(circle_op)
|
212
|
+
logger.debug(f"call_function: {node.name} ({opcode}) Op exported.")
|
213
|
+
|
214
|
+
# Register subgraph
|
215
|
+
model.subgraphs.append(graph)
|
216
|
+
|
217
|
+
# Encode operator codes
|
218
|
+
model.operatorCodes = [
|
219
|
+
code for code, _ in sorted(op_codes.items(), key=lambda x: x[1])
|
220
|
+
]
|
221
|
+
|
222
|
+
# Description
|
223
|
+
model.description = "circle"
|
224
|
+
|
225
|
+
# Set version
|
226
|
+
model.version = 0
|
227
|
+
|
228
|
+
# Finish model
|
229
|
+
builder.Finish(model.Pack(builder), "CIR0".encode("utf8"))
|
230
|
+
buf = builder.Output()
|
231
|
+
|
232
|
+
return bytes(buf)
|
@@ -0,0 +1,28 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import glob
|
16
|
+
from os.path import basename, dirname, isfile, join
|
17
|
+
|
18
|
+
from tico.utils.register_custom_op import RegisterOps
|
19
|
+
|
20
|
+
|
21
|
+
# Register custom ops to torch namespace
|
22
|
+
RegisterOps()
|
23
|
+
|
24
|
+
# Load all modules in the current directory
|
25
|
+
modules = glob.glob(join(dirname(__file__), "*.py"))
|
26
|
+
__all__ = [
|
27
|
+
basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py")
|
28
|
+
]
|