nested-learning 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nested_learning/__init__.py +12 -0
  2. nested_learning/__main__.py +12 -0
  3. nested_learning/assoc_memory.py +23 -0
  4. nested_learning/backbones.py +147 -0
  5. nested_learning/capabilities.py +104 -0
  6. nested_learning/cli.py +253 -0
  7. nested_learning/cms.py +92 -0
  8. nested_learning/config_utils.py +50 -0
  9. nested_learning/configs/ablations/cms_sparse.yaml +46 -0
  10. nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
  11. nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
  12. nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
  13. nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
  14. nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
  15. nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
  16. nested_learning/configs/data/continual_segments_sample.yaml +9 -0
  17. nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
  18. nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
  19. nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
  20. nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
  21. nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
  22. nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
  23. nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
  24. nested_learning/configs/deepspeed/zero3.json +25 -0
  25. nested_learning/configs/hope/mid.yaml +118 -0
  26. nested_learning/configs/hope/mid_fsdp.yaml +47 -0
  27. nested_learning/configs/hope/pilot.yaml +2 -0
  28. nested_learning/configs/hope/pilot_attention.yaml +9 -0
  29. nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
  30. nested_learning/configs/hope/pilot_transformer.yaml +9 -0
  31. nested_learning/configs/hope/target.yaml +145 -0
  32. nested_learning/configs/hope/target_fsdp.yaml +47 -0
  33. nested_learning/configs/mid_smoke.yaml +99 -0
  34. nested_learning/configs/mid_stage2.yaml +110 -0
  35. nested_learning/configs/mid_stage2_smoke.yaml +102 -0
  36. nested_learning/configs/mid_titan_baseline.yaml +92 -0
  37. nested_learning/configs/pilot.yaml +127 -0
  38. nested_learning/configs/pilot_paper_faithful.yaml +42 -0
  39. nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
  40. nested_learning/configs/pilot_smoke.yaml +80 -0
  41. nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
  42. nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
  43. nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
  44. nested_learning/continual_classification.py +136 -0
  45. nested_learning/continual_streaming.py +283 -0
  46. nested_learning/data.py +153 -0
  47. nested_learning/device.py +21 -0
  48. nested_learning/eval_state.py +72 -0
  49. nested_learning/fast_state.py +108 -0
  50. nested_learning/functional.py +69 -0
  51. nested_learning/hope/__init__.py +0 -0
  52. nested_learning/hope/block.py +1973 -0
  53. nested_learning/hope/self_mod.py +40 -0
  54. nested_learning/instrumentation.py +38 -0
  55. nested_learning/levels.py +94 -0
  56. nested_learning/logging_utils.py +64 -0
  57. nested_learning/memorize.py +382 -0
  58. nested_learning/model.py +604 -0
  59. nested_learning/optim/__init__.py +0 -0
  60. nested_learning/optim/deep.py +102 -0
  61. nested_learning/optim/factory.py +13 -0
  62. nested_learning/optim/m3.py +121 -0
  63. nested_learning/optim/manager.py +151 -0
  64. nested_learning/titan/__init__.py +0 -0
  65. nested_learning/titan/memory.py +88 -0
  66. nested_learning/titan/model.py +412 -0
  67. nested_learning/titan/self_modifying.py +724 -0
  68. nested_learning/tokenizer.py +28 -0
  69. nested_learning/tokenizer_coverage.py +77 -0
  70. nested_learning/training.py +1600 -0
  71. nested_learning/transformer.py +104 -0
  72. nested_learning-0.2.0.dist-info/METADATA +390 -0
  73. nested_learning-0.2.0.dist-info/RECORD +76 -0
  74. nested_learning-0.2.0.dist-info/WHEEL +4 -0
  75. nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
  76. nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,102 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ @dataclass
11
+ class DeepMomentumState:
12
+ grad_avg: Optional[torch.Tensor] = None
13
+ sq_avg: Optional[torch.Tensor] = None
14
+
15
+
16
+ class DeepMomentum(nn.Module):
17
+ """Implements momentum variants described in the NL paper."""
18
+
19
+ def __init__(
20
+ self,
21
+ *,
22
+ beta: float = 0.9,
23
+ beta2: float = 0.999,
24
+ eps: float = 1e-8,
25
+ variant: str = "preconditioned",
26
+ ) -> None:
27
+ super().__init__()
28
+ self.beta = beta
29
+ self.beta2 = beta2
30
+ self.eps = eps
31
+ self.variant = variant
32
+ self.state: dict[str, DeepMomentumState] = {}
33
+ self.nonlinearity = nn.Tanh() if variant in {"dmgd", "muon"} else nn.Identity()
34
+ self.last_metrics: dict[str, float] = {}
35
+
36
+ def reset_state(self) -> None:
37
+ self.state.clear()
38
+
39
+ def _precondition(self, grad: torch.Tensor, state: DeepMomentumState) -> torch.Tensor:
40
+ if state.sq_avg is None or state.sq_avg.shape != grad.shape:
41
+ state.sq_avg = torch.zeros_like(grad)
42
+ state.sq_avg.mul_(self.beta2).addcmul_(grad, grad, value=1 - self.beta2)
43
+ denom = state.sq_avg.sqrt().add_(self.eps)
44
+ return grad / denom
45
+
46
+ def _nl_precondition(
47
+ self,
48
+ grad: torch.Tensor,
49
+ context: torch.Tensor | None,
50
+ ) -> tuple[torch.Tensor, dict[str, float]]:
51
+ metrics: dict[str, float] = {
52
+ "ctx_norm": 0.0,
53
+ "proj_norm": 0.0,
54
+ "proj_skipped": 0.0,
55
+ }
56
+ if context is None:
57
+ return grad, metrics
58
+ ctx = context
59
+ if ctx.ndim > 1:
60
+ ctx = ctx.reshape(-1, ctx.shape[-1]).mean(dim=0)
61
+ ctx_norm = torch.norm(ctx)
62
+ metrics["ctx_norm"] = ctx_norm.item()
63
+
64
+ if ctx_norm > 0:
65
+ if grad.ndim == 0 or grad.shape[-1] != ctx.shape[-1]:
66
+ metrics["proj_skipped"] = 1.0
67
+ return grad, metrics
68
+ unit = ctx / (ctx_norm + self.eps)
69
+ # Project grad orthogonal to context (rank-1 projector).
70
+ projection = (grad * unit).sum(dim=-1, keepdim=True) * unit
71
+ update = grad - projection
72
+ metrics["proj_norm"] = torch.norm(update).item()
73
+ return update, metrics
74
+ return grad, metrics
75
+
76
+ def forward( # type: ignore[override]
77
+ self,
78
+ grad: torch.Tensor,
79
+ *,
80
+ context: torch.Tensor | None = None,
81
+ param_key: str | None = None,
82
+ ) -> torch.Tensor:
83
+ key = param_key or "__default__"
84
+ state = self.state.get(key)
85
+ if state is None:
86
+ state = DeepMomentumState()
87
+ self.state[key] = state
88
+ if state.grad_avg is None or state.grad_avg.shape != grad.shape:
89
+ state.grad_avg = torch.zeros_like(grad)
90
+ self.last_metrics = {}
91
+ update = grad
92
+ if self.variant in {"preconditioned", "muon"}:
93
+ update = self._precondition(grad, state)
94
+ if self.variant == "l2_objective":
95
+ update = grad + 0.1 * torch.mean(grad, dim=-1, keepdim=True)
96
+ if self.variant == "nl_l2_precond":
97
+ update, metrics = self._nl_precondition(grad, context)
98
+ self.last_metrics.update(metrics)
99
+ if self.variant in {"dmgd", "muon"}:
100
+ update = self.nonlinearity(update)
101
+ state.grad_avg.mul_(self.beta).add_(update, alpha=1 - self.beta)
102
+ return state.grad_avg
@@ -0,0 +1,13 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict
4
+
5
+ from .deep import DeepMomentum
6
+
7
+
8
+ def build_optimizer(config: Dict[str, Any]) -> DeepMomentum:
9
+ opt_type = config.get("type", "deep_momentum").lower()
10
+ if opt_type != "deep_momentum":
11
+ raise ValueError(f"Unsupported optimizer type {opt_type}")
12
+ params = config.get("params", {})
13
+ return DeepMomentum(**params)
@@ -0,0 +1,121 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Iterable
4
+
5
+ import torch
6
+
7
+
8
+ def _newton_schulz(matrix: torch.Tensor, steps: int, eps: float = 1e-6) -> torch.Tensor:
9
+ if matrix.ndim != 2:
10
+ raise ValueError("Newton-Schulz expects a 2D matrix")
11
+ dtype = matrix.dtype
12
+ device = matrix.device
13
+ m, n = matrix.shape
14
+ x = matrix
15
+ norm = torch.linalg.norm(x)
16
+ x = x / (norm + eps)
17
+ eye = torch.eye(n, device=device, dtype=dtype)
18
+ for _ in range(steps):
19
+ x = 0.5 * x @ (3.0 * eye - x.T @ x)
20
+ return x
21
+
22
+
23
+ def _orthogonalize(tensor: torch.Tensor, steps: int, eps: float) -> torch.Tensor:
24
+ if tensor.ndim < 2:
25
+ return tensor
26
+ mat = tensor.reshape(tensor.shape[0], -1)
27
+ ortho = _newton_schulz(mat, steps=steps, eps=eps)
28
+ return ortho.reshape_as(tensor)
29
+
30
+
31
+ class M3(torch.optim.Optimizer):
32
+ """
33
+ Multi-scale Momentum Muon (M3) optimizer (Nested Learning paper, Algorithm 1).
34
+
35
+ This is a paper-faithful implementation for 2D weight tensors:
36
+ - M1: fast momentum
37
+ - M2: slow momentum (updated every `slow_chunk` steps)
38
+ - V: second moment
39
+ - O1/O2: Newton-Schulz orthogonalized momenta
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ params: Iterable[torch.nn.Parameter],
45
+ *,
46
+ lr: float = 1e-3,
47
+ beta1: float = 0.9,
48
+ beta2: float = 0.999,
49
+ beta3: float = 0.9,
50
+ alpha: float = 1.0,
51
+ eps: float = 1e-8,
52
+ ns_steps: int = 3,
53
+ slow_chunk: int = 100,
54
+ weight_decay: float = 0.0,
55
+ ) -> None:
56
+ defaults = dict(
57
+ lr=lr,
58
+ beta1=beta1,
59
+ beta2=beta2,
60
+ beta3=beta3,
61
+ alpha=alpha,
62
+ eps=eps,
63
+ ns_steps=ns_steps,
64
+ slow_chunk=slow_chunk,
65
+ weight_decay=weight_decay,
66
+ )
67
+ super().__init__(params, defaults)
68
+
69
+ @torch.no_grad()
70
+ def step(self, closure=None): # type: ignore[override]
71
+ loss = None
72
+ if closure is not None:
73
+ with torch.enable_grad():
74
+ loss = closure()
75
+ for group in self.param_groups:
76
+ lr = group["lr"]
77
+ beta1 = group["beta1"]
78
+ beta2 = group["beta2"]
79
+ beta3 = group["beta3"]
80
+ alpha = group["alpha"]
81
+ eps = group["eps"]
82
+ ns_steps = group["ns_steps"]
83
+ slow_chunk = group["slow_chunk"]
84
+ weight_decay = group["weight_decay"]
85
+ for p in group["params"]:
86
+ if p.grad is None:
87
+ continue
88
+ grad = p.grad
89
+ if weight_decay != 0.0:
90
+ grad = grad.add(p, alpha=weight_decay)
91
+ state = self.state[p]
92
+ if not state:
93
+ state["step"] = 0
94
+ state["m1"] = torch.zeros_like(p)
95
+ state["m2"] = torch.zeros_like(p)
96
+ state["v"] = torch.zeros_like(p)
97
+ state["slow_buffer"] = torch.zeros_like(p)
98
+ state["o2"] = torch.zeros_like(p)
99
+ state["step"] += 1
100
+ m1 = state["m1"]
101
+ m2 = state["m2"]
102
+ v = state["v"]
103
+ slow_buffer = state["slow_buffer"]
104
+
105
+ m1.add_(grad, alpha=beta1)
106
+ v.addcmul_(grad, grad, value=beta2)
107
+ slow_buffer.add_(grad)
108
+
109
+ o1 = _orthogonalize(m1, steps=ns_steps, eps=eps)
110
+ o2 = state["o2"]
111
+ denom = v.sqrt().add_(eps)
112
+ update = (o1 + alpha * o2) / denom
113
+ p.add_(update, alpha=-lr)
114
+
115
+ if slow_chunk > 0 and state["step"] % slow_chunk == 0:
116
+ # Paper Algorithm 1 uses the updated slow momentum term in the *next* chunk.
117
+ # Compute it after applying the current step update to avoid off-by-one usage.
118
+ m2.add_(slow_buffer, alpha=beta3)
119
+ slow_buffer.zero_()
120
+ state["o2"] = _orthogonalize(m2, steps=ns_steps, eps=eps)
121
+ return loss
@@ -0,0 +1,151 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, Sequence, Tuple
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ from ..levels import LevelClock, LevelSpec
10
+ from .factory import build_optimizer
11
+
12
+
13
+ @dataclass
14
+ class LevelConfig:
15
+ specs: Sequence[LevelSpec]
16
+ optimizer_configs: Dict[str, dict]
17
+ default_lr: float
18
+
19
+
20
+ class LevelOptimizerManager:
21
+ def __init__(self, config: LevelConfig):
22
+ self.clock = LevelClock(config.specs)
23
+ self.learning_rates: Dict[str, float] = {}
24
+ self.optimizers = {}
25
+ self._last_metrics: Dict[str, Dict[str, float]] = {}
26
+ for spec in config.specs:
27
+ key = spec.optimizer_key or "default"
28
+ optim_cfg = config.optimizer_configs.get(key, {"type": "deep_momentum", "params": {}})
29
+ lr = optim_cfg.get("lr", config.default_lr)
30
+ params_cfg = optim_cfg.get("params", {})
31
+ optimizer = build_optimizer(
32
+ {"type": optim_cfg.get("type", "deep_momentum"), "params": params_cfg}
33
+ )
34
+ self.optimizers[spec.name] = optimizer
35
+ self.learning_rates[spec.name] = lr
36
+
37
+ def should_update(self, level: str) -> bool:
38
+ return self.clock.should_update(level)
39
+
40
+ def optimize(
41
+ self,
42
+ level: str,
43
+ module: nn.Module,
44
+ loss: torch.Tensor,
45
+ *,
46
+ context: torch.Tensor | None = None,
47
+ force: bool = False,
48
+ ) -> float:
49
+ if (not force) and (not self.should_update(level)):
50
+ return 0.0
51
+ named_params: Tuple[Tuple[str, torch.nn.Parameter], ...] = tuple(
52
+ (name, param) for name, param in module.named_parameters() if param.requires_grad
53
+ )
54
+ if not named_params:
55
+ return 0.0
56
+ params = tuple(param for _, param in named_params)
57
+ grads = torch.autograd.grad(loss, params, retain_graph=False, allow_unused=True)
58
+ grads_dict: Dict[str, torch.Tensor] = {}
59
+ for (name, _), grad in zip(named_params, grads, strict=True):
60
+ if grad is None:
61
+ continue
62
+ grads_dict[name] = grad
63
+ return self.apply_module_grads(
64
+ level,
65
+ module,
66
+ grads_dict,
67
+ context=context,
68
+ force=True,
69
+ )
70
+
71
+ def apply_module_grads(
72
+ self,
73
+ level: str,
74
+ module: nn.Module,
75
+ grads: Dict[str, torch.Tensor],
76
+ *,
77
+ context: torch.Tensor | None = None,
78
+ force: bool = False,
79
+ ) -> float:
80
+ if (not force) and (not self.should_update(level)):
81
+ return 0.0
82
+ optimizer = self.optimizers[level]
83
+ lr = self.learning_rates[level]
84
+ total_norm = 0.0
85
+ with torch.no_grad():
86
+ for name, param in module.named_parameters():
87
+ if not param.requires_grad:
88
+ continue
89
+ grad = grads.get(name)
90
+ if grad is None:
91
+ continue
92
+ update = optimizer(grad, context=context, param_key=name)
93
+ param.add_(update, alpha=-lr)
94
+ total_norm += grad.norm().item()
95
+ self.clock.record_update(level)
96
+ metrics = getattr(optimizer, "last_metrics", None)
97
+ if metrics:
98
+ self._last_metrics[level] = dict(metrics)
99
+ else:
100
+ self._last_metrics[level] = {}
101
+ return total_norm
102
+
103
+ def tick(self) -> None:
104
+ self.clock.tick()
105
+
106
+ def pop_last_metrics(self, level: str) -> Dict[str, float]:
107
+ return self._last_metrics.pop(level, {})
108
+
109
+ def apply_grads(
110
+ self,
111
+ level: str,
112
+ params: Dict[str, torch.Tensor],
113
+ grads: Dict[str, torch.Tensor],
114
+ *,
115
+ context: torch.Tensor | None = None,
116
+ force: bool = False,
117
+ differentiable: bool = False,
118
+ ) -> tuple[Dict[str, torch.Tensor], float]:
119
+ if (not force) and (not self.should_update(level)):
120
+ return params, 0.0
121
+ optimizer = self.optimizers[level]
122
+ lr = self.learning_rates[level]
123
+ updated: Dict[str, torch.Tensor] = {}
124
+ total_norm = 0.0
125
+ if differentiable:
126
+ for name, param in params.items():
127
+ grad = grads.get(name)
128
+ if grad is None:
129
+ updated[name] = param
130
+ continue
131
+ updated[name] = param - lr * grad
132
+ total_norm += float(grad.detach().norm().item())
133
+ self.clock.record_update(level)
134
+ self._last_metrics[level] = {"differentiable_updates": 1.0}
135
+ return updated, total_norm
136
+ with torch.no_grad():
137
+ for name, param in params.items():
138
+ grad = grads.get(name)
139
+ if grad is None:
140
+ updated[name] = param
141
+ continue
142
+ update = optimizer(grad, context=context, param_key=name)
143
+ updated[name] = (param - lr * update).detach()
144
+ total_norm += grad.norm().item()
145
+ self.clock.record_update(level)
146
+ metrics = getattr(optimizer, "last_metrics", None)
147
+ if metrics:
148
+ self._last_metrics[level] = dict(metrics)
149
+ else:
150
+ self._last_metrics[level] = {}
151
+ return updated, total_norm
File without changes
@@ -0,0 +1,88 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from ..assoc_memory import AssocMemory
10
+
11
+
12
+ @dataclass
13
+ class TitanMemoryConfig:
14
+ dim: int
15
+ hidden_multiplier: int = 4
16
+ layers: int = 2
17
+ activation: str = "gelu"
18
+
19
+
20
+ def _activation(name: str) -> nn.Module:
21
+ if name.lower() == "relu":
22
+ return nn.ReLU()
23
+ if name.lower() == "gelu":
24
+ return nn.GELU()
25
+ if name.lower() == "silu":
26
+ return nn.SiLU()
27
+ msg = f"Unsupported activation {name}"
28
+ raise ValueError(msg)
29
+
30
+
31
+ class TitanMemory(AssocMemory):
32
+ """Simplified TITAN-style associative memory."""
33
+
34
+ def __init__(self, config: TitanMemoryConfig):
35
+ super().__init__()
36
+ self.config = config
37
+ hidden = config.dim * config.hidden_multiplier
38
+ blocks = []
39
+ activation = _activation(config.activation)
40
+ for layer_idx in range(config.layers - 1):
41
+ blocks.extend([nn.Linear(config.dim if layer_idx == 0 else hidden, hidden), activation])
42
+ blocks.append(nn.Linear(hidden if config.layers > 1 else config.dim, config.dim))
43
+ self.net = nn.Sequential(*blocks)
44
+ self.norm = nn.LayerNorm(config.dim)
45
+ self.grad_clip = 1.0
46
+
47
+ def forward(self, query: torch.Tensor) -> torch.Tensor: # type: ignore[override]
48
+ attn = self.net(query)
49
+ if self.training and self.grad_clip > 0:
50
+ with torch.no_grad():
51
+ norm = attn.norm(dim=-1, keepdim=True)
52
+ scale = torch.clamp(norm / self.grad_clip, min=1.0)
53
+ attn = attn / scale
54
+ return self.norm(attn)
55
+
56
+ def surprise(self, residual: torch.Tensor) -> torch.Tensor:
57
+ return residual.norm(dim=-1, keepdim=True)
58
+
59
+ @torch.no_grad()
60
+ def update(
61
+ self,
62
+ *,
63
+ key: torch.Tensor,
64
+ value: torch.Tensor,
65
+ error_signal: torch.Tensor | None = None,
66
+ lr: float = 1e-3,
67
+ ) -> None:
68
+ with torch.enable_grad():
69
+ key_detached = key.detach().requires_grad_(True)
70
+ prediction = self.forward(key_detached)
71
+ target = value.detach()
72
+ if error_signal is None:
73
+ loss = torch.mean((prediction - target) ** 2)
74
+ else:
75
+ loss = torch.mean(error_signal * prediction)
76
+ grads = torch.autograd.grad(loss, list(self.net.parameters()), retain_graph=False)
77
+ for param, grad in zip(self.net.parameters(), grads, strict=False):
78
+ if grad is None:
79
+ continue
80
+ param.add_(grad, alpha=-lr)
81
+
82
+ @torch.no_grad()
83
+ def apply_deltas(self, deltas: Dict[str, torch.Tensor], scale: float = 1.0) -> None:
84
+ for name, tensor in deltas.items():
85
+ target = dict(self.named_parameters()).get(name)
86
+ if target is None:
87
+ continue
88
+ target.add_(tensor, alpha=scale)