tico 0.1.0.dev250608__py3-none-any.whl → 0.1.0.dev250609__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tico/__init__.py CHANGED
@@ -21,7 +21,7 @@ from tico.config import CompileConfigV1, get_default_config
21
21
  from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
22
22
 
23
23
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
24
- __version__ = "0.1.0.dev250608"
24
+ __version__ = "0.1.0.dev250609"
25
25
 
26
26
  MINIMUM_SUPPORTED_VERSION = "2.5.0"
27
27
  SECURE_TORCH_VERSION = "2.6.0"
tico/passes/ops.py CHANGED
@@ -54,7 +54,10 @@ class AtenOps:
54
54
  self.reshape = [torch.ops.aten.reshape.default]
55
55
  self.select = [torch.ops.aten.select_copy.int, torch.ops.aten.select.int]
56
56
  self.slice = [torch.ops.aten.slice.Tensor, torch.ops.aten.slice_copy.Tensor]
57
- self.softmax = [torch.ops.aten._softmax.default]
57
+ self.softmax = [
58
+ torch.ops.aten._softmax.default,
59
+ torch.ops.aten._safe_softmax.default,
60
+ ]
58
61
  self.squeeze = [torch.ops.aten.squeeze.dims, torch.ops.aten.squeeze_copy.dims]
59
62
  self.to_copy = [
60
63
  torch.ops.aten._to_copy.default,
@@ -24,6 +24,22 @@ from tico.serialize.circle_mapping import extract_shape, extract_stride
24
24
  from tico.utils import logging
25
25
  from tico.utils.passes import PassBase, PassResult
26
26
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
27
+ from tico.utils.validate_args_kwargs import PermuteArgs
28
+
29
+
30
+ def _compose_permutation(dims1: list[int], dims2: list[int]):
31
+ """
32
+ Compose two permutation vectors.
33
+
34
+ Given y = x.permute(dims1) and z = y.permute(dims2),
35
+ the overall permutation p = dims2 ∘ dims1 is
36
+
37
+ p[i] = dims1[dims2[i]]
38
+ """
39
+ assert len(dims1) == len(
40
+ dims2
41
+ ), f"len(dims1): {len(dims1)}, len(dims2): {len(dims2)}"
42
+ return [dims1[i] for i in dims2]
27
43
 
28
44
 
29
45
  def passes():
@@ -45,9 +61,13 @@ class RemoveRedundantPermutePattern1(PassBase):
45
61
  def call(self, exported_program: ExportedProgram) -> PassResult:
46
62
  """
47
63
  [BEFORE]
48
- (AxBxC) - aten.permute - aten.permute - (AxBxC)
64
+ (AxBxC) - aten.permute_1 - aten.permute_2 - (OUT_SHAPE)
49
65
  [AFTER]
50
- (AxBxC)
66
+ if OUT_SHAPE == (AxBxC):
67
+ (AxBxC)
68
+ else:
69
+ (AxBxC) - aten.permute (fused dims) - (OUT_SHAPE)
70
+
51
71
  """
52
72
  logger = logging.getLogger(__name__)
53
73
 
@@ -61,39 +81,36 @@ class RemoveRedundantPermutePattern1(PassBase):
61
81
  continue
62
82
  if len(permute2.users) != 1:
63
83
  continue
64
- assert len(permute2.args) == 2
65
- permute1, permute2_dims = permute2.args
66
- assert isinstance(permute1, torch.fx.Node), type(permute1)
67
- assert isinstance(permute2_dims, list), type(permute2_dims)
68
- for dim in permute2_dims:
69
- assert isinstance(dim, int), type(dim)
84
+ permute2_args = PermuteArgs(*permute2.args, **permute2.kwargs) # type: ignore[arg-type]
85
+ permute1, permute2_dims = permute2_args.input, permute2_args.dims
70
86
 
71
87
  if not permute1.target in ops.aten.permute:
72
88
  continue
73
89
  if len(permute1.users) != 1:
74
90
  continue
75
- assert len(permute1.args) == 2
76
- permute1_input, permute1_dims = permute1.args
77
- assert isinstance(permute1_input, torch.fx.Node), type(permute1_input)
78
- assert isinstance(permute1_dims, list), type(permute1_dims)
79
- for dim in permute1_dims:
80
- assert isinstance(dim, int), type(dim)
81
-
82
- # shape
83
- permute1_input_shape = extract_shape(permute1_input)
84
- permute2_shape = extract_shape(permute2)
85
- if permute1_input_shape != permute2_shape:
86
- continue
87
- # stride
88
- permute1_input_stride = extract_stride(permute1_input)
89
- permute2_stride = extract_stride(permute2)
90
- if permute1_input_stride != permute2_stride:
91
- continue
92
-
93
- permute2.replace_all_uses_with(permute1_input, propagate_meta=False)
94
-
91
+ permute1_args = PermuteArgs(*permute1.args, **permute1.kwargs) # type: ignore[arg-type]
92
+ permute1_input, permute1_dims = permute1_args.input, permute1_args.dims
93
+
94
+ fused_dims = _compose_permutation(permute1_dims, permute2_dims)
95
+ identity = list(range(len(fused_dims)))
96
+
97
+ if fused_dims == identity:
98
+ # shape
99
+ permute1_input_shape = extract_shape(permute1_input)
100
+ permute2_shape = extract_shape(permute2)
101
+ assert permute1_input_shape == permute2_shape
102
+
103
+ permute2.replace_all_uses_with(permute1_input, propagate_meta=False)
104
+ logger.debug(f"{permute1.name} and {permute2.name} are removed.")
105
+ else:
106
+ with graph.inserting_after(permute2):
107
+ new_args = (permute1_input, fused_dims)
108
+ fused_permute = graph.call_function(
109
+ torch.ops.aten.permute.default, args=new_args
110
+ )
111
+ permute2.replace_all_uses_with(fused_permute, propagate_meta=True)
112
+ logger.debug(f"{permute1.name} and {permute2.name} are fused.")
95
113
  modified = True
96
- logger.debug(f"{permute1.name} and {permute2.name} are removed.")
97
114
 
98
115
  graph.eliminate_dead_code()
99
116
  graph.lint()
@@ -24,11 +24,12 @@ from tico.serialize.circle_mapping import extract_shape
24
24
  from tico.utils import logging
25
25
  from tico.utils.passes import PassBase, PassResult
26
26
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
27
- from tico.utils.utils import set_new_meta_val
27
+ from tico.utils.utils import broadcastable, set_new_meta_val
28
28
  from tico.utils.validate_args_kwargs import (
29
29
  AddTensorArgs,
30
30
  PermuteArgs,
31
31
  ReshapeArgs,
32
+ SafeSoftmaxArgs,
32
33
  SoftmaxArgs,
33
34
  )
34
35
 
@@ -253,15 +254,14 @@ class RemoveRedundantReshapePattern3(PassBase):
253
254
  continue
254
255
  if not softmax.target in ops.aten.softmax:
255
256
  continue
256
- softmax_args = SoftmaxArgs(*softmax.args, **softmax.kwargs) # type: ignore[arg-type]
257
- add, softmax_dim, softmax_half_to_float = (
257
+ if softmax.target == torch.ops.aten._softmax.default:
258
+ softmax_args = SoftmaxArgs(*softmax.args, **softmax.kwargs) # type: ignore[arg-type, assignment]
259
+ elif softmax.target == torch.ops.aten._safe_softmax.default:
260
+ softmax_args = SafeSoftmaxArgs(*softmax.args, **softmax.kwargs) # type: ignore[arg-type, assignment]
261
+ add, softmax_dim = (
258
262
  softmax_args.input,
259
263
  softmax_args.dim,
260
- softmax_args.half_to_float,
261
264
  )
262
- assert isinstance(add, torch.fx.Node), type(add)
263
- assert isinstance(softmax_dim, int), type(softmax_dim)
264
- assert isinstance(softmax_half_to_float, bool), type(softmax_half_to_float)
265
265
  softmax_shape = extract_shape(softmax)
266
266
  # TODO support other dimension
267
267
  if softmax_dim != -1 and softmax_dim != len(softmax_shape) - 1:
@@ -295,10 +295,16 @@ class RemoveRedundantReshapePattern3(PassBase):
295
295
  # Check condition
296
296
  reshape_2_input_shape = extract_shape(reshape_2_input)
297
297
  reshape_3_input_shape = extract_shape(reshape_3_input)
298
- if reshape_2_input_shape != reshape_3_input_shape:
298
+ if not broadcastable(reshape_2_input_shape, reshape_3_input_shape):
299
299
  continue
300
300
  reshape_1_shape = extract_shape(reshape_1)
301
- if reshape_2_input_shape != reshape_1_shape:
301
+ if (
302
+ reshape_2_input_shape != reshape_1_shape
303
+ and reshape_3_input_shape != reshape_1_shape
304
+ ):
305
+ continue
306
+ # Make sure the softmax axis length is unchanged.
307
+ if softmax_shape[-1] != reshape_1_shape[-1]:
302
308
  continue
303
309
  # Assume `aten.add` and `aten.softmax` have only one user.
304
310
  if len(add.users) != 1:
@@ -311,8 +317,7 @@ class RemoveRedundantReshapePattern3(PassBase):
311
317
  set_new_meta_val(add)
312
318
  # Update softmax
313
319
  if softmax_dim == len(softmax_shape) - 1:
314
- updated_dim = len(extract_shape(reshape_2_input)) - 1
315
- softmax.args = (add, updated_dim, softmax_half_to_float)
320
+ softmax.update_arg(1, -1) # (index, last_dim)
316
321
  set_new_meta_val(softmax)
317
322
 
318
323
  reshape_1.replace_all_uses_with(softmax, propagate_meta=False)
tico/utils/utils.py CHANGED
@@ -17,6 +17,7 @@ import subprocess
17
17
  import typing
18
18
  import warnings
19
19
  from functools import wraps
20
+ from typing import List
20
21
 
21
22
  import torch
22
23
  from circle_schema import circle
@@ -341,3 +342,39 @@ def get_quant_dtype(qmin: int, qmax: int):
341
342
  return known_ranges[(qmin, qmax)]
342
343
  else:
343
344
  raise ValueError(f"Unsupported quantization range: ({qmin}, {qmax})")
345
+
346
+
347
+ def broadcastable(
348
+ shape_a: List[int] | torch.Size, shape_b: List[int] | torch.Size
349
+ ) -> bool:
350
+ """
351
+ Return **True** if two shapes are broadcast-compatible under the standard
352
+ NumPy/PyTorch rules.
353
+
354
+ Broadcasting rule
355
+ --------------------------------
356
+ - Align the shapes **right-to-left**.
357
+ - For each aligned dimension `(a, b)` one of the following must hold
358
+ - `a == b` (sizes match)
359
+ - `a == 1` (shape-A can repeat along that dim)
360
+ - `b == 1` (shape-B can repeat along that dim)
361
+ - When one shape is shorter, treat its missing leading dims as `1`.
362
+
363
+ Examples
364
+ --------
365
+ >>> _broadcastable([8, 16, 32], [16, 32])
366
+ True
367
+ >>> _broadcastable([8, 16, 32], [1, 32])
368
+ True
369
+ >>> _broadcastable([8, 16, 32], [8, 32, 16])
370
+ False
371
+ """
372
+ # Walk from the last dim to the front
373
+ len_a, len_b = len(shape_a), len(shape_b)
374
+ max_len = max(len_a, len_b)
375
+ for i in range(1, max_len + 1):
376
+ dim_a = shape_a[-i] if i <= len_a else 1
377
+ dim_b = shape_b[-i] if i <= len_b else 1
378
+ if dim_a != 1 and dim_b != 1 and dim_a != dim_b:
379
+ return False
380
+ return True
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tico
3
- Version: 0.1.0.dev250608
3
+ Version: 0.1.0.dev250609
4
4
  Summary: Convert exported Torch module to circle
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- tico/__init__.py,sha256=FjyTxaQjfrgfUCPfMzx5WnfpSY0Vbksu3gx8m6Tc5xY,1743
1
+ tico/__init__.py,sha256=mEQHwb4j_KQrsDMO6gp2USb5RCYNhd0MOxL4P_EGMW8,1743
2
2
  tico/pt2_to_circle.py,sha256=PPmFNw20jw2Z2VyM3ln9pX__jTzBOAZiv0gT5a-p-Y8,2666
3
3
  tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
4
4
  tico/config/base.py,sha256=anwOiJFkUxUi7Cef573JgQcjk6S-FSi6O_TLjYASW-g,1244
@@ -82,12 +82,12 @@ tico/passes/lower_pow2_to_mul.py,sha256=imx9CoKG4bLyNqv4F-Z203s_P0-0SdRH-y4_Q0PT
82
82
  tico/passes/lower_to_resize_nearest_neighbor.py,sha256=4bIxPSyNEzznTw8f8D9hMNXbZ0KmMPPPfRvITXonEz0,8881
83
83
  tico/passes/lower_to_slice.py,sha256=6xK7A2hdqdBdN2XRyt2rFGnqJGcaXxmgndis4kn2q2w,7112
84
84
  tico/passes/merge_consecutive_cat.py,sha256=u1E-7axmX4yw0TMlFO5jXB0rn3xe5ASmqXcR42c9eNA,2722
85
- tico/passes/ops.py,sha256=BmUtfBPrSmCd2TsFc6PtOmhuznuFNBNHl12RS7CS0KI,2836
85
+ tico/passes/ops.py,sha256=XzaKC_FpsfJLpnU4JlL9X-HVarWKm8cX0iiRgx9bMOs,2909
86
86
  tico/passes/remove_nop.py,sha256=5QE3inFsXgzyPT_t7pKeXNqD1LRf6ed_Mp7YMadA6AI,2707
87
87
  tico/passes/remove_redundant_assert_nodes.py,sha256=3a2xEQ2iPY7Gqg8jZi8G5bfDDrK2kOO1OHCMv_gJGz0,1592
88
88
  tico/passes/remove_redundant_expand.py,sha256=7st92AbWOl7yzM0Y5seaZJQKMFHqkYpH3qYMOlAU5lk,2234
89
- tico/passes/remove_redundant_permute.py,sha256=sS53eTY4sSnpZWDaaHN8czUmzNwmqh1lF90nYamXzac,3566
90
- tico/passes/remove_redundant_reshape.py,sha256=aPZcDR0kBExEsWCYfBbLulm_wcjJNnGjn4mgrUIPdpU,16810
89
+ tico/passes/remove_redundant_permute.py,sha256=PIS-ag1EiSLlXJeRjcZoapbvM_4LyGtLBzunomfkAYE,4236
90
+ tico/passes/remove_redundant_reshape.py,sha256=fDTMIrQFoRHYWrgGqEHgIg-azy4Z1Pod-Gfv7Gwplho,17010
91
91
  tico/passes/remove_redundant_slice.py,sha256=BAfSkA5jDIEhYx4nMnu6cJadQle3YTw5y39ZLiYfJJ8,2109
92
92
  tico/passes/remove_redundant_to_copy.py,sha256=uTIjAn3Eli_RvXC-QOqxBAkV_whDBkkNhu-mvNKAEhs,3136
93
93
  tico/passes/restore_linear.py,sha256=UMMHdLmRGq9bfJx_0L9lL2UQBd51PGNP0WywO8KdrDM,4066
@@ -189,15 +189,15 @@ tico/utils/passes.py,sha256=kGmDe__5cPaO6i5EDAoXSVe6yXEoX9hAny4ROb3ZEmQ,2409
189
189
  tico/utils/register_custom_op.py,sha256=iRQvdqlBqrJxq_pNkvJyDIJD_SYtCUl88wwbbuvSwlk,22952
190
190
  tico/utils/serialize.py,sha256=AQXMBOLu-Kg2Rn-qbqsAtHndjZAZIavlKA0QFgJREHM,1420
191
191
  tico/utils/trace_decorators.py,sha256=ddLIiKQfSaQrxgF1kNpwjFTQnXENzeSfcr1kuAW4jGI,3221
192
- tico/utils/utils.py,sha256=NAa3ZX5G-UCQwmz5WnFl0iCEra24PMY5wC0MyX7smUg,11156
192
+ tico/utils/utils.py,sha256=7t4U3Bs7SV7y3PCdgE9bEAAAI2u1mzQEUP81oP_XHHU,12334
193
193
  tico/utils/validate_args_kwargs.py,sha256=P4aMnr9EhNCtc_AgJPpuezfQbqFfDn0lhJSWqmumLZ8,25054
194
194
  tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
195
195
  tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
196
196
  tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
197
197
  tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
198
- tico-0.1.0.dev250608.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
199
- tico-0.1.0.dev250608.dist-info/METADATA,sha256=LJYAjDzOfuHcUGzJvc-Cgc_QKCnVUHFY2yr9IY8EkgU,8633
200
- tico-0.1.0.dev250608.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
201
- tico-0.1.0.dev250608.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
202
- tico-0.1.0.dev250608.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
203
- tico-0.1.0.dev250608.dist-info/RECORD,,
198
+ tico-0.1.0.dev250609.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
199
+ tico-0.1.0.dev250609.dist-info/METADATA,sha256=ZpNUuShOOV1hMrNe84u05IoeX4UZn1fj1_TEU4ZlI1Y,8633
200
+ tico-0.1.0.dev250609.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
201
+ tico-0.1.0.dev250609.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
202
+ tico-0.1.0.dev250609.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
203
+ tico-0.1.0.dev250609.dist-info/RECORD,,