nested-learning 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nested_learning/__init__.py +12 -0
  2. nested_learning/__main__.py +12 -0
  3. nested_learning/assoc_memory.py +23 -0
  4. nested_learning/backbones.py +147 -0
  5. nested_learning/capabilities.py +104 -0
  6. nested_learning/cli.py +253 -0
  7. nested_learning/cms.py +92 -0
  8. nested_learning/config_utils.py +50 -0
  9. nested_learning/configs/ablations/cms_sparse.yaml +46 -0
  10. nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
  11. nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
  12. nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
  13. nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
  14. nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
  15. nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
  16. nested_learning/configs/data/continual_segments_sample.yaml +9 -0
  17. nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
  18. nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
  19. nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
  20. nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
  21. nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
  22. nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
  23. nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
  24. nested_learning/configs/deepspeed/zero3.json +25 -0
  25. nested_learning/configs/hope/mid.yaml +118 -0
  26. nested_learning/configs/hope/mid_fsdp.yaml +47 -0
  27. nested_learning/configs/hope/pilot.yaml +2 -0
  28. nested_learning/configs/hope/pilot_attention.yaml +9 -0
  29. nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
  30. nested_learning/configs/hope/pilot_transformer.yaml +9 -0
  31. nested_learning/configs/hope/target.yaml +145 -0
  32. nested_learning/configs/hope/target_fsdp.yaml +47 -0
  33. nested_learning/configs/mid_smoke.yaml +99 -0
  34. nested_learning/configs/mid_stage2.yaml +110 -0
  35. nested_learning/configs/mid_stage2_smoke.yaml +102 -0
  36. nested_learning/configs/mid_titan_baseline.yaml +92 -0
  37. nested_learning/configs/pilot.yaml +127 -0
  38. nested_learning/configs/pilot_paper_faithful.yaml +42 -0
  39. nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
  40. nested_learning/configs/pilot_smoke.yaml +80 -0
  41. nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
  42. nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
  43. nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
  44. nested_learning/continual_classification.py +136 -0
  45. nested_learning/continual_streaming.py +283 -0
  46. nested_learning/data.py +153 -0
  47. nested_learning/device.py +21 -0
  48. nested_learning/eval_state.py +72 -0
  49. nested_learning/fast_state.py +108 -0
  50. nested_learning/functional.py +69 -0
  51. nested_learning/hope/__init__.py +0 -0
  52. nested_learning/hope/block.py +1973 -0
  53. nested_learning/hope/self_mod.py +40 -0
  54. nested_learning/instrumentation.py +38 -0
  55. nested_learning/levels.py +94 -0
  56. nested_learning/logging_utils.py +64 -0
  57. nested_learning/memorize.py +382 -0
  58. nested_learning/model.py +604 -0
  59. nested_learning/optim/__init__.py +0 -0
  60. nested_learning/optim/deep.py +102 -0
  61. nested_learning/optim/factory.py +13 -0
  62. nested_learning/optim/m3.py +121 -0
  63. nested_learning/optim/manager.py +151 -0
  64. nested_learning/titan/__init__.py +0 -0
  65. nested_learning/titan/memory.py +88 -0
  66. nested_learning/titan/model.py +412 -0
  67. nested_learning/titan/self_modifying.py +724 -0
  68. nested_learning/tokenizer.py +28 -0
  69. nested_learning/tokenizer_coverage.py +77 -0
  70. nested_learning/training.py +1600 -0
  71. nested_learning/transformer.py +104 -0
  72. nested_learning-0.2.0.dist-info/METADATA +390 -0
  73. nested_learning-0.2.0.dist-info/RECORD +76 -0
  74. nested_learning-0.2.0.dist-info/WHEEL +4 -0
  75. nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
  76. nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,412 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, cast
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from ..backbones import AttentionConfig, SelfAttention
10
+ from ..fast_state import (
11
+ AttentionKVCache,
12
+ BlockFastState,
13
+ ModelAttentionCache,
14
+ ModelFastState,
15
+ build_block_fast_state,
16
+ )
17
+ from ..functional import (
18
+ call_with_deltas,
19
+ call_with_params,
20
+ grads_to_dict,
21
+ params_with_deltas,
22
+ require_grad_params,
23
+ )
24
+ from ..hope.self_mod import SelfModifier
25
+ from ..levels import LevelSpec
26
+ from ..optim.manager import LevelConfig, LevelOptimizerManager
27
+ from ..titan.memory import TitanMemory, TitanMemoryConfig
28
+
29
+
30
+ @dataclass
31
+ class TitanOnlyModelConfig:
32
+ vocab_size: int
33
+ dim: int
34
+ num_layers: int
35
+ heads: int
36
+ titan_level: LevelSpec
37
+ optimizers: Dict[str, dict] | None = None
38
+ teach_scale: float = 1.0
39
+ teach_clip: float = 0.0
40
+ teach_schedule: Dict[str, float] | None = None
41
+ qk_l2_norm: bool = False
42
+ local_conv_window: int | None = None
43
+ titan_hidden_multiplier: int = 4
44
+ activation: str = "gelu"
45
+ self_mod_hidden: int = 4
46
+ self_mod_lr: float = 1e-3
47
+ surprise_threshold: float | None = None
48
+ surprise_metric: str = "l2"
49
+ freeze_backbone: bool = False
50
+
51
+
52
+ class TitanOnlyBlock(nn.Module):
53
+ def __init__(self, config: TitanOnlyModelConfig):
54
+ super().__init__()
55
+ self.config = config
56
+ self.surprise_threshold: float | None = None
57
+ self.surprise_metric: str = "l2"
58
+ self.enabled: bool = True
59
+ self.attn = SelfAttention(
60
+ AttentionConfig(
61
+ dim=config.dim,
62
+ heads=config.heads,
63
+ qk_l2_norm=config.qk_l2_norm,
64
+ local_conv_window=config.local_conv_window,
65
+ )
66
+ )
67
+ titan_config = TitanMemoryConfig(
68
+ dim=config.dim,
69
+ hidden_multiplier=config.titan_hidden_multiplier,
70
+ activation=config.activation,
71
+ )
72
+ self.titan_memory = TitanMemory(titan_config)
73
+ self.self_modifier = SelfModifier(config.dim, hidden_multiplier=config.self_mod_hidden)
74
+ self.dropout = nn.Dropout(0.0)
75
+ self.norm = nn.LayerNorm(config.dim)
76
+ level_config = LevelConfig(
77
+ specs=[config.titan_level],
78
+ optimizer_configs=config.optimizers or {},
79
+ default_lr=config.self_mod_lr,
80
+ )
81
+ self.level_manager = LevelOptimizerManager(level_config)
82
+
83
+ def forward(
84
+ self,
85
+ x: torch.Tensor,
86
+ *,
87
+ teach_signal: torch.Tensor | None = None,
88
+ surprise_value: float | None = None,
89
+ fast_state: BlockFastState | None = None,
90
+ attention_cache: AttentionKVCache | None = None,
91
+ return_attention_cache: bool = False,
92
+ differentiable_updates: bool = False,
93
+ ) -> torch.Tensor | tuple[torch.Tensor, AttentionKVCache]:
94
+ _ = differentiable_updates
95
+ next_attn_cache: AttentionKVCache | None = None
96
+ if return_attention_cache:
97
+ attn_out, next_attn_cache = self.attn(
98
+ x,
99
+ kv_cache=attention_cache,
100
+ return_kv_cache=True,
101
+ )
102
+ else:
103
+ attn_out = self.attn(x, kv_cache=attention_cache)
104
+ if fast_state is None:
105
+ mem_out = self.titan_memory(attn_out)
106
+ else:
107
+ if fast_state.titan_params is None:
108
+ raise ValueError(
109
+ "fast_state.titan_params is required for TitanOnlyBlock fast-state forward"
110
+ )
111
+ mem_out = call_with_deltas(self.titan_memory, fast_state.titan_params, attn_out)
112
+ combined = attn_out + mem_out
113
+ if teach_signal is not None:
114
+ if fast_state is None:
115
+ self._update_titan(attn_out, mem_out, teach_signal, surprise_value)
116
+ else:
117
+ self._update_titan_fast(fast_state, attn_out, mem_out, teach_signal, surprise_value)
118
+ if fast_state is None:
119
+ self.level_manager.tick()
120
+ else:
121
+ fast_state.level_manager.tick()
122
+ out = self.norm(combined)
123
+ if return_attention_cache:
124
+ assert next_attn_cache is not None
125
+ return out, next_attn_cache
126
+ return out
127
+
128
+ def set_surprise_threshold(self, threshold: float | None) -> None:
129
+ self.surprise_threshold = threshold
130
+
131
+ def set_surprise_metric(self, metric: str) -> None:
132
+ self.surprise_metric = str(metric).strip().lower()
133
+
134
+ def set_enabled(self, enabled: bool) -> None:
135
+ self.enabled = enabled
136
+
137
+ def _passes_surprise(self, surprise_value: float | None) -> bool:
138
+ if self.surprise_threshold is None:
139
+ return True
140
+ if surprise_value is None:
141
+ return False
142
+ return surprise_value >= self.surprise_threshold
143
+
144
+ def _update_titan(
145
+ self,
146
+ attn_out: torch.Tensor,
147
+ mem_out: torch.Tensor,
148
+ teach_signal: torch.Tensor,
149
+ surprise_value: float | None,
150
+ ) -> None:
151
+ level_name = self.config.titan_level.name
152
+ if not self.enabled:
153
+ return
154
+ if not self.level_manager.should_update(level_name):
155
+ return
156
+ if not self._passes_surprise(surprise_value):
157
+ return
158
+ # Use full sequence for granular updates (Critique P1)
159
+ # Note: We intentionally do not pool over dim=1 (sequence) here.
160
+ modifier = self.self_modifier(
161
+ key=attn_out.detach(),
162
+ value=mem_out.detach(),
163
+ error_signal=teach_signal.detach(),
164
+ )
165
+ context_vec = attn_out.detach().mean(dim=(0, 1))
166
+ with torch.enable_grad():
167
+ query = attn_out.detach()
168
+ target = (teach_signal.detach() + modifier).detach()
169
+ base_params = {name: param for name, param in self.titan_memory.named_parameters()}
170
+ params_req = require_grad_params(base_params)
171
+ prediction = call_with_params(self.titan_memory, params_req, query)
172
+ loss_terms = nn.functional.mse_loss(prediction, target, reduction="none")
173
+ active = teach_signal.detach().abs().sum(dim=-1, keepdim=True) > 0
174
+ mask = active.float()
175
+ if self.surprise_threshold is not None and self.surprise_metric == "l2":
176
+ norms = teach_signal.norm(dim=-1, keepdim=True)
177
+ mask = mask * (norms >= self.surprise_threshold).float()
178
+ loss = (loss_terms * mask).sum() / mask.sum().clamp(min=1.0)
179
+
180
+ grads = torch.autograd.grad(
181
+ loss,
182
+ tuple(params_req.values()),
183
+ retain_graph=False,
184
+ allow_unused=True,
185
+ )
186
+ grads_dict = grads_to_dict(params_req, grads)
187
+ self.level_manager.apply_module_grads(
188
+ level_name,
189
+ self.titan_memory,
190
+ grads_dict,
191
+ context=context_vec,
192
+ force=True,
193
+ )
194
+ # Pop metrics to avoid stale entries even if we do not log them yet.
195
+ self.level_manager.pop_last_metrics(level_name)
196
+
197
+ def _update_titan_fast(
198
+ self,
199
+ fast_state: BlockFastState,
200
+ attn_out: torch.Tensor,
201
+ mem_out: torch.Tensor,
202
+ teach_signal: torch.Tensor,
203
+ surprise_value: float | None,
204
+ ) -> None:
205
+ level_name = self.config.titan_level.name
206
+ if not self.enabled:
207
+ return
208
+ if not fast_state.level_manager.should_update(level_name):
209
+ return
210
+ if not self._passes_surprise(surprise_value):
211
+ return
212
+ if fast_state.titan_params is None:
213
+ return
214
+ modifier = self.self_modifier(
215
+ key=attn_out.detach(),
216
+ value=mem_out.detach(),
217
+ error_signal=teach_signal.detach(),
218
+ )
219
+ context_vec = attn_out.detach().mean(dim=(0, 1))
220
+ base_params = fast_state.titan_params
221
+ forward_params = params_with_deltas(self.titan_memory, base_params)
222
+ params_req = require_grad_params(forward_params)
223
+ with torch.enable_grad():
224
+ query = attn_out.detach()
225
+ target = (teach_signal.detach() + modifier).detach()
226
+ prediction = call_with_params(self.titan_memory, params_req, query)
227
+ loss_terms = nn.functional.mse_loss(prediction, target, reduction="none")
228
+ active = teach_signal.detach().abs().sum(dim=-1, keepdim=True) > 0
229
+ mask = active.float()
230
+ if self.surprise_threshold is not None and self.surprise_metric == "l2":
231
+ norms = teach_signal.norm(dim=-1, keepdim=True)
232
+ mask = mask * (norms >= self.surprise_threshold).float()
233
+ loss = (loss_terms * mask).sum() / mask.sum().clamp(min=1.0)
234
+ grads = torch.autograd.grad(
235
+ loss,
236
+ tuple(params_req.values()),
237
+ retain_graph=False,
238
+ allow_unused=True,
239
+ )
240
+ grads_dict = grads_to_dict(params_req, grads)
241
+ updated, _magnitude = fast_state.level_manager.apply_grads(
242
+ level_name,
243
+ base_params,
244
+ grads_dict,
245
+ context=context_vec,
246
+ force=False,
247
+ )
248
+ fast_state.titan_params = updated
249
+ fast_state.level_manager.pop_last_metrics(level_name)
250
+
251
+
252
+ class TitanOnlyModel(nn.Module):
253
+ def __init__(self, config: TitanOnlyModelConfig):
254
+ super().__init__()
255
+ self.config = config
256
+ self.embed = nn.Embedding(config.vocab_size, config.dim)
257
+ self.blocks = nn.ModuleList([TitanOnlyBlock(config) for _ in range(config.num_layers)])
258
+ self.norm = nn.LayerNorm(config.dim)
259
+ self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
260
+ self.lm_head.weight = self.embed.weight
261
+ self._runtime_teach_scale = config.teach_scale
262
+ self._runtime_teach_clip = config.teach_clip
263
+ self._surprise_threshold: float | None = None
264
+ self._surprise_metric = "l2"
265
+ self._updates_enabled: bool = True
266
+ self.set_surprise_metric(config.surprise_metric)
267
+ self.set_surprise_threshold(config.surprise_threshold)
268
+ if config.freeze_backbone:
269
+ self.freeze_backbone()
270
+
271
+ def set_teach_runtime(self, *, scale: float | None = None, clip: float | None = None) -> None:
272
+ if scale is not None:
273
+ self._runtime_teach_scale = scale
274
+ if clip is not None:
275
+ self._runtime_teach_clip = clip
276
+
277
+ def set_surprise_threshold(self, threshold: float | None) -> None:
278
+ self._surprise_threshold = threshold
279
+ for block in self.blocks:
280
+ cast(TitanOnlyBlock, block).set_surprise_threshold(threshold)
281
+
282
+ def get_surprise_threshold(self) -> float | None:
283
+ return self._surprise_threshold
284
+
285
+ def set_surprise_metric(self, metric: str) -> None:
286
+ normalized = str(metric).strip().lower()
287
+ allowed = {"l2", "loss", "logit_entropy"}
288
+ if normalized not in allowed:
289
+ raise ValueError(
290
+ f"Unsupported surprise_metric={metric!r}; expected one of {sorted(allowed)}"
291
+ )
292
+ self._surprise_metric = normalized
293
+ for block in self.blocks:
294
+ cast(TitanOnlyBlock, block).set_surprise_metric(normalized)
295
+
296
+ def get_surprise_metric(self) -> str:
297
+ return self._surprise_metric
298
+
299
+ def set_allowed_update_levels(self, levels: set[str] | None) -> None:
300
+ enabled = True
301
+ if levels is not None and "titan" not in levels and len(levels) > 0:
302
+ enabled = False
303
+ self._updates_enabled = enabled
304
+ for block in self.blocks:
305
+ cast(TitanOnlyBlock, block).set_enabled(enabled)
306
+
307
+ def get_allowed_update_levels(self) -> set[str] | None:
308
+ if self._updates_enabled:
309
+ return {"titan"}
310
+ return set()
311
+
312
+ def forward(
313
+ self,
314
+ tokens: torch.Tensor,
315
+ *,
316
+ teach_signal: torch.Tensor | None = None,
317
+ fast_state: ModelFastState | None = None,
318
+ surprise_value: float | None = None,
319
+ attention_cache: ModelAttentionCache | None = None,
320
+ return_attention_cache: bool = False,
321
+ differentiable_updates: bool = False,
322
+ ) -> torch.Tensor | tuple[torch.Tensor, ModelAttentionCache]:
323
+ require_external = self._surprise_metric in {"loss", "logit_entropy"}
324
+ if require_external and self._surprise_threshold is not None:
325
+ if teach_signal is not None and surprise_value is None:
326
+ raise ValueError(
327
+ f"surprise_metric={self._surprise_metric} requires passing surprise_value "
328
+ "when model.surprise_threshold is set."
329
+ )
330
+ x = self.embed(tokens)
331
+ if fast_state is not None and len(fast_state.blocks) != len(self.blocks):
332
+ raise ValueError("fast_state.blocks length does not match model.blocks")
333
+ if attention_cache is not None and len(attention_cache.blocks) != len(self.blocks):
334
+ raise ValueError("attention_cache.blocks length does not match model.blocks")
335
+ base_surprise = surprise_value
336
+ next_caches: list[AttentionKVCache | None] = []
337
+ for idx, block in enumerate(self.blocks):
338
+ scaled_signal = None
339
+ if teach_signal is not None:
340
+ scaled_signal = teach_signal * self._runtime_teach_scale
341
+ if self._runtime_teach_clip > 0:
342
+ with torch.no_grad():
343
+ norm = scaled_signal.norm(dim=-1, keepdim=True)
344
+ scale = torch.clamp(norm / self._runtime_teach_clip, min=1.0)
345
+ scaled_signal = scaled_signal / scale
346
+ block_surprise = base_surprise
347
+ if (
348
+ scaled_signal is not None
349
+ and base_surprise is None
350
+ and self._surprise_metric == "l2"
351
+ ):
352
+ block_surprise = float(scaled_signal.norm(dim=-1).mean().item())
353
+ block_state = None if fast_state is None else fast_state.blocks[idx]
354
+ block_cache = None if attention_cache is None else attention_cache.blocks[idx]
355
+ if return_attention_cache:
356
+ x, next_cache = block( # type: ignore[arg-type]
357
+ x,
358
+ teach_signal=scaled_signal,
359
+ surprise_value=block_surprise,
360
+ fast_state=block_state,
361
+ attention_cache=block_cache,
362
+ return_attention_cache=True,
363
+ differentiable_updates=differentiable_updates,
364
+ )
365
+ next_caches.append(next_cache)
366
+ else:
367
+ x = block( # type: ignore[arg-type]
368
+ x,
369
+ teach_signal=scaled_signal,
370
+ surprise_value=block_surprise,
371
+ fast_state=block_state,
372
+ attention_cache=block_cache,
373
+ differentiable_updates=differentiable_updates,
374
+ )
375
+ x = self.norm(x)
376
+ logits = self.lm_head(x)
377
+ if return_attention_cache:
378
+ return logits, ModelAttentionCache(blocks=next_caches)
379
+ return logits
380
+
381
+ def freeze_backbone(self) -> None:
382
+ """
383
+ Freeze shared transformer components; leave TITAN memory/trainable paths active.
384
+ """
385
+ for p in self.embed.parameters():
386
+ p.requires_grad = False
387
+ for p in self.norm.parameters():
388
+ p.requires_grad = False
389
+ for p in self.lm_head.parameters():
390
+ p.requires_grad = False
391
+ for block in self.blocks:
392
+ typed_block = cast(TitanOnlyBlock, block)
393
+ for p in typed_block.attn.parameters():
394
+ p.requires_grad = False
395
+
396
+ def init_fast_state(self) -> ModelFastState:
397
+ states = []
398
+ for block in self.blocks:
399
+ typed_block = cast(TitanOnlyBlock, block)
400
+ specs = [typed_block.config.titan_level]
401
+ state = build_block_fast_state(
402
+ titan_module=typed_block.titan_memory,
403
+ cms_blocks={},
404
+ specs=specs,
405
+ optimizer_configs=typed_block.config.optimizers or {},
406
+ default_lr=typed_block.config.self_mod_lr,
407
+ )
408
+ states.append(state)
409
+ return ModelFastState(blocks=states)
410
+
411
+ def init_attention_cache(self) -> ModelAttentionCache:
412
+ return ModelAttentionCache(blocks=[None for _ in self.blocks])