tico 0.1.0.dev250605__py3-none-any.whl → 0.1.0.dev250609__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/passes/ops.py +4 -1
- tico/passes/remove_redundant_permute.py +46 -29
- tico/passes/remove_redundant_reshape.py +16 -11
- tico/utils/utils.py +37 -0
- {tico-0.1.0.dev250605.dist-info → tico-0.1.0.dev250609.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250605.dist-info → tico-0.1.0.dev250609.dist-info}/RECORD +11 -11
- {tico-0.1.0.dev250605.dist-info → tico-0.1.0.dev250609.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250605.dist-info → tico-0.1.0.dev250609.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250605.dist-info → tico-0.1.0.dev250609.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250605.dist-info → tico-0.1.0.dev250609.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
@@ -21,7 +21,7 @@ from tico.config import CompileConfigV1, get_default_config
|
|
21
21
|
from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
|
22
22
|
|
23
23
|
# THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
|
24
|
-
__version__ = "0.1.0.
|
24
|
+
__version__ = "0.1.0.dev250609"
|
25
25
|
|
26
26
|
MINIMUM_SUPPORTED_VERSION = "2.5.0"
|
27
27
|
SECURE_TORCH_VERSION = "2.6.0"
|
tico/passes/ops.py
CHANGED
@@ -54,7 +54,10 @@ class AtenOps:
|
|
54
54
|
self.reshape = [torch.ops.aten.reshape.default]
|
55
55
|
self.select = [torch.ops.aten.select_copy.int, torch.ops.aten.select.int]
|
56
56
|
self.slice = [torch.ops.aten.slice.Tensor, torch.ops.aten.slice_copy.Tensor]
|
57
|
-
self.softmax = [
|
57
|
+
self.softmax = [
|
58
|
+
torch.ops.aten._softmax.default,
|
59
|
+
torch.ops.aten._safe_softmax.default,
|
60
|
+
]
|
58
61
|
self.squeeze = [torch.ops.aten.squeeze.dims, torch.ops.aten.squeeze_copy.dims]
|
59
62
|
self.to_copy = [
|
60
63
|
torch.ops.aten._to_copy.default,
|
@@ -24,6 +24,22 @@ from tico.serialize.circle_mapping import extract_shape, extract_stride
|
|
24
24
|
from tico.utils import logging
|
25
25
|
from tico.utils.passes import PassBase, PassResult
|
26
26
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
|
+
from tico.utils.validate_args_kwargs import PermuteArgs
|
28
|
+
|
29
|
+
|
30
|
+
def _compose_permutation(dims1: list[int], dims2: list[int]):
|
31
|
+
"""
|
32
|
+
Compose two permutation vectors.
|
33
|
+
|
34
|
+
Given y = x.permute(dims1) and z = y.permute(dims2),
|
35
|
+
the overall permutation p = dims2 ∘ dims1 is
|
36
|
+
|
37
|
+
p[i] = dims1[dims2[i]]
|
38
|
+
"""
|
39
|
+
assert len(dims1) == len(
|
40
|
+
dims2
|
41
|
+
), f"len(dims1): {len(dims1)}, len(dims2): {len(dims2)}"
|
42
|
+
return [dims1[i] for i in dims2]
|
27
43
|
|
28
44
|
|
29
45
|
def passes():
|
@@ -45,9 +61,13 @@ class RemoveRedundantPermutePattern1(PassBase):
|
|
45
61
|
def call(self, exported_program: ExportedProgram) -> PassResult:
|
46
62
|
"""
|
47
63
|
[BEFORE]
|
48
|
-
(AxBxC) - aten.
|
64
|
+
(AxBxC) - aten.permute_1 - aten.permute_2 - (OUT_SHAPE)
|
49
65
|
[AFTER]
|
50
|
-
(AxBxC)
|
66
|
+
if OUT_SHAPE == (AxBxC):
|
67
|
+
(AxBxC)
|
68
|
+
else:
|
69
|
+
(AxBxC) - aten.permute (fused dims) - (OUT_SHAPE)
|
70
|
+
|
51
71
|
"""
|
52
72
|
logger = logging.getLogger(__name__)
|
53
73
|
|
@@ -61,39 +81,36 @@ class RemoveRedundantPermutePattern1(PassBase):
|
|
61
81
|
continue
|
62
82
|
if len(permute2.users) != 1:
|
63
83
|
continue
|
64
|
-
|
65
|
-
permute1, permute2_dims =
|
66
|
-
assert isinstance(permute1, torch.fx.Node), type(permute1)
|
67
|
-
assert isinstance(permute2_dims, list), type(permute2_dims)
|
68
|
-
for dim in permute2_dims:
|
69
|
-
assert isinstance(dim, int), type(dim)
|
84
|
+
permute2_args = PermuteArgs(*permute2.args, **permute2.kwargs) # type: ignore[arg-type]
|
85
|
+
permute1, permute2_dims = permute2_args.input, permute2_args.dims
|
70
86
|
|
71
87
|
if not permute1.target in ops.aten.permute:
|
72
88
|
continue
|
73
89
|
if len(permute1.users) != 1:
|
74
90
|
continue
|
75
|
-
|
76
|
-
permute1_input, permute1_dims =
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
91
|
+
permute1_args = PermuteArgs(*permute1.args, **permute1.kwargs) # type: ignore[arg-type]
|
92
|
+
permute1_input, permute1_dims = permute1_args.input, permute1_args.dims
|
93
|
+
|
94
|
+
fused_dims = _compose_permutation(permute1_dims, permute2_dims)
|
95
|
+
identity = list(range(len(fused_dims)))
|
96
|
+
|
97
|
+
if fused_dims == identity:
|
98
|
+
# shape
|
99
|
+
permute1_input_shape = extract_shape(permute1_input)
|
100
|
+
permute2_shape = extract_shape(permute2)
|
101
|
+
assert permute1_input_shape == permute2_shape
|
102
|
+
|
103
|
+
permute2.replace_all_uses_with(permute1_input, propagate_meta=False)
|
104
|
+
logger.debug(f"{permute1.name} and {permute2.name} are removed.")
|
105
|
+
else:
|
106
|
+
with graph.inserting_after(permute2):
|
107
|
+
new_args = (permute1_input, fused_dims)
|
108
|
+
fused_permute = graph.call_function(
|
109
|
+
torch.ops.aten.permute.default, args=new_args
|
110
|
+
)
|
111
|
+
permute2.replace_all_uses_with(fused_permute, propagate_meta=True)
|
112
|
+
logger.debug(f"{permute1.name} and {permute2.name} are fused.")
|
95
113
|
modified = True
|
96
|
-
logger.debug(f"{permute1.name} and {permute2.name} are removed.")
|
97
114
|
|
98
115
|
graph.eliminate_dead_code()
|
99
116
|
graph.lint()
|
@@ -24,11 +24,12 @@ from tico.serialize.circle_mapping import extract_shape
|
|
24
24
|
from tico.utils import logging
|
25
25
|
from tico.utils.passes import PassBase, PassResult
|
26
26
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
|
-
from tico.utils.utils import set_new_meta_val
|
27
|
+
from tico.utils.utils import broadcastable, set_new_meta_val
|
28
28
|
from tico.utils.validate_args_kwargs import (
|
29
29
|
AddTensorArgs,
|
30
30
|
PermuteArgs,
|
31
31
|
ReshapeArgs,
|
32
|
+
SafeSoftmaxArgs,
|
32
33
|
SoftmaxArgs,
|
33
34
|
)
|
34
35
|
|
@@ -253,15 +254,14 @@ class RemoveRedundantReshapePattern3(PassBase):
|
|
253
254
|
continue
|
254
255
|
if not softmax.target in ops.aten.softmax:
|
255
256
|
continue
|
256
|
-
|
257
|
-
|
257
|
+
if softmax.target == torch.ops.aten._softmax.default:
|
258
|
+
softmax_args = SoftmaxArgs(*softmax.args, **softmax.kwargs) # type: ignore[arg-type, assignment]
|
259
|
+
elif softmax.target == torch.ops.aten._safe_softmax.default:
|
260
|
+
softmax_args = SafeSoftmaxArgs(*softmax.args, **softmax.kwargs) # type: ignore[arg-type, assignment]
|
261
|
+
add, softmax_dim = (
|
258
262
|
softmax_args.input,
|
259
263
|
softmax_args.dim,
|
260
|
-
softmax_args.half_to_float,
|
261
264
|
)
|
262
|
-
assert isinstance(add, torch.fx.Node), type(add)
|
263
|
-
assert isinstance(softmax_dim, int), type(softmax_dim)
|
264
|
-
assert isinstance(softmax_half_to_float, bool), type(softmax_half_to_float)
|
265
265
|
softmax_shape = extract_shape(softmax)
|
266
266
|
# TODO support other dimension
|
267
267
|
if softmax_dim != -1 and softmax_dim != len(softmax_shape) - 1:
|
@@ -295,10 +295,16 @@ class RemoveRedundantReshapePattern3(PassBase):
|
|
295
295
|
# Check condition
|
296
296
|
reshape_2_input_shape = extract_shape(reshape_2_input)
|
297
297
|
reshape_3_input_shape = extract_shape(reshape_3_input)
|
298
|
-
if reshape_2_input_shape
|
298
|
+
if not broadcastable(reshape_2_input_shape, reshape_3_input_shape):
|
299
299
|
continue
|
300
300
|
reshape_1_shape = extract_shape(reshape_1)
|
301
|
-
if
|
301
|
+
if (
|
302
|
+
reshape_2_input_shape != reshape_1_shape
|
303
|
+
and reshape_3_input_shape != reshape_1_shape
|
304
|
+
):
|
305
|
+
continue
|
306
|
+
# Make sure the softmax axis length is unchanged.
|
307
|
+
if softmax_shape[-1] != reshape_1_shape[-1]:
|
302
308
|
continue
|
303
309
|
# Assume `aten.add` and `aten.softmax` have only one user.
|
304
310
|
if len(add.users) != 1:
|
@@ -311,8 +317,7 @@ class RemoveRedundantReshapePattern3(PassBase):
|
|
311
317
|
set_new_meta_val(add)
|
312
318
|
# Update softmax
|
313
319
|
if softmax_dim == len(softmax_shape) - 1:
|
314
|
-
|
315
|
-
softmax.args = (add, updated_dim, softmax_half_to_float)
|
320
|
+
softmax.update_arg(1, -1) # (index, last_dim)
|
316
321
|
set_new_meta_val(softmax)
|
317
322
|
|
318
323
|
reshape_1.replace_all_uses_with(softmax, propagate_meta=False)
|
tico/utils/utils.py
CHANGED
@@ -17,6 +17,7 @@ import subprocess
|
|
17
17
|
import typing
|
18
18
|
import warnings
|
19
19
|
from functools import wraps
|
20
|
+
from typing import List
|
20
21
|
|
21
22
|
import torch
|
22
23
|
from circle_schema import circle
|
@@ -341,3 +342,39 @@ def get_quant_dtype(qmin: int, qmax: int):
|
|
341
342
|
return known_ranges[(qmin, qmax)]
|
342
343
|
else:
|
343
344
|
raise ValueError(f"Unsupported quantization range: ({qmin}, {qmax})")
|
345
|
+
|
346
|
+
|
347
|
+
def broadcastable(
|
348
|
+
shape_a: List[int] | torch.Size, shape_b: List[int] | torch.Size
|
349
|
+
) -> bool:
|
350
|
+
"""
|
351
|
+
Return **True** if two shapes are broadcast-compatible under the standard
|
352
|
+
NumPy/PyTorch rules.
|
353
|
+
|
354
|
+
Broadcasting rule
|
355
|
+
--------------------------------
|
356
|
+
- Align the shapes **right-to-left**.
|
357
|
+
- For each aligned dimension `(a, b)` one of the following must hold
|
358
|
+
- `a == b` (sizes match)
|
359
|
+
- `a == 1` (shape-A can repeat along that dim)
|
360
|
+
- `b == 1` (shape-B can repeat along that dim)
|
361
|
+
- When one shape is shorter, treat its missing leading dims as `1`.
|
362
|
+
|
363
|
+
Examples
|
364
|
+
--------
|
365
|
+
>>> _broadcastable([8, 16, 32], [16, 32])
|
366
|
+
True
|
367
|
+
>>> _broadcastable([8, 16, 32], [1, 32])
|
368
|
+
True
|
369
|
+
>>> _broadcastable([8, 16, 32], [8, 32, 16])
|
370
|
+
False
|
371
|
+
"""
|
372
|
+
# Walk from the last dim to the front
|
373
|
+
len_a, len_b = len(shape_a), len(shape_b)
|
374
|
+
max_len = max(len_a, len_b)
|
375
|
+
for i in range(1, max_len + 1):
|
376
|
+
dim_a = shape_a[-i] if i <= len_a else 1
|
377
|
+
dim_b = shape_b[-i] if i <= len_b else 1
|
378
|
+
if dim_a != 1 and dim_b != 1 and dim_a != dim_b:
|
379
|
+
return False
|
380
|
+
return True
|
@@ -1,4 +1,4 @@
|
|
1
|
-
tico/__init__.py,sha256=
|
1
|
+
tico/__init__.py,sha256=mEQHwb4j_KQrsDMO6gp2USb5RCYNhd0MOxL4P_EGMW8,1743
|
2
2
|
tico/pt2_to_circle.py,sha256=PPmFNw20jw2Z2VyM3ln9pX__jTzBOAZiv0gT5a-p-Y8,2666
|
3
3
|
tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
|
4
4
|
tico/config/base.py,sha256=anwOiJFkUxUi7Cef573JgQcjk6S-FSi6O_TLjYASW-g,1244
|
@@ -82,12 +82,12 @@ tico/passes/lower_pow2_to_mul.py,sha256=imx9CoKG4bLyNqv4F-Z203s_P0-0SdRH-y4_Q0PT
|
|
82
82
|
tico/passes/lower_to_resize_nearest_neighbor.py,sha256=4bIxPSyNEzznTw8f8D9hMNXbZ0KmMPPPfRvITXonEz0,8881
|
83
83
|
tico/passes/lower_to_slice.py,sha256=6xK7A2hdqdBdN2XRyt2rFGnqJGcaXxmgndis4kn2q2w,7112
|
84
84
|
tico/passes/merge_consecutive_cat.py,sha256=u1E-7axmX4yw0TMlFO5jXB0rn3xe5ASmqXcR42c9eNA,2722
|
85
|
-
tico/passes/ops.py,sha256=
|
85
|
+
tico/passes/ops.py,sha256=XzaKC_FpsfJLpnU4JlL9X-HVarWKm8cX0iiRgx9bMOs,2909
|
86
86
|
tico/passes/remove_nop.py,sha256=5QE3inFsXgzyPT_t7pKeXNqD1LRf6ed_Mp7YMadA6AI,2707
|
87
87
|
tico/passes/remove_redundant_assert_nodes.py,sha256=3a2xEQ2iPY7Gqg8jZi8G5bfDDrK2kOO1OHCMv_gJGz0,1592
|
88
88
|
tico/passes/remove_redundant_expand.py,sha256=7st92AbWOl7yzM0Y5seaZJQKMFHqkYpH3qYMOlAU5lk,2234
|
89
|
-
tico/passes/remove_redundant_permute.py,sha256=
|
90
|
-
tico/passes/remove_redundant_reshape.py,sha256=
|
89
|
+
tico/passes/remove_redundant_permute.py,sha256=PIS-ag1EiSLlXJeRjcZoapbvM_4LyGtLBzunomfkAYE,4236
|
90
|
+
tico/passes/remove_redundant_reshape.py,sha256=fDTMIrQFoRHYWrgGqEHgIg-azy4Z1Pod-Gfv7Gwplho,17010
|
91
91
|
tico/passes/remove_redundant_slice.py,sha256=BAfSkA5jDIEhYx4nMnu6cJadQle3YTw5y39ZLiYfJJ8,2109
|
92
92
|
tico/passes/remove_redundant_to_copy.py,sha256=uTIjAn3Eli_RvXC-QOqxBAkV_whDBkkNhu-mvNKAEhs,3136
|
93
93
|
tico/passes/restore_linear.py,sha256=UMMHdLmRGq9bfJx_0L9lL2UQBd51PGNP0WywO8KdrDM,4066
|
@@ -189,15 +189,15 @@ tico/utils/passes.py,sha256=kGmDe__5cPaO6i5EDAoXSVe6yXEoX9hAny4ROb3ZEmQ,2409
|
|
189
189
|
tico/utils/register_custom_op.py,sha256=iRQvdqlBqrJxq_pNkvJyDIJD_SYtCUl88wwbbuvSwlk,22952
|
190
190
|
tico/utils/serialize.py,sha256=AQXMBOLu-Kg2Rn-qbqsAtHndjZAZIavlKA0QFgJREHM,1420
|
191
191
|
tico/utils/trace_decorators.py,sha256=ddLIiKQfSaQrxgF1kNpwjFTQnXENzeSfcr1kuAW4jGI,3221
|
192
|
-
tico/utils/utils.py,sha256=
|
192
|
+
tico/utils/utils.py,sha256=7t4U3Bs7SV7y3PCdgE9bEAAAI2u1mzQEUP81oP_XHHU,12334
|
193
193
|
tico/utils/validate_args_kwargs.py,sha256=P4aMnr9EhNCtc_AgJPpuezfQbqFfDn0lhJSWqmumLZ8,25054
|
194
194
|
tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
195
195
|
tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
|
196
196
|
tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
|
197
197
|
tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
|
198
|
-
tico-0.1.0.
|
199
|
-
tico-0.1.0.
|
200
|
-
tico-0.1.0.
|
201
|
-
tico-0.1.0.
|
202
|
-
tico-0.1.0.
|
203
|
-
tico-0.1.0.
|
198
|
+
tico-0.1.0.dev250609.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
|
199
|
+
tico-0.1.0.dev250609.dist-info/METADATA,sha256=ZpNUuShOOV1hMrNe84u05IoeX4UZn1fj1_TEU4ZlI1Y,8633
|
200
|
+
tico-0.1.0.dev250609.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
201
|
+
tico-0.1.0.dev250609.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
|
202
|
+
tico-0.1.0.dev250609.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
|
203
|
+
tico-0.1.0.dev250609.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|