nested-learning 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nested_learning/__init__.py +12 -0
  2. nested_learning/__main__.py +12 -0
  3. nested_learning/assoc_memory.py +23 -0
  4. nested_learning/backbones.py +147 -0
  5. nested_learning/capabilities.py +104 -0
  6. nested_learning/cli.py +253 -0
  7. nested_learning/cms.py +92 -0
  8. nested_learning/config_utils.py +50 -0
  9. nested_learning/configs/ablations/cms_sparse.yaml +46 -0
  10. nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
  11. nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
  12. nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
  13. nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
  14. nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
  15. nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
  16. nested_learning/configs/data/continual_segments_sample.yaml +9 -0
  17. nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
  18. nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
  19. nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
  20. nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
  21. nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
  22. nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
  23. nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
  24. nested_learning/configs/deepspeed/zero3.json +25 -0
  25. nested_learning/configs/hope/mid.yaml +118 -0
  26. nested_learning/configs/hope/mid_fsdp.yaml +47 -0
  27. nested_learning/configs/hope/pilot.yaml +2 -0
  28. nested_learning/configs/hope/pilot_attention.yaml +9 -0
  29. nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
  30. nested_learning/configs/hope/pilot_transformer.yaml +9 -0
  31. nested_learning/configs/hope/target.yaml +145 -0
  32. nested_learning/configs/hope/target_fsdp.yaml +47 -0
  33. nested_learning/configs/mid_smoke.yaml +99 -0
  34. nested_learning/configs/mid_stage2.yaml +110 -0
  35. nested_learning/configs/mid_stage2_smoke.yaml +102 -0
  36. nested_learning/configs/mid_titan_baseline.yaml +92 -0
  37. nested_learning/configs/pilot.yaml +127 -0
  38. nested_learning/configs/pilot_paper_faithful.yaml +42 -0
  39. nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
  40. nested_learning/configs/pilot_smoke.yaml +80 -0
  41. nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
  42. nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
  43. nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
  44. nested_learning/continual_classification.py +136 -0
  45. nested_learning/continual_streaming.py +283 -0
  46. nested_learning/data.py +153 -0
  47. nested_learning/device.py +21 -0
  48. nested_learning/eval_state.py +72 -0
  49. nested_learning/fast_state.py +108 -0
  50. nested_learning/functional.py +69 -0
  51. nested_learning/hope/__init__.py +0 -0
  52. nested_learning/hope/block.py +1973 -0
  53. nested_learning/hope/self_mod.py +40 -0
  54. nested_learning/instrumentation.py +38 -0
  55. nested_learning/levels.py +94 -0
  56. nested_learning/logging_utils.py +64 -0
  57. nested_learning/memorize.py +382 -0
  58. nested_learning/model.py +604 -0
  59. nested_learning/optim/__init__.py +0 -0
  60. nested_learning/optim/deep.py +102 -0
  61. nested_learning/optim/factory.py +13 -0
  62. nested_learning/optim/m3.py +121 -0
  63. nested_learning/optim/manager.py +151 -0
  64. nested_learning/titan/__init__.py +0 -0
  65. nested_learning/titan/memory.py +88 -0
  66. nested_learning/titan/model.py +412 -0
  67. nested_learning/titan/self_modifying.py +724 -0
  68. nested_learning/tokenizer.py +28 -0
  69. nested_learning/tokenizer_coverage.py +77 -0
  70. nested_learning/training.py +1600 -0
  71. nested_learning/transformer.py +104 -0
  72. nested_learning-0.2.0.dist-info/METADATA +390 -0
  73. nested_learning-0.2.0.dist-info/RECORD +76 -0
  74. nested_learning-0.2.0.dist-info/WHEEL +4 -0
  75. nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
  76. nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,604 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, Protocol, Sequence, cast
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.checkpoint import checkpoint
9
+
10
+ from .fast_state import (
11
+ AttentionKVCache,
12
+ ModelAttentionCache,
13
+ ModelFastState,
14
+ build_block_fast_state,
15
+ )
16
+ from .hope.block import (
17
+ HOPEAttentionBlock,
18
+ HOPEAttentionBlockConfig,
19
+ HOPEBlock,
20
+ HOPEBlockConfig,
21
+ HOPESelfModBlock,
22
+ HOPESelfModBlockConfig,
23
+ )
24
+ from .levels import LevelSpec
25
+ from .transformer import TransformerBlock, TransformerBlockConfig
26
+
27
+
28
+ @dataclass
29
+ class ModelConfig:
30
+ vocab_size: int
31
+ dim: int
32
+ num_layers: int
33
+ heads: int
34
+ titan_level: LevelSpec
35
+ cms_levels: Sequence[LevelSpec]
36
+ cms_flush_partial_at_end: bool = False
37
+ cms_use_layernorm: bool = True
38
+ optimizers: Dict[str, dict] | None = None
39
+ teach_scale: float = 1.0
40
+ teach_clip: float = 0.0
41
+ teach_schedule: Dict[str, float] | None = None
42
+ gradient_checkpointing: bool = False
43
+ surprise_threshold: float | None = None
44
+ surprise_metric: str = "l2"
45
+ freeze_backbone: bool = False
46
+ qk_l2_norm: bool = False
47
+ local_conv_window: int | None = None
48
+ self_mod_lr: float = 1e-3
49
+ self_mod_hidden: int = 4
50
+ self_mod_chunk_size: int = 1
51
+ self_mod_chunk_size_memory: int | None = None
52
+ self_mod_objective: str = "l2"
53
+ self_mod_stopgrad_vhat: bool = True
54
+ self_mod_use_rank1_precond: bool = True
55
+ self_mod_use_alpha: bool = True
56
+ self_mod_use_skip: bool = True
57
+ self_mod_momentum: float = 0.0
58
+ self_mod_adaptive_q: bool = False
59
+ self_mod_local_conv_window: int | None = 4
60
+ transformer_mlp_hidden_multiplier: int = 4
61
+ transformer_activation: str = "gelu"
62
+ block_variant: str = "hope_hybrid"
63
+
64
+
65
+ class HOPEModel(nn.Module):
66
+ def __init__(self, config: ModelConfig):
67
+ super().__init__()
68
+ self.config = config
69
+ self.embed = nn.Embedding(config.vocab_size, config.dim)
70
+ self.base_teach_scale = config.teach_scale
71
+ self.base_teach_clip = config.teach_clip
72
+ self._runtime_teach_scale = config.teach_scale
73
+ self._runtime_teach_clip = config.teach_clip
74
+ self.gradient_checkpointing = config.gradient_checkpointing
75
+ self._surprise_threshold = config.surprise_threshold
76
+ self._surprise_metric = "l2"
77
+ self._allowed_update_levels: set[str] | None = None
78
+ self._allowed_update_layers: set[int] | None = None
79
+ variant = str(config.block_variant).strip().lower()
80
+ if variant == "hope_attention":
81
+ attn_block_config = HOPEAttentionBlockConfig(
82
+ dim=config.dim,
83
+ heads=config.heads,
84
+ cms_levels=config.cms_levels,
85
+ cms_flush_partial_at_end=config.cms_flush_partial_at_end,
86
+ cms_use_layernorm=config.cms_use_layernorm,
87
+ qk_l2_norm=config.qk_l2_norm,
88
+ local_conv_window=config.local_conv_window,
89
+ self_mod_lr=config.self_mod_lr,
90
+ optimizer_configs=config.optimizers or {},
91
+ )
92
+ self.blocks = nn.ModuleList(
93
+ [HOPEAttentionBlock(attn_block_config) for _ in range(config.num_layers)]
94
+ )
95
+ elif variant == "hope_hybrid":
96
+ hybrid_block_config = HOPEBlockConfig(
97
+ dim=config.dim,
98
+ heads=config.heads,
99
+ titan_level=config.titan_level,
100
+ cms_levels=config.cms_levels,
101
+ cms_flush_partial_at_end=config.cms_flush_partial_at_end,
102
+ cms_use_layernorm=config.cms_use_layernorm,
103
+ qk_l2_norm=config.qk_l2_norm,
104
+ local_conv_window=config.local_conv_window,
105
+ self_mod_lr=config.self_mod_lr,
106
+ self_mod_hidden=config.self_mod_hidden,
107
+ optimizer_configs=config.optimizers or {},
108
+ )
109
+ self.blocks = nn.ModuleList(
110
+ [HOPEBlock(hybrid_block_config) for _ in range(config.num_layers)]
111
+ )
112
+ elif variant == "hope_selfmod":
113
+ selfmod_block_config = HOPESelfModBlockConfig(
114
+ dim=config.dim,
115
+ cms_levels=config.cms_levels,
116
+ cms_flush_partial_at_end=config.cms_flush_partial_at_end,
117
+ cms_use_layernorm=config.cms_use_layernorm,
118
+ qk_l2_norm=config.qk_l2_norm,
119
+ selfmod_adaptive_q=config.self_mod_adaptive_q,
120
+ selfmod_local_conv_window=config.self_mod_local_conv_window,
121
+ eta_scale=config.self_mod_lr,
122
+ selfmod_chunk_size=config.self_mod_chunk_size,
123
+ selfmod_chunk_size_memory=config.self_mod_chunk_size_memory,
124
+ selfmod_objective=config.self_mod_objective,
125
+ selfmod_stopgrad_vhat=config.self_mod_stopgrad_vhat,
126
+ selfmod_use_rank1_precond=config.self_mod_use_rank1_precond,
127
+ selfmod_use_alpha=config.self_mod_use_alpha,
128
+ selfmod_use_skip=config.self_mod_use_skip,
129
+ selfmod_momentum=config.self_mod_momentum,
130
+ self_mod_lr=config.self_mod_lr,
131
+ optimizer_configs=config.optimizers or {},
132
+ )
133
+ self.blocks = nn.ModuleList(
134
+ [HOPESelfModBlock(selfmod_block_config) for _ in range(config.num_layers)]
135
+ )
136
+ elif variant == "transformer":
137
+ transformer_block_config = TransformerBlockConfig(
138
+ dim=config.dim,
139
+ heads=config.heads,
140
+ mlp_hidden_multiplier=config.transformer_mlp_hidden_multiplier,
141
+ activation=config.transformer_activation,
142
+ qk_l2_norm=config.qk_l2_norm,
143
+ local_conv_window=config.local_conv_window,
144
+ )
145
+ self.blocks = nn.ModuleList(
146
+ [TransformerBlock(transformer_block_config) for _ in range(config.num_layers)]
147
+ )
148
+ else:
149
+ raise ValueError(
150
+ f"Unsupported block_variant={config.block_variant!r}; expected one of "
151
+ "['hope_attention', 'hope_hybrid', 'hope_selfmod', 'transformer']"
152
+ )
153
+ self.norm = nn.LayerNorm(config.dim)
154
+ self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
155
+ # Weight tying keeps the LM head gradient aligned with the embedding space.
156
+ self.lm_head.weight = self.embed.weight
157
+ self._latest_update_metrics: Dict[str, float] = {}
158
+ self.set_surprise_metric(config.surprise_metric)
159
+ self.set_surprise_threshold(self._surprise_threshold)
160
+ if config.freeze_backbone:
161
+ self.freeze_backbone()
162
+
163
+ def set_teach_runtime(self, *, scale: float | None = None, clip: float | None = None) -> None:
164
+ if scale is not None:
165
+ self._runtime_teach_scale = scale
166
+ if clip is not None:
167
+ self._runtime_teach_clip = clip
168
+
169
+ def set_surprise_threshold(self, threshold: float | None) -> None:
170
+ self._surprise_threshold = threshold
171
+ for block in self.blocks:
172
+ cast(_UpdateControlledBlock, block).set_surprise_threshold(threshold)
173
+
174
+ def get_surprise_threshold(self) -> float | None:
175
+ return self._surprise_threshold
176
+
177
+ def set_surprise_metric(self, metric: str) -> None:
178
+ normalized = str(metric).strip().lower()
179
+ allowed = {"l2", "loss", "logit_entropy"}
180
+ if normalized not in allowed:
181
+ raise ValueError(
182
+ f"Unsupported surprise_metric={metric!r}; expected one of {sorted(allowed)}"
183
+ )
184
+ self._surprise_metric = normalized
185
+ for block in self.blocks:
186
+ cast(_UpdateControlledBlock, block).set_surprise_metric(normalized)
187
+
188
+ def get_surprise_metric(self) -> str:
189
+ return self._surprise_metric
190
+
191
+ def set_allowed_update_levels(self, levels: set[str] | None) -> None:
192
+ self._allowed_update_levels = levels.copy() if levels is not None else None
193
+ for block in self.blocks:
194
+ cast(_UpdateControlledBlock, block).set_allowed_levels(self._allowed_update_levels)
195
+
196
+ def get_allowed_update_levels(self) -> set[str] | None:
197
+ return None if self._allowed_update_levels is None else self._allowed_update_levels.copy()
198
+
199
+ def set_allowed_update_layers(self, layers: set[int] | None) -> None:
200
+ if layers is None:
201
+ self._allowed_update_layers = None
202
+ return
203
+ normalized: set[int] = set()
204
+ total = len(self.blocks)
205
+ for idx in layers:
206
+ layer_idx = int(idx)
207
+ if layer_idx < 0:
208
+ layer_idx = total + layer_idx
209
+ if not (0 <= layer_idx < total):
210
+ raise ValueError(f"Invalid layer index {idx} for model with {total} layers")
211
+ normalized.add(layer_idx)
212
+ self._allowed_update_layers = normalized
213
+
214
+ def get_allowed_update_layers(self) -> set[int] | None:
215
+ return None if self._allowed_update_layers is None else self._allowed_update_layers.copy()
216
+
217
+ def forward(
218
+ self,
219
+ tokens: torch.Tensor,
220
+ *,
221
+ teach_signal: torch.Tensor | None = None,
222
+ teach_signals: list[torch.Tensor] | None = None,
223
+ fast_state: ModelFastState | None = None,
224
+ surprise_value: float | None = None,
225
+ finalize_updates: bool = True,
226
+ attention_cache: ModelAttentionCache | None = None,
227
+ return_attention_cache: bool = False,
228
+ differentiable_updates: bool = False,
229
+ ) -> torch.Tensor | tuple[torch.Tensor, ModelAttentionCache]:
230
+ if return_attention_cache:
231
+ logits, _pre_norm, next_attention_cache = cast(
232
+ tuple[torch.Tensor, torch.Tensor, ModelAttentionCache],
233
+ self.forward_with_pre_norm(
234
+ tokens,
235
+ teach_signal=teach_signal,
236
+ teach_signals=teach_signals,
237
+ fast_state=fast_state,
238
+ surprise_value=surprise_value,
239
+ finalize_updates=finalize_updates,
240
+ attention_cache=attention_cache,
241
+ return_attention_cache=True,
242
+ differentiable_updates=differentiable_updates,
243
+ ),
244
+ )
245
+ return logits, next_attention_cache
246
+ logits, _pre_norm = cast(
247
+ tuple[torch.Tensor, torch.Tensor],
248
+ self.forward_with_pre_norm(
249
+ tokens,
250
+ teach_signal=teach_signal,
251
+ teach_signals=teach_signals,
252
+ fast_state=fast_state,
253
+ surprise_value=surprise_value,
254
+ finalize_updates=finalize_updates,
255
+ attention_cache=attention_cache,
256
+ return_attention_cache=False,
257
+ differentiable_updates=differentiable_updates,
258
+ ),
259
+ )
260
+ return logits
261
+
262
+ def forward_with_pre_norm(
263
+ self,
264
+ tokens: torch.Tensor,
265
+ *,
266
+ teach_signal: torch.Tensor | None = None,
267
+ teach_signals: list[torch.Tensor] | None = None,
268
+ fast_state: ModelFastState | None = None,
269
+ surprise_value: float | None = None,
270
+ finalize_updates: bool = True,
271
+ attention_cache: ModelAttentionCache | None = None,
272
+ return_attention_cache: bool = False,
273
+ differentiable_updates: bool = False,
274
+ ) -> (
275
+ tuple[torch.Tensor, torch.Tensor]
276
+ | tuple[torch.Tensor, torch.Tensor, ModelAttentionCache]
277
+ ):
278
+ if return_attention_cache:
279
+ x, next_attention_caches = cast(
280
+ tuple[torch.Tensor, list[AttentionKVCache | None]],
281
+ self._run_blocks(
282
+ tokens,
283
+ teach_signal=teach_signal,
284
+ teach_signals=teach_signals,
285
+ fast_state=fast_state,
286
+ surprise_value=surprise_value,
287
+ finalize_updates=finalize_updates,
288
+ attention_cache=attention_cache,
289
+ return_attention_cache=True,
290
+ differentiable_updates=differentiable_updates,
291
+ ),
292
+ )
293
+ else:
294
+ x = cast(
295
+ torch.Tensor,
296
+ self._run_blocks(
297
+ tokens,
298
+ teach_signal=teach_signal,
299
+ teach_signals=teach_signals,
300
+ fast_state=fast_state,
301
+ surprise_value=surprise_value,
302
+ finalize_updates=finalize_updates,
303
+ attention_cache=attention_cache,
304
+ return_attention_cache=False,
305
+ differentiable_updates=differentiable_updates,
306
+ ),
307
+ )
308
+ pre_norm = cast(torch.Tensor, x)
309
+ x = self.norm(pre_norm)
310
+ logits = self.lm_head(x)
311
+ if teach_signal is not None or teach_signals is not None:
312
+ self._latest_update_metrics = self._gather_block_stats()
313
+ if return_attention_cache:
314
+ return logits, pre_norm, ModelAttentionCache(blocks=next_attention_caches)
315
+ return logits, pre_norm
316
+
317
+ def forward_with_block_outputs(
318
+ self,
319
+ tokens: torch.Tensor,
320
+ *,
321
+ teach_signal: torch.Tensor | None = None,
322
+ teach_signals: list[torch.Tensor] | None = None,
323
+ fast_state: ModelFastState | None = None,
324
+ surprise_value: float | None = None,
325
+ finalize_updates: bool = True,
326
+ attention_cache: ModelAttentionCache | None = None,
327
+ return_attention_cache: bool = False,
328
+ differentiable_updates: bool = False,
329
+ ) -> (
330
+ tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
331
+ | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor], ModelAttentionCache]
332
+ ):
333
+ if return_attention_cache:
334
+ x, block_outputs, next_attention_caches = cast(
335
+ tuple[torch.Tensor, list[torch.Tensor], list[AttentionKVCache | None]],
336
+ self._run_blocks(
337
+ tokens,
338
+ teach_signal=teach_signal,
339
+ teach_signals=teach_signals,
340
+ fast_state=fast_state,
341
+ surprise_value=surprise_value,
342
+ finalize_updates=finalize_updates,
343
+ attention_cache=attention_cache,
344
+ return_attention_cache=True,
345
+ collect_outputs=True,
346
+ differentiable_updates=differentiable_updates,
347
+ ),
348
+ )
349
+ else:
350
+ x, block_outputs = cast(
351
+ tuple[torch.Tensor, list[torch.Tensor]],
352
+ self._run_blocks(
353
+ tokens,
354
+ teach_signal=teach_signal,
355
+ teach_signals=teach_signals,
356
+ fast_state=fast_state,
357
+ surprise_value=surprise_value,
358
+ finalize_updates=finalize_updates,
359
+ attention_cache=attention_cache,
360
+ return_attention_cache=False,
361
+ collect_outputs=True,
362
+ differentiable_updates=differentiable_updates,
363
+ ),
364
+ )
365
+ pre_norm = x
366
+ x = self.norm(x)
367
+ logits = self.lm_head(x)
368
+ if teach_signal is not None or teach_signals is not None:
369
+ self._latest_update_metrics = self._gather_block_stats()
370
+ if return_attention_cache:
371
+ return (
372
+ logits,
373
+ pre_norm,
374
+ block_outputs,
375
+ ModelAttentionCache(blocks=next_attention_caches),
376
+ )
377
+ return logits, pre_norm, block_outputs
378
+
379
+ def _run_blocks(
380
+ self,
381
+ tokens: torch.Tensor,
382
+ *,
383
+ teach_signal: torch.Tensor | None,
384
+ fast_state: ModelFastState | None,
385
+ teach_signals: list[torch.Tensor] | None = None,
386
+ surprise_value: float | None = None,
387
+ finalize_updates: bool = True,
388
+ attention_cache: ModelAttentionCache | None = None,
389
+ return_attention_cache: bool = False,
390
+ collect_outputs: bool = False,
391
+ differentiable_updates: bool = False,
392
+ ) -> (
393
+ torch.Tensor
394
+ | tuple[torch.Tensor, list[torch.Tensor]]
395
+ | tuple[torch.Tensor, list[AttentionKVCache | None]]
396
+ | tuple[torch.Tensor, list[torch.Tensor], list[AttentionKVCache | None]]
397
+ ):
398
+ x = self.embed(tokens)
399
+ block_outputs: list[torch.Tensor] = []
400
+ next_attention_caches: list[AttentionKVCache | None] = []
401
+ runtime_scale = self._runtime_teach_scale
402
+ runtime_clip = self._runtime_teach_clip
403
+ if teach_signals is not None:
404
+ if len(teach_signals) != len(self.blocks):
405
+ raise ValueError(
406
+ f"teach_signals length {len(teach_signals)} "
407
+ f"does not match blocks {len(self.blocks)}"
408
+ )
409
+ if teach_signal is not None:
410
+ raise ValueError("Provide either teach_signal or teach_signals, not both.")
411
+ if fast_state is not None and len(fast_state.blocks) != len(self.blocks):
412
+ raise ValueError("fast_state.blocks length does not match model.blocks")
413
+ if attention_cache is not None and len(attention_cache.blocks) != len(self.blocks):
414
+ raise ValueError("attention_cache.blocks length does not match model.blocks")
415
+
416
+ require_external = self._surprise_metric in {"loss", "logit_entropy"}
417
+ if require_external and self._surprise_threshold is not None:
418
+ if (teach_signal is not None or teach_signals is not None) and surprise_value is None:
419
+ raise ValueError(
420
+ f"surprise_metric={self._surprise_metric} requires passing surprise_value "
421
+ "when model.surprise_threshold is set."
422
+ )
423
+
424
+ base_surprise = surprise_value
425
+ scaled_global_signal: torch.Tensor | None = None
426
+ if base_surprise is None and teach_signal is not None and self._surprise_metric == "l2":
427
+ scaled_global_signal = teach_signal * runtime_scale
428
+ if runtime_clip > 0:
429
+ norm = scaled_global_signal.norm(dim=-1, keepdim=True)
430
+ scale = torch.clamp(norm / runtime_clip, min=1.0)
431
+ scaled_global_signal = scaled_global_signal / scale
432
+ base_surprise = float(scaled_global_signal.norm(dim=-1).mean().item())
433
+
434
+ for idx, block in enumerate(self.blocks):
435
+ block_state = None if fast_state is None else fast_state.blocks[idx]
436
+ block_attention_cache = None if attention_cache is None else attention_cache.blocks[idx]
437
+ scaled_signal = None
438
+ block_surprise = base_surprise
439
+ if teach_signal is not None:
440
+ if scaled_global_signal is None:
441
+ scaled_signal = teach_signal * runtime_scale
442
+ if runtime_clip > 0:
443
+ norm = scaled_signal.norm(dim=-1, keepdim=True)
444
+ scale = torch.clamp(norm / runtime_clip, min=1.0)
445
+ scaled_signal = scaled_signal / scale
446
+ else:
447
+ scaled_signal = scaled_global_signal
448
+ if (
449
+ self._allowed_update_layers is not None
450
+ and idx not in self._allowed_update_layers
451
+ ):
452
+ scaled_signal = None
453
+ if teach_signals is not None:
454
+ scaled_signal = teach_signals[idx] * self._runtime_teach_scale
455
+ if self._surprise_metric == "l2" and base_surprise is None:
456
+ block_surprise = float(scaled_signal.norm(dim=-1).mean().item())
457
+ if self._runtime_teach_clip > 0:
458
+ norm = scaled_signal.norm(dim=-1, keepdim=True)
459
+ scale = torch.clamp(norm / self._runtime_teach_clip, min=1.0)
460
+ scaled_signal = scaled_signal / scale
461
+ if (
462
+ self._allowed_update_layers is not None
463
+ and idx not in self._allowed_update_layers
464
+ ):
465
+ scaled_signal = None
466
+
467
+ def block_call(
468
+ hidden: torch.Tensor,
469
+ *,
470
+ blk=block,
471
+ sig=scaled_signal,
472
+ st=block_state,
473
+ sv=block_surprise,
474
+ fin=finalize_updates,
475
+ ac=block_attention_cache,
476
+ du=differentiable_updates,
477
+ ) -> torch.Tensor:
478
+ return blk(
479
+ hidden,
480
+ teach_signal=sig,
481
+ surprise_value=sv,
482
+ fast_state=st,
483
+ finalize_updates=fin,
484
+ attention_cache=ac,
485
+ differentiable_updates=du,
486
+ )
487
+
488
+ if return_attention_cache:
489
+ x, next_cache = block( # type: ignore[assignment]
490
+ x,
491
+ teach_signal=scaled_signal,
492
+ surprise_value=block_surprise,
493
+ fast_state=block_state,
494
+ finalize_updates=finalize_updates,
495
+ attention_cache=block_attention_cache,
496
+ return_attention_cache=True,
497
+ differentiable_updates=differentiable_updates,
498
+ )
499
+ next_attention_caches.append(next_cache)
500
+ elif torch.is_grad_enabled() and self.training and self.gradient_checkpointing:
501
+ x = checkpoint(block_call, x, use_reentrant=False)
502
+ else:
503
+ x = block_call(x)
504
+ if collect_outputs:
505
+ block_outputs.append(x)
506
+ if collect_outputs and return_attention_cache:
507
+ return x, block_outputs, next_attention_caches
508
+ if collect_outputs:
509
+ return x, block_outputs
510
+ if return_attention_cache:
511
+ return x, next_attention_caches
512
+ return x
513
+
514
+ def _gather_block_stats(self) -> Dict[str, float]:
515
+ metrics: Dict[str, float] = {}
516
+ for idx, block in enumerate(self.blocks):
517
+ pop_fn = getattr(block, "pop_update_stats", None)
518
+ if callable(pop_fn):
519
+ stats = cast(Dict[str, Dict[str, float]], pop_fn())
520
+ for level_name, payload in stats.items():
521
+ prefix = f"layer{idx}.{level_name}"
522
+ for key, value in payload.items():
523
+ metrics[f"{prefix}.{key}"] = value
524
+ return metrics
525
+
526
+ def pop_update_metrics(self) -> Dict[str, float]:
527
+ metrics = self._latest_update_metrics
528
+ self._latest_update_metrics = {}
529
+ return metrics
530
+
531
+ def init_fast_state(self) -> ModelFastState:
532
+ states = []
533
+ for block in self.blocks:
534
+ if isinstance(block, HOPEBlock):
535
+ specs = [block.config.titan_level, *block.config.cms_levels]
536
+ state = build_block_fast_state(
537
+ titan_module=block.titan_memory,
538
+ cms_blocks=dict(block.cms.blocks.items()),
539
+ specs=specs,
540
+ optimizer_configs=block.config.optimizer_configs,
541
+ default_lr=block.config.self_mod_lr,
542
+ )
543
+ states.append(state)
544
+ elif isinstance(block, HOPEAttentionBlock):
545
+ specs = list(block.config.cms_levels)
546
+ state = build_block_fast_state(
547
+ titan_module=None,
548
+ cms_blocks=dict(block.cms.blocks.items()),
549
+ specs=specs,
550
+ optimizer_configs=block.config.optimizer_configs,
551
+ default_lr=block.config.self_mod_lr,
552
+ )
553
+ states.append(state)
554
+ elif isinstance(block, HOPESelfModBlock):
555
+ specs = list(block.config.cms_levels)
556
+ state = build_block_fast_state(
557
+ titan_module=None,
558
+ cms_blocks=dict(block.cms.blocks.items()),
559
+ selfmod_module=block.selfmod,
560
+ specs=specs,
561
+ optimizer_configs=block.config.optimizer_configs,
562
+ default_lr=block.config.self_mod_lr,
563
+ )
564
+ states.append(state)
565
+ elif isinstance(block, TransformerBlock):
566
+ state = build_block_fast_state(
567
+ titan_module=None,
568
+ cms_blocks={},
569
+ specs=(),
570
+ optimizer_configs={},
571
+ default_lr=0.0,
572
+ )
573
+ states.append(state)
574
+ else:
575
+ raise TypeError(f"Unsupported block type for fast state: {type(block)}")
576
+ return ModelFastState(blocks=states)
577
+
578
+ def init_attention_cache(self) -> ModelAttentionCache:
579
+ return ModelAttentionCache(blocks=[None for _ in self.blocks])
580
+
581
+ def freeze_backbone(self) -> None:
582
+ """
583
+ Freeze the shared transformer spine (embeddings, attention blocks, norm, LM head).
584
+ HOPE/TITAN/CMS memories remain trainable for adapter-style finetuning.
585
+ """
586
+ for p in self.embed.parameters():
587
+ p.requires_grad = False
588
+ for p in self.norm.parameters():
589
+ p.requires_grad = False
590
+ for p in self.lm_head.parameters():
591
+ p.requires_grad = False
592
+ for block in self.blocks:
593
+ attn = getattr(block, "attn", None)
594
+ if isinstance(attn, nn.Module):
595
+ for p in attn.parameters():
596
+ p.requires_grad = False
597
+
598
+
599
+ class _UpdateControlledBlock(Protocol):
600
+ def set_surprise_threshold(self, threshold: float | None) -> None: ...
601
+
602
+ def set_surprise_metric(self, metric: str) -> None: ...
603
+
604
+ def set_allowed_levels(self, allowed: set[str] | None) -> None: ...
File without changes