tico 0.1.0.dev250923__py3-none-any.whl → 0.1.0.dev250925__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tico might be problematic. Click here for more details.
- tico/__init__.py +1 -1
- tico/config/v1.py +1 -0
- tico/experimental/quantization/ptq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/ops.py +0 -1
- tico/utils/convert.py +2 -0
- {tico-0.1.0.dev250923.dist-info → tico-0.1.0.dev250925.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250923.dist-info → tico-0.1.0.dev250925.dist-info}/RECORD +12 -10
- {tico-0.1.0.dev250923.dist-info → tico-0.1.0.dev250925.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250923.dist-info → tico-0.1.0.dev250925.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250923.dist-info → tico-0.1.0.dev250925.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250923.dist-info → tico-0.1.0.dev250925.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
tico/config/v1.py
CHANGED
|
@@ -24,6 +24,7 @@ class CompileConfigV1(CompileConfigBase):
|
|
|
24
24
|
convert_lhs_const_mm_to_fc: bool = False
|
|
25
25
|
convert_rhs_const_mm_to_fc: bool = True
|
|
26
26
|
convert_single_batch_lhs_const_bmm_to_fc: bool = False
|
|
27
|
+
convert_expand_to_slice_cat: bool = False
|
|
27
28
|
|
|
28
29
|
def get(self, name: str):
|
|
29
30
|
return super().get(name)
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
# -----------------------------------------------------------------------------
|
|
16
|
+
# This file includes modifications based on fairseq
|
|
17
|
+
# (https://github.com/facebookresearch/fairseq), originally licensed under
|
|
18
|
+
# the MIT License. See the LICENSE file in the fairseq repository for details.
|
|
19
|
+
# -----------------------------------------------------------------------------
|
|
20
|
+
|
|
21
|
+
"""
|
|
22
|
+
Q) Why the name "SingleStep"?
|
|
23
|
+
|
|
24
|
+
Fairseq's decoder already advances one token at a time during generation,
|
|
25
|
+
but the default path is "stateful" and "shape-polymorphic": it owns and
|
|
26
|
+
mutates K/V caches internally, prefix lengths and triangular masks grow with
|
|
27
|
+
the step, and beam reordering updates hidden module state. That's friendly
|
|
28
|
+
for eager execution, but hostile to `torch.export` and many accelerator
|
|
29
|
+
backends.
|
|
30
|
+
|
|
31
|
+
This export wrapper makes the per-token call truly "single-step" in the
|
|
32
|
+
export sense: "stateless" and "fixed-shape" so every invocation has the
|
|
33
|
+
exact same graph.
|
|
34
|
+
|
|
35
|
+
Key invariants
|
|
36
|
+
--------------
|
|
37
|
+
• "Stateless": K/V caches come in as explicit inputs and go out as outputs.
|
|
38
|
+
The module does not store or mutate hidden state.
|
|
39
|
+
• "Static shapes": Query is always [B, 1, C]; encoder features and masks
|
|
40
|
+
have fixed, predeclared sizes; K/V slots use fixed capacity (unused tail
|
|
41
|
+
is simply masked/ignored).
|
|
42
|
+
• "External control": Step indexing, cache slot management (append/roll),
|
|
43
|
+
and beam reordering are handled outside the module.
|
|
44
|
+
• "Prebuilt additive masks": Self-attention masks are provided by the
|
|
45
|
+
caller (0 for valid, large negative sentinel, e.g. -120, for masked),
|
|
46
|
+
avoiding data-dependent control flow.
|
|
47
|
+
|
|
48
|
+
In short: still step-wise like fairseq, but restructured for export—no
|
|
49
|
+
internal state, no data-dependent shapes, no dynamic control flow.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
from typing import List, Tuple
|
|
53
|
+
|
|
54
|
+
import torch
|
|
55
|
+
import torch.nn as nn
|
|
56
|
+
|
|
57
|
+
import tico
|
|
58
|
+
|
|
59
|
+
# ----- 1) Export wrapper module -------------------------------------------
|
|
60
|
+
class DecoderExportSingleStep(nn.Module):
|
|
61
|
+
"""
|
|
62
|
+
Export-only single-step decoder module.
|
|
63
|
+
|
|
64
|
+
Inputs (example shapes; B=1, H=8, Dh=64, C=512, S=64, Tprev=63):
|
|
65
|
+
- prev_x: [B, 1, C] embedded decoder input for the current step
|
|
66
|
+
- enc_x: [S, B, C] encoder hidden states (fixed-length export input)
|
|
67
|
+
- enc_pad_additive: [B, 1, S] additive float key_padding_mask for enc-dec attn (0 for keep, -120 for pad)
|
|
68
|
+
- self_attn_mask: [B, 1, S] additive float mask for decoder self-attn at this step; pass zeros if unused
|
|
69
|
+
- prev_self_k_0..L-1: [B, H, Tprev, Dh] cached self-attn K per layer
|
|
70
|
+
- prev_self_v_0..L-1: [B, H, Tprev, Dh] cached self-attn V per layer
|
|
71
|
+
|
|
72
|
+
Outputs:
|
|
73
|
+
- x_out: [B, 1, C] new decoder features at the current step
|
|
74
|
+
- new_k_0..L-1: [H, B, Dh] per-layer new K (single-timestep; time dim squeezed)
|
|
75
|
+
- new_v_0..L-1: [H, B, Dh] per-layer new V (single-timestep; time dim squeezed)
|
|
76
|
+
|
|
77
|
+
Notes:
|
|
78
|
+
• We keep masks/additive semantics externally to avoid any mask-building inside the graph.
|
|
79
|
+
• We reshape the new K/V from [B,H,1,Dh] -> [H,B,Dh] to match the requested output spec (8,1,64).
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(self, decoder: nn.Module):
|
|
83
|
+
super().__init__()
|
|
84
|
+
self.decoder = decoder
|
|
85
|
+
# Cache common meta for assertions
|
|
86
|
+
self.num_layers = len(getattr(decoder, "layers"))
|
|
87
|
+
# Infer heads/head_dim from the wrapped self_attn of layer 0
|
|
88
|
+
any_layer = getattr(decoder.layers[0], "wrapped", decoder.layers[0]) # type: ignore[index]
|
|
89
|
+
mha = getattr(any_layer, "self_attn", None)
|
|
90
|
+
assert mha is not None, "Decoder layer must expose self_attn"
|
|
91
|
+
self.num_heads = int(mha.num_heads)
|
|
92
|
+
self.head_dim = int(mha.head_dim)
|
|
93
|
+
# Embed dim (C)
|
|
94
|
+
self.embed_dim = int(getattr(decoder, "embed_dim"))
|
|
95
|
+
|
|
96
|
+
def forward(
|
|
97
|
+
self,
|
|
98
|
+
prev_x: torch.Tensor, # [B,1,C]
|
|
99
|
+
enc_x: torch.Tensor, # [S,B,C]
|
|
100
|
+
enc_pad_additive: torch.Tensor, # [B,1,S]
|
|
101
|
+
*kv_args: torch.Tensor, # prev_k_0..L-1, prev_v_0..L-1 (total 2L tensors)
|
|
102
|
+
self_attn_mask: torch.Tensor, # [B,1,S] (or zeros)
|
|
103
|
+
):
|
|
104
|
+
L = self.num_layers
|
|
105
|
+
H = self.num_heads
|
|
106
|
+
Dh = self.head_dim
|
|
107
|
+
B, one, C = prev_x.shape
|
|
108
|
+
S, B2, C2 = enc_x.shape
|
|
109
|
+
assert (
|
|
110
|
+
one == 1 and C == self.embed_dim and B == B2 and C2 == C
|
|
111
|
+
), "Shape mismatch in prev_x/enc_x"
|
|
112
|
+
assert len(kv_args) == 2 * L, f"Expected {2*L} KV tensors, got {len(kv_args)}"
|
|
113
|
+
|
|
114
|
+
# Unpack previous self-attn caches
|
|
115
|
+
prev_k_list: List[torch.Tensor] = list() # each [B,H,Tprev,Dh]
|
|
116
|
+
prev_v_list: List[torch.Tensor] = list() # each [B,H,Tprev,Dh]
|
|
117
|
+
for i in range(L):
|
|
118
|
+
prev_k_list.append(kv_args[2 * i])
|
|
119
|
+
prev_v_list.append(kv_args[2 * i + 1])
|
|
120
|
+
for i in range(L):
|
|
121
|
+
assert (
|
|
122
|
+
prev_k_list[i].dim() == 4 and prev_v_list[i].dim() == 4
|
|
123
|
+
), "KV must be [B,H,Tprev,Dh]"
|
|
124
|
+
assert (
|
|
125
|
+
prev_k_list[i].shape[0] == B
|
|
126
|
+
and prev_k_list[i].shape[1] == H
|
|
127
|
+
and prev_k_list[i].shape[3] == Dh
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Call decoder's external single-step path
|
|
131
|
+
# Returns:
|
|
132
|
+
# x_step: [B,1,C]
|
|
133
|
+
# newk/newv: lists of length L, each [B*H,1,Dh]
|
|
134
|
+
x_step, newk_list, newv_list = self.decoder.forward_external_step( # type: ignore[operator]
|
|
135
|
+
prev_output_x=prev_x,
|
|
136
|
+
encoder_out_x=enc_x,
|
|
137
|
+
encoder_padding_mask=enc_pad_additive,
|
|
138
|
+
self_attn_mask=self_attn_mask,
|
|
139
|
+
prev_self_k_list=prev_k_list,
|
|
140
|
+
prev_self_v_list=prev_v_list,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
out_tensors: List[torch.Tensor] = [
|
|
144
|
+
x_step
|
|
145
|
+
] # first output is the new decoder features
|
|
146
|
+
for i in range(L):
|
|
147
|
+
nk = newk_list[i] # [B*H, Tnew, Dh]
|
|
148
|
+
nv = newv_list[i] # [B*H, Tnew, Dh]
|
|
149
|
+
out_tensors.append(nk)
|
|
150
|
+
out_tensors.append(nv)
|
|
151
|
+
|
|
152
|
+
# Return tuple: (x_step, new_k_0, new_v_0, new_k_1, new_v_1, ..., new_k_{L-1}, new_v_{L-1})
|
|
153
|
+
return tuple(out_tensors)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
# ----- 2) Example inputs (B=1, S=64, H=8, Dh=64, C=512, L=4) ---------------
|
|
157
|
+
def make_example_inputs(*, L=4, B=1, S=64, H=8, Dh=64, C=512, Tprev=63, device="cpu"):
|
|
158
|
+
"""
|
|
159
|
+
Build example tensors that match the export I/O spec.
|
|
160
|
+
Shapes follow the request:
|
|
161
|
+
prev_x: [1,1,512]
|
|
162
|
+
enc_x: [64,1,512]
|
|
163
|
+
enc_pad_additive: [1,1,64] (additive float; zeros -> keep)
|
|
164
|
+
prev_k_i / prev_v_i (for i in 0..L-1): [1,8,63,64]
|
|
165
|
+
self_attn_mask: [1,1,64] (additive float; zeros -> keep)
|
|
166
|
+
"""
|
|
167
|
+
g = torch.Generator(device=device).manual_seed(0)
|
|
168
|
+
|
|
169
|
+
prev_x = torch.randn(B, 1, C, device=device, dtype=torch.float32, generator=g)
|
|
170
|
+
enc_x = torch.randn(S, B, C, device=device, dtype=torch.float32, generator=g)
|
|
171
|
+
|
|
172
|
+
# Additive masks (0 for allowed, -120 for masked)
|
|
173
|
+
enc_pad_additive = torch.full((B, 1, S), float(-120), device=device)
|
|
174
|
+
self_attn_mask = torch.full((B, 1, S), float(-120), device=device)
|
|
175
|
+
enc_pad_additive[0, :27] = 0 # 27 is a random example.
|
|
176
|
+
self_attn_mask[0, :27] = 0 # 27 is a random example.
|
|
177
|
+
|
|
178
|
+
# Previous self-attn caches for each layer
|
|
179
|
+
prev_k_list = []
|
|
180
|
+
prev_v_list = []
|
|
181
|
+
for _ in range(L):
|
|
182
|
+
prev_k = torch.randn(
|
|
183
|
+
B, H, Tprev, Dh, device=device, dtype=torch.float32, generator=g
|
|
184
|
+
)
|
|
185
|
+
prev_v = torch.randn(
|
|
186
|
+
B, H, Tprev, Dh, device=device, dtype=torch.float32, generator=g
|
|
187
|
+
)
|
|
188
|
+
prev_k_list.append(prev_k)
|
|
189
|
+
prev_v_list.append(prev_v)
|
|
190
|
+
|
|
191
|
+
# Pack inputs as the export function will expect:
|
|
192
|
+
# (prev_x, enc_x, enc_pad_additive, self_attn_mask, prev_k_0..L-1, prev_v_0..L-1)
|
|
193
|
+
example_args: Tuple[torch.Tensor, ...] = (
|
|
194
|
+
prev_x,
|
|
195
|
+
enc_x,
|
|
196
|
+
enc_pad_additive,
|
|
197
|
+
*prev_k_list,
|
|
198
|
+
*prev_v_list,
|
|
199
|
+
)
|
|
200
|
+
example_kwargs = {"self_attn_mask": self_attn_mask}
|
|
201
|
+
return example_args, example_kwargs
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
# ----- 3) Export driver -----------------------------------------------------
|
|
205
|
+
def export_decoder_single_step(translator, *, save_path="decoder_step_export.circle"):
|
|
206
|
+
"""
|
|
207
|
+
Wrap the QuantFairseqDecoder into the export-friendly single-step module
|
|
208
|
+
and export with torch.export.export using example inputs.
|
|
209
|
+
"""
|
|
210
|
+
# Grab the wrapped decoder
|
|
211
|
+
dec = translator.models[
|
|
212
|
+
0
|
|
213
|
+
].decoder # assumed QuantFairseqDecoder with forward_external_step
|
|
214
|
+
# Build export wrapper
|
|
215
|
+
wrapper = DecoderExportSingleStep(decoder=dec).eval()
|
|
216
|
+
|
|
217
|
+
# Example inputs (L inferred from wrapper/decoder)
|
|
218
|
+
L = wrapper.num_layers
|
|
219
|
+
H = wrapper.num_heads
|
|
220
|
+
Dh = wrapper.head_dim
|
|
221
|
+
C = wrapper.embed_dim
|
|
222
|
+
example_inputs, example_kwargs = make_example_inputs(L=L, H=H, Dh=Dh, C=C)
|
|
223
|
+
|
|
224
|
+
# Export circle (no dynamism assumed; shapes are fixed for export)
|
|
225
|
+
cm = tico.convert(
|
|
226
|
+
wrapper,
|
|
227
|
+
args=example_inputs,
|
|
228
|
+
kwargs=example_kwargs,
|
|
229
|
+
strict=True, # fail if something cannot be captured
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Save .pte
|
|
233
|
+
cm.save(save_path)
|
|
234
|
+
print(f"Saved decoder single-step export to: {save_path}")
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import torch.fx
|
|
19
|
+
import torch
|
|
20
|
+
from torch.export import ExportedProgram
|
|
21
|
+
|
|
22
|
+
from tico.passes import ops
|
|
23
|
+
from tico.serialize.circle_mapping import extract_shape
|
|
24
|
+
from tico.utils import logging
|
|
25
|
+
from tico.utils.graph import create_node
|
|
26
|
+
from tico.utils.passes import PassBase, PassResult
|
|
27
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
|
28
|
+
from tico.utils.utils import is_target_node
|
|
29
|
+
from tico.utils.validate_args_kwargs import ExpandArgs, ReshapeArgs
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@trace_graph_diff_on_pass
|
|
33
|
+
class ConvertExpandToSliceCat(PassBase):
|
|
34
|
+
"""
|
|
35
|
+
This pass replaces `aten.reshape` + `aten.expand` pattern by rewriting it using
|
|
36
|
+
a series of `aten.slice` and `aten.cat` operations.
|
|
37
|
+
|
|
38
|
+
This pass is specialized for expand of KVCache.
|
|
39
|
+
- Expects (batch, num_key_value_heads, seq_len, head_dim) as input shape of reshape
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, enabled: bool = False):
|
|
43
|
+
super().__init__()
|
|
44
|
+
self.enabled = enabled
|
|
45
|
+
|
|
46
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
|
47
|
+
if not self.enabled:
|
|
48
|
+
return PassResult(False)
|
|
49
|
+
|
|
50
|
+
logger = logging.getLogger(__name__)
|
|
51
|
+
|
|
52
|
+
graph_module = exported_program.graph_module
|
|
53
|
+
graph = graph_module.graph
|
|
54
|
+
modified = False
|
|
55
|
+
|
|
56
|
+
# This pass handles expand on EXPAND_DIM only
|
|
57
|
+
CAT_DIM = 1
|
|
58
|
+
EXPAND_DIM = 2
|
|
59
|
+
|
|
60
|
+
for node in graph.nodes:
|
|
61
|
+
if not isinstance(node, torch.fx.Node) or not is_target_node(
|
|
62
|
+
node, ops.aten.reshape
|
|
63
|
+
):
|
|
64
|
+
continue
|
|
65
|
+
|
|
66
|
+
post_reshape = node
|
|
67
|
+
post_reshape_args = ReshapeArgs(*post_reshape.args, **post_reshape.kwargs)
|
|
68
|
+
post_reshape_input = post_reshape_args.input
|
|
69
|
+
|
|
70
|
+
if not isinstance(post_reshape_input, torch.fx.Node) or not is_target_node(
|
|
71
|
+
post_reshape_input, ops.aten.expand
|
|
72
|
+
):
|
|
73
|
+
continue
|
|
74
|
+
|
|
75
|
+
expand = post_reshape_input
|
|
76
|
+
expand_args = ExpandArgs(*expand.args, **expand.kwargs)
|
|
77
|
+
expand_input = expand_args.input
|
|
78
|
+
expand_shape = extract_shape(expand)
|
|
79
|
+
|
|
80
|
+
if not isinstance(expand_input, torch.fx.Node) or not is_target_node(
|
|
81
|
+
expand_input, ops.aten.reshape
|
|
82
|
+
):
|
|
83
|
+
continue
|
|
84
|
+
|
|
85
|
+
pre_reshape = expand_input
|
|
86
|
+
pre_reshape_args = ReshapeArgs(*pre_reshape.args, **pre_reshape.kwargs)
|
|
87
|
+
pre_reshape_input = pre_reshape_args.input
|
|
88
|
+
pre_reshape_shape = extract_shape(pre_reshape)
|
|
89
|
+
|
|
90
|
+
if pre_reshape_shape[EXPAND_DIM] != 1:
|
|
91
|
+
continue
|
|
92
|
+
|
|
93
|
+
reshape_input_shape = extract_shape(pre_reshape_input)
|
|
94
|
+
|
|
95
|
+
if len(expand_shape) != len(pre_reshape_shape):
|
|
96
|
+
continue
|
|
97
|
+
|
|
98
|
+
# Ensure all dimensions *except* at EXPAND_DIM are identical.
|
|
99
|
+
if not (
|
|
100
|
+
expand_shape[:EXPAND_DIM] == pre_reshape_shape[:EXPAND_DIM]
|
|
101
|
+
and expand_shape[EXPAND_DIM + 1 :]
|
|
102
|
+
== pre_reshape_shape[EXPAND_DIM + 1 :]
|
|
103
|
+
):
|
|
104
|
+
continue
|
|
105
|
+
|
|
106
|
+
# Ensure the expansion dimension is a clean multiple.
|
|
107
|
+
if expand_shape[EXPAND_DIM] % pre_reshape_shape[EXPAND_DIM] != 0:
|
|
108
|
+
continue
|
|
109
|
+
|
|
110
|
+
expand_ratio = expand_shape[EXPAND_DIM] // pre_reshape_shape[EXPAND_DIM]
|
|
111
|
+
|
|
112
|
+
if expand_ratio <= 1:
|
|
113
|
+
continue
|
|
114
|
+
|
|
115
|
+
cat_nodes = []
|
|
116
|
+
|
|
117
|
+
for i in range(reshape_input_shape[CAT_DIM]):
|
|
118
|
+
with graph.inserting_before(expand):
|
|
119
|
+
slice_copy_args = (pre_reshape_input, CAT_DIM, i, i + 1, 1)
|
|
120
|
+
slice_node = create_node(
|
|
121
|
+
graph,
|
|
122
|
+
torch.ops.aten.slice.Tensor,
|
|
123
|
+
args=slice_copy_args,
|
|
124
|
+
origin=expand,
|
|
125
|
+
)
|
|
126
|
+
with graph.inserting_after(slice_node):
|
|
127
|
+
cat_args = ([slice_node] * expand_ratio, CAT_DIM)
|
|
128
|
+
cat_node = create_node(
|
|
129
|
+
graph,
|
|
130
|
+
torch.ops.aten.cat.default,
|
|
131
|
+
args=cat_args,
|
|
132
|
+
origin=expand,
|
|
133
|
+
)
|
|
134
|
+
cat_nodes.append(cat_node)
|
|
135
|
+
|
|
136
|
+
with graph.inserting_after(expand):
|
|
137
|
+
cat_args = (cat_nodes, CAT_DIM)
|
|
138
|
+
cat_node = create_node(
|
|
139
|
+
graph,
|
|
140
|
+
torch.ops.aten.cat.default,
|
|
141
|
+
args=cat_args,
|
|
142
|
+
origin=expand,
|
|
143
|
+
)
|
|
144
|
+
expand.replace_all_uses_with(cat_node)
|
|
145
|
+
|
|
146
|
+
modified = True
|
|
147
|
+
logger.debug(f"{expand.name} is replaced with {cat_node.name} operators")
|
|
148
|
+
|
|
149
|
+
graph.eliminate_dead_code()
|
|
150
|
+
graph.lint()
|
|
151
|
+
graph_module.recompile()
|
|
152
|
+
|
|
153
|
+
return PassResult(modified)
|
tico/passes/ops.py
CHANGED
tico/utils/convert.py
CHANGED
|
@@ -39,6 +39,7 @@ from tico.passes.cast_clamp_mixed_type_args import CastClampMixedTypeArgs
|
|
|
39
39
|
from tico.passes.cast_mixed_type_args import CastMixedTypeArgs
|
|
40
40
|
from tico.passes.const_prop_pass import ConstPropPass
|
|
41
41
|
from tico.passes.convert_conv1d_to_conv2d import ConvertConv1dToConv2d
|
|
42
|
+
from tico.passes.convert_expand_to_slice_cat import ConvertExpandToSliceCat
|
|
42
43
|
from tico.passes.convert_layout_op_to_reshape import ConvertLayoutOpToReshape
|
|
43
44
|
from tico.passes.convert_matmul_to_linear import ConvertMatmulToLinear
|
|
44
45
|
from tico.passes.convert_repeat_to_expand_copy import ConvertRepeatToExpandCopy
|
|
@@ -250,6 +251,7 @@ def convert_exported_module_to_circle(
|
|
|
250
251
|
ConstPropPass(),
|
|
251
252
|
SegmentIndexSelectConst(),
|
|
252
253
|
LegalizeCausalMaskValue(enabled=config.get("legalize_causal_mask_value")),
|
|
254
|
+
ConvertExpandToSliceCat(enabled=config.get("convert_expand_to_slice_cat")),
|
|
253
255
|
ConvertMatmulToLinear(
|
|
254
256
|
enable_lhs_const=config.get("convert_lhs_const_mm_to_fc"),
|
|
255
257
|
enable_rhs_const=config.get("convert_rhs_const_mm_to_fc"),
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
tico/__init__.py,sha256=
|
|
1
|
+
tico/__init__.py,sha256=Dh35NwoATHo8LtQfaClJ4ecMmZjFjQAmdKJ9nJvnQb0,1883
|
|
2
2
|
tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
|
|
3
3
|
tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
|
|
4
4
|
tico/config/base.py,sha256=q5xMqGxTUZs4mFqt5c7i_y9U00fYgdMGl9nUqIVMlCo,1248
|
|
5
5
|
tico/config/factory.py,sha256=il0zqB6Lm5NX2LnG-TUhmiP9vVeZ_3TucJMorVZIodY,1324
|
|
6
|
-
tico/config/v1.py,sha256=
|
|
6
|
+
tico/config/v1.py,sha256=uB5d39fkmuBACwjBVGtdWb_HGXfXsvmw6nw64xZcC-8,1342
|
|
7
7
|
tico/experimental/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
8
8
|
tico/experimental/quantization/__init__.py,sha256=IaJPZegVJp0P3luutBo907Kp5sOJensE1Mm-XBG_jBs,122
|
|
9
9
|
tico/experimental/quantization/public_interface.py,sha256=TGo3bTapwLA8KpsoEwBhuzI0LQUO6y3-sUM1VZvkLo8,4220
|
|
@@ -91,6 +91,7 @@ tico/experimental/quantization/ptq/wrappers/quant_elementwise.py,sha256=LhEoobfv
|
|
|
91
91
|
tico/experimental/quantization/ptq/wrappers/quant_module_base.py,sha256=vkcDos_knGSS29rIZuEIWkAJLHrENbGz8nCH2-iara8,5969
|
|
92
92
|
tico/experimental/quantization/ptq/wrappers/registry.py,sha256=OVO5nev6J8Br9zsIX-Ut7ZgWzA9f_jk0Np9bGioXgQM,5171
|
|
93
93
|
tico/experimental/quantization/ptq/wrappers/fairseq/__init__.py,sha256=Mc8FLd9DusyB_IT1vk1OYrRkngOYnYd05IvtA9ORVQc,160
|
|
94
|
+
tico/experimental/quantization/ptq/wrappers/fairseq/decoder_export_single_step.py,sha256=d7ZieKiSbZ2ffkaLYMg2PJl1OyAxkKjB3OHKB4poxJs,9796
|
|
94
95
|
tico/experimental/quantization/ptq/wrappers/fairseq/quant_decoder.py,sha256=CILYvxPhW7xLkroWW_hunQBGAYGexLqnPnO5xmMnK-E,17877
|
|
95
96
|
tico/experimental/quantization/ptq/wrappers/fairseq/quant_decoder_layer.py,sha256=JT79shxOhDtRFgm8jrrN6HKvyVotiytLjMjAxX-Cztg,20416
|
|
96
97
|
tico/experimental/quantization/ptq/wrappers/fairseq/quant_encoder.py,sha256=r9DPUAbL2KRJ8zpMJ39Y9n6Oe79nte-mFcdjG2qEP-w,13809
|
|
@@ -113,6 +114,7 @@ tico/passes/cast_clamp_mixed_type_args.py,sha256=m3_HpXLywWmWERfE5lM5PgvjBod7C4B
|
|
|
113
114
|
tico/passes/cast_mixed_type_args.py,sha256=Wd3sCDKJZwdb8GiMWKljm8X5CLFRd8eCz-dmWks15Hc,7763
|
|
114
115
|
tico/passes/const_prop_pass.py,sha256=hDxGgJNiRjsgOArdaoeAOcOOA-nKBvA1W1zcMZQA5yg,11531
|
|
115
116
|
tico/passes/convert_conv1d_to_conv2d.py,sha256=ktS3h158y9rg1sQiW8BZZbflV_dk_UdjBPQnuiOKyzg,5303
|
|
117
|
+
tico/passes/convert_expand_to_slice_cat.py,sha256=Fa6b5pqiQNq-QBiEC0e3WkQYf2UEhMgzSTIt4hlzdjc,5470
|
|
116
118
|
tico/passes/convert_layout_op_to_reshape.py,sha256=sCAFjkmVtiKjvDQSAgnjNBHl3_hWXJZElGDXQiTH-7s,2963
|
|
117
119
|
tico/passes/convert_matmul_to_linear.py,sha256=WATtsHk_GzsU0HYovc3UMyEj8ApF2qLbInAsNlQj0nE,9759
|
|
118
120
|
tico/passes/convert_repeat_to_expand_copy.py,sha256=JbtFTmWyfJS2SSd_higP1IEhQeh7wHdN5dmTbbiFVCs,3237
|
|
@@ -134,7 +136,7 @@ tico/passes/lower_pow2_to_mul.py,sha256=nfJXa9ZTZMiLg6ownSyvkM4KF2z9tZW34Q3CCWI_
|
|
|
134
136
|
tico/passes/lower_to_resize_nearest_neighbor.py,sha256=gbrvTmWSXDPdJ1XJtWGI5mo-uEiauXEG3ELwbKYVPLI,9013
|
|
135
137
|
tico/passes/lower_to_slice.py,sha256=OzlFzK3lBYyYwC3WThsWd94Ob4JINIJF8UaLAtnumzU,7262
|
|
136
138
|
tico/passes/merge_consecutive_cat.py,sha256=ayZNLDA1DFM7Fxxi2Dmk1CujkgUuaVCH1rhQgLrvvOQ,2701
|
|
137
|
-
tico/passes/ops.py,sha256=
|
|
139
|
+
tico/passes/ops.py,sha256=7IGRnxIJl-nLO4huVk_mgBfD4VGUNQRyeuM8K1L2u1U,2934
|
|
138
140
|
tico/passes/remove_nop.py,sha256=Hf91p_EJAOC6DyWNthash0_UWtEcNc_M7znamQfYQ5Y,2686
|
|
139
141
|
tico/passes/remove_redundant_assert_nodes.py,sha256=rYbTCyuNIXIC-2NreHKBVCuaSUkEQvB_iSRzb26P_EA,1821
|
|
140
142
|
tico/passes/remove_redundant_expand.py,sha256=8yhlMnbog-T9gIK6LKIU0tu0__gfhZzO36g_fJIVVP4,2162
|
|
@@ -236,7 +238,7 @@ tico/serialize/operators/utils.py,sha256=lXGpEJW1h8U_-gfc6EWjvvSiq3yJ9P-v1v3EMRT
|
|
|
236
238
|
tico/serialize/operators/adapters/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
237
239
|
tico/serialize/operators/adapters/llama_rmsnorm.py,sha256=6t3dhfNpR03eIjsmhymF2JKd6lCf7PvInqMf77c_BOE,1139
|
|
238
240
|
tico/utils/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
239
|
-
tico/utils/convert.py,sha256=
|
|
241
|
+
tico/utils/convert.py,sha256=10YufXpuqHz274ACUb1_F5594uClUFhBEh8SY6gYp7w,13809
|
|
240
242
|
tico/utils/define.py,sha256=Ypgp7YffM4pgPl4Zh6TmogSn1OxGBMRw_e09qYGflZk,1467
|
|
241
243
|
tico/utils/diff_graph.py,sha256=_eDGGPDPYQD4b--MXX0DLoVgSt_wLfNPt47UlolLLR4,5272
|
|
242
244
|
tico/utils/dtype.py,sha256=L5Qb7qgbt0eQ5frUTvHYrRtTJb1dg4-JNEopcxCNg1U,1389
|
|
@@ -260,9 +262,9 @@ tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
|
260
262
|
tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
|
|
261
263
|
tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
|
|
262
264
|
tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
|
|
263
|
-
tico-0.1.0.
|
|
264
|
-
tico-0.1.0.
|
|
265
|
-
tico-0.1.0.
|
|
266
|
-
tico-0.1.0.
|
|
267
|
-
tico-0.1.0.
|
|
268
|
-
tico-0.1.0.
|
|
265
|
+
tico-0.1.0.dev250925.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
|
|
266
|
+
tico-0.1.0.dev250925.dist-info/METADATA,sha256=LReZIvZizVYMd9esGIdPgEiaO250UU59SBEoEbAe0rM,8450
|
|
267
|
+
tico-0.1.0.dev250925.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
|
268
|
+
tico-0.1.0.dev250925.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
|
|
269
|
+
tico-0.1.0.dev250925.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
|
|
270
|
+
tico-0.1.0.dev250925.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|