sam-gate 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sam_gate/__init__.py +10 -0
- sam_gate/config.py +206 -0
- sam_gate/kv_cache.py +2384 -0
- sam_gate/sam.py +2420 -0
- sam_gate/spectral.py +1685 -0
- sam_gate-0.1.0.dist-info/METADATA +28 -0
- sam_gate-0.1.0.dist-info/RECORD +8 -0
- sam_gate-0.1.0.dist-info/WHEEL +4 -0
sam_gate/__init__.py
ADDED
sam_gate/config.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SAM-Gate hyperparameters and defaults.
|
|
3
|
+
|
|
4
|
+
Centralizes values that were previously scattered across the main script, CLI, and heuristics.
|
|
5
|
+
Adjust here to calibrate policy, prober, demo prompts, and KV estimates.
|
|
6
|
+
|
|
7
|
+
Typical usage:
|
|
8
|
+
from sam_gate.config import SAMConfig, DEFAULT_CLI
|
|
9
|
+
cfg = SAMConfig()
|
|
10
|
+
# or
|
|
11
|
+
cfg = SAMConfig(tau=0.03, max_ctx_flat=256)
|
|
12
|
+
|
|
13
|
+
─── CALIBRATION (Qwen2.5-7B-Instruct) ───────────────────────────────────────────
|
|
14
|
+
Thresholds below are NOT yet calibrated for Qwen2.5-7B-Instruct.
|
|
15
|
+
Run calibration before use:
|
|
16
|
+
|
|
17
|
+
python -m sam_gate.sam --model Qwen/Qwen2.5-7B-Instruct --calibrate --verbose
|
|
18
|
+
|
|
19
|
+
Then adjust f_flat_max / f_obs_min in SAMConfig and SAMCliDefaults according to
|
|
20
|
+
the observed f(t) distribution per layer.
|
|
21
|
+
|
|
22
|
+
Reference values from Qwen2.5-3B-Instruct (for comparison):
|
|
23
|
+
L00-L01 : f(t) ~ 2e-05 – 2e-02 (quiet layers)
|
|
24
|
+
L02-L09 : f(t) ~ 3e-02 – 1.2e+00 (medium layers)
|
|
25
|
+
L10-L22 : f(t) ~ 4e-02 – 2.9e+01 (active layers)
|
|
26
|
+
L23 : f(t) ~ 5.2e+02 – 1.4e+03 (most active layer)
|
|
27
|
+
─────────────────────────────────────────────────────────────────────────────────
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
from __future__ import annotations
|
|
31
|
+
|
|
32
|
+
from dataclasses import dataclass
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# ── Prompts ──────────────────────────────────────────────────────────────────
|
|
36
|
+
|
|
37
|
+
STRESS_TEST_PROMPT: str = (
|
|
38
|
+
"Compare Gödel incompleteness with Turing undecidability and explain the "
|
|
39
|
+
"implications for formal systems under different axiomatic assumptions."
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
DEFAULT_DEMO_PROMPT: str = (
|
|
43
|
+
"Explain how attention mechanisms work in transformers."
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
CALIBRATION_SIMPLE_PROMPT: str = "What is the capital of France?"
|
|
47
|
+
|
|
48
|
+
# Base paragraph for --long_prompt / _make_long_prompt (repeated until target tokens reached)
|
|
49
|
+
LONG_PROMPT_FILLER_BASE: str = (
|
|
50
|
+
"The relationship between computational complexity and semantic meaning in "
|
|
51
|
+
"large language models is a deeply contested area of research. Attention "
|
|
52
|
+
"mechanisms allow the model to dynamically weight different parts of the "
|
|
53
|
+
"input sequence, creating context-dependent representations at each layer. "
|
|
54
|
+
"The key-value cache stores intermediate computations to avoid redundant "
|
|
55
|
+
"forward passes during autoregressive generation. "
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Heuristic ~ chars per token for sizing long text
|
|
59
|
+
LONG_PROMPT_CHARS_PER_TOKEN: float = 4.5
|
|
60
|
+
|
|
61
|
+
# Memory test: random secret code (6 digits)
|
|
62
|
+
SECRET_CODE_MIN: int = 100_000
|
|
63
|
+
SECRET_CODE_MAX: int = 999_999
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# ── Attention structure (inference when model doesn't expose head_dim) ───────
|
|
67
|
+
|
|
68
|
+
HEAD_DIM_CANDIDATES: tuple[int, ...] = (32, 64, 80, 96, 128, 256)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# ── RealDynamicsProber ───────────────────────────────────────────────────────
|
|
72
|
+
|
|
73
|
+
# Maximum heads used for Gram norms per step (cost vs stability)
|
|
74
|
+
PROBE_HEADS_CAP: int = 4
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# ── Heuristic KV estimate (_estimate_max_kv_mb) ──────────────────────────────
|
|
78
|
+
# Qwen2.5-7B-Instruct: 28 layers, 8 KV heads, head_dim=128.
|
|
79
|
+
# Does not replace actual model measurement — used only for pre-run estimates.
|
|
80
|
+
|
|
81
|
+
@dataclass
|
|
82
|
+
class SAMKVHeuristicEstimate:
|
|
83
|
+
kv_heads: int = 8
|
|
84
|
+
head_dim: int = 128
|
|
85
|
+
n_layers_flat: int = 2
|
|
86
|
+
n_layers_trans: int = 22
|
|
87
|
+
n_layers_obs: int = 4
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
DEFAULT_KV_HEURISTIC = SAMKVHeuristicEstimate()
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
# ── Generation (benchmark / main) ────────────────────────────────────────────
|
|
94
|
+
|
|
95
|
+
@dataclass
|
|
96
|
+
class SAMGenerationDefaults:
|
|
97
|
+
max_new_tokens: int = 100
|
|
98
|
+
max_new_tokens_calibrate: int = 30
|
|
99
|
+
use_cache: bool = False
|
|
100
|
+
do_sample: bool = False
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
DEFAULT_GENERATION = SAMGenerationDefaults()
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
# ── CLI and interactive wizard (argparse defaults / questions) ────────────────
|
|
107
|
+
# NOTE: all policy parameters here must stay aligned with SAMConfig defaults below.
|
|
108
|
+
|
|
109
|
+
@dataclass
|
|
110
|
+
class SAMCliDefaults:
|
|
111
|
+
model: str = "Qwen/Qwen2.5-7B-Instruct"
|
|
112
|
+
device: str = "cpu"
|
|
113
|
+
prompt: str = DEFAULT_DEMO_PROMPT
|
|
114
|
+
max_new_tokens: int = 100
|
|
115
|
+
tau: float = 0.05
|
|
116
|
+
flat_bits: int = 4
|
|
117
|
+
trans_bits: int = 8
|
|
118
|
+
obs_bits: int = 16
|
|
119
|
+
flat_heads: float = 0.5 # aligned with SAMConfig
|
|
120
|
+
trans_heads: float = 0.75 # aligned with SAMConfig
|
|
121
|
+
f_flat_max: float = 5.0 # ⚠ NOT calibrated for 7B — run --calibrate
|
|
122
|
+
f_obs_min: float = 500.0 # ⚠ NOT calibrated for 7B — run --calibrate
|
|
123
|
+
max_ctx_flat: int = 128 # aligned with SAMConfig
|
|
124
|
+
max_ctx_trans: int = 512 # aligned with SAMConfig
|
|
125
|
+
max_ctx_obs: int = 2048 # aligned with SAMConfig
|
|
126
|
+
long_tokens: int = 2000
|
|
127
|
+
# Wizard: suggested minimums in questions
|
|
128
|
+
wizard_max_new_tokens_min: int = 10
|
|
129
|
+
wizard_ctx_flat_min: int = 16
|
|
130
|
+
wizard_ctx_trans_min: int = 64
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
DEFAULT_CLI = SAMCliDefaults()
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
# ── SAMConfig — core of semantic policy ──────────────────────────────────────
|
|
137
|
+
|
|
138
|
+
@dataclass
|
|
139
|
+
class SAMConfig:
|
|
140
|
+
"""
|
|
141
|
+
SAM-Gate (Semantic-Aware Memory Gate) parameters.
|
|
142
|
+
|
|
143
|
+
Typical calibration: --calibrate + adjust f_flat_max / f_obs_min according to
|
|
144
|
+
the observed f(t) distribution per layer.
|
|
145
|
+
|
|
146
|
+
⚠ Thresholds below are NOT yet calibrated for Qwen2.5-7B-Instruct.
|
|
147
|
+
Run: python -m sam_gate.sam --model Qwen/Qwen2.5-7B-Instruct --calibrate --verbose
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
tau: float = 0.05
|
|
151
|
+
flat_bits: int = 4
|
|
152
|
+
trans_bits: int = 8
|
|
153
|
+
obs_bits: int = 16
|
|
154
|
+
flat_heads: float = 0.5
|
|
155
|
+
trans_heads: float = 0.75
|
|
156
|
+
f_flat_max: float = 1e-2 # ⚠ NOT calibrated for 7B — run --calibrate
|
|
157
|
+
f_obs_min: float = 50.0 # ⚠ NOT calibrated for 7B — run --calibrate
|
|
158
|
+
dsem_obs_thresh: float = 5.0
|
|
159
|
+
ema_alpha: float = 0.3
|
|
160
|
+
# If True: probe_from_output only on prefill; decode reuses EMA from end of prefill.
|
|
161
|
+
probe_prefill_only: bool = True
|
|
162
|
+
max_ctx_flat: int = 128 # window in flat regime
|
|
163
|
+
max_ctx_trans: int = 512 # window in transition
|
|
164
|
+
max_ctx_obs: int = 2048 # window in obstructed
|
|
165
|
+
# Decode with dense ring: use flash_attn.flash_attn_interface.flash_attn_with_kvcache
|
|
166
|
+
# (CUDA) when available — fused kernel vs SDPA + materializing contiguous K/V.
|
|
167
|
+
use_flash_attn_kvcache: bool = True
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
# ── Per-layer state (initialization) ─────────────────────────────────────────
|
|
171
|
+
|
|
172
|
+
@dataclass
|
|
173
|
+
class LayerSemanticState:
|
|
174
|
+
d_eff_ratio: float = 1.0
|
|
175
|
+
f_t: float = 0.0
|
|
176
|
+
D_sem: float = 0.0
|
|
177
|
+
H_sem: float = 1.0
|
|
178
|
+
step: int = 0
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
# ── Neutral policy for --calibrate (observe f(t) without compressing) ────────
|
|
182
|
+
|
|
183
|
+
CALIBRATION_F_SENTINEL: float = 1e9
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def neutral_sam_config_for_calibration(tau: float) -> SAMConfig:
|
|
187
|
+
"""SAMConfig that disables compression and windows for measuring f(t) / d_eff."""
|
|
188
|
+
return SAMConfig(
|
|
189
|
+
tau=tau,
|
|
190
|
+
flat_bits=16,
|
|
191
|
+
trans_bits=16,
|
|
192
|
+
obs_bits=16,
|
|
193
|
+
flat_heads=1.0,
|
|
194
|
+
trans_heads=1.0,
|
|
195
|
+
f_flat_max=CALIBRATION_F_SENTINEL,
|
|
196
|
+
f_obs_min=CALIBRATION_F_SENTINEL,
|
|
197
|
+
max_ctx_flat=0,
|
|
198
|
+
max_ctx_trans=0,
|
|
199
|
+
max_ctx_obs=0,
|
|
200
|
+
use_flash_attn_kvcache=False,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def bits_to_kv_prec_str(bits: int, *, fallback: str = "int4") -> str:
|
|
205
|
+
"""Symbolic name of KV precision used in MemoryPolicy.prec_str."""
|
|
206
|
+
return {4: "int4", 8: "int8", 16: "fp16"}.get(bits, fallback)
|