tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +9 -1
- tico/config/base.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_aten_where_arg_type.py +1 -1
- tico/passes/cast_clamp_mixed_type_args.py +169 -0
- tico/passes/cast_mixed_type_args.py +4 -2
- tico/passes/const_prop_pass.py +1 -1
- tico/passes/convert_conv1d_to_conv2d.py +1 -1
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_addmm.py +0 -3
- tico/passes/decompose_batch_norm.py +2 -2
- tico/passes/decompose_fake_quantize.py +0 -3
- tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
- tico/passes/decompose_group_norm.py +0 -3
- tico/passes/legalize_predefined_layout_operators.py +2 -11
- tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
- tico/passes/lower_to_slice.py +1 -1
- tico/passes/merge_consecutive_cat.py +1 -1
- tico/passes/ops.py +1 -1
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- tico/passes/remove_redundant_expand.py +3 -6
- tico/passes/remove_redundant_reshape.py +5 -5
- tico/passes/segment_index_select.py +1 -1
- tico/quantization/__init__.py +6 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
- tico/quantization/algorithm/gptq/quantizer.py +292 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/quantization/config/ptq.py +119 -0
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/quantization/evaluation/metric.py +146 -0
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/quantization/passes/__init__.py +1 -0
- tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
- tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
- tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
- tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
- tico/{experimental/quantization → quantization}/public_interface.py +19 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/__init__.py +1 -0
- tico/quantization/wrapq/dtypes.py +70 -0
- tico/quantization/wrapq/examples/__init__.py +1 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/quantization/wrapq/examples/quantize_linear.py +107 -0
- tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
- tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
- tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/quantization/wrapq/mode.py +32 -0
- tico/quantization/wrapq/observers/__init__.py +1 -0
- tico/quantization/wrapq/observers/affine_base.py +128 -0
- tico/quantization/wrapq/observers/base.py +98 -0
- tico/quantization/wrapq/observers/ema.py +62 -0
- tico/quantization/wrapq/observers/identity.py +74 -0
- tico/quantization/wrapq/observers/minmax.py +39 -0
- tico/quantization/wrapq/observers/mx.py +60 -0
- tico/quantization/wrapq/qscheme.py +40 -0
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/quantization/wrapq/utils/__init__.py +1 -0
- tico/quantization/wrapq/utils/introspection.py +167 -0
- tico/quantization/wrapq/utils/metrics.py +124 -0
- tico/quantization/wrapq/utils/reduce_utils.py +25 -0
- tico/quantization/wrapq/wrappers/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
- tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
- tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
- tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
- tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
- tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
- tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
- tico/quantization/wrapq/wrappers/registry.py +125 -0
- tico/serialize/circle_graph.py +12 -4
- tico/serialize/circle_mapping.py +76 -2
- tico/serialize/circle_serializer.py +253 -148
- tico/serialize/operators/adapters/__init__.py +1 -0
- tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
- tico/serialize/operators/op_any.py +7 -14
- tico/serialize/operators/op_avg_pool2d.py +11 -4
- tico/serialize/operators/op_clamp.py +5 -7
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_conv2d.py +14 -6
- tico/serialize/operators/op_copy.py +26 -3
- tico/serialize/operators/op_cumsum.py +3 -1
- tico/serialize/operators/op_depthwise_conv2d.py +17 -7
- tico/serialize/operators/op_full_like.py +0 -2
- tico/serialize/operators/op_index_select.py +8 -1
- tico/serialize/operators/op_instance_norm.py +0 -6
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_log1p.py +3 -2
- tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
- tico/serialize/operators/op_mm.py +15 -131
- tico/serialize/operators/op_mul.py +2 -8
- tico/serialize/operators/op_pow.py +3 -1
- tico/serialize/operators/op_repeat.py +12 -3
- tico/serialize/operators/op_reshape.py +1 -1
- tico/serialize/operators/op_rmsnorm.py +65 -0
- tico/serialize/operators/op_softmax.py +7 -14
- tico/serialize/operators/op_split_with_sizes.py +16 -8
- tico/serialize/operators/op_transpose_conv.py +11 -8
- tico/serialize/operators/op_view.py +2 -1
- tico/serialize/quant_param.py +5 -5
- tico/utils/convert.py +30 -17
- tico/utils/dtype.py +42 -0
- tico/utils/graph.py +1 -1
- tico/utils/model.py +2 -1
- tico/utils/padding.py +2 -2
- tico/utils/pytree_utils.py +134 -0
- tico/utils/record_input.py +102 -0
- tico/utils/register_custom_op.py +29 -4
- tico/utils/serialize.py +16 -3
- tico/utils/signature.py +247 -0
- tico/utils/torch_compat.py +52 -0
- tico/utils/utils.py +50 -58
- tico/utils/validate_args_kwargs.py +38 -3
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
- tico-0.1.0.dev251102.dist-info/RECORD +271 -0
- tico/experimental/quantization/__init__.py +0 -1
- tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/evaluation/metric.py +0 -109
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
- tico-0.1.0.dev250714.dist-info/RECORD +0 -209
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
from tico.quantization.evaluation.metric import MetricCalculator
|
|
20
|
+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
|
|
21
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def build_fqn_map(root: torch.nn.Module) -> dict[torch.nn.Module, str]:
|
|
25
|
+
"""
|
|
26
|
+
Return {module_object: full_qualified_name} without touching the modules.
|
|
27
|
+
"""
|
|
28
|
+
return {m: n for n, m in root.named_modules()}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def save_fp_outputs(
|
|
32
|
+
model: torch.nn.Module,
|
|
33
|
+
) -> Tuple[List[torch.utils.hooks.RemovableHandle], Dict[str, torch.Tensor]]:
|
|
34
|
+
"""
|
|
35
|
+
Register forward-hooks on every `QuantModuleBase` wrapper itself (not the
|
|
36
|
+
wrapped `module`) and cache its output while the wrapper runs in CALIB mode.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
model : torch.nn.Module
|
|
41
|
+
The model whose wrappers are already switched to CALIB mode
|
|
42
|
+
(`enable_calibration()` has been called).
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
handles : list[RemovableHandle]
|
|
47
|
+
Hook handles; call `.remove()` on each one to detach the hooks.
|
|
48
|
+
cache : dict[str, torch.Tensor]
|
|
49
|
+
Mapping "wrapper-name → cached FP32 activation" captured from the first
|
|
50
|
+
forward pass. Keys default to `wrapper.fp_name`; if that attribute is
|
|
51
|
+
`None`, the `id(wrapper)` string is used instead.
|
|
52
|
+
"""
|
|
53
|
+
cache: Dict[str, torch.Tensor] = {}
|
|
54
|
+
handles: List[torch.utils.hooks.RemovableHandle] = []
|
|
55
|
+
|
|
56
|
+
def _save(name: str):
|
|
57
|
+
def hook(_, __, out: torch.Tensor | Tuple):
|
|
58
|
+
if isinstance(out, tuple):
|
|
59
|
+
out = out[0]
|
|
60
|
+
assert isinstance(out, torch.Tensor)
|
|
61
|
+
cache[name] = out.detach()
|
|
62
|
+
|
|
63
|
+
return hook
|
|
64
|
+
|
|
65
|
+
for m in model.modules():
|
|
66
|
+
if isinstance(m, QuantModuleBase):
|
|
67
|
+
name = m.fp_name or str(id(m))
|
|
68
|
+
handles.append(m.register_forward_hook(_save(name)))
|
|
69
|
+
|
|
70
|
+
return handles, cache
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def compare_layer_outputs(
|
|
74
|
+
model: torch.nn.Module,
|
|
75
|
+
cache: Dict[str, torch.Tensor],
|
|
76
|
+
*,
|
|
77
|
+
metrics: Optional[List[str]] = None,
|
|
78
|
+
custom_metrics: Optional[Dict[str, Callable]] = None,
|
|
79
|
+
rtol: float = 1e-3,
|
|
80
|
+
atol: float = 1e-3,
|
|
81
|
+
collect: bool = False,
|
|
82
|
+
):
|
|
83
|
+
"""
|
|
84
|
+
Register forward-hooks on every `QuantModuleBase` wrapper to compare its
|
|
85
|
+
QUANT-mode output to the FP32 reference saved by `save_fp_outputs()`.
|
|
86
|
+
|
|
87
|
+
Each hook prints a per-layer diff report:
|
|
88
|
+
|
|
89
|
+
✓ layer_name max=1.23e-02 mean=8.45e-04 (within tolerance)
|
|
90
|
+
⚠️ layer_name max=3.07e+00 mean=5.12e-01 (exceeds tolerance)
|
|
91
|
+
|
|
92
|
+
Parameters
|
|
93
|
+
----------
|
|
94
|
+
model : torch.nn.Module
|
|
95
|
+
The model whose wrappers are now in QUANT mode
|
|
96
|
+
(`freeze_qparams()` has been called).
|
|
97
|
+
cache : dict[str, torch.Tensor]
|
|
98
|
+
The reference activations captured during CALIB mode.
|
|
99
|
+
metrics
|
|
100
|
+
Metrics to compute. Defaults to `["diff"]`. Add `peir` to print PEIR.
|
|
101
|
+
custom_metrics
|
|
102
|
+
Optional user metric functions. Same signature as built-ins.
|
|
103
|
+
rtol, atol : float, optional
|
|
104
|
+
Relative / absolute tolerances used to flag large deviations
|
|
105
|
+
(similar to `torch.allclose` semantics).
|
|
106
|
+
collect : bool, optional
|
|
107
|
+
• False (default) → print one-line report per layer, return `None`
|
|
108
|
+
• True → suppress printing, return a nested dict
|
|
109
|
+
{layer_name -> {metric -> value}}
|
|
110
|
+
|
|
111
|
+
Returns
|
|
112
|
+
-------
|
|
113
|
+
handles
|
|
114
|
+
Hook handles; call `.remove()` once diffing is complete.
|
|
115
|
+
results
|
|
116
|
+
Only if *collect* is True.
|
|
117
|
+
"""
|
|
118
|
+
metrics = metrics or ["diff"]
|
|
119
|
+
calc = MetricCalculator(custom_metrics)
|
|
120
|
+
handles: List[torch.utils.hooks.RemovableHandle] = []
|
|
121
|
+
results: Dict[
|
|
122
|
+
str, Dict[str, float]
|
|
123
|
+
] = {} # Dict[layer_name, Dict[metric_name, value]]
|
|
124
|
+
|
|
125
|
+
def _cmp(name: str):
|
|
126
|
+
ref = cache.get(name)
|
|
127
|
+
|
|
128
|
+
def hook(_, __, out):
|
|
129
|
+
if ref is None:
|
|
130
|
+
if not collect:
|
|
131
|
+
print(f"[{name}] no cached reference")
|
|
132
|
+
return
|
|
133
|
+
if isinstance(out, tuple):
|
|
134
|
+
out = out[0]
|
|
135
|
+
assert isinstance(out, torch.Tensor)
|
|
136
|
+
|
|
137
|
+
# Compute all requested metrics
|
|
138
|
+
res = calc.compute([ref], [out], metrics) # lists with length-1 tensors
|
|
139
|
+
res = {k: v[0] for k, v in res.items()} # flatten
|
|
140
|
+
|
|
141
|
+
if collect:
|
|
142
|
+
results[name] = res # type: ignore[assignment]
|
|
143
|
+
return
|
|
144
|
+
|
|
145
|
+
# Pretty print ------------------------------------------------ #
|
|
146
|
+
diff_val = res.get("diff") or res.get("max_abs_diff")
|
|
147
|
+
thresh = atol + rtol * ref.abs().max().item()
|
|
148
|
+
flag = "⚠️" if (diff_val is not None and diff_val > thresh) else "✓" # type: ignore[operator]
|
|
149
|
+
|
|
150
|
+
pieces = [f"{flag} {name:45s}"]
|
|
151
|
+
for key, val in res.items():
|
|
152
|
+
pieces.append(f"{key}={val:<7.4}")
|
|
153
|
+
print(" ".join(pieces))
|
|
154
|
+
|
|
155
|
+
return hook
|
|
156
|
+
|
|
157
|
+
for m in model.modules():
|
|
158
|
+
if isinstance(m, PTQWrapper):
|
|
159
|
+
# skip the internal fp module inside the wrapper
|
|
160
|
+
continue
|
|
161
|
+
if isinstance(m, QuantModuleBase):
|
|
162
|
+
lname = m.fp_name or str(id(m))
|
|
163
|
+
handles.append(m.register_forward_hook(_cmp(lname)))
|
|
164
|
+
|
|
165
|
+
if collect:
|
|
166
|
+
return handles, results
|
|
167
|
+
return handles
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import tqdm
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def perplexity(
|
|
22
|
+
model: torch.nn.Module,
|
|
23
|
+
encodings: torch.Tensor,
|
|
24
|
+
device: torch.device | str,
|
|
25
|
+
*,
|
|
26
|
+
max_length: Optional[int] = None,
|
|
27
|
+
stride: int = 512,
|
|
28
|
+
ignore_index: int | None = -100,
|
|
29
|
+
show_progress: bool = True,
|
|
30
|
+
) -> float:
|
|
31
|
+
"""
|
|
32
|
+
Compute perplexity (PPL) using a "strided sliding-window"
|
|
33
|
+
evaluation strategy.
|
|
34
|
+
|
|
35
|
+
The function:
|
|
36
|
+
1. Splits the token sequence into overlapping windows of length
|
|
37
|
+
`max_length` (model context size).
|
|
38
|
+
2. Masks tokens that were already scored in previous windows
|
|
39
|
+
(`labels == -100`), so each token's negative log-likelihood (NLL)
|
|
40
|
+
is counted EXACTLY once.
|
|
41
|
+
3. Aggregates token-wise NLL to return corpus-level PPL.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
model : torch.nn.Module
|
|
46
|
+
Causal LM loaded in evaluation mode (`model.eval()`).
|
|
47
|
+
encodings : torch.Tensor | transformers.BatchEncoding
|
|
48
|
+
Tokenised corpus. If a `BatchEncoding` is passed, its
|
|
49
|
+
`.input_ids` field is used. Shape must be `(1, seq_len)`.
|
|
50
|
+
device : torch.device | str
|
|
51
|
+
CUDA or CPU device on which to run evaluation.
|
|
52
|
+
max_length : int, optional
|
|
53
|
+
Context window size. Defaults to `model.config.max_position_embeddings`.
|
|
54
|
+
stride : int, default = 512
|
|
55
|
+
Step size by which the sliding window advances. Must satisfy
|
|
56
|
+
`1 ≤ stride ≤ max_length`.
|
|
57
|
+
ignore_index : int, default = -100
|
|
58
|
+
Label value to ignore in loss computation. This should match
|
|
59
|
+
the `ignore_index` used by the model's internal
|
|
60
|
+
`CrossEntropyLoss`. For Hugging Face causal LMs, the
|
|
61
|
+
convention is `-100`.
|
|
62
|
+
show_progress : bool, default = True
|
|
63
|
+
If True, displays a tqdm progess bar while evaluating.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
float
|
|
68
|
+
Corpus-level perplexity.
|
|
69
|
+
"""
|
|
70
|
+
# -------- input preparation -------- #
|
|
71
|
+
try:
|
|
72
|
+
# transformers.BatchEncoding has `input_ids`
|
|
73
|
+
input_ids_full = encodings.input_ids # type: ignore[attr-defined]
|
|
74
|
+
except AttributeError: # already a tensor
|
|
75
|
+
input_ids_full = encodings
|
|
76
|
+
assert isinstance(input_ids_full, torch.Tensor)
|
|
77
|
+
input_ids_full = input_ids_full.to(device)
|
|
78
|
+
|
|
79
|
+
if max_length is None:
|
|
80
|
+
assert hasattr(model, "config")
|
|
81
|
+
assert hasattr(model.config, "max_position_embeddings")
|
|
82
|
+
assert isinstance(model.config.max_position_embeddings, int)
|
|
83
|
+
max_length = model.config.max_position_embeddings
|
|
84
|
+
assert max_length is not None
|
|
85
|
+
assert (
|
|
86
|
+
1 <= stride <= max_length
|
|
87
|
+
), f"stride ({stride}) must be in [1, max_length ({max_length})]"
|
|
88
|
+
|
|
89
|
+
seq_len = input_ids_full.size(1)
|
|
90
|
+
nll_sum = 0.0
|
|
91
|
+
n_tokens = 0
|
|
92
|
+
prev_end = 0
|
|
93
|
+
|
|
94
|
+
# -------- main loop -------- #
|
|
95
|
+
for begin in tqdm.trange(0, seq_len, stride, desc="PPL", disable=not show_progress):
|
|
96
|
+
end = min(begin + max_length, seq_len)
|
|
97
|
+
trg_len = end - prev_end # fresh tokens in this window
|
|
98
|
+
|
|
99
|
+
input_ids = input_ids_full[:, begin:end]
|
|
100
|
+
target_ids = input_ids.clone()
|
|
101
|
+
# mask previously-scored tokens
|
|
102
|
+
target_ids[:, :-trg_len] = ignore_index # type: ignore[assignment]
|
|
103
|
+
|
|
104
|
+
with torch.no_grad():
|
|
105
|
+
outputs = model(input_ids, labels=target_ids)
|
|
106
|
+
# loss is already averaged over non-masked labels
|
|
107
|
+
neg_log_likelihood = outputs.loss
|
|
108
|
+
|
|
109
|
+
# exact number of labels that contributed to loss
|
|
110
|
+
loss_tokens = (target_ids[:, 1:] != ignore_index).sum().item() # type: ignore[attr-defined]
|
|
111
|
+
nll_sum += neg_log_likelihood * loss_tokens
|
|
112
|
+
n_tokens += int(loss_tokens)
|
|
113
|
+
|
|
114
|
+
prev_end = end
|
|
115
|
+
if end == seq_len:
|
|
116
|
+
break
|
|
117
|
+
|
|
118
|
+
avg_nll: float | torch.Tensor = nll_sum / n_tokens
|
|
119
|
+
if not isinstance(avg_nll, torch.Tensor):
|
|
120
|
+
avg_nll = torch.tensor(avg_nll)
|
|
121
|
+
assert isinstance(avg_nll, torch.Tensor)
|
|
122
|
+
ppl = torch.exp(avg_nll)
|
|
123
|
+
|
|
124
|
+
return ppl.item()
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def channelwise_minmax(x: torch.Tensor, channel_axis: int):
|
|
19
|
+
"""
|
|
20
|
+
Compute per-channel (min, max) by reducing all axes except `channel_axis`.
|
|
21
|
+
"""
|
|
22
|
+
channel_axis = channel_axis % x.ndim # handle negative indices safely
|
|
23
|
+
dims = tuple(d for d in range(x.ndim) if d != channel_axis)
|
|
24
|
+
|
|
25
|
+
return x.amin(dim=dims), x.amax(dim=dims)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
# -----------------------------------------------------------------------------
|
|
16
|
+
# This file includes modifications based on fairseq
|
|
17
|
+
# (https://github.com/facebookresearch/fairseq), originally licensed under
|
|
18
|
+
# the MIT License. See the LICENSE file in the fairseq repository for details.
|
|
19
|
+
# -----------------------------------------------------------------------------
|
|
20
|
+
|
|
21
|
+
"""
|
|
22
|
+
Q) Why the name "SingleStep"?
|
|
23
|
+
|
|
24
|
+
Fairseq's decoder already advances one token at a time during generation,
|
|
25
|
+
but the default path is "stateful" and "shape-polymorphic": it owns and
|
|
26
|
+
mutates K/V caches internally, prefix lengths and triangular masks grow with
|
|
27
|
+
the step, and beam reordering updates hidden module state. That's friendly
|
|
28
|
+
for eager execution, but hostile to `torch.export` and many accelerator
|
|
29
|
+
backends.
|
|
30
|
+
|
|
31
|
+
This export wrapper makes the per-token call truly "single-step" in the
|
|
32
|
+
export sense: "stateless" and "fixed-shape" so every invocation has the
|
|
33
|
+
exact same graph.
|
|
34
|
+
|
|
35
|
+
Key invariants
|
|
36
|
+
--------------
|
|
37
|
+
• "Stateless": K/V caches come in as explicit inputs and go out as outputs.
|
|
38
|
+
The module does not store or mutate hidden state.
|
|
39
|
+
• "Static shapes": Query is always [B, 1, C]; encoder features and masks
|
|
40
|
+
have fixed, predeclared sizes; K/V slots use fixed capacity (unused tail
|
|
41
|
+
is simply masked/ignored).
|
|
42
|
+
• "External control": Step indexing, cache slot management (append/roll),
|
|
43
|
+
and beam reordering are handled outside the module.
|
|
44
|
+
• "Prebuilt additive masks": Self-attention masks are provided by the
|
|
45
|
+
caller (0 for valid, large negative sentinel, e.g. -120, for masked),
|
|
46
|
+
avoiding data-dependent control flow.
|
|
47
|
+
|
|
48
|
+
In short: still step-wise like fairseq, but restructured for export—no
|
|
49
|
+
internal state, no data-dependent shapes, no dynamic control flow.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
from typing import List, Tuple
|
|
53
|
+
|
|
54
|
+
import torch
|
|
55
|
+
import torch.nn as nn
|
|
56
|
+
|
|
57
|
+
import tico
|
|
58
|
+
|
|
59
|
+
# ----- 1) Export wrapper module -------------------------------------------
|
|
60
|
+
class DecoderExportSingleStep(nn.Module):
|
|
61
|
+
"""
|
|
62
|
+
Export-only single-step decoder module.
|
|
63
|
+
|
|
64
|
+
Inputs (example shapes; B=1, H=8, Dh=64, C=512, S=64, Tprev=63):
|
|
65
|
+
- prev_x: [B, 1, C] embedded decoder input for the current step
|
|
66
|
+
- enc_x: [S, B, C] encoder hidden states (fixed-length export input)
|
|
67
|
+
- enc_pad_additive: [B, 1, S] additive float key_padding_mask for enc-dec attn (0 for keep, -120 for pad)
|
|
68
|
+
- self_attn_mask: [B, 1, S] additive float mask for decoder self-attn at this step; pass zeros if unused
|
|
69
|
+
- prev_self_k_0..L-1: [B, H, Tprev, Dh] cached self-attn K per layer
|
|
70
|
+
- prev_self_v_0..L-1: [B, H, Tprev, Dh] cached self-attn V per layer
|
|
71
|
+
|
|
72
|
+
Outputs:
|
|
73
|
+
- x_out: [B, 1, C] new decoder features at the current step
|
|
74
|
+
- new_k_0..L-1: [H, B, Dh] per-layer new K (single-timestep; time dim squeezed)
|
|
75
|
+
- new_v_0..L-1: [H, B, Dh] per-layer new V (single-timestep; time dim squeezed)
|
|
76
|
+
|
|
77
|
+
Notes:
|
|
78
|
+
• We keep masks/additive semantics externally to avoid any mask-building inside the graph.
|
|
79
|
+
• We reshape the new K/V from [B,H,1,Dh] -> [H,B,Dh] to match the requested output spec (8,1,64).
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(self, decoder: nn.Module):
|
|
83
|
+
super().__init__()
|
|
84
|
+
self.decoder = decoder
|
|
85
|
+
# Cache common meta for assertions
|
|
86
|
+
self.num_layers = len(getattr(decoder, "layers"))
|
|
87
|
+
# Infer heads/head_dim from the wrapped self_attn of layer 0
|
|
88
|
+
any_layer = getattr(decoder.layers[0], "wrapped", decoder.layers[0]) # type: ignore[index]
|
|
89
|
+
mha = getattr(any_layer, "self_attn", None)
|
|
90
|
+
assert mha is not None, "Decoder layer must expose self_attn"
|
|
91
|
+
self.num_heads = int(mha.num_heads)
|
|
92
|
+
self.head_dim = int(mha.head_dim)
|
|
93
|
+
# Embed dim (C)
|
|
94
|
+
self.embed_dim = int(getattr(decoder, "embed_dim"))
|
|
95
|
+
|
|
96
|
+
def forward(
|
|
97
|
+
self,
|
|
98
|
+
prev_x: torch.Tensor, # [B,1,C]
|
|
99
|
+
enc_x: torch.Tensor, # [S,B,C]
|
|
100
|
+
enc_pad_additive: torch.Tensor, # [B,1,S]
|
|
101
|
+
*kv_args: torch.Tensor, # prev_k_0..L-1, prev_v_0..L-1 (total 2L tensors)
|
|
102
|
+
self_attn_mask: torch.Tensor, # [B,1,S] (or zeros)
|
|
103
|
+
):
|
|
104
|
+
L = self.num_layers
|
|
105
|
+
H = self.num_heads
|
|
106
|
+
Dh = self.head_dim
|
|
107
|
+
B, one, C = prev_x.shape
|
|
108
|
+
S, B2, C2 = enc_x.shape
|
|
109
|
+
assert (
|
|
110
|
+
one == 1 and C == self.embed_dim and B == B2 and C2 == C
|
|
111
|
+
), "Shape mismatch in prev_x/enc_x"
|
|
112
|
+
assert len(kv_args) == 2 * L, f"Expected {2*L} KV tensors, got {len(kv_args)}"
|
|
113
|
+
|
|
114
|
+
# Unpack previous self-attn caches
|
|
115
|
+
prev_k_list: List[torch.Tensor] = list() # each [B,H,Tprev,Dh]
|
|
116
|
+
prev_v_list: List[torch.Tensor] = list() # each [B,H,Tprev,Dh]
|
|
117
|
+
for i in range(L):
|
|
118
|
+
prev_k_list.append(kv_args[2 * i])
|
|
119
|
+
prev_v_list.append(kv_args[2 * i + 1])
|
|
120
|
+
for i in range(L):
|
|
121
|
+
assert (
|
|
122
|
+
prev_k_list[i].dim() == 4 and prev_v_list[i].dim() == 4
|
|
123
|
+
), "KV must be [B,H,Tprev,Dh]"
|
|
124
|
+
assert (
|
|
125
|
+
prev_k_list[i].shape[0] == B
|
|
126
|
+
and prev_k_list[i].shape[1] == H
|
|
127
|
+
and prev_k_list[i].shape[3] == Dh
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Call decoder's external single-step path
|
|
131
|
+
# Returns:
|
|
132
|
+
# x_step: [B,1,C]
|
|
133
|
+
# newk/newv: lists of length L, each [B*H,1,Dh]
|
|
134
|
+
x_step, newk_list, newv_list = self.decoder.forward_external_step( # type: ignore[operator]
|
|
135
|
+
prev_output_x=prev_x,
|
|
136
|
+
encoder_out_x=enc_x,
|
|
137
|
+
encoder_padding_mask=enc_pad_additive,
|
|
138
|
+
self_attn_mask=self_attn_mask,
|
|
139
|
+
prev_self_k_list=prev_k_list,
|
|
140
|
+
prev_self_v_list=prev_v_list,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
out_tensors: List[torch.Tensor] = [
|
|
144
|
+
x_step
|
|
145
|
+
] # first output is the new decoder features
|
|
146
|
+
for i in range(L):
|
|
147
|
+
nk = newk_list[i] # [B*H, Tnew, Dh]
|
|
148
|
+
nv = newv_list[i] # [B*H, Tnew, Dh]
|
|
149
|
+
out_tensors.append(nk)
|
|
150
|
+
out_tensors.append(nv)
|
|
151
|
+
|
|
152
|
+
# Return tuple: (x_step, new_k_0, new_v_0, new_k_1, new_v_1, ..., new_k_{L-1}, new_v_{L-1})
|
|
153
|
+
return tuple(out_tensors)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
# ----- 2) Example inputs (B=1, S=64, H=8, Dh=64, C=512, L=4) ---------------
|
|
157
|
+
def make_example_inputs(*, L=4, B=1, S=64, H=8, Dh=64, C=512, Tprev=63, device="cpu"):
|
|
158
|
+
"""
|
|
159
|
+
Build example tensors that match the export I/O spec.
|
|
160
|
+
Shapes follow the request:
|
|
161
|
+
prev_x: [1,1,512]
|
|
162
|
+
enc_x: [64,1,512]
|
|
163
|
+
enc_pad_additive: [1,1,64] (additive float; zeros -> keep)
|
|
164
|
+
prev_k_i / prev_v_i (for i in 0..L-1): [1,8,63,64]
|
|
165
|
+
self_attn_mask: [1,1,64] (additive float; zeros -> keep)
|
|
166
|
+
"""
|
|
167
|
+
g = torch.Generator(device=device).manual_seed(0)
|
|
168
|
+
|
|
169
|
+
prev_x = torch.randn(B, 1, C, device=device, dtype=torch.float32, generator=g)
|
|
170
|
+
enc_x = torch.randn(S, B, C, device=device, dtype=torch.float32, generator=g)
|
|
171
|
+
|
|
172
|
+
# Additive masks (0 for allowed, -120 for masked)
|
|
173
|
+
enc_pad_additive = torch.full((B, 1, S), float(-120), device=device)
|
|
174
|
+
self_attn_mask = torch.full((B, 1, S), float(-120), device=device)
|
|
175
|
+
enc_pad_additive[0, :27] = 0 # 27 is a random example.
|
|
176
|
+
self_attn_mask[0, :27] = 0 # 27 is a random example.
|
|
177
|
+
|
|
178
|
+
# Previous self-attn caches for each layer
|
|
179
|
+
prev_k_list = []
|
|
180
|
+
prev_v_list = []
|
|
181
|
+
for _ in range(L):
|
|
182
|
+
prev_k = torch.randn(
|
|
183
|
+
B, H, Tprev, Dh, device=device, dtype=torch.float32, generator=g
|
|
184
|
+
)
|
|
185
|
+
prev_v = torch.randn(
|
|
186
|
+
B, H, Tprev, Dh, device=device, dtype=torch.float32, generator=g
|
|
187
|
+
)
|
|
188
|
+
prev_k_list.append(prev_k)
|
|
189
|
+
prev_v_list.append(prev_v)
|
|
190
|
+
|
|
191
|
+
# Pack inputs as the export function will expect:
|
|
192
|
+
# (prev_x, enc_x, enc_pad_additive, self_attn_mask, prev_k_0..L-1, prev_v_0..L-1)
|
|
193
|
+
example_args: Tuple[torch.Tensor, ...] = (
|
|
194
|
+
prev_x,
|
|
195
|
+
enc_x,
|
|
196
|
+
enc_pad_additive,
|
|
197
|
+
*prev_k_list,
|
|
198
|
+
*prev_v_list,
|
|
199
|
+
)
|
|
200
|
+
example_kwargs = {"self_attn_mask": self_attn_mask}
|
|
201
|
+
return example_args, example_kwargs
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
# ----- 3) Export driver -----------------------------------------------------
|
|
205
|
+
def export_decoder_single_step(translator, *, save_path="decoder_step_export.circle"):
|
|
206
|
+
"""
|
|
207
|
+
Wrap the QuantFairseqDecoder into the export-friendly single-step module
|
|
208
|
+
and export with torch.export.export using example inputs.
|
|
209
|
+
"""
|
|
210
|
+
# Grab the wrapped decoder
|
|
211
|
+
dec = translator.models[
|
|
212
|
+
0
|
|
213
|
+
].decoder # assumed QuantFairseqDecoder with forward_external_step
|
|
214
|
+
# Build export wrapper
|
|
215
|
+
wrapper = DecoderExportSingleStep(decoder=dec).eval()
|
|
216
|
+
|
|
217
|
+
# Example inputs (L inferred from wrapper/decoder)
|
|
218
|
+
L = wrapper.num_layers
|
|
219
|
+
H = wrapper.num_heads
|
|
220
|
+
Dh = wrapper.head_dim
|
|
221
|
+
C = wrapper.embed_dim
|
|
222
|
+
example_inputs, example_kwargs = make_example_inputs(L=L, H=H, Dh=Dh, C=C)
|
|
223
|
+
|
|
224
|
+
# Export circle (no dynamism assumed; shapes are fixed for export)
|
|
225
|
+
cm = tico.convert(
|
|
226
|
+
wrapper,
|
|
227
|
+
args=example_inputs,
|
|
228
|
+
kwargs=example_kwargs,
|
|
229
|
+
strict=True, # fail if something cannot be captured
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Save .pte
|
|
233
|
+
cm.save(save_path)
|
|
234
|
+
print(f"Saved decoder single-step export to: {save_path}")
|