tico 0.1.0.dev251106__py3-none-any.whl → 0.2.0.dev260122__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +2 -2
- tico/_version.py +1 -0
- tico/passes/convert_conv3d_to_conv2d.py +435 -0
- tico/passes/convert_sym_size_to_circle_shape.py +99 -0
- tico/passes/decompose_batch_norm.py +9 -5
- tico/passes/lower_copy.py +95 -0
- tico/passes/ops.py +4 -0
- tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +251 -0
- tico/quantization/algorithm/fpi_gptq/quantizer.py +180 -0
- tico/quantization/algorithm/gptq/gptq.py +231 -11
- tico/quantization/algorithm/gptq/quantizer.py +18 -6
- tico/quantization/config/{pt2e.py → fpi_gptq.py} +11 -4
- tico/quantization/config/gptq.py +27 -4
- tico/quantization/public_interface.py +0 -10
- tico/quantization/wrapq/quantizer.py +2 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +51 -11
- tico/serialize/operators/adapters/onert/llama_attention.py +51 -0
- tico/serialize/operators/op_attention.py +58 -0
- tico/serialize/operators/op_circle_shape.py +64 -0
- tico/serialize/operators/op_dequantize_per_channel.py +1 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +1 -0
- tico/serialize/operators/op_transpose_conv.py +66 -50
- tico/utils/convert.py +16 -1
- tico/utils/padding.py +13 -5
- tico/utils/record_input.py +2 -2
- tico/utils/register_custom_op.py +63 -0
- tico/utils/validate_args_kwargs.py +49 -4
- tico-0.2.0.dev260122.dist-info/METADATA +631 -0
- {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info}/RECORD +35 -46
- {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info}/WHEEL +1 -1
- {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info}/entry_points.txt +0 -1
- tico/quantization/algorithm/pt2e/annotation/annotator.py +0 -208
- tico/quantization/algorithm/pt2e/annotation/config.py +0 -26
- tico/quantization/algorithm/pt2e/annotation/op/__init__.py +0 -21
- tico/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +0 -63
- tico/quantization/algorithm/pt2e/annotation/op/add.py +0 -55
- tico/quantization/algorithm/pt2e/annotation/op/conv2d.py +0 -90
- tico/quantization/algorithm/pt2e/annotation/op/div.py +0 -55
- tico/quantization/algorithm/pt2e/annotation/op/linear.py +0 -92
- tico/quantization/algorithm/pt2e/annotation/op/mean.py +0 -51
- tico/quantization/algorithm/pt2e/annotation/op/mul.py +0 -55
- tico/quantization/algorithm/pt2e/annotation/op/relu6.py +0 -51
- tico/quantization/algorithm/pt2e/annotation/op/rsqrt.py +0 -51
- tico/quantization/algorithm/pt2e/annotation/op/sub.py +0 -55
- tico/quantization/algorithm/pt2e/annotation/spec.py +0 -45
- tico/quantization/algorithm/pt2e/annotation/utils.py +0 -88
- tico/quantization/algorithm/pt2e/quantizer.py +0 -81
- tico/quantization/algorithm/pt2e/transformation/__init__.py +0 -1
- tico/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -58
- tico/quantization/algorithm/pt2e/utils.py +0 -135
- tico/serialize/operators/op_copy.py +0 -187
- tico-0.1.0.dev251106.dist-info/METADATA +0 -392
- /tico/quantization/algorithm/{pt2e → fpi_gptq}/__init__.py +0 -0
- /tico/{quantization/algorithm/pt2e/annotation → serialize/operators/adapters/onert}/__init__.py +0 -0
- {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info/licenses}/LICENSE +0 -0
- {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import torch._ops
|
|
19
|
+
import torch.fx
|
|
20
|
+
import torch
|
|
21
|
+
from circle_schema import circle
|
|
22
|
+
|
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@register_node_visitor
|
|
30
|
+
class CircleShapeVisitor(NodeVisitor):
|
|
31
|
+
"""
|
|
32
|
+
Visitor for circle_custom::shape operator.
|
|
33
|
+
|
|
34
|
+
This operator extracts the shape of a tensor.
|
|
35
|
+
It's implemented using Circle's SHAPE operator.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
target: List[torch._ops.OpOverload] = [
|
|
39
|
+
torch.ops.circle_custom.shape,
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
|
43
|
+
super().__init__(op_codes, graph)
|
|
44
|
+
|
|
45
|
+
def define_node(
|
|
46
|
+
self,
|
|
47
|
+
node: torch.fx.Node,
|
|
48
|
+
) -> circle.Operator.OperatorT:
|
|
49
|
+
# Args: (input)
|
|
50
|
+
input_node = node.args[0]
|
|
51
|
+
|
|
52
|
+
# SHAPE operator to get the full shape of input
|
|
53
|
+
op_index = get_op_index(
|
|
54
|
+
circle.BuiltinOperator.BuiltinOperator.SHAPE, self._op_codes
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
inputs = [input_node]
|
|
58
|
+
outputs = [node]
|
|
59
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
|
60
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.ShapeOptions
|
|
61
|
+
operator.builtinOptions = circle.ShapeOptions.ShapeOptionsT()
|
|
62
|
+
operator.builtinOptions.outType = circle.TensorType.TensorType.INT32
|
|
63
|
+
|
|
64
|
+
return operator
|
|
@@ -66,6 +66,29 @@ class TransposeConvVisitor(NodeVisitor):
|
|
|
66
66
|
set_transpose_conv_option(operator, stride)
|
|
67
67
|
return operator
|
|
68
68
|
|
|
69
|
+
def define_slice_node(
|
|
70
|
+
self,
|
|
71
|
+
src_tensor,
|
|
72
|
+
begin_vals: List[int],
|
|
73
|
+
size_vals: List[int],
|
|
74
|
+
dst_tensor,
|
|
75
|
+
) -> circle.Operator.OperatorT:
|
|
76
|
+
slice_op_index = get_op_index(
|
|
77
|
+
circle.BuiltinOperator.BuiltinOperator.SLICE, self._op_codes
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Begin / Size as int32 const tensors
|
|
81
|
+
begin_arr = circle_legalize_dtype_to(begin_vals, dtype=torch.int32)
|
|
82
|
+
size_arr = circle_legalize_dtype_to(size_vals, dtype=torch.int32)
|
|
83
|
+
|
|
84
|
+
operator = create_builtin_operator(
|
|
85
|
+
self.graph, slice_op_index, [src_tensor, begin_arr, size_arr], [dst_tensor]
|
|
86
|
+
)
|
|
87
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.SliceOptions
|
|
88
|
+
option = circle.SliceOptions.SliceOptionsT()
|
|
89
|
+
operator.builtinOptions = option
|
|
90
|
+
return operator
|
|
91
|
+
|
|
69
92
|
def __init__(self, op_codes: Dict[OpCode, int], graph):
|
|
70
93
|
super().__init__(op_codes, graph)
|
|
71
94
|
|
|
@@ -88,65 +111,58 @@ class TransposeConvVisitor(NodeVisitor):
|
|
|
88
111
|
assert len(output_shape) == 4, len(output_shape)
|
|
89
112
|
assert len(weight_shape) == 4, len(weight_shape)
|
|
90
113
|
|
|
91
|
-
pad_decision = identify_padding(
|
|
114
|
+
pad_decision = identify_padding(
|
|
115
|
+
padding, input_shape, output_shape, stride, is_transpose=True
|
|
116
|
+
)
|
|
92
117
|
|
|
93
118
|
conv_input: torch.fx.Node | circle.Tensor.TensorT = input_
|
|
94
|
-
if pad_decision.explicit_pad_hw is not None:
|
|
95
|
-
pad_h, pad_w = pad_decision.explicit_pad_hw
|
|
96
|
-
paddings = torch.tensor(
|
|
97
|
-
[
|
|
98
|
-
[0, 0],
|
|
99
|
-
[pad_h, pad_h],
|
|
100
|
-
[pad_w, pad_w],
|
|
101
|
-
[0, 0],
|
|
102
|
-
],
|
|
103
|
-
dtype=torch.int32,
|
|
104
|
-
)
|
|
105
|
-
pad_output_shape: List[int | torch.SymInt] = [
|
|
106
|
-
input_shape[0],
|
|
107
|
-
input_shape[1] + pad_h * 2,
|
|
108
|
-
input_shape[2] + pad_w * 2,
|
|
109
|
-
input_shape[3],
|
|
110
|
-
]
|
|
111
|
-
pad_output_cshape, pad_output_cshape_signature = to_circle_shape(
|
|
112
|
-
pad_output_shape
|
|
113
|
-
)
|
|
114
|
-
# create padded output tensor
|
|
115
|
-
input_qparam: Optional[QuantParam] = input_.meta.get(QPARAM_KEY)
|
|
116
|
-
pad_output = self.graph.add_tensor_from_scratch(
|
|
117
|
-
prefix=f"{node.name}_input_pad_output",
|
|
118
|
-
shape=pad_output_cshape,
|
|
119
|
-
shape_signature=pad_output_cshape_signature,
|
|
120
|
-
dtype=extract_circle_dtype(input_),
|
|
121
|
-
qparam=input_qparam,
|
|
122
|
-
source_node=node,
|
|
123
|
-
)
|
|
124
|
-
# CirclePad
|
|
125
|
-
pad_operator = define_pad_node(
|
|
126
|
-
self.graph, self._op_codes, [input_, paddings], [pad_output]
|
|
127
|
-
)
|
|
128
|
-
self.graph.add_operator(pad_operator)
|
|
129
|
-
conv_input = pad_output
|
|
130
|
-
|
|
131
119
|
if bias is None:
|
|
132
120
|
# luci-interpreter can't run no bias conv. Let's add zero vector for bias.
|
|
133
121
|
bias = [0.0] * weight_shape[0] # type: ignore[assignment]
|
|
134
122
|
|
|
135
|
-
#
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
123
|
+
# Compute pre-crop output shape if we need to apply an explicit crop.
|
|
124
|
+
if pad_decision.output_crop_hw is not None:
|
|
125
|
+
pad_h, pad_w = pad_decision.output_crop_hw
|
|
126
|
+
pre_h = int(output_shape[1]) + 2 * pad_h
|
|
127
|
+
pre_w = int(output_shape[2]) + 2 * pad_w
|
|
128
|
+
pre_out_shape = [output_shape[0], pre_h, pre_w, output_shape[3]] # NHWC
|
|
129
|
+
else:
|
|
130
|
+
pre_out_shape = list(output_shape)
|
|
131
|
+
|
|
132
|
+
tconv_output = circle_legalize_dtype_to(pre_out_shape, dtype=torch.int32)
|
|
133
|
+
|
|
134
|
+
pre_out_cshape, pre_out_csig = to_circle_shape(pre_out_shape)
|
|
135
|
+
tconv_tmp = node # type: ignore[assignment]
|
|
136
|
+
if pad_decision.output_crop_hw is not None:
|
|
137
|
+
tconv_tmp = self.graph.add_tensor_from_scratch( # type: ignore[assignment]
|
|
138
|
+
prefix=f"{node.name}_tconv_out_pre_crop",
|
|
139
|
+
shape=pre_out_cshape,
|
|
140
|
+
shape_signature=pre_out_csig,
|
|
141
|
+
dtype=extract_circle_dtype(node),
|
|
142
|
+
qparam=node.meta.get(QPARAM_KEY),
|
|
143
|
+
source_node=node,
|
|
144
|
+
)
|
|
143
145
|
|
|
144
|
-
# TConv2D
|
|
145
146
|
tconv2d_operator = self.define_transpose_conv_node(
|
|
146
|
-
pad_decision.conv_padding_type,
|
|
147
|
+
pad_decision.conv_padding_type,
|
|
147
148
|
stride,
|
|
148
|
-
[
|
|
149
|
-
[
|
|
149
|
+
[tconv_output, weight, conv_input, bias],
|
|
150
|
+
[tconv_tmp],
|
|
150
151
|
)
|
|
151
152
|
|
|
153
|
+
# If we need an output crop, insert a SLICE to produce the final tensor.
|
|
154
|
+
if pad_decision.output_crop_hw is not None:
|
|
155
|
+
self.graph.add_operator(tconv2d_operator)
|
|
156
|
+
pad_h, pad_w = pad_decision.output_crop_hw
|
|
157
|
+
begin = [0, pad_h, pad_w, 0]
|
|
158
|
+
size = [
|
|
159
|
+
int(output_shape[0]),
|
|
160
|
+
int(output_shape[1]),
|
|
161
|
+
int(output_shape[2]),
|
|
162
|
+
int(output_shape[3]),
|
|
163
|
+
]
|
|
164
|
+
|
|
165
|
+
slice_op = self.define_slice_node(tconv_tmp, begin, size, node)
|
|
166
|
+
return slice_op
|
|
167
|
+
|
|
152
168
|
return tconv2d_operator
|
tico/utils/convert.py
CHANGED
|
@@ -25,10 +25,12 @@ from tico.passes.cast_clamp_mixed_type_args import CastClampMixedTypeArgs
|
|
|
25
25
|
from tico.passes.cast_mixed_type_args import CastMixedTypeArgs
|
|
26
26
|
from tico.passes.const_prop_pass import ConstPropPass
|
|
27
27
|
from tico.passes.convert_conv1d_to_conv2d import ConvertConv1dToConv2d
|
|
28
|
+
from tico.passes.convert_conv3d_to_conv2d import ConvertConv3dToConv2d
|
|
28
29
|
from tico.passes.convert_expand_to_slice_cat import ConvertExpandToSliceCat
|
|
29
30
|
from tico.passes.convert_layout_op_to_reshape import ConvertLayoutOpToReshape
|
|
30
31
|
from tico.passes.convert_matmul_to_linear import ConvertMatmulToLinear
|
|
31
32
|
from tico.passes.convert_repeat_to_expand_copy import ConvertRepeatToExpandCopy
|
|
33
|
+
from tico.passes.convert_sym_size_to_circle_shape import ConvertSymSizeToCircleShape
|
|
32
34
|
from tico.passes.convert_to_relu6 import ConvertToReLU6
|
|
33
35
|
from tico.passes.decompose_addmm import DecomposeAddmm
|
|
34
36
|
from tico.passes.decompose_batch_norm import DecomposeBatchNorm
|
|
@@ -47,6 +49,7 @@ from tico.passes.legalize_causal_mask_value import LegalizeCausalMaskValue
|
|
|
47
49
|
from tico.passes.legalize_predefined_layout_operators import (
|
|
48
50
|
LegalizePreDefinedLayoutOperators,
|
|
49
51
|
)
|
|
52
|
+
from tico.passes.lower_copy import LowerCopy
|
|
50
53
|
from tico.passes.lower_pow2_to_mul import LowerPow2ToMul
|
|
51
54
|
from tico.passes.lower_to_resize_nearest_neighbor import LowerToResizeNearestNeighbor
|
|
52
55
|
from tico.passes.lower_to_slice import passes as LowerToSlicePasses
|
|
@@ -97,12 +100,15 @@ def traced_run_decompositions(exported_program: ExportedProgram):
|
|
|
97
100
|
torch.ops.aten.conv2d.padding,
|
|
98
101
|
torch.ops.aten.conv1d.default,
|
|
99
102
|
torch.ops.aten.conv1d.padding,
|
|
103
|
+
torch.ops.aten.conv3d.default,
|
|
104
|
+
torch.ops.aten.conv3d.padding,
|
|
100
105
|
torch.ops.aten.conv_transpose2d.input,
|
|
101
106
|
torch.ops.aten.instance_norm.default,
|
|
102
107
|
torch.ops.aten._safe_softmax.default,
|
|
103
108
|
torch.ops.aten.relu6.default, # Do not decompose to hardtanh
|
|
104
109
|
torch.ops.aten.linear.default,
|
|
105
110
|
torch.ops.aten.upsample_nearest2d.vec,
|
|
111
|
+
torch.ops.aten.rms_norm.default,
|
|
106
112
|
)
|
|
107
113
|
ep = ep.run_decompositions(_preserve_ops=_preserve_ops)
|
|
108
114
|
|
|
@@ -115,6 +121,8 @@ def traced_run_decompositions(exported_program: ExportedProgram):
|
|
|
115
121
|
torch.ops.aten.conv2d.padding,
|
|
116
122
|
torch.ops.aten.conv1d.default,
|
|
117
123
|
torch.ops.aten.conv1d.padding,
|
|
124
|
+
torch.ops.aten.conv3d.default,
|
|
125
|
+
torch.ops.aten.conv3d.padding,
|
|
118
126
|
torch.ops.aten.conv_transpose2d.input,
|
|
119
127
|
torch.ops.aten.instance_norm.default,
|
|
120
128
|
torch.ops.aten._safe_softmax.default,
|
|
@@ -138,6 +146,7 @@ def traced_run_decompositions(exported_program: ExportedProgram):
|
|
|
138
146
|
or torch.__version__.startswith("2.8")
|
|
139
147
|
or torch.__version__.startswith("2.9")
|
|
140
148
|
or torch.__version__.startswith("2.10")
|
|
149
|
+
or torch.__version__.startswith("2.11")
|
|
141
150
|
):
|
|
142
151
|
return run_decompositions(exported_program)
|
|
143
152
|
else:
|
|
@@ -224,6 +233,8 @@ def convert_exported_module_to_circle(
|
|
|
224
233
|
FillMetaVal(),
|
|
225
234
|
ExtractDtypeKwargsPass(),
|
|
226
235
|
RemoveNop(),
|
|
236
|
+
LowerCopy(),
|
|
237
|
+
ConvertSymSizeToCircleShape(),
|
|
227
238
|
ConvertLayoutOpToReshape(),
|
|
228
239
|
RestoreLinear(),
|
|
229
240
|
ConvertToReLU6(),
|
|
@@ -258,6 +269,7 @@ def convert_exported_module_to_circle(
|
|
|
258
269
|
LegalizePreDefinedLayoutOperators(),
|
|
259
270
|
LowerPow2ToMul(),
|
|
260
271
|
ConvertConv1dToConv2d(),
|
|
272
|
+
ConvertConv3dToConv2d(),
|
|
261
273
|
*LowerToSlicePasses(),
|
|
262
274
|
FuseLeadingUnsqueezeReshape(),
|
|
263
275
|
CastClampMixedTypeArgs(),
|
|
@@ -316,7 +328,10 @@ def convert(
|
|
|
316
328
|
mod, args, kwargs, dynamic_shapes=dynamic_shapes, strict=strict
|
|
317
329
|
)
|
|
318
330
|
|
|
319
|
-
|
|
331
|
+
with SuppressWarning(FutureWarning, ".*LeafSpec*"):
|
|
332
|
+
circle_binary = convert_exported_module_to_circle(
|
|
333
|
+
exported_program, config=config
|
|
334
|
+
)
|
|
320
335
|
|
|
321
336
|
return CircleModel(circle_binary)
|
|
322
337
|
|
tico/utils/padding.py
CHANGED
|
@@ -35,6 +35,7 @@ class ConvPaddingInfo(NamedTuple):
|
|
|
35
35
|
|
|
36
36
|
conv_padding_type: ConvPadding
|
|
37
37
|
explicit_pad_hw: Optional[Tuple[int, int]] # None -> no extra Pad() op needed
|
|
38
|
+
output_crop_hw: Optional[Tuple[int, int]]
|
|
38
39
|
|
|
39
40
|
|
|
40
41
|
def identify_padding(
|
|
@@ -42,6 +43,7 @@ def identify_padding(
|
|
|
42
43
|
input_shape: Sequence[int | torch.SymInt] | torch.Size,
|
|
43
44
|
output_shape: Sequence[int | torch.SymInt] | torch.Size,
|
|
44
45
|
stride: Sequence[int],
|
|
46
|
+
is_transpose: bool = False,
|
|
45
47
|
) -> ConvPaddingInfo:
|
|
46
48
|
"""
|
|
47
49
|
Normalizes all PyTorch `padding` variants to a single decision.
|
|
@@ -61,9 +63,9 @@ def identify_padding(
|
|
|
61
63
|
if isinstance(padding, str):
|
|
62
64
|
pad = padding.lower()
|
|
63
65
|
if pad == "valid":
|
|
64
|
-
return ConvPaddingInfo(ConvPadding.VALID, None)
|
|
66
|
+
return ConvPaddingInfo(ConvPadding.VALID, None, None)
|
|
65
67
|
if pad == "same":
|
|
66
|
-
return ConvPaddingInfo(ConvPadding.SAME, None)
|
|
68
|
+
return ConvPaddingInfo(ConvPadding.SAME, None, None)
|
|
67
69
|
raise InvalidArgumentError(f"Unknown padding string: {padding}")
|
|
68
70
|
|
|
69
71
|
# ─── 2. List / tuple form ─────────────────────────────────────────────
|
|
@@ -73,15 +75,21 @@ def identify_padding(
|
|
|
73
75
|
)
|
|
74
76
|
|
|
75
77
|
pad_h, pad_w = padding
|
|
78
|
+
|
|
79
|
+
if is_transpose:
|
|
80
|
+
if pad_h == 0 and pad_w == 0:
|
|
81
|
+
return ConvPaddingInfo(ConvPadding.VALID, None, None)
|
|
82
|
+
return ConvPaddingInfo(ConvPadding.VALID, None, (pad_h, pad_w))
|
|
83
|
+
|
|
76
84
|
# [0, 0] → VALID
|
|
77
85
|
if pad_h == 0 and pad_w == 0:
|
|
78
|
-
return ConvPaddingInfo(ConvPadding.VALID, None)
|
|
86
|
+
return ConvPaddingInfo(ConvPadding.VALID, None, None)
|
|
79
87
|
|
|
80
88
|
# SAME heuristic: output H/W already match input when stride is 1
|
|
81
89
|
hw_in = tuple(input_shape[1:3])
|
|
82
90
|
hw_out = tuple(output_shape[1:3])
|
|
83
91
|
if hw_in == hw_out and stride == [1, 1]:
|
|
84
|
-
return ConvPaddingInfo(ConvPadding.SAME, None)
|
|
92
|
+
return ConvPaddingInfo(ConvPadding.SAME, None, None)
|
|
85
93
|
|
|
86
94
|
# Anything else = explicit symmetric padding
|
|
87
|
-
return ConvPaddingInfo(ConvPadding.VALID, (pad_h, pad_w))
|
|
95
|
+
return ConvPaddingInfo(ConvPadding.VALID, (pad_h, pad_w), None)
|
tico/utils/record_input.py
CHANGED
|
@@ -49,13 +49,13 @@ class RecordingInput:
|
|
|
49
49
|
def __init__(
|
|
50
50
|
self,
|
|
51
51
|
module: nn.Module,
|
|
52
|
-
condition: Callable[[dict], bool] =
|
|
52
|
+
condition: Optional[Callable[[dict], bool]] = None,
|
|
53
53
|
*,
|
|
54
54
|
input_to_remove: Optional[List[str]] = [],
|
|
55
55
|
):
|
|
56
56
|
self.module = module
|
|
57
57
|
self.forward_org = module.forward
|
|
58
|
-
self.condition = condition
|
|
58
|
+
self.condition = condition or (lambda args_dict: True)
|
|
59
59
|
self.input_to_remove = input_to_remove
|
|
60
60
|
self.sig = inspect.signature(self.forward_org)
|
|
61
61
|
|
tico/utils/register_custom_op.py
CHANGED
|
@@ -727,6 +727,67 @@ def CircleRMSNorm():
|
|
|
727
727
|
return hidden_states.new_empty(hidden_states.size())
|
|
728
728
|
|
|
729
729
|
|
|
730
|
+
def CircleAttention():
|
|
731
|
+
@custom_op("circle_custom::attention", mutates_args=())
|
|
732
|
+
def attention(
|
|
733
|
+
hidden_states: torch.Tensor,
|
|
734
|
+
wq: torch.Tensor,
|
|
735
|
+
wk: torch.Tensor,
|
|
736
|
+
wv: torch.Tensor,
|
|
737
|
+
wo: torch.Tensor,
|
|
738
|
+
position_cos: torch.Tensor,
|
|
739
|
+
position_sin: torch.Tensor,
|
|
740
|
+
attention_mask: torch.Tensor,
|
|
741
|
+
past_key: torch.Tensor,
|
|
742
|
+
past_value: torch.Tensor,
|
|
743
|
+
cache_position: torch.Tensor,
|
|
744
|
+
) -> torch.Tensor:
|
|
745
|
+
return None
|
|
746
|
+
|
|
747
|
+
@register_fake("circle_custom::attention")
|
|
748
|
+
def _(
|
|
749
|
+
hidden_states: torch.Tensor,
|
|
750
|
+
wq: torch.Tensor,
|
|
751
|
+
wk: torch.Tensor,
|
|
752
|
+
wv: torch.Tensor,
|
|
753
|
+
wo: torch.Tensor,
|
|
754
|
+
position_cos: torch.Tensor,
|
|
755
|
+
position_sin: torch.Tensor,
|
|
756
|
+
attention_mask: torch.Tensor,
|
|
757
|
+
past_key: torch.Tensor,
|
|
758
|
+
past_value: torch.Tensor,
|
|
759
|
+
cache_position: torch.Tensor,
|
|
760
|
+
) -> torch.Tensor:
|
|
761
|
+
return hidden_states
|
|
762
|
+
|
|
763
|
+
|
|
764
|
+
def CircleShape():
|
|
765
|
+
"""
|
|
766
|
+
Custom operator to extract the shape of a tensor.
|
|
767
|
+
This is similar to TensorFlow's shape operator and is used to preserve
|
|
768
|
+
dynamic shape information in the Circle model.
|
|
769
|
+
|
|
770
|
+
Args:
|
|
771
|
+
input_: Input tensor
|
|
772
|
+
|
|
773
|
+
Returns:
|
|
774
|
+
A 1D tensor containing the shape of the input tensor
|
|
775
|
+
"""
|
|
776
|
+
|
|
777
|
+
@custom_op("circle_custom::shape", mutates_args=())
|
|
778
|
+
def shape(input_: torch.Tensor) -> torch.Tensor:
|
|
779
|
+
# Return the shape of the input tensor as a 1D tensor
|
|
780
|
+
shape_val = list(input_.size())
|
|
781
|
+
return torch.tensor(shape_val, dtype=torch.int32)
|
|
782
|
+
|
|
783
|
+
@register_fake("circle_custom::shape")
|
|
784
|
+
def _(input_: torch.Tensor) -> torch.Tensor:
|
|
785
|
+
# Return a 1D tensor with symbolic shape
|
|
786
|
+
# The actual value will be determined at runtime
|
|
787
|
+
rank = len(input_.size())
|
|
788
|
+
return torch.empty([rank], dtype=torch.int32)
|
|
789
|
+
|
|
790
|
+
|
|
730
791
|
# Add custom ops to the torch namespace
|
|
731
792
|
def RegisterOps():
|
|
732
793
|
CircleResizeNearestNeighbor()
|
|
@@ -740,3 +801,5 @@ def RegisterOps():
|
|
|
740
801
|
CircleInstanceNorm()
|
|
741
802
|
CircleQuantizeMX()
|
|
742
803
|
CircleRMSNorm()
|
|
804
|
+
CircleAttention()
|
|
805
|
+
CircleShape()
|
|
@@ -171,6 +171,26 @@ class CatArgs:
|
|
|
171
171
|
dim: int = 0
|
|
172
172
|
|
|
173
173
|
|
|
174
|
+
@enforce_type
|
|
175
|
+
@dataclass
|
|
176
|
+
class CircleAttentionArgs:
|
|
177
|
+
"""
|
|
178
|
+
For circle.BuiltinOperator.BuiltinOperator.ATTENTION
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
hidden_states: torch.fx.Node
|
|
182
|
+
wq: torch.fx.Node
|
|
183
|
+
wk: torch.fx.Node
|
|
184
|
+
wv: torch.fx.Node
|
|
185
|
+
wo: torch.fx.Node
|
|
186
|
+
position_cos: torch.fx.Node
|
|
187
|
+
position_sin: torch.fx.Node
|
|
188
|
+
attention_mask: torch.fx.Node
|
|
189
|
+
past_key: torch.fx.Node
|
|
190
|
+
past_value: torch.fx.Node
|
|
191
|
+
cache_position: torch.fx.Node
|
|
192
|
+
|
|
193
|
+
|
|
174
194
|
@enforce_type
|
|
175
195
|
@dataclass
|
|
176
196
|
class CircleRMSNormArgs:
|
|
@@ -299,6 +319,27 @@ class Conv1DArgs:
|
|
|
299
319
|
assert len(self.dilation) == 1, len(self.dilation)
|
|
300
320
|
|
|
301
321
|
|
|
322
|
+
@enforce_type
|
|
323
|
+
@dataclass
|
|
324
|
+
class Conv3DArgs:
|
|
325
|
+
"""
|
|
326
|
+
conv3d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, SymInt groups=1) -> Tensor
|
|
327
|
+
conv3d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, str padding="valid", SymInt[3] dilation=1, SymInt groups=1) -> Tensor
|
|
328
|
+
"""
|
|
329
|
+
|
|
330
|
+
input: torch.fx.Node
|
|
331
|
+
weight: torch.fx.Node
|
|
332
|
+
bias: Union[torch.fx.Node, None] = None
|
|
333
|
+
stride: List[int] = field(default_factory=lambda: [1, 1, 1])
|
|
334
|
+
padding: Union[List[int], str] = field(default_factory=lambda: [0, 0, 0])
|
|
335
|
+
dilation: List[int] = field(default_factory=lambda: [1, 1, 1])
|
|
336
|
+
groups: int = 1
|
|
337
|
+
|
|
338
|
+
def __post_init__(self):
|
|
339
|
+
assert len(self.stride) == 3, len(self.stride)
|
|
340
|
+
assert len(self.dilation) == 3, len(self.dilation)
|
|
341
|
+
|
|
342
|
+
|
|
302
343
|
@enforce_type
|
|
303
344
|
@dataclass
|
|
304
345
|
class CopyArgs:
|
|
@@ -930,7 +971,7 @@ class RepeatArgs:
|
|
|
930
971
|
"""
|
|
931
972
|
|
|
932
973
|
input: torch.fx.Node
|
|
933
|
-
repeats: List[int]
|
|
974
|
+
repeats: List[Union[int, torch.SymInt, torch.fx.Node]]
|
|
934
975
|
|
|
935
976
|
|
|
936
977
|
@enforce_type
|
|
@@ -938,10 +979,14 @@ class RepeatArgs:
|
|
|
938
979
|
class ReshapeArgs:
|
|
939
980
|
"""
|
|
940
981
|
reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)
|
|
982
|
+
|
|
983
|
+
Note: After PrepareReshapeDynamicShape pass, shape can be either:
|
|
984
|
+
- A list of int/SymInt/Node (original or static)
|
|
985
|
+
- A single Node (dynamic shape tensor prepared by the pass)
|
|
941
986
|
"""
|
|
942
987
|
|
|
943
988
|
input: torch.fx.Node
|
|
944
|
-
shape: List[int]
|
|
989
|
+
shape: Union[List[Union[int, torch.SymInt, torch.fx.Node]], torch.fx.Node]
|
|
945
990
|
|
|
946
991
|
|
|
947
992
|
@enforce_type
|
|
@@ -1077,7 +1122,7 @@ class SplitWithSizesArgs:
|
|
|
1077
1122
|
"""
|
|
1078
1123
|
|
|
1079
1124
|
input: torch.fx.Node
|
|
1080
|
-
split_sizes: List[int]
|
|
1125
|
+
split_sizes: List[Union[int, torch.SymInt, torch.fx.Node]]
|
|
1081
1126
|
dim: int = 0
|
|
1082
1127
|
|
|
1083
1128
|
|
|
@@ -1218,7 +1263,7 @@ class ViewArgs:
|
|
|
1218
1263
|
"""
|
|
1219
1264
|
|
|
1220
1265
|
input: torch.fx.Node
|
|
1221
|
-
size: List[int]
|
|
1266
|
+
size: List[Union[int, torch.SymInt, torch.fx.Node]]
|
|
1222
1267
|
|
|
1223
1268
|
|
|
1224
1269
|
@enforce_type
|