tico 0.1.0.dev250917__py3-none-any.whl → 0.1.0.dev250918__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tico might be problematic. Click here for more details.

tico/__init__.py CHANGED
@@ -29,7 +29,7 @@ __all__ = [
29
29
  ]
30
30
 
31
31
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
32
- __version__ = "0.1.0.dev250917"
32
+ __version__ = "0.1.0.dev250918"
33
33
 
34
34
  MINIMUM_SUPPORTED_VERSION = "2.5.0"
35
35
  SECURE_TORCH_VERSION = "2.6.0"
tico/config/v1.py CHANGED
@@ -20,6 +20,9 @@ from tico.config.base import CompileConfigBase
20
20
  @dataclass
21
21
  class CompileConfigV1(CompileConfigBase):
22
22
  legalize_causal_mask_value: bool = False
23
+ remove_constant_input: bool = False
24
+ convert_lhs_const_mm_to_fc: bool = False
25
+ convert_rhs_const_mm_to_fc: bool = True
23
26
 
24
27
  def get(self, name: str):
25
28
  return super().get(name)
@@ -24,6 +24,8 @@
24
24
  # 6. Freeze all Q-params and compute Wikitext-2 perplexity.
25
25
  # =============================================================================
26
26
 
27
+ import argparse
28
+ import sys
27
29
  from typing import Any
28
30
 
29
31
  import torch
@@ -42,12 +44,6 @@ from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
42
44
  QuantModuleBase,
43
45
  )
44
46
 
45
- # -------------------------------------------------------------------------
46
- # 0. Global configuration
47
- # -------------------------------------------------------------------------
48
- MODEL_NAME = "meta-llama/Meta-Llama-3-1B"
49
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
50
- STRIDE = 512
51
47
 
52
48
  # Token-budget presets for activation calibration
53
49
  TOKENS: dict[str, int] = {
@@ -58,7 +54,18 @@ TOKENS: dict[str, int] = {
58
54
  # Production / 4-bit observer smoothing
59
55
  "production": 200_000,
60
56
  }
61
- CALIB_TOKENS = TOKENS["baseline"]
57
+
58
+ DTYPE_MAP = {
59
+ "float32": torch.float32,
60
+ "bfloat16": torch.bfloat16,
61
+ "float16": torch.float16,
62
+ }
63
+
64
+ # Hardcoded dataset settings
65
+ DATASET_NAME = "wikitext"
66
+ DATASET_CONFIG = "wikitext-2-raw-v1"
67
+ TRAIN_SPLIT = "train"
68
+ TEST_SPLIT = "test"
62
69
 
63
70
  # -------------------------------------------------------------------------
64
71
  # 1. Helper — copy GPTQ (scale, zp) into PTQ observers
@@ -89,77 +96,191 @@ def inject_gptq_qparams(
89
96
  obs.load_qparams(quantizer.scale, quantizer.zero, lock=True)
90
97
 
91
98
 
92
- # -------------------------------------------------------------------------
93
- # 2. Load the FP backbone
94
- # -------------------------------------------------------------------------
95
- print("Loading FP model …")
96
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
97
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE).eval()
98
- model.config.use_cache = False # disable KV-cache → full forward
99
- m_to_fqn = build_fqn_map(model) # map modules → fully-qualified names
99
+ def main():
100
+ parser = argparse.ArgumentParser(
101
+ description="GPTQ+PTQ pipeline (weight-only + activation UINT8)"
102
+ )
103
+ parser.add_argument(
104
+ "--model", type=str, required=True, help="HF repo name or local path."
105
+ )
106
+ parser.add_argument(
107
+ "--device",
108
+ type=str,
109
+ default="cuda" if torch.cuda.is_available() else "cpu",
110
+ help="Device to run on (cuda|cpu|mps).",
111
+ )
112
+ parser.add_argument(
113
+ "--dtype",
114
+ choices=list(DTYPE_MAP.keys()),
115
+ default="float32",
116
+ help="Model dtype for load.",
117
+ )
118
+ parser.add_argument(
119
+ "--stride",
120
+ type=int,
121
+ default=512,
122
+ help="Sliding-window stride used for calibration and eval.",
123
+ )
124
+ parser.add_argument(
125
+ "--calib-preset",
126
+ choices=list(TOKENS.keys()),
127
+ default="debug",
128
+ help="Activation calibration token budget preset.",
129
+ )
130
+ parser.add_argument("--seed", type=int, default=42, help="Random seed.")
131
+ parser.add_argument(
132
+ "--trust-remote-code",
133
+ action="store_true",
134
+ help="Enable only if you trust the model repo code.",
135
+ )
136
+ parser.add_argument(
137
+ "--hf-token",
138
+ type=str,
139
+ default=None,
140
+ help="Optional HF token for gated/private repos.",
141
+ )
142
+ parser.add_argument(
143
+ "--use-cache",
144
+ dest="use_cache",
145
+ action="store_true",
146
+ default=False,
147
+ help="Use model KV cache if enabled (off by default).",
148
+ )
149
+ parser.add_argument(
150
+ "--no-tqdm", action="store_true", help="Disable tqdm progress bars."
151
+ )
100
152
 
101
- # -------------------------------------------------------------------------
102
- # 3. Run GPTQ (weight-only) pass
103
- # -------------------------------------------------------------------------
104
- print("Applying GPTQ …")
105
- dataset = load_dataset("wikiText", "wikitext-2-raw-v1", split="test")
106
- q_m = prepare(model, GPTQConfig(), inplace=True)
153
+ args = parser.parse_args()
107
154
 
108
- for d in tqdm.tqdm(dataset, desc="GPTQ calibration"):
109
- ids = tokenizer(d["text"], return_tensors="pt").input_ids.to(DEVICE)
110
- q_m(ids) # observers gather weight stats
155
+ # Basic setup
156
+ torch.manual_seed(args.seed)
157
+ device = torch.device(args.device)
158
+ dtype = DTYPE_MAP[args.dtype]
111
159
 
112
- q_m = convert(q_m, inplace=True) # materialize INT-weight tensors
160
+ print("=== Config ===")
161
+ print(f"Model : {args.model}")
162
+ print(f"Device : {device.type}")
163
+ print(f"DType : {args.dtype}")
164
+ print(f"Stride : {args.stride}")
165
+ print(
166
+ f"Calib preset : {args.calib_preset} ({TOKENS[args.calib_preset]:,} tokens)"
167
+ )
168
+ print(f"Use HF cache? : {args.use_cache}")
169
+ print()
113
170
 
114
- # -------------------------------------------------------------------------
115
- # 4. Wrap every layer with PTQWrapper (activation UINT-8)
116
- # -------------------------------------------------------------------------
117
- qcfg = QuantConfig() # default: per-tensor UINT8
118
- new_layers = torch.nn.ModuleList()
119
-
120
- for idx, fp_layer in enumerate(q_m.model.layers):
121
- layer_cfg = qcfg.child(f"layer{idx}")
122
- q_layer = PTQWrapper(
123
- fp_layer,
124
- qcfg=layer_cfg,
125
- fp_name=m_to_fqn.get(fp_layer),
171
+ # -------------------------------------------------------------------------
172
+ # 2. Load the FP backbone and tokenizer
173
+ # -------------------------------------------------------------------------
174
+ print("Loading FP model …")
175
+ tokenizer = AutoTokenizer.from_pretrained(
176
+ args.model,
177
+ trust_remote_code=args.trust_remote_code,
178
+ token=args.hf_token,
179
+ )
180
+ model = (
181
+ AutoModelForCausalLM.from_pretrained(
182
+ args.model,
183
+ torch_dtype=dtype,
184
+ trust_remote_code=args.trust_remote_code,
185
+ token=args.hf_token,
186
+ )
187
+ .to(device)
188
+ .eval()
126
189
  )
127
- new_layers.append(q_layer)
128
190
 
129
- q_m.model.layers = new_layers
191
+ model.config.use_cache = args.use_cache
130
192
 
131
- # -------------------------------------------------------------------------
132
- # 5. Single-pass activation calibration
133
- # -------------------------------------------------------------------------
134
- print("Calibrating UINT-8 observers …")
135
- calib_txt = " ".join(
136
- load_dataset("wikitext", "wikitext-2-raw-v1", split="train")["text"]
137
- )[:CALIB_TOKENS]
138
- ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(DEVICE)
193
+ # Build module -> FQN map BEFORE wrapping
194
+ m_to_fqn = build_fqn_map(model)
139
195
 
140
- # (a) Enable CALIB mode on every QuantModuleBase
141
- for l in q_m.model.layers:
142
- l.enable_calibration()
196
+ # -------------------------------------------------------------------------
197
+ # 3. Run GPTQ (weight-only) pass
198
+ # -------------------------------------------------------------------------
199
+ print("Applying GPTQ …")
200
+ dataset_test = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TEST_SPLIT)
201
+ q_m = prepare(model, GPTQConfig(), inplace=True)
143
202
 
144
- # (b) Overwrite weight observers with GPTQ statistics
145
- inject_gptq_qparams(q_m, q_m.quantizers)
203
+ it = (
204
+ dataset_test
205
+ if args.no_tqdm
206
+ else tqdm.tqdm(dataset_test, desc="GPTQ calibration")
207
+ )
208
+ for d in it:
209
+ ids = tokenizer(d["text"], return_tensors="pt").input_ids.to(device)
210
+ q_m(ids) # observers gather weight stats
146
211
 
147
- with torch.no_grad():
148
- for i in tqdm.trange(0, ids.size(1) - 1, STRIDE, desc="Act-calibration"):
149
- q_m(ids[:, i : i + STRIDE]) # observers collect act. ranges
212
+ q_m = convert(q_m, inplace=True) # materialize INT-weight tensors
150
213
 
151
- # (c) Freeze all Q-params (scale, zp)
152
- for l in q_m.model.layers:
153
- l.freeze_qparams()
214
+ # -------------------------------------------------------------------------
215
+ # 4. Wrap every layer with PTQWrapper (activation UINT-8)
216
+ # -------------------------------------------------------------------------
217
+ print("Wrapping layers with PTQWrapper …")
218
+ layers = q_m.model.layers
219
+ if not isinstance(layers, (list, torch.nn.ModuleList)):
220
+ raise TypeError(f"'model.layers' must be list/ModuleList, got {type(layers)}")
154
221
 
155
- # -------------------------------------------------------------------------
156
- # 6. Evaluate perplexity on Wikitext-2
157
- # -------------------------------------------------------------------------
158
- print("\nCalculating perplexities ")
159
- test_ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
160
- enc = tokenizer("\n\n".join(test_ds["text"]), return_tensors="pt")
161
- ppl_uint8 = perplexity(q_m, enc, DEVICE, stride=STRIDE)
162
-
163
- print("\n┌── Wikitext-2 test perplexity ─────────────")
164
- print(f"│ UINT-8 : {ppl_uint8:8.2f}")
165
- print("└───────────────────────────────────────────")
222
+ qcfg = QuantConfig() # default: per-tensor UINT8
223
+ wrapped = torch.nn.ModuleList()
224
+ for idx, fp_layer in enumerate(layers):
225
+ layer_cfg = qcfg.child(f"layer{idx}")
226
+ wrapped.append(
227
+ PTQWrapper(
228
+ fp_layer,
229
+ qcfg=layer_cfg,
230
+ fp_name=m_to_fqn.get(fp_layer),
231
+ )
232
+ )
233
+ q_m.model.layers = wrapped
234
+
235
+ # -------------------------------------------------------------------------
236
+ # 5. Single-pass activation calibration
237
+ # -------------------------------------------------------------------------
238
+ print("Calibrating UINT-8 observers …")
239
+ CALIB_TOKENS = TOKENS[args.calib_preset]
240
+ print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
241
+ dataset_train = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT)
242
+ calib_txt = " ".join(dataset_train["text"])[:CALIB_TOKENS]
243
+ train_ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device)
244
+
245
+ # (a) Enable CALIB mode on every QuantModuleBase
246
+ for l in q_m.model.layers:
247
+ l.enable_calibration()
248
+
249
+ # (b) Overwrite weight observers with GPTQ statistics
250
+ if hasattr(q_m, "quantizers") and isinstance(q_m.quantizers, dict):
251
+ inject_gptq_qparams(q_m, q_m.quantizers)
252
+ else:
253
+ print(
254
+ "[Warn] q_m.quantizers not found or not a dict; skipping GPTQ qparam injection."
255
+ )
256
+
257
+ # (c) Forward passes to collect activation ranges
258
+ iterator = range(0, train_ids.size(1) - 1, args.stride)
259
+ if not args.no_tqdm:
260
+ iterator = tqdm.tqdm(iterator, desc="Act-calibration")
261
+ with torch.no_grad():
262
+ for i in iterator:
263
+ q_m(train_ids[:, i : i + args.stride])
264
+
265
+ # (d) Freeze all Q-params (scale, zero-point)
266
+ for l in q_m.model.layers:
267
+ l.freeze_qparams()
268
+
269
+ # -------------------------------------------------------------------------
270
+ # 6. Evaluate perplexity on Wikitext-2
271
+ # -------------------------------------------------------------------------
272
+ print("\nCalculating perplexities …")
273
+ enc = tokenizer("\n\n".join(dataset_test["text"]), return_tensors="pt")
274
+ ppl_uint8 = perplexity(q_m, enc, device, stride=args.stride)
275
+
276
+ print("\n┌── Wikitext-2 test perplexity ─────────────")
277
+ print(f"│ UINT-8 : {ppl_uint8:8.2f}")
278
+ print("└───────────────────────────────────────────")
279
+
280
+
281
+ if __name__ == "__main__":
282
+ try:
283
+ main()
284
+ except Exception as e:
285
+ print(f"\n[Error] {e}", file=sys.stderr)
286
+ sys.exit(1)