tico 0.1.0.dev250916__py3-none-any.whl → 0.1.0.dev250918__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tico might be problematic. Click here for more details.

tico/__init__.py CHANGED
@@ -29,7 +29,7 @@ __all__ = [
29
29
  ]
30
30
 
31
31
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
32
- __version__ = "0.1.0.dev250916"
32
+ __version__ = "0.1.0.dev250918"
33
33
 
34
34
  MINIMUM_SUPPORTED_VERSION = "2.5.0"
35
35
  SECURE_TORCH_VERSION = "2.6.0"
tico/config/v1.py CHANGED
@@ -20,6 +20,9 @@ from tico.config.base import CompileConfigBase
20
20
  @dataclass
21
21
  class CompileConfigV1(CompileConfigBase):
22
22
  legalize_causal_mask_value: bool = False
23
+ remove_constant_input: bool = False
24
+ convert_lhs_const_mm_to_fc: bool = False
25
+ convert_rhs_const_mm_to_fc: bool = True
23
26
 
24
27
  def get(self, name: str):
25
28
  return super().get(name)
@@ -77,7 +77,7 @@ def main():
77
77
  "--dtype",
78
78
  choices=list(DTYPE_MAP.keys()),
79
79
  default="float32",
80
- help="Model dtype for load (float32|bfloat16|float16).",
80
+ help=f"Model dtype for load.",
81
81
  )
82
82
  parser.add_argument(
83
83
  "--stride", type=int, default=512, help="Sliding-window stride for perplexity."
@@ -126,7 +126,9 @@ def main():
126
126
  print(f"DType : {args.dtype}")
127
127
  print(f"Stride : {args.stride}")
128
128
  print(f"Use HF cache? : {args.use_cache}")
129
- print(f"Calib preset : {args.calib_preset}")
129
+ print(
130
+ f"Calib preset : {args.calib_preset} ({TOKENS[args.calib_preset]:,} tokens)"
131
+ )
130
132
  print()
131
133
 
132
134
  # -------------------------------------------------------------------------
@@ -12,19 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import torch
16
- import tqdm
17
- from datasets import load_dataset
18
- from transformers import AutoModelForCausalLM, AutoTokenizer
19
-
20
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
21
- from tico.experimental.quantization.ptq.utils.introspection import (
22
- build_fqn_map,
23
- compare_layer_outputs,
24
- save_fp_outputs,
25
- )
26
- from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
27
-
28
15
  # ============================================================================
29
16
  # LAYER-WISE DIFF DEBUGGING PIPELINE
30
17
  # ----------------------------------------------------------------------------
@@ -43,12 +30,21 @@ from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
43
30
  # problematic modules during post-training quantization.
44
31
  # ============================================================================
45
32
 
46
- # -------------------------------------------------------------------------
47
- # 0. Global configuration
48
- # -------------------------------------------------------------------------
49
- MODEL_NAME = "meta-llama/Meta-Llama-3-1B"
50
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
51
- STRIDE = 512
33
+ import argparse
34
+ import sys
35
+
36
+ import torch
37
+ import tqdm
38
+ from datasets import load_dataset
39
+ from transformers import AutoModelForCausalLM, AutoTokenizer
40
+
41
+ from tico.experimental.quantization.ptq.quant_config import QuantConfig
42
+ from tico.experimental.quantization.ptq.utils.introspection import (
43
+ build_fqn_map,
44
+ compare_layer_outputs,
45
+ save_fp_outputs,
46
+ )
47
+ from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
52
48
 
53
49
  # Token-budget presets for activation calibration
54
50
  TOKENS: dict[str, int] = {
@@ -59,71 +55,185 @@ TOKENS: dict[str, int] = {
59
55
  # Production / 4-bit observer smoothing
60
56
  "production": 200_000,
61
57
  }
62
- CALIB_TOKENS = TOKENS["baseline"]
63
- print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
64
-
65
- # -------------------------------------------------------------------------
66
- # 1. Load the FP backbone
67
- # -------------------------------------------------------------------------
68
- print("Loading FP model …")
69
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
70
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE).eval()
71
- model.config.use_cache = False # disable KV-cache → full forward
72
- m_to_fqn = build_fqn_map(model) # map modules → fully-qualified names
73
-
74
- # Use Wikitext-2 train split for calibration.
75
- dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
76
-
77
- # -------------------------------------------------------------------------
78
- # 2. Wrap every layer with PTQWrapper (UINT-8 activations)
79
- # -------------------------------------------------------------------------
80
- print("Wrapping layers with PTQWrapper …")
81
- qcfg = QuantConfig() # default: per-tensor UINT8
82
-
83
- new_layers = torch.nn.ModuleList()
84
- for idx, fp_layer in enumerate(model.model.layers):
85
- layer_cfg = qcfg.child(f"layer{idx}")
86
- q_layer = PTQWrapper(
87
- fp_layer,
88
- qcfg=layer_cfg,
89
- fp_name=m_to_fqn.get(fp_layer),
90
- )
91
- new_layers.append(q_layer)
92
58
 
93
- model.model.layers = new_layers # swap in quant wrappers
94
-
95
- # -------------------------------------------------------------------------
96
- # 3. Activation calibration plus FP-vs-UINT8 diffing
97
- # -------------------------------------------------------------------------
98
- print("Calibrating UINT-8 observers …")
99
- calib_txt = " ".join(dataset["text"])[:CALIB_TOKENS]
100
- ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(DEVICE)
59
+ DTYPE_MAP = {
60
+ "float32": torch.float32,
61
+ "bfloat16": torch.bfloat16,
62
+ "float16": torch.float16,
63
+ }
101
64
 
102
- # (a) Enable CALIB mode on every QuantModuleBase
103
- for l in model.model.layers:
104
- l.enable_calibration()
65
+ # Hardcoded dataset settings
66
+ DATASET_NAME = "wikitext"
67
+ DATASET_CONFIG = "wikitext-2-raw-v1"
68
+ TRAIN_SPLIT = "train"
105
69
 
106
- # Save reference FP activations before observers clamp/quantize
107
- save_handles, act_cache = save_fp_outputs(model)
108
70
 
109
- with torch.no_grad():
110
- for i in tqdm.trange(0, ids.size(1) - 1, STRIDE, desc="Act-calibration"):
111
- inputs = ids[:, i : i + STRIDE]
112
- model(inputs) # observers collect act. ranges
71
+ def main():
72
+ parser = argparse.ArgumentParser(
73
+ description="Layer-wise diff debugging pipeline for PTQ"
74
+ )
75
+ parser.add_argument(
76
+ "--model", type=str, required=True, help="HF repo name or local path."
77
+ )
78
+ parser.add_argument(
79
+ "--device",
80
+ type=str,
81
+ default="cuda" if torch.cuda.is_available() else "cpu",
82
+ help="Device to run on (cuda|cpu|mps).",
83
+ )
84
+ parser.add_argument(
85
+ "--dtype",
86
+ choices=list(DTYPE_MAP.keys()),
87
+ default="float32",
88
+ help=f"Model dtype for load.",
89
+ )
90
+ parser.add_argument(
91
+ "--stride",
92
+ type=int,
93
+ default=512,
94
+ help="Sliding-window stride used during calibration.",
95
+ )
96
+ parser.add_argument(
97
+ "--calib-preset",
98
+ choices=list(TOKENS.keys()),
99
+ default="debug",
100
+ help="Calibration token budget preset.",
101
+ )
102
+ parser.add_argument("--seed", type=int, default=42, help="Random seed.")
103
+ parser.add_argument(
104
+ "--trust-remote-code",
105
+ action="store_true",
106
+ help="Enable only if you trust the model repo code.",
107
+ )
108
+ parser.add_argument(
109
+ "--hf-token",
110
+ type=str,
111
+ default=None,
112
+ help="Optional HF token for gated/private repos.",
113
+ )
114
+ parser.add_argument(
115
+ "--use-cache",
116
+ dest="use_cache",
117
+ action="store_true",
118
+ default=False,
119
+ help="Use model KV cache if enabled (off by default).",
120
+ )
121
+ parser.add_argument(
122
+ "--no-tqdm", action="store_true", help="Disable tqdm progress bars."
123
+ )
113
124
 
114
- # Remove save hooks now that FP activations are cached
115
- for h in save_handles:
116
- h.remove()
125
+ args = parser.parse_args()
117
126
 
118
- # (b) Freeze (scale, zero-point) after calibration
119
- for l in model.model.layers:
120
- l.freeze_qparams()
127
+ # Basic setup
128
+ torch.manual_seed(args.seed)
129
+ device = torch.device(args.device)
130
+ dtype = DTYPE_MAP[args.dtype] # noqa: E999 (kept readable)
121
131
 
122
- # (c) Register diff hooks and measure per-layer deltas
123
- cmp_handles = compare_layer_outputs(model, act_cache, metrics=["diff", "peir"])
124
- # Use same inputs for comparison.
125
- model(inputs)
132
+ print("=== Config ===")
133
+ print(f"Model : {args.model}")
134
+ print(f"Device : {device.type}")
135
+ print(f"DType : {args.dtype}")
136
+ print(f"Stride : {args.stride}")
137
+ print(
138
+ f"Calib preset : {args.calib_preset} ({TOKENS[args.calib_preset]:,} tokens)"
139
+ )
140
+ print(f"Use HF cache? : {args.use_cache}")
141
+ print()
142
+
143
+ # -------------------------------------------------------------------------
144
+ # 1. Load the FP backbone and tokenizer
145
+ # -------------------------------------------------------------------------
146
+ print("Loading FP model …")
147
+ tokenizer = AutoTokenizer.from_pretrained(
148
+ args.model,
149
+ trust_remote_code=args.trust_remote_code,
150
+ token=args.hf_token,
151
+ )
152
+ model = (
153
+ AutoModelForCausalLM.from_pretrained(
154
+ args.model,
155
+ torch_dtype=dtype,
156
+ trust_remote_code=args.trust_remote_code,
157
+ token=args.hf_token,
158
+ )
159
+ .to(device)
160
+ .eval()
161
+ )
126
162
 
127
- assert isinstance(cmp_handles, list)
128
- for h in cmp_handles:
129
- h.remove()
163
+ # Disable KV cache to force full forward passes for introspection
164
+ model.config.use_cache = args.use_cache
165
+
166
+ # Build module -> FQN map before wrapping
167
+ m_to_fqn = build_fqn_map(model)
168
+
169
+ # Prepare calibration inputs (HF Wikitext-2 train split)
170
+ CALIB_TOKENS = TOKENS[args.calib_preset]
171
+ print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
172
+ # Use Wikitext-2 train split for calibration.
173
+ dataset = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT)
174
+
175
+ # -------------------------------------------------------------------------
176
+ # 2. Wrap every layer with PTQWrapper (UINT-8 activations)
177
+ # -------------------------------------------------------------------------
178
+ print("Wrapping layers with PTQWrapper …")
179
+ qcfg = QuantConfig() # default: per-tensor UINT8
180
+
181
+ new_layers = torch.nn.ModuleList()
182
+ for idx, fp_layer in enumerate(model.model.layers):
183
+ layer_cfg = qcfg.child(f"layer{idx}")
184
+ q_layer = PTQWrapper(
185
+ fp_layer,
186
+ qcfg=layer_cfg,
187
+ fp_name=m_to_fqn.get(fp_layer),
188
+ )
189
+ new_layers.append(q_layer)
190
+
191
+ model.model.layers = new_layers # swap in quant wrappers
192
+
193
+ # -------------------------------------------------------------------------
194
+ # 3. Activation calibration plus FP-vs-UINT8 diffing
195
+ # -------------------------------------------------------------------------
196
+ print("Calibrating UINT-8 observers …")
197
+ calib_txt = " ".join(dataset["text"])[:CALIB_TOKENS]
198
+ ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device)
199
+
200
+ # (a) Enable CALIB mode on every QuantModuleBase
201
+ for l in model.model.layers:
202
+ l.enable_calibration()
203
+
204
+ # Save reference FP activations before observers clamp/quantize
205
+ save_handles, act_cache = save_fp_outputs(model)
206
+
207
+ iterator = range(0, ids.size(1) - 1, args.stride)
208
+ if not args.no_tqdm:
209
+ iterator = tqdm.tqdm(iterator, desc="Act-Calibration")
210
+ with torch.no_grad():
211
+ for i in iterator:
212
+ inputs = ids[:, i : i + args.stride]
213
+ model(inputs) # observers collect act. ranges
214
+
215
+ # Remove save hooks now that FP activations are cached
216
+ for h in save_handles:
217
+ h.remove()
218
+
219
+ # (b) Freeze (scale, zero-point) after calibration
220
+ for l in model.model.layers:
221
+ l.freeze_qparams()
222
+
223
+ # (c) Register diff hooks and measure per-layer deltas
224
+ cmp_handles = compare_layer_outputs(model, act_cache, metrics=["diff", "peir"])
225
+ # Use same inputs for comparison.
226
+ with torch.no_grad():
227
+ model(inputs)
228
+
229
+ assert isinstance(cmp_handles, list)
230
+ for h in cmp_handles:
231
+ h.remove()
232
+
233
+
234
+ if __name__ == "__main__":
235
+ try:
236
+ main()
237
+ except Exception as e:
238
+ print(f"\n[Error] {e}", file=sys.stderr)
239
+ sys.exit(1)
@@ -24,6 +24,8 @@
24
24
  # 6. Freeze all Q-params and compute Wikitext-2 perplexity.
25
25
  # =============================================================================
26
26
 
27
+ import argparse
28
+ import sys
27
29
  from typing import Any
28
30
 
29
31
  import torch
@@ -42,12 +44,6 @@ from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
42
44
  QuantModuleBase,
43
45
  )
44
46
 
45
- # -------------------------------------------------------------------------
46
- # 0. Global configuration
47
- # -------------------------------------------------------------------------
48
- MODEL_NAME = "meta-llama/Meta-Llama-3-1B"
49
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
50
- STRIDE = 512
51
47
 
52
48
  # Token-budget presets for activation calibration
53
49
  TOKENS: dict[str, int] = {
@@ -58,7 +54,18 @@ TOKENS: dict[str, int] = {
58
54
  # Production / 4-bit observer smoothing
59
55
  "production": 200_000,
60
56
  }
61
- CALIB_TOKENS = TOKENS["baseline"]
57
+
58
+ DTYPE_MAP = {
59
+ "float32": torch.float32,
60
+ "bfloat16": torch.bfloat16,
61
+ "float16": torch.float16,
62
+ }
63
+
64
+ # Hardcoded dataset settings
65
+ DATASET_NAME = "wikitext"
66
+ DATASET_CONFIG = "wikitext-2-raw-v1"
67
+ TRAIN_SPLIT = "train"
68
+ TEST_SPLIT = "test"
62
69
 
63
70
  # -------------------------------------------------------------------------
64
71
  # 1. Helper — copy GPTQ (scale, zp) into PTQ observers
@@ -89,77 +96,191 @@ def inject_gptq_qparams(
89
96
  obs.load_qparams(quantizer.scale, quantizer.zero, lock=True)
90
97
 
91
98
 
92
- # -------------------------------------------------------------------------
93
- # 2. Load the FP backbone
94
- # -------------------------------------------------------------------------
95
- print("Loading FP model …")
96
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
97
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE).eval()
98
- model.config.use_cache = False # disable KV-cache → full forward
99
- m_to_fqn = build_fqn_map(model) # map modules → fully-qualified names
99
+ def main():
100
+ parser = argparse.ArgumentParser(
101
+ description="GPTQ+PTQ pipeline (weight-only + activation UINT8)"
102
+ )
103
+ parser.add_argument(
104
+ "--model", type=str, required=True, help="HF repo name or local path."
105
+ )
106
+ parser.add_argument(
107
+ "--device",
108
+ type=str,
109
+ default="cuda" if torch.cuda.is_available() else "cpu",
110
+ help="Device to run on (cuda|cpu|mps).",
111
+ )
112
+ parser.add_argument(
113
+ "--dtype",
114
+ choices=list(DTYPE_MAP.keys()),
115
+ default="float32",
116
+ help="Model dtype for load.",
117
+ )
118
+ parser.add_argument(
119
+ "--stride",
120
+ type=int,
121
+ default=512,
122
+ help="Sliding-window stride used for calibration and eval.",
123
+ )
124
+ parser.add_argument(
125
+ "--calib-preset",
126
+ choices=list(TOKENS.keys()),
127
+ default="debug",
128
+ help="Activation calibration token budget preset.",
129
+ )
130
+ parser.add_argument("--seed", type=int, default=42, help="Random seed.")
131
+ parser.add_argument(
132
+ "--trust-remote-code",
133
+ action="store_true",
134
+ help="Enable only if you trust the model repo code.",
135
+ )
136
+ parser.add_argument(
137
+ "--hf-token",
138
+ type=str,
139
+ default=None,
140
+ help="Optional HF token for gated/private repos.",
141
+ )
142
+ parser.add_argument(
143
+ "--use-cache",
144
+ dest="use_cache",
145
+ action="store_true",
146
+ default=False,
147
+ help="Use model KV cache if enabled (off by default).",
148
+ )
149
+ parser.add_argument(
150
+ "--no-tqdm", action="store_true", help="Disable tqdm progress bars."
151
+ )
100
152
 
101
- # -------------------------------------------------------------------------
102
- # 3. Run GPTQ (weight-only) pass
103
- # -------------------------------------------------------------------------
104
- print("Applying GPTQ …")
105
- dataset = load_dataset("wikiText", "wikitext-2-raw-v1", split="test")
106
- q_m = prepare(model, GPTQConfig(), inplace=True)
153
+ args = parser.parse_args()
107
154
 
108
- for d in tqdm.tqdm(dataset, desc="GPTQ calibration"):
109
- ids = tokenizer(d["text"], return_tensors="pt").input_ids.to(DEVICE)
110
- q_m(ids) # observers gather weight stats
155
+ # Basic setup
156
+ torch.manual_seed(args.seed)
157
+ device = torch.device(args.device)
158
+ dtype = DTYPE_MAP[args.dtype]
111
159
 
112
- q_m = convert(q_m, inplace=True) # materialize INT-weight tensors
160
+ print("=== Config ===")
161
+ print(f"Model : {args.model}")
162
+ print(f"Device : {device.type}")
163
+ print(f"DType : {args.dtype}")
164
+ print(f"Stride : {args.stride}")
165
+ print(
166
+ f"Calib preset : {args.calib_preset} ({TOKENS[args.calib_preset]:,} tokens)"
167
+ )
168
+ print(f"Use HF cache? : {args.use_cache}")
169
+ print()
113
170
 
114
- # -------------------------------------------------------------------------
115
- # 4. Wrap every layer with PTQWrapper (activation UINT-8)
116
- # -------------------------------------------------------------------------
117
- qcfg = QuantConfig() # default: per-tensor UINT8
118
- new_layers = torch.nn.ModuleList()
119
-
120
- for idx, fp_layer in enumerate(q_m.model.layers):
121
- layer_cfg = qcfg.child(f"layer{idx}")
122
- q_layer = PTQWrapper(
123
- fp_layer,
124
- qcfg=layer_cfg,
125
- fp_name=m_to_fqn.get(fp_layer),
171
+ # -------------------------------------------------------------------------
172
+ # 2. Load the FP backbone and tokenizer
173
+ # -------------------------------------------------------------------------
174
+ print("Loading FP model …")
175
+ tokenizer = AutoTokenizer.from_pretrained(
176
+ args.model,
177
+ trust_remote_code=args.trust_remote_code,
178
+ token=args.hf_token,
179
+ )
180
+ model = (
181
+ AutoModelForCausalLM.from_pretrained(
182
+ args.model,
183
+ torch_dtype=dtype,
184
+ trust_remote_code=args.trust_remote_code,
185
+ token=args.hf_token,
186
+ )
187
+ .to(device)
188
+ .eval()
126
189
  )
127
- new_layers.append(q_layer)
128
190
 
129
- q_m.model.layers = new_layers
191
+ model.config.use_cache = args.use_cache
130
192
 
131
- # -------------------------------------------------------------------------
132
- # 5. Single-pass activation calibration
133
- # -------------------------------------------------------------------------
134
- print("Calibrating UINT-8 observers …")
135
- calib_txt = " ".join(
136
- load_dataset("wikitext", "wikitext-2-raw-v1", split="train")["text"]
137
- )[:CALIB_TOKENS]
138
- ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(DEVICE)
193
+ # Build module -> FQN map BEFORE wrapping
194
+ m_to_fqn = build_fqn_map(model)
139
195
 
140
- # (a) Enable CALIB mode on every QuantModuleBase
141
- for l in q_m.model.layers:
142
- l.enable_calibration()
196
+ # -------------------------------------------------------------------------
197
+ # 3. Run GPTQ (weight-only) pass
198
+ # -------------------------------------------------------------------------
199
+ print("Applying GPTQ …")
200
+ dataset_test = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TEST_SPLIT)
201
+ q_m = prepare(model, GPTQConfig(), inplace=True)
143
202
 
144
- # (b) Overwrite weight observers with GPTQ statistics
145
- inject_gptq_qparams(q_m, q_m.quantizers)
203
+ it = (
204
+ dataset_test
205
+ if args.no_tqdm
206
+ else tqdm.tqdm(dataset_test, desc="GPTQ calibration")
207
+ )
208
+ for d in it:
209
+ ids = tokenizer(d["text"], return_tensors="pt").input_ids.to(device)
210
+ q_m(ids) # observers gather weight stats
146
211
 
147
- with torch.no_grad():
148
- for i in tqdm.trange(0, ids.size(1) - 1, STRIDE, desc="Act-calibration"):
149
- q_m(ids[:, i : i + STRIDE]) # observers collect act. ranges
212
+ q_m = convert(q_m, inplace=True) # materialize INT-weight tensors
150
213
 
151
- # (c) Freeze all Q-params (scale, zp)
152
- for l in q_m.model.layers:
153
- l.freeze_qparams()
214
+ # -------------------------------------------------------------------------
215
+ # 4. Wrap every layer with PTQWrapper (activation UINT-8)
216
+ # -------------------------------------------------------------------------
217
+ print("Wrapping layers with PTQWrapper …")
218
+ layers = q_m.model.layers
219
+ if not isinstance(layers, (list, torch.nn.ModuleList)):
220
+ raise TypeError(f"'model.layers' must be list/ModuleList, got {type(layers)}")
154
221
 
155
- # -------------------------------------------------------------------------
156
- # 6. Evaluate perplexity on Wikitext-2
157
- # -------------------------------------------------------------------------
158
- print("\nCalculating perplexities ")
159
- test_ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
160
- enc = tokenizer("\n\n".join(test_ds["text"]), return_tensors="pt")
161
- ppl_uint8 = perplexity(q_m, enc, DEVICE, stride=STRIDE)
162
-
163
- print("\n┌── Wikitext-2 test perplexity ─────────────")
164
- print(f"│ UINT-8 : {ppl_uint8:8.2f}")
165
- print("└───────────────────────────────────────────")
222
+ qcfg = QuantConfig() # default: per-tensor UINT8
223
+ wrapped = torch.nn.ModuleList()
224
+ for idx, fp_layer in enumerate(layers):
225
+ layer_cfg = qcfg.child(f"layer{idx}")
226
+ wrapped.append(
227
+ PTQWrapper(
228
+ fp_layer,
229
+ qcfg=layer_cfg,
230
+ fp_name=m_to_fqn.get(fp_layer),
231
+ )
232
+ )
233
+ q_m.model.layers = wrapped
234
+
235
+ # -------------------------------------------------------------------------
236
+ # 5. Single-pass activation calibration
237
+ # -------------------------------------------------------------------------
238
+ print("Calibrating UINT-8 observers …")
239
+ CALIB_TOKENS = TOKENS[args.calib_preset]
240
+ print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
241
+ dataset_train = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT)
242
+ calib_txt = " ".join(dataset_train["text"])[:CALIB_TOKENS]
243
+ train_ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device)
244
+
245
+ # (a) Enable CALIB mode on every QuantModuleBase
246
+ for l in q_m.model.layers:
247
+ l.enable_calibration()
248
+
249
+ # (b) Overwrite weight observers with GPTQ statistics
250
+ if hasattr(q_m, "quantizers") and isinstance(q_m.quantizers, dict):
251
+ inject_gptq_qparams(q_m, q_m.quantizers)
252
+ else:
253
+ print(
254
+ "[Warn] q_m.quantizers not found or not a dict; skipping GPTQ qparam injection."
255
+ )
256
+
257
+ # (c) Forward passes to collect activation ranges
258
+ iterator = range(0, train_ids.size(1) - 1, args.stride)
259
+ if not args.no_tqdm:
260
+ iterator = tqdm.tqdm(iterator, desc="Act-calibration")
261
+ with torch.no_grad():
262
+ for i in iterator:
263
+ q_m(train_ids[:, i : i + args.stride])
264
+
265
+ # (d) Freeze all Q-params (scale, zero-point)
266
+ for l in q_m.model.layers:
267
+ l.freeze_qparams()
268
+
269
+ # -------------------------------------------------------------------------
270
+ # 6. Evaluate perplexity on Wikitext-2
271
+ # -------------------------------------------------------------------------
272
+ print("\nCalculating perplexities …")
273
+ enc = tokenizer("\n\n".join(dataset_test["text"]), return_tensors="pt")
274
+ ppl_uint8 = perplexity(q_m, enc, device, stride=args.stride)
275
+
276
+ print("\n┌── Wikitext-2 test perplexity ─────────────")
277
+ print(f"│ UINT-8 : {ppl_uint8:8.2f}")
278
+ print("└───────────────────────────────────────────")
279
+
280
+
281
+ if __name__ == "__main__":
282
+ try:
283
+ main()
284
+ except Exception as e:
285
+ print(f"\n[Error] {e}", file=sys.stderr)
286
+ sys.exit(1)