tico 0.1.0.dev250728__py3-none-any.whl → 0.1.0.dev250729__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/serialize/circle_graph.py +2 -1
- tico/serialize/circle_mapping.py +29 -11
- tico/serialize/operators/op_avg_pool2d.py +11 -14
- tico/serialize/operators/op_conv2d.py +14 -11
- tico/serialize/operators/op_depthwise_conv2d.py +15 -11
- tico/serialize/operators/op_max_pool2d_with_indices.py +12 -12
- tico/serialize/operators/op_repeat.py +13 -8
- tico/serialize/operators/op_transpose_conv.py +7 -7
- tico/utils/padding.py +4 -2
- tico/utils/record_input.py +92 -0
- tico/utils/serialize.py +6 -1
- {tico-0.1.0.dev250728.dist-info → tico-0.1.0.dev250729.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250728.dist-info → tico-0.1.0.dev250729.dist-info}/RECORD +18 -17
- {tico-0.1.0.dev250728.dist-info → tico-0.1.0.dev250729.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250728.dist-info → tico-0.1.0.dev250729.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250728.dist-info → tico-0.1.0.dev250729.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250728.dist-info → tico-0.1.0.dev250729.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
tico/serialize/circle_graph.py
CHANGED
@@ -27,6 +27,7 @@ from tico.serialize.circle_mapping import (
|
|
27
27
|
extract_circle_shape,
|
28
28
|
str_to_circle_dtype,
|
29
29
|
to_circle_dtype,
|
30
|
+
to_circle_shape,
|
30
31
|
)
|
31
32
|
from tico.serialize.pack import pack_buffer
|
32
33
|
from tico.serialize.quant_param import QPARAM_KEY, QuantParam
|
@@ -186,7 +187,7 @@ class CircleSubgraph(circle.SubGraph.SubGraphT):
|
|
186
187
|
torch_t = torch.as_tensor(data=data)
|
187
188
|
torch_t_shape = list(torch_t.size())
|
188
189
|
tensor.type = to_circle_dtype(torch_dtype=torch_t.dtype)
|
189
|
-
tensor.shape = torch_t_shape
|
190
|
+
tensor.shape, tensor.shapeSignature = to_circle_shape(torch_t_shape)
|
190
191
|
|
191
192
|
buffer = circle.Buffer.BufferT()
|
192
193
|
buffer.data = torch_t.flatten().cpu().numpy().view(np.uint8) # type: ignore[assignment]
|
tico/serialize/circle_mapping.py
CHANGED
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import List, Optional, Tuple, TYPE_CHECKING, Union
|
15
|
+
from typing import List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
|
16
16
|
|
17
17
|
if TYPE_CHECKING:
|
18
18
|
import torch.fx
|
@@ -132,18 +132,36 @@ def extract_circle_shape(node: torch.fx.Node) -> Tuple[List[int], Optional[List[
|
|
132
132
|
return to_circle_shape(extract_shape(node))
|
133
133
|
|
134
134
|
|
135
|
-
def to_circle_shape(
|
136
|
-
|
137
|
-
|
135
|
+
def to_circle_shape(
|
136
|
+
torch_shape: Union[
|
137
|
+
torch.Size, Sequence[int | torch.SymInt]
|
138
|
+
], # Sequence[int | torch.SymInt] is added for type covariance
|
139
|
+
) -> Tuple[List[int], Optional[List[int]]]:
|
138
140
|
|
139
|
-
if any(isinstance(s, torch.SymInt) for s in
|
140
|
-
|
141
|
-
|
141
|
+
if any(isinstance(s, torch.SymInt) for s in torch_shape):
|
142
|
+
# Follow dynamic shape spec
|
143
|
+
shape = []
|
144
|
+
shape_signature = []
|
145
|
+
for s in torch_shape:
|
142
146
|
if isinstance(s, torch.SymInt):
|
143
|
-
shape
|
144
|
-
shape_signature
|
145
|
-
|
146
|
-
|
147
|
+
shape.append(1)
|
148
|
+
shape_signature.append(-1)
|
149
|
+
elif isinstance(s, int):
|
150
|
+
shape.append(s)
|
151
|
+
shape_signature.append(s)
|
152
|
+
else:
|
153
|
+
raise RuntimeError(f"Unsupported shape {torch_shape}")
|
154
|
+
return shape, shape_signature
|
155
|
+
else:
|
156
|
+
# Follow static shape spec
|
157
|
+
shape = []
|
158
|
+
shape_signature = None
|
159
|
+
for s in torch_shape:
|
160
|
+
if isinstance(s, int):
|
161
|
+
shape.append(s)
|
162
|
+
else:
|
163
|
+
assert False, "Cannot reach here"
|
164
|
+
return shape, shape_signature
|
147
165
|
|
148
166
|
|
149
167
|
def validate_circle_shape(shape: List[int], shape_signature: Optional[List[int]]):
|
@@ -22,7 +22,11 @@ import torch
|
|
22
22
|
from circle_schema import circle
|
23
23
|
|
24
24
|
from tico.serialize.circle_graph import CircleSubgraph
|
25
|
-
from tico.serialize.circle_mapping import
|
25
|
+
from tico.serialize.circle_mapping import (
|
26
|
+
extract_circle_dtype,
|
27
|
+
extract_shape,
|
28
|
+
to_circle_shape,
|
29
|
+
)
|
26
30
|
from tico.serialize.operators.hashable_opcode import OpCode
|
27
31
|
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
28
32
|
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
@@ -57,12 +61,7 @@ class AvgPool2DVisitor(NodeVisitor):
|
|
57
61
|
return True
|
58
62
|
|
59
63
|
def has_same_padding(self, args: AvgPool2dArgs) -> bool:
|
60
|
-
input_shape
|
61
|
-
|
62
|
-
if input_shape_signature is not None:
|
63
|
-
# TODO: support dynamic shapes
|
64
|
-
raise NotImplementedError("Dynamic shape is not supported yet")
|
65
|
-
|
64
|
+
input_shape: torch.Size = extract_shape(args.input)
|
66
65
|
kernel_size = args.kernel_size
|
67
66
|
stride = args.stride
|
68
67
|
assert stride
|
@@ -142,11 +141,7 @@ class AvgPool2DVisitor(NodeVisitor):
|
|
142
141
|
],
|
143
142
|
dtype=torch.int32,
|
144
143
|
)
|
145
|
-
input_shape
|
146
|
-
|
147
|
-
if input_shape_signature is not None:
|
148
|
-
raise RuntimeError("Dynamic shape is not supported yet.")
|
149
|
-
|
144
|
+
input_shape = extract_shape(input)
|
150
145
|
input_dtype: int = extract_circle_dtype(input)
|
151
146
|
padded_input_shape = [
|
152
147
|
input_shape[0],
|
@@ -156,11 +151,13 @@ class AvgPool2DVisitor(NodeVisitor):
|
|
156
151
|
]
|
157
152
|
padded_input_shape[1] += padding[0] * 2
|
158
153
|
padded_input_shape[2] += padding[1] * 2
|
154
|
+
|
159
155
|
# create padded input tensor
|
156
|
+
padded_cshape, padded_cshape_signature = to_circle_shape(padded_input_shape)
|
160
157
|
padded_input_tensor = self.graph.add_tensor_from_scratch(
|
161
158
|
prefix=f"{input.name}_pad_output",
|
162
|
-
shape=
|
163
|
-
shape_signature=
|
159
|
+
shape=padded_cshape,
|
160
|
+
shape_signature=padded_cshape_signature,
|
164
161
|
dtype=input_dtype,
|
165
162
|
source_node=node,
|
166
163
|
)
|
@@ -20,7 +20,11 @@ if TYPE_CHECKING:
|
|
20
20
|
import torch
|
21
21
|
from circle_schema import circle
|
22
22
|
|
23
|
-
from tico.serialize.circle_mapping import
|
23
|
+
from tico.serialize.circle_mapping import (
|
24
|
+
extract_circle_dtype,
|
25
|
+
extract_shape,
|
26
|
+
to_circle_shape,
|
27
|
+
)
|
24
28
|
from tico.serialize.operators.hashable_opcode import OpCode
|
25
29
|
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
30
|
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
@@ -111,17 +115,13 @@ class Conv2dVisitor(NodeVisitor):
|
|
111
115
|
|
112
116
|
assert groups == 1, "Only support group 1 conv2d"
|
113
117
|
|
114
|
-
input_shape
|
115
|
-
output_shape
|
116
|
-
weight_shape
|
118
|
+
input_shape = extract_shape(input_)
|
119
|
+
output_shape = extract_shape(node)
|
120
|
+
weight_shape = extract_shape(weight)
|
117
121
|
assert len(input_shape) == 4, len(input_shape)
|
118
122
|
assert len(output_shape) == 4, len(output_shape)
|
119
123
|
assert len(weight_shape) == 4, len(weight_shape)
|
120
124
|
|
121
|
-
if input_shape_signature is not None:
|
122
|
-
# TODO: support dynamic shapes
|
123
|
-
raise NotImplementedError("Dynamic shape is not supported yet")
|
124
|
-
|
125
125
|
pad_decision = identify_padding(padding, input_shape, output_shape, stride)
|
126
126
|
|
127
127
|
conv_input: torch.fx.Node | circle.Tensor.TensorT = input_
|
@@ -136,18 +136,21 @@ class Conv2dVisitor(NodeVisitor):
|
|
136
136
|
],
|
137
137
|
dtype=torch.int32,
|
138
138
|
)
|
139
|
-
pad_output_shape = [
|
139
|
+
pad_output_shape: List[int | torch.SymInt] = [
|
140
140
|
input_shape[0],
|
141
141
|
input_shape[1] + pad_h * 2,
|
142
142
|
input_shape[2] + pad_w * 2,
|
143
143
|
input_shape[3],
|
144
144
|
]
|
145
|
+
pad_output_cshape, pad_output_cshape_signature = to_circle_shape(
|
146
|
+
pad_output_shape
|
147
|
+
)
|
145
148
|
# create padded output tensor
|
146
149
|
input_qparam: Optional[QuantParam] = input_.meta.get(QPARAM_KEY)
|
147
150
|
pad_output = self.graph.add_tensor_from_scratch(
|
148
151
|
prefix=f"{node.name}_input_pad_output",
|
149
|
-
shape=
|
150
|
-
shape_signature=
|
152
|
+
shape=pad_output_cshape,
|
153
|
+
shape_signature=pad_output_cshape_signature,
|
151
154
|
dtype=extract_circle_dtype(input_),
|
152
155
|
qparam=input_qparam,
|
153
156
|
source_node=node,
|
@@ -20,7 +20,11 @@ if TYPE_CHECKING:
|
|
20
20
|
import torch
|
21
21
|
from circle_schema import circle
|
22
22
|
|
23
|
-
from tico.serialize.circle_mapping import
|
23
|
+
from tico.serialize.circle_mapping import (
|
24
|
+
extract_circle_dtype,
|
25
|
+
extract_shape,
|
26
|
+
to_circle_shape,
|
27
|
+
)
|
24
28
|
from tico.serialize.operators.hashable_opcode import OpCode
|
25
29
|
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
30
|
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
@@ -115,17 +119,13 @@ class DepthwiseConv2dVisitor(NodeVisitor):
|
|
115
119
|
dilation = args.dilation
|
116
120
|
groups = args.groups
|
117
121
|
|
118
|
-
input_shape
|
119
|
-
output_shape
|
120
|
-
weight_shape
|
122
|
+
input_shape = extract_shape(input_) # OHWI
|
123
|
+
output_shape = extract_shape(node) # OHWI
|
124
|
+
weight_shape = extract_shape(weight) # 1HWO
|
121
125
|
assert len(input_shape) == 4, len(input_shape)
|
122
126
|
assert len(output_shape) == 4, len(output_shape)
|
123
127
|
assert len(weight_shape) == 4, len(weight_shape)
|
124
128
|
|
125
|
-
if input_shape_signature is not None:
|
126
|
-
# TODO: support dynamic shapes
|
127
|
-
raise NotImplementedError("Dynamic shape is not supported yet")
|
128
|
-
|
129
129
|
assert weight_shape[0] == 1
|
130
130
|
assert weight_shape[3] == output_shape[3]
|
131
131
|
assert input_shape[3] == groups
|
@@ -150,18 +150,22 @@ class DepthwiseConv2dVisitor(NodeVisitor):
|
|
150
150
|
],
|
151
151
|
dtype=torch.int32,
|
152
152
|
)
|
153
|
-
pad_output_shape = [
|
153
|
+
pad_output_shape: List[int | torch.SymInt] = [
|
154
154
|
input_shape[0],
|
155
155
|
input_shape[1] + pad_h * 2,
|
156
156
|
input_shape[2] + pad_w * 2,
|
157
157
|
input_shape[3],
|
158
158
|
]
|
159
|
+
|
160
|
+
pad_output_cshape, pad_output_cshape_signature = to_circle_shape(
|
161
|
+
pad_output_shape
|
162
|
+
)
|
159
163
|
# create padded output tensor
|
160
164
|
input_qparam: Optional[QuantParam] = input_.meta.get(QPARAM_KEY)
|
161
165
|
pad_output = self.graph.add_tensor_from_scratch(
|
162
166
|
prefix=f"{node.name}_input_pad_output",
|
163
|
-
shape=
|
164
|
-
shape_signature=
|
167
|
+
shape=pad_output_cshape,
|
168
|
+
shape_signature=pad_output_cshape_signature,
|
165
169
|
dtype=extract_circle_dtype(input_),
|
166
170
|
qparam=input_qparam,
|
167
171
|
source_node=node,
|
@@ -24,8 +24,8 @@ from circle_schema import circle
|
|
24
24
|
from tico.serialize.circle_graph import CircleSubgraph
|
25
25
|
from tico.serialize.circle_mapping import (
|
26
26
|
extract_circle_dtype,
|
27
|
-
extract_circle_shape,
|
28
27
|
extract_shape,
|
28
|
+
to_circle_shape,
|
29
29
|
)
|
30
30
|
from tico.serialize.operators.hashable_opcode import OpCode
|
31
31
|
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
@@ -92,13 +92,15 @@ class MaxPool2DWithIndicesVisitor(NodeVisitor):
|
|
92
92
|
],
|
93
93
|
dtype=torch.int32,
|
94
94
|
)
|
95
|
-
input_shape, input_shape_signature = extract_circle_shape(input)
|
96
|
-
|
97
|
-
if input_shape_signature is not None:
|
98
|
-
# TODO: support dynamic shape
|
99
|
-
raise NotImplementedError("Padding with dynamic shape is not supported")
|
100
95
|
|
96
|
+
input_shape = extract_shape(input)
|
101
97
|
input_dtype: int = extract_circle_dtype(input)
|
98
|
+
|
99
|
+
input_qparam: Optional[QuantParam] = (
|
100
|
+
input.meta[QPARAM_KEY] if QPARAM_KEY in input.meta else None
|
101
|
+
)
|
102
|
+
|
103
|
+
# create padded input tensor
|
102
104
|
padded_input_shape = [
|
103
105
|
input_shape[0],
|
104
106
|
input_shape[1],
|
@@ -107,18 +109,16 @@ class MaxPool2DWithIndicesVisitor(NodeVisitor):
|
|
107
109
|
]
|
108
110
|
padded_input_shape[1] += padding[0] * 2
|
109
111
|
padded_input_shape[2] += padding[1] * 2
|
110
|
-
|
111
|
-
input.meta[QPARAM_KEY] if QPARAM_KEY in input.meta else None
|
112
|
-
)
|
113
|
-
# create padded input tensor
|
112
|
+
padded_cshape, padded_cshape_signature = to_circle_shape(padded_input_shape)
|
114
113
|
padded_input_tensor = self.graph.add_tensor_from_scratch(
|
115
114
|
prefix=f"{input.name}_pad_output",
|
116
|
-
shape=
|
117
|
-
shape_signature=
|
115
|
+
shape=padded_cshape,
|
116
|
+
shape_signature=padded_cshape_signature,
|
118
117
|
dtype=input_dtype,
|
119
118
|
qparam=input_qparam,
|
120
119
|
source_node=node,
|
121
120
|
)
|
121
|
+
|
122
122
|
if input_qparam is not None:
|
123
123
|
padding_value = get_integer_dtype_min(input_qparam.dtype)
|
124
124
|
else:
|
@@ -21,7 +21,11 @@ import torch
|
|
21
21
|
from circle_schema import circle
|
22
22
|
|
23
23
|
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
-
from tico.serialize.circle_mapping import
|
24
|
+
from tico.serialize.circle_mapping import (
|
25
|
+
extract_circle_dtype,
|
26
|
+
extract_shape,
|
27
|
+
to_circle_shape,
|
28
|
+
)
|
25
29
|
from tico.serialize.operators.hashable_opcode import OpCode
|
26
30
|
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
27
31
|
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
@@ -51,10 +55,7 @@ class RepeatVisitor(NodeVisitor):
|
|
51
55
|
elif r < 0:
|
52
56
|
raise InvalidArgumentError("Only support positive repeat value")
|
53
57
|
|
54
|
-
tensor_shape
|
55
|
-
if tensor_shape_signature is not None:
|
56
|
-
# TODO: support dynamic shape
|
57
|
-
raise NotYetSupportedError("Repeat does not support dynamic shape yet.")
|
58
|
+
tensor_shape = extract_shape(input)
|
58
59
|
assert len(tensor_shape) <= len(repeats)
|
59
60
|
if len(tensor_shape) != len(repeats):
|
60
61
|
# TODO Support len(tensor_shape) < len(repeats)
|
@@ -73,12 +74,16 @@ class RepeatVisitor(NodeVisitor):
|
|
73
74
|
if r > 1:
|
74
75
|
# Except last created concat, a tensor should be created.
|
75
76
|
if repeat_dim_cnt > 1:
|
76
|
-
repeated_shape = list(tensor_shape)
|
77
|
+
repeated_shape: List[int | torch.SymInt] = list(tensor_shape)
|
77
78
|
repeated_shape[idx] = repeated_shape[idx] * r
|
79
|
+
|
80
|
+
repeated_cshape, repeated_cshape_signature = to_circle_shape(
|
81
|
+
repeated_shape
|
82
|
+
)
|
78
83
|
concat_output = self.graph.add_tensor_from_scratch(
|
79
84
|
prefix=f"{node.name}_concat_{idx}",
|
80
|
-
shape=
|
81
|
-
shape_signature=
|
85
|
+
shape=repeated_cshape,
|
86
|
+
shape_signature=repeated_cshape_signature,
|
82
87
|
dtype=tensor_dtype,
|
83
88
|
source_node=node,
|
84
89
|
)
|
@@ -24,6 +24,7 @@ from tico.serialize.circle_mapping import (
|
|
24
24
|
circle_legalize_dtype_to,
|
25
25
|
extract_circle_dtype,
|
26
26
|
extract_circle_shape,
|
27
|
+
to_circle_shape,
|
27
28
|
)
|
28
29
|
from tico.serialize.operators.hashable_opcode import OpCode
|
29
30
|
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
@@ -87,10 +88,6 @@ class TransposeConvVisitor(NodeVisitor):
|
|
87
88
|
assert len(output_shape) == 4, len(output_shape)
|
88
89
|
assert len(weight_shape) == 4, len(weight_shape)
|
89
90
|
|
90
|
-
if input_shape_signature is not None:
|
91
|
-
# TODO: support dynamic shapes
|
92
|
-
raise NotImplementedError("Dynamic shape is not supported yet")
|
93
|
-
|
94
91
|
pad_decision = identify_padding(padding, input_shape, output_shape, stride)
|
95
92
|
|
96
93
|
conv_input: torch.fx.Node | circle.Tensor.TensorT = input_
|
@@ -105,18 +102,21 @@ class TransposeConvVisitor(NodeVisitor):
|
|
105
102
|
],
|
106
103
|
dtype=torch.int32,
|
107
104
|
)
|
108
|
-
pad_output_shape = [
|
105
|
+
pad_output_shape: List[int | torch.SymInt] = [
|
109
106
|
input_shape[0],
|
110
107
|
input_shape[1] + pad_h * 2,
|
111
108
|
input_shape[2] + pad_w * 2,
|
112
109
|
input_shape[3],
|
113
110
|
]
|
111
|
+
pad_output_cshape, pad_output_cshape_signature = to_circle_shape(
|
112
|
+
pad_output_shape
|
113
|
+
)
|
114
114
|
# create padded output tensor
|
115
115
|
input_qparam: Optional[QuantParam] = input_.meta.get(QPARAM_KEY)
|
116
116
|
pad_output = self.graph.add_tensor_from_scratch(
|
117
117
|
prefix=f"{node.name}_input_pad_output",
|
118
|
-
shape=
|
119
|
-
shape_signature=
|
118
|
+
shape=pad_output_cshape,
|
119
|
+
shape_signature=pad_output_cshape_signature,
|
120
120
|
dtype=extract_circle_dtype(input_),
|
121
121
|
qparam=input_qparam,
|
122
122
|
source_node=node,
|
tico/utils/padding.py
CHANGED
@@ -15,6 +15,8 @@
|
|
15
15
|
from enum import IntEnum
|
16
16
|
from typing import NamedTuple, Optional, Sequence, Tuple, Union
|
17
17
|
|
18
|
+
import torch
|
19
|
+
|
18
20
|
from tico.utils.errors import InvalidArgumentError
|
19
21
|
|
20
22
|
|
@@ -37,8 +39,8 @@ class ConvPaddingInfo(NamedTuple):
|
|
37
39
|
|
38
40
|
def identify_padding(
|
39
41
|
padding: PaddingValue,
|
40
|
-
input_shape: Sequence[int],
|
41
|
-
output_shape: Sequence[int],
|
42
|
+
input_shape: Sequence[int | torch.SymInt] | torch.Size,
|
43
|
+
output_shape: Sequence[int | torch.SymInt] | torch.Size,
|
42
44
|
stride: Sequence[int],
|
43
45
|
) -> ConvPaddingInfo:
|
44
46
|
"""
|
@@ -0,0 +1,92 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import copy
|
16
|
+
|
17
|
+
import inspect
|
18
|
+
from contextlib import contextmanager
|
19
|
+
from typing import Callable, List, Optional
|
20
|
+
|
21
|
+
import torch.nn as nn
|
22
|
+
|
23
|
+
|
24
|
+
class RecordingInput:
|
25
|
+
r"""Context-manager that records the input values of model::forward()
|
26
|
+
|
27
|
+
Recording input is useful for preparing example input for torch.export
|
28
|
+
|
29
|
+
Args:
|
30
|
+
condition: lambda to provide the condition whether to record or not
|
31
|
+
|
32
|
+
For examples, if you want to capture only args["past_key_values"] is not None,
|
33
|
+
conditon = lambda args_dict: args_dict["past_key_value"] is not None
|
34
|
+
|
35
|
+
input_to_remove: list of arg names to remove
|
36
|
+
|
37
|
+
Sometimes you would like to remove some arg values to make exported graph tidy or correct
|
38
|
+
For example, "past_key_values" may be not None, but just an empty cache. Then,
|
39
|
+
input_to_remove = [ "past_key_values" ]; makes the life easy
|
40
|
+
|
41
|
+
Example::
|
42
|
+
>>> with RecordingInput(model, input_to_remove=input_to_remove) as rec:
|
43
|
+
... outputs = model.generate(
|
44
|
+
... **inputs,
|
45
|
+
... )
|
46
|
+
... captured_input = rec.captured_input
|
47
|
+
>>> circle_model = tico.convert(model, captured_input)
|
48
|
+
"""
|
49
|
+
|
50
|
+
def __init__(
|
51
|
+
self,
|
52
|
+
module: nn.Module,
|
53
|
+
condition: Callable[[dict], bool] = lambda args_dict: True,
|
54
|
+
*,
|
55
|
+
input_to_remove: Optional[List[str]] = [],
|
56
|
+
):
|
57
|
+
self.module = module
|
58
|
+
self.forward_org = module.forward
|
59
|
+
self.condition = condition
|
60
|
+
self.input_to_remove = input_to_remove
|
61
|
+
self.sig = inspect.signature(self.forward_org)
|
62
|
+
self.args_names = [
|
63
|
+
name
|
64
|
+
for name in self.sig.parameters.keys()
|
65
|
+
if name not in ("self", "kwargs")
|
66
|
+
]
|
67
|
+
self.captured_input = None
|
68
|
+
|
69
|
+
def __enter__(self):
|
70
|
+
def capture_and_forward(*args, **kwargs):
|
71
|
+
bound = self.sig.bind(*args, **kwargs)
|
72
|
+
bound.apply_defaults()
|
73
|
+
args_dict = dict(bound.arguments)
|
74
|
+
|
75
|
+
def populate_args(args_dict, input_to_remove):
|
76
|
+
for key in input_to_remove:
|
77
|
+
args_dict.pop(key, None)
|
78
|
+
args_tuple = tuple(
|
79
|
+
args_dict.get(name, None) for name in self.args_names
|
80
|
+
)
|
81
|
+
return copy.deepcopy(args_tuple)
|
82
|
+
|
83
|
+
if self.condition(args_dict) and self.captured_input is None:
|
84
|
+
self.captured_input = populate_args(args_dict, self.input_to_remove)
|
85
|
+
|
86
|
+
return self.forward_org(*args, **kwargs)
|
87
|
+
|
88
|
+
self.module.forward = capture_and_forward
|
89
|
+
return self
|
90
|
+
|
91
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
92
|
+
self.module.forward = self.forward_org
|
tico/utils/serialize.py
CHANGED
@@ -47,4 +47,9 @@ def validate_tensor_shapes(
|
|
47
47
|
Let's validate all tensors' shapes against their shape signatures.
|
48
48
|
"""
|
49
49
|
for tensor in graph.tensors:
|
50
|
-
|
50
|
+
try:
|
51
|
+
validate_circle_shape(tensor.shape, tensor.shapeSignature)
|
52
|
+
except Exception as e:
|
53
|
+
raise ValueError(
|
54
|
+
f"Tensor {tensor.name} has invalid shape ({tensor.shape}), shape_signature ({tensor.shapeSignature})"
|
55
|
+
) from e
|
@@ -1,4 +1,4 @@
|
|
1
|
-
tico/__init__.py,sha256=
|
1
|
+
tico/__init__.py,sha256=quSJ4KKyh76eIG3MrN15qViNKdudexAIk-h6x_L-hRc,1883
|
2
2
|
tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
|
3
3
|
tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
|
4
4
|
tico/config/base.py,sha256=q5xMqGxTUZs4mFqt5c7i_y9U00fYgdMGl9nUqIVMlCo,1248
|
@@ -96,8 +96,8 @@ tico/passes/remove_redundant_to_copy.py,sha256=tKy4XKkO2l33fMxVPQ_iFkUeFvP15kbPv
|
|
96
96
|
tico/passes/restore_linear.py,sha256=xGJdNb-1CrkOKS9BnLbcblkZc6P2vVjKIi-7lRcs7Bk,4111
|
97
97
|
tico/passes/segment_index_select.py,sha256=VVCKNLtYRkr9n5lGnlzEuQsQ0WVxEYXGchFrDnB1C40,5189
|
98
98
|
tico/serialize/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
99
|
-
tico/serialize/circle_graph.py,sha256=
|
100
|
-
tico/serialize/circle_mapping.py,sha256=
|
99
|
+
tico/serialize/circle_graph.py,sha256=qvyul_HULoz7B_6RFKQ8s9RjEvMgPq-ynMVkZe8aqE4,12034
|
100
|
+
tico/serialize/circle_mapping.py,sha256=c__AIHPi23lPugNJFolgMAKrw8j7gEeMaUQ1LAMSFnY,8542
|
101
101
|
tico/serialize/circle_serializer.py,sha256=BGK9tltKkoL1h4rcrJUgDJIGlHst7aF3cZAKJk_GPWc,10950
|
102
102
|
tico/serialize/pack.py,sha256=5HZ9kX3x6C6CyT_FWS6FRmvx_P7Dx21orjUNQxJ2xlo,1297
|
103
103
|
tico/serialize/quant_param.py,sha256=6nbGKdqwMI9Cx9BLXJ9A9JU4qb770S8vTM1vCZRX3Eo,1342
|
@@ -110,17 +110,17 @@ tico/serialize/operators/op_alias_copy.py,sha256=Xu9OiILbGf8oddh8yTqovvLfgVs8XYV
|
|
110
110
|
tico/serialize/operators/op_any.py,sha256=wrTXFQ1TWl-2ET2NGXAXI1dzfDDJsYtA71pyj2numPE,4968
|
111
111
|
tico/serialize/operators/op_arange_start_step.py,sha256=0T5lWwh3TfsFStmVv0v5qG03KENRDBmMix08RXQ4D-U,2132
|
112
112
|
tico/serialize/operators/op_argmax.py,sha256=ARyGHlmWVmzwCct93V5x1-VyKqhxMOvV8GuM8yQWXdo,2290
|
113
|
-
tico/serialize/operators/op_avg_pool2d.py,sha256=
|
113
|
+
tico/serialize/operators/op_avg_pool2d.py,sha256=eZl232FqYQsn2jUN_XjombHq_lzp2hf_AKwCLbZBxh8,7720
|
114
114
|
tico/serialize/operators/op_bmm.py,sha256=AELjHC9ISFPIzEEl5Kr1s4GSNLZElwZmVZJWkEyCEoA,2189
|
115
115
|
tico/serialize/operators/op_cat.py,sha256=XDYOh0XAyrM0TlxVm6Sa0OFFGrKk7aSDcGXC-hYX4gs,2204
|
116
116
|
tico/serialize/operators/op_clamp.py,sha256=RRQVrzayDfN3PioCVJqa_yYOtcYwb5HHwkMe4E_YPmE,4408
|
117
117
|
tico/serialize/operators/op_clone.py,sha256=vzDYJ8TS3tc2BAyd_z8nt5VqT1inpymSseMEhd9dva0,2394
|
118
118
|
tico/serialize/operators/op_constant_pad_nd.py,sha256=OpP4AP-d1IFcWZolNa-o9ZxzXJQkMdG9WQ66soX3s-E,2675
|
119
|
-
tico/serialize/operators/op_conv2d.py,sha256=
|
119
|
+
tico/serialize/operators/op_conv2d.py,sha256=1_vouWXaF51gDLYg8z5Zlup0Tecq_ggAzvguiHzFffw,6828
|
120
120
|
tico/serialize/operators/op_copy.py,sha256=boXHfl0bcvdBVl0tpzPMA_KBonh80vVqv61N3H5-PRU,6941
|
121
121
|
tico/serialize/operators/op_cos.py,sha256=N12bNyuTQIxRnD0eHRPdFVzRQPMy1NFM4iM8oQ4lYzw,2034
|
122
122
|
tico/serialize/operators/op_cumsum.py,sha256=px9ZGUDDsdWjrql8Z1FdXfF-7CJhditxyNz5QRZbLiM,3948
|
123
|
-
tico/serialize/operators/op_depthwise_conv2d.py,sha256=
|
123
|
+
tico/serialize/operators/op_depthwise_conv2d.py,sha256=U6_nX2V31Evm-HLN9b3RKIVg-m8jyD-Nw1GdePUPPjY,7284
|
124
124
|
tico/serialize/operators/op_dequantize_per_channel.py,sha256=aPcVxjdgvfSFoLnv9NL-RxO5vZYj8ulqriMP5LHIWs0,3133
|
125
125
|
tico/serialize/operators/op_dequantize_per_tensor.py,sha256=u9aK_Xle9rDN0EHLE0YrCTlXY4Q53Ch9Di4qmx7ynps,2304
|
126
126
|
tico/serialize/operators/op_div.py,sha256=WjeM2Ux7TyGlSNx2aVC783JvcL0xnY6FBYo1Q_kdb5Q,2201
|
@@ -144,7 +144,7 @@ tico/serialize/operators/op_logical_and.py,sha256=WhQ8knuq32BO-WhAqkOgpcUStPkjoP
|
|
144
144
|
tico/serialize/operators/op_logical_not.py,sha256=ugrVcRqR3IvUUaiRVW5cArCYJbzmkcXp88QM846jCww,2129
|
145
145
|
tico/serialize/operators/op_lt.py,sha256=_vA7dWpV9wVBxB7JL9pLQT9BsV91NGQBq_0auAtHK5Y,2080
|
146
146
|
tico/serialize/operators/op_max_dim.py,sha256=nS_TZl5uq4uv1LwgBD9Wddyac4atKqBiIWKIyeXse2s,2519
|
147
|
-
tico/serialize/operators/op_max_pool2d_with_indices.py,sha256=
|
147
|
+
tico/serialize/operators/op_max_pool2d_with_indices.py,sha256=i4iKZ262ytDKUt7bG9MiXuoKn--bgi-HWG24U5lvPPc,5919
|
148
148
|
tico/serialize/operators/op_maximum.py,sha256=JjBr6gWEnuakLuk1_feotTHfIIm3s5YqWmqhUMpSPI0,1873
|
149
149
|
tico/serialize/operators/op_mean.py,sha256=rVQZOxCJkHFY4kQBAS1HVK0HkcqxgkSy6zvEDLX_WYQ,2267
|
150
150
|
tico/serialize/operators/op_minimum.py,sha256=fASjQVcTPCin02umQwFPdq2ss-Ve7S5A33J3QmmQ_wQ,1873
|
@@ -159,7 +159,7 @@ tico/serialize/operators/op_quantize_per_tensor.py,sha256=w-vYxSPnN2gtx-pEkkcMGU
|
|
159
159
|
tico/serialize/operators/op_reciprocal.py,sha256=6b9_bxjg_0EvgAitSv1MgBi4PJSEgm-21s5qtWI1UR4,2394
|
160
160
|
tico/serialize/operators/op_relu.py,sha256=WXCR_chwEUBqjFIQ_4E2avwk-Acy76pmX20rJQCBTQo,1832
|
161
161
|
tico/serialize/operators/op_relu6.py,sha256=ZWqEolfAKjOdUC1ZCg0iuu4dBhkJRxVYR2tUzpbvKQM,1829
|
162
|
-
tico/serialize/operators/op_repeat.py,sha256=
|
162
|
+
tico/serialize/operators/op_repeat.py,sha256=VrRxD31pT3hRGH-5n6ia3PJBXh_u0GvIl1hZZYFrKTQ,4507
|
163
163
|
tico/serialize/operators/op_reshape.py,sha256=6wErQpmDX9mAmfJRCTg_cg1uOdJZqHm8Nux8dNI53Vg,2559
|
164
164
|
tico/serialize/operators/op_resize_nearest_neighbor.py,sha256=dXaAnZ5M_ko_tH-HolxNpHFXkDUQ8x45myskojP5XZE,2771
|
165
165
|
tico/serialize/operators/op_round.py,sha256=pe6w_TB4xGLu0iPv4Qo0a0fIkY9DgCgXk5127TWt8pE,1837
|
@@ -177,7 +177,7 @@ tico/serialize/operators/op_sub.py,sha256=yZskQJF0ylXVk02Uid8djPNIWDJ-0uHJar4UYh
|
|
177
177
|
tico/serialize/operators/op_sum.py,sha256=B5aSwQMhyoBe2JYdE5nVQ3QeVDSzL-yuZZujsG08OdQ,2294
|
178
178
|
tico/serialize/operators/op_tanh.py,sha256=rs7FsbQeUQ7Ak8RoQV9ymNGXHXRObojfY_SiqJiyqdA,1846
|
179
179
|
tico/serialize/operators/op_to_copy.py,sha256=a8T0uPMavMO_md1a-4_0dlvDHyZS_xew0qB6xjf69rI,3934
|
180
|
-
tico/serialize/operators/op_transpose_conv.py,sha256=
|
180
|
+
tico/serialize/operators/op_transpose_conv.py,sha256=9NLnWpitfQzSDF-iAgw2fBA3YHL5y2Y8DQipeo8OvYA,5826
|
181
181
|
tico/serialize/operators/op_unsqueeze.py,sha256=ZHhfVXSWEiwb2VDYX5uhxbGQyzZjKT7CrbBpVGxVHBU,2310
|
182
182
|
tico/serialize/operators/op_view.py,sha256=xxE-GvTJ1UpcHst5KXYz3qKY-eJQvXKKrSZiA2O7E40,2593
|
183
183
|
tico/serialize/operators/op_where.py,sha256=doE81GSwygrPBm3JIfN9w7kKXxeIYKxgk0eoY22QIcg,2845
|
@@ -192,11 +192,12 @@ tico/utils/graph.py,sha256=jD6m58m5JmN9mPfaROA9CW3406iJxmnukke00AuwRqI,9131
|
|
192
192
|
tico/utils/installed_packages.py,sha256=J0FTwnkCGs0MxRWoCMYAqiwH7Z0GWFDLV--x-IndSp4,1017
|
193
193
|
tico/utils/logging.py,sha256=IlbBWscsaHidI0dNqro1HEXAbIcbkR3BD5ukLy2m95k,1286
|
194
194
|
tico/utils/model.py,sha256=pPOIjD0qjQirLibiRxxfjOR6efimOcDAd9R-74eus-k,1282
|
195
|
-
tico/utils/padding.py,sha256=
|
195
|
+
tico/utils/padding.py,sha256=qKke-dJeeLHiRaePjDS66txrGyiYuipLVQeqLYad8uk,3349
|
196
196
|
tico/utils/passes.py,sha256=kGmDe__5cPaO6i5EDAoXSVe6yXEoX9hAny4ROb3ZEmQ,2409
|
197
197
|
tico/utils/pytree_utils.py,sha256=jrk3N6X6LiUnBCX_gM1K9nywbVAJBVnszlTAgeIeDUc,5219
|
198
|
+
tico/utils/record_input.py,sha256=FBtV00WWcXMXmg-Ujgvci9HjOmRJC1cVzx_WRNIF4MI,3324
|
198
199
|
tico/utils/register_custom_op.py,sha256=3-Yl6iYmx1qQA2igNHt4hYhQhQMkdPb7gF50LIY8yvc,27350
|
199
|
-
tico/utils/serialize.py,sha256=
|
200
|
+
tico/utils/serialize.py,sha256=mEuusEzi82WFsz3AkowgWwxSLeo50JDxyOj6yYDQhEI,1914
|
200
201
|
tico/utils/torch_compat.py,sha256=oc6PztVsXdHcQ3iaVR90wLLxrGaj6zFHWZ8K9rRS6q8,1795
|
201
202
|
tico/utils/trace_decorators.py,sha256=ddLIiKQfSaQrxgF1kNpwjFTQnXENzeSfcr1kuAW4jGI,3221
|
202
203
|
tico/utils/utils.py,sha256=A5p3iAAxRGDsZJh4ybp-Qo3MX3vk5RrmSY-R3rXqVeI,12976
|
@@ -205,9 +206,9 @@ tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
205
206
|
tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
|
206
207
|
tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
|
207
208
|
tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
|
208
|
-
tico-0.1.0.
|
209
|
-
tico-0.1.0.
|
210
|
-
tico-0.1.0.
|
211
|
-
tico-0.1.0.
|
212
|
-
tico-0.1.0.
|
213
|
-
tico-0.1.0.
|
209
|
+
tico-0.1.0.dev250729.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
|
210
|
+
tico-0.1.0.dev250729.dist-info/METADATA,sha256=ZcnGD8K56o04Pt172XYRpuh-DHPoFWDhUhEpaCsy23k,8430
|
211
|
+
tico-0.1.0.dev250729.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
212
|
+
tico-0.1.0.dev250729.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
|
213
|
+
tico-0.1.0.dev250729.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
|
214
|
+
tico-0.1.0.dev250729.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|