tico 0.1.0.dev251106__py3-none-any.whl → 0.2.0.dev260122__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. tico/__init__.py +2 -2
  2. tico/_version.py +1 -0
  3. tico/passes/convert_conv3d_to_conv2d.py +435 -0
  4. tico/passes/convert_sym_size_to_circle_shape.py +99 -0
  5. tico/passes/decompose_batch_norm.py +9 -5
  6. tico/passes/lower_copy.py +95 -0
  7. tico/passes/ops.py +4 -0
  8. tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +251 -0
  9. tico/quantization/algorithm/fpi_gptq/quantizer.py +180 -0
  10. tico/quantization/algorithm/gptq/gptq.py +231 -11
  11. tico/quantization/algorithm/gptq/quantizer.py +18 -6
  12. tico/quantization/config/{pt2e.py → fpi_gptq.py} +11 -4
  13. tico/quantization/config/gptq.py +27 -4
  14. tico/quantization/public_interface.py +0 -10
  15. tico/quantization/wrapq/quantizer.py +2 -0
  16. tico/quantization/wrapq/wrappers/quant_elementwise.py +51 -11
  17. tico/serialize/operators/adapters/onert/llama_attention.py +51 -0
  18. tico/serialize/operators/op_attention.py +58 -0
  19. tico/serialize/operators/op_circle_shape.py +64 -0
  20. tico/serialize/operators/op_dequantize_per_channel.py +1 -0
  21. tico/serialize/operators/op_dequantize_per_tensor.py +1 -0
  22. tico/serialize/operators/op_transpose_conv.py +66 -50
  23. tico/utils/convert.py +16 -1
  24. tico/utils/padding.py +13 -5
  25. tico/utils/record_input.py +2 -2
  26. tico/utils/register_custom_op.py +63 -0
  27. tico/utils/validate_args_kwargs.py +49 -4
  28. tico-0.2.0.dev260122.dist-info/METADATA +631 -0
  29. {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info}/RECORD +35 -46
  30. {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info}/WHEEL +1 -1
  31. {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info}/entry_points.txt +0 -1
  32. tico/quantization/algorithm/pt2e/annotation/annotator.py +0 -208
  33. tico/quantization/algorithm/pt2e/annotation/config.py +0 -26
  34. tico/quantization/algorithm/pt2e/annotation/op/__init__.py +0 -21
  35. tico/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +0 -63
  36. tico/quantization/algorithm/pt2e/annotation/op/add.py +0 -55
  37. tico/quantization/algorithm/pt2e/annotation/op/conv2d.py +0 -90
  38. tico/quantization/algorithm/pt2e/annotation/op/div.py +0 -55
  39. tico/quantization/algorithm/pt2e/annotation/op/linear.py +0 -92
  40. tico/quantization/algorithm/pt2e/annotation/op/mean.py +0 -51
  41. tico/quantization/algorithm/pt2e/annotation/op/mul.py +0 -55
  42. tico/quantization/algorithm/pt2e/annotation/op/relu6.py +0 -51
  43. tico/quantization/algorithm/pt2e/annotation/op/rsqrt.py +0 -51
  44. tico/quantization/algorithm/pt2e/annotation/op/sub.py +0 -55
  45. tico/quantization/algorithm/pt2e/annotation/spec.py +0 -45
  46. tico/quantization/algorithm/pt2e/annotation/utils.py +0 -88
  47. tico/quantization/algorithm/pt2e/quantizer.py +0 -81
  48. tico/quantization/algorithm/pt2e/transformation/__init__.py +0 -1
  49. tico/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -58
  50. tico/quantization/algorithm/pt2e/utils.py +0 -135
  51. tico/serialize/operators/op_copy.py +0 -187
  52. tico-0.1.0.dev251106.dist-info/METADATA +0 -392
  53. /tico/quantization/algorithm/{pt2e → fpi_gptq}/__init__.py +0 -0
  54. /tico/{quantization/algorithm/pt2e/annotation → serialize/operators/adapters/onert}/__init__.py +0 -0
  55. {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info/licenses}/LICENSE +0 -0
  56. {tico-0.1.0.dev251106.dist-info → tico-0.2.0.dev260122.dist-info}/top_level.txt +0 -0
tico/__init__.py CHANGED
@@ -28,8 +28,8 @@ __all__ = [
28
28
  "convert_from_pt2",
29
29
  ]
30
30
 
31
- # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
32
- __version__ = "0.1.0.dev251106"
31
+ # THIS LINE IS AUTOMATICALLY GENERATED
32
+ __version__ = "0.2.0"
33
33
 
34
34
  MINIMUM_SUPPORTED_VERSION = "2.5.0"
35
35
  SECURE_TORCH_VERSION = "2.6.0"
tico/_version.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "0.2.0.dev260122"
@@ -0,0 +1,435 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+
14
+ from typing import List, TYPE_CHECKING
15
+
16
+ if TYPE_CHECKING:
17
+ import torch.fx
18
+
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.serialize.circle_mapping import extract_shape
23
+ from tico.utils import logging
24
+ from tico.utils.errors import NotYetSupportedError
25
+ from tico.utils.graph import create_node
26
+ from tico.utils.passes import PassBase, PassResult
27
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.utils import is_target_node
29
+ from tico.utils.validate_args_kwargs import Conv3DArgs
30
+
31
+
32
+ @trace_graph_diff_on_pass
33
+ class ConvertConv3dToConv2d(PassBase):
34
+ """
35
+ This pass converts `torch.ops.aten.conv3d` to multiple `torch.ops.aten.conv2d` operations
36
+
37
+ [before] input(dim=5) weight(dim=5)
38
+ │ │
39
+ │ │
40
+ conv3d<----------------+
41
+
42
+
43
+ output(dim=5)
44
+
45
+ [after] input(dim=5) weight(dim=5)
46
+ │ │
47
+ │ ┌───────┴───────┐
48
+ │ │ weight slice │
49
+ │ │ (kT times) │
50
+ │ └───────┬───────┘
51
+ │ │
52
+ │ ┌───────┴───────┐
53
+ │ │ squeeze dims │
54
+ │ │ (remove dim=2)│
55
+ │ └───────┬───────┘
56
+ │ │
57
+ │ ┌───────┴────────────┐
58
+ │ │ weight_2d[0..kT-1] │
59
+ │ │ [C_out,C_in,kH,kW] │
60
+ │ └───────┬────────────┘
61
+ │ │
62
+ ┌─────────────────┴──────────────────────────────┐ |
63
+ │ temporal padding (if needed) │ |
64
+ │ ┌────────────┐ ┌────────────┐ ┌───────────┐ │ |
65
+ │ │ zeros │ │ input │ │zeros │ │ |
66
+ │ │ [N,C,p,H,W]│ │ [N,C,T,H,W]│ │[N,C,p,H,W]│ │ |
67
+ │ └────┬───────┘ └────┬───────┘ └────┬──────┘ │ |
68
+ │ └───────────┼───┴───────────────┘ │ |
69
+ │ │ │ |
70
+ │ ┌───────┴───────┐ │ |
71
+ │ │ cat │ │ |
72
+ │ │ (dim=2) │ │ |
73
+ │ └───────┬───────┘ │ |
74
+ │ │ │ |
75
+ │ ┌───────┴───────┐ │ |
76
+ │ │ padded_input │ │ |
77
+ │ │ [N,C,T+2p,H,W]│ │ |
78
+ │ └───────┬───────┘ │ |
79
+ └───────────────────┼────────────────────────────┘ |
80
+ │ |
81
+ ┌───────────────────┴───────────────────────────────┐ |
82
+ │ Temporal Processing Loop │ |
83
+ │ ┌────────────────────────────────────────────┐ │ |
84
+ │ │ For t_out = 0..T_out-1: │ │ |
85
+ │ │ For i = 0..kT-1: │ │ |
86
+ │ │ t_idx = t_out*stride[0] + i*dilation[0]│ │ |
87
+ │ │ ┌─────────────────────────┐ │ │ |
88
+ │ │ │ slice input[t_idx] │ │ │ |
89
+ │ │ │ [N,C,H,W] │ │ │ |
90
+ │ │ └─────────┬───────────────┘ │ │ |
91
+ │ │ │ │ │ |
92
+ │ │ ┌─────────┴───────────────┐ │ │ |
93
+ │ │ │ squeeze dims │ │ │ |
94
+ │ │ │ [N,C,H,W] │ │ │ |
95
+ │ │ └─────────┬───────────────┘ │ │ |
96
+ │ │ │ │ │ |
97
+ │ │ ┌─────────┴───────────────┐ │ │ |
98
+ │ │ │ conv2d(input,weight) │ │ │───────┘
99
+ │ │ │ [N,C_out,H_out,W_out] │ │ │
100
+ │ │ └─────────┬───────────────┘ │ │
101
+ │ │ │ │ │
102
+ │ │ ┌─────────┴───────────────┐ │ │
103
+ │ │ │ where(valid_mask, │ │ │
104
+ │ │ │ conv2d, zeros) │ │ │
105
+ │ │ └─────────┬───────────────┘ │ │
106
+ │ │ │ │ │
107
+ │ │ ┌─────────┴───────────────┐ │ │
108
+ │ │ │ accumulate (add) │ │ │
109
+ │ │ └─────────┬───────────────┘ │ │
110
+ │ └───────────────┼────────────────────────────┘ │
111
+ │ │ │
112
+ │ ┌──────┴───────────┐ │
113
+ │ │ add bias (if any)│ │
114
+ │ └───────┬──────────┘ │
115
+ │ │ │
116
+ │ ┌───────┴──────────┐ │
117
+ │ │ unsqueeze (dim=2)│ │
118
+ │ └───────┬──────────┘ │
119
+ └───────────────────┼───────────────────────────────┘
120
+
121
+ ┌───────────────────┴───────────────────────┐
122
+ │ cat (dim=2) │
123
+ │ [N,C_out,T_out,H_out,W_out] │
124
+ └───────────────────┬───────────────────────┘
125
+
126
+ output(dim=5)
127
+ """
128
+
129
+ def __init__(self):
130
+ super().__init__()
131
+
132
+ def _parse_3d_padding(self, padding, kernel_size):
133
+ """
134
+ Parse 3D padding parameter and return (temporal, H, W) tuple.
135
+
136
+ Args:
137
+ padding: Can be str ('same', 'valid'), int, list, or tuple
138
+ kernel_size: 3D kernel size (kT, kH, kW)
139
+
140
+ Returns:
141
+ Tuple of 3 padding values: (temporal_padding, H_padding, W_padding)
142
+ """
143
+ if isinstance(padding, str):
144
+ if padding == "same":
145
+ # For 'same' padding, use kernel_size // 2
146
+ if isinstance(kernel_size, int):
147
+ return kernel_size // 2, kernel_size // 2, kernel_size // 2
148
+ else:
149
+ return kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[2] // 2
150
+ elif padding == "valid":
151
+ return 0, 0, 0
152
+ else:
153
+ raise NotYetSupportedError(f"Unsupported padding string: {padding}")
154
+ elif isinstance(padding, (list, tuple)):
155
+ if len(padding) == 1:
156
+ return padding[0], padding[0], padding[0]
157
+ elif len(padding) == 3:
158
+ return padding[0], padding[1], padding[2]
159
+ else:
160
+ raise NotYetSupportedError(f"Unsupported padding format: {padding}")
161
+ else: # int
162
+ return padding, padding, padding
163
+
164
+ def convert(self, exported_program: ExportedProgram, node: torch.fx.Node) -> bool:
165
+ logger = logging.getLogger(__name__)
166
+ modified = False
167
+ graph_module = exported_program.graph_module
168
+ graph = graph_module.graph
169
+
170
+ # Extract conv3d arguments
171
+ args = Conv3DArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
172
+
173
+ input = args.input
174
+ weight = args.weight
175
+ bias = args.bias
176
+ stride = args.stride
177
+ padding = args.padding
178
+ dilation = args.dilation
179
+ groups = args.groups
180
+
181
+ input_shape = extract_shape(input)
182
+ weight_shape = extract_shape(weight)
183
+
184
+ if not (len(input_shape) == 5):
185
+ raise NotYetSupportedError(
186
+ f"Only support 5D input tensor: node's input shape: {input_shape}"
187
+ )
188
+
189
+ if not (len(weight_shape) == 5):
190
+ raise NotYetSupportedError(
191
+ f"Only support 5D weight tensor: node's weight shape: {weight_shape}"
192
+ )
193
+
194
+ N, C_in, T_in, H_in, W_in = input_shape
195
+ C_out, C_in_weight, kT, kH, kW = weight_shape
196
+
197
+ temporal_padding, h_padding, w_padding = self._parse_3d_padding(
198
+ padding, (kT, kH, kW)
199
+ )
200
+
201
+ # Calculate output dimensions
202
+ T_out = (T_in + 2 * temporal_padding - dilation[0] * (kT - 1) - 1) // stride[
203
+ 0
204
+ ] + 1
205
+
206
+ H_out = (H_in + 2 * h_padding - dilation[1] * (kH - 1) - 1) // stride[1] + 1
207
+ W_out = (W_in + 2 * w_padding - dilation[2] * (kW - 1) - 1) // stride[2] + 1
208
+
209
+ # Find the next node after conv3d
210
+ next_node = node.next
211
+ if next_node is None:
212
+ # If no next node, find the output node
213
+ for n in graph.nodes:
214
+ if n.op == "output":
215
+ next_node = n
216
+ break
217
+
218
+ if next_node is None:
219
+ raise RuntimeError("Could not find insertion point for temporal outputs")
220
+
221
+ # Create all nodes before the next node in one go
222
+ with graph.inserting_before(next_node):
223
+ # Step 1: Create weight_2d layers first (they depend only on weight)
224
+ weight_2d_layers = []
225
+ for t in range(kT):
226
+ # Slice weight for temporal dimension t: [C_out, C_in, t, kH, kW] -> [C_out, C_in, kH, kW]
227
+ weight_slice = create_node(
228
+ graph,
229
+ torch.ops.aten.slice.Tensor,
230
+ args=(weight, 2, t, t + 1, 1),
231
+ origin=weight,
232
+ )
233
+
234
+ # Remove temporal dimension: [C_out, C_in, 1, kH, kW] -> [C_out, C_in, kH, kW]
235
+ weight_2d = create_node(
236
+ graph,
237
+ torch.ops.aten.squeeze.dims,
238
+ args=(weight_slice, [2]),
239
+ origin=weight_slice,
240
+ )
241
+ weight_2d_layers.append(weight_2d)
242
+
243
+ # Step 2: Create padded input (if needed) using cat
244
+ if temporal_padding > 0:
245
+ # Create zero padding: [N, C, padding, H, W]
246
+ zero_padding = create_node(
247
+ graph,
248
+ torch.ops.aten.zeros.default,
249
+ args=([N, C_in, temporal_padding, H_in, W_in],),
250
+ kwargs={
251
+ "dtype": input.meta.get("dtype", torch.float32),
252
+ "device": input.meta.get("device", "cpu"),
253
+ },
254
+ origin=input,
255
+ )
256
+
257
+ # Cat: [zero_padding, input, zero_padding] -> [N, C, T+2*padding, H, W]
258
+ padded_input = create_node(
259
+ graph,
260
+ torch.ops.aten.cat.default,
261
+ args=([zero_padding, input, zero_padding], 2),
262
+ origin=input,
263
+ )
264
+ T_padded = T_in + 2 * temporal_padding
265
+ else:
266
+ padded_input = input
267
+ T_padded = T_in
268
+
269
+ # Step 3: Process each temporal output position
270
+ temporal_outputs = []
271
+ for t_out in range(T_out):
272
+ # Calculate input time position
273
+ t_in = t_out * stride[0]
274
+
275
+ # Initialize accumulator for this temporal position
276
+ acc = None
277
+
278
+ for i, weight_2d in enumerate(weight_2d_layers):
279
+ # Calculate actual time index with dilation
280
+ t_idx = t_in + i * dilation[0]
281
+
282
+ # Create constant for time index
283
+ t_idx_const = create_node(
284
+ graph,
285
+ torch.ops.aten.scalar_tensor.default,
286
+ args=(t_idx,),
287
+ kwargs={"dtype": torch.int64},
288
+ origin=node,
289
+ )
290
+
291
+ # Create constant for T_padded
292
+ t_padded_const = create_node(
293
+ graph,
294
+ torch.ops.aten.scalar_tensor.default,
295
+ args=(T_padded,),
296
+ kwargs={"dtype": torch.int64},
297
+ origin=node,
298
+ )
299
+
300
+ # Check if t_idx < T_padded
301
+ valid_mask = create_node(
302
+ graph,
303
+ torch.ops.aten.lt.Tensor,
304
+ args=(t_idx_const, t_padded_const),
305
+ origin=node,
306
+ )
307
+
308
+ # Slice input at time t_idx: [N, C_in, T_padded, H_in, W_in] -> [N, C_in, H_in, W_in]
309
+ input_slice = create_node(
310
+ graph,
311
+ torch.ops.aten.slice.Tensor,
312
+ args=(padded_input, 2, t_idx, t_idx + 1, 1),
313
+ origin=padded_input,
314
+ )
315
+
316
+ # Remove temporal dimension: [N, C_in, 1, H_in, W_in] -> [N, C_in, H_in, W_in]
317
+ input_2d = create_node(
318
+ graph,
319
+ torch.ops.aten.squeeze.dims,
320
+ args=(input_slice, [2]),
321
+ origin=input_slice,
322
+ )
323
+
324
+ # Create conv2d operation with proper input
325
+ conv2d = create_node(
326
+ graph,
327
+ torch.ops.aten.conv2d.default,
328
+ args=(
329
+ input_2d, # input is now available
330
+ weight_2d,
331
+ None, # bias = False
332
+ [stride[1], stride[2]],
333
+ [h_padding, w_padding],
334
+ [dilation[1], dilation[2]],
335
+ groups,
336
+ ),
337
+ origin=node,
338
+ )
339
+
340
+ # Create zero tensor with calculated shape
341
+ # conv2d output shape: [N, C_out, H_out, W_out]
342
+ zero_tensor = create_node(
343
+ graph,
344
+ torch.ops.aten.zeros.default,
345
+ args=([N, C_out, H_out, W_out],),
346
+ kwargs={
347
+ "dtype": input.meta.get("dtype", torch.float32),
348
+ "device": input.meta.get("device", "cpu"),
349
+ },
350
+ origin=conv2d,
351
+ )
352
+
353
+ # Apply conditional execution
354
+ conv2d_masked = create_node(
355
+ graph,
356
+ torch.ops.aten.where.self,
357
+ args=(valid_mask, conv2d, zero_tensor),
358
+ origin=conv2d,
359
+ )
360
+
361
+ if acc is None:
362
+ # First temporal slice
363
+ acc = conv2d_masked
364
+ else:
365
+ # Add subsequent temporal slices
366
+ acc = create_node(
367
+ graph,
368
+ torch.ops.aten.add.Tensor,
369
+ args=(acc, conv2d_masked),
370
+ origin=acc,
371
+ )
372
+
373
+ # Add bias if present
374
+ if bias is not None:
375
+ bias_reshaped = create_node(
376
+ graph,
377
+ torch.ops.aten.reshape.default,
378
+ args=(bias, [1, C_out, 1, 1]),
379
+ origin=bias,
380
+ )
381
+ acc = create_node(
382
+ graph,
383
+ torch.ops.aten.add.Tensor,
384
+ args=(acc, bias_reshaped),
385
+ origin=acc,
386
+ )
387
+
388
+ temporal_outputs.append(acc)
389
+
390
+ # Step 4: Stack temporal outputs using cat instead of stack
391
+ # First, unsqueeze each temporal output to add the time dimension
392
+ unsqueezed_outputs = []
393
+ for i, temp_output in enumerate(temporal_outputs):
394
+ # Add time dimension: [N, C_out, H_out, W_out] -> [N, C_out, 1, H_out, W_out]
395
+ unsqueezed = create_node(
396
+ graph,
397
+ torch.ops.aten.unsqueeze.default,
398
+ args=(temp_output, 2),
399
+ origin=temp_output,
400
+ )
401
+ unsqueezed_outputs.append(unsqueezed)
402
+
403
+ # Cat along time dimension: [N, C_out, T_out, H_out, W_out]
404
+ stacked_output = create_node(
405
+ graph,
406
+ torch.ops.aten.cat.default,
407
+ args=(unsqueezed_outputs, 2),
408
+ origin=node,
409
+ )
410
+
411
+ # Replace the original node
412
+ node.replace_all_uses_with(stacked_output, propagate_meta=False)
413
+ logger.debug(f"{node.name} is replaced with conv2d decomposition")
414
+ modified = True
415
+
416
+ return modified
417
+
418
+ def call(self, exported_program: ExportedProgram) -> PassResult:
419
+ target_conv_op = [torch.ops.aten.conv3d.default, torch.ops.aten.conv3d.padding]
420
+ graph_module = exported_program.graph_module
421
+ graph = graph_module.graph
422
+
423
+ modified = False
424
+
425
+ # Process all Conv3D nodes in forward pass order
426
+ for node in graph.nodes:
427
+ if not is_target_node(node, target_conv_op):
428
+ continue
429
+ modified |= self.convert(exported_program, node)
430
+
431
+ graph.eliminate_dead_code()
432
+ graph.lint()
433
+ graph_module.recompile()
434
+
435
+ return PassResult(modified)
@@ -0,0 +1,99 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.utils import logging
23
+ from tico.utils.graph import create_node
24
+ from tico.utils.passes import PassBase, PassResult
25
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
26
+
27
+
28
+ @trace_graph_diff_on_pass
29
+ class ConvertSymSizeToCircleShape(PassBase):
30
+ """
31
+ This pass converts torch.ops.aten.sym_size.int operations to circle_custom::shape.
32
+
33
+ The circle_custom::shape operator allows preserving dynamic shape information
34
+ in the Circle model. This is essential for models with dynamic batch sizes or other dynamic dimensions.
35
+
36
+ Example:
37
+ Before: %sym_size_int_1 = call_function[target=torch.ops.aten.sym_size.int](args=(%x, 0))
38
+ After: %shape_0 = call_function[target=torch.ops.circle_custom.shape](args=(%x,))
39
+ %slice_0 = call_function[target=torch.ops.aten.slice.Tensor](args=(%shape_0, 0, 0, 1, 1))
40
+ """
41
+
42
+ def __init__(self):
43
+ super().__init__()
44
+
45
+ def call(self, exported_program: ExportedProgram) -> PassResult:
46
+ logger = logging.getLogger(__name__)
47
+
48
+ graph_module = exported_program.graph_module
49
+ graph = graph_module.graph
50
+ modified = False
51
+
52
+ for node in graph.nodes:
53
+ if node.op != "call_function":
54
+ continue
55
+
56
+ if node.target == torch.ops.aten.sym_size.int:
57
+ # sym_size.int has args: (input, dim)
58
+ input_tensor = node.args[0]
59
+ dim = node.args[1]
60
+
61
+ # Create circle_custom::shape node
62
+ with graph.inserting_after(node):
63
+ shape_node = create_node(
64
+ graph,
65
+ torch.ops.circle_custom.shape,
66
+ args=(input_tensor,),
67
+ )
68
+
69
+ # Set metadata for shape_node
70
+ if "val" in input_tensor.meta:
71
+ input_val = input_tensor.meta["val"]
72
+ rank = len(input_val.shape)
73
+ # shape output is a 1D tensor of size rank, dtype int32
74
+ # We use a real tensor here as a placeholder for metadata
75
+ shape_node.meta["val"] = torch.zeros(rank, dtype=torch.int32)
76
+
77
+ # Extract the specific dimension using slice
78
+ with graph.inserting_after(shape_node):
79
+ slice_node = create_node(
80
+ graph,
81
+ torch.ops.aten.slice.Tensor,
82
+ args=(shape_node, 0, dim, dim + 1, 1),
83
+ )
84
+ # slice output is 1D tensor of size 1
85
+ slice_node.meta["val"] = torch.zeros(1, dtype=torch.int32)
86
+
87
+ # Replace all uses
88
+ node.replace_all_uses_with(slice_node, propagate_meta=False)
89
+ modified = True
90
+
91
+ logger.debug(
92
+ f"Converted {node.name} (sym_size.int) to {shape_node.name} (circle_custom::shape) + {slice_node.name} (slice)"
93
+ )
94
+
95
+ graph.eliminate_dead_code()
96
+ graph.lint()
97
+ graph_module.recompile()
98
+
99
+ return PassResult(modified)
@@ -115,7 +115,7 @@ class DecomposeBatchNorm(PassBase):
115
115
  continue
116
116
 
117
117
  input_shape = extract_shape(input_)
118
- assert len(input_shape) == 4
118
+ assert len(input_shape) >= 2, len(input_shape)
119
119
  C = input_shape[1]
120
120
 
121
121
  weight_value = (
@@ -145,11 +145,15 @@ class DecomposeBatchNorm(PassBase):
145
145
  # Calculate constants for mul and add
146
146
  mul_const = weight_value / torch.sqrt(var_value + eps)
147
147
  add_const = bias_value - (mul_const * mean_value)
148
- # N, C, H, W
148
+
149
+ # Make sure channel count matches
149
150
  assert len(mul_const) == len(add_const) == C
150
- # reshape along with channel dimension
151
- mul_const = mul_const.view(1, mul_const.shape[0], 1, 1)
152
- add_const = add_const.view(1, add_const.shape[0], 1, 1)
151
+
152
+ # Build a broadcastable shape like (1, C, 1, ...)
153
+ view_shape = [1] * len(input_shape)
154
+ view_shape[1] = C
155
+ mul_const = mul_const.view(*view_shape)
156
+ add_const = add_const.view(*view_shape)
153
157
 
154
158
  # Placeholder nodes must be the first N nodes in the nodes list of a graph.
155
159
  # Therefore, insert the newly created placeholders at the start of the node list.
@@ -0,0 +1,95 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.passes import ops
23
+ from tico.serialize.circle_mapping import extract_shape
24
+ from tico.utils import logging
25
+ from tico.utils.graph import create_node
26
+ from tico.utils.passes import PassBase, PassResult
27
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.validate_args_kwargs import CopyArgs
29
+
30
+
31
+ @trace_graph_diff_on_pass
32
+ class LowerCopy(PassBase):
33
+ """
34
+ This pass lowers `aten.copy.default` to simpler broadcast operations.
35
+
36
+ - If src and dst shapes are the same, the copy is redundant and folded away.
37
+ - If src and dst shapes differ, it's replaced with expand (broadcast).
38
+
39
+ This simplifies serialization by handling copy logic at the pass level.
40
+ """
41
+
42
+ def __init__(self):
43
+ super().__init__()
44
+
45
+ def call(self, exported_program: ExportedProgram) -> PassResult:
46
+ logger = logging.getLogger(__name__)
47
+
48
+ graph_module = exported_program.graph_module
49
+ graph = graph_module.graph
50
+ modified = False
51
+
52
+ for node in graph.nodes:
53
+ if not node.op == "call_function":
54
+ continue
55
+
56
+ if node.target != torch.ops.aten.copy.default:
57
+ continue
58
+
59
+ args = CopyArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
60
+ dst = args.dst
61
+ src = args.src
62
+
63
+ dst_shape = list(extract_shape(dst))
64
+ src_shape = list(extract_shape(src))
65
+
66
+ # Case 1: Same shape - copy is redundant, just use src
67
+ if dst_shape == src_shape:
68
+ logger.debug(
69
+ f"{node.name}: Same shape {dst_shape}, replacing with src directly"
70
+ )
71
+ node.replace_all_uses_with(src, propagate_meta=False)
72
+ modified = True
73
+ continue
74
+
75
+ # Case 2: Different shapes - need expand
76
+ logger.debug(
77
+ f"{node.name}: Different shapes src={src_shape} dst={dst_shape}, "
78
+ f"inserting expand"
79
+ )
80
+
81
+ with graph.inserting_before(node):
82
+ expand_node = create_node(
83
+ graph,
84
+ torch.ops.aten.expand.default,
85
+ args=(src, dst_shape),
86
+ )
87
+
88
+ node.replace_all_uses_with(expand_node, propagate_meta=True)
89
+ modified = True
90
+
91
+ graph.eliminate_dead_code()
92
+ graph.lint()
93
+ graph_module.recompile()
94
+
95
+ return PassResult(modified)