invarlock 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. invarlock/__init__.py +33 -0
  2. invarlock/__main__.py +10 -0
  3. invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
  4. invarlock/_data/runtime/profiles/release.yaml +23 -0
  5. invarlock/_data/runtime/tiers.yaml +76 -0
  6. invarlock/adapters/__init__.py +102 -0
  7. invarlock/adapters/_capabilities.py +45 -0
  8. invarlock/adapters/auto.py +99 -0
  9. invarlock/adapters/base.py +530 -0
  10. invarlock/adapters/base_types.py +85 -0
  11. invarlock/adapters/hf_bert.py +852 -0
  12. invarlock/adapters/hf_gpt2.py +403 -0
  13. invarlock/adapters/hf_llama.py +485 -0
  14. invarlock/adapters/hf_mixin.py +383 -0
  15. invarlock/adapters/hf_onnx.py +112 -0
  16. invarlock/adapters/hf_t5.py +137 -0
  17. invarlock/adapters/py.typed +1 -0
  18. invarlock/assurance/__init__.py +43 -0
  19. invarlock/cli/__init__.py +8 -0
  20. invarlock/cli/__main__.py +8 -0
  21. invarlock/cli/_evidence.py +25 -0
  22. invarlock/cli/_json.py +75 -0
  23. invarlock/cli/adapter_auto.py +162 -0
  24. invarlock/cli/app.py +287 -0
  25. invarlock/cli/commands/__init__.py +26 -0
  26. invarlock/cli/commands/certify.py +403 -0
  27. invarlock/cli/commands/doctor.py +1358 -0
  28. invarlock/cli/commands/explain_gates.py +151 -0
  29. invarlock/cli/commands/export_html.py +100 -0
  30. invarlock/cli/commands/plugins.py +1331 -0
  31. invarlock/cli/commands/report.py +354 -0
  32. invarlock/cli/commands/run.py +4146 -0
  33. invarlock/cli/commands/verify.py +1040 -0
  34. invarlock/cli/config.py +396 -0
  35. invarlock/cli/constants.py +68 -0
  36. invarlock/cli/device.py +92 -0
  37. invarlock/cli/doctor_helpers.py +74 -0
  38. invarlock/cli/errors.py +6 -0
  39. invarlock/cli/overhead_utils.py +60 -0
  40. invarlock/cli/provenance.py +66 -0
  41. invarlock/cli/utils.py +41 -0
  42. invarlock/config.py +56 -0
  43. invarlock/core/__init__.py +62 -0
  44. invarlock/core/abi.py +15 -0
  45. invarlock/core/api.py +274 -0
  46. invarlock/core/auto_tuning.py +317 -0
  47. invarlock/core/bootstrap.py +226 -0
  48. invarlock/core/checkpoint.py +221 -0
  49. invarlock/core/contracts.py +73 -0
  50. invarlock/core/error_utils.py +64 -0
  51. invarlock/core/events.py +298 -0
  52. invarlock/core/exceptions.py +95 -0
  53. invarlock/core/registry.py +481 -0
  54. invarlock/core/retry.py +146 -0
  55. invarlock/core/runner.py +2041 -0
  56. invarlock/core/types.py +154 -0
  57. invarlock/edits/__init__.py +12 -0
  58. invarlock/edits/_edit_utils.py +249 -0
  59. invarlock/edits/_external_utils.py +268 -0
  60. invarlock/edits/noop.py +47 -0
  61. invarlock/edits/py.typed +1 -0
  62. invarlock/edits/quant_rtn.py +801 -0
  63. invarlock/edits/registry.py +166 -0
  64. invarlock/eval/__init__.py +23 -0
  65. invarlock/eval/bench.py +1207 -0
  66. invarlock/eval/bootstrap.py +50 -0
  67. invarlock/eval/data.py +2052 -0
  68. invarlock/eval/metrics.py +2167 -0
  69. invarlock/eval/primary_metric.py +767 -0
  70. invarlock/eval/probes/__init__.py +24 -0
  71. invarlock/eval/probes/fft.py +139 -0
  72. invarlock/eval/probes/mi.py +213 -0
  73. invarlock/eval/probes/post_attention.py +323 -0
  74. invarlock/eval/providers/base.py +67 -0
  75. invarlock/eval/providers/seq2seq.py +111 -0
  76. invarlock/eval/providers/text_lm.py +113 -0
  77. invarlock/eval/providers/vision_text.py +93 -0
  78. invarlock/eval/py.typed +1 -0
  79. invarlock/guards/__init__.py +18 -0
  80. invarlock/guards/_contracts.py +9 -0
  81. invarlock/guards/invariants.py +640 -0
  82. invarlock/guards/policies.py +805 -0
  83. invarlock/guards/py.typed +1 -0
  84. invarlock/guards/rmt.py +2097 -0
  85. invarlock/guards/spectral.py +1419 -0
  86. invarlock/guards/tier_config.py +354 -0
  87. invarlock/guards/variance.py +3298 -0
  88. invarlock/guards_ref/__init__.py +15 -0
  89. invarlock/guards_ref/rmt_ref.py +40 -0
  90. invarlock/guards_ref/spectral_ref.py +135 -0
  91. invarlock/guards_ref/variance_ref.py +60 -0
  92. invarlock/model_profile.py +353 -0
  93. invarlock/model_utils.py +221 -0
  94. invarlock/observability/__init__.py +10 -0
  95. invarlock/observability/alerting.py +535 -0
  96. invarlock/observability/core.py +546 -0
  97. invarlock/observability/exporters.py +565 -0
  98. invarlock/observability/health.py +588 -0
  99. invarlock/observability/metrics.py +457 -0
  100. invarlock/observability/py.typed +1 -0
  101. invarlock/observability/utils.py +553 -0
  102. invarlock/plugins/__init__.py +12 -0
  103. invarlock/plugins/hello_guard.py +33 -0
  104. invarlock/plugins/hf_awq_adapter.py +82 -0
  105. invarlock/plugins/hf_bnb_adapter.py +79 -0
  106. invarlock/plugins/hf_gptq_adapter.py +78 -0
  107. invarlock/plugins/py.typed +1 -0
  108. invarlock/py.typed +1 -0
  109. invarlock/reporting/__init__.py +7 -0
  110. invarlock/reporting/certificate.py +3221 -0
  111. invarlock/reporting/certificate_schema.py +244 -0
  112. invarlock/reporting/dataset_hashing.py +215 -0
  113. invarlock/reporting/guards_analysis.py +948 -0
  114. invarlock/reporting/html.py +32 -0
  115. invarlock/reporting/normalizer.py +235 -0
  116. invarlock/reporting/policy_utils.py +517 -0
  117. invarlock/reporting/primary_metric_utils.py +265 -0
  118. invarlock/reporting/render.py +1442 -0
  119. invarlock/reporting/report.py +903 -0
  120. invarlock/reporting/report_types.py +278 -0
  121. invarlock/reporting/utils.py +175 -0
  122. invarlock/reporting/validate.py +631 -0
  123. invarlock/security.py +176 -0
  124. invarlock/sparsity_utils.py +323 -0
  125. invarlock/utils/__init__.py +150 -0
  126. invarlock/utils/digest.py +45 -0
  127. invarlock-0.2.0.dist-info/METADATA +586 -0
  128. invarlock-0.2.0.dist-info/RECORD +132 -0
  129. invarlock-0.2.0.dist-info/WHEEL +5 -0
  130. invarlock-0.2.0.dist-info/entry_points.txt +20 -0
  131. invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
  132. invarlock-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,801 @@
1
+ """
2
+ InvarLock – RTN Quantization Edit Plugin
3
+ ====================================
4
+
5
+ Pure PyTorch Round-To-Nearest (RTN) weight-only quantization with no external dependencies.
6
+ Implements per-channel symmetric quantization with optional group size and outlier clipping.
7
+
8
+ Features:
9
+ - 8-bit weight quantization (INT8 RTN demo edit)
10
+ - Per-channel symmetric quantization (zero-point = 0)
11
+ - Configurable scope (FFN, attention, or all linear layers)
12
+ - Deterministic behavior with seed control
13
+ - GuardChain integration with quantization-aware policies
14
+
15
+ Follows the ModelEdit protocol with preview() and apply() methods.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import random
21
+ from typing import Any
22
+
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn as nn
26
+
27
+ from invarlock.core.api import CalibrationData, GuardChain, ModelAdapter, ModelEdit
28
+
29
+ __all__ = ["RTNQuantEdit"]
30
+
31
+
32
+ class RTNQuantEdit(ModelEdit):
33
+ """
34
+ ModelEdit implementation for RTN (Round-To-Nearest) weight-only quantization.
35
+
36
+ This built-in edit is intentionally minimal and calibrated for INT8 only.
37
+ It performs symmetric per-channel quantization with configurable scope and
38
+ deterministic operation.
39
+ """
40
+
41
+ name = "quant_rtn"
42
+
43
+ def __init__(
44
+ self,
45
+ bitwidth: int = 8,
46
+ per_channel: bool = True,
47
+ group_size: int | None = None,
48
+ clamp_ratio: float = 0.0,
49
+ scope: str = "ffn",
50
+ seed: int = 42,
51
+ guard_chain: GuardChain | None = None,
52
+ max_modules: int | None = None,
53
+ ):
54
+ """
55
+ Initialize RTN quantization edit.
56
+
57
+ Args:
58
+ bitwidth: Quantization bitwidth (INT8 only for built-in edit)
59
+ per_channel: Always True for per-channel quantization
60
+ group_size: Reserved for future use (ignored for INT8 demo edit)
61
+ clamp_ratio: Outlier clipping ratio (0.0 = no clipping)
62
+ scope: Target scope ("ffn", "attn", "all")
63
+ seed: Random seed for deterministic behavior
64
+ guard_chain: Optional GuardChain for safety checks
65
+ """
66
+ # Validate configuration – built-in edit is INT8-only
67
+ if bitwidth != 8:
68
+ raise ValueError(
69
+ f"RTNQuantEdit only supports 8-bit quantization (got bitwidth={bitwidth})"
70
+ )
71
+ if not (0.0 <= clamp_ratio <= 0.5):
72
+ raise ValueError(
73
+ f"Clamp ratio must be between 0.0 and 0.5, got {clamp_ratio}"
74
+ )
75
+ if scope not in ["ffn", "attn", "all"]:
76
+ raise ValueError(f"Scope must be 'ffn', 'attn', or 'all', got {scope}")
77
+
78
+ self.bitwidth = bitwidth
79
+ self.per_channel = per_channel # Always True
80
+ self.group_size = group_size
81
+ self.clamp_ratio = clamp_ratio
82
+ self.scope = scope
83
+ self.seed = seed
84
+ self.guard_chain = guard_chain
85
+ self.max_modules = max_modules
86
+
87
+ # group_size is currently reserved for potential future variants; it is
88
+ # ignored for the built-in INT8 demo edit.
89
+
90
+ def can_edit(self, model_desc: dict[str, Any]) -> bool:
91
+ """Check if RTN quantization can be applied to this model."""
92
+ # Basic requirements for quantization
93
+ required_keys = ["n_layer", "total_params"]
94
+ has_requirements = all(key in model_desc for key in required_keys)
95
+
96
+ # Need sufficient model size for meaningful quantization
97
+ if has_requirements and model_desc.get("total_params", 0) > 1000:
98
+ return True
99
+ return False
100
+
101
+ def preview(
102
+ self, model: nn.Module, adapter: ModelAdapter, calib: CalibrationData
103
+ ) -> dict:
104
+ """
105
+ Preview RTN quantization without modifying the model.
106
+
107
+ Args:
108
+ model: The model to preview quantization on
109
+ adapter: ModelAdapter for model-specific operations
110
+ calib: Calibration data (not used for RTN)
111
+
112
+ Returns:
113
+ Dictionary with preview results including quantization plan
114
+ """
115
+ try:
116
+ # Set deterministic seed
117
+ torch.manual_seed(self.seed)
118
+ random.seed(self.seed)
119
+ np.random.seed(self.seed)
120
+
121
+ # Get model description
122
+ model_desc = adapter.describe(model)
123
+
124
+ # Identify target modules
125
+ target_modules = self._identify_target_modules(model)
126
+ total_identified = len(target_modules)
127
+
128
+ if (
129
+ isinstance(self.max_modules, int)
130
+ and self.max_modules > 0
131
+ and self.max_modules < total_identified
132
+ ):
133
+ target_modules = target_modules[: self.max_modules]
134
+
135
+ # Compute quantization statistics
136
+ quant_stats = self._compute_quantization_stats(target_modules)
137
+
138
+ # Estimate parameter changes
139
+ total_params = sum(p.numel() for p in model.parameters())
140
+ target_params = sum(module.weight.numel() for _, module in target_modules)
141
+
142
+ # Create quantization plan
143
+ plan = {
144
+ "operation": "rtn_quantization",
145
+ "bitwidth": self.bitwidth,
146
+ "per_channel": self.per_channel,
147
+ "group_size": self.group_size if self.bitwidth == 4 else None,
148
+ "clamp_ratio": self.clamp_ratio,
149
+ "scope": self.scope,
150
+ "seed": self.seed,
151
+ "target_modules": [name for name, _ in target_modules],
152
+ "quantization_stats": quant_stats,
153
+ "anti_tying_map": self._get_weight_tying_map(model),
154
+ }
155
+ if (
156
+ isinstance(self.max_modules, int)
157
+ and self.max_modules > 0
158
+ and self.max_modules < total_identified
159
+ ):
160
+ plan["max_modules"] = self.max_modules
161
+
162
+ # Estimate sparsity (RTN doesn't create structural sparsity)
163
+ estimated_sparsity = {
164
+ "head_sparsity": 0.0,
165
+ "neuron_sparsity": 0.0,
166
+ "weight_sparsity": 0.0, # RTN doesn't create weight sparsity
167
+ }
168
+
169
+ # Preview metrics
170
+ bits_per_param = self.bitwidth
171
+ if self.bitwidth == 4 and self.group_size:
172
+ # Account for scale storage
173
+ scales_per_group = target_params / self.group_size
174
+ bits_per_param = 4 + (
175
+ 32 * scales_per_group / target_params
176
+ ) # 32-bit scales
177
+
178
+ memory_reduction_estimate = (
179
+ target_params * (32 - bits_per_param) / 8
180
+ ) # bytes
181
+
182
+ preview_metrics = {
183
+ "preview_duration": 0.0,
184
+ "target_params": int(target_params),
185
+ "total_params": int(total_params),
186
+ "coverage_ratio": target_params / total_params
187
+ if total_params > 0
188
+ else 0.0,
189
+ "target_modules_count": len(target_modules),
190
+ "estimated_memory_saved_bytes": int(memory_reduction_estimate),
191
+ "estimated_bits_per_param": bits_per_param,
192
+ "will_use_clipping": self.clamp_ratio > 0.0,
193
+ "will_use_grouping": self.bitwidth == 4 and self.group_size is not None,
194
+ }
195
+
196
+ return {
197
+ "plan": plan,
198
+ "estimated_sparsity": estimated_sparsity,
199
+ "preview_metrics": preview_metrics,
200
+ "model_info": model_desc,
201
+ }
202
+
203
+ except Exception as e:
204
+ # Return error in preview
205
+ return {
206
+ "plan": {"operation": "failed", "error": str(e)},
207
+ "estimated_sparsity": {
208
+ "head_sparsity": 0.0,
209
+ "neuron_sparsity": 0.0,
210
+ "weight_sparsity": 0.0,
211
+ },
212
+ "preview_metrics": {"error": str(e)},
213
+ "model_info": {},
214
+ }
215
+
216
+ def apply(self, model: Any, adapter, **kwargs) -> dict[str, Any]:
217
+ """
218
+ Apply RTN quantization to the model.
219
+
220
+ Args:
221
+ model: The model to edit (modified in-place)
222
+ adapter: ModelAdapter for model-specific operations
223
+ **kwargs: Edit parameters and configuration
224
+
225
+ Returns:
226
+ Dictionary with application results
227
+ """
228
+ try:
229
+ # Extract configuration from kwargs - handle both 'bits' and 'bitwidth' for compatibility
230
+ bitwidth = kwargs.get("bitwidth", kwargs.get("bits", self.bitwidth))
231
+ group_size = kwargs.get("group_size", self.group_size)
232
+ clamp_ratio = kwargs.get("clamp_ratio", self.clamp_ratio)
233
+ scope = kwargs.get("scope", self.scope)
234
+ seed = kwargs.get("seed", self.seed)
235
+
236
+ # Diagnostic reporting
237
+ print("🔧 RTN Quantization Configuration:")
238
+ print(
239
+ f" Bitwidth: {bitwidth} (from config: {kwargs.get('bitwidth', kwargs.get('bits', 'default'))})"
240
+ )
241
+ print(f" Scope: {scope}")
242
+ print(f" Group size: {group_size}")
243
+ print(f" Clamp ratio: {clamp_ratio}")
244
+ print(f" Seed: {seed}")
245
+
246
+ # Persist configuration overrides for downstream helpers
247
+ self.bitwidth = bitwidth
248
+ self.group_size = group_size
249
+ self.clamp_ratio = clamp_ratio
250
+ self.scope = scope
251
+ self.seed = seed
252
+
253
+ # Set deterministic seed
254
+ torch.manual_seed(seed)
255
+ random.seed(seed)
256
+ np.random.seed(seed)
257
+
258
+ # Identify target modules and get weight tying map
259
+ print(f"🎯 Identifying target modules for scope '{scope}'...")
260
+ target_modules = self._identify_target_modules(model)
261
+ total_identified = len(target_modules)
262
+
263
+ max_modules = kwargs.get("max_modules")
264
+ if isinstance(max_modules, int) and max_modules > 0:
265
+ if max_modules < total_identified:
266
+ print(
267
+ f" Limiting quantization to first {max_modules} modules "
268
+ f"(of {total_identified}) based on plan.max_modules"
269
+ )
270
+ target_modules = target_modules[:max_modules]
271
+ self.max_modules = max_modules
272
+ else:
273
+ print(
274
+ f" max_modules={max_modules} >= available modules "
275
+ f"({total_identified}); using all targets"
276
+ )
277
+ self.max_modules = None
278
+ else:
279
+ self.max_modules = None
280
+
281
+ tying_map = self._get_weight_tying_map(model)
282
+
283
+ print(f" Found {len(target_modules)} target modules:")
284
+ for i, (name, module) in enumerate(target_modules):
285
+ weight_shape = module.weight.shape
286
+ param_count = module.weight.numel()
287
+ print(f" [{i + 1}] {name}: {weight_shape} ({param_count:,} params)")
288
+
289
+ if len(target_modules) == 0:
290
+ print("❌ WARNING: No target modules found! Check scope configuration.")
291
+ print(" Available linear modules:")
292
+ linear_modules = []
293
+ for name, module in model.named_modules():
294
+ if isinstance(module, nn.Linear | nn.Conv1d):
295
+ linear_modules.append((name, module.weight.shape))
296
+ for name, shape in linear_modules[:10]: # Show first 10
297
+ print(f" {name}: {shape}")
298
+ if len(linear_modules) > 10:
299
+ print(f" ... and {len(linear_modules) - 10} more")
300
+
301
+ # Execute GuardChain before edit (if provided)
302
+ guard_results = {}
303
+ if self.guard_chain is not None:
304
+ print(" Executing guard chain preparation...")
305
+ guard_results["prepare"] = self.guard_chain.prepare_all(
306
+ model, adapter, None, {}
307
+ )
308
+
309
+ print(" Executing before-edit guards...")
310
+ self.guard_chain.before_edit_all(model)
311
+
312
+ # Apply quantization to each target module
313
+ quantization_results = []
314
+ total_params_quantized = 0
315
+
316
+ for i, (module_name, module) in enumerate(target_modules):
317
+ print(f" [{i + 1}/{len(target_modules)}] Quantizing: {module_name}")
318
+ print(
319
+ f" Shape: {module.weight.shape}, Params: {module.weight.numel():,}"
320
+ )
321
+ print(
322
+ f" Weight range: [{module.weight.min():.4f}, {module.weight.max():.4f}]"
323
+ )
324
+
325
+ # Apply RTN quantization
326
+ quant_result = self._apply_rtn_quantization(
327
+ module,
328
+ bitwidth,
329
+ group_size,
330
+ clamp_ratio,
331
+ tying_map.get(module_name),
332
+ )
333
+
334
+ quant_result["module_name"] = module_name
335
+ quantization_results.append(quant_result)
336
+ total_params_quantized += quant_result["params_quantized"]
337
+
338
+ print(
339
+ f" ✓ Quantized {quant_result['params_quantized']:,} parameters"
340
+ )
341
+
342
+ # Execute GuardChain after edit (if provided)
343
+ if self.guard_chain is not None:
344
+ print(" Executing after-edit guards...")
345
+ self.guard_chain.after_edit_all(model)
346
+
347
+ print(" Finalizing guard chain...")
348
+ guard_results["finalize"] = self.guard_chain.finalize_all(model)
349
+
350
+ # Check if all guards passed
351
+ if not self.guard_chain.all_passed(guard_results["finalize"]):
352
+ print(" ⚠️ Guard chain validation failed!")
353
+ guard_results["all_passed"] = False
354
+ else:
355
+ print(" ✓ All guards passed")
356
+ guard_results["all_passed"] = True
357
+
358
+ # Create bitwidth map
359
+ bitwidth_map = {}
360
+ for result in quantization_results:
361
+ bitwidth_map[result["module_name"]] = {
362
+ "bitwidth": bitwidth,
363
+ "group_size": group_size if bitwidth == 4 else None,
364
+ "params": result["params_quantized"],
365
+ "scale_stats": result.get("scale_stats", {}),
366
+ }
367
+
368
+ # Identify modified layers
369
+ modified_layers = []
370
+ for result in quantization_results:
371
+ # Extract layer name from module name (e.g., "transformer.h.0.mlp.c_fc" -> "layer_0")
372
+ name_parts = result["module_name"].split(".")
373
+ if "h" in name_parts:
374
+ h_idx = name_parts.index("h")
375
+ if h_idx + 1 < len(name_parts):
376
+ layer_num = name_parts[h_idx + 1]
377
+ layer_name = f"layer_{layer_num}"
378
+ if layer_name not in modified_layers:
379
+ modified_layers.append(layer_name)
380
+
381
+ # Store edit plan for certificate generation
382
+ modules_quantized = [r["module_name"] for r in quantization_results]
383
+
384
+ edit_plan = {
385
+ "bitwidth": bitwidth,
386
+ "scope": scope,
387
+ "group_size": group_size,
388
+ "clamp_ratio": clamp_ratio,
389
+ "seed": seed,
390
+ "total_modules_quantized": len(modules_quantized),
391
+ "total_params_quantized": total_params_quantized,
392
+ "modules_quantized": modules_quantized,
393
+ }
394
+
395
+ # Return in the standard format expected by the framework
396
+ return {
397
+ "name": self.name,
398
+ "plan_digest": f"rtn_quantization_{bitwidth}bit_{scope}",
399
+ "plan": edit_plan, # Include the plan for certificate generation
400
+ "deltas": {
401
+ "params_changed": total_params_quantized,
402
+ "sparsity": None, # Quantization doesn't create sparsity
403
+ "bitwidth_map": bitwidth_map,
404
+ "layers_modified": len(modified_layers),
405
+ },
406
+ "config": kwargs,
407
+ "model_desc": adapter.describe(model)
408
+ if hasattr(adapter, "describe")
409
+ else {},
410
+ }
411
+
412
+ except Exception as e:
413
+ # Return error in expected format
414
+ return {
415
+ "name": self.name,
416
+ "plan_digest": "rtn_quantization_failed",
417
+ "deltas": {
418
+ "params_changed": 0,
419
+ "sparsity": None,
420
+ "bitwidth_map": None,
421
+ "layers_modified": 0,
422
+ },
423
+ "config": kwargs,
424
+ "model_desc": {},
425
+ "error": str(e),
426
+ }
427
+
428
+ def _identify_target_modules(self, model: nn.Module) -> list[tuple[str, nn.Module]]:
429
+ """Identify target modules based on scope configuration."""
430
+ target_modules = []
431
+ skipped_modules = []
432
+
433
+ for name, module in model.named_modules():
434
+ # Check for both Linear and Conv1D (GPT-2 uses Conv1D)
435
+ if not isinstance(module, nn.Linear | nn.Conv1d):
436
+ # Import Conv1D from transformers if available
437
+ try:
438
+ from transformers.pytorch_utils import Conv1D
439
+
440
+ if not isinstance(module, Conv1D):
441
+ continue
442
+ except ImportError:
443
+ continue
444
+
445
+ # Check scope
446
+ should_include = False
447
+ if self.scope == "ffn":
448
+ # FFN layers - be more permissive with pattern matching
449
+ ffn_patterns = [
450
+ "mlp.c_fc",
451
+ "mlp.c_proj",
452
+ "feed_forward",
453
+ "fc1",
454
+ "fc2",
455
+ "mlp",
456
+ "ffn",
457
+ "intermediate.dense",
458
+ "output.dense",
459
+ ]
460
+ if any(pattern in name.lower() for pattern in ffn_patterns):
461
+ should_include = True
462
+ elif self.scope == "attn":
463
+ # Attention layers - be more permissive with pattern matching
464
+ attn_patterns = [
465
+ "attn.c_attn",
466
+ "attn.c_proj",
467
+ "attention",
468
+ "q_proj",
469
+ "k_proj",
470
+ "v_proj",
471
+ "o_proj",
472
+ "attn",
473
+ ]
474
+ if any(pattern in name.lower() for pattern in attn_patterns):
475
+ should_include = True
476
+ elif self.scope == "all":
477
+ # All linear layers above a minimum size threshold
478
+ if module.weight.numel() >= 100: # Minimum parameter threshold
479
+ should_include = True
480
+ else:
481
+ skipped_modules.append(
482
+ (name, f"too small ({module.weight.numel()} params)")
483
+ )
484
+
485
+ if should_include:
486
+ target_modules.append((name, module))
487
+ else:
488
+ if self.scope != "all": # Only log for specific scopes
489
+ skipped_modules.append((name, f"scope mismatch ({self.scope})"))
490
+
491
+ # Log diagnostic information
492
+ if skipped_modules:
493
+ print(f" Skipped {len(skipped_modules)} modules:")
494
+ for name, reason in skipped_modules[:5]: # Show first 5
495
+ print(f" {name}: {reason}")
496
+ if len(skipped_modules) > 5:
497
+ print(f" ... and {len(skipped_modules) - 5} more")
498
+
499
+ return target_modules
500
+
501
+ def _get_module_by_name(self, model: nn.Module, name: str) -> nn.Module | None:
502
+ """Get module by dotted name."""
503
+ try:
504
+ parts = name.split(".")
505
+ module = model
506
+ for part in parts:
507
+ module = getattr(module, part)
508
+ return module
509
+ except AttributeError:
510
+ return None
511
+
512
+ def _get_weight_tying_map(self, model: nn.Module) -> dict[str, list[str]]:
513
+ """Identify weight tying relationships for preservation."""
514
+ tying_map = {}
515
+
516
+ # Common tying patterns (e.g., lm_head and wte sharing weights)
517
+ weight_to_modules: dict[int, list[str]] = {}
518
+
519
+ for name, module in model.named_modules():
520
+ if hasattr(module, "weight") and module.weight is not None:
521
+ weight_id = id(module.weight)
522
+ if weight_id not in weight_to_modules:
523
+ weight_to_modules[weight_id] = []
524
+ weight_to_modules[weight_id].append(name)
525
+
526
+ # Create tying map
527
+ for _weight_id, module_names in weight_to_modules.items():
528
+ if len(module_names) > 1:
529
+ for name in module_names:
530
+ tying_map[name] = [n for n in module_names if n != name]
531
+
532
+ return tying_map
533
+
534
+ def _compute_quantization_stats(
535
+ self, target_modules: list[tuple[str, nn.Module]]
536
+ ) -> dict[str, Any]:
537
+ """Compute statistics about what will be quantized."""
538
+ stats = {
539
+ "total_modules": len(target_modules),
540
+ "total_params": 0,
541
+ "module_stats": [],
542
+ }
543
+
544
+ for name, module in target_modules:
545
+ weight = module.weight
546
+ module_stat = {
547
+ "name": name,
548
+ "shape": list(weight.shape),
549
+ "params": weight.numel(),
550
+ "weight_range": [float(weight.min()), float(weight.max())],
551
+ "weight_mean": float(weight.mean()),
552
+ "weight_std": float(weight.std()),
553
+ }
554
+
555
+ # Compute per-channel statistics
556
+ if len(weight.shape) >= 2:
557
+ channel_stats = []
558
+ for c in range(weight.shape[0]): # Output channels
559
+ channel_weight = weight[c]
560
+ channel_stats.append(
561
+ {
562
+ "channel": c,
563
+ "absmax": float(channel_weight.abs().max()),
564
+ "mean": float(channel_weight.mean()),
565
+ "std": float(channel_weight.std()),
566
+ }
567
+ )
568
+ module_stat["channel_stats"] = channel_stats[:10] # Limit for preview
569
+
570
+ stats["module_stats"].append(module_stat)
571
+ stats["total_params"] += module_stat["params"]
572
+
573
+ return stats
574
+
575
+ def _apply_rtn_quantization(
576
+ self,
577
+ module: nn.Module,
578
+ bitwidth: int,
579
+ group_size: int | None,
580
+ clamp_ratio: float,
581
+ tied_modules: list[str] | None = None,
582
+ ) -> dict[str, Any]:
583
+ """Apply RTN quantization to a single module."""
584
+ weight = module.weight.data
585
+ original_shape = weight.shape
586
+ params_quantized = weight.numel()
587
+
588
+ # Store original for comparison
589
+ original_weight = weight.clone()
590
+
591
+ # Flatten weight for processing
592
+ if len(weight.shape) == 1:
593
+ # Handle bias or 1D weights
594
+ weight_2d = weight.unsqueeze(0)
595
+ is_1d = True
596
+ else:
597
+ weight_2d = weight.view(weight.shape[0], -1) # [out_channels, in_features]
598
+ is_1d = False
599
+
600
+ # Apply outlier clipping if requested
601
+ if clamp_ratio > 0.0:
602
+ weight_2d = self._apply_outlier_clipping(weight_2d, clamp_ratio)
603
+
604
+ # Compute quantization parameters
605
+ qmin = -(2 ** (bitwidth - 1))
606
+ qmax = 2 ** (bitwidth - 1) - 1
607
+
608
+ if bitwidth == 4 and group_size is not None:
609
+ # Group-wise quantization for 4-bit
610
+ quantized_weight, scales, scale_stats = self._quantize_grouped(
611
+ weight_2d, qmin, qmax, group_size
612
+ )
613
+ else:
614
+ # Per-channel quantization
615
+ quantized_weight, scales, scale_stats = self._quantize_per_channel(
616
+ weight_2d, qmin, qmax
617
+ )
618
+
619
+ # Reshape back to original shape
620
+ if is_1d:
621
+ quantized_weight = quantized_weight.squeeze(0)
622
+ else:
623
+ quantized_weight = quantized_weight.view(original_shape)
624
+
625
+ # Ensure actual quantization occurred by applying quantization loss
626
+ # This guarantees the weights are actually modified
627
+ quantization_error = (quantized_weight - original_weight).abs().mean()
628
+ print(f" Quantization error: {quantization_error:.6f}")
629
+
630
+ # Write back to module (preserving tying if needed)
631
+ module.weight.data.copy_(quantized_weight)
632
+
633
+ # Verify the weights actually changed
634
+ final_weight = module.weight.data
635
+ actual_change = not torch.allclose(original_weight, final_weight, atol=1e-6)
636
+ if not actual_change:
637
+ print(f" WARNING: No actual weight change detected for {module}")
638
+
639
+ # Handle tied weights
640
+ if tied_modules:
641
+ for _tied_name in tied_modules:
642
+ # In a real implementation, we'd update tied modules here
643
+ # For now, just log
644
+ pass
645
+
646
+ return {
647
+ "params_quantized": params_quantized,
648
+ "original_shape": original_shape,
649
+ "bitwidth": bitwidth,
650
+ "group_size": group_size,
651
+ "scale_stats": scale_stats,
652
+ "clamp_applied": clamp_ratio > 0.0,
653
+ }
654
+
655
+ def _apply_outlier_clipping(
656
+ self, weight: torch.Tensor, clamp_ratio: float
657
+ ) -> torch.Tensor:
658
+ """Apply outlier clipping based on quantile thresholds."""
659
+ if clamp_ratio <= 0.0:
660
+ return weight
661
+
662
+ lower = clamp_ratio / 2
663
+ upper = 1 - lower
664
+ eps = torch.finfo(weight.dtype).eps
665
+
666
+ # Compute per-output-channel quantiles to preserve channel statistics
667
+ quantiles = torch.quantile(
668
+ weight,
669
+ torch.tensor([lower, upper], device=weight.device, dtype=weight.dtype),
670
+ dim=1,
671
+ keepdim=True,
672
+ )
673
+
674
+ q_low = quantiles[0].clamp_min(-torch.inf)
675
+ q_high = quantiles[1].clamp_min(eps)
676
+ return torch.clamp(weight, q_low, q_high)
677
+
678
+ def _quantize_per_channel(
679
+ self, weight: torch.Tensor, qmin: int, qmax: int
680
+ ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]:
681
+ """Apply per-channel symmetric quantization."""
682
+ # Compute per-channel scales (per output channel)
683
+ channel_absmax = weight.abs().max(dim=1, keepdim=True)[0] # [out_channels, 1]
684
+
685
+ # Avoid division by zero
686
+ eps = 1e-8
687
+ channel_absmax = torch.clamp(channel_absmax, min=eps)
688
+
689
+ # Symmetric quantization scale
690
+ scales = channel_absmax / qmax
691
+
692
+ # Quantize
693
+ weight_scaled = weight / scales
694
+ weight_quantized = torch.clamp(torch.round(weight_scaled), qmin, qmax)
695
+
696
+ # Dequantize (write back as float)
697
+ weight_dequantized = weight_quantized * scales
698
+
699
+ # Compute statistics
700
+ scale_stats = {
701
+ "scale_mean": float(scales.mean()),
702
+ "scale_std": float(scales.std()),
703
+ "scale_min": float(scales.min()),
704
+ "scale_max": float(scales.max()),
705
+ "zero_scales": int((scales <= eps).sum()),
706
+ }
707
+
708
+ return weight_dequantized, scales.squeeze(), scale_stats
709
+
710
+ def _quantize_grouped(
711
+ self, weight: torch.Tensor, qmin: int, qmax: int, group_size: int
712
+ ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]:
713
+ """Apply group-wise quantization for 4-bit mode."""
714
+ out_channels, in_features = weight.shape
715
+
716
+ # Pad input features to be divisible by group_size
717
+ pad_size = (group_size - (in_features % group_size)) % group_size
718
+ if pad_size > 0:
719
+ weight_padded = torch.cat(
720
+ [weight, torch.zeros(out_channels, pad_size, device=weight.device)],
721
+ dim=1,
722
+ )
723
+ else:
724
+ weight_padded = weight
725
+
726
+ padded_in_features = weight_padded.shape[1]
727
+ num_groups = padded_in_features // group_size
728
+
729
+ # Reshape for group processing
730
+ weight_grouped = weight_padded.view(out_channels, num_groups, group_size)
731
+
732
+ # Compute per-group scales
733
+ group_absmax = weight_grouped.abs().max(dim=2, keepdim=True)[
734
+ 0
735
+ ] # [out_channels, num_groups, 1]
736
+
737
+ # Avoid division by zero
738
+ eps = 1e-8
739
+ group_absmax = torch.clamp(group_absmax, min=eps)
740
+
741
+ # Symmetric quantization scale
742
+ scales = group_absmax / qmax
743
+
744
+ # Quantize
745
+ weight_scaled = weight_grouped / scales
746
+ weight_quantized = torch.clamp(torch.round(weight_scaled), qmin, qmax)
747
+
748
+ # Dequantize
749
+ weight_dequantized = weight_quantized * scales
750
+
751
+ # Reshape back and remove padding
752
+ weight_dequantized = weight_dequantized.view(out_channels, padded_in_features)
753
+ if pad_size > 0:
754
+ weight_dequantized = weight_dequantized[:, :-pad_size]
755
+
756
+ # Compute statistics
757
+ scale_stats = {
758
+ "scale_mean": float(scales.mean()),
759
+ "scale_std": float(scales.std()),
760
+ "scale_min": float(scales.min()),
761
+ "scale_max": float(scales.max()),
762
+ "num_groups": num_groups,
763
+ "group_size": group_size,
764
+ "zero_scales": int((scales <= eps).sum()),
765
+ }
766
+
767
+ return weight_dequantized, scales.view(-1), scale_stats
768
+
769
+
770
+ # For backward compatibility, provide a functional interface
771
+ def apply(
772
+ model: nn.Module,
773
+ adapter: ModelAdapter,
774
+ plan: dict[Any, Any] | None = None,
775
+ **kwargs,
776
+ ) -> dict:
777
+ """
778
+ Apply RTN quantization using the RTNQuantEdit API.
779
+
780
+ This is the recommended interface that follows the ModelEdit protocol.
781
+ """
782
+ if plan is None:
783
+ # Create plan from kwargs
784
+ edit = RTNQuantEdit(
785
+ bitwidth=kwargs.get("bitwidth", 8),
786
+ per_channel=kwargs.get("per_channel", True),
787
+ group_size=kwargs.get("group_size"),
788
+ clamp_ratio=kwargs.get("clamp_ratio", 0.0),
789
+ scope=kwargs.get("scope", "ffn"),
790
+ seed=kwargs.get("seed", 42),
791
+ max_modules=kwargs.get("max_modules"),
792
+ )
793
+
794
+ # Need calibration data for preview (though RTN doesn't use it)
795
+ calib = kwargs.get("calib")
796
+ preview_result = edit.preview(model, adapter, calib)
797
+ plan = preview_result["plan"]
798
+
799
+ # Apply the plan
800
+ edit = RTNQuantEdit()
801
+ return edit.apply(model, adapter, plan)