tico 0.1.0.dev250922__py3-none-any.whl → 0.1.0.dev250924__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tico might be problematic. Click here for more details.

tico/__init__.py CHANGED
@@ -29,7 +29,7 @@ __all__ = [
29
29
  ]
30
30
 
31
31
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
32
- __version__ = "0.1.0.dev250922"
32
+ __version__ = "0.1.0.dev250924"
33
33
 
34
34
  MINIMUM_SUPPORTED_VERSION = "2.5.0"
35
35
  SECURE_TORCH_VERSION = "2.6.0"
tico/config/v1.py CHANGED
@@ -24,6 +24,7 @@ class CompileConfigV1(CompileConfigBase):
24
24
  convert_lhs_const_mm_to_fc: bool = False
25
25
  convert_rhs_const_mm_to_fc: bool = True
26
26
  convert_single_batch_lhs_const_bmm_to_fc: bool = False
27
+ convert_expand_to_slice_cat: bool = False
27
28
 
28
29
  def get(self, name: str):
29
30
  return super().get(name)
@@ -27,6 +27,7 @@ from tico.experimental.quantization.algorithm.gptq.utils import (
27
27
  )
28
28
  from tico.experimental.quantization.config.gptq import GPTQConfig
29
29
  from tico.experimental.quantization.quantizer import BaseQuantizer
30
+ from tico.experimental.quantization.quantizer_registry import register_quantizer
30
31
 
31
32
 
32
33
  class StopForward(Exception):
@@ -35,6 +36,7 @@ class StopForward(Exception):
35
36
  pass
36
37
 
37
38
 
39
+ @register_quantizer(GPTQConfig)
38
40
  class GPTQQuantizer(BaseQuantizer):
39
41
  """
40
42
  Quantizer for applying the GPTQ algorithm (typically for weight quantization).
@@ -22,9 +22,12 @@ from tico.experimental.quantization.algorithm.pt2e.annotation.annotator import (
22
22
  get_asymmetric_quantization_config,
23
23
  PT2EAnnotator,
24
24
  )
25
+ from tico.experimental.quantization.config.pt2e import PT2EConfig
25
26
  from tico.experimental.quantization.quantizer import BaseQuantizer
27
+ from tico.experimental.quantization.quantizer_registry import register_quantizer
26
28
 
27
29
 
30
+ @register_quantizer(PT2EConfig)
28
31
  class PT2EQuantizer(BaseQuantizer):
29
32
  """
30
33
  Quantizer for applying pytorch 2.0 export quantization (typically for activation quantization).
@@ -25,8 +25,10 @@ from tico.experimental.quantization.algorithm.smoothquant.smooth_quant import (
25
25
  )
26
26
  from tico.experimental.quantization.config.smoothquant import SmoothQuantConfig
27
27
  from tico.experimental.quantization.quantizer import BaseQuantizer
28
+ from tico.experimental.quantization.quantizer_registry import register_quantizer
28
29
 
29
30
 
31
+ @register_quantizer(SmoothQuantConfig)
30
32
  class SmoothQuantQuantizer(BaseQuantizer):
31
33
  """
32
34
  Quantizer for applying the SmoothQuant algorithm
@@ -38,4 +38,4 @@ class SmoothQuantConfig(BaseConfig):
38
38
 
39
39
  @property
40
40
  def name(self) -> str:
41
- return "smooth_quant"
41
+ return "smoothquant"
@@ -0,0 +1,234 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # -----------------------------------------------------------------------------
16
+ # This file includes modifications based on fairseq
17
+ # (https://github.com/facebookresearch/fairseq), originally licensed under
18
+ # the MIT License. See the LICENSE file in the fairseq repository for details.
19
+ # -----------------------------------------------------------------------------
20
+
21
+ """
22
+ Q) Why the name "SingleStep"?
23
+
24
+ Fairseq's decoder already advances one token at a time during generation,
25
+ but the default path is "stateful" and "shape-polymorphic": it owns and
26
+ mutates K/V caches internally, prefix lengths and triangular masks grow with
27
+ the step, and beam reordering updates hidden module state. That's friendly
28
+ for eager execution, but hostile to `torch.export` and many accelerator
29
+ backends.
30
+
31
+ This export wrapper makes the per-token call truly "single-step" in the
32
+ export sense: "stateless" and "fixed-shape" so every invocation has the
33
+ exact same graph.
34
+
35
+ Key invariants
36
+ --------------
37
+ • "Stateless": K/V caches come in as explicit inputs and go out as outputs.
38
+ The module does not store or mutate hidden state.
39
+ • "Static shapes": Query is always [B, 1, C]; encoder features and masks
40
+ have fixed, predeclared sizes; K/V slots use fixed capacity (unused tail
41
+ is simply masked/ignored).
42
+ • "External control": Step indexing, cache slot management (append/roll),
43
+ and beam reordering are handled outside the module.
44
+ • "Prebuilt additive masks": Self-attention masks are provided by the
45
+ caller (0 for valid, large negative sentinel, e.g. -120, for masked),
46
+ avoiding data-dependent control flow.
47
+
48
+ In short: still step-wise like fairseq, but restructured for export—no
49
+ internal state, no data-dependent shapes, no dynamic control flow.
50
+ """
51
+
52
+ from typing import List, Tuple
53
+
54
+ import torch
55
+ import torch.nn as nn
56
+
57
+ import tico
58
+
59
+ # ----- 1) Export wrapper module -------------------------------------------
60
+ class DecoderExportSingleStep(nn.Module):
61
+ """
62
+ Export-only single-step decoder module.
63
+
64
+ Inputs (example shapes; B=1, H=8, Dh=64, C=512, S=64, Tprev=63):
65
+ - prev_x: [B, 1, C] embedded decoder input for the current step
66
+ - enc_x: [S, B, C] encoder hidden states (fixed-length export input)
67
+ - enc_pad_additive: [B, 1, S] additive float key_padding_mask for enc-dec attn (0 for keep, -120 for pad)
68
+ - self_attn_mask: [B, 1, S] additive float mask for decoder self-attn at this step; pass zeros if unused
69
+ - prev_self_k_0..L-1: [B, H, Tprev, Dh] cached self-attn K per layer
70
+ - prev_self_v_0..L-1: [B, H, Tprev, Dh] cached self-attn V per layer
71
+
72
+ Outputs:
73
+ - x_out: [B, 1, C] new decoder features at the current step
74
+ - new_k_0..L-1: [H, B, Dh] per-layer new K (single-timestep; time dim squeezed)
75
+ - new_v_0..L-1: [H, B, Dh] per-layer new V (single-timestep; time dim squeezed)
76
+
77
+ Notes:
78
+ • We keep masks/additive semantics externally to avoid any mask-building inside the graph.
79
+ • We reshape the new K/V from [B,H,1,Dh] -> [H,B,Dh] to match the requested output spec (8,1,64).
80
+ """
81
+
82
+ def __init__(self, decoder: nn.Module):
83
+ super().__init__()
84
+ self.decoder = decoder
85
+ # Cache common meta for assertions
86
+ self.num_layers = len(getattr(decoder, "layers"))
87
+ # Infer heads/head_dim from the wrapped self_attn of layer 0
88
+ any_layer = getattr(decoder.layers[0], "wrapped", decoder.layers[0]) # type: ignore[index]
89
+ mha = getattr(any_layer, "self_attn", None)
90
+ assert mha is not None, "Decoder layer must expose self_attn"
91
+ self.num_heads = int(mha.num_heads)
92
+ self.head_dim = int(mha.head_dim)
93
+ # Embed dim (C)
94
+ self.embed_dim = int(getattr(decoder, "embed_dim"))
95
+
96
+ def forward(
97
+ self,
98
+ prev_x: torch.Tensor, # [B,1,C]
99
+ enc_x: torch.Tensor, # [S,B,C]
100
+ enc_pad_additive: torch.Tensor, # [B,1,S]
101
+ *kv_args: torch.Tensor, # prev_k_0..L-1, prev_v_0..L-1 (total 2L tensors)
102
+ self_attn_mask: torch.Tensor, # [B,1,S] (or zeros)
103
+ ):
104
+ L = self.num_layers
105
+ H = self.num_heads
106
+ Dh = self.head_dim
107
+ B, one, C = prev_x.shape
108
+ S, B2, C2 = enc_x.shape
109
+ assert (
110
+ one == 1 and C == self.embed_dim and B == B2 and C2 == C
111
+ ), "Shape mismatch in prev_x/enc_x"
112
+ assert len(kv_args) == 2 * L, f"Expected {2*L} KV tensors, got {len(kv_args)}"
113
+
114
+ # Unpack previous self-attn caches
115
+ prev_k_list: List[torch.Tensor] = list() # each [B,H,Tprev,Dh]
116
+ prev_v_list: List[torch.Tensor] = list() # each [B,H,Tprev,Dh]
117
+ for i in range(L):
118
+ prev_k_list.append(kv_args[2 * i])
119
+ prev_v_list.append(kv_args[2 * i + 1])
120
+ for i in range(L):
121
+ assert (
122
+ prev_k_list[i].dim() == 4 and prev_v_list[i].dim() == 4
123
+ ), "KV must be [B,H,Tprev,Dh]"
124
+ assert (
125
+ prev_k_list[i].shape[0] == B
126
+ and prev_k_list[i].shape[1] == H
127
+ and prev_k_list[i].shape[3] == Dh
128
+ )
129
+
130
+ # Call decoder's external single-step path
131
+ # Returns:
132
+ # x_step: [B,1,C]
133
+ # newk/newv: lists of length L, each [B*H,1,Dh]
134
+ x_step, newk_list, newv_list = self.decoder.forward_external_step( # type: ignore[operator]
135
+ prev_output_x=prev_x,
136
+ encoder_out_x=enc_x,
137
+ encoder_padding_mask=enc_pad_additive,
138
+ self_attn_mask=self_attn_mask,
139
+ prev_self_k_list=prev_k_list,
140
+ prev_self_v_list=prev_v_list,
141
+ )
142
+
143
+ out_tensors: List[torch.Tensor] = [
144
+ x_step
145
+ ] # first output is the new decoder features
146
+ for i in range(L):
147
+ nk = newk_list[i] # [B*H, Tnew, Dh]
148
+ nv = newv_list[i] # [B*H, Tnew, Dh]
149
+ out_tensors.append(nk)
150
+ out_tensors.append(nv)
151
+
152
+ # Return tuple: (x_step, new_k_0, new_v_0, new_k_1, new_v_1, ..., new_k_{L-1}, new_v_{L-1})
153
+ return tuple(out_tensors)
154
+
155
+
156
+ # ----- 2) Example inputs (B=1, S=64, H=8, Dh=64, C=512, L=4) ---------------
157
+ def make_example_inputs(*, L=4, B=1, S=64, H=8, Dh=64, C=512, Tprev=63, device="cpu"):
158
+ """
159
+ Build example tensors that match the export I/O spec.
160
+ Shapes follow the request:
161
+ prev_x: [1,1,512]
162
+ enc_x: [64,1,512]
163
+ enc_pad_additive: [1,1,64] (additive float; zeros -> keep)
164
+ prev_k_i / prev_v_i (for i in 0..L-1): [1,8,63,64]
165
+ self_attn_mask: [1,1,64] (additive float; zeros -> keep)
166
+ """
167
+ g = torch.Generator(device=device).manual_seed(0)
168
+
169
+ prev_x = torch.randn(B, 1, C, device=device, dtype=torch.float32, generator=g)
170
+ enc_x = torch.randn(S, B, C, device=device, dtype=torch.float32, generator=g)
171
+
172
+ # Additive masks (0 for allowed, -120 for masked)
173
+ enc_pad_additive = torch.full((B, 1, S), float(-120), device=device)
174
+ self_attn_mask = torch.full((B, 1, S), float(-120), device=device)
175
+ enc_pad_additive[0, :27] = 0 # 27 is a random example.
176
+ self_attn_mask[0, :27] = 0 # 27 is a random example.
177
+
178
+ # Previous self-attn caches for each layer
179
+ prev_k_list = []
180
+ prev_v_list = []
181
+ for _ in range(L):
182
+ prev_k = torch.randn(
183
+ B, H, Tprev, Dh, device=device, dtype=torch.float32, generator=g
184
+ )
185
+ prev_v = torch.randn(
186
+ B, H, Tprev, Dh, device=device, dtype=torch.float32, generator=g
187
+ )
188
+ prev_k_list.append(prev_k)
189
+ prev_v_list.append(prev_v)
190
+
191
+ # Pack inputs as the export function will expect:
192
+ # (prev_x, enc_x, enc_pad_additive, self_attn_mask, prev_k_0..L-1, prev_v_0..L-1)
193
+ example_args: Tuple[torch.Tensor, ...] = (
194
+ prev_x,
195
+ enc_x,
196
+ enc_pad_additive,
197
+ *prev_k_list,
198
+ *prev_v_list,
199
+ )
200
+ example_kwargs = {"self_attn_mask": self_attn_mask}
201
+ return example_args, example_kwargs
202
+
203
+
204
+ # ----- 3) Export driver -----------------------------------------------------
205
+ def export_decoder_single_step(translator, *, save_path="decoder_step_export.circle"):
206
+ """
207
+ Wrap the QuantFairseqDecoder into the export-friendly single-step module
208
+ and export with torch.export.export using example inputs.
209
+ """
210
+ # Grab the wrapped decoder
211
+ dec = translator.models[
212
+ 0
213
+ ].decoder # assumed QuantFairseqDecoder with forward_external_step
214
+ # Build export wrapper
215
+ wrapper = DecoderExportSingleStep(decoder=dec).eval()
216
+
217
+ # Example inputs (L inferred from wrapper/decoder)
218
+ L = wrapper.num_layers
219
+ H = wrapper.num_heads
220
+ Dh = wrapper.head_dim
221
+ C = wrapper.embed_dim
222
+ example_inputs, example_kwargs = make_example_inputs(L=L, H=H, Dh=Dh, C=C)
223
+
224
+ # Export circle (no dynamism assumed; shapes are fixed for export)
225
+ cm = tico.convert(
226
+ wrapper,
227
+ args=example_inputs,
228
+ kwargs=example_kwargs,
229
+ strict=True, # fail if something cannot be captured
230
+ )
231
+
232
+ # Save .pte
233
+ cm.save(save_path)
234
+ print(f"Saved decoder single-step export to: {save_path}")
@@ -13,25 +13,17 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import copy
16
- from typing import Any, Dict, Optional, Type
16
+ from typing import Any, Dict, Optional
17
17
 
18
18
  import torch
19
19
 
20
20
  from tico.experimental.quantization.algorithm.gptq.quantizer import GPTQQuantizer
21
21
  from tico.experimental.quantization.algorithm.pt2e.quantizer import PT2EQuantizer
22
- from tico.experimental.quantization.algorithm.smoothquant.quantizer import (
23
- SmoothQuantQuantizer,
24
- )
25
22
  from tico.experimental.quantization.config.base import BaseConfig
26
23
  from tico.experimental.quantization.quantizer import BaseQuantizer
24
+ from tico.experimental.quantization.quantizer_registry import get_quantizer
27
25
 
28
26
 
29
- config_to_quantizer: Dict[str, Type[BaseQuantizer]] = {
30
- "pt2e": PT2EQuantizer,
31
- "gptq": GPTQQuantizer,
32
- "smooth_quant": SmoothQuantQuantizer,
33
- }
34
-
35
27
  QUANTIZER_ATTRIBUTE_NAME = "tico_quantizer"
36
28
 
37
29
 
@@ -61,14 +53,15 @@ def prepare(
61
53
  """
62
54
  if hasattr(model, QUANTIZER_ATTRIBUTE_NAME):
63
55
  raise RuntimeError("prepare() already has been called.")
64
- if quant_config.name == "pt2e" and inplace:
56
+ quantizer = get_quantizer(quant_config)
57
+
58
+ if isinstance(quantizer, PT2EQuantizer) and inplace:
65
59
  raise RuntimeError(
66
60
  "In-place is not supported for PT2E quantization due to limitation in the underlying Torch APIs. Please set 'inplace=False' to proceed."
67
61
  )
68
62
 
69
63
  model = model if inplace else copy.deepcopy(model)
70
64
 
71
- quantizer = config_to_quantizer[quant_config.name](quant_config)
72
65
  model = quantizer.prepare(model, args, kwargs)
73
66
  setattr(model, QUANTIZER_ATTRIBUTE_NAME, quantizer)
74
67
 
@@ -0,0 +1,72 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib
16
+ from typing import Dict, Optional, Type, TypeVar
17
+
18
+ from tico.experimental.quantization.config.base import BaseConfig
19
+ from tico.experimental.quantization.quantizer import BaseQuantizer
20
+
21
+ TQ = TypeVar("TQ", bound=BaseQuantizer)
22
+
23
+ # Mapping: Config type -> Quantizer type
24
+ _REGISTRY: Dict[Type[BaseConfig], Type[BaseQuantizer]] = {}
25
+
26
+
27
+ def register_quantizer(config_cls: Type[BaseConfig]):
28
+ """
29
+ Decorator to register a quantizer for a given config class.
30
+ Usage:
31
+ @register_quantizer(GPTQConfig)
32
+ class GPTQQuantizer(BaseQuantizer): ...
33
+ """
34
+
35
+ def wrapper(quantizer_cls: Type[TQ]) -> Type[TQ]:
36
+ _REGISTRY[config_cls] = quantizer_cls
37
+ return quantizer_cls
38
+
39
+ return wrapper
40
+
41
+
42
+ def _lookup(cfg: BaseConfig) -> Optional[Type[BaseQuantizer]]:
43
+ """Return a quantizer class only if the exact config type is registered."""
44
+ return _REGISTRY.get(type(cfg))
45
+
46
+
47
+ def get_quantizer(cfg: BaseConfig) -> BaseQuantizer:
48
+ """Factory to return a quantizer instance for the given config."""
49
+ qcls = _lookup(cfg)
50
+ if qcls is not None:
51
+ return qcls(cfg)
52
+
53
+ # Lazy import by naming convention
54
+ name = getattr(cfg, "name", None)
55
+ if name:
56
+ try:
57
+ importlib.import_module(
58
+ f"tico.experimental.quantization.algorithm.{name}.quantizer"
59
+ )
60
+ except Exception as e:
61
+ raise RuntimeError(
62
+ f"Failed to import quantizer module for config name='{name}': {e}"
63
+ )
64
+
65
+ qcls = _lookup(cfg)
66
+ if qcls is not None:
67
+ return qcls(cfg)
68
+
69
+ raise RuntimeError(
70
+ f"No quantizer registered for config type {type(cfg).__name__} "
71
+ f"(name='{getattr(cfg,'name',None)}')."
72
+ )
@@ -0,0 +1,153 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.passes import ops
23
+ from tico.serialize.circle_mapping import extract_shape
24
+ from tico.utils import logging
25
+ from tico.utils.graph import create_node
26
+ from tico.utils.passes import PassBase, PassResult
27
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.utils import is_target_node
29
+ from tico.utils.validate_args_kwargs import ExpandArgs, ReshapeArgs
30
+
31
+
32
+ @trace_graph_diff_on_pass
33
+ class ConvertExpandToSliceCat(PassBase):
34
+ """
35
+ This pass replaces `aten.reshape` + `aten.expand` pattern by rewriting it using
36
+ a series of `aten.slice` and `aten.cat` operations.
37
+
38
+ This pass is specialized for expand of KVCache.
39
+ - Expects (batch, num_key_value_heads, seq_len, head_dim) as input shape of reshape
40
+ """
41
+
42
+ def __init__(self, enabled: bool = False):
43
+ super().__init__()
44
+ self.enabled = enabled
45
+
46
+ def call(self, exported_program: ExportedProgram) -> PassResult:
47
+ if not self.enabled:
48
+ return PassResult(False)
49
+
50
+ logger = logging.getLogger(__name__)
51
+
52
+ graph_module = exported_program.graph_module
53
+ graph = graph_module.graph
54
+ modified = False
55
+
56
+ # This pass handles expand on EXPAND_DIM only
57
+ CAT_DIM = 1
58
+ EXPAND_DIM = 2
59
+
60
+ for node in graph.nodes:
61
+ if not isinstance(node, torch.fx.Node) or not is_target_node(
62
+ node, ops.aten.reshape
63
+ ):
64
+ continue
65
+
66
+ post_reshape = node
67
+ post_reshape_args = ReshapeArgs(*post_reshape.args, **post_reshape.kwargs)
68
+ post_reshape_input = post_reshape_args.input
69
+
70
+ if not isinstance(post_reshape_input, torch.fx.Node) or not is_target_node(
71
+ post_reshape_input, ops.aten.expand
72
+ ):
73
+ continue
74
+
75
+ expand = post_reshape_input
76
+ expand_args = ExpandArgs(*expand.args, **expand.kwargs)
77
+ expand_input = expand_args.input
78
+ expand_shape = extract_shape(expand)
79
+
80
+ if not isinstance(expand_input, torch.fx.Node) or not is_target_node(
81
+ expand_input, ops.aten.reshape
82
+ ):
83
+ continue
84
+
85
+ pre_reshape = expand_input
86
+ pre_reshape_args = ReshapeArgs(*pre_reshape.args, **pre_reshape.kwargs)
87
+ pre_reshape_input = pre_reshape_args.input
88
+ pre_reshape_shape = extract_shape(pre_reshape)
89
+
90
+ if pre_reshape_shape[EXPAND_DIM] != 1:
91
+ continue
92
+
93
+ reshape_input_shape = extract_shape(pre_reshape_input)
94
+
95
+ if len(expand_shape) != len(pre_reshape_shape):
96
+ continue
97
+
98
+ # Ensure all dimensions *except* at EXPAND_DIM are identical.
99
+ if not (
100
+ expand_shape[:EXPAND_DIM] == pre_reshape_shape[:EXPAND_DIM]
101
+ and expand_shape[EXPAND_DIM + 1 :]
102
+ == pre_reshape_shape[EXPAND_DIM + 1 :]
103
+ ):
104
+ continue
105
+
106
+ # Ensure the expansion dimension is a clean multiple.
107
+ if expand_shape[EXPAND_DIM] % pre_reshape_shape[EXPAND_DIM] != 0:
108
+ continue
109
+
110
+ expand_ratio = expand_shape[EXPAND_DIM] // pre_reshape_shape[EXPAND_DIM]
111
+
112
+ if expand_ratio <= 1:
113
+ continue
114
+
115
+ cat_nodes = []
116
+
117
+ for i in range(reshape_input_shape[CAT_DIM]):
118
+ with graph.inserting_before(expand):
119
+ slice_copy_args = (pre_reshape_input, CAT_DIM, i, i + 1, 1)
120
+ slice_node = create_node(
121
+ graph,
122
+ torch.ops.aten.slice.Tensor,
123
+ args=slice_copy_args,
124
+ origin=expand,
125
+ )
126
+ with graph.inserting_after(slice_node):
127
+ cat_args = ([slice_node] * expand_ratio, CAT_DIM)
128
+ cat_node = create_node(
129
+ graph,
130
+ torch.ops.aten.cat.default,
131
+ args=cat_args,
132
+ origin=expand,
133
+ )
134
+ cat_nodes.append(cat_node)
135
+
136
+ with graph.inserting_after(expand):
137
+ cat_args = (cat_nodes, CAT_DIM)
138
+ cat_node = create_node(
139
+ graph,
140
+ torch.ops.aten.cat.default,
141
+ args=cat_args,
142
+ origin=expand,
143
+ )
144
+ expand.replace_all_uses_with(cat_node)
145
+
146
+ modified = True
147
+ logger.debug(f"{expand.name} is replaced with {cat_node.name} operators")
148
+
149
+ graph.eliminate_dead_code()
150
+ graph.lint()
151
+ graph_module.recompile()
152
+
153
+ return PassResult(modified)
tico/passes/ops.py CHANGED
@@ -69,7 +69,6 @@ class AtenOps:
69
69
  torch.ops.aten.unsqueeze_copy.default,
70
70
  ]
71
71
  self.view = [
72
- torch.ops.aten.view,
73
72
  torch.ops.aten.view.default,
74
73
  torch.ops.aten.view_copy.default,
75
74
  ]
tico/utils/convert.py CHANGED
@@ -39,6 +39,7 @@ from tico.passes.cast_clamp_mixed_type_args import CastClampMixedTypeArgs
39
39
  from tico.passes.cast_mixed_type_args import CastMixedTypeArgs
40
40
  from tico.passes.const_prop_pass import ConstPropPass
41
41
  from tico.passes.convert_conv1d_to_conv2d import ConvertConv1dToConv2d
42
+ from tico.passes.convert_expand_to_slice_cat import ConvertExpandToSliceCat
42
43
  from tico.passes.convert_layout_op_to_reshape import ConvertLayoutOpToReshape
43
44
  from tico.passes.convert_matmul_to_linear import ConvertMatmulToLinear
44
45
  from tico.passes.convert_repeat_to_expand_copy import ConvertRepeatToExpandCopy
@@ -250,6 +251,7 @@ def convert_exported_module_to_circle(
250
251
  ConstPropPass(),
251
252
  SegmentIndexSelectConst(),
252
253
  LegalizeCausalMaskValue(enabled=config.get("legalize_causal_mask_value")),
254
+ ConvertExpandToSliceCat(enabled=config.get("convert_expand_to_slice_cat")),
253
255
  ConvertMatmulToLinear(
254
256
  enable_lhs_const=config.get("convert_lhs_const_mm_to_fc"),
255
257
  enable_rhs_const=config.get("convert_rhs_const_mm_to_fc"),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tico
3
- Version: 0.1.0.dev250922
3
+ Version: 0.1.0.dev250924
4
4
  Summary: Convert exported Torch module to circle
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,21 +1,22 @@
1
- tico/__init__.py,sha256=aXzPnAgp_3hFd-ia92oDhfjfZ1NABlYkgUlEbFs5Pb0,1883
1
+ tico/__init__.py,sha256=QZao9QkVmcSoCMri9OngdTEi5qQ-fR7joim-Mp04_Hk,1883
2
2
  tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
3
3
  tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
4
4
  tico/config/base.py,sha256=q5xMqGxTUZs4mFqt5c7i_y9U00fYgdMGl9nUqIVMlCo,1248
5
5
  tico/config/factory.py,sha256=il0zqB6Lm5NX2LnG-TUhmiP9vVeZ_3TucJMorVZIodY,1324
6
- tico/config/v1.py,sha256=lEyKemeKGrJ0bA5w-LPkMWVlnAiJRDm9mM48TJle-e4,1296
6
+ tico/config/v1.py,sha256=uB5d39fkmuBACwjBVGtdWb_HGXfXsvmw6nw64xZcC-8,1342
7
7
  tico/experimental/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
8
8
  tico/experimental/quantization/__init__.py,sha256=IaJPZegVJp0P3luutBo907Kp5sOJensE1Mm-XBG_jBs,122
9
- tico/experimental/quantization/public_interface.py,sha256=y-iwaeuedBvHwTh5hflQg4u2ZCdqf46IlTl9ntHq8pU,4425
9
+ tico/experimental/quantization/public_interface.py,sha256=TGo3bTapwLA8KpsoEwBhuzI0LQUO6y3-sUM1VZvkLo8,4220
10
10
  tico/experimental/quantization/quantizer.py,sha256=pDTQGzR-BcQJeGZ7O4cXRQdCme4q_POpxHetwnv0bYg,2370
11
+ tico/experimental/quantization/quantizer_registry.py,sha256=7wm2JcuPRribu7c8dCSZeYVcVqWQO1S-tHoinDDt11s,2345
11
12
  tico/experimental/quantization/algorithm/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
12
13
  tico/experimental/quantization/algorithm/gptq/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
13
14
  tico/experimental/quantization/algorithm/gptq/gptq.py,sha256=Qn9b_2ki7B64DcVEY25NMkww3PdZ5EqYQQXfYhNDQ6I,5555
14
15
  tico/experimental/quantization/algorithm/gptq/quant.py,sha256=Rl4wAOCmlE0U09BtNCDbccaSNohRHCNLwFi3zCqZfNo,5127
15
- tico/experimental/quantization/algorithm/gptq/quantizer.py,sha256=ZKeQQWm6eMUyRgntQxVR-QVjxJOc2pW4Dc_mrEPZA64,11686
16
+ tico/experimental/quantization/algorithm/gptq/quantizer.py,sha256=CDAo7M5Xi8Oa2EjzNtCb9i6IWwpkxWzfP2fe8_VTM8M,11799
16
17
  tico/experimental/quantization/algorithm/gptq/utils.py,sha256=leGKayf-xbSjVwwAGTA5RsxUKrhDiklOQdlsLifjdrs,1811
17
18
  tico/experimental/quantization/algorithm/pt2e/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
18
- tico/experimental/quantization/algorithm/pt2e/quantizer.py,sha256=mdTvsG87bo8fu0GaWqSM8iBCs-4f4EfUlVtk-Ko6M34,2546
19
+ tico/experimental/quantization/algorithm/pt2e/quantizer.py,sha256=PXfCQWCDYjMHTmEA6txHKh5miwruEZwDGsgjPYFBB9o,2725
19
20
  tico/experimental/quantization/algorithm/pt2e/utils.py,sha256=URjTGgsnDdhUC2Nr0-YJ9GWbVOKmjElfLr83Y8eCz-M,4806
20
21
  tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
21
22
  tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py,sha256=lFfblxglPxcN2IcrjAVYq7GOECIAQ4rr7M4euPp3yWc,7551
@@ -37,13 +38,13 @@ tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py,sha256=
37
38
  tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py,sha256=Idtoya2RcGKlgUJgC9WqNz0jH3gf6ViuPmsD9ySHbls,2253
38
39
  tico/experimental/quantization/algorithm/smoothquant/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
39
40
  tico/experimental/quantization/algorithm/smoothquant/observer.py,sha256=OWBKQ3ox6PqeqgevxOjpXvb7uApoqE4YbUBelGhVSN8,3435
40
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py,sha256=14-QrKAW-Rw6pIbbNaD5eORcH2fqi40-TNFGaWVakIg,3649
41
+ tico/experimental/quantization/algorithm/smoothquant/quantizer.py,sha256=VHc-_23VZWKCKZlcZvG5ESRKALgH4zU_Q9Tr-EEW4mk,3769
41
42
  tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py,sha256=fxCy4m-BsSjraciSVPFlPhgsOT46RjrOgczQGb7B9TA,11561
42
43
  tico/experimental/quantization/config/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
43
44
  tico/experimental/quantization/config/base.py,sha256=xg_HCDSuMgYvMd6ENZe4Sm2SYJgMaCBj4cmqaz_lhAs,816
44
45
  tico/experimental/quantization/config/gptq.py,sha256=IUIEz5bLhsTXqoBCE1rfPec99zsRjwgpDbPW5YJqOPg,973
45
46
  tico/experimental/quantization/config/pt2e.py,sha256=9HCrraTGGZeKEN9puKV-ODi7ncV2Wjc3oe_JCO1D_Rs,850
46
- tico/experimental/quantization/config/smoothquant.py,sha256=fcyhu3YlOTM7fDW9lGTXh-uJOUD6CeykZj7AMCNVbak,1415
47
+ tico/experimental/quantization/config/smoothquant.py,sha256=b92dz4-MiBbkaLzXb47bVoO29d2P416woFQUZ1wpO_s,1414
47
48
  tico/experimental/quantization/evaluation/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
48
49
  tico/experimental/quantization/evaluation/backend.py,sha256=CZL9rZOA0t8cH7PHp6u9l7dGqWNvTj9bKOvwo0PVul0,692
49
50
  tico/experimental/quantization/evaluation/evaluate.py,sha256=kfa_GvFaX6DoSTAmuCImMJqF2jgqtnor5UpC7wVmGPI,7877
@@ -90,6 +91,7 @@ tico/experimental/quantization/ptq/wrappers/quant_elementwise.py,sha256=LhEoobfv
90
91
  tico/experimental/quantization/ptq/wrappers/quant_module_base.py,sha256=vkcDos_knGSS29rIZuEIWkAJLHrENbGz8nCH2-iara8,5969
91
92
  tico/experimental/quantization/ptq/wrappers/registry.py,sha256=OVO5nev6J8Br9zsIX-Ut7ZgWzA9f_jk0Np9bGioXgQM,5171
92
93
  tico/experimental/quantization/ptq/wrappers/fairseq/__init__.py,sha256=Mc8FLd9DusyB_IT1vk1OYrRkngOYnYd05IvtA9ORVQc,160
94
+ tico/experimental/quantization/ptq/wrappers/fairseq/decoder_export_single_step.py,sha256=d7ZieKiSbZ2ffkaLYMg2PJl1OyAxkKjB3OHKB4poxJs,9796
93
95
  tico/experimental/quantization/ptq/wrappers/fairseq/quant_decoder.py,sha256=CILYvxPhW7xLkroWW_hunQBGAYGexLqnPnO5xmMnK-E,17877
94
96
  tico/experimental/quantization/ptq/wrappers/fairseq/quant_decoder_layer.py,sha256=JT79shxOhDtRFgm8jrrN6HKvyVotiytLjMjAxX-Cztg,20416
95
97
  tico/experimental/quantization/ptq/wrappers/fairseq/quant_encoder.py,sha256=r9DPUAbL2KRJ8zpMJ39Y9n6Oe79nte-mFcdjG2qEP-w,13809
@@ -112,6 +114,7 @@ tico/passes/cast_clamp_mixed_type_args.py,sha256=m3_HpXLywWmWERfE5lM5PgvjBod7C4B
112
114
  tico/passes/cast_mixed_type_args.py,sha256=Wd3sCDKJZwdb8GiMWKljm8X5CLFRd8eCz-dmWks15Hc,7763
113
115
  tico/passes/const_prop_pass.py,sha256=hDxGgJNiRjsgOArdaoeAOcOOA-nKBvA1W1zcMZQA5yg,11531
114
116
  tico/passes/convert_conv1d_to_conv2d.py,sha256=ktS3h158y9rg1sQiW8BZZbflV_dk_UdjBPQnuiOKyzg,5303
117
+ tico/passes/convert_expand_to_slice_cat.py,sha256=Fa6b5pqiQNq-QBiEC0e3WkQYf2UEhMgzSTIt4hlzdjc,5470
115
118
  tico/passes/convert_layout_op_to_reshape.py,sha256=sCAFjkmVtiKjvDQSAgnjNBHl3_hWXJZElGDXQiTH-7s,2963
116
119
  tico/passes/convert_matmul_to_linear.py,sha256=WATtsHk_GzsU0HYovc3UMyEj8ApF2qLbInAsNlQj0nE,9759
117
120
  tico/passes/convert_repeat_to_expand_copy.py,sha256=JbtFTmWyfJS2SSd_higP1IEhQeh7wHdN5dmTbbiFVCs,3237
@@ -133,7 +136,7 @@ tico/passes/lower_pow2_to_mul.py,sha256=nfJXa9ZTZMiLg6ownSyvkM4KF2z9tZW34Q3CCWI_
133
136
  tico/passes/lower_to_resize_nearest_neighbor.py,sha256=gbrvTmWSXDPdJ1XJtWGI5mo-uEiauXEG3ELwbKYVPLI,9013
134
137
  tico/passes/lower_to_slice.py,sha256=OzlFzK3lBYyYwC3WThsWd94Ob4JINIJF8UaLAtnumzU,7262
135
138
  tico/passes/merge_consecutive_cat.py,sha256=ayZNLDA1DFM7Fxxi2Dmk1CujkgUuaVCH1rhQgLrvvOQ,2701
136
- tico/passes/ops.py,sha256=cSj3Sk2x2cOE9b8oU5pmSa_rHr-iX2lORzu3N_UHMSQ,2967
139
+ tico/passes/ops.py,sha256=7IGRnxIJl-nLO4huVk_mgBfD4VGUNQRyeuM8K1L2u1U,2934
137
140
  tico/passes/remove_nop.py,sha256=Hf91p_EJAOC6DyWNthash0_UWtEcNc_M7znamQfYQ5Y,2686
138
141
  tico/passes/remove_redundant_assert_nodes.py,sha256=rYbTCyuNIXIC-2NreHKBVCuaSUkEQvB_iSRzb26P_EA,1821
139
142
  tico/passes/remove_redundant_expand.py,sha256=8yhlMnbog-T9gIK6LKIU0tu0__gfhZzO36g_fJIVVP4,2162
@@ -235,7 +238,7 @@ tico/serialize/operators/utils.py,sha256=lXGpEJW1h8U_-gfc6EWjvvSiq3yJ9P-v1v3EMRT
235
238
  tico/serialize/operators/adapters/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
236
239
  tico/serialize/operators/adapters/llama_rmsnorm.py,sha256=6t3dhfNpR03eIjsmhymF2JKd6lCf7PvInqMf77c_BOE,1139
237
240
  tico/utils/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
238
- tico/utils/convert.py,sha256=XbogVXO-QS0UTFNvEDyADvhCp87kTUpGAUalN8I8eRQ,13645
241
+ tico/utils/convert.py,sha256=10YufXpuqHz274ACUb1_F5594uClUFhBEh8SY6gYp7w,13809
239
242
  tico/utils/define.py,sha256=Ypgp7YffM4pgPl4Zh6TmogSn1OxGBMRw_e09qYGflZk,1467
240
243
  tico/utils/diff_graph.py,sha256=_eDGGPDPYQD4b--MXX0DLoVgSt_wLfNPt47UlolLLR4,5272
241
244
  tico/utils/dtype.py,sha256=L5Qb7qgbt0eQ5frUTvHYrRtTJb1dg4-JNEopcxCNg1U,1389
@@ -259,9 +262,9 @@ tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
259
262
  tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
260
263
  tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
261
264
  tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
262
- tico-0.1.0.dev250922.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
263
- tico-0.1.0.dev250922.dist-info/METADATA,sha256=2JnBgGh089dLyvlk3CyDQyTraHh_vDRRcPZla7pmuus,8450
264
- tico-0.1.0.dev250922.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
265
- tico-0.1.0.dev250922.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
266
- tico-0.1.0.dev250922.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
267
- tico-0.1.0.dev250922.dist-info/RECORD,,
265
+ tico-0.1.0.dev250924.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
266
+ tico-0.1.0.dev250924.dist-info/METADATA,sha256=v5AiawRevK3MQpg4CBHHj7pStUz7PpagrCX4pq3Exns,8450
267
+ tico-0.1.0.dev250924.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
268
+ tico-0.1.0.dev250924.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
269
+ tico-0.1.0.dev250924.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
270
+ tico-0.1.0.dev250924.dist-info/RECORD,,