tico 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +42 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/quantize_bias.py +123 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +191 -0
- tico/passes/cast_mixed_type_args.py +187 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +160 -0
- tico/passes/convert_layout_op_to_reshape.py +85 -0
- tico/passes/convert_repeat_to_expand_copy.py +89 -0
- tico/passes/convert_to_relu6.py +181 -0
- tico/passes/decompose_addmm.py +124 -0
- tico/passes/decompose_batch_norm.py +192 -0
- tico/passes/decompose_fake_quantize.py +134 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
- tico/passes/decompose_group_norm.py +275 -0
- tico/passes/decompose_grouped_conv2d.py +209 -0
- tico/passes/decompose_slice_scatter.py +169 -0
- tico/passes/extract_dtype_kwargs.py +122 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +108 -0
- tico/passes/legalize_predefined_layout_operators.py +386 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
- tico/passes/lower_to_slice.py +230 -0
- tico/passes/merge_consecutive_cat.py +80 -0
- tico/passes/ops.py +78 -0
- tico/passes/remove_nop.py +84 -0
- tico/passes/remove_redundant_assert_nodes.py +51 -0
- tico/passes/remove_redundant_expand.py +66 -0
- tico/passes/remove_redundant_permute.py +122 -0
- tico/passes/remove_redundant_reshape.py +436 -0
- tico/passes/remove_redundant_slice.py +62 -0
- tico/passes/remove_redundant_to_copy.py +86 -0
- tico/passes/restore_linear.py +115 -0
- tico/passes/segment_index_select.py +145 -0
- tico/pt2_to_circle.py +105 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +319 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +240 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_abs.py +53 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +150 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +192 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +126 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +186 -0
- tico/serialize/operators/op_copy.py +164 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +95 -0
- tico/serialize/operators/op_depthwise_conv2d.py +199 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_leaky_relu.py +60 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +86 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_dim.py +70 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +177 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +141 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +100 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +99 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +94 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +296 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +282 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/mx/__init__.py +1 -0
- tico/utils/mx/elemwise_ops.py +267 -0
- tico/utils/mx/formats.py +125 -0
- tico/utils/mx/mx_ops.py +270 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +609 -0
- tico/utils/serialize.py +42 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +406 -0
- tico/utils/validate_args_kwargs.py +1149 -0
- tico-0.1.0.dist-info/LICENSE +241 -0
- tico-0.1.0.dist-info/METADATA +354 -0
- tico-0.1.0.dist-info/RECORD +206 -0
- tico-0.1.0.dist-info/WHEEL +5 -0
- tico-0.1.0.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,319 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from collections import defaultdict
|
16
|
+
from typing import Any, cast, Dict, final, List, Optional, TYPE_CHECKING, Union
|
17
|
+
|
18
|
+
if TYPE_CHECKING:
|
19
|
+
import torch.fx
|
20
|
+
import numpy as np
|
21
|
+
import torch
|
22
|
+
from circle_schema import circle
|
23
|
+
from torch._subclasses.fake_tensor import FakeTensor
|
24
|
+
|
25
|
+
from tico.serialize.circle_mapping import (
|
26
|
+
extract_circle_dtype,
|
27
|
+
extract_shape,
|
28
|
+
str_to_circle_dtype,
|
29
|
+
to_circle_dtype,
|
30
|
+
)
|
31
|
+
from tico.serialize.pack import pack_buffer
|
32
|
+
from tico.serialize.quant_param import QPARAM_KEY, QuantParam
|
33
|
+
from tico.utils.utils import to_circle_qparam
|
34
|
+
|
35
|
+
"""
|
36
|
+
Type alias for const
|
37
|
+
"""
|
38
|
+
_PRIMITIVE_TYPES = (
|
39
|
+
float,
|
40
|
+
int,
|
41
|
+
bool,
|
42
|
+
str,
|
43
|
+
torch.Tensor,
|
44
|
+
torch.device,
|
45
|
+
torch.dtype,
|
46
|
+
torch.layout,
|
47
|
+
)
|
48
|
+
ConstDataElement = Union[
|
49
|
+
int, float, bool, str, torch.Tensor, torch.device, torch.dtype, torch.layout
|
50
|
+
]
|
51
|
+
ConstData = Union[ConstDataElement, List[ConstDataElement]]
|
52
|
+
|
53
|
+
|
54
|
+
def is_const(arg) -> bool:
|
55
|
+
if isinstance(arg, FakeTensor):
|
56
|
+
return False
|
57
|
+
if isinstance(arg, _PRIMITIVE_TYPES):
|
58
|
+
return True
|
59
|
+
if isinstance(arg, (tuple, list)):
|
60
|
+
return all(map(is_const, arg))
|
61
|
+
if isinstance(arg, dict):
|
62
|
+
return all(map(is_const, arg.values()))
|
63
|
+
return False
|
64
|
+
|
65
|
+
|
66
|
+
@final
|
67
|
+
class CircleModel(circle.Model.ModelT):
|
68
|
+
def __init__(self):
|
69
|
+
super().__init__()
|
70
|
+
self.subgraphs: List[circle.SubGraph.SubGraphT] = []
|
71
|
+
self.buffers: List[circle.Buffer.BufferT] = []
|
72
|
+
|
73
|
+
def add_subgraph(self, graph: circle.SubGraph.SubGraphT) -> None:
|
74
|
+
self.subgraphs.append(graph)
|
75
|
+
|
76
|
+
def add_buffer(self, buffer: circle.Buffer.BufferT) -> int:
|
77
|
+
"""Return buffer id"""
|
78
|
+
self.buffers.append(buffer)
|
79
|
+
buf_id = len(self.buffers) - 1 # last index
|
80
|
+
return buf_id
|
81
|
+
|
82
|
+
|
83
|
+
@final
|
84
|
+
class CircleSubgraph(circle.SubGraph.SubGraphT):
|
85
|
+
def __init__(self, model: CircleModel):
|
86
|
+
super().__init__()
|
87
|
+
self.model: CircleModel = model
|
88
|
+
self.name: str = "subgraph"
|
89
|
+
self.inputs: List[int] = []
|
90
|
+
self.outputs: List[int] = []
|
91
|
+
self.tensors: List[circle.Tensor.TensorT] = []
|
92
|
+
self.operators: List[circle.Operator.OperatorT] = []
|
93
|
+
self.name_to_tid: Dict[str, int] = {}
|
94
|
+
# Mapping from Circle tensor names to their originating FX nodes.
|
95
|
+
# Used to trace back tensor definitions to their source and finalize
|
96
|
+
# human-readable tensor names after serialization.
|
97
|
+
self.name_to_node: Dict[str, torch.fx.Node] = {}
|
98
|
+
self.counter: defaultdict = defaultdict(int)
|
99
|
+
|
100
|
+
# Generate a unique name with prefix.
|
101
|
+
# Naming rule
|
102
|
+
# - If no tensor has the same name with prefix, return prefix
|
103
|
+
# - Otherwise, add postfix f"_{idx}" where idx increases by 1 from 0
|
104
|
+
# Example
|
105
|
+
# If prefix = "add", this function will find a unique name in the following order.
|
106
|
+
# "add", "add_0", "add_1", ...
|
107
|
+
def _gen_unique_name_with_prefix(self, prefix: str):
|
108
|
+
name = prefix
|
109
|
+
while self.has_tensor(name):
|
110
|
+
index = self.counter[prefix]
|
111
|
+
name = f"{prefix}_{index}"
|
112
|
+
self.counter[prefix] += 1
|
113
|
+
|
114
|
+
return name
|
115
|
+
|
116
|
+
def _add_tensor(self, tensor: circle.Tensor.TensorT) -> None:
|
117
|
+
self.tensors.append(tensor)
|
118
|
+
assert tensor.name not in self.name_to_tid
|
119
|
+
self.name_to_tid[tensor.name] = len(self.tensors) - 1
|
120
|
+
|
121
|
+
def add_operator(self, op: circle.Operator.OperatorT) -> None:
|
122
|
+
self.operators.append(op)
|
123
|
+
|
124
|
+
def add_input(self, input_name: str) -> None:
|
125
|
+
assert input_name in self.name_to_tid, f"{input_name}"
|
126
|
+
tid = self.name_to_tid[input_name]
|
127
|
+
self.inputs.append(tid)
|
128
|
+
|
129
|
+
def add_output(self, output: Any) -> None:
|
130
|
+
if isinstance(output, str):
|
131
|
+
assert output in self.name_to_tid
|
132
|
+
output_name = output
|
133
|
+
elif isinstance(output, int | float):
|
134
|
+
# output is built-in type.
|
135
|
+
circle_tensor = self.add_const_tensor(output)
|
136
|
+
output_name = circle_tensor.name
|
137
|
+
else:
|
138
|
+
raise NotImplementedError(f"Unsupported output dtype: {type(output)}")
|
139
|
+
tid = self.name_to_tid[output_name]
|
140
|
+
self.outputs.append(tid)
|
141
|
+
|
142
|
+
def has_tensor(self, name: str):
|
143
|
+
return name in self.name_to_tid
|
144
|
+
|
145
|
+
def add_tensor_from_node(
|
146
|
+
self, node: torch.fx.Node, data: Optional[np.ndarray] = None
|
147
|
+
) -> None:
|
148
|
+
tensor = circle.Tensor.TensorT()
|
149
|
+
tensor.name = self._gen_unique_name_with_prefix(node.name)
|
150
|
+
assert tensor.name not in self.name_to_node
|
151
|
+
self.name_to_node[tensor.name] = node
|
152
|
+
assert node.meta.get("val") is not None
|
153
|
+
tensor.type = extract_circle_dtype(node)
|
154
|
+
tensor.shape = list(extract_shape(node))
|
155
|
+
if QPARAM_KEY in node.meta:
|
156
|
+
tensor.quantization = to_circle_qparam(node.meta[QPARAM_KEY])
|
157
|
+
tensor.type = str_to_circle_dtype(node.meta[QPARAM_KEY].dtype)
|
158
|
+
|
159
|
+
buffer = circle.Buffer.BufferT()
|
160
|
+
if data is not None and isinstance(data, np.ndarray):
|
161
|
+
data = data.flatten()
|
162
|
+
|
163
|
+
if QPARAM_KEY in node.meta:
|
164
|
+
if node.meta[QPARAM_KEY].dtype == "uint4":
|
165
|
+
data = pack_buffer(data, "uint4")
|
166
|
+
|
167
|
+
# Packing np.ndarray is faster than packing bytes
|
168
|
+
buffer.data = data.view(np.uint8) # type: ignore[assignment]
|
169
|
+
else:
|
170
|
+
assert data is None
|
171
|
+
bid = self.model.add_buffer(buffer)
|
172
|
+
tensor.buffer = bid
|
173
|
+
self._add_tensor(tensor)
|
174
|
+
|
175
|
+
def add_const_tensor(
|
176
|
+
self, data: ConstData, source_node: Optional[torch.fx.Node] = None
|
177
|
+
) -> circle.Tensor.TensorT:
|
178
|
+
assert is_const(data)
|
179
|
+
tensor = circle.Tensor.TensorT()
|
180
|
+
tensor.name = self._gen_unique_name_with_prefix("const_tensor")
|
181
|
+
assert tensor.name not in self.name_to_node
|
182
|
+
if source_node is not None:
|
183
|
+
self.name_to_node[tensor.name] = source_node
|
184
|
+
assert not self.has_tensor(tensor.name)
|
185
|
+
torch_t = torch.as_tensor(data=data)
|
186
|
+
torch_t_shape = list(torch_t.size())
|
187
|
+
tensor.type = to_circle_dtype(torch_dtype=torch_t.dtype)
|
188
|
+
tensor.shape = torch_t_shape
|
189
|
+
|
190
|
+
buffer = circle.Buffer.BufferT()
|
191
|
+
buffer.data = torch_t.flatten().cpu().numpy().view(np.uint8) # type: ignore[assignment]
|
192
|
+
bid = self.model.add_buffer(buffer)
|
193
|
+
tensor.buffer = bid
|
194
|
+
self._add_tensor(tensor)
|
195
|
+
|
196
|
+
return tensor
|
197
|
+
|
198
|
+
def add_tensor_from_scratch(
|
199
|
+
self,
|
200
|
+
prefix: str,
|
201
|
+
shape: List[int],
|
202
|
+
dtype: int,
|
203
|
+
qparam: Optional[QuantParam] = None,
|
204
|
+
source_node: Optional[torch.fx.Node] = None,
|
205
|
+
) -> circle.Tensor.TensorT:
|
206
|
+
"""
|
207
|
+
Create a new tensor and register it into the Circle subgraph from scratch.
|
208
|
+
|
209
|
+
This function is used to allocate tensors that are not directly derived from
|
210
|
+
values in the FX graph, such as those created by padding or shape-generating
|
211
|
+
operators.
|
212
|
+
|
213
|
+
If a `source_node` is provided, it is used to enrich the tensor's metadata
|
214
|
+
(e.g., by associating the tensor with the module hierarchy path stored in
|
215
|
+
the node's `nn_module_stack`). This enables better traceability and more
|
216
|
+
informative tensor names in the final Circle model.
|
217
|
+
|
218
|
+
Parameters
|
219
|
+
----------
|
220
|
+
prefix : str
|
221
|
+
A name prefix used to generate a unique tensor name.
|
222
|
+
shape : List[int]
|
223
|
+
The shape of the tensor.
|
224
|
+
dtype : int
|
225
|
+
The Circle-compatible dtype of the tensor. Use `to_circle_dtype()` to convert.
|
226
|
+
qparam : Optional[QuantParam]
|
227
|
+
Optional quantization parameters to apply to the tensor.
|
228
|
+
source_node : Optional[torch.fx.Node]
|
229
|
+
If provided, the FX node from which this tensor originates. Used to generate
|
230
|
+
a richer name and track module origin.
|
231
|
+
|
232
|
+
Returns
|
233
|
+
-------
|
234
|
+
circle.Tensor.TensorT
|
235
|
+
The newly created and registered tensor.
|
236
|
+
"""
|
237
|
+
assert isinstance(dtype, int), f"{dtype} must be integer. Use to_circle_dtype."
|
238
|
+
tensor = circle.Tensor.TensorT()
|
239
|
+
tensor.name = self._gen_unique_name_with_prefix(prefix)
|
240
|
+
assert tensor.name not in self.name_to_node
|
241
|
+
if source_node is not None:
|
242
|
+
self.name_to_node[tensor.name] = source_node
|
243
|
+
tensor.shape = shape
|
244
|
+
if qparam is not None:
|
245
|
+
tensor.quantization = to_circle_qparam(qparam)
|
246
|
+
tensor.type = str_to_circle_dtype(qparam.dtype)
|
247
|
+
else:
|
248
|
+
tensor.type = dtype
|
249
|
+
|
250
|
+
buffer = circle.Buffer.BufferT()
|
251
|
+
bid = self.model.add_buffer(buffer)
|
252
|
+
tensor.buffer = bid
|
253
|
+
self._add_tensor(tensor)
|
254
|
+
|
255
|
+
return tensor
|
256
|
+
|
257
|
+
# Some operators like `full`, `arange_start_step` or `scalar_tensor` needs buffers to be in-place updated.
|
258
|
+
# TODO remove this function
|
259
|
+
def update_tensor_buffer(
|
260
|
+
self, data: ConstData, tensor_name: str = str()
|
261
|
+
) -> circle.Tensor.TensorT:
|
262
|
+
assert is_const(data)
|
263
|
+
assert self.has_tensor(tensor_name)
|
264
|
+
data_tensor = torch.as_tensor(data=data)
|
265
|
+
data_shape = list(data_tensor.size())
|
266
|
+
op_tensor = self.tensors[self.name_to_tid[tensor_name]]
|
267
|
+
assert op_tensor.type == to_circle_dtype(
|
268
|
+
data_tensor.dtype
|
269
|
+
), f"{op_tensor.type}, {data_tensor.dtype}"
|
270
|
+
assert op_tensor.shape == data_shape
|
271
|
+
|
272
|
+
buffer = circle.Buffer.BufferT()
|
273
|
+
# Packing np.ndarray is faster than packing bytes
|
274
|
+
buffer.data = data_tensor.flatten().cpu().numpy().view(np.uint8) # type: ignore[assignment]
|
275
|
+
bid = self.model.add_buffer(buffer)
|
276
|
+
op_tensor.buffer = bid
|
277
|
+
|
278
|
+
return op_tensor
|
279
|
+
|
280
|
+
def get_tid_registered(
|
281
|
+
self, node: Union[torch.fx.node.Node, circle.Tensor.TensorT]
|
282
|
+
) -> int:
|
283
|
+
assert hasattr(node, "name"), "FIX CALLER UNLESS"
|
284
|
+
|
285
|
+
tid = self.name_to_tid.get(node.name, None)
|
286
|
+
|
287
|
+
if tid is None:
|
288
|
+
raise KeyError(f"{node}({node.name}) is not registered.")
|
289
|
+
|
290
|
+
assert tid < len(self.tensors)
|
291
|
+
|
292
|
+
return tid
|
293
|
+
|
294
|
+
def get_tensor(self, node: torch.fx.node.Node) -> circle.Tensor.TensorT:
|
295
|
+
tid = self.get_tid_registered(node)
|
296
|
+
|
297
|
+
return self.tensors[tid]
|
298
|
+
|
299
|
+
def get_buffer(self, node: torch.fx.Node) -> circle.Buffer.BufferT:
|
300
|
+
buf_id = self.get_tensor(node).buffer
|
301
|
+
return self.model.buffers[buf_id]
|
302
|
+
|
303
|
+
# TODO Rename, it doesn't only get_tid but also possibly add a new const tensor
|
304
|
+
def get_tid(
|
305
|
+
self, node: Union[torch.fx.Node, circle.Tensor.TensorT, ConstData]
|
306
|
+
) -> int:
|
307
|
+
# return -1 if node is None. This is for generating CircleOutputExclude
|
308
|
+
if node == None:
|
309
|
+
return -1
|
310
|
+
|
311
|
+
if hasattr(node, "name") and node.name in self.name_to_tid:
|
312
|
+
return self.name_to_tid[node.name]
|
313
|
+
|
314
|
+
if is_const(node):
|
315
|
+
node_name = self.add_const_tensor(cast(ConstData, node)).name
|
316
|
+
return self.name_to_tid[node_name]
|
317
|
+
|
318
|
+
# Unreachable
|
319
|
+
raise RuntimeError("fx Node was not converted to tensor.")
|
@@ -0,0 +1,177 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Tuple, TYPE_CHECKING, Union
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import numpy as np
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
|
24
|
+
# Convert torch dtype to circle dtype
|
25
|
+
def to_circle_dtype(
|
26
|
+
torch_dtype: torch.dtype,
|
27
|
+
) -> int:
|
28
|
+
assert isinstance(torch_dtype, torch.dtype)
|
29
|
+
dmap = {
|
30
|
+
torch.float32: circle.TensorType.TensorType.FLOAT32,
|
31
|
+
torch.float: circle.TensorType.TensorType.FLOAT32,
|
32
|
+
torch.uint8: circle.TensorType.TensorType.UINT8,
|
33
|
+
torch.int8: circle.TensorType.TensorType.INT8,
|
34
|
+
torch.int16: circle.TensorType.TensorType.INT16,
|
35
|
+
torch.short: circle.TensorType.TensorType.INT16,
|
36
|
+
torch.int32: circle.TensorType.TensorType.INT32,
|
37
|
+
torch.int: circle.TensorType.TensorType.INT32,
|
38
|
+
torch.int64: circle.TensorType.TensorType.INT64,
|
39
|
+
torch.bool: circle.TensorType.TensorType.BOOL,
|
40
|
+
}
|
41
|
+
|
42
|
+
if torch_dtype not in dmap:
|
43
|
+
raise RuntimeError(f"Unsupported dtype {torch_dtype}")
|
44
|
+
|
45
|
+
circle_type = dmap[torch_dtype]
|
46
|
+
assert circle_type is not None
|
47
|
+
return circle_type
|
48
|
+
|
49
|
+
|
50
|
+
# Convert str dtype used in QuantParam to circle dtype
|
51
|
+
def str_to_circle_dtype(
|
52
|
+
str_dtype: str,
|
53
|
+
) -> int:
|
54
|
+
dmap = {
|
55
|
+
"float32": circle.TensorType.TensorType.FLOAT32,
|
56
|
+
"float": circle.TensorType.TensorType.FLOAT32,
|
57
|
+
"uint8": circle.TensorType.TensorType.UINT8,
|
58
|
+
"int8": circle.TensorType.TensorType.INT8,
|
59
|
+
"int16": circle.TensorType.TensorType.INT16,
|
60
|
+
"short": circle.TensorType.TensorType.INT16,
|
61
|
+
"int32": circle.TensorType.TensorType.INT32,
|
62
|
+
"int": circle.TensorType.TensorType.INT32,
|
63
|
+
"int64": circle.TensorType.TensorType.INT64,
|
64
|
+
"bool": circle.TensorType.TensorType.BOOL,
|
65
|
+
"uint4": circle.TensorType.TensorType.UINT4,
|
66
|
+
# TODO Add more dtypes
|
67
|
+
}
|
68
|
+
|
69
|
+
if str_dtype not in dmap:
|
70
|
+
raise RuntimeError(f"Unsupported dtype {str_dtype}")
|
71
|
+
|
72
|
+
circle_type = dmap[str_dtype]
|
73
|
+
assert circle_type is not None
|
74
|
+
return circle_type
|
75
|
+
|
76
|
+
|
77
|
+
# Convert circle dtype to numpy dtype
|
78
|
+
def np_dtype_from_circle_dtype(circle_dtype: int):
|
79
|
+
dmap = {
|
80
|
+
circle.TensorType.TensorType.FLOAT32: np.float32,
|
81
|
+
circle.TensorType.TensorType.UINT8: np.uint8,
|
82
|
+
circle.TensorType.TensorType.INT8: np.int8,
|
83
|
+
circle.TensorType.TensorType.INT16: np.int16,
|
84
|
+
circle.TensorType.TensorType.INT32: np.int32,
|
85
|
+
circle.TensorType.TensorType.INT64: np.int64,
|
86
|
+
circle.TensorType.TensorType.BOOL: np.bool_,
|
87
|
+
}
|
88
|
+
|
89
|
+
if circle_dtype not in dmap:
|
90
|
+
raise RuntimeError(f"Unsupported dtype {circle_dtype}")
|
91
|
+
|
92
|
+
np_dtype = dmap[circle_dtype]
|
93
|
+
assert np_dtype is not None
|
94
|
+
return np_dtype
|
95
|
+
|
96
|
+
|
97
|
+
# Return dtype of node
|
98
|
+
def extract_torch_dtype(node: torch.fx.Node) -> torch.dtype:
|
99
|
+
assert node.meta is not None
|
100
|
+
assert node.meta.get("val") is not None
|
101
|
+
|
102
|
+
val = node.meta.get("val")
|
103
|
+
val_dtype = None
|
104
|
+
if isinstance(val, torch.Tensor):
|
105
|
+
assert isinstance(val.dtype, torch.dtype)
|
106
|
+
val_dtype = val.dtype
|
107
|
+
else:
|
108
|
+
val_dtype = torch.tensor(val).dtype
|
109
|
+
return val_dtype
|
110
|
+
|
111
|
+
|
112
|
+
def extract_circle_dtype(node: torch.fx.Node) -> int:
|
113
|
+
return to_circle_dtype(extract_torch_dtype(node))
|
114
|
+
|
115
|
+
|
116
|
+
# Return shape of node
|
117
|
+
def extract_shape(node: torch.fx.Node) -> torch.Size:
|
118
|
+
assert node.meta is not None
|
119
|
+
assert node.meta.get("val") is not None
|
120
|
+
|
121
|
+
val = node.meta.get("val")
|
122
|
+
val_shape = None
|
123
|
+
if isinstance(val, torch.Tensor):
|
124
|
+
val_shape = val.size()
|
125
|
+
else:
|
126
|
+
val_shape = torch.tensor(val).shape
|
127
|
+
|
128
|
+
return val_shape
|
129
|
+
|
130
|
+
|
131
|
+
# Return stride of node
|
132
|
+
def extract_stride(node: torch.fx.Node) -> Tuple[int, ...]:
|
133
|
+
assert node.meta is not None
|
134
|
+
assert node.meta.get("val") is not None
|
135
|
+
|
136
|
+
val = node.meta.get("val")
|
137
|
+
val_stride = None
|
138
|
+
assert isinstance(val, torch.Tensor)
|
139
|
+
val_stride = val.stride()
|
140
|
+
|
141
|
+
return val_stride
|
142
|
+
|
143
|
+
|
144
|
+
def traverse_elements(iter, container_types=(list, tuple)):
|
145
|
+
if isinstance(iter, container_types):
|
146
|
+
for e in iter:
|
147
|
+
for sub_e in traverse_elements(e, container_types):
|
148
|
+
yield sub_e
|
149
|
+
else:
|
150
|
+
yield iter
|
151
|
+
|
152
|
+
|
153
|
+
def check_if_i32_range(axis: Union[list, int]):
|
154
|
+
INT32_MAX = 2**31 - 1
|
155
|
+
INT32_MIN = -(2**31)
|
156
|
+
values = list(traverse_elements(axis))
|
157
|
+
return all(INT32_MIN <= val <= INT32_MAX for val in values)
|
158
|
+
|
159
|
+
|
160
|
+
def circle_legalize_dtype_to(values, *, dtype: torch.dtype):
|
161
|
+
"""
|
162
|
+
Legalize data types from `torch.int64` to `torch.int32`.
|
163
|
+
|
164
|
+
Pytorch assumes python's built-in integer type is `torch.int64`.
|
165
|
+
But, many of the circle infrastructures support only int32 type. E.g. circle-interpreter.
|
166
|
+
|
167
|
+
So, if constants has values whose range is inside [INT32_MIN <= val <= INT32_MAX], we will legalize the data type to int32.
|
168
|
+
|
169
|
+
TODO support more types
|
170
|
+
|
171
|
+
NOTE. This function must be applied only to constant values.
|
172
|
+
"""
|
173
|
+
if dtype != torch.int32:
|
174
|
+
raise RuntimeError("Not supported data types.")
|
175
|
+
if not check_if_i32_range(values):
|
176
|
+
raise RuntimeError("'size' cannot be converted from int64 to int32.")
|
177
|
+
return torch.as_tensor(values, dtype=dtype)
|