tico 0.1.0.dev250924__py3-none-any.whl → 0.1.0.dev251111__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tico might be problematic. Click here for more details.

Files changed (114) hide show
  1. tico/__init__.py +1 -1
  2. tico/quantization/__init__.py +6 -0
  3. tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +161 -0
  4. tico/quantization/algorithm/fpi_gptq/quantizer.py +179 -0
  5. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
  6. tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +12 -6
  7. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
  8. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  9. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  10. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  11. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  12. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
  13. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  14. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  15. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  16. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  17. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  18. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  19. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  20. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +4 -4
  21. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
  22. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +6 -10
  23. tico/quantization/config/fpi_gptq.py +29 -0
  24. tico/{experimental/quantization → quantization}/config/gptq.py +1 -1
  25. tico/{experimental/quantization → quantization}/config/pt2e.py +1 -1
  26. tico/{experimental/quantization/ptq/quant_config.py → quantization/config/ptq.py} +18 -10
  27. tico/{experimental/quantization → quantization}/config/smoothquant.py +1 -1
  28. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +6 -12
  29. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +1 -3
  30. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  31. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  32. tico/{experimental/quantization → quantization}/public_interface.py +7 -7
  33. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  34. tico/{experimental/quantization → quantization}/quantizer_registry.py +11 -10
  35. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/compare_ppl.py +8 -19
  36. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/debug_quant_outputs.py +9 -24
  37. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_linear.py +11 -10
  38. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_attn.py +10 -12
  39. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_decoder_layer.py +10 -9
  40. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_mlp.py +13 -13
  41. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_with_gptq.py +14 -35
  42. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/affine_base.py +3 -3
  43. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/base.py +2 -2
  44. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/ema.py +2 -2
  45. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/identity.py +1 -1
  46. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/minmax.py +2 -2
  47. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/mx.py +1 -1
  48. tico/quantization/wrapq/quantizer.py +179 -0
  49. tico/{experimental/quantization/ptq → quantization/wrapq}/utils/introspection.py +3 -5
  50. tico/{experimental/quantization/ptq → quantization/wrapq}/utils/metrics.py +3 -2
  51. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/fairseq/__init__.py +1 -1
  52. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/fairseq/quant_decoder.py +6 -8
  53. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/fairseq/quant_decoder_layer.py +6 -8
  54. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/fairseq/quant_encoder.py +6 -8
  55. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/fairseq/quant_encoder_layer.py +6 -8
  56. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/fairseq/quant_mha.py +5 -7
  57. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_attn.py +5 -7
  58. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_decoder_layer.py +8 -12
  59. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_mlp.py +5 -7
  60. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  61. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_layernorm.py +6 -7
  62. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_linear.py +7 -8
  63. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_silu.py +8 -9
  64. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/ptq_wrapper.py +4 -6
  65. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_elementwise.py +55 -17
  66. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_module_base.py +10 -9
  67. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/registry.py +17 -16
  68. tico/utils/convert.py +9 -14
  69. {tico-0.1.0.dev250924.dist-info → tico-0.1.0.dev251111.dist-info}/METADATA +48 -2
  70. {tico-0.1.0.dev250924.dist-info → tico-0.1.0.dev251111.dist-info}/RECORD +113 -108
  71. tico/experimental/quantization/__init__.py +0 -6
  72. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  73. /tico/{experimental/quantization/algorithm/gptq → quantization/algorithm/fpi_gptq}/__init__.py +0 -0
  74. /tico/{experimental/quantization/algorithm/pt2e → quantization/algorithm/gptq}/__init__.py +0 -0
  75. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  76. /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
  77. /tico/{experimental/quantization/algorithm/pt2e/annotation → quantization/algorithm/pt2e}/__init__.py +0 -0
  78. /tico/{experimental/quantization/algorithm/pt2e/transformation → quantization/algorithm/pt2e/annotation}/__init__.py +0 -0
  79. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  80. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  81. /tico/{experimental/quantization/algorithm/smoothquant → quantization/algorithm/pt2e/transformation}/__init__.py +0 -0
  82. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  83. /tico/{experimental/quantization/config → quantization/algorithm/smoothquant}/__init__.py +0 -0
  84. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +0 -0
  85. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/smooth_quant.py +0 -0
  86. /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
  87. /tico/{experimental/quantization → quantization}/config/base.py +0 -0
  88. /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
  89. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  90. /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
  91. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  92. /tico/{experimental/quantization → quantization}/evaluation/metric.py +0 -0
  93. /tico/{experimental/quantization/ptq → quantization/passes}/__init__.py +0 -0
  94. /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
  95. /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
  96. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  97. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  98. /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
  99. /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
  100. /tico/{experimental/quantization/ptq/examples → quantization/wrapq}/__init__.py +0 -0
  101. /tico/{experimental/quantization/ptq → quantization/wrapq}/dtypes.py +0 -0
  102. /tico/{experimental/quantization/ptq/observers → quantization/wrapq/examples}/__init__.py +0 -0
  103. /tico/{experimental/quantization/ptq → quantization/wrapq}/mode.py +0 -0
  104. /tico/{experimental/quantization/ptq/utils → quantization/wrapq/observers}/__init__.py +0 -0
  105. /tico/{experimental/quantization/ptq → quantization/wrapq}/qscheme.py +0 -0
  106. /tico/{experimental/quantization/ptq/wrappers → quantization/wrapq/utils}/__init__.py +0 -0
  107. /tico/{experimental/quantization/ptq → quantization/wrapq}/utils/reduce_utils.py +0 -0
  108. /tico/{experimental/quantization/ptq/wrappers/llama → quantization/wrapq/wrappers}/__init__.py +0 -0
  109. /tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/fairseq/decoder_export_single_step.py +0 -0
  110. /tico/{experimental/quantization/ptq/wrappers/nn → quantization/wrapq/wrappers/llama}/__init__.py +0 -0
  111. {tico-0.1.0.dev250924.dist-info → tico-0.1.0.dev251111.dist-info}/LICENSE +0 -0
  112. {tico-0.1.0.dev250924.dist-info → tico-0.1.0.dev251111.dist-info}/WHEEL +0 -0
  113. {tico-0.1.0.dev250924.dist-info → tico-0.1.0.dev251111.dist-info}/entry_points.txt +0 -0
  114. {tico-0.1.0.dev250924.dist-info → tico-0.1.0.dev251111.dist-info}/top_level.txt +0 -0
@@ -33,16 +33,14 @@ import tqdm
33
33
  from datasets import load_dataset
34
34
  from transformers import AutoModelForCausalLM, AutoTokenizer
35
35
 
36
- from tico.experimental.quantization import convert, prepare
37
- from tico.experimental.quantization.config.gptq import GPTQConfig
38
- from tico.experimental.quantization.ptq.observers.affine_base import AffineObserverBase
39
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
40
- from tico.experimental.quantization.ptq.utils.introspection import build_fqn_map
41
- from tico.experimental.quantization.ptq.utils.metrics import perplexity
42
- from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
43
- from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
44
- QuantModuleBase,
45
- )
36
+ from tico.quantization import convert, prepare
37
+ from tico.quantization.config.gptq import GPTQConfig
38
+ from tico.quantization.config.ptq import PTQConfig
39
+ from tico.quantization.wrapq.observers.affine_base import AffineObserverBase
40
+ from tico.quantization.wrapq.utils.introspection import build_fqn_map
41
+ from tico.quantization.wrapq.utils.metrics import perplexity
42
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
43
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
46
44
 
47
45
 
48
46
  # Token-budget presets for activation calibration
@@ -215,22 +213,8 @@ def main():
215
213
  # 4. Wrap every layer with PTQWrapper (activation UINT-8)
216
214
  # -------------------------------------------------------------------------
217
215
  print("Wrapping layers with PTQWrapper …")
218
- layers = q_m.model.layers
219
- if not isinstance(layers, (list, torch.nn.ModuleList)):
220
- raise TypeError(f"'model.layers' must be list/ModuleList, got {type(layers)}")
221
-
222
- qcfg = QuantConfig() # default: per-tensor UINT8
223
- wrapped = torch.nn.ModuleList()
224
- for idx, fp_layer in enumerate(layers):
225
- layer_cfg = qcfg.child(f"layer{idx}")
226
- wrapped.append(
227
- PTQWrapper(
228
- fp_layer,
229
- qcfg=layer_cfg,
230
- fp_name=m_to_fqn.get(fp_layer),
231
- )
232
- )
233
- q_m.model.layers = wrapped
216
+ qcfg = PTQConfig() # default: per-tensor UINT8
217
+ prepare(q_m, qcfg)
234
218
 
235
219
  # -------------------------------------------------------------------------
236
220
  # 5. Single-pass activation calibration
@@ -242,11 +226,7 @@ def main():
242
226
  calib_txt = " ".join(dataset_train["text"])[:CALIB_TOKENS]
243
227
  train_ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device)
244
228
 
245
- # (a) Enable CALIB mode on every QuantModuleBase
246
- for l in q_m.model.layers:
247
- l.enable_calibration()
248
-
249
- # (b) Overwrite weight observers with GPTQ statistics
229
+ # Overwrite weight observers with GPTQ statistics
250
230
  if hasattr(q_m, "quantizers") and isinstance(q_m.quantizers, dict):
251
231
  inject_gptq_qparams(q_m, q_m.quantizers)
252
232
  else:
@@ -254,7 +234,7 @@ def main():
254
234
  "[Warn] q_m.quantizers not found or not a dict; skipping GPTQ qparam injection."
255
235
  )
256
236
 
257
- # (c) Forward passes to collect activation ranges
237
+ # Forward passes to collect activation ranges
258
238
  iterator = range(0, train_ids.size(1) - 1, args.stride)
259
239
  if not args.no_tqdm:
260
240
  iterator = tqdm.tqdm(iterator, desc="Act-calibration")
@@ -262,9 +242,8 @@ def main():
262
242
  for i in iterator:
263
243
  q_m(train_ids[:, i : i + args.stride])
264
244
 
265
- # (d) Freeze all Q-params (scale, zero-point)
266
- for l in q_m.model.layers:
267
- l.freeze_qparams()
245
+ # Freeze all Q-params (scale, zero-point)
246
+ convert(q_m)
268
247
 
269
248
  # -------------------------------------------------------------------------
270
249
  # 6. Evaluate perplexity on Wikitext-2
@@ -17,9 +17,9 @@ from typing import Optional, Tuple
17
17
 
18
18
  import torch
19
19
 
20
- from tico.experimental.quantization.ptq.dtypes import DType, UINT8
21
- from tico.experimental.quantization.ptq.observers.base import ObserverBase
22
- from tico.experimental.quantization.ptq.qscheme import QScheme
20
+ from tico.quantization.wrapq.dtypes import DType, UINT8
21
+ from tico.quantization.wrapq.observers.base import ObserverBase
22
+ from tico.quantization.wrapq.qscheme import QScheme
23
23
 
24
24
 
25
25
  class AffineObserverBase(ObserverBase):
@@ -17,8 +17,8 @@ from typing import Optional, Tuple
17
17
 
18
18
  import torch
19
19
 
20
- from tico.experimental.quantization.ptq.dtypes import DType, UINT8
21
- from tico.experimental.quantization.ptq.qscheme import QScheme
20
+ from tico.quantization.wrapq.dtypes import DType, UINT8
21
+ from tico.quantization.wrapq.qscheme import QScheme
22
22
 
23
23
 
24
24
  class ObserverBase(ABC):
@@ -14,8 +14,8 @@
14
14
 
15
15
  import torch
16
16
 
17
- from tico.experimental.quantization.ptq.observers.affine_base import AffineObserverBase
18
- from tico.experimental.quantization.ptq.utils.reduce_utils import channelwise_minmax
17
+ from tico.quantization.wrapq.observers.affine_base import AffineObserverBase
18
+ from tico.quantization.wrapq.utils.reduce_utils import channelwise_minmax
19
19
 
20
20
 
21
21
  class EMAObserver(AffineObserverBase):
@@ -24,7 +24,7 @@ performing any statistics gathering or fake-quantization.
24
24
  """
25
25
  import torch
26
26
 
27
- from tico.experimental.quantization.ptq.observers.affine_base import AffineObserverBase
27
+ from tico.quantization.wrapq.observers.affine_base import AffineObserverBase
28
28
 
29
29
 
30
30
  class IdentityObserver(AffineObserverBase):
@@ -14,8 +14,8 @@
14
14
 
15
15
  import torch
16
16
 
17
- from tico.experimental.quantization.ptq.observers.affine_base import AffineObserverBase
18
- from tico.experimental.quantization.ptq.utils.reduce_utils import channelwise_minmax
17
+ from tico.quantization.wrapq.observers.affine_base import AffineObserverBase
18
+ from tico.quantization.wrapq.utils.reduce_utils import channelwise_minmax
19
19
 
20
20
 
21
21
  class MinMaxObserver(AffineObserverBase):
@@ -14,7 +14,7 @@
14
14
 
15
15
  import torch
16
16
 
17
- from tico.experimental.quantization.ptq.observers.base import ObserverBase
17
+ from tico.quantization.wrapq.observers.base import ObserverBase
18
18
  from tico.utils.mx.mx_ops import quantize_mx
19
19
 
20
20
 
@@ -0,0 +1,179 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from tico.quantization.config.ptq import PTQConfig
21
+ from tico.quantization.quantizer import BaseQuantizer
22
+ from tico.quantization.quantizer_registry import register_quantizer
23
+
24
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
25
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
26
+
27
+
28
+ @register_quantizer(PTQConfig)
29
+ class PTQQuantizer(BaseQuantizer):
30
+ """
31
+ Post-Training Quantization (PTQ) quantizer integrated with the public interface.
32
+
33
+ Features
34
+ --------
35
+ • Automatically wraps quantizable modules using PTQWrapper.
36
+ • Supports leaf-level (single-module) quantization (e.g., prepare(model.fc, PTQConfig())).
37
+ • Enforces strict wrapping if `strict_wrap=True`: raises NotImplementedError if
38
+ no quantizable module was found at any boundary.
39
+ • If `strict_wrap=False`, unquantizable modules are silently skipped.
40
+ """
41
+
42
+ def __init__(self, config: PTQConfig):
43
+ super().__init__(config)
44
+ self.qcfg: PTQConfig = config
45
+ self.strict_wrap: bool = bool(getattr(config, "strict_wrap", True))
46
+
47
+ @torch.no_grad()
48
+ def prepare(
49
+ self,
50
+ model: torch.nn.Module,
51
+ args: Optional[Any] = None,
52
+ kwargs: Optional[Dict[str, Any]] = None,
53
+ ):
54
+ # Wrap the tree (or single module) according to strictness policy
55
+ model = self._wrap_supported(model, self.qcfg)
56
+
57
+ # Switch all quant modules into calibration mode
58
+ if isinstance(model, QuantModuleBase):
59
+ model.enable_calibration()
60
+ for m in model.modules():
61
+ if isinstance(m, QuantModuleBase):
62
+ m.enable_calibration()
63
+ return model
64
+
65
+ @torch.no_grad()
66
+ def convert(self, model):
67
+ # Freeze qparams across the tree (QUANT mode)
68
+ if isinstance(model, QuantModuleBase):
69
+ model.freeze_qparams()
70
+ for m in model.modules():
71
+ if isinstance(m, QuantModuleBase):
72
+ m.freeze_qparams()
73
+ return model
74
+
75
+ def _wrap_supported(
76
+ self,
77
+ root: nn.Module,
78
+ qcfg: PTQConfig,
79
+ ) -> nn.Module:
80
+ """
81
+ Recursively attempt to wrap boundaries. Strictness is applied at every boundary.
82
+ """
83
+ assert not isinstance(root, QuantModuleBase), "The module is already wrapped."
84
+
85
+ # Case A: HuggingFace-style transformers: model.model.layers
86
+ lm = getattr(root, "model", None)
87
+ layers = getattr(lm, "layers", None) if isinstance(lm, nn.Module) else None
88
+ if isinstance(layers, nn.ModuleList):
89
+ new_list = nn.ModuleList()
90
+ for idx, layer in enumerate(layers):
91
+ child_scope = f"layer{idx}"
92
+ child_cfg = qcfg.child(child_scope)
93
+
94
+ # Enforce strictness at the child boundary
95
+ wrapped = self._try_wrap(
96
+ layer,
97
+ child_cfg,
98
+ fp_name=child_scope,
99
+ raise_on_fail=self.strict_wrap,
100
+ )
101
+ new_list.append(wrapped)
102
+ lm.layers = new_list # type: ignore[union-attr]
103
+ return root
104
+
105
+ # Case B: Containers
106
+ if isinstance(root, (nn.Sequential, nn.ModuleList)):
107
+ for i, child in enumerate(list(root)):
108
+ name = str(i)
109
+ child_cfg = qcfg.child(name)
110
+
111
+ wrapped = self._try_wrap(
112
+ child, child_cfg, fp_name=name, raise_on_fail=self.strict_wrap
113
+ )
114
+ if wrapped is child:
115
+ assert not self.strict_wrap
116
+ wrapped = self._wrap_supported(wrapped, child_cfg)
117
+ root[i] = wrapped # type: ignore[index]
118
+
119
+ if isinstance(root, nn.ModuleDict):
120
+ for k, child in list(root.items()):
121
+ name = k
122
+ child_cfg = qcfg.child(name)
123
+
124
+ wrapped = self._try_wrap(
125
+ child, child_cfg, fp_name=name, raise_on_fail=self.strict_wrap
126
+ )
127
+ if wrapped is child:
128
+ assert not self.strict_wrap
129
+ wrapped = self._wrap_supported(wrapped, child_cfg)
130
+ root[k] = wrapped # type: ignore[index]
131
+
132
+ # Case C: Leaf node
133
+ root_name = getattr(root, "_get_name", lambda: None)()
134
+ wrapped = self._try_wrap(
135
+ root, qcfg, fp_name=root_name, raise_on_fail=self.strict_wrap
136
+ )
137
+ if wrapped is not root:
138
+ return wrapped
139
+
140
+ assert not self.strict_wrap
141
+ # Case D: Named children
142
+ for name, child in list(root.named_children()):
143
+ child_cfg = qcfg.child(name)
144
+
145
+ wrapped = self._try_wrap(
146
+ child, child_cfg, fp_name=name, raise_on_fail=self.strict_wrap
147
+ )
148
+ if wrapped is child:
149
+ assert not self.strict_wrap
150
+ wrapped = self._wrap_supported(wrapped, child_cfg)
151
+ setattr(root, name, wrapped)
152
+
153
+ return root
154
+
155
+ def _try_wrap(
156
+ self,
157
+ module: nn.Module,
158
+ qcfg_for_child: PTQConfig,
159
+ *,
160
+ fp_name: Optional[str],
161
+ raise_on_fail: bool,
162
+ ) -> nn.Module:
163
+ """
164
+ Attempt to wrap a boundary with PTQWrapper.
165
+
166
+ Behavior:
167
+ • If PTQWrapper succeeds: return wrapped module.
168
+ • If PTQWrapper raises NotImplementedError:
169
+ - raise_on_fail=True -> re-raise (strict)
170
+ - raise_on_fail=False -> return original module (permissive)
171
+ """
172
+ try:
173
+ return PTQWrapper(module, qcfg=qcfg_for_child, fp_name=fp_name)
174
+ except NotImplementedError as e:
175
+ if raise_on_fail:
176
+ raise NotImplementedError(
177
+ f"PTQQuantizer: no quantization wrapper for {type(module).__name__}"
178
+ ) from e
179
+ return module
@@ -16,11 +16,9 @@ from typing import Callable, Dict, List, Optional, Tuple
16
16
 
17
17
  import torch
18
18
 
19
- from tico.experimental.quantization.evaluation.metric import MetricCalculator
20
- from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
21
- from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
22
- QuantModuleBase,
23
- )
19
+ from tico.quantization.evaluation.metric import MetricCalculator
20
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
21
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
24
22
 
25
23
 
26
24
  def build_fqn_map(root: torch.nn.Module) -> dict[torch.nn.Module, str]:
@@ -98,7 +98,8 @@ def perplexity(
98
98
 
99
99
  input_ids = input_ids_full[:, begin:end]
100
100
  target_ids = input_ids.clone()
101
- target_ids[:, :-trg_len] = ignore_index # mask previously-scored tokens
101
+ # mask previously-scored tokens
102
+ target_ids[:, :-trg_len] = ignore_index # type: ignore[assignment]
102
103
 
103
104
  with torch.no_grad():
104
105
  outputs = model(input_ids, labels=target_ids)
@@ -106,7 +107,7 @@ def perplexity(
106
107
  neg_log_likelihood = outputs.loss
107
108
 
108
109
  # exact number of labels that contributed to loss
109
- loss_tokens = (target_ids[:, 1:] != ignore_index).sum().item()
110
+ loss_tokens = (target_ids[:, 1:] != ignore_index).sum().item() # type: ignore[attr-defined]
110
111
  nll_sum += neg_log_likelihood * loss_tokens
111
112
  n_tokens += int(loss_tokens)
112
113
 
@@ -1,4 +1,4 @@
1
- from tico.experimental.quantization.ptq.wrappers.fairseq.quant_mha import (
1
+ from tico.quantization.wrapq.wrappers.fairseq.quant_mha import (
2
2
  QuantFairseqMultiheadAttention,
3
3
  )
4
4
 
@@ -25,12 +25,10 @@ import torch
25
25
  import torch.nn.functional as F
26
26
  from torch import nn, Tensor
27
27
 
28
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
29
- from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
30
- from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
31
- QuantModuleBase,
32
- )
33
- from tico.experimental.quantization.ptq.wrappers.registry import try_register
28
+ from tico.quantization.config.ptq import PTQConfig
29
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
30
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
31
+ from tico.quantization.wrapq.wrappers.registry import try_register
34
32
 
35
33
 
36
34
  @try_register("fairseq.models.transformer.TransformerDecoderBase")
@@ -53,7 +51,7 @@ class QuantFairseqDecoder(QuantModuleBase):
53
51
  self,
54
52
  fp_decoder: nn.Module,
55
53
  *,
56
- qcfg: Optional[QuantConfig] = None,
54
+ qcfg: Optional[PTQConfig] = None,
57
55
  fp_name: Optional[str] = None,
58
56
  ):
59
57
  super().__init__(qcfg, fp_name=fp_name)
@@ -116,7 +114,7 @@ class QuantFairseqDecoder(QuantModuleBase):
116
114
 
117
115
  prefix = _safe_prefix(fp_name)
118
116
 
119
- # Prepare child QuantConfig namespaces: layers/<idx>
117
+ # Prepare child PTQConfig namespaces: layers/<idx>
120
118
  layers_qcfg = qcfg.child("layers") if qcfg else None
121
119
  for i, layer in enumerate(fp_layers):
122
120
  child_cfg = layers_qcfg.child(str(i)) if layers_qcfg else None
@@ -23,15 +23,13 @@ from typing import Dict, Iterable, List, Optional, Tuple
23
23
  import torch
24
24
  from torch import nn, Tensor
25
25
 
26
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
27
- from tico.experimental.quantization.ptq.wrappers.fairseq.quant_mha import (
26
+ from tico.quantization.config.ptq import PTQConfig
27
+ from tico.quantization.wrapq.wrappers.fairseq.quant_mha import (
28
28
  QuantFairseqMultiheadAttention,
29
29
  )
30
- from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
31
- from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
32
- QuantModuleBase,
33
- )
34
- from tico.experimental.quantization.ptq.wrappers.registry import try_register
30
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
31
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
32
+ from tico.quantization.wrapq.wrappers.registry import try_register
35
33
 
36
34
 
37
35
  @try_register("fairseq.modules.transformer_layer.TransformerDecoderLayerBase")
@@ -55,7 +53,7 @@ class QuantFairseqDecoderLayer(QuantModuleBase):
55
53
  self,
56
54
  fp_layer: nn.Module,
57
55
  *,
58
- qcfg: Optional[QuantConfig] = None,
56
+ qcfg: Optional[PTQConfig] = None,
59
57
  fp_name: Optional[str] = None,
60
58
  ):
61
59
  super().__init__(qcfg, fp_name=fp_name)
@@ -25,12 +25,10 @@ import torch
25
25
  import torch.nn as nn
26
26
  from torch import Tensor
27
27
 
28
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
29
- from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
30
- from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
31
- QuantModuleBase,
32
- )
33
- from tico.experimental.quantization.ptq.wrappers.registry import try_register
28
+ from tico.quantization.config.ptq import PTQConfig
29
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
30
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
31
+ from tico.quantization.wrapq.wrappers.registry import try_register
34
32
 
35
33
 
36
34
  @try_register("fairseq.models.transformer.TransformerEncoderBase")
@@ -56,7 +54,7 @@ class QuantFairseqEncoder(QuantModuleBase):
56
54
  self,
57
55
  fp_encoder: nn.Module,
58
56
  *,
59
- qcfg: Optional[QuantConfig] = None,
57
+ qcfg: Optional[PTQConfig] = None,
60
58
  fp_name: Optional[str] = None,
61
59
  use_external_inputs: bool = False, # export-mode flag
62
60
  return_type: Literal["tensor", "dict"] = "dict",
@@ -100,7 +98,7 @@ class QuantFairseqEncoder(QuantModuleBase):
100
98
  fp_layers = list(fp_encoder.layers) # type: ignore[arg-type]
101
99
  self.layers = nn.ModuleList()
102
100
 
103
- # Prepare child QuantConfig namespaces: layers/<idx>
101
+ # Prepare child PTQConfig namespaces: layers/<idx>
104
102
  layers_qcfg = qcfg.child("layers") if qcfg else None
105
103
  for i, layer in enumerate(fp_layers):
106
104
  child_cfg = layers_qcfg.child(str(i)) if layers_qcfg else None
@@ -23,15 +23,13 @@ from typing import Optional
23
23
  import torch.nn as nn
24
24
  from torch import Tensor
25
25
 
26
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
27
- from tico.experimental.quantization.ptq.wrappers.fairseq.quant_mha import (
26
+ from tico.quantization.config.ptq import PTQConfig
27
+ from tico.quantization.wrapq.wrappers.fairseq.quant_mha import (
28
28
  QuantFairseqMultiheadAttention,
29
29
  )
30
- from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
31
- from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
32
- QuantModuleBase,
33
- )
34
- from tico.experimental.quantization.ptq.wrappers.registry import try_register
30
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
31
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
32
+ from tico.quantization.wrapq.wrappers.registry import try_register
35
33
 
36
34
 
37
35
  @try_register("fairseq.modules.transformer_layer.TransformerEncoderLayerBase")
@@ -49,7 +47,7 @@ class QuantFairseqEncoderLayer(QuantModuleBase):
49
47
  self,
50
48
  fp_layer: nn.Module,
51
49
  *,
52
- qcfg: Optional[QuantConfig] = None,
50
+ qcfg: Optional[PTQConfig] = None,
53
51
  fp_name: Optional[str] = None,
54
52
  ):
55
53
  super().__init__(qcfg, fp_name=fp_name)
@@ -24,12 +24,10 @@ import torch
24
24
  import torch.nn as nn
25
25
  import torch.nn.functional as F
26
26
 
27
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
28
- from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
29
- from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
30
- QuantModuleBase,
31
- )
32
- from tico.experimental.quantization.ptq.wrappers.registry import try_register
27
+ from tico.quantization.config.ptq import PTQConfig
28
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
29
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
30
+ from tico.quantization.wrapq.wrappers.registry import try_register
33
31
 
34
32
 
35
33
  @try_register("fairseq.modules.multihead_attention.MultiheadAttention")
@@ -59,7 +57,7 @@ class QuantFairseqMultiheadAttention(QuantModuleBase):
59
57
  self,
60
58
  fp_attn: nn.Module,
61
59
  *,
62
- qcfg: Optional[QuantConfig] = None,
60
+ qcfg: Optional[PTQConfig] = None,
63
61
  fp_name: Optional[str] = None,
64
62
  max_seq: int = 4096,
65
63
  use_static_causal: bool = False,
@@ -17,12 +17,10 @@ from typing import Optional, Tuple
17
17
  import torch
18
18
  import torch.nn as nn
19
19
 
20
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
21
- from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
22
- from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
23
- QuantModuleBase,
24
- )
25
- from tico.experimental.quantization.ptq.wrappers.registry import try_register
20
+ from tico.quantization.config.ptq import PTQConfig
21
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
22
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
23
+ from tico.quantization.wrapq.wrappers.registry import try_register
26
24
 
27
25
 
28
26
  @try_register(
@@ -34,7 +32,7 @@ class QuantLlamaAttention(QuantModuleBase):
34
32
  self,
35
33
  fp_attn: nn.Module,
36
34
  *,
37
- qcfg: Optional[QuantConfig] = None,
35
+ qcfg: Optional[PTQConfig] = None,
38
36
  fp_name: Optional[str] = None,
39
37
  ):
40
38
  super().__init__(qcfg, fp_name=fp_name)
@@ -17,16 +17,12 @@ from typing import Optional, Tuple
17
17
  import torch
18
18
  import torch.nn as nn
19
19
 
20
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
21
- from tico.experimental.quantization.ptq.wrappers.llama.quant_attn import (
22
- QuantLlamaAttention,
23
- )
24
- from tico.experimental.quantization.ptq.wrappers.llama.quant_mlp import QuantLlamaMLP
25
- from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
26
- from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
27
- QuantModuleBase,
28
- )
29
- from tico.experimental.quantization.ptq.wrappers.registry import try_register
20
+ from tico.quantization.config.ptq import PTQConfig
21
+ from tico.quantization.wrapq.wrappers.llama.quant_attn import QuantLlamaAttention
22
+ from tico.quantization.wrapq.wrappers.llama.quant_mlp import QuantLlamaMLP
23
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
24
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
25
+ from tico.quantization.wrapq.wrappers.registry import try_register
30
26
 
31
27
 
32
28
  @try_register("transformers.models.llama.modeling_llama.LlamaDecoderLayer")
@@ -56,7 +52,7 @@ class QuantLlamaDecoderLayer(QuantModuleBase):
56
52
  self,
57
53
  fp_layer: nn.Module,
58
54
  *,
59
- qcfg: Optional[QuantConfig] = None,
55
+ qcfg: Optional[PTQConfig] = None,
60
56
  fp_name: Optional[str] = None,
61
57
  return_type: Optional[str] = None,
62
58
  ):
@@ -165,7 +161,7 @@ class QuantLlamaDecoderLayer(QuantModuleBase):
165
161
  # - If use_cache: always return (hidden_states, present_key_value)
166
162
  # - Else: return as configured (tuple/tensor) for HF compatibility
167
163
  if use_cache:
168
- return hidden_states, present_key_value
164
+ return hidden_states, present_key_value # type: ignore[return-value]
169
165
 
170
166
  if self.return_type == "tuple":
171
167
  return (hidden_states,)
@@ -17,12 +17,10 @@ from typing import Optional
17
17
  import torch
18
18
  import torch.nn as nn
19
19
 
20
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
21
- from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
22
- from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
23
- QuantModuleBase,
24
- )
25
- from tico.experimental.quantization.ptq.wrappers.registry import try_register
20
+ from tico.quantization.config.ptq import PTQConfig
21
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
22
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
23
+ from tico.quantization.wrapq.wrappers.registry import try_register
26
24
 
27
25
 
28
26
  @try_register("transformers.models.llama.modeling_llama.LlamaMLP")
@@ -31,7 +29,7 @@ class QuantLlamaMLP(QuantModuleBase):
31
29
  self,
32
30
  mlp_fp: nn.Module,
33
31
  *,
34
- qcfg: Optional[QuantConfig] = None,
32
+ qcfg: Optional[PTQConfig] = None,
35
33
  fp_name: Optional[str] = None,
36
34
  ):
37
35
  super().__init__(qcfg, fp_name=fp_name)
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -17,12 +17,11 @@ from typing import Iterable, Optional, Tuple
17
17
  import torch
18
18
  import torch.nn as nn
19
19
 
20
- from tico.experimental.quantization.ptq.mode import Mode
21
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
22
- from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
23
- QuantModuleBase,
24
- )
25
- from tico.experimental.quantization.ptq.wrappers.registry import register
20
+ from tico.quantization.config.ptq import PTQConfig
21
+
22
+ from tico.quantization.wrapq.mode import Mode
23
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
24
+ from tico.quantization.wrapq.wrappers.registry import register
26
25
 
27
26
 
28
27
  @register(nn.LayerNorm)
@@ -46,7 +45,7 @@ class QuantLayerNorm(QuantModuleBase):
46
45
  self,
47
46
  fp: nn.LayerNorm,
48
47
  *,
49
- qcfg: Optional[QuantConfig] = None,
48
+ qcfg: Optional[PTQConfig] = None,
50
49
  fp_name: Optional[str] = None
51
50
  ):
52
51
  super().__init__(qcfg, fp_name=fp_name)