tico 0.1.0.dev250915__py3-none-any.whl → 0.1.0.dev250917__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tico might be problematic. Click here for more details.

tico/__init__.py CHANGED
@@ -29,7 +29,7 @@ __all__ = [
29
29
  ]
30
30
 
31
31
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
32
- __version__ = "0.1.0.dev250915"
32
+ __version__ = "0.1.0.dev250917"
33
33
 
34
34
  MINIMUM_SUPPORTED_VERSION = "2.5.0"
35
35
  SECURE_TORCH_VERSION = "2.6.0"
@@ -77,7 +77,7 @@ def main():
77
77
  "--dtype",
78
78
  choices=list(DTYPE_MAP.keys()),
79
79
  default="float32",
80
- help="Model dtype for load (float32|bfloat16|float16).",
80
+ help=f"Model dtype for load.",
81
81
  )
82
82
  parser.add_argument(
83
83
  "--stride", type=int, default=512, help="Sliding-window stride for perplexity."
@@ -126,7 +126,9 @@ def main():
126
126
  print(f"DType : {args.dtype}")
127
127
  print(f"Stride : {args.stride}")
128
128
  print(f"Use HF cache? : {args.use_cache}")
129
- print(f"Calib preset : {args.calib_preset}")
129
+ print(
130
+ f"Calib preset : {args.calib_preset} ({TOKENS[args.calib_preset]:,} tokens)"
131
+ )
130
132
  print()
131
133
 
132
134
  # -------------------------------------------------------------------------
@@ -12,19 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import torch
16
- import tqdm
17
- from datasets import load_dataset
18
- from transformers import AutoModelForCausalLM, AutoTokenizer
19
-
20
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
21
- from tico.experimental.quantization.ptq.utils.introspection import (
22
- build_fqn_map,
23
- compare_layer_outputs,
24
- save_fp_outputs,
25
- )
26
- from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
27
-
28
15
  # ============================================================================
29
16
  # LAYER-WISE DIFF DEBUGGING PIPELINE
30
17
  # ----------------------------------------------------------------------------
@@ -43,12 +30,21 @@ from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
43
30
  # problematic modules during post-training quantization.
44
31
  # ============================================================================
45
32
 
46
- # -------------------------------------------------------------------------
47
- # 0. Global configuration
48
- # -------------------------------------------------------------------------
49
- MODEL_NAME = "meta-llama/Meta-Llama-3-1B"
50
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
51
- STRIDE = 512
33
+ import argparse
34
+ import sys
35
+
36
+ import torch
37
+ import tqdm
38
+ from datasets import load_dataset
39
+ from transformers import AutoModelForCausalLM, AutoTokenizer
40
+
41
+ from tico.experimental.quantization.ptq.quant_config import QuantConfig
42
+ from tico.experimental.quantization.ptq.utils.introspection import (
43
+ build_fqn_map,
44
+ compare_layer_outputs,
45
+ save_fp_outputs,
46
+ )
47
+ from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
52
48
 
53
49
  # Token-budget presets for activation calibration
54
50
  TOKENS: dict[str, int] = {
@@ -59,71 +55,185 @@ TOKENS: dict[str, int] = {
59
55
  # Production / 4-bit observer smoothing
60
56
  "production": 200_000,
61
57
  }
62
- CALIB_TOKENS = TOKENS["baseline"]
63
- print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
64
-
65
- # -------------------------------------------------------------------------
66
- # 1. Load the FP backbone
67
- # -------------------------------------------------------------------------
68
- print("Loading FP model …")
69
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
70
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE).eval()
71
- model.config.use_cache = False # disable KV-cache → full forward
72
- m_to_fqn = build_fqn_map(model) # map modules → fully-qualified names
73
-
74
- # Use Wikitext-2 train split for calibration.
75
- dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
76
-
77
- # -------------------------------------------------------------------------
78
- # 2. Wrap every layer with PTQWrapper (UINT-8 activations)
79
- # -------------------------------------------------------------------------
80
- print("Wrapping layers with PTQWrapper …")
81
- qcfg = QuantConfig() # default: per-tensor UINT8
82
-
83
- new_layers = torch.nn.ModuleList()
84
- for idx, fp_layer in enumerate(model.model.layers):
85
- layer_cfg = qcfg.child(f"layer{idx}")
86
- q_layer = PTQWrapper(
87
- fp_layer,
88
- qcfg=layer_cfg,
89
- fp_name=m_to_fqn.get(fp_layer),
90
- )
91
- new_layers.append(q_layer)
92
58
 
93
- model.model.layers = new_layers # swap in quant wrappers
94
-
95
- # -------------------------------------------------------------------------
96
- # 3. Activation calibration plus FP-vs-UINT8 diffing
97
- # -------------------------------------------------------------------------
98
- print("Calibrating UINT-8 observers …")
99
- calib_txt = " ".join(dataset["text"])[:CALIB_TOKENS]
100
- ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(DEVICE)
59
+ DTYPE_MAP = {
60
+ "float32": torch.float32,
61
+ "bfloat16": torch.bfloat16,
62
+ "float16": torch.float16,
63
+ }
101
64
 
102
- # (a) Enable CALIB mode on every QuantModuleBase
103
- for l in model.model.layers:
104
- l.enable_calibration()
65
+ # Hardcoded dataset settings
66
+ DATASET_NAME = "wikitext"
67
+ DATASET_CONFIG = "wikitext-2-raw-v1"
68
+ TRAIN_SPLIT = "train"
105
69
 
106
- # Save reference FP activations before observers clamp/quantize
107
- save_handles, act_cache = save_fp_outputs(model)
108
70
 
109
- with torch.no_grad():
110
- for i in tqdm.trange(0, ids.size(1) - 1, STRIDE, desc="Act-calibration"):
111
- inputs = ids[:, i : i + STRIDE]
112
- model(inputs) # observers collect act. ranges
71
+ def main():
72
+ parser = argparse.ArgumentParser(
73
+ description="Layer-wise diff debugging pipeline for PTQ"
74
+ )
75
+ parser.add_argument(
76
+ "--model", type=str, required=True, help="HF repo name or local path."
77
+ )
78
+ parser.add_argument(
79
+ "--device",
80
+ type=str,
81
+ default="cuda" if torch.cuda.is_available() else "cpu",
82
+ help="Device to run on (cuda|cpu|mps).",
83
+ )
84
+ parser.add_argument(
85
+ "--dtype",
86
+ choices=list(DTYPE_MAP.keys()),
87
+ default="float32",
88
+ help=f"Model dtype for load.",
89
+ )
90
+ parser.add_argument(
91
+ "--stride",
92
+ type=int,
93
+ default=512,
94
+ help="Sliding-window stride used during calibration.",
95
+ )
96
+ parser.add_argument(
97
+ "--calib-preset",
98
+ choices=list(TOKENS.keys()),
99
+ default="debug",
100
+ help="Calibration token budget preset.",
101
+ )
102
+ parser.add_argument("--seed", type=int, default=42, help="Random seed.")
103
+ parser.add_argument(
104
+ "--trust-remote-code",
105
+ action="store_true",
106
+ help="Enable only if you trust the model repo code.",
107
+ )
108
+ parser.add_argument(
109
+ "--hf-token",
110
+ type=str,
111
+ default=None,
112
+ help="Optional HF token for gated/private repos.",
113
+ )
114
+ parser.add_argument(
115
+ "--use-cache",
116
+ dest="use_cache",
117
+ action="store_true",
118
+ default=False,
119
+ help="Use model KV cache if enabled (off by default).",
120
+ )
121
+ parser.add_argument(
122
+ "--no-tqdm", action="store_true", help="Disable tqdm progress bars."
123
+ )
113
124
 
114
- # Remove save hooks now that FP activations are cached
115
- for h in save_handles:
116
- h.remove()
125
+ args = parser.parse_args()
117
126
 
118
- # (b) Freeze (scale, zero-point) after calibration
119
- for l in model.model.layers:
120
- l.freeze_qparams()
127
+ # Basic setup
128
+ torch.manual_seed(args.seed)
129
+ device = torch.device(args.device)
130
+ dtype = DTYPE_MAP[args.dtype] # noqa: E999 (kept readable)
121
131
 
122
- # (c) Register diff hooks and measure per-layer deltas
123
- cmp_handles = compare_layer_outputs(model, act_cache, metrics=["diff", "peir"])
124
- # Use same inputs for comparison.
125
- model(inputs)
132
+ print("=== Config ===")
133
+ print(f"Model : {args.model}")
134
+ print(f"Device : {device.type}")
135
+ print(f"DType : {args.dtype}")
136
+ print(f"Stride : {args.stride}")
137
+ print(
138
+ f"Calib preset : {args.calib_preset} ({TOKENS[args.calib_preset]:,} tokens)"
139
+ )
140
+ print(f"Use HF cache? : {args.use_cache}")
141
+ print()
142
+
143
+ # -------------------------------------------------------------------------
144
+ # 1. Load the FP backbone and tokenizer
145
+ # -------------------------------------------------------------------------
146
+ print("Loading FP model …")
147
+ tokenizer = AutoTokenizer.from_pretrained(
148
+ args.model,
149
+ trust_remote_code=args.trust_remote_code,
150
+ token=args.hf_token,
151
+ )
152
+ model = (
153
+ AutoModelForCausalLM.from_pretrained(
154
+ args.model,
155
+ torch_dtype=dtype,
156
+ trust_remote_code=args.trust_remote_code,
157
+ token=args.hf_token,
158
+ )
159
+ .to(device)
160
+ .eval()
161
+ )
126
162
 
127
- assert isinstance(cmp_handles, list)
128
- for h in cmp_handles:
129
- h.remove()
163
+ # Disable KV cache to force full forward passes for introspection
164
+ model.config.use_cache = args.use_cache
165
+
166
+ # Build module -> FQN map before wrapping
167
+ m_to_fqn = build_fqn_map(model)
168
+
169
+ # Prepare calibration inputs (HF Wikitext-2 train split)
170
+ CALIB_TOKENS = TOKENS[args.calib_preset]
171
+ print(f"Calibrating with {CALIB_TOKENS:,} tokens.\n")
172
+ # Use Wikitext-2 train split for calibration.
173
+ dataset = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT)
174
+
175
+ # -------------------------------------------------------------------------
176
+ # 2. Wrap every layer with PTQWrapper (UINT-8 activations)
177
+ # -------------------------------------------------------------------------
178
+ print("Wrapping layers with PTQWrapper …")
179
+ qcfg = QuantConfig() # default: per-tensor UINT8
180
+
181
+ new_layers = torch.nn.ModuleList()
182
+ for idx, fp_layer in enumerate(model.model.layers):
183
+ layer_cfg = qcfg.child(f"layer{idx}")
184
+ q_layer = PTQWrapper(
185
+ fp_layer,
186
+ qcfg=layer_cfg,
187
+ fp_name=m_to_fqn.get(fp_layer),
188
+ )
189
+ new_layers.append(q_layer)
190
+
191
+ model.model.layers = new_layers # swap in quant wrappers
192
+
193
+ # -------------------------------------------------------------------------
194
+ # 3. Activation calibration plus FP-vs-UINT8 diffing
195
+ # -------------------------------------------------------------------------
196
+ print("Calibrating UINT-8 observers …")
197
+ calib_txt = " ".join(dataset["text"])[:CALIB_TOKENS]
198
+ ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device)
199
+
200
+ # (a) Enable CALIB mode on every QuantModuleBase
201
+ for l in model.model.layers:
202
+ l.enable_calibration()
203
+
204
+ # Save reference FP activations before observers clamp/quantize
205
+ save_handles, act_cache = save_fp_outputs(model)
206
+
207
+ iterator = range(0, ids.size(1) - 1, args.stride)
208
+ if not args.no_tqdm:
209
+ iterator = tqdm.tqdm(iterator, desc="Act-Calibration")
210
+ with torch.no_grad():
211
+ for i in iterator:
212
+ inputs = ids[:, i : i + args.stride]
213
+ model(inputs) # observers collect act. ranges
214
+
215
+ # Remove save hooks now that FP activations are cached
216
+ for h in save_handles:
217
+ h.remove()
218
+
219
+ # (b) Freeze (scale, zero-point) after calibration
220
+ for l in model.model.layers:
221
+ l.freeze_qparams()
222
+
223
+ # (c) Register diff hooks and measure per-layer deltas
224
+ cmp_handles = compare_layer_outputs(model, act_cache, metrics=["diff", "peir"])
225
+ # Use same inputs for comparison.
226
+ with torch.no_grad():
227
+ model(inputs)
228
+
229
+ assert isinstance(cmp_handles, list)
230
+ for h in cmp_handles:
231
+ h.remove()
232
+
233
+
234
+ if __name__ == "__main__":
235
+ try:
236
+ main()
237
+ except Exception as e:
238
+ print(f"\n[Error] {e}", file=sys.stderr)
239
+ sys.exit(1)
tico/utils/signature.py CHANGED
@@ -141,22 +141,21 @@ class ModelInputSpec:
141
141
  args = flatten_and_convert_args(args)
142
142
  kwargs = flatten_and_convert_kwargs(kwargs)
143
143
 
144
+ arg_num = len(args) + len(kwargs)
145
+ m_input_num = len(self.names)
146
+ if arg_num != m_input_num:
147
+ raise ValueError(
148
+ f"Mismatch: number of model inputs and number of passed arguments are not the same: inputs({m_input_num}) != passed({arg_num}), input spec: {self.names}"
149
+ )
150
+
144
151
  # 1. positional arguments
145
152
  for i, val in enumerate(args):
146
- if i >= len(self.names):
147
- raise ValueError(f"Too many positional arguments ({i+1}).")
148
153
  name = self.names[i]
149
- if name in kwargs:
150
- raise TypeError(
151
- f"Got multiple values for argument '{name}' (positional and keyword)."
152
- )
153
154
  inputs.append(val)
154
155
 
155
156
  # 2. keyword arguments
156
157
  for idx in range(len(args), len(self.names)):
157
158
  name = self.names[idx]
158
- if name not in kwargs:
159
- raise ValueError(f"Missing argument for input '{name}'.")
160
159
  inputs.append(kwargs[name])
161
160
 
162
161
  if check:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tico
3
- Version: 0.1.0.dev250915
3
+ Version: 0.1.0.dev250917
4
4
  Summary: Convert exported Torch module to circle
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- tico/__init__.py,sha256=RT4YNN5EM4rbOVOWo1BHEO8vnfPWLEwVNjMFh3qRYeY,1883
1
+ tico/__init__.py,sha256=Da7Ln6MuWCBJXrjts6OsAslWSS79toVgPG2PITYPzE0,1883
2
2
  tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
3
3
  tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
4
4
  tico/config/base.py,sha256=q5xMqGxTUZs4mFqt5c7i_y9U00fYgdMGl9nUqIVMlCo,1248
@@ -62,8 +62,8 @@ tico/experimental/quantization/ptq/mode.py,sha256=lT-T8vIv8YWcwrjT7xXVhOw1g7aoAd
62
62
  tico/experimental/quantization/ptq/qscheme.py,sha256=uwhv7bCxOOXB3I-IKlRyr_u4eXOq48uIqGy4TLDqGxY,1301
63
63
  tico/experimental/quantization/ptq/quant_config.py,sha256=nm7570Y1X2mOT_8s27ilWid04otor6cVTi9GwgAEaKc,4300
64
64
  tico/experimental/quantization/ptq/examples/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
65
- tico/experimental/quantization/ptq/examples/compare_ppl.py,sha256=QWUuO50lITnooYqEe57VV6mvIHKWZMB_TOGvtZ8C8qQ,8238
66
- tico/experimental/quantization/ptq/examples/debug_quant_outputs.py,sha256=astXzx-maq1W4gKvX2QaGmD2Tpmjunv4JqDYVk9eZRQ,5177
65
+ tico/experimental/quantization/ptq/examples/compare_ppl.py,sha256=eVQn8-M24QkoCy_FCBQLSlUrnyqUDSkvUFpUpEdpMx4,8265
66
+ tico/experimental/quantization/ptq/examples/debug_quant_outputs.py,sha256=Hpx_jj46WISwdVp33NrIadizVAzf1nFTXuAcHsKEQuQ,8179
67
67
  tico/experimental/quantization/ptq/examples/quantize_linear.py,sha256=8zq-ZJDYgam0xQ-PbC6Xb1I7W1mv0Wi-b--IP2wwXtw,4539
68
68
  tico/experimental/quantization/ptq/examples/quantize_llama_attn.py,sha256=cVWUSSzaZWFp5QZkNkrlpHU3kXyP84QtnZbahVml_yQ,4329
69
69
  tico/experimental/quantization/ptq/examples/quantize_llama_decoder_layer.py,sha256=mBWrjkyEovYQsPC4Rrsri6Pm1rlFmDb3NiP0DQQhFyM,5751
@@ -243,7 +243,7 @@ tico/utils/pytree_utils.py,sha256=jrk3N6X6LiUnBCX_gM1K9nywbVAJBVnszlTAgeIeDUc,52
243
243
  tico/utils/record_input.py,sha256=QN-8D71G_WAX3QQQ5CIwbEfFJZTQ3CvL4wCMiVddua4,3894
244
244
  tico/utils/register_custom_op.py,sha256=895SKZeXQzolK-mPG38cQC37Be76xUV_Ujw1k1ts9_w,28218
245
245
  tico/utils/serialize.py,sha256=mEuusEzi82WFsz3AkowgWwxSLeo50JDxyOj6yYDQhEI,1914
246
- tico/utils/signature.py,sha256=R2GV0alRpXEbZISqPKyxCUWbgDcsrQ2ovbVG3737IzA,9595
246
+ tico/utils/signature.py,sha256=3OOwyVJzfcGcgC0LiVmOcUIzfqSk27qoNHhkoCI7zPY,9530
247
247
  tico/utils/torch_compat.py,sha256=oc6PztVsXdHcQ3iaVR90wLLxrGaj6zFHWZ8K9rRS6q8,1795
248
248
  tico/utils/trace_decorators.py,sha256=ddLIiKQfSaQrxgF1kNpwjFTQnXENzeSfcr1kuAW4jGI,3221
249
249
  tico/utils/utils.py,sha256=aySftYnNTsqVAMcGs_3uX3-hz577a2cj4p1aVV-1XeQ,12747
@@ -252,9 +252,9 @@ tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
252
252
  tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
253
253
  tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
254
254
  tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
255
- tico-0.1.0.dev250915.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
256
- tico-0.1.0.dev250915.dist-info/METADATA,sha256=5l-EgJKZwF179QnVqWApKdARhQxw0c2iibtckWUu-XA,8450
257
- tico-0.1.0.dev250915.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
258
- tico-0.1.0.dev250915.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
259
- tico-0.1.0.dev250915.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
260
- tico-0.1.0.dev250915.dist-info/RECORD,,
255
+ tico-0.1.0.dev250917.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
256
+ tico-0.1.0.dev250917.dist-info/METADATA,sha256=WJdcwQ8suuOhdWCv9cW8_RW_qyckaOM5jEzlvi00vbM,8450
257
+ tico-0.1.0.dev250917.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
258
+ tico-0.1.0.dev250917.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
259
+ tico-0.1.0.dev250917.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
260
+ tico-0.1.0.dev250917.dist-info/RECORD,,