nested-learning 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nested_learning/__init__.py +12 -0
- nested_learning/__main__.py +12 -0
- nested_learning/assoc_memory.py +23 -0
- nested_learning/backbones.py +147 -0
- nested_learning/capabilities.py +104 -0
- nested_learning/cli.py +253 -0
- nested_learning/cms.py +92 -0
- nested_learning/config_utils.py +50 -0
- nested_learning/configs/ablations/cms_sparse.yaml +46 -0
- nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
- nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
- nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
- nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
- nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
- nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
- nested_learning/configs/data/continual_segments_sample.yaml +9 -0
- nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
- nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
- nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
- nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
- nested_learning/configs/deepspeed/zero3.json +25 -0
- nested_learning/configs/hope/mid.yaml +118 -0
- nested_learning/configs/hope/mid_fsdp.yaml +47 -0
- nested_learning/configs/hope/pilot.yaml +2 -0
- nested_learning/configs/hope/pilot_attention.yaml +9 -0
- nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
- nested_learning/configs/hope/pilot_transformer.yaml +9 -0
- nested_learning/configs/hope/target.yaml +145 -0
- nested_learning/configs/hope/target_fsdp.yaml +47 -0
- nested_learning/configs/mid_smoke.yaml +99 -0
- nested_learning/configs/mid_stage2.yaml +110 -0
- nested_learning/configs/mid_stage2_smoke.yaml +102 -0
- nested_learning/configs/mid_titan_baseline.yaml +92 -0
- nested_learning/configs/pilot.yaml +127 -0
- nested_learning/configs/pilot_paper_faithful.yaml +42 -0
- nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
- nested_learning/configs/pilot_smoke.yaml +80 -0
- nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
- nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
- nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
- nested_learning/continual_classification.py +136 -0
- nested_learning/continual_streaming.py +283 -0
- nested_learning/data.py +153 -0
- nested_learning/device.py +21 -0
- nested_learning/eval_state.py +72 -0
- nested_learning/fast_state.py +108 -0
- nested_learning/functional.py +69 -0
- nested_learning/hope/__init__.py +0 -0
- nested_learning/hope/block.py +1973 -0
- nested_learning/hope/self_mod.py +40 -0
- nested_learning/instrumentation.py +38 -0
- nested_learning/levels.py +94 -0
- nested_learning/logging_utils.py +64 -0
- nested_learning/memorize.py +382 -0
- nested_learning/model.py +604 -0
- nested_learning/optim/__init__.py +0 -0
- nested_learning/optim/deep.py +102 -0
- nested_learning/optim/factory.py +13 -0
- nested_learning/optim/m3.py +121 -0
- nested_learning/optim/manager.py +151 -0
- nested_learning/titan/__init__.py +0 -0
- nested_learning/titan/memory.py +88 -0
- nested_learning/titan/model.py +412 -0
- nested_learning/titan/self_modifying.py +724 -0
- nested_learning/tokenizer.py +28 -0
- nested_learning/tokenizer_coverage.py +77 -0
- nested_learning/training.py +1600 -0
- nested_learning/transformer.py +104 -0
- nested_learning-0.2.0.dist-info/METADATA +390 -0
- nested_learning-0.2.0.dist-info/RECORD +76 -0
- nested_learning-0.2.0.dist-info/WHEEL +4 -0
- nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
- nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
nested_learning/model.py
ADDED
|
@@ -0,0 +1,604 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Dict, Protocol, Sequence, cast
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
from torch.utils.checkpoint import checkpoint
|
|
9
|
+
|
|
10
|
+
from .fast_state import (
|
|
11
|
+
AttentionKVCache,
|
|
12
|
+
ModelAttentionCache,
|
|
13
|
+
ModelFastState,
|
|
14
|
+
build_block_fast_state,
|
|
15
|
+
)
|
|
16
|
+
from .hope.block import (
|
|
17
|
+
HOPEAttentionBlock,
|
|
18
|
+
HOPEAttentionBlockConfig,
|
|
19
|
+
HOPEBlock,
|
|
20
|
+
HOPEBlockConfig,
|
|
21
|
+
HOPESelfModBlock,
|
|
22
|
+
HOPESelfModBlockConfig,
|
|
23
|
+
)
|
|
24
|
+
from .levels import LevelSpec
|
|
25
|
+
from .transformer import TransformerBlock, TransformerBlockConfig
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class ModelConfig:
|
|
30
|
+
vocab_size: int
|
|
31
|
+
dim: int
|
|
32
|
+
num_layers: int
|
|
33
|
+
heads: int
|
|
34
|
+
titan_level: LevelSpec
|
|
35
|
+
cms_levels: Sequence[LevelSpec]
|
|
36
|
+
cms_flush_partial_at_end: bool = False
|
|
37
|
+
cms_use_layernorm: bool = True
|
|
38
|
+
optimizers: Dict[str, dict] | None = None
|
|
39
|
+
teach_scale: float = 1.0
|
|
40
|
+
teach_clip: float = 0.0
|
|
41
|
+
teach_schedule: Dict[str, float] | None = None
|
|
42
|
+
gradient_checkpointing: bool = False
|
|
43
|
+
surprise_threshold: float | None = None
|
|
44
|
+
surprise_metric: str = "l2"
|
|
45
|
+
freeze_backbone: bool = False
|
|
46
|
+
qk_l2_norm: bool = False
|
|
47
|
+
local_conv_window: int | None = None
|
|
48
|
+
self_mod_lr: float = 1e-3
|
|
49
|
+
self_mod_hidden: int = 4
|
|
50
|
+
self_mod_chunk_size: int = 1
|
|
51
|
+
self_mod_chunk_size_memory: int | None = None
|
|
52
|
+
self_mod_objective: str = "l2"
|
|
53
|
+
self_mod_stopgrad_vhat: bool = True
|
|
54
|
+
self_mod_use_rank1_precond: bool = True
|
|
55
|
+
self_mod_use_alpha: bool = True
|
|
56
|
+
self_mod_use_skip: bool = True
|
|
57
|
+
self_mod_momentum: float = 0.0
|
|
58
|
+
self_mod_adaptive_q: bool = False
|
|
59
|
+
self_mod_local_conv_window: int | None = 4
|
|
60
|
+
transformer_mlp_hidden_multiplier: int = 4
|
|
61
|
+
transformer_activation: str = "gelu"
|
|
62
|
+
block_variant: str = "hope_hybrid"
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class HOPEModel(nn.Module):
|
|
66
|
+
def __init__(self, config: ModelConfig):
|
|
67
|
+
super().__init__()
|
|
68
|
+
self.config = config
|
|
69
|
+
self.embed = nn.Embedding(config.vocab_size, config.dim)
|
|
70
|
+
self.base_teach_scale = config.teach_scale
|
|
71
|
+
self.base_teach_clip = config.teach_clip
|
|
72
|
+
self._runtime_teach_scale = config.teach_scale
|
|
73
|
+
self._runtime_teach_clip = config.teach_clip
|
|
74
|
+
self.gradient_checkpointing = config.gradient_checkpointing
|
|
75
|
+
self._surprise_threshold = config.surprise_threshold
|
|
76
|
+
self._surprise_metric = "l2"
|
|
77
|
+
self._allowed_update_levels: set[str] | None = None
|
|
78
|
+
self._allowed_update_layers: set[int] | None = None
|
|
79
|
+
variant = str(config.block_variant).strip().lower()
|
|
80
|
+
if variant == "hope_attention":
|
|
81
|
+
attn_block_config = HOPEAttentionBlockConfig(
|
|
82
|
+
dim=config.dim,
|
|
83
|
+
heads=config.heads,
|
|
84
|
+
cms_levels=config.cms_levels,
|
|
85
|
+
cms_flush_partial_at_end=config.cms_flush_partial_at_end,
|
|
86
|
+
cms_use_layernorm=config.cms_use_layernorm,
|
|
87
|
+
qk_l2_norm=config.qk_l2_norm,
|
|
88
|
+
local_conv_window=config.local_conv_window,
|
|
89
|
+
self_mod_lr=config.self_mod_lr,
|
|
90
|
+
optimizer_configs=config.optimizers or {},
|
|
91
|
+
)
|
|
92
|
+
self.blocks = nn.ModuleList(
|
|
93
|
+
[HOPEAttentionBlock(attn_block_config) for _ in range(config.num_layers)]
|
|
94
|
+
)
|
|
95
|
+
elif variant == "hope_hybrid":
|
|
96
|
+
hybrid_block_config = HOPEBlockConfig(
|
|
97
|
+
dim=config.dim,
|
|
98
|
+
heads=config.heads,
|
|
99
|
+
titan_level=config.titan_level,
|
|
100
|
+
cms_levels=config.cms_levels,
|
|
101
|
+
cms_flush_partial_at_end=config.cms_flush_partial_at_end,
|
|
102
|
+
cms_use_layernorm=config.cms_use_layernorm,
|
|
103
|
+
qk_l2_norm=config.qk_l2_norm,
|
|
104
|
+
local_conv_window=config.local_conv_window,
|
|
105
|
+
self_mod_lr=config.self_mod_lr,
|
|
106
|
+
self_mod_hidden=config.self_mod_hidden,
|
|
107
|
+
optimizer_configs=config.optimizers or {},
|
|
108
|
+
)
|
|
109
|
+
self.blocks = nn.ModuleList(
|
|
110
|
+
[HOPEBlock(hybrid_block_config) for _ in range(config.num_layers)]
|
|
111
|
+
)
|
|
112
|
+
elif variant == "hope_selfmod":
|
|
113
|
+
selfmod_block_config = HOPESelfModBlockConfig(
|
|
114
|
+
dim=config.dim,
|
|
115
|
+
cms_levels=config.cms_levels,
|
|
116
|
+
cms_flush_partial_at_end=config.cms_flush_partial_at_end,
|
|
117
|
+
cms_use_layernorm=config.cms_use_layernorm,
|
|
118
|
+
qk_l2_norm=config.qk_l2_norm,
|
|
119
|
+
selfmod_adaptive_q=config.self_mod_adaptive_q,
|
|
120
|
+
selfmod_local_conv_window=config.self_mod_local_conv_window,
|
|
121
|
+
eta_scale=config.self_mod_lr,
|
|
122
|
+
selfmod_chunk_size=config.self_mod_chunk_size,
|
|
123
|
+
selfmod_chunk_size_memory=config.self_mod_chunk_size_memory,
|
|
124
|
+
selfmod_objective=config.self_mod_objective,
|
|
125
|
+
selfmod_stopgrad_vhat=config.self_mod_stopgrad_vhat,
|
|
126
|
+
selfmod_use_rank1_precond=config.self_mod_use_rank1_precond,
|
|
127
|
+
selfmod_use_alpha=config.self_mod_use_alpha,
|
|
128
|
+
selfmod_use_skip=config.self_mod_use_skip,
|
|
129
|
+
selfmod_momentum=config.self_mod_momentum,
|
|
130
|
+
self_mod_lr=config.self_mod_lr,
|
|
131
|
+
optimizer_configs=config.optimizers or {},
|
|
132
|
+
)
|
|
133
|
+
self.blocks = nn.ModuleList(
|
|
134
|
+
[HOPESelfModBlock(selfmod_block_config) for _ in range(config.num_layers)]
|
|
135
|
+
)
|
|
136
|
+
elif variant == "transformer":
|
|
137
|
+
transformer_block_config = TransformerBlockConfig(
|
|
138
|
+
dim=config.dim,
|
|
139
|
+
heads=config.heads,
|
|
140
|
+
mlp_hidden_multiplier=config.transformer_mlp_hidden_multiplier,
|
|
141
|
+
activation=config.transformer_activation,
|
|
142
|
+
qk_l2_norm=config.qk_l2_norm,
|
|
143
|
+
local_conv_window=config.local_conv_window,
|
|
144
|
+
)
|
|
145
|
+
self.blocks = nn.ModuleList(
|
|
146
|
+
[TransformerBlock(transformer_block_config) for _ in range(config.num_layers)]
|
|
147
|
+
)
|
|
148
|
+
else:
|
|
149
|
+
raise ValueError(
|
|
150
|
+
f"Unsupported block_variant={config.block_variant!r}; expected one of "
|
|
151
|
+
"['hope_attention', 'hope_hybrid', 'hope_selfmod', 'transformer']"
|
|
152
|
+
)
|
|
153
|
+
self.norm = nn.LayerNorm(config.dim)
|
|
154
|
+
self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
|
|
155
|
+
# Weight tying keeps the LM head gradient aligned with the embedding space.
|
|
156
|
+
self.lm_head.weight = self.embed.weight
|
|
157
|
+
self._latest_update_metrics: Dict[str, float] = {}
|
|
158
|
+
self.set_surprise_metric(config.surprise_metric)
|
|
159
|
+
self.set_surprise_threshold(self._surprise_threshold)
|
|
160
|
+
if config.freeze_backbone:
|
|
161
|
+
self.freeze_backbone()
|
|
162
|
+
|
|
163
|
+
def set_teach_runtime(self, *, scale: float | None = None, clip: float | None = None) -> None:
|
|
164
|
+
if scale is not None:
|
|
165
|
+
self._runtime_teach_scale = scale
|
|
166
|
+
if clip is not None:
|
|
167
|
+
self._runtime_teach_clip = clip
|
|
168
|
+
|
|
169
|
+
def set_surprise_threshold(self, threshold: float | None) -> None:
|
|
170
|
+
self._surprise_threshold = threshold
|
|
171
|
+
for block in self.blocks:
|
|
172
|
+
cast(_UpdateControlledBlock, block).set_surprise_threshold(threshold)
|
|
173
|
+
|
|
174
|
+
def get_surprise_threshold(self) -> float | None:
|
|
175
|
+
return self._surprise_threshold
|
|
176
|
+
|
|
177
|
+
def set_surprise_metric(self, metric: str) -> None:
|
|
178
|
+
normalized = str(metric).strip().lower()
|
|
179
|
+
allowed = {"l2", "loss", "logit_entropy"}
|
|
180
|
+
if normalized not in allowed:
|
|
181
|
+
raise ValueError(
|
|
182
|
+
f"Unsupported surprise_metric={metric!r}; expected one of {sorted(allowed)}"
|
|
183
|
+
)
|
|
184
|
+
self._surprise_metric = normalized
|
|
185
|
+
for block in self.blocks:
|
|
186
|
+
cast(_UpdateControlledBlock, block).set_surprise_metric(normalized)
|
|
187
|
+
|
|
188
|
+
def get_surprise_metric(self) -> str:
|
|
189
|
+
return self._surprise_metric
|
|
190
|
+
|
|
191
|
+
def set_allowed_update_levels(self, levels: set[str] | None) -> None:
|
|
192
|
+
self._allowed_update_levels = levels.copy() if levels is not None else None
|
|
193
|
+
for block in self.blocks:
|
|
194
|
+
cast(_UpdateControlledBlock, block).set_allowed_levels(self._allowed_update_levels)
|
|
195
|
+
|
|
196
|
+
def get_allowed_update_levels(self) -> set[str] | None:
|
|
197
|
+
return None if self._allowed_update_levels is None else self._allowed_update_levels.copy()
|
|
198
|
+
|
|
199
|
+
def set_allowed_update_layers(self, layers: set[int] | None) -> None:
|
|
200
|
+
if layers is None:
|
|
201
|
+
self._allowed_update_layers = None
|
|
202
|
+
return
|
|
203
|
+
normalized: set[int] = set()
|
|
204
|
+
total = len(self.blocks)
|
|
205
|
+
for idx in layers:
|
|
206
|
+
layer_idx = int(idx)
|
|
207
|
+
if layer_idx < 0:
|
|
208
|
+
layer_idx = total + layer_idx
|
|
209
|
+
if not (0 <= layer_idx < total):
|
|
210
|
+
raise ValueError(f"Invalid layer index {idx} for model with {total} layers")
|
|
211
|
+
normalized.add(layer_idx)
|
|
212
|
+
self._allowed_update_layers = normalized
|
|
213
|
+
|
|
214
|
+
def get_allowed_update_layers(self) -> set[int] | None:
|
|
215
|
+
return None if self._allowed_update_layers is None else self._allowed_update_layers.copy()
|
|
216
|
+
|
|
217
|
+
def forward(
|
|
218
|
+
self,
|
|
219
|
+
tokens: torch.Tensor,
|
|
220
|
+
*,
|
|
221
|
+
teach_signal: torch.Tensor | None = None,
|
|
222
|
+
teach_signals: list[torch.Tensor] | None = None,
|
|
223
|
+
fast_state: ModelFastState | None = None,
|
|
224
|
+
surprise_value: float | None = None,
|
|
225
|
+
finalize_updates: bool = True,
|
|
226
|
+
attention_cache: ModelAttentionCache | None = None,
|
|
227
|
+
return_attention_cache: bool = False,
|
|
228
|
+
differentiable_updates: bool = False,
|
|
229
|
+
) -> torch.Tensor | tuple[torch.Tensor, ModelAttentionCache]:
|
|
230
|
+
if return_attention_cache:
|
|
231
|
+
logits, _pre_norm, next_attention_cache = cast(
|
|
232
|
+
tuple[torch.Tensor, torch.Tensor, ModelAttentionCache],
|
|
233
|
+
self.forward_with_pre_norm(
|
|
234
|
+
tokens,
|
|
235
|
+
teach_signal=teach_signal,
|
|
236
|
+
teach_signals=teach_signals,
|
|
237
|
+
fast_state=fast_state,
|
|
238
|
+
surprise_value=surprise_value,
|
|
239
|
+
finalize_updates=finalize_updates,
|
|
240
|
+
attention_cache=attention_cache,
|
|
241
|
+
return_attention_cache=True,
|
|
242
|
+
differentiable_updates=differentiable_updates,
|
|
243
|
+
),
|
|
244
|
+
)
|
|
245
|
+
return logits, next_attention_cache
|
|
246
|
+
logits, _pre_norm = cast(
|
|
247
|
+
tuple[torch.Tensor, torch.Tensor],
|
|
248
|
+
self.forward_with_pre_norm(
|
|
249
|
+
tokens,
|
|
250
|
+
teach_signal=teach_signal,
|
|
251
|
+
teach_signals=teach_signals,
|
|
252
|
+
fast_state=fast_state,
|
|
253
|
+
surprise_value=surprise_value,
|
|
254
|
+
finalize_updates=finalize_updates,
|
|
255
|
+
attention_cache=attention_cache,
|
|
256
|
+
return_attention_cache=False,
|
|
257
|
+
differentiable_updates=differentiable_updates,
|
|
258
|
+
),
|
|
259
|
+
)
|
|
260
|
+
return logits
|
|
261
|
+
|
|
262
|
+
def forward_with_pre_norm(
|
|
263
|
+
self,
|
|
264
|
+
tokens: torch.Tensor,
|
|
265
|
+
*,
|
|
266
|
+
teach_signal: torch.Tensor | None = None,
|
|
267
|
+
teach_signals: list[torch.Tensor] | None = None,
|
|
268
|
+
fast_state: ModelFastState | None = None,
|
|
269
|
+
surprise_value: float | None = None,
|
|
270
|
+
finalize_updates: bool = True,
|
|
271
|
+
attention_cache: ModelAttentionCache | None = None,
|
|
272
|
+
return_attention_cache: bool = False,
|
|
273
|
+
differentiable_updates: bool = False,
|
|
274
|
+
) -> (
|
|
275
|
+
tuple[torch.Tensor, torch.Tensor]
|
|
276
|
+
| tuple[torch.Tensor, torch.Tensor, ModelAttentionCache]
|
|
277
|
+
):
|
|
278
|
+
if return_attention_cache:
|
|
279
|
+
x, next_attention_caches = cast(
|
|
280
|
+
tuple[torch.Tensor, list[AttentionKVCache | None]],
|
|
281
|
+
self._run_blocks(
|
|
282
|
+
tokens,
|
|
283
|
+
teach_signal=teach_signal,
|
|
284
|
+
teach_signals=teach_signals,
|
|
285
|
+
fast_state=fast_state,
|
|
286
|
+
surprise_value=surprise_value,
|
|
287
|
+
finalize_updates=finalize_updates,
|
|
288
|
+
attention_cache=attention_cache,
|
|
289
|
+
return_attention_cache=True,
|
|
290
|
+
differentiable_updates=differentiable_updates,
|
|
291
|
+
),
|
|
292
|
+
)
|
|
293
|
+
else:
|
|
294
|
+
x = cast(
|
|
295
|
+
torch.Tensor,
|
|
296
|
+
self._run_blocks(
|
|
297
|
+
tokens,
|
|
298
|
+
teach_signal=teach_signal,
|
|
299
|
+
teach_signals=teach_signals,
|
|
300
|
+
fast_state=fast_state,
|
|
301
|
+
surprise_value=surprise_value,
|
|
302
|
+
finalize_updates=finalize_updates,
|
|
303
|
+
attention_cache=attention_cache,
|
|
304
|
+
return_attention_cache=False,
|
|
305
|
+
differentiable_updates=differentiable_updates,
|
|
306
|
+
),
|
|
307
|
+
)
|
|
308
|
+
pre_norm = cast(torch.Tensor, x)
|
|
309
|
+
x = self.norm(pre_norm)
|
|
310
|
+
logits = self.lm_head(x)
|
|
311
|
+
if teach_signal is not None or teach_signals is not None:
|
|
312
|
+
self._latest_update_metrics = self._gather_block_stats()
|
|
313
|
+
if return_attention_cache:
|
|
314
|
+
return logits, pre_norm, ModelAttentionCache(blocks=next_attention_caches)
|
|
315
|
+
return logits, pre_norm
|
|
316
|
+
|
|
317
|
+
def forward_with_block_outputs(
|
|
318
|
+
self,
|
|
319
|
+
tokens: torch.Tensor,
|
|
320
|
+
*,
|
|
321
|
+
teach_signal: torch.Tensor | None = None,
|
|
322
|
+
teach_signals: list[torch.Tensor] | None = None,
|
|
323
|
+
fast_state: ModelFastState | None = None,
|
|
324
|
+
surprise_value: float | None = None,
|
|
325
|
+
finalize_updates: bool = True,
|
|
326
|
+
attention_cache: ModelAttentionCache | None = None,
|
|
327
|
+
return_attention_cache: bool = False,
|
|
328
|
+
differentiable_updates: bool = False,
|
|
329
|
+
) -> (
|
|
330
|
+
tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
|
331
|
+
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor], ModelAttentionCache]
|
|
332
|
+
):
|
|
333
|
+
if return_attention_cache:
|
|
334
|
+
x, block_outputs, next_attention_caches = cast(
|
|
335
|
+
tuple[torch.Tensor, list[torch.Tensor], list[AttentionKVCache | None]],
|
|
336
|
+
self._run_blocks(
|
|
337
|
+
tokens,
|
|
338
|
+
teach_signal=teach_signal,
|
|
339
|
+
teach_signals=teach_signals,
|
|
340
|
+
fast_state=fast_state,
|
|
341
|
+
surprise_value=surprise_value,
|
|
342
|
+
finalize_updates=finalize_updates,
|
|
343
|
+
attention_cache=attention_cache,
|
|
344
|
+
return_attention_cache=True,
|
|
345
|
+
collect_outputs=True,
|
|
346
|
+
differentiable_updates=differentiable_updates,
|
|
347
|
+
),
|
|
348
|
+
)
|
|
349
|
+
else:
|
|
350
|
+
x, block_outputs = cast(
|
|
351
|
+
tuple[torch.Tensor, list[torch.Tensor]],
|
|
352
|
+
self._run_blocks(
|
|
353
|
+
tokens,
|
|
354
|
+
teach_signal=teach_signal,
|
|
355
|
+
teach_signals=teach_signals,
|
|
356
|
+
fast_state=fast_state,
|
|
357
|
+
surprise_value=surprise_value,
|
|
358
|
+
finalize_updates=finalize_updates,
|
|
359
|
+
attention_cache=attention_cache,
|
|
360
|
+
return_attention_cache=False,
|
|
361
|
+
collect_outputs=True,
|
|
362
|
+
differentiable_updates=differentiable_updates,
|
|
363
|
+
),
|
|
364
|
+
)
|
|
365
|
+
pre_norm = x
|
|
366
|
+
x = self.norm(x)
|
|
367
|
+
logits = self.lm_head(x)
|
|
368
|
+
if teach_signal is not None or teach_signals is not None:
|
|
369
|
+
self._latest_update_metrics = self._gather_block_stats()
|
|
370
|
+
if return_attention_cache:
|
|
371
|
+
return (
|
|
372
|
+
logits,
|
|
373
|
+
pre_norm,
|
|
374
|
+
block_outputs,
|
|
375
|
+
ModelAttentionCache(blocks=next_attention_caches),
|
|
376
|
+
)
|
|
377
|
+
return logits, pre_norm, block_outputs
|
|
378
|
+
|
|
379
|
+
def _run_blocks(
|
|
380
|
+
self,
|
|
381
|
+
tokens: torch.Tensor,
|
|
382
|
+
*,
|
|
383
|
+
teach_signal: torch.Tensor | None,
|
|
384
|
+
fast_state: ModelFastState | None,
|
|
385
|
+
teach_signals: list[torch.Tensor] | None = None,
|
|
386
|
+
surprise_value: float | None = None,
|
|
387
|
+
finalize_updates: bool = True,
|
|
388
|
+
attention_cache: ModelAttentionCache | None = None,
|
|
389
|
+
return_attention_cache: bool = False,
|
|
390
|
+
collect_outputs: bool = False,
|
|
391
|
+
differentiable_updates: bool = False,
|
|
392
|
+
) -> (
|
|
393
|
+
torch.Tensor
|
|
394
|
+
| tuple[torch.Tensor, list[torch.Tensor]]
|
|
395
|
+
| tuple[torch.Tensor, list[AttentionKVCache | None]]
|
|
396
|
+
| tuple[torch.Tensor, list[torch.Tensor], list[AttentionKVCache | None]]
|
|
397
|
+
):
|
|
398
|
+
x = self.embed(tokens)
|
|
399
|
+
block_outputs: list[torch.Tensor] = []
|
|
400
|
+
next_attention_caches: list[AttentionKVCache | None] = []
|
|
401
|
+
runtime_scale = self._runtime_teach_scale
|
|
402
|
+
runtime_clip = self._runtime_teach_clip
|
|
403
|
+
if teach_signals is not None:
|
|
404
|
+
if len(teach_signals) != len(self.blocks):
|
|
405
|
+
raise ValueError(
|
|
406
|
+
f"teach_signals length {len(teach_signals)} "
|
|
407
|
+
f"does not match blocks {len(self.blocks)}"
|
|
408
|
+
)
|
|
409
|
+
if teach_signal is not None:
|
|
410
|
+
raise ValueError("Provide either teach_signal or teach_signals, not both.")
|
|
411
|
+
if fast_state is not None and len(fast_state.blocks) != len(self.blocks):
|
|
412
|
+
raise ValueError("fast_state.blocks length does not match model.blocks")
|
|
413
|
+
if attention_cache is not None and len(attention_cache.blocks) != len(self.blocks):
|
|
414
|
+
raise ValueError("attention_cache.blocks length does not match model.blocks")
|
|
415
|
+
|
|
416
|
+
require_external = self._surprise_metric in {"loss", "logit_entropy"}
|
|
417
|
+
if require_external and self._surprise_threshold is not None:
|
|
418
|
+
if (teach_signal is not None or teach_signals is not None) and surprise_value is None:
|
|
419
|
+
raise ValueError(
|
|
420
|
+
f"surprise_metric={self._surprise_metric} requires passing surprise_value "
|
|
421
|
+
"when model.surprise_threshold is set."
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
base_surprise = surprise_value
|
|
425
|
+
scaled_global_signal: torch.Tensor | None = None
|
|
426
|
+
if base_surprise is None and teach_signal is not None and self._surprise_metric == "l2":
|
|
427
|
+
scaled_global_signal = teach_signal * runtime_scale
|
|
428
|
+
if runtime_clip > 0:
|
|
429
|
+
norm = scaled_global_signal.norm(dim=-1, keepdim=True)
|
|
430
|
+
scale = torch.clamp(norm / runtime_clip, min=1.0)
|
|
431
|
+
scaled_global_signal = scaled_global_signal / scale
|
|
432
|
+
base_surprise = float(scaled_global_signal.norm(dim=-1).mean().item())
|
|
433
|
+
|
|
434
|
+
for idx, block in enumerate(self.blocks):
|
|
435
|
+
block_state = None if fast_state is None else fast_state.blocks[idx]
|
|
436
|
+
block_attention_cache = None if attention_cache is None else attention_cache.blocks[idx]
|
|
437
|
+
scaled_signal = None
|
|
438
|
+
block_surprise = base_surprise
|
|
439
|
+
if teach_signal is not None:
|
|
440
|
+
if scaled_global_signal is None:
|
|
441
|
+
scaled_signal = teach_signal * runtime_scale
|
|
442
|
+
if runtime_clip > 0:
|
|
443
|
+
norm = scaled_signal.norm(dim=-1, keepdim=True)
|
|
444
|
+
scale = torch.clamp(norm / runtime_clip, min=1.0)
|
|
445
|
+
scaled_signal = scaled_signal / scale
|
|
446
|
+
else:
|
|
447
|
+
scaled_signal = scaled_global_signal
|
|
448
|
+
if (
|
|
449
|
+
self._allowed_update_layers is not None
|
|
450
|
+
and idx not in self._allowed_update_layers
|
|
451
|
+
):
|
|
452
|
+
scaled_signal = None
|
|
453
|
+
if teach_signals is not None:
|
|
454
|
+
scaled_signal = teach_signals[idx] * self._runtime_teach_scale
|
|
455
|
+
if self._surprise_metric == "l2" and base_surprise is None:
|
|
456
|
+
block_surprise = float(scaled_signal.norm(dim=-1).mean().item())
|
|
457
|
+
if self._runtime_teach_clip > 0:
|
|
458
|
+
norm = scaled_signal.norm(dim=-1, keepdim=True)
|
|
459
|
+
scale = torch.clamp(norm / self._runtime_teach_clip, min=1.0)
|
|
460
|
+
scaled_signal = scaled_signal / scale
|
|
461
|
+
if (
|
|
462
|
+
self._allowed_update_layers is not None
|
|
463
|
+
and idx not in self._allowed_update_layers
|
|
464
|
+
):
|
|
465
|
+
scaled_signal = None
|
|
466
|
+
|
|
467
|
+
def block_call(
|
|
468
|
+
hidden: torch.Tensor,
|
|
469
|
+
*,
|
|
470
|
+
blk=block,
|
|
471
|
+
sig=scaled_signal,
|
|
472
|
+
st=block_state,
|
|
473
|
+
sv=block_surprise,
|
|
474
|
+
fin=finalize_updates,
|
|
475
|
+
ac=block_attention_cache,
|
|
476
|
+
du=differentiable_updates,
|
|
477
|
+
) -> torch.Tensor:
|
|
478
|
+
return blk(
|
|
479
|
+
hidden,
|
|
480
|
+
teach_signal=sig,
|
|
481
|
+
surprise_value=sv,
|
|
482
|
+
fast_state=st,
|
|
483
|
+
finalize_updates=fin,
|
|
484
|
+
attention_cache=ac,
|
|
485
|
+
differentiable_updates=du,
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
if return_attention_cache:
|
|
489
|
+
x, next_cache = block( # type: ignore[assignment]
|
|
490
|
+
x,
|
|
491
|
+
teach_signal=scaled_signal,
|
|
492
|
+
surprise_value=block_surprise,
|
|
493
|
+
fast_state=block_state,
|
|
494
|
+
finalize_updates=finalize_updates,
|
|
495
|
+
attention_cache=block_attention_cache,
|
|
496
|
+
return_attention_cache=True,
|
|
497
|
+
differentiable_updates=differentiable_updates,
|
|
498
|
+
)
|
|
499
|
+
next_attention_caches.append(next_cache)
|
|
500
|
+
elif torch.is_grad_enabled() and self.training and self.gradient_checkpointing:
|
|
501
|
+
x = checkpoint(block_call, x, use_reentrant=False)
|
|
502
|
+
else:
|
|
503
|
+
x = block_call(x)
|
|
504
|
+
if collect_outputs:
|
|
505
|
+
block_outputs.append(x)
|
|
506
|
+
if collect_outputs and return_attention_cache:
|
|
507
|
+
return x, block_outputs, next_attention_caches
|
|
508
|
+
if collect_outputs:
|
|
509
|
+
return x, block_outputs
|
|
510
|
+
if return_attention_cache:
|
|
511
|
+
return x, next_attention_caches
|
|
512
|
+
return x
|
|
513
|
+
|
|
514
|
+
def _gather_block_stats(self) -> Dict[str, float]:
|
|
515
|
+
metrics: Dict[str, float] = {}
|
|
516
|
+
for idx, block in enumerate(self.blocks):
|
|
517
|
+
pop_fn = getattr(block, "pop_update_stats", None)
|
|
518
|
+
if callable(pop_fn):
|
|
519
|
+
stats = cast(Dict[str, Dict[str, float]], pop_fn())
|
|
520
|
+
for level_name, payload in stats.items():
|
|
521
|
+
prefix = f"layer{idx}.{level_name}"
|
|
522
|
+
for key, value in payload.items():
|
|
523
|
+
metrics[f"{prefix}.{key}"] = value
|
|
524
|
+
return metrics
|
|
525
|
+
|
|
526
|
+
def pop_update_metrics(self) -> Dict[str, float]:
|
|
527
|
+
metrics = self._latest_update_metrics
|
|
528
|
+
self._latest_update_metrics = {}
|
|
529
|
+
return metrics
|
|
530
|
+
|
|
531
|
+
def init_fast_state(self) -> ModelFastState:
|
|
532
|
+
states = []
|
|
533
|
+
for block in self.blocks:
|
|
534
|
+
if isinstance(block, HOPEBlock):
|
|
535
|
+
specs = [block.config.titan_level, *block.config.cms_levels]
|
|
536
|
+
state = build_block_fast_state(
|
|
537
|
+
titan_module=block.titan_memory,
|
|
538
|
+
cms_blocks=dict(block.cms.blocks.items()),
|
|
539
|
+
specs=specs,
|
|
540
|
+
optimizer_configs=block.config.optimizer_configs,
|
|
541
|
+
default_lr=block.config.self_mod_lr,
|
|
542
|
+
)
|
|
543
|
+
states.append(state)
|
|
544
|
+
elif isinstance(block, HOPEAttentionBlock):
|
|
545
|
+
specs = list(block.config.cms_levels)
|
|
546
|
+
state = build_block_fast_state(
|
|
547
|
+
titan_module=None,
|
|
548
|
+
cms_blocks=dict(block.cms.blocks.items()),
|
|
549
|
+
specs=specs,
|
|
550
|
+
optimizer_configs=block.config.optimizer_configs,
|
|
551
|
+
default_lr=block.config.self_mod_lr,
|
|
552
|
+
)
|
|
553
|
+
states.append(state)
|
|
554
|
+
elif isinstance(block, HOPESelfModBlock):
|
|
555
|
+
specs = list(block.config.cms_levels)
|
|
556
|
+
state = build_block_fast_state(
|
|
557
|
+
titan_module=None,
|
|
558
|
+
cms_blocks=dict(block.cms.blocks.items()),
|
|
559
|
+
selfmod_module=block.selfmod,
|
|
560
|
+
specs=specs,
|
|
561
|
+
optimizer_configs=block.config.optimizer_configs,
|
|
562
|
+
default_lr=block.config.self_mod_lr,
|
|
563
|
+
)
|
|
564
|
+
states.append(state)
|
|
565
|
+
elif isinstance(block, TransformerBlock):
|
|
566
|
+
state = build_block_fast_state(
|
|
567
|
+
titan_module=None,
|
|
568
|
+
cms_blocks={},
|
|
569
|
+
specs=(),
|
|
570
|
+
optimizer_configs={},
|
|
571
|
+
default_lr=0.0,
|
|
572
|
+
)
|
|
573
|
+
states.append(state)
|
|
574
|
+
else:
|
|
575
|
+
raise TypeError(f"Unsupported block type for fast state: {type(block)}")
|
|
576
|
+
return ModelFastState(blocks=states)
|
|
577
|
+
|
|
578
|
+
def init_attention_cache(self) -> ModelAttentionCache:
|
|
579
|
+
return ModelAttentionCache(blocks=[None for _ in self.blocks])
|
|
580
|
+
|
|
581
|
+
def freeze_backbone(self) -> None:
|
|
582
|
+
"""
|
|
583
|
+
Freeze the shared transformer spine (embeddings, attention blocks, norm, LM head).
|
|
584
|
+
HOPE/TITAN/CMS memories remain trainable for adapter-style finetuning.
|
|
585
|
+
"""
|
|
586
|
+
for p in self.embed.parameters():
|
|
587
|
+
p.requires_grad = False
|
|
588
|
+
for p in self.norm.parameters():
|
|
589
|
+
p.requires_grad = False
|
|
590
|
+
for p in self.lm_head.parameters():
|
|
591
|
+
p.requires_grad = False
|
|
592
|
+
for block in self.blocks:
|
|
593
|
+
attn = getattr(block, "attn", None)
|
|
594
|
+
if isinstance(attn, nn.Module):
|
|
595
|
+
for p in attn.parameters():
|
|
596
|
+
p.requires_grad = False
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
class _UpdateControlledBlock(Protocol):
|
|
600
|
+
def set_surprise_threshold(self, threshold: float | None) -> None: ...
|
|
601
|
+
|
|
602
|
+
def set_surprise_metric(self, metric: str) -> None: ...
|
|
603
|
+
|
|
604
|
+
def set_allowed_levels(self, allowed: set[str] | None) -> None: ...
|
|
File without changes
|