tico 0.1.0.dev250629__py3-none-any.whl → 0.1.0.dev250701__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +63 -1
- tico/passes/legalize_predefined_layout_operators.py +84 -1
- tico/serialize/operators/op_conv2d.py +1 -1
- tico/serialize/operators/op_transpose_conv.py +165 -0
- tico/utils/convert.py +10 -2
- tico/utils/padding.py +2 -2
- tico/utils/register_custom_op.py +108 -0
- tico/utils/validate_args_kwargs.py +39 -0
- {tico-0.1.0.dev250629.dist-info → tico-0.1.0.dev250701.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250629.dist-info → tico-0.1.0.dev250701.dist-info}/RECORD +15 -14
- {tico-0.1.0.dev250629.dist-info → tico-0.1.0.dev250701.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250629.dist-info → tico-0.1.0.dev250701.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250629.dist-info → tico-0.1.0.dev250701.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250629.dist-info → tico-0.1.0.dev250701.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
@@ -21,7 +21,7 @@ from tico.config import CompileConfigV1, get_default_config
|
|
21
21
|
from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
|
22
22
|
|
23
23
|
# THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
|
24
|
-
__version__ = "0.1.0.
|
24
|
+
__version__ = "0.1.0.dev250701"
|
25
25
|
|
26
26
|
MINIMUM_SUPPORTED_VERSION = "2.5.0"
|
27
27
|
SECURE_TORCH_VERSION = "2.6.0"
|
@@ -29,10 +29,12 @@ from tico.utils.passes import PassBase, PassResult
|
|
29
29
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
30
30
|
from tico.utils.utils import quant_min_max, set_new_meta_val
|
31
31
|
from tico.utils.validate_args_kwargs import (
|
32
|
+
AddTensorArgs,
|
32
33
|
BmmArgs,
|
33
34
|
LinearArgs,
|
34
35
|
MulTensorArgs,
|
35
36
|
PermuteArgs,
|
37
|
+
ReluArgs,
|
36
38
|
ReshapeArgs,
|
37
39
|
)
|
38
40
|
|
@@ -77,7 +79,7 @@ def _u8_to_i16(qparam: QuantParam) -> QuantParam:
|
|
77
79
|
max_ = u8_scale * (255 - u8_zerop)
|
78
80
|
min_ = u8_scale * (-u8_zerop)
|
79
81
|
|
80
|
-
abs_max = max([max_, min_], key=abs)
|
82
|
+
abs_max = abs(max([max_, min_], key=abs))
|
81
83
|
s16_scale = abs_max / 32767
|
82
84
|
s16_zerop = 0
|
83
85
|
|
@@ -210,6 +212,42 @@ class InsertQuantizeOnDtypeMismatch(PassBase):
|
|
210
212
|
logger.debug(
|
211
213
|
f"quantize_per_tensor.default is inserted after {node.name}."
|
212
214
|
)
|
215
|
+
else:
|
216
|
+
raise NotYetSupportedError(
|
217
|
+
f"Unsupported dtype: From {qparam_dtype(inp)} to {qparam_dtype(node)}"
|
218
|
+
)
|
219
|
+
|
220
|
+
elif node.target == torch.ops.aten.add.Tensor:
|
221
|
+
add_args = AddTensorArgs(*node.args, **node.kwargs)
|
222
|
+
x = add_args.input
|
223
|
+
y = add_args.other
|
224
|
+
|
225
|
+
if not isinstance(x, torch.fx.Node):
|
226
|
+
continue
|
227
|
+
if not isinstance(y, torch.fx.Node):
|
228
|
+
continue
|
229
|
+
|
230
|
+
if QPARAM_KEY not in x.meta:
|
231
|
+
continue
|
232
|
+
if QPARAM_KEY not in y.meta:
|
233
|
+
continue
|
234
|
+
if QPARAM_KEY not in node.meta:
|
235
|
+
continue
|
236
|
+
|
237
|
+
if qparam_dtype(x) == qparam_dtype(node):
|
238
|
+
continue
|
239
|
+
|
240
|
+
if qparam_dtype(x) != qparam_dtype(y):
|
241
|
+
continue
|
242
|
+
|
243
|
+
if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
|
244
|
+
quantize = _insert_quantize_op_after(node)
|
245
|
+
|
246
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
247
|
+
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
|
248
|
+
logger.debug(
|
249
|
+
f"quantize_per_tensor.default is inserted after {node.name}."
|
250
|
+
)
|
213
251
|
else:
|
214
252
|
raise NotYetSupportedError("Unsupported dtype")
|
215
253
|
|
@@ -335,6 +373,30 @@ class InsertQuantizeOnDtypeMismatch(PassBase):
|
|
335
373
|
else:
|
336
374
|
raise NotYetSupportedError("Unsupported dtype")
|
337
375
|
|
376
|
+
elif node.target == torch.ops.aten.relu.default:
|
377
|
+
relu_args = ReluArgs(*node.args, **node.kwargs)
|
378
|
+
inp = relu_args.input
|
379
|
+
|
380
|
+
if QPARAM_KEY not in inp.meta:
|
381
|
+
continue
|
382
|
+
|
383
|
+
if QPARAM_KEY not in node.meta:
|
384
|
+
continue
|
385
|
+
|
386
|
+
if qparam_dtype(inp) == qparam_dtype(node):
|
387
|
+
continue
|
388
|
+
|
389
|
+
if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
|
390
|
+
quantize = _insert_quantize_op_after(node)
|
391
|
+
|
392
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
393
|
+
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
|
394
|
+
logger.debug(
|
395
|
+
f"quantize_per_tensor.default is inserted after {node.name}."
|
396
|
+
)
|
397
|
+
else:
|
398
|
+
raise NotYetSupportedError("Unsupported dtype")
|
399
|
+
|
338
400
|
# TODO Support more ops.
|
339
401
|
|
340
402
|
graph.eliminate_dead_code()
|
@@ -30,6 +30,7 @@ from tico.utils.utils import is_target_node
|
|
30
30
|
from tico.utils.validate_args_kwargs import (
|
31
31
|
AvgPool2dArgs,
|
32
32
|
Conv2DArgs,
|
33
|
+
ConvTranspose2DArgs,
|
33
34
|
DequantizePerChannelArgs,
|
34
35
|
DequantizePerTensorArgs,
|
35
36
|
InstanceNormArgs,
|
@@ -37,7 +38,9 @@ from tico.utils.validate_args_kwargs import (
|
|
37
38
|
)
|
38
39
|
|
39
40
|
|
40
|
-
def get_permute_weight_input(
|
41
|
+
def get_permute_weight_input(
|
42
|
+
conv_args: Conv2DArgs | ConvTranspose2DArgs,
|
43
|
+
) -> torch.fx.Node:
|
41
44
|
"""
|
42
45
|
Retrieves the weight input for the permute operation.
|
43
46
|
|
@@ -194,6 +197,85 @@ class LegalizePreDefinedLayoutOperators(PassBase):
|
|
194
197
|
modified = True
|
195
198
|
return modified
|
196
199
|
|
200
|
+
def legalize_conv_transpose2d(self, exported_program, node) -> bool:
|
201
|
+
logger = logging.getLogger(__name__)
|
202
|
+
modified = False
|
203
|
+
|
204
|
+
graph_module = exported_program.graph_module
|
205
|
+
graph = graph_module.graph
|
206
|
+
|
207
|
+
args = ConvTranspose2DArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
208
|
+
input = args.input
|
209
|
+
padding = args.padding
|
210
|
+
groups = args.groups
|
211
|
+
dilation = args.dilation
|
212
|
+
|
213
|
+
input_shape = extract_shape(input)
|
214
|
+
if not (len(input_shape) == 4):
|
215
|
+
raise NotYetSupportedError(
|
216
|
+
f"Only support 4D input tensor: node's input shape: {input_shape}"
|
217
|
+
)
|
218
|
+
|
219
|
+
if groups != 1:
|
220
|
+
raise NotYetSupportedError(
|
221
|
+
f"Only support groups=1: node's groups: {groups}"
|
222
|
+
)
|
223
|
+
|
224
|
+
if dilation != [1, 1]:
|
225
|
+
raise NotYetSupportedError(
|
226
|
+
f"Only support dilation=[1, 1]: node's groups: {dilation}"
|
227
|
+
)
|
228
|
+
|
229
|
+
NCHW_to_NHWC = [0, 2, 3, 1]
|
230
|
+
# input permute
|
231
|
+
with graph.inserting_after(input):
|
232
|
+
input_permute = create_node(
|
233
|
+
graph,
|
234
|
+
torch.ops.aten.permute.default,
|
235
|
+
args=(input, NCHW_to_NHWC),
|
236
|
+
origin=input,
|
237
|
+
)
|
238
|
+
node.update_arg(node.args.index(input), input_permute)
|
239
|
+
|
240
|
+
# weight permute
|
241
|
+
weight = get_permute_weight_input(args)
|
242
|
+
with graph.inserting_after(weight):
|
243
|
+
perm = [1, 2, 3, 0] # IOHW_to_OHWI
|
244
|
+
weight_permute = create_node(
|
245
|
+
graph,
|
246
|
+
torch.ops.aten.permute.default,
|
247
|
+
args=(weight, perm),
|
248
|
+
origin=weight,
|
249
|
+
)
|
250
|
+
if args.weight.target in [
|
251
|
+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
252
|
+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
253
|
+
]:
|
254
|
+
dq = args.weight
|
255
|
+
dq.update_arg(dq.args.index(weight), weight_permute)
|
256
|
+
# Need to update dq.meta["val"] in FillMetaVal pass.
|
257
|
+
del dq.meta["val"]
|
258
|
+
else:
|
259
|
+
node.update_arg(node.args.index(weight), weight_permute)
|
260
|
+
|
261
|
+
with graph.inserting_before(node):
|
262
|
+
legalized_op = torch.ops.circle_custom.transpose_conv
|
263
|
+
circle_op = create_node(
|
264
|
+
graph, legalized_op, args=node.args, kwargs=node.kwargs, origin=node
|
265
|
+
)
|
266
|
+
# output permute
|
267
|
+
NHWC_to_NCHW = [0, 3, 1, 2]
|
268
|
+
conv_out_permute = create_node(
|
269
|
+
graph,
|
270
|
+
torch.ops.aten.permute.default,
|
271
|
+
args=(circle_op, NHWC_to_NCHW),
|
272
|
+
)
|
273
|
+
node.replace_all_uses_with(conv_out_permute, propagate_meta=True)
|
274
|
+
|
275
|
+
logger.debug(f"{node.name} is replaced with {circle_op.name}")
|
276
|
+
modified = True
|
277
|
+
return modified
|
278
|
+
|
197
279
|
def legalize_instance_norm(self, exported_program, node) -> bool:
|
198
280
|
logger = logging.getLogger(__name__)
|
199
281
|
modified = False
|
@@ -365,6 +447,7 @@ class LegalizePreDefinedLayoutOperators(PassBase):
|
|
365
447
|
target_to_legalize_func = {
|
366
448
|
torch.ops.aten.conv2d.default: self.legalize_conv2d,
|
367
449
|
torch.ops.aten.conv2d.padding: self.legalize_conv2d,
|
450
|
+
torch.ops.aten.conv_transpose2d.input: self.legalize_conv_transpose2d,
|
368
451
|
torch.ops.aten.max_pool2d_with_indices.default: self.legalize_max_pool2d_with_indices,
|
369
452
|
torch.ops.aten.avg_pool2d.default: self.legalize_avg_pool2d,
|
370
453
|
torch.ops.aten.instance_norm.default: self.legalize_instance_norm,
|
@@ -122,7 +122,7 @@ class Conv2dVisitor(NodeVisitor):
|
|
122
122
|
|
123
123
|
if is_valid_padding(padding):
|
124
124
|
conv2d_padding_type = VALID
|
125
|
-
elif is_same_padding(padding, input_shape, output_shape):
|
125
|
+
elif is_same_padding(padding, input_shape, output_shape) and stride == [1, 1]:
|
126
126
|
conv2d_padding_type = SAME
|
127
127
|
else:
|
128
128
|
assert isinstance(padding, list) and len(padding) == 2
|
@@ -0,0 +1,165 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, Optional, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_mapping import (
|
24
|
+
circle_legalize_dtype_to,
|
25
|
+
extract_circle_dtype,
|
26
|
+
extract_shape,
|
27
|
+
)
|
28
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
29
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
30
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
31
|
+
from tico.serialize.quant_param import QPARAM_KEY, QuantParam
|
32
|
+
from tico.utils.define import define_pad_node
|
33
|
+
from tico.utils.padding import is_same_padding, is_valid_padding, SAME, VALID
|
34
|
+
from tico.utils.validate_args_kwargs import ConvTranspose2DArgs
|
35
|
+
|
36
|
+
|
37
|
+
@register_node_visitor
|
38
|
+
class TransposeConvVisitor(NodeVisitor):
|
39
|
+
target: List[torch._ops.OpOverload] = [
|
40
|
+
torch.ops.circle_custom.transpose_conv,
|
41
|
+
]
|
42
|
+
|
43
|
+
def define_transpose_conv_node(
|
44
|
+
self, padding: int, stride: List, inputs: List, outputs: List
|
45
|
+
) -> circle.Operator.OperatorT:
|
46
|
+
def set_transpose_conv_option(operator, stride):
|
47
|
+
operator.builtinOptionsType = (
|
48
|
+
circle.BuiltinOptions.BuiltinOptions.TransposeConvOptions
|
49
|
+
)
|
50
|
+
option = circle.TransposeConvOptions.TransposeConvOptionsT()
|
51
|
+
option.padding = padding
|
52
|
+
option.strideH = stride[0]
|
53
|
+
option.strideW = stride[1]
|
54
|
+
option.fusedActivationFunction = (
|
55
|
+
circle.ActivationFunctionType.ActivationFunctionType.NONE
|
56
|
+
)
|
57
|
+
operator.builtinOptions = option
|
58
|
+
|
59
|
+
transpose_conv_op_index = get_op_index(
|
60
|
+
circle.BuiltinOperator.BuiltinOperator.TRANSPOSE_CONV, self._op_codes
|
61
|
+
)
|
62
|
+
operator = create_builtin_operator(
|
63
|
+
self.graph, transpose_conv_op_index, inputs, outputs
|
64
|
+
)
|
65
|
+
set_transpose_conv_option(operator, stride)
|
66
|
+
return operator
|
67
|
+
|
68
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph):
|
69
|
+
super().__init__(op_codes, graph)
|
70
|
+
|
71
|
+
def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
|
72
|
+
args = ConvTranspose2DArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
73
|
+
|
74
|
+
input_ = args.input
|
75
|
+
weight = args.weight
|
76
|
+
bias = args.bias
|
77
|
+
stride = args.stride
|
78
|
+
padding = args.padding
|
79
|
+
output_padding = args.output_padding
|
80
|
+
groups = args.groups
|
81
|
+
dilation = args.dilation
|
82
|
+
|
83
|
+
assert groups == 1, "Only support group 1"
|
84
|
+
|
85
|
+
input_dtype: int = extract_circle_dtype(input_)
|
86
|
+
input_shape = list(extract_shape(input_))
|
87
|
+
assert len(input_shape) == 4, len(input_shape)
|
88
|
+
output_shape = extract_shape(node)
|
89
|
+
assert len(output_shape) == 4, len(output_shape)
|
90
|
+
|
91
|
+
conv_input: torch.fx.Node | circle.Tensor.TensorT = input_
|
92
|
+
weight_shape = list(extract_shape(weight))
|
93
|
+
|
94
|
+
if is_valid_padding(padding):
|
95
|
+
tconv2d_padding_type = VALID
|
96
|
+
elif is_same_padding(padding, input_shape, output_shape) and stride == [1, 1]:
|
97
|
+
tconv2d_padding_type = SAME
|
98
|
+
else:
|
99
|
+
assert isinstance(padding, list) and len(padding) == 2
|
100
|
+
|
101
|
+
tconv2d_padding_type = VALID
|
102
|
+
|
103
|
+
# Padding is not valid or same, so we use valid padding and add padding operator before tconv2d operator.
|
104
|
+
# when data_foramt is "NHWC", padding should be [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]
|
105
|
+
paddings = torch.tensor(
|
106
|
+
[
|
107
|
+
[0, 0],
|
108
|
+
[padding[0], padding[0]],
|
109
|
+
[padding[1], padding[1]],
|
110
|
+
[0, 0],
|
111
|
+
],
|
112
|
+
dtype=torch.int32,
|
113
|
+
)
|
114
|
+
pad_output_shape = [
|
115
|
+
input_shape[0],
|
116
|
+
input_shape[1],
|
117
|
+
input_shape[2],
|
118
|
+
input_shape[3],
|
119
|
+
]
|
120
|
+
# Add (pad_top+pad_bottom) to pad_output_shape_h
|
121
|
+
pad_output_shape[1] += padding[0] * 2
|
122
|
+
# Add (pad_left+pad_Right) to pad_output_shape_w
|
123
|
+
pad_output_shape[2] += padding[1] * 2
|
124
|
+
# create padded output tensor
|
125
|
+
input_qparam: Optional[QuantParam] = (
|
126
|
+
input_.meta[QPARAM_KEY] if QPARAM_KEY in input_.meta else None
|
127
|
+
)
|
128
|
+
pad_output = self.graph.add_tensor_from_scratch(
|
129
|
+
prefix=f"{node.name}_input_pad_output",
|
130
|
+
shape=pad_output_shape,
|
131
|
+
dtype=input_dtype,
|
132
|
+
qparam=input_qparam,
|
133
|
+
source_node=node,
|
134
|
+
)
|
135
|
+
# CirclePad
|
136
|
+
pad_operator = define_pad_node(
|
137
|
+
self.graph, self._op_codes, [input_, paddings], [pad_output]
|
138
|
+
)
|
139
|
+
self.graph.add_operator(pad_operator)
|
140
|
+
conv_input = pad_output
|
141
|
+
|
142
|
+
if bias is None:
|
143
|
+
# luci-interpreter can't run no bias conv. Let's add zero vector for bias.
|
144
|
+
assert len(weight_shape) == 4
|
145
|
+
out_channel = weight_shape[0]
|
146
|
+
bias = [0.0] * out_channel # type: ignore[assignment]
|
147
|
+
|
148
|
+
# First arguemnt is output shape of tconv.
|
149
|
+
assert output_shape[0] == input_shape[0]
|
150
|
+
assert output_shape[3] == weight_shape[0]
|
151
|
+
tconv_output = circle_legalize_dtype_to(output_shape, dtype=torch.int32)
|
152
|
+
|
153
|
+
tconv_output_tensor = self.graph.add_const_tensor(
|
154
|
+
tconv_output, source_node=node
|
155
|
+
)
|
156
|
+
|
157
|
+
# TConv2D
|
158
|
+
tconv2d_operator = self.define_transpose_conv_node(
|
159
|
+
tconv2d_padding_type, # 'SAME'(0) or 'VALID'(1)
|
160
|
+
stride,
|
161
|
+
[tconv_output_tensor, weight, conv_input, bias],
|
162
|
+
[node],
|
163
|
+
)
|
164
|
+
|
165
|
+
return tconv2d_operator
|
tico/utils/convert.py
CHANGED
@@ -100,6 +100,7 @@ def traced_run_decompositions(exported_program: ExportedProgram):
|
|
100
100
|
torch.ops.aten.conv2d.padding,
|
101
101
|
torch.ops.aten.conv1d.default,
|
102
102
|
torch.ops.aten.conv1d.padding,
|
103
|
+
torch.ops.aten.conv_transpose2d.input,
|
103
104
|
torch.ops.aten.instance_norm.default,
|
104
105
|
torch.ops.aten._safe_softmax.default,
|
105
106
|
torch.ops.aten.relu6.default, # Do not decompose to hardtanh
|
@@ -116,6 +117,7 @@ def traced_run_decompositions(exported_program: ExportedProgram):
|
|
116
117
|
torch.ops.aten.conv2d.padding,
|
117
118
|
torch.ops.aten.conv1d.default,
|
118
119
|
torch.ops.aten.conv1d.padding,
|
120
|
+
torch.ops.aten.conv_transpose2d.input,
|
119
121
|
torch.ops.aten.instance_norm.default,
|
120
122
|
torch.ops.aten._safe_softmax.default,
|
121
123
|
torch.ops.aten.relu6.default, # Do not decompose to hardtanh
|
@@ -174,7 +176,7 @@ def check_training_ops(exported_program: ExportedProgram):
|
|
174
176
|
|
175
177
|
if found:
|
176
178
|
raise RuntimeError(
|
177
|
-
f"Detected training-mode ops {
|
179
|
+
f"Detected training-mode ops {found}. Call `model.eval()` before export."
|
178
180
|
)
|
179
181
|
|
180
182
|
|
@@ -186,7 +188,6 @@ def convert_exported_module_to_circle(
|
|
186
188
|
logger.debug("Input ExportedProgram (must be core aten)")
|
187
189
|
logger.debug(exported_program)
|
188
190
|
|
189
|
-
check_training_ops(exported_program)
|
190
191
|
# PRE-EDGE PASSES
|
191
192
|
#
|
192
193
|
# Here are the passes that run before to_edge() conversion.
|
@@ -275,6 +276,7 @@ def convert_exported_module_to_circle(
|
|
275
276
|
quantize_graph.run(exported_program)
|
276
277
|
|
277
278
|
check_unsupported_target(exported_program)
|
279
|
+
check_training_ops(exported_program)
|
278
280
|
circle_program = build_circle(exported_program)
|
279
281
|
|
280
282
|
return circle_program
|
@@ -287,6 +289,12 @@ def convert(
|
|
287
289
|
strict: bool = True,
|
288
290
|
config: CompileConfigBase = get_default_config(),
|
289
291
|
) -> CircleModel:
|
292
|
+
if hasattr(mod, "training") and mod.training:
|
293
|
+
logger = logging.getLogger(__name__)
|
294
|
+
logger.fatal(
|
295
|
+
"Your model is in TRAINING MODE. PLEASE CHECK IF YOU FORGOT `model.eval()`."
|
296
|
+
)
|
297
|
+
|
290
298
|
with torch.no_grad():
|
291
299
|
exported_program = export(mod, args, kwargs, strict=strict)
|
292
300
|
|
tico/utils/padding.py
CHANGED
@@ -40,8 +40,8 @@ def is_same_padding(
|
|
40
40
|
if isinstance(padding, list):
|
41
41
|
assert len(padding) == 2, "Padding should be a list of length 2."
|
42
42
|
|
43
|
-
input_HW = input_shape[1:
|
44
|
-
output_HW = output_shape[1:
|
43
|
+
input_HW = tuple(input_shape[1:3]) # N H W C
|
44
|
+
output_HW = tuple(output_shape[1:3]) # N H W C
|
45
45
|
return input_HW == output_HW
|
46
46
|
|
47
47
|
raise InvalidArgumentError("Invalid padding.")
|
tico/utils/register_custom_op.py
CHANGED
@@ -371,6 +371,113 @@ def CircleDepthwiseConv2dPadding():
|
|
371
371
|
return NHWC_output
|
372
372
|
|
373
373
|
|
374
|
+
def CircleTransposeConv():
|
375
|
+
"""
|
376
|
+
Note that this op follows the input spec of `aten.conv_transpose2d.input` whose number
|
377
|
+
of arguments meets (2 <= node.args <= 8) condition.
|
378
|
+
[RESTRICTION]
|
379
|
+
Therefore, I tried to define a spec of it as transpose_conv(input, weight, *args).
|
380
|
+
But, custom operators in torch do not support positional-only args. So, I set it
|
381
|
+
them as None by default.
|
382
|
+
"""
|
383
|
+
|
384
|
+
@custom_op("circle_custom::transpose_conv", mutates_args=())
|
385
|
+
def transpose_conv(
|
386
|
+
input_: torch.Tensor,
|
387
|
+
weight: torch.Tensor,
|
388
|
+
bias: Optional[torch.Tensor] = None,
|
389
|
+
stride: Optional[List[int]] = None,
|
390
|
+
padding: Optional[List[int]] = None,
|
391
|
+
output_padding: Optional[List[int]] = None,
|
392
|
+
groups: Optional[int] = None,
|
393
|
+
dilation: Optional[List[int]] = None,
|
394
|
+
) -> torch.Tensor:
|
395
|
+
"""
|
396
|
+
Set default values.
|
397
|
+
Custom operators have limited types when it comes to default values.
|
398
|
+
So, let's set them by None in input specs, and then, set it by default values.
|
399
|
+
https://github.com/pytorch/pytorch/blob/6b05aafc/torch/_library/infer_schema.py#L131-L144
|
400
|
+
"""
|
401
|
+
stride = [1, 1] if stride is None else stride
|
402
|
+
padding = [0, 0] if padding is None else padding
|
403
|
+
output_padding = [0, 0] if output_padding is None else output_padding
|
404
|
+
groups = 1 if groups is None else groups
|
405
|
+
dilation = [1, 1] if dilation is None else dilation
|
406
|
+
if groups != 1:
|
407
|
+
raise RuntimeError(
|
408
|
+
f"CircleTransposeConv only supports 1 'groups'. the node's groups: {groups}"
|
409
|
+
)
|
410
|
+
|
411
|
+
NHWC_to_NCHW = [0, 3, 1, 2]
|
412
|
+
OHWI_to_IOHW = [3, 0, 1, 2]
|
413
|
+
NCHW_input = torch.ops.aten.permute.default(input_, NHWC_to_NCHW)
|
414
|
+
OIHW_weight = torch.ops.aten.permute.default(weight, OHWI_to_IOHW)
|
415
|
+
|
416
|
+
args = [
|
417
|
+
NCHW_input,
|
418
|
+
OIHW_weight,
|
419
|
+
bias,
|
420
|
+
stride,
|
421
|
+
padding,
|
422
|
+
output_padding,
|
423
|
+
groups,
|
424
|
+
dilation,
|
425
|
+
]
|
426
|
+
NCHW_output = torch.ops.aten.conv_transpose2d.input(*args)
|
427
|
+
NCHW_to_NHWC = [0, 2, 3, 1]
|
428
|
+
NHWC_output = torch.ops.aten.permute.default(NCHW_output, NCHW_to_NHWC)
|
429
|
+
|
430
|
+
return NHWC_output
|
431
|
+
|
432
|
+
@register_fake("circle_custom::transpose_conv")
|
433
|
+
def _(
|
434
|
+
input_: torch.Tensor,
|
435
|
+
weight: torch.Tensor,
|
436
|
+
bias: Optional[torch.Tensor] = None,
|
437
|
+
stride: Optional[List[int]] = None,
|
438
|
+
padding: Optional[List[int]] = None,
|
439
|
+
output_padding: Optional[List[int]] = None,
|
440
|
+
groups: Optional[int] = None,
|
441
|
+
dilation: Optional[List[int]] = None,
|
442
|
+
):
|
443
|
+
"""
|
444
|
+
Set default values.
|
445
|
+
Custom operators have limited types when it comes to default values.
|
446
|
+
So, let's set them by None in input specs, and then, set it by default values.
|
447
|
+
https://github.com/pytorch/pytorch/blob/6b05aafc/torch/_library/infer_schema.py#L131-L144
|
448
|
+
"""
|
449
|
+
stride = [1, 1] if stride is None else stride
|
450
|
+
padding = [0, 0] if padding is None else padding
|
451
|
+
output_padding = [0, 0] if output_padding is None else output_padding
|
452
|
+
groups = 1 if groups is None else groups
|
453
|
+
dilation = [1, 1] if dilation is None else dilation
|
454
|
+
if groups != 1:
|
455
|
+
raise RuntimeError(
|
456
|
+
f"CircleConv2d only supports 1 'groups'. the node's groups: {groups}"
|
457
|
+
)
|
458
|
+
|
459
|
+
NHWC_to_NCHW = [0, 3, 1, 2]
|
460
|
+
OHWI_to_IOHW = [3, 0, 1, 2]
|
461
|
+
NCHW_input = torch.ops.aten.permute.default(input_, NHWC_to_NCHW)
|
462
|
+
OIHW_weight = torch.ops.aten.permute.default(weight, OHWI_to_IOHW)
|
463
|
+
|
464
|
+
args = [
|
465
|
+
NCHW_input,
|
466
|
+
OIHW_weight,
|
467
|
+
bias,
|
468
|
+
stride,
|
469
|
+
padding,
|
470
|
+
output_padding,
|
471
|
+
groups,
|
472
|
+
dilation,
|
473
|
+
]
|
474
|
+
NCHW_output = torch.ops.aten.conv_transpose2d.input(*args)
|
475
|
+
NCHW_to_NHWC = [0, 2, 3, 1]
|
476
|
+
NHWC_output = torch.ops.aten.permute.default(NCHW_output, NCHW_to_NHWC)
|
477
|
+
|
478
|
+
return NHWC_output
|
479
|
+
|
480
|
+
|
374
481
|
def CircleMaxPool2D():
|
375
482
|
"""
|
376
483
|
Note that this op follows the input spec of `aten.max_pool2d_with_indices.default` whose number
|
@@ -603,6 +710,7 @@ def RegisterOps():
|
|
603
710
|
CircleDepthwiseConv2dPadding()
|
604
711
|
CircleConv2d()
|
605
712
|
CircleConv2dPadding()
|
713
|
+
CircleTransposeConv()
|
606
714
|
CircleMaxPool2D()
|
607
715
|
CircleAvgPool2D()
|
608
716
|
CircleInstanceNorm()
|
@@ -208,6 +208,45 @@ class ConstantPadNdArgs:
|
|
208
208
|
value: int | float
|
209
209
|
|
210
210
|
|
211
|
+
@enforce_type
|
212
|
+
@dataclass
|
213
|
+
class ConvArgs:
|
214
|
+
"""
|
215
|
+
convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor
|
216
|
+
"""
|
217
|
+
|
218
|
+
input: torch.fx.Node
|
219
|
+
weight: torch.fx.Node
|
220
|
+
bias: Union[torch.fx.Node, None]
|
221
|
+
stride: List[int]
|
222
|
+
padding: List[int]
|
223
|
+
dilation: List[int]
|
224
|
+
transposed: bool
|
225
|
+
output_padding: List[int]
|
226
|
+
groups: int
|
227
|
+
|
228
|
+
|
229
|
+
@enforce_type
|
230
|
+
@dataclass
|
231
|
+
class ConvTranspose2DArgs:
|
232
|
+
"""
|
233
|
+
conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt groups=1, SymInt[2] dilation=1) -> Tensor
|
234
|
+
"""
|
235
|
+
|
236
|
+
input: torch.fx.Node
|
237
|
+
weight: torch.fx.Node
|
238
|
+
bias: Union[torch.fx.Node, None] = None
|
239
|
+
stride: List[int] = field(default_factory=lambda: [1, 1])
|
240
|
+
padding: List[int] = field(default_factory=lambda: [0, 0])
|
241
|
+
output_padding: List[int] = field(default_factory=lambda: [0, 0])
|
242
|
+
groups: int = 1
|
243
|
+
dilation: List[int] = field(default_factory=lambda: [1, 1])
|
244
|
+
|
245
|
+
def __post_init__(self):
|
246
|
+
assert len(self.stride) == 2, len(self.stride)
|
247
|
+
assert len(self.dilation) == 2, len(self.dilation)
|
248
|
+
|
249
|
+
|
211
250
|
@enforce_type
|
212
251
|
@dataclass
|
213
252
|
class Conv2DArgs:
|
@@ -1,4 +1,4 @@
|
|
1
|
-
tico/__init__.py,sha256=
|
1
|
+
tico/__init__.py,sha256=na5S7s-9ToqvW2qHQsFkHg9ST_gMeKGTyCnymdNnbo0,1743
|
2
2
|
tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
|
3
3
|
tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
|
4
4
|
tico/config/base.py,sha256=anwOiJFkUxUi7Cef573JgQcjk6S-FSi6O_TLjYASW-g,1244
|
@@ -51,7 +51,7 @@ tico/experimental/quantization/evaluation/executor/circle_executor.py,sha256=eCC
|
|
51
51
|
tico/experimental/quantization/evaluation/executor/triv24_executor.py,sha256=sUoXl6oOO2arAKaNjOBg7HiQja145_Jv6qgY7XtR7A8,5159
|
52
52
|
tico/experimental/quantization/passes/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
53
53
|
tico/experimental/quantization/passes/fold_quant_ops.py,sha256=iaBMyO49CwVkhebMz3rjkHWfWE2LhwH6fORe7n4S6XQ,7040
|
54
|
-
tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py,sha256=
|
54
|
+
tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py,sha256=lQeN6VJNYQyDeucB4KpyyWvIiuhGRq7wjIeaCKdM7ck,15462
|
55
55
|
tico/experimental/quantization/passes/propagate_qparam_backward.py,sha256=TGtyW0Z2qOTgVIasBdGRgbwH31YYd6ek7OvLTmCV614,3118
|
56
56
|
tico/experimental/quantization/passes/propagate_qparam_forward.py,sha256=RhUHGCR2RpBO5KYkQ7Z8U5u7HEwDq2wdKHLKAJCi-5c,5138
|
57
57
|
tico/experimental/quantization/passes/quantize_bias.py,sha256=ZQ3rETYStpW28JUbODRixbq5sDEOiIOB_qWA-Jzuu-Y,4337
|
@@ -79,7 +79,7 @@ tico/passes/fill_meta_val.py,sha256=Xbam6Aq90ZfWItZw1dgLIwH_q8RCiU5JodKNqkj-ink,
|
|
79
79
|
tico/passes/fuse_leading_unsqueeze_reshape.py,sha256=88jwTP35yRyXOk9xdO6YW2OEfdKAws3KFRT16WQz0RI,4291
|
80
80
|
tico/passes/fuse_redundant_reshape_to_mean.py,sha256=GhJS1ZKB6Ns4AhwcW3uUQ6q-0N-AzlD32B2EwusUJHg,3761
|
81
81
|
tico/passes/legalize_causal_mask_value.py,sha256=xKdFwwMaSFCSQpSk8xISOAqFpZ1jIhgbBIqf7KTSGuk,4017
|
82
|
-
tico/passes/legalize_predefined_layout_operators.py,sha256=
|
82
|
+
tico/passes/legalize_predefined_layout_operators.py,sha256=MNx7L2dAlsxSazb-F7c0onPqHleI17zAc7AzQAa9aJ4,18934
|
83
83
|
tico/passes/lower_pow2_to_mul.py,sha256=nfJXa9ZTZMiLg6ownSyvkM4KF2z9tZW34Q3CCWI_vmQ,2402
|
84
84
|
tico/passes/lower_to_resize_nearest_neighbor.py,sha256=N6F56Of8Aiv-KIiYLHnh33WX72W60ZVQSBEYWHdYqNQ,9005
|
85
85
|
tico/passes/lower_to_slice.py,sha256=0qAX3WzZdyMFDW4DiO9b5JFXd4rL1-0doBT6lJvaw_I,7260
|
@@ -115,7 +115,7 @@ tico/serialize/operators/op_cat.py,sha256=XDYOh0XAyrM0TlxVm6Sa0OFFGrKk7aSDcGXC-h
|
|
115
115
|
tico/serialize/operators/op_clamp.py,sha256=ZRAsXLGsZqJEh4wXxESEpRJkRtUuJWTDgAem6lr9_5I,4298
|
116
116
|
tico/serialize/operators/op_clone.py,sha256=vzDYJ8TS3tc2BAyd_z8nt5VqT1inpymSseMEhd9dva0,2394
|
117
117
|
tico/serialize/operators/op_constant_pad_nd.py,sha256=OpP4AP-d1IFcWZolNa-o9ZxzXJQkMdG9WQ66soX3s-E,2675
|
118
|
-
tico/serialize/operators/op_conv2d.py,sha256=
|
118
|
+
tico/serialize/operators/op_conv2d.py,sha256=BmSCunhziD9EhXEkWwFrWkaQ_t3cIhrJJQSRLbgqmxI,7338
|
119
119
|
tico/serialize/operators/op_copy.py,sha256=vaianLQ19-2ZQZ-MdQ07YuOPeFeo_HAx2a0Qfn7I5Kk,6122
|
120
120
|
tico/serialize/operators/op_cos.py,sha256=N12bNyuTQIxRnD0eHRPdFVzRQPMy1NFM4iM8oQ4lYzw,2034
|
121
121
|
tico/serialize/operators/op_cumsum.py,sha256=3fmOf1mIeCX1uhTBcSJmRGXejzLtO8UwaI1eEQDC6nA,3798
|
@@ -175,32 +175,33 @@ tico/serialize/operators/op_sub.py,sha256=yZskQJF0ylXVk02Uid8djPNIWDJ-0uHJar4UYh
|
|
175
175
|
tico/serialize/operators/op_sum.py,sha256=B5aSwQMhyoBe2JYdE5nVQ3QeVDSzL-yuZZujsG08OdQ,2294
|
176
176
|
tico/serialize/operators/op_tanh.py,sha256=rs7FsbQeUQ7Ak8RoQV9ymNGXHXRObojfY_SiqJiyqdA,1846
|
177
177
|
tico/serialize/operators/op_to_copy.py,sha256=a8T0uPMavMO_md1a-4_0dlvDHyZS_xew0qB6xjf69rI,3934
|
178
|
+
tico/serialize/operators/op_transpose_conv.py,sha256=ZpVNWH58FqbQahExeZIpW51w3VtfGg49JlHCDSOv7wg,6370
|
178
179
|
tico/serialize/operators/op_unsqueeze.py,sha256=ZHhfVXSWEiwb2VDYX5uhxbGQyzZjKT7CrbBpVGxVHBU,2310
|
179
180
|
tico/serialize/operators/op_view.py,sha256=5EMww-ve17Vm9XPuV03Tn7vJsjpU2J8U4d_FOrlm9_o,2546
|
180
181
|
tico/serialize/operators/op_where.py,sha256=doE81GSwygrPBm3JIfN9w7kKXxeIYKxgk0eoY22QIcg,2845
|
181
182
|
tico/serialize/operators/utils.py,sha256=lXGpEJW1h8U_-gfc6EWjvvSiq3yJ9P-v1v3EMRT_pSk,2954
|
182
183
|
tico/utils/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
183
|
-
tico/utils/convert.py,sha256=
|
184
|
+
tico/utils/convert.py,sha256=5C8Z2ia2XN4k3XgtJrFZYJSEejoeMllyr8YW6gwu9mw,12763
|
184
185
|
tico/utils/define.py,sha256=Ypgp7YffM4pgPl4Zh6TmogSn1OxGBMRw_e09qYGflZk,1467
|
185
186
|
tico/utils/diff_graph.py,sha256=_eDGGPDPYQD4b--MXX0DLoVgSt_wLfNPt47UlolLLR4,5272
|
186
187
|
tico/utils/errors.py,sha256=f3csJjgbXG9W1aHhqEcou008Aor19W57X8oT5Hx8w1M,954
|
187
188
|
tico/utils/graph.py,sha256=Y6aODsnc_-9l61oanknb7K1jqJ8B35iPypOKkM0Qkk0,9149
|
188
189
|
tico/utils/logging.py,sha256=IlbBWscsaHidI0dNqro1HEXAbIcbkR3BD5ukLy2m95k,1286
|
189
190
|
tico/utils/model.py,sha256=Uqc92AnJXQ2pbvctS2z2F3Ku3yNrwXZ9O33hZVis7is,1250
|
190
|
-
tico/utils/padding.py,sha256=
|
191
|
+
tico/utils/padding.py,sha256=jNMX2KFoZ3c6HTlMU8BAwG3Fyrqpq4F3ytKP13Pg4ps,1498
|
191
192
|
tico/utils/passes.py,sha256=kGmDe__5cPaO6i5EDAoXSVe6yXEoX9hAny4ROb3ZEmQ,2409
|
192
|
-
tico/utils/register_custom_op.py,sha256=
|
193
|
+
tico/utils/register_custom_op.py,sha256=3-Yl6iYmx1qQA2igNHt4hYhQhQMkdPb7gF50LIY8yvc,27350
|
193
194
|
tico/utils/serialize.py,sha256=AQXMBOLu-Kg2Rn-qbqsAtHndjZAZIavlKA0QFgJREHM,1420
|
194
195
|
tico/utils/trace_decorators.py,sha256=ddLIiKQfSaQrxgF1kNpwjFTQnXENzeSfcr1kuAW4jGI,3221
|
195
196
|
tico/utils/utils.py,sha256=fnbZ2RLH6-J-wqb32O4qsR1ce4BJU0wYNrk84QXa6_E,13158
|
196
|
-
tico/utils/validate_args_kwargs.py,sha256=
|
197
|
+
tico/utils/validate_args_kwargs.py,sha256=cJAK6aqdzK3_Xccu6K1FQ32WGdmwWA_SqJ--TPavIuk,26614
|
197
198
|
tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
198
199
|
tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
|
199
200
|
tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
|
200
201
|
tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
|
201
|
-
tico-0.1.0.
|
202
|
-
tico-0.1.0.
|
203
|
-
tico-0.1.0.
|
204
|
-
tico-0.1.0.
|
205
|
-
tico-0.1.0.
|
206
|
-
tico-0.1.0.
|
202
|
+
tico-0.1.0.dev250701.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
|
203
|
+
tico-0.1.0.dev250701.dist-info/METADATA,sha256=FKpHfi_Mwf8vsk5athXrwRHTOulyZP0H45y_zUJ6uqM,8846
|
204
|
+
tico-0.1.0.dev250701.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
205
|
+
tico-0.1.0.dev250701.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
|
206
|
+
tico-0.1.0.dev250701.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
|
207
|
+
tico-0.1.0.dev250701.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|