sam-gate 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sam_gate/__init__.py ADDED
@@ -0,0 +1,10 @@
1
+ """SAM-Gate: semantic KV compression and RCI-guided quantization."""
2
+
3
+ from .config import SAMConfig
4
+ from .sam import attach_semantic_hooks, verify_sam_wblk_caches
5
+
6
+ __all__ = [
7
+ "SAMConfig",
8
+ "attach_semantic_hooks",
9
+ "verify_sam_wblk_caches",
10
+ ]
sam_gate/config.py ADDED
@@ -0,0 +1,206 @@
1
+ """
2
+ SAM-Gate hyperparameters and defaults.
3
+
4
+ Centralizes values that were previously scattered across the main script, CLI, and heuristics.
5
+ Adjust here to calibrate policy, prober, demo prompts, and KV estimates.
6
+
7
+ Typical usage:
8
+ from sam_gate.config import SAMConfig, DEFAULT_CLI
9
+ cfg = SAMConfig()
10
+ # or
11
+ cfg = SAMConfig(tau=0.03, max_ctx_flat=256)
12
+
13
+ ─── CALIBRATION (Qwen2.5-7B-Instruct) ───────────────────────────────────────────
14
+ Thresholds below are NOT yet calibrated for Qwen2.5-7B-Instruct.
15
+ Run calibration before use:
16
+
17
+ python -m sam_gate.sam --model Qwen/Qwen2.5-7B-Instruct --calibrate --verbose
18
+
19
+ Then adjust f_flat_max / f_obs_min in SAMConfig and SAMCliDefaults according to
20
+ the observed f(t) distribution per layer.
21
+
22
+ Reference values from Qwen2.5-3B-Instruct (for comparison):
23
+ L00-L01 : f(t) ~ 2e-05 – 2e-02 (quiet layers)
24
+ L02-L09 : f(t) ~ 3e-02 – 1.2e+00 (medium layers)
25
+ L10-L22 : f(t) ~ 4e-02 – 2.9e+01 (active layers)
26
+ L23 : f(t) ~ 5.2e+02 – 1.4e+03 (most active layer)
27
+ ─────────────────────────────────────────────────────────────────────────────────
28
+ """
29
+
30
+ from __future__ import annotations
31
+
32
+ from dataclasses import dataclass
33
+
34
+
35
+ # ── Prompts ──────────────────────────────────────────────────────────────────
36
+
37
+ STRESS_TEST_PROMPT: str = (
38
+ "Compare Gödel incompleteness with Turing undecidability and explain the "
39
+ "implications for formal systems under different axiomatic assumptions."
40
+ )
41
+
42
+ DEFAULT_DEMO_PROMPT: str = (
43
+ "Explain how attention mechanisms work in transformers."
44
+ )
45
+
46
+ CALIBRATION_SIMPLE_PROMPT: str = "What is the capital of France?"
47
+
48
+ # Base paragraph for --long_prompt / _make_long_prompt (repeated until target tokens reached)
49
+ LONG_PROMPT_FILLER_BASE: str = (
50
+ "The relationship between computational complexity and semantic meaning in "
51
+ "large language models is a deeply contested area of research. Attention "
52
+ "mechanisms allow the model to dynamically weight different parts of the "
53
+ "input sequence, creating context-dependent representations at each layer. "
54
+ "The key-value cache stores intermediate computations to avoid redundant "
55
+ "forward passes during autoregressive generation. "
56
+ )
57
+
58
+ # Heuristic ~ chars per token for sizing long text
59
+ LONG_PROMPT_CHARS_PER_TOKEN: float = 4.5
60
+
61
+ # Memory test: random secret code (6 digits)
62
+ SECRET_CODE_MIN: int = 100_000
63
+ SECRET_CODE_MAX: int = 999_999
64
+
65
+
66
+ # ── Attention structure (inference when model doesn't expose head_dim) ───────
67
+
68
+ HEAD_DIM_CANDIDATES: tuple[int, ...] = (32, 64, 80, 96, 128, 256)
69
+
70
+
71
+ # ── RealDynamicsProber ───────────────────────────────────────────────────────
72
+
73
+ # Maximum heads used for Gram norms per step (cost vs stability)
74
+ PROBE_HEADS_CAP: int = 4
75
+
76
+
77
+ # ── Heuristic KV estimate (_estimate_max_kv_mb) ──────────────────────────────
78
+ # Qwen2.5-7B-Instruct: 28 layers, 8 KV heads, head_dim=128.
79
+ # Does not replace actual model measurement — used only for pre-run estimates.
80
+
81
+ @dataclass
82
+ class SAMKVHeuristicEstimate:
83
+ kv_heads: int = 8
84
+ head_dim: int = 128
85
+ n_layers_flat: int = 2
86
+ n_layers_trans: int = 22
87
+ n_layers_obs: int = 4
88
+
89
+
90
+ DEFAULT_KV_HEURISTIC = SAMKVHeuristicEstimate()
91
+
92
+
93
+ # ── Generation (benchmark / main) ────────────────────────────────────────────
94
+
95
+ @dataclass
96
+ class SAMGenerationDefaults:
97
+ max_new_tokens: int = 100
98
+ max_new_tokens_calibrate: int = 30
99
+ use_cache: bool = False
100
+ do_sample: bool = False
101
+
102
+
103
+ DEFAULT_GENERATION = SAMGenerationDefaults()
104
+
105
+
106
+ # ── CLI and interactive wizard (argparse defaults / questions) ────────────────
107
+ # NOTE: all policy parameters here must stay aligned with SAMConfig defaults below.
108
+
109
+ @dataclass
110
+ class SAMCliDefaults:
111
+ model: str = "Qwen/Qwen2.5-7B-Instruct"
112
+ device: str = "cpu"
113
+ prompt: str = DEFAULT_DEMO_PROMPT
114
+ max_new_tokens: int = 100
115
+ tau: float = 0.05
116
+ flat_bits: int = 4
117
+ trans_bits: int = 8
118
+ obs_bits: int = 16
119
+ flat_heads: float = 0.5 # aligned with SAMConfig
120
+ trans_heads: float = 0.75 # aligned with SAMConfig
121
+ f_flat_max: float = 5.0 # ⚠ NOT calibrated for 7B — run --calibrate
122
+ f_obs_min: float = 500.0 # ⚠ NOT calibrated for 7B — run --calibrate
123
+ max_ctx_flat: int = 128 # aligned with SAMConfig
124
+ max_ctx_trans: int = 512 # aligned with SAMConfig
125
+ max_ctx_obs: int = 2048 # aligned with SAMConfig
126
+ long_tokens: int = 2000
127
+ # Wizard: suggested minimums in questions
128
+ wizard_max_new_tokens_min: int = 10
129
+ wizard_ctx_flat_min: int = 16
130
+ wizard_ctx_trans_min: int = 64
131
+
132
+
133
+ DEFAULT_CLI = SAMCliDefaults()
134
+
135
+
136
+ # ── SAMConfig — core of semantic policy ──────────────────────────────────────
137
+
138
+ @dataclass
139
+ class SAMConfig:
140
+ """
141
+ SAM-Gate (Semantic-Aware Memory Gate) parameters.
142
+
143
+ Typical calibration: --calibrate + adjust f_flat_max / f_obs_min according to
144
+ the observed f(t) distribution per layer.
145
+
146
+ ⚠ Thresholds below are NOT yet calibrated for Qwen2.5-7B-Instruct.
147
+ Run: python -m sam_gate.sam --model Qwen/Qwen2.5-7B-Instruct --calibrate --verbose
148
+ """
149
+
150
+ tau: float = 0.05
151
+ flat_bits: int = 4
152
+ trans_bits: int = 8
153
+ obs_bits: int = 16
154
+ flat_heads: float = 0.5
155
+ trans_heads: float = 0.75
156
+ f_flat_max: float = 1e-2 # ⚠ NOT calibrated for 7B — run --calibrate
157
+ f_obs_min: float = 50.0 # ⚠ NOT calibrated for 7B — run --calibrate
158
+ dsem_obs_thresh: float = 5.0
159
+ ema_alpha: float = 0.3
160
+ # If True: probe_from_output only on prefill; decode reuses EMA from end of prefill.
161
+ probe_prefill_only: bool = True
162
+ max_ctx_flat: int = 128 # window in flat regime
163
+ max_ctx_trans: int = 512 # window in transition
164
+ max_ctx_obs: int = 2048 # window in obstructed
165
+ # Decode with dense ring: use flash_attn.flash_attn_interface.flash_attn_with_kvcache
166
+ # (CUDA) when available — fused kernel vs SDPA + materializing contiguous K/V.
167
+ use_flash_attn_kvcache: bool = True
168
+
169
+
170
+ # ── Per-layer state (initialization) ─────────────────────────────────────────
171
+
172
+ @dataclass
173
+ class LayerSemanticState:
174
+ d_eff_ratio: float = 1.0
175
+ f_t: float = 0.0
176
+ D_sem: float = 0.0
177
+ H_sem: float = 1.0
178
+ step: int = 0
179
+
180
+
181
+ # ── Neutral policy for --calibrate (observe f(t) without compressing) ────────
182
+
183
+ CALIBRATION_F_SENTINEL: float = 1e9
184
+
185
+
186
+ def neutral_sam_config_for_calibration(tau: float) -> SAMConfig:
187
+ """SAMConfig that disables compression and windows for measuring f(t) / d_eff."""
188
+ return SAMConfig(
189
+ tau=tau,
190
+ flat_bits=16,
191
+ trans_bits=16,
192
+ obs_bits=16,
193
+ flat_heads=1.0,
194
+ trans_heads=1.0,
195
+ f_flat_max=CALIBRATION_F_SENTINEL,
196
+ f_obs_min=CALIBRATION_F_SENTINEL,
197
+ max_ctx_flat=0,
198
+ max_ctx_trans=0,
199
+ max_ctx_obs=0,
200
+ use_flash_attn_kvcache=False,
201
+ )
202
+
203
+
204
+ def bits_to_kv_prec_str(bits: int, *, fallback: str = "int4") -> str:
205
+ """Symbolic name of KV precision used in MemoryPolicy.prec_str."""
206
+ return {4: "int4", 8: "int8", 16: "fp16"}.get(bits, fallback)