tico 0.1.0.dev251106__py3-none-any.whl → 0.2.0.dev260122__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +2 -2
- tico/_version.py +1 -0
- tico/passes/convert_conv3d_to_conv2d.py +435 -0
- tico/passes/convert_sym_size_to_circle_shape.py +99 -0
- tico/passes/decompose_batch_norm.py +9 -5
- tico/passes/lower_copy.py +95 -0
- tico/passes/ops.py +4 -0
- tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +251 -0
- tico/quantization/algorithm/fpi_gptq/quantizer.py +180 -0
- tico/quantization/algorithm/gptq/gptq.py +231 -11
- tico/quantization/algorithm/gptq/quantizer.py +18 -6
- tico/quantization/config/{pt2e.py → fpi_gptq.py} +11 -4
- tico/quantization/config/gptq.py +27 -4
- tico/quantization/public_interface.py +0 -10
- tico/quantization/wrapq/quantizer.py +2 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +51 -11
- tico/serialize/operators/adapters/onert/llama_attention.py +51 -0
- tico/serialize/operators/op_attention.py +58 -0
- tico/serialize/operators/op_circle_shape.py +64 -0
- tico/serialize/operators/op_dequantize_per_channel.py +1 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +1 -0
- tico/serialize/operators/op_transpose_conv.py +66 -50
- tico/utils/convert.py +16 -1
- tico/utils/padding.py +13 -5
- tico/utils/record_input.py +2 -2
- tico/utils/register_custom_op.py +63 -0
- tico/utils/validate_args_kwargs.py +49 -4
- tico-0.2.0.dev260122.dist-info/METADATA +631 -0
- {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info}/RECORD +35 -46
- {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info}/WHEEL +1 -1
- {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info}/entry_points.txt +0 -1
- tico/quantization/algorithm/pt2e/annotation/annotator.py +0 -208
- tico/quantization/algorithm/pt2e/annotation/config.py +0 -26
- tico/quantization/algorithm/pt2e/annotation/op/__init__.py +0 -21
- tico/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +0 -63
- tico/quantization/algorithm/pt2e/annotation/op/add.py +0 -55
- tico/quantization/algorithm/pt2e/annotation/op/conv2d.py +0 -90
- tico/quantization/algorithm/pt2e/annotation/op/div.py +0 -55
- tico/quantization/algorithm/pt2e/annotation/op/linear.py +0 -92
- tico/quantization/algorithm/pt2e/annotation/op/mean.py +0 -51
- tico/quantization/algorithm/pt2e/annotation/op/mul.py +0 -55
- tico/quantization/algorithm/pt2e/annotation/op/relu6.py +0 -51
- tico/quantization/algorithm/pt2e/annotation/op/rsqrt.py +0 -51
- tico/quantization/algorithm/pt2e/annotation/op/sub.py +0 -55
- tico/quantization/algorithm/pt2e/annotation/spec.py +0 -45
- tico/quantization/algorithm/pt2e/annotation/utils.py +0 -88
- tico/quantization/algorithm/pt2e/quantizer.py +0 -81
- tico/quantization/algorithm/pt2e/transformation/__init__.py +0 -1
- tico/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -58
- tico/quantization/algorithm/pt2e/utils.py +0 -135
- tico/serialize/operators/op_copy.py +0 -187
- tico-0.1.0.dev251106.dist-info/METADATA +0 -392
- /tico/quantization/algorithm/{pt2e → fpi_gptq}/__init__.py +0 -0
- /tico/{quantization/algorithm/pt2e/annotation → serialize/operators/adapters/onert}/__init__.py +0 -0
- {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info/licenses}/LICENSE +0 -0
- {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
|
@@ -28,8 +28,8 @@ __all__ = [
|
|
|
28
28
|
"convert_from_pt2",
|
|
29
29
|
]
|
|
30
30
|
|
|
31
|
-
# THIS LINE IS AUTOMATICALLY GENERATED
|
|
32
|
-
__version__ = "0.
|
|
31
|
+
# THIS LINE IS AUTOMATICALLY GENERATED
|
|
32
|
+
__version__ = "0.2.0"
|
|
33
33
|
|
|
34
34
|
MINIMUM_SUPPORTED_VERSION = "2.5.0"
|
|
35
35
|
SECURE_TORCH_VERSION = "2.6.0"
|
tico/_version.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.2.0.dev260122"
|
|
@@ -0,0 +1,435 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved.
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
|
|
14
|
+
from typing import List, TYPE_CHECKING
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
import torch.fx
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
from torch.export import ExportedProgram
|
|
21
|
+
|
|
22
|
+
from tico.serialize.circle_mapping import extract_shape
|
|
23
|
+
from tico.utils import logging
|
|
24
|
+
from tico.utils.errors import NotYetSupportedError
|
|
25
|
+
from tico.utils.graph import create_node
|
|
26
|
+
from tico.utils.passes import PassBase, PassResult
|
|
27
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
|
28
|
+
from tico.utils.utils import is_target_node
|
|
29
|
+
from tico.utils.validate_args_kwargs import Conv3DArgs
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@trace_graph_diff_on_pass
|
|
33
|
+
class ConvertConv3dToConv2d(PassBase):
|
|
34
|
+
"""
|
|
35
|
+
This pass converts `torch.ops.aten.conv3d` to multiple `torch.ops.aten.conv2d` operations
|
|
36
|
+
|
|
37
|
+
[before] input(dim=5) weight(dim=5)
|
|
38
|
+
│ │
|
|
39
|
+
│ │
|
|
40
|
+
conv3d<----------------+
|
|
41
|
+
│
|
|
42
|
+
│
|
|
43
|
+
output(dim=5)
|
|
44
|
+
|
|
45
|
+
[after] input(dim=5) weight(dim=5)
|
|
46
|
+
│ │
|
|
47
|
+
│ ┌───────┴───────┐
|
|
48
|
+
│ │ weight slice │
|
|
49
|
+
│ │ (kT times) │
|
|
50
|
+
│ └───────┬───────┘
|
|
51
|
+
│ │
|
|
52
|
+
│ ┌───────┴───────┐
|
|
53
|
+
│ │ squeeze dims │
|
|
54
|
+
│ │ (remove dim=2)│
|
|
55
|
+
│ └───────┬───────┘
|
|
56
|
+
│ │
|
|
57
|
+
│ ┌───────┴────────────┐
|
|
58
|
+
│ │ weight_2d[0..kT-1] │
|
|
59
|
+
│ │ [C_out,C_in,kH,kW] │
|
|
60
|
+
│ └───────┬────────────┘
|
|
61
|
+
│ │
|
|
62
|
+
┌─────────────────┴──────────────────────────────┐ |
|
|
63
|
+
│ temporal padding (if needed) │ |
|
|
64
|
+
│ ┌────────────┐ ┌────────────┐ ┌───────────┐ │ |
|
|
65
|
+
│ │ zeros │ │ input │ │zeros │ │ |
|
|
66
|
+
│ │ [N,C,p,H,W]│ │ [N,C,T,H,W]│ │[N,C,p,H,W]│ │ |
|
|
67
|
+
│ └────┬───────┘ └────┬───────┘ └────┬──────┘ │ |
|
|
68
|
+
│ └───────────┼───┴───────────────┘ │ |
|
|
69
|
+
│ │ │ |
|
|
70
|
+
│ ┌───────┴───────┐ │ |
|
|
71
|
+
│ │ cat │ │ |
|
|
72
|
+
│ │ (dim=2) │ │ |
|
|
73
|
+
│ └───────┬───────┘ │ |
|
|
74
|
+
│ │ │ |
|
|
75
|
+
│ ┌───────┴───────┐ │ |
|
|
76
|
+
│ │ padded_input │ │ |
|
|
77
|
+
│ │ [N,C,T+2p,H,W]│ │ |
|
|
78
|
+
│ └───────┬───────┘ │ |
|
|
79
|
+
└───────────────────┼────────────────────────────┘ |
|
|
80
|
+
│ |
|
|
81
|
+
┌───────────────────┴───────────────────────────────┐ |
|
|
82
|
+
│ Temporal Processing Loop │ |
|
|
83
|
+
│ ┌────────────────────────────────────────────┐ │ |
|
|
84
|
+
│ │ For t_out = 0..T_out-1: │ │ |
|
|
85
|
+
│ │ For i = 0..kT-1: │ │ |
|
|
86
|
+
│ │ t_idx = t_out*stride[0] + i*dilation[0]│ │ |
|
|
87
|
+
│ │ ┌─────────────────────────┐ │ │ |
|
|
88
|
+
│ │ │ slice input[t_idx] │ │ │ |
|
|
89
|
+
│ │ │ [N,C,H,W] │ │ │ |
|
|
90
|
+
│ │ └─────────┬───────────────┘ │ │ |
|
|
91
|
+
│ │ │ │ │ |
|
|
92
|
+
│ │ ┌─────────┴───────────────┐ │ │ |
|
|
93
|
+
│ │ │ squeeze dims │ │ │ |
|
|
94
|
+
│ │ │ [N,C,H,W] │ │ │ |
|
|
95
|
+
│ │ └─────────┬───────────────┘ │ │ |
|
|
96
|
+
│ │ │ │ │ |
|
|
97
|
+
│ │ ┌─────────┴───────────────┐ │ │ |
|
|
98
|
+
│ │ │ conv2d(input,weight) │ │ │───────┘
|
|
99
|
+
│ │ │ [N,C_out,H_out,W_out] │ │ │
|
|
100
|
+
│ │ └─────────┬───────────────┘ │ │
|
|
101
|
+
│ │ │ │ │
|
|
102
|
+
│ │ ┌─────────┴───────────────┐ │ │
|
|
103
|
+
│ │ │ where(valid_mask, │ │ │
|
|
104
|
+
│ │ │ conv2d, zeros) │ │ │
|
|
105
|
+
│ │ └─────────┬───────────────┘ │ │
|
|
106
|
+
│ │ │ │ │
|
|
107
|
+
│ │ ┌─────────┴───────────────┐ │ │
|
|
108
|
+
│ │ │ accumulate (add) │ │ │
|
|
109
|
+
│ │ └─────────┬───────────────┘ │ │
|
|
110
|
+
│ └───────────────┼────────────────────────────┘ │
|
|
111
|
+
│ │ │
|
|
112
|
+
│ ┌──────┴───────────┐ │
|
|
113
|
+
│ │ add bias (if any)│ │
|
|
114
|
+
│ └───────┬──────────┘ │
|
|
115
|
+
│ │ │
|
|
116
|
+
│ ┌───────┴──────────┐ │
|
|
117
|
+
│ │ unsqueeze (dim=2)│ │
|
|
118
|
+
│ └───────┬──────────┘ │
|
|
119
|
+
└───────────────────┼───────────────────────────────┘
|
|
120
|
+
│
|
|
121
|
+
┌───────────────────┴───────────────────────┐
|
|
122
|
+
│ cat (dim=2) │
|
|
123
|
+
│ [N,C_out,T_out,H_out,W_out] │
|
|
124
|
+
└───────────────────┬───────────────────────┘
|
|
125
|
+
│
|
|
126
|
+
output(dim=5)
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def __init__(self):
|
|
130
|
+
super().__init__()
|
|
131
|
+
|
|
132
|
+
def _parse_3d_padding(self, padding, kernel_size):
|
|
133
|
+
"""
|
|
134
|
+
Parse 3D padding parameter and return (temporal, H, W) tuple.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
padding: Can be str ('same', 'valid'), int, list, or tuple
|
|
138
|
+
kernel_size: 3D kernel size (kT, kH, kW)
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Tuple of 3 padding values: (temporal_padding, H_padding, W_padding)
|
|
142
|
+
"""
|
|
143
|
+
if isinstance(padding, str):
|
|
144
|
+
if padding == "same":
|
|
145
|
+
# For 'same' padding, use kernel_size // 2
|
|
146
|
+
if isinstance(kernel_size, int):
|
|
147
|
+
return kernel_size // 2, kernel_size // 2, kernel_size // 2
|
|
148
|
+
else:
|
|
149
|
+
return kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[2] // 2
|
|
150
|
+
elif padding == "valid":
|
|
151
|
+
return 0, 0, 0
|
|
152
|
+
else:
|
|
153
|
+
raise NotYetSupportedError(f"Unsupported padding string: {padding}")
|
|
154
|
+
elif isinstance(padding, (list, tuple)):
|
|
155
|
+
if len(padding) == 1:
|
|
156
|
+
return padding[0], padding[0], padding[0]
|
|
157
|
+
elif len(padding) == 3:
|
|
158
|
+
return padding[0], padding[1], padding[2]
|
|
159
|
+
else:
|
|
160
|
+
raise NotYetSupportedError(f"Unsupported padding format: {padding}")
|
|
161
|
+
else: # int
|
|
162
|
+
return padding, padding, padding
|
|
163
|
+
|
|
164
|
+
def convert(self, exported_program: ExportedProgram, node: torch.fx.Node) -> bool:
|
|
165
|
+
logger = logging.getLogger(__name__)
|
|
166
|
+
modified = False
|
|
167
|
+
graph_module = exported_program.graph_module
|
|
168
|
+
graph = graph_module.graph
|
|
169
|
+
|
|
170
|
+
# Extract conv3d arguments
|
|
171
|
+
args = Conv3DArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
172
|
+
|
|
173
|
+
input = args.input
|
|
174
|
+
weight = args.weight
|
|
175
|
+
bias = args.bias
|
|
176
|
+
stride = args.stride
|
|
177
|
+
padding = args.padding
|
|
178
|
+
dilation = args.dilation
|
|
179
|
+
groups = args.groups
|
|
180
|
+
|
|
181
|
+
input_shape = extract_shape(input)
|
|
182
|
+
weight_shape = extract_shape(weight)
|
|
183
|
+
|
|
184
|
+
if not (len(input_shape) == 5):
|
|
185
|
+
raise NotYetSupportedError(
|
|
186
|
+
f"Only support 5D input tensor: node's input shape: {input_shape}"
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
if not (len(weight_shape) == 5):
|
|
190
|
+
raise NotYetSupportedError(
|
|
191
|
+
f"Only support 5D weight tensor: node's weight shape: {weight_shape}"
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
N, C_in, T_in, H_in, W_in = input_shape
|
|
195
|
+
C_out, C_in_weight, kT, kH, kW = weight_shape
|
|
196
|
+
|
|
197
|
+
temporal_padding, h_padding, w_padding = self._parse_3d_padding(
|
|
198
|
+
padding, (kT, kH, kW)
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# Calculate output dimensions
|
|
202
|
+
T_out = (T_in + 2 * temporal_padding - dilation[0] * (kT - 1) - 1) // stride[
|
|
203
|
+
0
|
|
204
|
+
] + 1
|
|
205
|
+
|
|
206
|
+
H_out = (H_in + 2 * h_padding - dilation[1] * (kH - 1) - 1) // stride[1] + 1
|
|
207
|
+
W_out = (W_in + 2 * w_padding - dilation[2] * (kW - 1) - 1) // stride[2] + 1
|
|
208
|
+
|
|
209
|
+
# Find the next node after conv3d
|
|
210
|
+
next_node = node.next
|
|
211
|
+
if next_node is None:
|
|
212
|
+
# If no next node, find the output node
|
|
213
|
+
for n in graph.nodes:
|
|
214
|
+
if n.op == "output":
|
|
215
|
+
next_node = n
|
|
216
|
+
break
|
|
217
|
+
|
|
218
|
+
if next_node is None:
|
|
219
|
+
raise RuntimeError("Could not find insertion point for temporal outputs")
|
|
220
|
+
|
|
221
|
+
# Create all nodes before the next node in one go
|
|
222
|
+
with graph.inserting_before(next_node):
|
|
223
|
+
# Step 1: Create weight_2d layers first (they depend only on weight)
|
|
224
|
+
weight_2d_layers = []
|
|
225
|
+
for t in range(kT):
|
|
226
|
+
# Slice weight for temporal dimension t: [C_out, C_in, t, kH, kW] -> [C_out, C_in, kH, kW]
|
|
227
|
+
weight_slice = create_node(
|
|
228
|
+
graph,
|
|
229
|
+
torch.ops.aten.slice.Tensor,
|
|
230
|
+
args=(weight, 2, t, t + 1, 1),
|
|
231
|
+
origin=weight,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# Remove temporal dimension: [C_out, C_in, 1, kH, kW] -> [C_out, C_in, kH, kW]
|
|
235
|
+
weight_2d = create_node(
|
|
236
|
+
graph,
|
|
237
|
+
torch.ops.aten.squeeze.dims,
|
|
238
|
+
args=(weight_slice, [2]),
|
|
239
|
+
origin=weight_slice,
|
|
240
|
+
)
|
|
241
|
+
weight_2d_layers.append(weight_2d)
|
|
242
|
+
|
|
243
|
+
# Step 2: Create padded input (if needed) using cat
|
|
244
|
+
if temporal_padding > 0:
|
|
245
|
+
# Create zero padding: [N, C, padding, H, W]
|
|
246
|
+
zero_padding = create_node(
|
|
247
|
+
graph,
|
|
248
|
+
torch.ops.aten.zeros.default,
|
|
249
|
+
args=([N, C_in, temporal_padding, H_in, W_in],),
|
|
250
|
+
kwargs={
|
|
251
|
+
"dtype": input.meta.get("dtype", torch.float32),
|
|
252
|
+
"device": input.meta.get("device", "cpu"),
|
|
253
|
+
},
|
|
254
|
+
origin=input,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# Cat: [zero_padding, input, zero_padding] -> [N, C, T+2*padding, H, W]
|
|
258
|
+
padded_input = create_node(
|
|
259
|
+
graph,
|
|
260
|
+
torch.ops.aten.cat.default,
|
|
261
|
+
args=([zero_padding, input, zero_padding], 2),
|
|
262
|
+
origin=input,
|
|
263
|
+
)
|
|
264
|
+
T_padded = T_in + 2 * temporal_padding
|
|
265
|
+
else:
|
|
266
|
+
padded_input = input
|
|
267
|
+
T_padded = T_in
|
|
268
|
+
|
|
269
|
+
# Step 3: Process each temporal output position
|
|
270
|
+
temporal_outputs = []
|
|
271
|
+
for t_out in range(T_out):
|
|
272
|
+
# Calculate input time position
|
|
273
|
+
t_in = t_out * stride[0]
|
|
274
|
+
|
|
275
|
+
# Initialize accumulator for this temporal position
|
|
276
|
+
acc = None
|
|
277
|
+
|
|
278
|
+
for i, weight_2d in enumerate(weight_2d_layers):
|
|
279
|
+
# Calculate actual time index with dilation
|
|
280
|
+
t_idx = t_in + i * dilation[0]
|
|
281
|
+
|
|
282
|
+
# Create constant for time index
|
|
283
|
+
t_idx_const = create_node(
|
|
284
|
+
graph,
|
|
285
|
+
torch.ops.aten.scalar_tensor.default,
|
|
286
|
+
args=(t_idx,),
|
|
287
|
+
kwargs={"dtype": torch.int64},
|
|
288
|
+
origin=node,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
# Create constant for T_padded
|
|
292
|
+
t_padded_const = create_node(
|
|
293
|
+
graph,
|
|
294
|
+
torch.ops.aten.scalar_tensor.default,
|
|
295
|
+
args=(T_padded,),
|
|
296
|
+
kwargs={"dtype": torch.int64},
|
|
297
|
+
origin=node,
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
# Check if t_idx < T_padded
|
|
301
|
+
valid_mask = create_node(
|
|
302
|
+
graph,
|
|
303
|
+
torch.ops.aten.lt.Tensor,
|
|
304
|
+
args=(t_idx_const, t_padded_const),
|
|
305
|
+
origin=node,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# Slice input at time t_idx: [N, C_in, T_padded, H_in, W_in] -> [N, C_in, H_in, W_in]
|
|
309
|
+
input_slice = create_node(
|
|
310
|
+
graph,
|
|
311
|
+
torch.ops.aten.slice.Tensor,
|
|
312
|
+
args=(padded_input, 2, t_idx, t_idx + 1, 1),
|
|
313
|
+
origin=padded_input,
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# Remove temporal dimension: [N, C_in, 1, H_in, W_in] -> [N, C_in, H_in, W_in]
|
|
317
|
+
input_2d = create_node(
|
|
318
|
+
graph,
|
|
319
|
+
torch.ops.aten.squeeze.dims,
|
|
320
|
+
args=(input_slice, [2]),
|
|
321
|
+
origin=input_slice,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
# Create conv2d operation with proper input
|
|
325
|
+
conv2d = create_node(
|
|
326
|
+
graph,
|
|
327
|
+
torch.ops.aten.conv2d.default,
|
|
328
|
+
args=(
|
|
329
|
+
input_2d, # input is now available
|
|
330
|
+
weight_2d,
|
|
331
|
+
None, # bias = False
|
|
332
|
+
[stride[1], stride[2]],
|
|
333
|
+
[h_padding, w_padding],
|
|
334
|
+
[dilation[1], dilation[2]],
|
|
335
|
+
groups,
|
|
336
|
+
),
|
|
337
|
+
origin=node,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
# Create zero tensor with calculated shape
|
|
341
|
+
# conv2d output shape: [N, C_out, H_out, W_out]
|
|
342
|
+
zero_tensor = create_node(
|
|
343
|
+
graph,
|
|
344
|
+
torch.ops.aten.zeros.default,
|
|
345
|
+
args=([N, C_out, H_out, W_out],),
|
|
346
|
+
kwargs={
|
|
347
|
+
"dtype": input.meta.get("dtype", torch.float32),
|
|
348
|
+
"device": input.meta.get("device", "cpu"),
|
|
349
|
+
},
|
|
350
|
+
origin=conv2d,
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
# Apply conditional execution
|
|
354
|
+
conv2d_masked = create_node(
|
|
355
|
+
graph,
|
|
356
|
+
torch.ops.aten.where.self,
|
|
357
|
+
args=(valid_mask, conv2d, zero_tensor),
|
|
358
|
+
origin=conv2d,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
if acc is None:
|
|
362
|
+
# First temporal slice
|
|
363
|
+
acc = conv2d_masked
|
|
364
|
+
else:
|
|
365
|
+
# Add subsequent temporal slices
|
|
366
|
+
acc = create_node(
|
|
367
|
+
graph,
|
|
368
|
+
torch.ops.aten.add.Tensor,
|
|
369
|
+
args=(acc, conv2d_masked),
|
|
370
|
+
origin=acc,
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
# Add bias if present
|
|
374
|
+
if bias is not None:
|
|
375
|
+
bias_reshaped = create_node(
|
|
376
|
+
graph,
|
|
377
|
+
torch.ops.aten.reshape.default,
|
|
378
|
+
args=(bias, [1, C_out, 1, 1]),
|
|
379
|
+
origin=bias,
|
|
380
|
+
)
|
|
381
|
+
acc = create_node(
|
|
382
|
+
graph,
|
|
383
|
+
torch.ops.aten.add.Tensor,
|
|
384
|
+
args=(acc, bias_reshaped),
|
|
385
|
+
origin=acc,
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
temporal_outputs.append(acc)
|
|
389
|
+
|
|
390
|
+
# Step 4: Stack temporal outputs using cat instead of stack
|
|
391
|
+
# First, unsqueeze each temporal output to add the time dimension
|
|
392
|
+
unsqueezed_outputs = []
|
|
393
|
+
for i, temp_output in enumerate(temporal_outputs):
|
|
394
|
+
# Add time dimension: [N, C_out, H_out, W_out] -> [N, C_out, 1, H_out, W_out]
|
|
395
|
+
unsqueezed = create_node(
|
|
396
|
+
graph,
|
|
397
|
+
torch.ops.aten.unsqueeze.default,
|
|
398
|
+
args=(temp_output, 2),
|
|
399
|
+
origin=temp_output,
|
|
400
|
+
)
|
|
401
|
+
unsqueezed_outputs.append(unsqueezed)
|
|
402
|
+
|
|
403
|
+
# Cat along time dimension: [N, C_out, T_out, H_out, W_out]
|
|
404
|
+
stacked_output = create_node(
|
|
405
|
+
graph,
|
|
406
|
+
torch.ops.aten.cat.default,
|
|
407
|
+
args=(unsqueezed_outputs, 2),
|
|
408
|
+
origin=node,
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
# Replace the original node
|
|
412
|
+
node.replace_all_uses_with(stacked_output, propagate_meta=False)
|
|
413
|
+
logger.debug(f"{node.name} is replaced with conv2d decomposition")
|
|
414
|
+
modified = True
|
|
415
|
+
|
|
416
|
+
return modified
|
|
417
|
+
|
|
418
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
|
419
|
+
target_conv_op = [torch.ops.aten.conv3d.default, torch.ops.aten.conv3d.padding]
|
|
420
|
+
graph_module = exported_program.graph_module
|
|
421
|
+
graph = graph_module.graph
|
|
422
|
+
|
|
423
|
+
modified = False
|
|
424
|
+
|
|
425
|
+
# Process all Conv3D nodes in forward pass order
|
|
426
|
+
for node in graph.nodes:
|
|
427
|
+
if not is_target_node(node, target_conv_op):
|
|
428
|
+
continue
|
|
429
|
+
modified |= self.convert(exported_program, node)
|
|
430
|
+
|
|
431
|
+
graph.eliminate_dead_code()
|
|
432
|
+
graph.lint()
|
|
433
|
+
graph_module.recompile()
|
|
434
|
+
|
|
435
|
+
return PassResult(modified)
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import torch.fx
|
|
19
|
+
import torch
|
|
20
|
+
from torch.export import ExportedProgram
|
|
21
|
+
|
|
22
|
+
from tico.utils import logging
|
|
23
|
+
from tico.utils.graph import create_node
|
|
24
|
+
from tico.utils.passes import PassBase, PassResult
|
|
25
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@trace_graph_diff_on_pass
|
|
29
|
+
class ConvertSymSizeToCircleShape(PassBase):
|
|
30
|
+
"""
|
|
31
|
+
This pass converts torch.ops.aten.sym_size.int operations to circle_custom::shape.
|
|
32
|
+
|
|
33
|
+
The circle_custom::shape operator allows preserving dynamic shape information
|
|
34
|
+
in the Circle model. This is essential for models with dynamic batch sizes or other dynamic dimensions.
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
Before: %sym_size_int_1 = call_function[target=torch.ops.aten.sym_size.int](args=(%x, 0))
|
|
38
|
+
After: %shape_0 = call_function[target=torch.ops.circle_custom.shape](args=(%x,))
|
|
39
|
+
%slice_0 = call_function[target=torch.ops.aten.slice.Tensor](args=(%shape_0, 0, 0, 1, 1))
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self):
|
|
43
|
+
super().__init__()
|
|
44
|
+
|
|
45
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
|
46
|
+
logger = logging.getLogger(__name__)
|
|
47
|
+
|
|
48
|
+
graph_module = exported_program.graph_module
|
|
49
|
+
graph = graph_module.graph
|
|
50
|
+
modified = False
|
|
51
|
+
|
|
52
|
+
for node in graph.nodes:
|
|
53
|
+
if node.op != "call_function":
|
|
54
|
+
continue
|
|
55
|
+
|
|
56
|
+
if node.target == torch.ops.aten.sym_size.int:
|
|
57
|
+
# sym_size.int has args: (input, dim)
|
|
58
|
+
input_tensor = node.args[0]
|
|
59
|
+
dim = node.args[1]
|
|
60
|
+
|
|
61
|
+
# Create circle_custom::shape node
|
|
62
|
+
with graph.inserting_after(node):
|
|
63
|
+
shape_node = create_node(
|
|
64
|
+
graph,
|
|
65
|
+
torch.ops.circle_custom.shape,
|
|
66
|
+
args=(input_tensor,),
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# Set metadata for shape_node
|
|
70
|
+
if "val" in input_tensor.meta:
|
|
71
|
+
input_val = input_tensor.meta["val"]
|
|
72
|
+
rank = len(input_val.shape)
|
|
73
|
+
# shape output is a 1D tensor of size rank, dtype int32
|
|
74
|
+
# We use a real tensor here as a placeholder for metadata
|
|
75
|
+
shape_node.meta["val"] = torch.zeros(rank, dtype=torch.int32)
|
|
76
|
+
|
|
77
|
+
# Extract the specific dimension using slice
|
|
78
|
+
with graph.inserting_after(shape_node):
|
|
79
|
+
slice_node = create_node(
|
|
80
|
+
graph,
|
|
81
|
+
torch.ops.aten.slice.Tensor,
|
|
82
|
+
args=(shape_node, 0, dim, dim + 1, 1),
|
|
83
|
+
)
|
|
84
|
+
# slice output is 1D tensor of size 1
|
|
85
|
+
slice_node.meta["val"] = torch.zeros(1, dtype=torch.int32)
|
|
86
|
+
|
|
87
|
+
# Replace all uses
|
|
88
|
+
node.replace_all_uses_with(slice_node, propagate_meta=False)
|
|
89
|
+
modified = True
|
|
90
|
+
|
|
91
|
+
logger.debug(
|
|
92
|
+
f"Converted {node.name} (sym_size.int) to {shape_node.name} (circle_custom::shape) + {slice_node.name} (slice)"
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
graph.eliminate_dead_code()
|
|
96
|
+
graph.lint()
|
|
97
|
+
graph_module.recompile()
|
|
98
|
+
|
|
99
|
+
return PassResult(modified)
|
|
@@ -115,7 +115,7 @@ class DecomposeBatchNorm(PassBase):
|
|
|
115
115
|
continue
|
|
116
116
|
|
|
117
117
|
input_shape = extract_shape(input_)
|
|
118
|
-
assert len(input_shape)
|
|
118
|
+
assert len(input_shape) >= 2, len(input_shape)
|
|
119
119
|
C = input_shape[1]
|
|
120
120
|
|
|
121
121
|
weight_value = (
|
|
@@ -145,11 +145,15 @@ class DecomposeBatchNorm(PassBase):
|
|
|
145
145
|
# Calculate constants for mul and add
|
|
146
146
|
mul_const = weight_value / torch.sqrt(var_value + eps)
|
|
147
147
|
add_const = bias_value - (mul_const * mean_value)
|
|
148
|
-
|
|
148
|
+
|
|
149
|
+
# Make sure channel count matches
|
|
149
150
|
assert len(mul_const) == len(add_const) == C
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
151
|
+
|
|
152
|
+
# Build a broadcastable shape like (1, C, 1, ...)
|
|
153
|
+
view_shape = [1] * len(input_shape)
|
|
154
|
+
view_shape[1] = C
|
|
155
|
+
mul_const = mul_const.view(*view_shape)
|
|
156
|
+
add_const = add_const.view(*view_shape)
|
|
153
157
|
|
|
154
158
|
# Placeholder nodes must be the first N nodes in the nodes list of a graph.
|
|
155
159
|
# Therefore, insert the newly created placeholders at the start of the node list.
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import torch.fx
|
|
19
|
+
import torch
|
|
20
|
+
from torch.export import ExportedProgram
|
|
21
|
+
|
|
22
|
+
from tico.passes import ops
|
|
23
|
+
from tico.serialize.circle_mapping import extract_shape
|
|
24
|
+
from tico.utils import logging
|
|
25
|
+
from tico.utils.graph import create_node
|
|
26
|
+
from tico.utils.passes import PassBase, PassResult
|
|
27
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
|
28
|
+
from tico.utils.validate_args_kwargs import CopyArgs
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@trace_graph_diff_on_pass
|
|
32
|
+
class LowerCopy(PassBase):
|
|
33
|
+
"""
|
|
34
|
+
This pass lowers `aten.copy.default` to simpler broadcast operations.
|
|
35
|
+
|
|
36
|
+
- If src and dst shapes are the same, the copy is redundant and folded away.
|
|
37
|
+
- If src and dst shapes differ, it's replaced with expand (broadcast).
|
|
38
|
+
|
|
39
|
+
This simplifies serialization by handling copy logic at the pass level.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self):
|
|
43
|
+
super().__init__()
|
|
44
|
+
|
|
45
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
|
46
|
+
logger = logging.getLogger(__name__)
|
|
47
|
+
|
|
48
|
+
graph_module = exported_program.graph_module
|
|
49
|
+
graph = graph_module.graph
|
|
50
|
+
modified = False
|
|
51
|
+
|
|
52
|
+
for node in graph.nodes:
|
|
53
|
+
if not node.op == "call_function":
|
|
54
|
+
continue
|
|
55
|
+
|
|
56
|
+
if node.target != torch.ops.aten.copy.default:
|
|
57
|
+
continue
|
|
58
|
+
|
|
59
|
+
args = CopyArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
60
|
+
dst = args.dst
|
|
61
|
+
src = args.src
|
|
62
|
+
|
|
63
|
+
dst_shape = list(extract_shape(dst))
|
|
64
|
+
src_shape = list(extract_shape(src))
|
|
65
|
+
|
|
66
|
+
# Case 1: Same shape - copy is redundant, just use src
|
|
67
|
+
if dst_shape == src_shape:
|
|
68
|
+
logger.debug(
|
|
69
|
+
f"{node.name}: Same shape {dst_shape}, replacing with src directly"
|
|
70
|
+
)
|
|
71
|
+
node.replace_all_uses_with(src, propagate_meta=False)
|
|
72
|
+
modified = True
|
|
73
|
+
continue
|
|
74
|
+
|
|
75
|
+
# Case 2: Different shapes - need expand
|
|
76
|
+
logger.debug(
|
|
77
|
+
f"{node.name}: Different shapes src={src_shape} dst={dst_shape}, "
|
|
78
|
+
f"inserting expand"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
with graph.inserting_before(node):
|
|
82
|
+
expand_node = create_node(
|
|
83
|
+
graph,
|
|
84
|
+
torch.ops.aten.expand.default,
|
|
85
|
+
args=(src, dst_shape),
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
node.replace_all_uses_with(expand_node, propagate_meta=True)
|
|
89
|
+
modified = True
|
|
90
|
+
|
|
91
|
+
graph.eliminate_dead_code()
|
|
92
|
+
graph.lint()
|
|
93
|
+
graph_module.recompile()
|
|
94
|
+
|
|
95
|
+
return PassResult(modified)
|