nested-learning 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nested_learning/__init__.py +12 -0
- nested_learning/__main__.py +12 -0
- nested_learning/assoc_memory.py +23 -0
- nested_learning/backbones.py +147 -0
- nested_learning/capabilities.py +104 -0
- nested_learning/cli.py +253 -0
- nested_learning/cms.py +92 -0
- nested_learning/config_utils.py +50 -0
- nested_learning/configs/ablations/cms_sparse.yaml +46 -0
- nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
- nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
- nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
- nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
- nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
- nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
- nested_learning/configs/data/continual_segments_sample.yaml +9 -0
- nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
- nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
- nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
- nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
- nested_learning/configs/deepspeed/zero3.json +25 -0
- nested_learning/configs/hope/mid.yaml +118 -0
- nested_learning/configs/hope/mid_fsdp.yaml +47 -0
- nested_learning/configs/hope/pilot.yaml +2 -0
- nested_learning/configs/hope/pilot_attention.yaml +9 -0
- nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
- nested_learning/configs/hope/pilot_transformer.yaml +9 -0
- nested_learning/configs/hope/target.yaml +145 -0
- nested_learning/configs/hope/target_fsdp.yaml +47 -0
- nested_learning/configs/mid_smoke.yaml +99 -0
- nested_learning/configs/mid_stage2.yaml +110 -0
- nested_learning/configs/mid_stage2_smoke.yaml +102 -0
- nested_learning/configs/mid_titan_baseline.yaml +92 -0
- nested_learning/configs/pilot.yaml +127 -0
- nested_learning/configs/pilot_paper_faithful.yaml +42 -0
- nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
- nested_learning/configs/pilot_smoke.yaml +80 -0
- nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
- nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
- nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
- nested_learning/continual_classification.py +136 -0
- nested_learning/continual_streaming.py +283 -0
- nested_learning/data.py +153 -0
- nested_learning/device.py +21 -0
- nested_learning/eval_state.py +72 -0
- nested_learning/fast_state.py +108 -0
- nested_learning/functional.py +69 -0
- nested_learning/hope/__init__.py +0 -0
- nested_learning/hope/block.py +1973 -0
- nested_learning/hope/self_mod.py +40 -0
- nested_learning/instrumentation.py +38 -0
- nested_learning/levels.py +94 -0
- nested_learning/logging_utils.py +64 -0
- nested_learning/memorize.py +382 -0
- nested_learning/model.py +604 -0
- nested_learning/optim/__init__.py +0 -0
- nested_learning/optim/deep.py +102 -0
- nested_learning/optim/factory.py +13 -0
- nested_learning/optim/m3.py +121 -0
- nested_learning/optim/manager.py +151 -0
- nested_learning/titan/__init__.py +0 -0
- nested_learning/titan/memory.py +88 -0
- nested_learning/titan/model.py +412 -0
- nested_learning/titan/self_modifying.py +724 -0
- nested_learning/tokenizer.py +28 -0
- nested_learning/tokenizer_coverage.py +77 -0
- nested_learning/training.py +1600 -0
- nested_learning/transformer.py +104 -0
- nested_learning-0.2.0.dist-info/METADATA +390 -0
- nested_learning-0.2.0.dist-info/RECORD +76 -0
- nested_learning-0.2.0.dist-info/WHEEL +4 -0
- nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
- nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,412 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Dict, cast
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
|
|
9
|
+
from ..backbones import AttentionConfig, SelfAttention
|
|
10
|
+
from ..fast_state import (
|
|
11
|
+
AttentionKVCache,
|
|
12
|
+
BlockFastState,
|
|
13
|
+
ModelAttentionCache,
|
|
14
|
+
ModelFastState,
|
|
15
|
+
build_block_fast_state,
|
|
16
|
+
)
|
|
17
|
+
from ..functional import (
|
|
18
|
+
call_with_deltas,
|
|
19
|
+
call_with_params,
|
|
20
|
+
grads_to_dict,
|
|
21
|
+
params_with_deltas,
|
|
22
|
+
require_grad_params,
|
|
23
|
+
)
|
|
24
|
+
from ..hope.self_mod import SelfModifier
|
|
25
|
+
from ..levels import LevelSpec
|
|
26
|
+
from ..optim.manager import LevelConfig, LevelOptimizerManager
|
|
27
|
+
from ..titan.memory import TitanMemory, TitanMemoryConfig
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class TitanOnlyModelConfig:
|
|
32
|
+
vocab_size: int
|
|
33
|
+
dim: int
|
|
34
|
+
num_layers: int
|
|
35
|
+
heads: int
|
|
36
|
+
titan_level: LevelSpec
|
|
37
|
+
optimizers: Dict[str, dict] | None = None
|
|
38
|
+
teach_scale: float = 1.0
|
|
39
|
+
teach_clip: float = 0.0
|
|
40
|
+
teach_schedule: Dict[str, float] | None = None
|
|
41
|
+
qk_l2_norm: bool = False
|
|
42
|
+
local_conv_window: int | None = None
|
|
43
|
+
titan_hidden_multiplier: int = 4
|
|
44
|
+
activation: str = "gelu"
|
|
45
|
+
self_mod_hidden: int = 4
|
|
46
|
+
self_mod_lr: float = 1e-3
|
|
47
|
+
surprise_threshold: float | None = None
|
|
48
|
+
surprise_metric: str = "l2"
|
|
49
|
+
freeze_backbone: bool = False
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class TitanOnlyBlock(nn.Module):
|
|
53
|
+
def __init__(self, config: TitanOnlyModelConfig):
|
|
54
|
+
super().__init__()
|
|
55
|
+
self.config = config
|
|
56
|
+
self.surprise_threshold: float | None = None
|
|
57
|
+
self.surprise_metric: str = "l2"
|
|
58
|
+
self.enabled: bool = True
|
|
59
|
+
self.attn = SelfAttention(
|
|
60
|
+
AttentionConfig(
|
|
61
|
+
dim=config.dim,
|
|
62
|
+
heads=config.heads,
|
|
63
|
+
qk_l2_norm=config.qk_l2_norm,
|
|
64
|
+
local_conv_window=config.local_conv_window,
|
|
65
|
+
)
|
|
66
|
+
)
|
|
67
|
+
titan_config = TitanMemoryConfig(
|
|
68
|
+
dim=config.dim,
|
|
69
|
+
hidden_multiplier=config.titan_hidden_multiplier,
|
|
70
|
+
activation=config.activation,
|
|
71
|
+
)
|
|
72
|
+
self.titan_memory = TitanMemory(titan_config)
|
|
73
|
+
self.self_modifier = SelfModifier(config.dim, hidden_multiplier=config.self_mod_hidden)
|
|
74
|
+
self.dropout = nn.Dropout(0.0)
|
|
75
|
+
self.norm = nn.LayerNorm(config.dim)
|
|
76
|
+
level_config = LevelConfig(
|
|
77
|
+
specs=[config.titan_level],
|
|
78
|
+
optimizer_configs=config.optimizers or {},
|
|
79
|
+
default_lr=config.self_mod_lr,
|
|
80
|
+
)
|
|
81
|
+
self.level_manager = LevelOptimizerManager(level_config)
|
|
82
|
+
|
|
83
|
+
def forward(
|
|
84
|
+
self,
|
|
85
|
+
x: torch.Tensor,
|
|
86
|
+
*,
|
|
87
|
+
teach_signal: torch.Tensor | None = None,
|
|
88
|
+
surprise_value: float | None = None,
|
|
89
|
+
fast_state: BlockFastState | None = None,
|
|
90
|
+
attention_cache: AttentionKVCache | None = None,
|
|
91
|
+
return_attention_cache: bool = False,
|
|
92
|
+
differentiable_updates: bool = False,
|
|
93
|
+
) -> torch.Tensor | tuple[torch.Tensor, AttentionKVCache]:
|
|
94
|
+
_ = differentiable_updates
|
|
95
|
+
next_attn_cache: AttentionKVCache | None = None
|
|
96
|
+
if return_attention_cache:
|
|
97
|
+
attn_out, next_attn_cache = self.attn(
|
|
98
|
+
x,
|
|
99
|
+
kv_cache=attention_cache,
|
|
100
|
+
return_kv_cache=True,
|
|
101
|
+
)
|
|
102
|
+
else:
|
|
103
|
+
attn_out = self.attn(x, kv_cache=attention_cache)
|
|
104
|
+
if fast_state is None:
|
|
105
|
+
mem_out = self.titan_memory(attn_out)
|
|
106
|
+
else:
|
|
107
|
+
if fast_state.titan_params is None:
|
|
108
|
+
raise ValueError(
|
|
109
|
+
"fast_state.titan_params is required for TitanOnlyBlock fast-state forward"
|
|
110
|
+
)
|
|
111
|
+
mem_out = call_with_deltas(self.titan_memory, fast_state.titan_params, attn_out)
|
|
112
|
+
combined = attn_out + mem_out
|
|
113
|
+
if teach_signal is not None:
|
|
114
|
+
if fast_state is None:
|
|
115
|
+
self._update_titan(attn_out, mem_out, teach_signal, surprise_value)
|
|
116
|
+
else:
|
|
117
|
+
self._update_titan_fast(fast_state, attn_out, mem_out, teach_signal, surprise_value)
|
|
118
|
+
if fast_state is None:
|
|
119
|
+
self.level_manager.tick()
|
|
120
|
+
else:
|
|
121
|
+
fast_state.level_manager.tick()
|
|
122
|
+
out = self.norm(combined)
|
|
123
|
+
if return_attention_cache:
|
|
124
|
+
assert next_attn_cache is not None
|
|
125
|
+
return out, next_attn_cache
|
|
126
|
+
return out
|
|
127
|
+
|
|
128
|
+
def set_surprise_threshold(self, threshold: float | None) -> None:
|
|
129
|
+
self.surprise_threshold = threshold
|
|
130
|
+
|
|
131
|
+
def set_surprise_metric(self, metric: str) -> None:
|
|
132
|
+
self.surprise_metric = str(metric).strip().lower()
|
|
133
|
+
|
|
134
|
+
def set_enabled(self, enabled: bool) -> None:
|
|
135
|
+
self.enabled = enabled
|
|
136
|
+
|
|
137
|
+
def _passes_surprise(self, surprise_value: float | None) -> bool:
|
|
138
|
+
if self.surprise_threshold is None:
|
|
139
|
+
return True
|
|
140
|
+
if surprise_value is None:
|
|
141
|
+
return False
|
|
142
|
+
return surprise_value >= self.surprise_threshold
|
|
143
|
+
|
|
144
|
+
def _update_titan(
|
|
145
|
+
self,
|
|
146
|
+
attn_out: torch.Tensor,
|
|
147
|
+
mem_out: torch.Tensor,
|
|
148
|
+
teach_signal: torch.Tensor,
|
|
149
|
+
surprise_value: float | None,
|
|
150
|
+
) -> None:
|
|
151
|
+
level_name = self.config.titan_level.name
|
|
152
|
+
if not self.enabled:
|
|
153
|
+
return
|
|
154
|
+
if not self.level_manager.should_update(level_name):
|
|
155
|
+
return
|
|
156
|
+
if not self._passes_surprise(surprise_value):
|
|
157
|
+
return
|
|
158
|
+
# Use full sequence for granular updates (Critique P1)
|
|
159
|
+
# Note: We intentionally do not pool over dim=1 (sequence) here.
|
|
160
|
+
modifier = self.self_modifier(
|
|
161
|
+
key=attn_out.detach(),
|
|
162
|
+
value=mem_out.detach(),
|
|
163
|
+
error_signal=teach_signal.detach(),
|
|
164
|
+
)
|
|
165
|
+
context_vec = attn_out.detach().mean(dim=(0, 1))
|
|
166
|
+
with torch.enable_grad():
|
|
167
|
+
query = attn_out.detach()
|
|
168
|
+
target = (teach_signal.detach() + modifier).detach()
|
|
169
|
+
base_params = {name: param for name, param in self.titan_memory.named_parameters()}
|
|
170
|
+
params_req = require_grad_params(base_params)
|
|
171
|
+
prediction = call_with_params(self.titan_memory, params_req, query)
|
|
172
|
+
loss_terms = nn.functional.mse_loss(prediction, target, reduction="none")
|
|
173
|
+
active = teach_signal.detach().abs().sum(dim=-1, keepdim=True) > 0
|
|
174
|
+
mask = active.float()
|
|
175
|
+
if self.surprise_threshold is not None and self.surprise_metric == "l2":
|
|
176
|
+
norms = teach_signal.norm(dim=-1, keepdim=True)
|
|
177
|
+
mask = mask * (norms >= self.surprise_threshold).float()
|
|
178
|
+
loss = (loss_terms * mask).sum() / mask.sum().clamp(min=1.0)
|
|
179
|
+
|
|
180
|
+
grads = torch.autograd.grad(
|
|
181
|
+
loss,
|
|
182
|
+
tuple(params_req.values()),
|
|
183
|
+
retain_graph=False,
|
|
184
|
+
allow_unused=True,
|
|
185
|
+
)
|
|
186
|
+
grads_dict = grads_to_dict(params_req, grads)
|
|
187
|
+
self.level_manager.apply_module_grads(
|
|
188
|
+
level_name,
|
|
189
|
+
self.titan_memory,
|
|
190
|
+
grads_dict,
|
|
191
|
+
context=context_vec,
|
|
192
|
+
force=True,
|
|
193
|
+
)
|
|
194
|
+
# Pop metrics to avoid stale entries even if we do not log them yet.
|
|
195
|
+
self.level_manager.pop_last_metrics(level_name)
|
|
196
|
+
|
|
197
|
+
def _update_titan_fast(
|
|
198
|
+
self,
|
|
199
|
+
fast_state: BlockFastState,
|
|
200
|
+
attn_out: torch.Tensor,
|
|
201
|
+
mem_out: torch.Tensor,
|
|
202
|
+
teach_signal: torch.Tensor,
|
|
203
|
+
surprise_value: float | None,
|
|
204
|
+
) -> None:
|
|
205
|
+
level_name = self.config.titan_level.name
|
|
206
|
+
if not self.enabled:
|
|
207
|
+
return
|
|
208
|
+
if not fast_state.level_manager.should_update(level_name):
|
|
209
|
+
return
|
|
210
|
+
if not self._passes_surprise(surprise_value):
|
|
211
|
+
return
|
|
212
|
+
if fast_state.titan_params is None:
|
|
213
|
+
return
|
|
214
|
+
modifier = self.self_modifier(
|
|
215
|
+
key=attn_out.detach(),
|
|
216
|
+
value=mem_out.detach(),
|
|
217
|
+
error_signal=teach_signal.detach(),
|
|
218
|
+
)
|
|
219
|
+
context_vec = attn_out.detach().mean(dim=(0, 1))
|
|
220
|
+
base_params = fast_state.titan_params
|
|
221
|
+
forward_params = params_with_deltas(self.titan_memory, base_params)
|
|
222
|
+
params_req = require_grad_params(forward_params)
|
|
223
|
+
with torch.enable_grad():
|
|
224
|
+
query = attn_out.detach()
|
|
225
|
+
target = (teach_signal.detach() + modifier).detach()
|
|
226
|
+
prediction = call_with_params(self.titan_memory, params_req, query)
|
|
227
|
+
loss_terms = nn.functional.mse_loss(prediction, target, reduction="none")
|
|
228
|
+
active = teach_signal.detach().abs().sum(dim=-1, keepdim=True) > 0
|
|
229
|
+
mask = active.float()
|
|
230
|
+
if self.surprise_threshold is not None and self.surprise_metric == "l2":
|
|
231
|
+
norms = teach_signal.norm(dim=-1, keepdim=True)
|
|
232
|
+
mask = mask * (norms >= self.surprise_threshold).float()
|
|
233
|
+
loss = (loss_terms * mask).sum() / mask.sum().clamp(min=1.0)
|
|
234
|
+
grads = torch.autograd.grad(
|
|
235
|
+
loss,
|
|
236
|
+
tuple(params_req.values()),
|
|
237
|
+
retain_graph=False,
|
|
238
|
+
allow_unused=True,
|
|
239
|
+
)
|
|
240
|
+
grads_dict = grads_to_dict(params_req, grads)
|
|
241
|
+
updated, _magnitude = fast_state.level_manager.apply_grads(
|
|
242
|
+
level_name,
|
|
243
|
+
base_params,
|
|
244
|
+
grads_dict,
|
|
245
|
+
context=context_vec,
|
|
246
|
+
force=False,
|
|
247
|
+
)
|
|
248
|
+
fast_state.titan_params = updated
|
|
249
|
+
fast_state.level_manager.pop_last_metrics(level_name)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
class TitanOnlyModel(nn.Module):
|
|
253
|
+
def __init__(self, config: TitanOnlyModelConfig):
|
|
254
|
+
super().__init__()
|
|
255
|
+
self.config = config
|
|
256
|
+
self.embed = nn.Embedding(config.vocab_size, config.dim)
|
|
257
|
+
self.blocks = nn.ModuleList([TitanOnlyBlock(config) for _ in range(config.num_layers)])
|
|
258
|
+
self.norm = nn.LayerNorm(config.dim)
|
|
259
|
+
self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
|
|
260
|
+
self.lm_head.weight = self.embed.weight
|
|
261
|
+
self._runtime_teach_scale = config.teach_scale
|
|
262
|
+
self._runtime_teach_clip = config.teach_clip
|
|
263
|
+
self._surprise_threshold: float | None = None
|
|
264
|
+
self._surprise_metric = "l2"
|
|
265
|
+
self._updates_enabled: bool = True
|
|
266
|
+
self.set_surprise_metric(config.surprise_metric)
|
|
267
|
+
self.set_surprise_threshold(config.surprise_threshold)
|
|
268
|
+
if config.freeze_backbone:
|
|
269
|
+
self.freeze_backbone()
|
|
270
|
+
|
|
271
|
+
def set_teach_runtime(self, *, scale: float | None = None, clip: float | None = None) -> None:
|
|
272
|
+
if scale is not None:
|
|
273
|
+
self._runtime_teach_scale = scale
|
|
274
|
+
if clip is not None:
|
|
275
|
+
self._runtime_teach_clip = clip
|
|
276
|
+
|
|
277
|
+
def set_surprise_threshold(self, threshold: float | None) -> None:
|
|
278
|
+
self._surprise_threshold = threshold
|
|
279
|
+
for block in self.blocks:
|
|
280
|
+
cast(TitanOnlyBlock, block).set_surprise_threshold(threshold)
|
|
281
|
+
|
|
282
|
+
def get_surprise_threshold(self) -> float | None:
|
|
283
|
+
return self._surprise_threshold
|
|
284
|
+
|
|
285
|
+
def set_surprise_metric(self, metric: str) -> None:
|
|
286
|
+
normalized = str(metric).strip().lower()
|
|
287
|
+
allowed = {"l2", "loss", "logit_entropy"}
|
|
288
|
+
if normalized not in allowed:
|
|
289
|
+
raise ValueError(
|
|
290
|
+
f"Unsupported surprise_metric={metric!r}; expected one of {sorted(allowed)}"
|
|
291
|
+
)
|
|
292
|
+
self._surprise_metric = normalized
|
|
293
|
+
for block in self.blocks:
|
|
294
|
+
cast(TitanOnlyBlock, block).set_surprise_metric(normalized)
|
|
295
|
+
|
|
296
|
+
def get_surprise_metric(self) -> str:
|
|
297
|
+
return self._surprise_metric
|
|
298
|
+
|
|
299
|
+
def set_allowed_update_levels(self, levels: set[str] | None) -> None:
|
|
300
|
+
enabled = True
|
|
301
|
+
if levels is not None and "titan" not in levels and len(levels) > 0:
|
|
302
|
+
enabled = False
|
|
303
|
+
self._updates_enabled = enabled
|
|
304
|
+
for block in self.blocks:
|
|
305
|
+
cast(TitanOnlyBlock, block).set_enabled(enabled)
|
|
306
|
+
|
|
307
|
+
def get_allowed_update_levels(self) -> set[str] | None:
|
|
308
|
+
if self._updates_enabled:
|
|
309
|
+
return {"titan"}
|
|
310
|
+
return set()
|
|
311
|
+
|
|
312
|
+
def forward(
|
|
313
|
+
self,
|
|
314
|
+
tokens: torch.Tensor,
|
|
315
|
+
*,
|
|
316
|
+
teach_signal: torch.Tensor | None = None,
|
|
317
|
+
fast_state: ModelFastState | None = None,
|
|
318
|
+
surprise_value: float | None = None,
|
|
319
|
+
attention_cache: ModelAttentionCache | None = None,
|
|
320
|
+
return_attention_cache: bool = False,
|
|
321
|
+
differentiable_updates: bool = False,
|
|
322
|
+
) -> torch.Tensor | tuple[torch.Tensor, ModelAttentionCache]:
|
|
323
|
+
require_external = self._surprise_metric in {"loss", "logit_entropy"}
|
|
324
|
+
if require_external and self._surprise_threshold is not None:
|
|
325
|
+
if teach_signal is not None and surprise_value is None:
|
|
326
|
+
raise ValueError(
|
|
327
|
+
f"surprise_metric={self._surprise_metric} requires passing surprise_value "
|
|
328
|
+
"when model.surprise_threshold is set."
|
|
329
|
+
)
|
|
330
|
+
x = self.embed(tokens)
|
|
331
|
+
if fast_state is not None and len(fast_state.blocks) != len(self.blocks):
|
|
332
|
+
raise ValueError("fast_state.blocks length does not match model.blocks")
|
|
333
|
+
if attention_cache is not None and len(attention_cache.blocks) != len(self.blocks):
|
|
334
|
+
raise ValueError("attention_cache.blocks length does not match model.blocks")
|
|
335
|
+
base_surprise = surprise_value
|
|
336
|
+
next_caches: list[AttentionKVCache | None] = []
|
|
337
|
+
for idx, block in enumerate(self.blocks):
|
|
338
|
+
scaled_signal = None
|
|
339
|
+
if teach_signal is not None:
|
|
340
|
+
scaled_signal = teach_signal * self._runtime_teach_scale
|
|
341
|
+
if self._runtime_teach_clip > 0:
|
|
342
|
+
with torch.no_grad():
|
|
343
|
+
norm = scaled_signal.norm(dim=-1, keepdim=True)
|
|
344
|
+
scale = torch.clamp(norm / self._runtime_teach_clip, min=1.0)
|
|
345
|
+
scaled_signal = scaled_signal / scale
|
|
346
|
+
block_surprise = base_surprise
|
|
347
|
+
if (
|
|
348
|
+
scaled_signal is not None
|
|
349
|
+
and base_surprise is None
|
|
350
|
+
and self._surprise_metric == "l2"
|
|
351
|
+
):
|
|
352
|
+
block_surprise = float(scaled_signal.norm(dim=-1).mean().item())
|
|
353
|
+
block_state = None if fast_state is None else fast_state.blocks[idx]
|
|
354
|
+
block_cache = None if attention_cache is None else attention_cache.blocks[idx]
|
|
355
|
+
if return_attention_cache:
|
|
356
|
+
x, next_cache = block( # type: ignore[arg-type]
|
|
357
|
+
x,
|
|
358
|
+
teach_signal=scaled_signal,
|
|
359
|
+
surprise_value=block_surprise,
|
|
360
|
+
fast_state=block_state,
|
|
361
|
+
attention_cache=block_cache,
|
|
362
|
+
return_attention_cache=True,
|
|
363
|
+
differentiable_updates=differentiable_updates,
|
|
364
|
+
)
|
|
365
|
+
next_caches.append(next_cache)
|
|
366
|
+
else:
|
|
367
|
+
x = block( # type: ignore[arg-type]
|
|
368
|
+
x,
|
|
369
|
+
teach_signal=scaled_signal,
|
|
370
|
+
surprise_value=block_surprise,
|
|
371
|
+
fast_state=block_state,
|
|
372
|
+
attention_cache=block_cache,
|
|
373
|
+
differentiable_updates=differentiable_updates,
|
|
374
|
+
)
|
|
375
|
+
x = self.norm(x)
|
|
376
|
+
logits = self.lm_head(x)
|
|
377
|
+
if return_attention_cache:
|
|
378
|
+
return logits, ModelAttentionCache(blocks=next_caches)
|
|
379
|
+
return logits
|
|
380
|
+
|
|
381
|
+
def freeze_backbone(self) -> None:
|
|
382
|
+
"""
|
|
383
|
+
Freeze shared transformer components; leave TITAN memory/trainable paths active.
|
|
384
|
+
"""
|
|
385
|
+
for p in self.embed.parameters():
|
|
386
|
+
p.requires_grad = False
|
|
387
|
+
for p in self.norm.parameters():
|
|
388
|
+
p.requires_grad = False
|
|
389
|
+
for p in self.lm_head.parameters():
|
|
390
|
+
p.requires_grad = False
|
|
391
|
+
for block in self.blocks:
|
|
392
|
+
typed_block = cast(TitanOnlyBlock, block)
|
|
393
|
+
for p in typed_block.attn.parameters():
|
|
394
|
+
p.requires_grad = False
|
|
395
|
+
|
|
396
|
+
def init_fast_state(self) -> ModelFastState:
|
|
397
|
+
states = []
|
|
398
|
+
for block in self.blocks:
|
|
399
|
+
typed_block = cast(TitanOnlyBlock, block)
|
|
400
|
+
specs = [typed_block.config.titan_level]
|
|
401
|
+
state = build_block_fast_state(
|
|
402
|
+
titan_module=typed_block.titan_memory,
|
|
403
|
+
cms_blocks={},
|
|
404
|
+
specs=specs,
|
|
405
|
+
optimizer_configs=typed_block.config.optimizers or {},
|
|
406
|
+
default_lr=typed_block.config.self_mod_lr,
|
|
407
|
+
)
|
|
408
|
+
states.append(state)
|
|
409
|
+
return ModelFastState(blocks=states)
|
|
410
|
+
|
|
411
|
+
def init_attention_cache(self) -> ModelAttentionCache:
|
|
412
|
+
return ModelAttentionCache(blocks=[None for _ in self.blocks])
|