tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. tico/__init__.py +9 -1
  2. tico/config/base.py +1 -1
  3. tico/config/v1.py +5 -0
  4. tico/passes/cast_aten_where_arg_type.py +1 -1
  5. tico/passes/cast_clamp_mixed_type_args.py +169 -0
  6. tico/passes/cast_mixed_type_args.py +4 -2
  7. tico/passes/const_prop_pass.py +1 -1
  8. tico/passes/convert_conv1d_to_conv2d.py +1 -1
  9. tico/passes/convert_expand_to_slice_cat.py +153 -0
  10. tico/passes/convert_matmul_to_linear.py +312 -0
  11. tico/passes/convert_to_relu6.py +1 -1
  12. tico/passes/decompose_addmm.py +0 -3
  13. tico/passes/decompose_batch_norm.py +2 -2
  14. tico/passes/decompose_fake_quantize.py +0 -3
  15. tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
  16. tico/passes/decompose_group_norm.py +0 -3
  17. tico/passes/legalize_predefined_layout_operators.py +2 -11
  18. tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
  19. tico/passes/lower_to_slice.py +1 -1
  20. tico/passes/merge_consecutive_cat.py +1 -1
  21. tico/passes/ops.py +1 -1
  22. tico/passes/remove_redundant_assert_nodes.py +3 -1
  23. tico/passes/remove_redundant_expand.py +3 -6
  24. tico/passes/remove_redundant_reshape.py +5 -5
  25. tico/passes/segment_index_select.py +1 -1
  26. tico/quantization/__init__.py +6 -0
  27. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
  28. tico/quantization/algorithm/gptq/quantizer.py +292 -0
  29. tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
  30. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
  31. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  32. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  33. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  34. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  35. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
  36. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  37. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  38. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  39. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  40. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  41. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  42. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  43. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  44. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
  45. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  46. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  47. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  48. tico/quantization/config/base.py +26 -0
  49. tico/quantization/config/gptq.py +29 -0
  50. tico/quantization/config/pt2e.py +25 -0
  51. tico/quantization/config/ptq.py +119 -0
  52. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
  53. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
  54. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  55. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  56. tico/quantization/evaluation/metric.py +146 -0
  57. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  58. tico/quantization/passes/__init__.py +1 -0
  59. tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
  60. tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
  61. tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
  62. tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
  63. tico/{experimental/quantization → quantization}/public_interface.py +19 -18
  64. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  65. tico/quantization/quantizer_registry.py +73 -0
  66. tico/quantization/wrapq/__init__.py +1 -0
  67. tico/quantization/wrapq/dtypes.py +70 -0
  68. tico/quantization/wrapq/examples/__init__.py +1 -0
  69. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  70. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  71. tico/quantization/wrapq/examples/quantize_linear.py +107 -0
  72. tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
  73. tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
  74. tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
  75. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  76. tico/quantization/wrapq/mode.py +32 -0
  77. tico/quantization/wrapq/observers/__init__.py +1 -0
  78. tico/quantization/wrapq/observers/affine_base.py +128 -0
  79. tico/quantization/wrapq/observers/base.py +98 -0
  80. tico/quantization/wrapq/observers/ema.py +62 -0
  81. tico/quantization/wrapq/observers/identity.py +74 -0
  82. tico/quantization/wrapq/observers/minmax.py +39 -0
  83. tico/quantization/wrapq/observers/mx.py +60 -0
  84. tico/quantization/wrapq/qscheme.py +40 -0
  85. tico/quantization/wrapq/quantizer.py +179 -0
  86. tico/quantization/wrapq/utils/__init__.py +1 -0
  87. tico/quantization/wrapq/utils/introspection.py +167 -0
  88. tico/quantization/wrapq/utils/metrics.py +124 -0
  89. tico/quantization/wrapq/utils/reduce_utils.py +25 -0
  90. tico/quantization/wrapq/wrappers/__init__.py +1 -0
  91. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  92. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  93. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  94. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  95. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  96. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  97. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  98. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  99. tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
  100. tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
  101. tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
  102. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  103. tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
  104. tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
  105. tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
  106. tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
  107. tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
  108. tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
  109. tico/quantization/wrapq/wrappers/registry.py +125 -0
  110. tico/serialize/circle_graph.py +12 -4
  111. tico/serialize/circle_mapping.py +76 -2
  112. tico/serialize/circle_serializer.py +253 -148
  113. tico/serialize/operators/adapters/__init__.py +1 -0
  114. tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
  115. tico/serialize/operators/op_any.py +7 -14
  116. tico/serialize/operators/op_avg_pool2d.py +11 -4
  117. tico/serialize/operators/op_clamp.py +5 -7
  118. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  119. tico/serialize/operators/op_conv2d.py +14 -6
  120. tico/serialize/operators/op_copy.py +26 -3
  121. tico/serialize/operators/op_cumsum.py +3 -1
  122. tico/serialize/operators/op_depthwise_conv2d.py +17 -7
  123. tico/serialize/operators/op_full_like.py +0 -2
  124. tico/serialize/operators/op_index_select.py +8 -1
  125. tico/serialize/operators/op_instance_norm.py +0 -6
  126. tico/serialize/operators/op_le.py +54 -0
  127. tico/serialize/operators/op_log1p.py +3 -2
  128. tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
  129. tico/serialize/operators/op_mm.py +15 -131
  130. tico/serialize/operators/op_mul.py +2 -8
  131. tico/serialize/operators/op_pow.py +3 -1
  132. tico/serialize/operators/op_repeat.py +12 -3
  133. tico/serialize/operators/op_reshape.py +1 -1
  134. tico/serialize/operators/op_rmsnorm.py +65 -0
  135. tico/serialize/operators/op_softmax.py +7 -14
  136. tico/serialize/operators/op_split_with_sizes.py +16 -8
  137. tico/serialize/operators/op_transpose_conv.py +11 -8
  138. tico/serialize/operators/op_view.py +2 -1
  139. tico/serialize/quant_param.py +5 -5
  140. tico/utils/convert.py +30 -17
  141. tico/utils/dtype.py +42 -0
  142. tico/utils/graph.py +1 -1
  143. tico/utils/model.py +2 -1
  144. tico/utils/padding.py +2 -2
  145. tico/utils/pytree_utils.py +134 -0
  146. tico/utils/record_input.py +102 -0
  147. tico/utils/register_custom_op.py +29 -4
  148. tico/utils/serialize.py +16 -3
  149. tico/utils/signature.py +247 -0
  150. tico/utils/torch_compat.py +52 -0
  151. tico/utils/utils.py +50 -58
  152. tico/utils/validate_args_kwargs.py +38 -3
  153. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
  154. tico-0.1.0.dev251102.dist-info/RECORD +271 -0
  155. tico/experimental/quantization/__init__.py +0 -1
  156. tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
  157. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  158. tico/experimental/quantization/evaluation/metric.py +0 -109
  159. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
  160. tico-0.1.0.dev250714.dist-info/RECORD +0 -209
  161. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  162. /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
  163. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  164. /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
  165. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
  166. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  167. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  168. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
  169. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  170. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
  171. /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
  172. /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
  173. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  174. /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
  175. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  176. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  177. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  178. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
  179. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
  180. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
  181. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,167 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Callable, Dict, List, Optional, Tuple
16
+
17
+ import torch
18
+
19
+ from tico.quantization.evaluation.metric import MetricCalculator
20
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
21
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
22
+
23
+
24
+ def build_fqn_map(root: torch.nn.Module) -> dict[torch.nn.Module, str]:
25
+ """
26
+ Return {module_object: full_qualified_name} without touching the modules.
27
+ """
28
+ return {m: n for n, m in root.named_modules()}
29
+
30
+
31
+ def save_fp_outputs(
32
+ model: torch.nn.Module,
33
+ ) -> Tuple[List[torch.utils.hooks.RemovableHandle], Dict[str, torch.Tensor]]:
34
+ """
35
+ Register forward-hooks on every `QuantModuleBase` wrapper itself (not the
36
+ wrapped `module`) and cache its output while the wrapper runs in CALIB mode.
37
+
38
+ Parameters
39
+ ----------
40
+ model : torch.nn.Module
41
+ The model whose wrappers are already switched to CALIB mode
42
+ (`enable_calibration()` has been called).
43
+
44
+ Returns
45
+ -------
46
+ handles : list[RemovableHandle]
47
+ Hook handles; call `.remove()` on each one to detach the hooks.
48
+ cache : dict[str, torch.Tensor]
49
+ Mapping "wrapper-name → cached FP32 activation" captured from the first
50
+ forward pass. Keys default to `wrapper.fp_name`; if that attribute is
51
+ `None`, the `id(wrapper)` string is used instead.
52
+ """
53
+ cache: Dict[str, torch.Tensor] = {}
54
+ handles: List[torch.utils.hooks.RemovableHandle] = []
55
+
56
+ def _save(name: str):
57
+ def hook(_, __, out: torch.Tensor | Tuple):
58
+ if isinstance(out, tuple):
59
+ out = out[0]
60
+ assert isinstance(out, torch.Tensor)
61
+ cache[name] = out.detach()
62
+
63
+ return hook
64
+
65
+ for m in model.modules():
66
+ if isinstance(m, QuantModuleBase):
67
+ name = m.fp_name or str(id(m))
68
+ handles.append(m.register_forward_hook(_save(name)))
69
+
70
+ return handles, cache
71
+
72
+
73
+ def compare_layer_outputs(
74
+ model: torch.nn.Module,
75
+ cache: Dict[str, torch.Tensor],
76
+ *,
77
+ metrics: Optional[List[str]] = None,
78
+ custom_metrics: Optional[Dict[str, Callable]] = None,
79
+ rtol: float = 1e-3,
80
+ atol: float = 1e-3,
81
+ collect: bool = False,
82
+ ):
83
+ """
84
+ Register forward-hooks on every `QuantModuleBase` wrapper to compare its
85
+ QUANT-mode output to the FP32 reference saved by `save_fp_outputs()`.
86
+
87
+ Each hook prints a per-layer diff report:
88
+
89
+ ✓ layer_name max=1.23e-02 mean=8.45e-04 (within tolerance)
90
+ ⚠️ layer_name max=3.07e+00 mean=5.12e-01 (exceeds tolerance)
91
+
92
+ Parameters
93
+ ----------
94
+ model : torch.nn.Module
95
+ The model whose wrappers are now in QUANT mode
96
+ (`freeze_qparams()` has been called).
97
+ cache : dict[str, torch.Tensor]
98
+ The reference activations captured during CALIB mode.
99
+ metrics
100
+ Metrics to compute. Defaults to `["diff"]`. Add `peir` to print PEIR.
101
+ custom_metrics
102
+ Optional user metric functions. Same signature as built-ins.
103
+ rtol, atol : float, optional
104
+ Relative / absolute tolerances used to flag large deviations
105
+ (similar to `torch.allclose` semantics).
106
+ collect : bool, optional
107
+ • False (default) → print one-line report per layer, return `None`
108
+ • True → suppress printing, return a nested dict
109
+ {layer_name -> {metric -> value}}
110
+
111
+ Returns
112
+ -------
113
+ handles
114
+ Hook handles; call `.remove()` once diffing is complete.
115
+ results
116
+ Only if *collect* is True.
117
+ """
118
+ metrics = metrics or ["diff"]
119
+ calc = MetricCalculator(custom_metrics)
120
+ handles: List[torch.utils.hooks.RemovableHandle] = []
121
+ results: Dict[
122
+ str, Dict[str, float]
123
+ ] = {} # Dict[layer_name, Dict[metric_name, value]]
124
+
125
+ def _cmp(name: str):
126
+ ref = cache.get(name)
127
+
128
+ def hook(_, __, out):
129
+ if ref is None:
130
+ if not collect:
131
+ print(f"[{name}] no cached reference")
132
+ return
133
+ if isinstance(out, tuple):
134
+ out = out[0]
135
+ assert isinstance(out, torch.Tensor)
136
+
137
+ # Compute all requested metrics
138
+ res = calc.compute([ref], [out], metrics) # lists with length-1 tensors
139
+ res = {k: v[0] for k, v in res.items()} # flatten
140
+
141
+ if collect:
142
+ results[name] = res # type: ignore[assignment]
143
+ return
144
+
145
+ # Pretty print ------------------------------------------------ #
146
+ diff_val = res.get("diff") or res.get("max_abs_diff")
147
+ thresh = atol + rtol * ref.abs().max().item()
148
+ flag = "⚠️" if (diff_val is not None and diff_val > thresh) else "✓" # type: ignore[operator]
149
+
150
+ pieces = [f"{flag} {name:45s}"]
151
+ for key, val in res.items():
152
+ pieces.append(f"{key}={val:<7.4}")
153
+ print(" ".join(pieces))
154
+
155
+ return hook
156
+
157
+ for m in model.modules():
158
+ if isinstance(m, PTQWrapper):
159
+ # skip the internal fp module inside the wrapper
160
+ continue
161
+ if isinstance(m, QuantModuleBase):
162
+ lname = m.fp_name or str(id(m))
163
+ handles.append(m.register_forward_hook(_cmp(lname)))
164
+
165
+ if collect:
166
+ return handles, results
167
+ return handles
@@ -0,0 +1,124 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ import tqdm
19
+
20
+
21
+ def perplexity(
22
+ model: torch.nn.Module,
23
+ encodings: torch.Tensor,
24
+ device: torch.device | str,
25
+ *,
26
+ max_length: Optional[int] = None,
27
+ stride: int = 512,
28
+ ignore_index: int | None = -100,
29
+ show_progress: bool = True,
30
+ ) -> float:
31
+ """
32
+ Compute perplexity (PPL) using a "strided sliding-window"
33
+ evaluation strategy.
34
+
35
+ The function:
36
+ 1. Splits the token sequence into overlapping windows of length
37
+ `max_length` (model context size).
38
+ 2. Masks tokens that were already scored in previous windows
39
+ (`labels == -100`), so each token's negative log-likelihood (NLL)
40
+ is counted EXACTLY once.
41
+ 3. Aggregates token-wise NLL to return corpus-level PPL.
42
+
43
+ Parameters
44
+ ----------
45
+ model : torch.nn.Module
46
+ Causal LM loaded in evaluation mode (`model.eval()`).
47
+ encodings : torch.Tensor | transformers.BatchEncoding
48
+ Tokenised corpus. If a `BatchEncoding` is passed, its
49
+ `.input_ids` field is used. Shape must be `(1, seq_len)`.
50
+ device : torch.device | str
51
+ CUDA or CPU device on which to run evaluation.
52
+ max_length : int, optional
53
+ Context window size. Defaults to `model.config.max_position_embeddings`.
54
+ stride : int, default = 512
55
+ Step size by which the sliding window advances. Must satisfy
56
+ `1 ≤ stride ≤ max_length`.
57
+ ignore_index : int, default = -100
58
+ Label value to ignore in loss computation. This should match
59
+ the `ignore_index` used by the model's internal
60
+ `CrossEntropyLoss`. For Hugging Face causal LMs, the
61
+ convention is `-100`.
62
+ show_progress : bool, default = True
63
+ If True, displays a tqdm progess bar while evaluating.
64
+
65
+ Returns
66
+ -------
67
+ float
68
+ Corpus-level perplexity.
69
+ """
70
+ # -------- input preparation -------- #
71
+ try:
72
+ # transformers.BatchEncoding has `input_ids`
73
+ input_ids_full = encodings.input_ids # type: ignore[attr-defined]
74
+ except AttributeError: # already a tensor
75
+ input_ids_full = encodings
76
+ assert isinstance(input_ids_full, torch.Tensor)
77
+ input_ids_full = input_ids_full.to(device)
78
+
79
+ if max_length is None:
80
+ assert hasattr(model, "config")
81
+ assert hasattr(model.config, "max_position_embeddings")
82
+ assert isinstance(model.config.max_position_embeddings, int)
83
+ max_length = model.config.max_position_embeddings
84
+ assert max_length is not None
85
+ assert (
86
+ 1 <= stride <= max_length
87
+ ), f"stride ({stride}) must be in [1, max_length ({max_length})]"
88
+
89
+ seq_len = input_ids_full.size(1)
90
+ nll_sum = 0.0
91
+ n_tokens = 0
92
+ prev_end = 0
93
+
94
+ # -------- main loop -------- #
95
+ for begin in tqdm.trange(0, seq_len, stride, desc="PPL", disable=not show_progress):
96
+ end = min(begin + max_length, seq_len)
97
+ trg_len = end - prev_end # fresh tokens in this window
98
+
99
+ input_ids = input_ids_full[:, begin:end]
100
+ target_ids = input_ids.clone()
101
+ # mask previously-scored tokens
102
+ target_ids[:, :-trg_len] = ignore_index # type: ignore[assignment]
103
+
104
+ with torch.no_grad():
105
+ outputs = model(input_ids, labels=target_ids)
106
+ # loss is already averaged over non-masked labels
107
+ neg_log_likelihood = outputs.loss
108
+
109
+ # exact number of labels that contributed to loss
110
+ loss_tokens = (target_ids[:, 1:] != ignore_index).sum().item() # type: ignore[attr-defined]
111
+ nll_sum += neg_log_likelihood * loss_tokens
112
+ n_tokens += int(loss_tokens)
113
+
114
+ prev_end = end
115
+ if end == seq_len:
116
+ break
117
+
118
+ avg_nll: float | torch.Tensor = nll_sum / n_tokens
119
+ if not isinstance(avg_nll, torch.Tensor):
120
+ avg_nll = torch.tensor(avg_nll)
121
+ assert isinstance(avg_nll, torch.Tensor)
122
+ ppl = torch.exp(avg_nll)
123
+
124
+ return ppl.item()
@@ -0,0 +1,25 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+
18
+ def channelwise_minmax(x: torch.Tensor, channel_axis: int):
19
+ """
20
+ Compute per-channel (min, max) by reducing all axes except `channel_axis`.
21
+ """
22
+ channel_axis = channel_axis % x.ndim # handle negative indices safely
23
+ dims = tuple(d for d in range(x.ndim) if d != channel_axis)
24
+
25
+ return x.amin(dim=dims), x.amax(dim=dims)
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -0,0 +1,5 @@
1
+ from tico.quantization.wrapq.wrappers.fairseq.quant_mha import (
2
+ QuantFairseqMultiheadAttention,
3
+ )
4
+
5
+ __all__ = ["QuantFairseqMultiheadAttention"]
@@ -0,0 +1,234 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # -----------------------------------------------------------------------------
16
+ # This file includes modifications based on fairseq
17
+ # (https://github.com/facebookresearch/fairseq), originally licensed under
18
+ # the MIT License. See the LICENSE file in the fairseq repository for details.
19
+ # -----------------------------------------------------------------------------
20
+
21
+ """
22
+ Q) Why the name "SingleStep"?
23
+
24
+ Fairseq's decoder already advances one token at a time during generation,
25
+ but the default path is "stateful" and "shape-polymorphic": it owns and
26
+ mutates K/V caches internally, prefix lengths and triangular masks grow with
27
+ the step, and beam reordering updates hidden module state. That's friendly
28
+ for eager execution, but hostile to `torch.export` and many accelerator
29
+ backends.
30
+
31
+ This export wrapper makes the per-token call truly "single-step" in the
32
+ export sense: "stateless" and "fixed-shape" so every invocation has the
33
+ exact same graph.
34
+
35
+ Key invariants
36
+ --------------
37
+ • "Stateless": K/V caches come in as explicit inputs and go out as outputs.
38
+ The module does not store or mutate hidden state.
39
+ • "Static shapes": Query is always [B, 1, C]; encoder features and masks
40
+ have fixed, predeclared sizes; K/V slots use fixed capacity (unused tail
41
+ is simply masked/ignored).
42
+ • "External control": Step indexing, cache slot management (append/roll),
43
+ and beam reordering are handled outside the module.
44
+ • "Prebuilt additive masks": Self-attention masks are provided by the
45
+ caller (0 for valid, large negative sentinel, e.g. -120, for masked),
46
+ avoiding data-dependent control flow.
47
+
48
+ In short: still step-wise like fairseq, but restructured for export—no
49
+ internal state, no data-dependent shapes, no dynamic control flow.
50
+ """
51
+
52
+ from typing import List, Tuple
53
+
54
+ import torch
55
+ import torch.nn as nn
56
+
57
+ import tico
58
+
59
+ # ----- 1) Export wrapper module -------------------------------------------
60
+ class DecoderExportSingleStep(nn.Module):
61
+ """
62
+ Export-only single-step decoder module.
63
+
64
+ Inputs (example shapes; B=1, H=8, Dh=64, C=512, S=64, Tprev=63):
65
+ - prev_x: [B, 1, C] embedded decoder input for the current step
66
+ - enc_x: [S, B, C] encoder hidden states (fixed-length export input)
67
+ - enc_pad_additive: [B, 1, S] additive float key_padding_mask for enc-dec attn (0 for keep, -120 for pad)
68
+ - self_attn_mask: [B, 1, S] additive float mask for decoder self-attn at this step; pass zeros if unused
69
+ - prev_self_k_0..L-1: [B, H, Tprev, Dh] cached self-attn K per layer
70
+ - prev_self_v_0..L-1: [B, H, Tprev, Dh] cached self-attn V per layer
71
+
72
+ Outputs:
73
+ - x_out: [B, 1, C] new decoder features at the current step
74
+ - new_k_0..L-1: [H, B, Dh] per-layer new K (single-timestep; time dim squeezed)
75
+ - new_v_0..L-1: [H, B, Dh] per-layer new V (single-timestep; time dim squeezed)
76
+
77
+ Notes:
78
+ • We keep masks/additive semantics externally to avoid any mask-building inside the graph.
79
+ • We reshape the new K/V from [B,H,1,Dh] -> [H,B,Dh] to match the requested output spec (8,1,64).
80
+ """
81
+
82
+ def __init__(self, decoder: nn.Module):
83
+ super().__init__()
84
+ self.decoder = decoder
85
+ # Cache common meta for assertions
86
+ self.num_layers = len(getattr(decoder, "layers"))
87
+ # Infer heads/head_dim from the wrapped self_attn of layer 0
88
+ any_layer = getattr(decoder.layers[0], "wrapped", decoder.layers[0]) # type: ignore[index]
89
+ mha = getattr(any_layer, "self_attn", None)
90
+ assert mha is not None, "Decoder layer must expose self_attn"
91
+ self.num_heads = int(mha.num_heads)
92
+ self.head_dim = int(mha.head_dim)
93
+ # Embed dim (C)
94
+ self.embed_dim = int(getattr(decoder, "embed_dim"))
95
+
96
+ def forward(
97
+ self,
98
+ prev_x: torch.Tensor, # [B,1,C]
99
+ enc_x: torch.Tensor, # [S,B,C]
100
+ enc_pad_additive: torch.Tensor, # [B,1,S]
101
+ *kv_args: torch.Tensor, # prev_k_0..L-1, prev_v_0..L-1 (total 2L tensors)
102
+ self_attn_mask: torch.Tensor, # [B,1,S] (or zeros)
103
+ ):
104
+ L = self.num_layers
105
+ H = self.num_heads
106
+ Dh = self.head_dim
107
+ B, one, C = prev_x.shape
108
+ S, B2, C2 = enc_x.shape
109
+ assert (
110
+ one == 1 and C == self.embed_dim and B == B2 and C2 == C
111
+ ), "Shape mismatch in prev_x/enc_x"
112
+ assert len(kv_args) == 2 * L, f"Expected {2*L} KV tensors, got {len(kv_args)}"
113
+
114
+ # Unpack previous self-attn caches
115
+ prev_k_list: List[torch.Tensor] = list() # each [B,H,Tprev,Dh]
116
+ prev_v_list: List[torch.Tensor] = list() # each [B,H,Tprev,Dh]
117
+ for i in range(L):
118
+ prev_k_list.append(kv_args[2 * i])
119
+ prev_v_list.append(kv_args[2 * i + 1])
120
+ for i in range(L):
121
+ assert (
122
+ prev_k_list[i].dim() == 4 and prev_v_list[i].dim() == 4
123
+ ), "KV must be [B,H,Tprev,Dh]"
124
+ assert (
125
+ prev_k_list[i].shape[0] == B
126
+ and prev_k_list[i].shape[1] == H
127
+ and prev_k_list[i].shape[3] == Dh
128
+ )
129
+
130
+ # Call decoder's external single-step path
131
+ # Returns:
132
+ # x_step: [B,1,C]
133
+ # newk/newv: lists of length L, each [B*H,1,Dh]
134
+ x_step, newk_list, newv_list = self.decoder.forward_external_step( # type: ignore[operator]
135
+ prev_output_x=prev_x,
136
+ encoder_out_x=enc_x,
137
+ encoder_padding_mask=enc_pad_additive,
138
+ self_attn_mask=self_attn_mask,
139
+ prev_self_k_list=prev_k_list,
140
+ prev_self_v_list=prev_v_list,
141
+ )
142
+
143
+ out_tensors: List[torch.Tensor] = [
144
+ x_step
145
+ ] # first output is the new decoder features
146
+ for i in range(L):
147
+ nk = newk_list[i] # [B*H, Tnew, Dh]
148
+ nv = newv_list[i] # [B*H, Tnew, Dh]
149
+ out_tensors.append(nk)
150
+ out_tensors.append(nv)
151
+
152
+ # Return tuple: (x_step, new_k_0, new_v_0, new_k_1, new_v_1, ..., new_k_{L-1}, new_v_{L-1})
153
+ return tuple(out_tensors)
154
+
155
+
156
+ # ----- 2) Example inputs (B=1, S=64, H=8, Dh=64, C=512, L=4) ---------------
157
+ def make_example_inputs(*, L=4, B=1, S=64, H=8, Dh=64, C=512, Tprev=63, device="cpu"):
158
+ """
159
+ Build example tensors that match the export I/O spec.
160
+ Shapes follow the request:
161
+ prev_x: [1,1,512]
162
+ enc_x: [64,1,512]
163
+ enc_pad_additive: [1,1,64] (additive float; zeros -> keep)
164
+ prev_k_i / prev_v_i (for i in 0..L-1): [1,8,63,64]
165
+ self_attn_mask: [1,1,64] (additive float; zeros -> keep)
166
+ """
167
+ g = torch.Generator(device=device).manual_seed(0)
168
+
169
+ prev_x = torch.randn(B, 1, C, device=device, dtype=torch.float32, generator=g)
170
+ enc_x = torch.randn(S, B, C, device=device, dtype=torch.float32, generator=g)
171
+
172
+ # Additive masks (0 for allowed, -120 for masked)
173
+ enc_pad_additive = torch.full((B, 1, S), float(-120), device=device)
174
+ self_attn_mask = torch.full((B, 1, S), float(-120), device=device)
175
+ enc_pad_additive[0, :27] = 0 # 27 is a random example.
176
+ self_attn_mask[0, :27] = 0 # 27 is a random example.
177
+
178
+ # Previous self-attn caches for each layer
179
+ prev_k_list = []
180
+ prev_v_list = []
181
+ for _ in range(L):
182
+ prev_k = torch.randn(
183
+ B, H, Tprev, Dh, device=device, dtype=torch.float32, generator=g
184
+ )
185
+ prev_v = torch.randn(
186
+ B, H, Tprev, Dh, device=device, dtype=torch.float32, generator=g
187
+ )
188
+ prev_k_list.append(prev_k)
189
+ prev_v_list.append(prev_v)
190
+
191
+ # Pack inputs as the export function will expect:
192
+ # (prev_x, enc_x, enc_pad_additive, self_attn_mask, prev_k_0..L-1, prev_v_0..L-1)
193
+ example_args: Tuple[torch.Tensor, ...] = (
194
+ prev_x,
195
+ enc_x,
196
+ enc_pad_additive,
197
+ *prev_k_list,
198
+ *prev_v_list,
199
+ )
200
+ example_kwargs = {"self_attn_mask": self_attn_mask}
201
+ return example_args, example_kwargs
202
+
203
+
204
+ # ----- 3) Export driver -----------------------------------------------------
205
+ def export_decoder_single_step(translator, *, save_path="decoder_step_export.circle"):
206
+ """
207
+ Wrap the QuantFairseqDecoder into the export-friendly single-step module
208
+ and export with torch.export.export using example inputs.
209
+ """
210
+ # Grab the wrapped decoder
211
+ dec = translator.models[
212
+ 0
213
+ ].decoder # assumed QuantFairseqDecoder with forward_external_step
214
+ # Build export wrapper
215
+ wrapper = DecoderExportSingleStep(decoder=dec).eval()
216
+
217
+ # Example inputs (L inferred from wrapper/decoder)
218
+ L = wrapper.num_layers
219
+ H = wrapper.num_heads
220
+ Dh = wrapper.head_dim
221
+ C = wrapper.embed_dim
222
+ example_inputs, example_kwargs = make_example_inputs(L=L, H=H, Dh=Dh, C=C)
223
+
224
+ # Export circle (no dynamism assumed; shapes are fixed for export)
225
+ cm = tico.convert(
226
+ wrapper,
227
+ args=example_inputs,
228
+ kwargs=example_kwargs,
229
+ strict=True, # fail if something cannot be captured
230
+ )
231
+
232
+ # Save .pte
233
+ cm.save(save_path)
234
+ print(f"Saved decoder single-step export to: {save_path}")