tico 0.1.0.dev250904__py3-none-any.whl → 0.1.0.dev251109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tico might be problematic. Click here for more details.

Files changed (133) hide show
  1. tico/__init__.py +1 -1
  2. tico/config/v1.py +5 -0
  3. tico/passes/cast_mixed_type_args.py +2 -0
  4. tico/passes/convert_expand_to_slice_cat.py +153 -0
  5. tico/passes/convert_matmul_to_linear.py +312 -0
  6. tico/passes/convert_to_relu6.py +1 -1
  7. tico/passes/decompose_fake_quantize_tensor_qparams.py +4 -3
  8. tico/passes/ops.py +0 -1
  9. tico/passes/remove_redundant_expand.py +3 -1
  10. tico/quantization/__init__.py +6 -0
  11. tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +161 -0
  12. tico/quantization/algorithm/fpi_gptq/quantizer.py +179 -0
  13. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
  14. tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +14 -6
  15. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
  16. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  17. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  18. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  19. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  20. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
  21. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  22. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  23. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  24. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  25. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  26. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  27. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  28. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  29. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
  30. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  31. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  32. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  33. tico/quantization/config/base.py +26 -0
  34. tico/quantization/config/fpi_gptq.py +29 -0
  35. tico/quantization/config/gptq.py +29 -0
  36. tico/quantization/config/pt2e.py +25 -0
  37. tico/{experimental/quantization/ptq/quant_config.py → quantization/config/ptq.py} +18 -10
  38. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -37
  39. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +6 -12
  40. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  41. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  42. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  43. tico/{experimental/quantization → quantization}/public_interface.py +11 -18
  44. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  45. tico/quantization/quantizer_registry.py +73 -0
  46. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  47. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  48. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_linear.py +11 -10
  49. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_attn.py +10 -12
  50. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_decoder_layer.py +10 -9
  51. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_mlp.py +13 -13
  52. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  53. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/affine_base.py +3 -3
  54. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/base.py +2 -2
  55. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/ema.py +2 -2
  56. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/identity.py +1 -1
  57. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/minmax.py +2 -2
  58. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/mx.py +1 -1
  59. tico/quantization/wrapq/quantizer.py +179 -0
  60. tico/{experimental/quantization/ptq → quantization/wrapq}/utils/introspection.py +3 -5
  61. tico/{experimental/quantization/ptq → quantization/wrapq}/utils/metrics.py +3 -2
  62. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  63. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  64. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  65. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  66. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  67. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  68. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  69. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  70. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_attn.py +58 -21
  71. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_decoder_layer.py +21 -13
  72. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_mlp.py +5 -7
  73. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  74. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_layernorm.py +6 -7
  75. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_linear.py +7 -8
  76. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_silu.py +8 -9
  77. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/ptq_wrapper.py +4 -6
  78. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_elementwise.py +55 -17
  79. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_module_base.py +10 -9
  80. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/registry.py +17 -10
  81. tico/serialize/circle_serializer.py +11 -4
  82. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  83. tico/serialize/operators/op_le.py +54 -0
  84. tico/serialize/operators/op_mm.py +15 -132
  85. tico/utils/convert.py +20 -15
  86. tico/utils/register_custom_op.py +6 -4
  87. tico/utils/signature.py +7 -8
  88. tico/utils/validate_args_kwargs.py +12 -0
  89. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/METADATA +48 -2
  90. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/RECORD +128 -108
  91. tico/experimental/quantization/__init__.py +0 -6
  92. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  93. tico/experimental/quantization/ptq/examples/compare_ppl.py +0 -121
  94. tico/experimental/quantization/ptq/examples/debug_quant_outputs.py +0 -129
  95. tico/experimental/quantization/ptq/examples/quantize_with_gptq.py +0 -165
  96. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  97. /tico/{experimental/quantization/algorithm/gptq → quantization/algorithm/fpi_gptq}/__init__.py +0 -0
  98. /tico/{experimental/quantization/algorithm/pt2e → quantization/algorithm/gptq}/__init__.py +0 -0
  99. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  100. /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
  101. /tico/{experimental/quantization/algorithm/pt2e/annotation → quantization/algorithm/pt2e}/__init__.py +0 -0
  102. /tico/{experimental/quantization/algorithm/pt2e/transformation → quantization/algorithm/pt2e/annotation}/__init__.py +0 -0
  103. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  104. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  105. /tico/{experimental/quantization/algorithm/smoothquant → quantization/algorithm/pt2e/transformation}/__init__.py +0 -0
  106. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  107. /tico/{experimental/quantization/evaluation → quantization/algorithm/smoothquant}/__init__.py +0 -0
  108. /tico/{experimental/quantization/evaluation/executor → quantization/config}/__init__.py +0 -0
  109. /tico/{experimental/quantization/passes → quantization/evaluation}/__init__.py +0 -0
  110. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  111. /tico/{experimental/quantization/ptq → quantization/evaluation/executor}/__init__.py +0 -0
  112. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  113. /tico/{experimental/quantization → quantization}/evaluation/metric.py +0 -0
  114. /tico/{experimental/quantization/ptq/examples → quantization/passes}/__init__.py +0 -0
  115. /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
  116. /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
  117. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  118. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  119. /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
  120. /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
  121. /tico/{experimental/quantization/ptq/observers → quantization/wrapq}/__init__.py +0 -0
  122. /tico/{experimental/quantization/ptq → quantization/wrapq}/dtypes.py +0 -0
  123. /tico/{experimental/quantization/ptq/utils → quantization/wrapq/examples}/__init__.py +0 -0
  124. /tico/{experimental/quantization/ptq → quantization/wrapq}/mode.py +0 -0
  125. /tico/{experimental/quantization/ptq/wrappers → quantization/wrapq/observers}/__init__.py +0 -0
  126. /tico/{experimental/quantization/ptq → quantization/wrapq}/qscheme.py +0 -0
  127. /tico/{experimental/quantization/ptq/wrappers/llama → quantization/wrapq/utils}/__init__.py +0 -0
  128. /tico/{experimental/quantization/ptq → quantization/wrapq}/utils/reduce_utils.py +0 -0
  129. /tico/{experimental/quantization/ptq/wrappers/nn → quantization/wrapq/wrappers}/__init__.py +0 -0
  130. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/LICENSE +0 -0
  131. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/WHEEL +0 -0
  132. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/entry_points.txt +0 -0
  133. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/top_level.txt +0 -0
@@ -17,9 +17,9 @@ from typing import Optional, Tuple
17
17
 
18
18
  import torch
19
19
 
20
- from tico.experimental.quantization.ptq.dtypes import DType, UINT8
21
- from tico.experimental.quantization.ptq.observers.base import ObserverBase
22
- from tico.experimental.quantization.ptq.qscheme import QScheme
20
+ from tico.quantization.wrapq.dtypes import DType, UINT8
21
+ from tico.quantization.wrapq.observers.base import ObserverBase
22
+ from tico.quantization.wrapq.qscheme import QScheme
23
23
 
24
24
 
25
25
  class AffineObserverBase(ObserverBase):
@@ -17,8 +17,8 @@ from typing import Optional, Tuple
17
17
 
18
18
  import torch
19
19
 
20
- from tico.experimental.quantization.ptq.dtypes import DType, UINT8
21
- from tico.experimental.quantization.ptq.qscheme import QScheme
20
+ from tico.quantization.wrapq.dtypes import DType, UINT8
21
+ from tico.quantization.wrapq.qscheme import QScheme
22
22
 
23
23
 
24
24
  class ObserverBase(ABC):
@@ -14,8 +14,8 @@
14
14
 
15
15
  import torch
16
16
 
17
- from tico.experimental.quantization.ptq.observers.affine_base import AffineObserverBase
18
- from tico.experimental.quantization.ptq.utils.reduce_utils import channelwise_minmax
17
+ from tico.quantization.wrapq.observers.affine_base import AffineObserverBase
18
+ from tico.quantization.wrapq.utils.reduce_utils import channelwise_minmax
19
19
 
20
20
 
21
21
  class EMAObserver(AffineObserverBase):
@@ -24,7 +24,7 @@ performing any statistics gathering or fake-quantization.
24
24
  """
25
25
  import torch
26
26
 
27
- from tico.experimental.quantization.ptq.observers.affine_base import AffineObserverBase
27
+ from tico.quantization.wrapq.observers.affine_base import AffineObserverBase
28
28
 
29
29
 
30
30
  class IdentityObserver(AffineObserverBase):
@@ -14,8 +14,8 @@
14
14
 
15
15
  import torch
16
16
 
17
- from tico.experimental.quantization.ptq.observers.affine_base import AffineObserverBase
18
- from tico.experimental.quantization.ptq.utils.reduce_utils import channelwise_minmax
17
+ from tico.quantization.wrapq.observers.affine_base import AffineObserverBase
18
+ from tico.quantization.wrapq.utils.reduce_utils import channelwise_minmax
19
19
 
20
20
 
21
21
  class MinMaxObserver(AffineObserverBase):
@@ -14,7 +14,7 @@
14
14
 
15
15
  import torch
16
16
 
17
- from tico.experimental.quantization.ptq.observers.base import ObserverBase
17
+ from tico.quantization.wrapq.observers.base import ObserverBase
18
18
  from tico.utils.mx.mx_ops import quantize_mx
19
19
 
20
20
 
@@ -0,0 +1,179 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from tico.quantization.config.ptq import PTQConfig
21
+ from tico.quantization.quantizer import BaseQuantizer
22
+ from tico.quantization.quantizer_registry import register_quantizer
23
+
24
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
25
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
26
+
27
+
28
+ @register_quantizer(PTQConfig)
29
+ class PTQQuantizer(BaseQuantizer):
30
+ """
31
+ Post-Training Quantization (PTQ) quantizer integrated with the public interface.
32
+
33
+ Features
34
+ --------
35
+ • Automatically wraps quantizable modules using PTQWrapper.
36
+ • Supports leaf-level (single-module) quantization (e.g., prepare(model.fc, PTQConfig())).
37
+ • Enforces strict wrapping if `strict_wrap=True`: raises NotImplementedError if
38
+ no quantizable module was found at any boundary.
39
+ • If `strict_wrap=False`, unquantizable modules are silently skipped.
40
+ """
41
+
42
+ def __init__(self, config: PTQConfig):
43
+ super().__init__(config)
44
+ self.qcfg: PTQConfig = config
45
+ self.strict_wrap: bool = bool(getattr(config, "strict_wrap", True))
46
+
47
+ @torch.no_grad()
48
+ def prepare(
49
+ self,
50
+ model: torch.nn.Module,
51
+ args: Optional[Any] = None,
52
+ kwargs: Optional[Dict[str, Any]] = None,
53
+ ):
54
+ # Wrap the tree (or single module) according to strictness policy
55
+ model = self._wrap_supported(model, self.qcfg)
56
+
57
+ # Switch all quant modules into calibration mode
58
+ if isinstance(model, QuantModuleBase):
59
+ model.enable_calibration()
60
+ for m in model.modules():
61
+ if isinstance(m, QuantModuleBase):
62
+ m.enable_calibration()
63
+ return model
64
+
65
+ @torch.no_grad()
66
+ def convert(self, model):
67
+ # Freeze qparams across the tree (QUANT mode)
68
+ if isinstance(model, QuantModuleBase):
69
+ model.freeze_qparams()
70
+ for m in model.modules():
71
+ if isinstance(m, QuantModuleBase):
72
+ m.freeze_qparams()
73
+ return model
74
+
75
+ def _wrap_supported(
76
+ self,
77
+ root: nn.Module,
78
+ qcfg: PTQConfig,
79
+ ) -> nn.Module:
80
+ """
81
+ Recursively attempt to wrap boundaries. Strictness is applied at every boundary.
82
+ """
83
+ assert not isinstance(root, QuantModuleBase), "The module is already wrapped."
84
+
85
+ # Case A: HuggingFace-style transformers: model.model.layers
86
+ lm = getattr(root, "model", None)
87
+ layers = getattr(lm, "layers", None) if isinstance(lm, nn.Module) else None
88
+ if isinstance(layers, nn.ModuleList):
89
+ new_list = nn.ModuleList()
90
+ for idx, layer in enumerate(layers):
91
+ child_scope = f"layer{idx}"
92
+ child_cfg = qcfg.child(child_scope)
93
+
94
+ # Enforce strictness at the child boundary
95
+ wrapped = self._try_wrap(
96
+ layer,
97
+ child_cfg,
98
+ fp_name=child_scope,
99
+ raise_on_fail=self.strict_wrap,
100
+ )
101
+ new_list.append(wrapped)
102
+ lm.layers = new_list # type: ignore[union-attr]
103
+ return root
104
+
105
+ # Case B: Containers
106
+ if isinstance(root, (nn.Sequential, nn.ModuleList)):
107
+ for i, child in enumerate(list(root)):
108
+ name = str(i)
109
+ child_cfg = qcfg.child(name)
110
+
111
+ wrapped = self._try_wrap(
112
+ child, child_cfg, fp_name=name, raise_on_fail=self.strict_wrap
113
+ )
114
+ if wrapped is child:
115
+ assert not self.strict_wrap
116
+ wrapped = self._wrap_supported(wrapped, child_cfg)
117
+ root[i] = wrapped # type: ignore[index]
118
+
119
+ if isinstance(root, nn.ModuleDict):
120
+ for k, child in list(root.items()):
121
+ name = k
122
+ child_cfg = qcfg.child(name)
123
+
124
+ wrapped = self._try_wrap(
125
+ child, child_cfg, fp_name=name, raise_on_fail=self.strict_wrap
126
+ )
127
+ if wrapped is child:
128
+ assert not self.strict_wrap
129
+ wrapped = self._wrap_supported(wrapped, child_cfg)
130
+ root[k] = wrapped # type: ignore[index]
131
+
132
+ # Case C: Leaf node
133
+ root_name = getattr(root, "_get_name", lambda: None)()
134
+ wrapped = self._try_wrap(
135
+ root, qcfg, fp_name=root_name, raise_on_fail=self.strict_wrap
136
+ )
137
+ if wrapped is not root:
138
+ return wrapped
139
+
140
+ assert not self.strict_wrap
141
+ # Case D: Named children
142
+ for name, child in list(root.named_children()):
143
+ child_cfg = qcfg.child(name)
144
+
145
+ wrapped = self._try_wrap(
146
+ child, child_cfg, fp_name=name, raise_on_fail=self.strict_wrap
147
+ )
148
+ if wrapped is child:
149
+ assert not self.strict_wrap
150
+ wrapped = self._wrap_supported(wrapped, child_cfg)
151
+ setattr(root, name, wrapped)
152
+
153
+ return root
154
+
155
+ def _try_wrap(
156
+ self,
157
+ module: nn.Module,
158
+ qcfg_for_child: PTQConfig,
159
+ *,
160
+ fp_name: Optional[str],
161
+ raise_on_fail: bool,
162
+ ) -> nn.Module:
163
+ """
164
+ Attempt to wrap a boundary with PTQWrapper.
165
+
166
+ Behavior:
167
+ • If PTQWrapper succeeds: return wrapped module.
168
+ • If PTQWrapper raises NotImplementedError:
169
+ - raise_on_fail=True -> re-raise (strict)
170
+ - raise_on_fail=False -> return original module (permissive)
171
+ """
172
+ try:
173
+ return PTQWrapper(module, qcfg=qcfg_for_child, fp_name=fp_name)
174
+ except NotImplementedError as e:
175
+ if raise_on_fail:
176
+ raise NotImplementedError(
177
+ f"PTQQuantizer: no quantization wrapper for {type(module).__name__}"
178
+ ) from e
179
+ return module
@@ -16,11 +16,9 @@ from typing import Callable, Dict, List, Optional, Tuple
16
16
 
17
17
  import torch
18
18
 
19
- from tico.experimental.quantization.evaluation.metric import MetricCalculator
20
- from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
21
- from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
22
- QuantModuleBase,
23
- )
19
+ from tico.quantization.evaluation.metric import MetricCalculator
20
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
21
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
24
22
 
25
23
 
26
24
  def build_fqn_map(root: torch.nn.Module) -> dict[torch.nn.Module, str]:
@@ -98,7 +98,8 @@ def perplexity(
98
98
 
99
99
  input_ids = input_ids_full[:, begin:end]
100
100
  target_ids = input_ids.clone()
101
- target_ids[:, :-trg_len] = ignore_index # mask previously-scored tokens
101
+ # mask previously-scored tokens
102
+ target_ids[:, :-trg_len] = ignore_index # type: ignore[assignment]
102
103
 
103
104
  with torch.no_grad():
104
105
  outputs = model(input_ids, labels=target_ids)
@@ -106,7 +107,7 @@ def perplexity(
106
107
  neg_log_likelihood = outputs.loss
107
108
 
108
109
  # exact number of labels that contributed to loss
109
- loss_tokens = (target_ids[:, 1:] != ignore_index).sum().item()
110
+ loss_tokens = (target_ids[:, 1:] != ignore_index).sum().item() # type: ignore[attr-defined]
110
111
  nll_sum += neg_log_likelihood * loss_tokens
111
112
  n_tokens += int(loss_tokens)
112
113
 
@@ -0,0 +1,5 @@
1
+ from tico.quantization.wrapq.wrappers.fairseq.quant_mha import (
2
+ QuantFairseqMultiheadAttention,
3
+ )
4
+
5
+ __all__ = ["QuantFairseqMultiheadAttention"]
@@ -0,0 +1,234 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # -----------------------------------------------------------------------------
16
+ # This file includes modifications based on fairseq
17
+ # (https://github.com/facebookresearch/fairseq), originally licensed under
18
+ # the MIT License. See the LICENSE file in the fairseq repository for details.
19
+ # -----------------------------------------------------------------------------
20
+
21
+ """
22
+ Q) Why the name "SingleStep"?
23
+
24
+ Fairseq's decoder already advances one token at a time during generation,
25
+ but the default path is "stateful" and "shape-polymorphic": it owns and
26
+ mutates K/V caches internally, prefix lengths and triangular masks grow with
27
+ the step, and beam reordering updates hidden module state. That's friendly
28
+ for eager execution, but hostile to `torch.export` and many accelerator
29
+ backends.
30
+
31
+ This export wrapper makes the per-token call truly "single-step" in the
32
+ export sense: "stateless" and "fixed-shape" so every invocation has the
33
+ exact same graph.
34
+
35
+ Key invariants
36
+ --------------
37
+ • "Stateless": K/V caches come in as explicit inputs and go out as outputs.
38
+ The module does not store or mutate hidden state.
39
+ • "Static shapes": Query is always [B, 1, C]; encoder features and masks
40
+ have fixed, predeclared sizes; K/V slots use fixed capacity (unused tail
41
+ is simply masked/ignored).
42
+ • "External control": Step indexing, cache slot management (append/roll),
43
+ and beam reordering are handled outside the module.
44
+ • "Prebuilt additive masks": Self-attention masks are provided by the
45
+ caller (0 for valid, large negative sentinel, e.g. -120, for masked),
46
+ avoiding data-dependent control flow.
47
+
48
+ In short: still step-wise like fairseq, but restructured for export—no
49
+ internal state, no data-dependent shapes, no dynamic control flow.
50
+ """
51
+
52
+ from typing import List, Tuple
53
+
54
+ import torch
55
+ import torch.nn as nn
56
+
57
+ import tico
58
+
59
+ # ----- 1) Export wrapper module -------------------------------------------
60
+ class DecoderExportSingleStep(nn.Module):
61
+ """
62
+ Export-only single-step decoder module.
63
+
64
+ Inputs (example shapes; B=1, H=8, Dh=64, C=512, S=64, Tprev=63):
65
+ - prev_x: [B, 1, C] embedded decoder input for the current step
66
+ - enc_x: [S, B, C] encoder hidden states (fixed-length export input)
67
+ - enc_pad_additive: [B, 1, S] additive float key_padding_mask for enc-dec attn (0 for keep, -120 for pad)
68
+ - self_attn_mask: [B, 1, S] additive float mask for decoder self-attn at this step; pass zeros if unused
69
+ - prev_self_k_0..L-1: [B, H, Tprev, Dh] cached self-attn K per layer
70
+ - prev_self_v_0..L-1: [B, H, Tprev, Dh] cached self-attn V per layer
71
+
72
+ Outputs:
73
+ - x_out: [B, 1, C] new decoder features at the current step
74
+ - new_k_0..L-1: [H, B, Dh] per-layer new K (single-timestep; time dim squeezed)
75
+ - new_v_0..L-1: [H, B, Dh] per-layer new V (single-timestep; time dim squeezed)
76
+
77
+ Notes:
78
+ • We keep masks/additive semantics externally to avoid any mask-building inside the graph.
79
+ • We reshape the new K/V from [B,H,1,Dh] -> [H,B,Dh] to match the requested output spec (8,1,64).
80
+ """
81
+
82
+ def __init__(self, decoder: nn.Module):
83
+ super().__init__()
84
+ self.decoder = decoder
85
+ # Cache common meta for assertions
86
+ self.num_layers = len(getattr(decoder, "layers"))
87
+ # Infer heads/head_dim from the wrapped self_attn of layer 0
88
+ any_layer = getattr(decoder.layers[0], "wrapped", decoder.layers[0]) # type: ignore[index]
89
+ mha = getattr(any_layer, "self_attn", None)
90
+ assert mha is not None, "Decoder layer must expose self_attn"
91
+ self.num_heads = int(mha.num_heads)
92
+ self.head_dim = int(mha.head_dim)
93
+ # Embed dim (C)
94
+ self.embed_dim = int(getattr(decoder, "embed_dim"))
95
+
96
+ def forward(
97
+ self,
98
+ prev_x: torch.Tensor, # [B,1,C]
99
+ enc_x: torch.Tensor, # [S,B,C]
100
+ enc_pad_additive: torch.Tensor, # [B,1,S]
101
+ *kv_args: torch.Tensor, # prev_k_0..L-1, prev_v_0..L-1 (total 2L tensors)
102
+ self_attn_mask: torch.Tensor, # [B,1,S] (or zeros)
103
+ ):
104
+ L = self.num_layers
105
+ H = self.num_heads
106
+ Dh = self.head_dim
107
+ B, one, C = prev_x.shape
108
+ S, B2, C2 = enc_x.shape
109
+ assert (
110
+ one == 1 and C == self.embed_dim and B == B2 and C2 == C
111
+ ), "Shape mismatch in prev_x/enc_x"
112
+ assert len(kv_args) == 2 * L, f"Expected {2*L} KV tensors, got {len(kv_args)}"
113
+
114
+ # Unpack previous self-attn caches
115
+ prev_k_list: List[torch.Tensor] = list() # each [B,H,Tprev,Dh]
116
+ prev_v_list: List[torch.Tensor] = list() # each [B,H,Tprev,Dh]
117
+ for i in range(L):
118
+ prev_k_list.append(kv_args[2 * i])
119
+ prev_v_list.append(kv_args[2 * i + 1])
120
+ for i in range(L):
121
+ assert (
122
+ prev_k_list[i].dim() == 4 and prev_v_list[i].dim() == 4
123
+ ), "KV must be [B,H,Tprev,Dh]"
124
+ assert (
125
+ prev_k_list[i].shape[0] == B
126
+ and prev_k_list[i].shape[1] == H
127
+ and prev_k_list[i].shape[3] == Dh
128
+ )
129
+
130
+ # Call decoder's external single-step path
131
+ # Returns:
132
+ # x_step: [B,1,C]
133
+ # newk/newv: lists of length L, each [B*H,1,Dh]
134
+ x_step, newk_list, newv_list = self.decoder.forward_external_step( # type: ignore[operator]
135
+ prev_output_x=prev_x,
136
+ encoder_out_x=enc_x,
137
+ encoder_padding_mask=enc_pad_additive,
138
+ self_attn_mask=self_attn_mask,
139
+ prev_self_k_list=prev_k_list,
140
+ prev_self_v_list=prev_v_list,
141
+ )
142
+
143
+ out_tensors: List[torch.Tensor] = [
144
+ x_step
145
+ ] # first output is the new decoder features
146
+ for i in range(L):
147
+ nk = newk_list[i] # [B*H, Tnew, Dh]
148
+ nv = newv_list[i] # [B*H, Tnew, Dh]
149
+ out_tensors.append(nk)
150
+ out_tensors.append(nv)
151
+
152
+ # Return tuple: (x_step, new_k_0, new_v_0, new_k_1, new_v_1, ..., new_k_{L-1}, new_v_{L-1})
153
+ return tuple(out_tensors)
154
+
155
+
156
+ # ----- 2) Example inputs (B=1, S=64, H=8, Dh=64, C=512, L=4) ---------------
157
+ def make_example_inputs(*, L=4, B=1, S=64, H=8, Dh=64, C=512, Tprev=63, device="cpu"):
158
+ """
159
+ Build example tensors that match the export I/O spec.
160
+ Shapes follow the request:
161
+ prev_x: [1,1,512]
162
+ enc_x: [64,1,512]
163
+ enc_pad_additive: [1,1,64] (additive float; zeros -> keep)
164
+ prev_k_i / prev_v_i (for i in 0..L-1): [1,8,63,64]
165
+ self_attn_mask: [1,1,64] (additive float; zeros -> keep)
166
+ """
167
+ g = torch.Generator(device=device).manual_seed(0)
168
+
169
+ prev_x = torch.randn(B, 1, C, device=device, dtype=torch.float32, generator=g)
170
+ enc_x = torch.randn(S, B, C, device=device, dtype=torch.float32, generator=g)
171
+
172
+ # Additive masks (0 for allowed, -120 for masked)
173
+ enc_pad_additive = torch.full((B, 1, S), float(-120), device=device)
174
+ self_attn_mask = torch.full((B, 1, S), float(-120), device=device)
175
+ enc_pad_additive[0, :27] = 0 # 27 is a random example.
176
+ self_attn_mask[0, :27] = 0 # 27 is a random example.
177
+
178
+ # Previous self-attn caches for each layer
179
+ prev_k_list = []
180
+ prev_v_list = []
181
+ for _ in range(L):
182
+ prev_k = torch.randn(
183
+ B, H, Tprev, Dh, device=device, dtype=torch.float32, generator=g
184
+ )
185
+ prev_v = torch.randn(
186
+ B, H, Tprev, Dh, device=device, dtype=torch.float32, generator=g
187
+ )
188
+ prev_k_list.append(prev_k)
189
+ prev_v_list.append(prev_v)
190
+
191
+ # Pack inputs as the export function will expect:
192
+ # (prev_x, enc_x, enc_pad_additive, self_attn_mask, prev_k_0..L-1, prev_v_0..L-1)
193
+ example_args: Tuple[torch.Tensor, ...] = (
194
+ prev_x,
195
+ enc_x,
196
+ enc_pad_additive,
197
+ *prev_k_list,
198
+ *prev_v_list,
199
+ )
200
+ example_kwargs = {"self_attn_mask": self_attn_mask}
201
+ return example_args, example_kwargs
202
+
203
+
204
+ # ----- 3) Export driver -----------------------------------------------------
205
+ def export_decoder_single_step(translator, *, save_path="decoder_step_export.circle"):
206
+ """
207
+ Wrap the QuantFairseqDecoder into the export-friendly single-step module
208
+ and export with torch.export.export using example inputs.
209
+ """
210
+ # Grab the wrapped decoder
211
+ dec = translator.models[
212
+ 0
213
+ ].decoder # assumed QuantFairseqDecoder with forward_external_step
214
+ # Build export wrapper
215
+ wrapper = DecoderExportSingleStep(decoder=dec).eval()
216
+
217
+ # Example inputs (L inferred from wrapper/decoder)
218
+ L = wrapper.num_layers
219
+ H = wrapper.num_heads
220
+ Dh = wrapper.head_dim
221
+ C = wrapper.embed_dim
222
+ example_inputs, example_kwargs = make_example_inputs(L=L, H=H, Dh=Dh, C=C)
223
+
224
+ # Export circle (no dynamism assumed; shapes are fixed for export)
225
+ cm = tico.convert(
226
+ wrapper,
227
+ args=example_inputs,
228
+ kwargs=example_kwargs,
229
+ strict=True, # fail if something cannot be captured
230
+ )
231
+
232
+ # Save .pte
233
+ cm.save(save_path)
234
+ print(f"Saved decoder single-step export to: {save_path}")