nested-learning 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nested_learning/__init__.py +12 -0
- nested_learning/__main__.py +12 -0
- nested_learning/assoc_memory.py +23 -0
- nested_learning/backbones.py +147 -0
- nested_learning/capabilities.py +104 -0
- nested_learning/cli.py +253 -0
- nested_learning/cms.py +92 -0
- nested_learning/config_utils.py +50 -0
- nested_learning/configs/ablations/cms_sparse.yaml +46 -0
- nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
- nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
- nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
- nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
- nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
- nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
- nested_learning/configs/data/continual_segments_sample.yaml +9 -0
- nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
- nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
- nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
- nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
- nested_learning/configs/deepspeed/zero3.json +25 -0
- nested_learning/configs/hope/mid.yaml +118 -0
- nested_learning/configs/hope/mid_fsdp.yaml +47 -0
- nested_learning/configs/hope/pilot.yaml +2 -0
- nested_learning/configs/hope/pilot_attention.yaml +9 -0
- nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
- nested_learning/configs/hope/pilot_transformer.yaml +9 -0
- nested_learning/configs/hope/target.yaml +145 -0
- nested_learning/configs/hope/target_fsdp.yaml +47 -0
- nested_learning/configs/mid_smoke.yaml +99 -0
- nested_learning/configs/mid_stage2.yaml +110 -0
- nested_learning/configs/mid_stage2_smoke.yaml +102 -0
- nested_learning/configs/mid_titan_baseline.yaml +92 -0
- nested_learning/configs/pilot.yaml +127 -0
- nested_learning/configs/pilot_paper_faithful.yaml +42 -0
- nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
- nested_learning/configs/pilot_smoke.yaml +80 -0
- nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
- nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
- nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
- nested_learning/continual_classification.py +136 -0
- nested_learning/continual_streaming.py +283 -0
- nested_learning/data.py +153 -0
- nested_learning/device.py +21 -0
- nested_learning/eval_state.py +72 -0
- nested_learning/fast_state.py +108 -0
- nested_learning/functional.py +69 -0
- nested_learning/hope/__init__.py +0 -0
- nested_learning/hope/block.py +1973 -0
- nested_learning/hope/self_mod.py +40 -0
- nested_learning/instrumentation.py +38 -0
- nested_learning/levels.py +94 -0
- nested_learning/logging_utils.py +64 -0
- nested_learning/memorize.py +382 -0
- nested_learning/model.py +604 -0
- nested_learning/optim/__init__.py +0 -0
- nested_learning/optim/deep.py +102 -0
- nested_learning/optim/factory.py +13 -0
- nested_learning/optim/m3.py +121 -0
- nested_learning/optim/manager.py +151 -0
- nested_learning/titan/__init__.py +0 -0
- nested_learning/titan/memory.py +88 -0
- nested_learning/titan/model.py +412 -0
- nested_learning/titan/self_modifying.py +724 -0
- nested_learning/tokenizer.py +28 -0
- nested_learning/tokenizer_coverage.py +77 -0
- nested_learning/training.py +1600 -0
- nested_learning/transformer.py +104 -0
- nested_learning-0.2.0.dist-info/METADATA +390 -0
- nested_learning-0.2.0.dist-info/RECORD +76 -0
- nested_learning-0.2.0.dist-info/WHEEL +4 -0
- nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
- nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class DeepMomentumState:
|
|
12
|
+
grad_avg: Optional[torch.Tensor] = None
|
|
13
|
+
sq_avg: Optional[torch.Tensor] = None
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DeepMomentum(nn.Module):
|
|
17
|
+
"""Implements momentum variants described in the NL paper."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
*,
|
|
22
|
+
beta: float = 0.9,
|
|
23
|
+
beta2: float = 0.999,
|
|
24
|
+
eps: float = 1e-8,
|
|
25
|
+
variant: str = "preconditioned",
|
|
26
|
+
) -> None:
|
|
27
|
+
super().__init__()
|
|
28
|
+
self.beta = beta
|
|
29
|
+
self.beta2 = beta2
|
|
30
|
+
self.eps = eps
|
|
31
|
+
self.variant = variant
|
|
32
|
+
self.state: dict[str, DeepMomentumState] = {}
|
|
33
|
+
self.nonlinearity = nn.Tanh() if variant in {"dmgd", "muon"} else nn.Identity()
|
|
34
|
+
self.last_metrics: dict[str, float] = {}
|
|
35
|
+
|
|
36
|
+
def reset_state(self) -> None:
|
|
37
|
+
self.state.clear()
|
|
38
|
+
|
|
39
|
+
def _precondition(self, grad: torch.Tensor, state: DeepMomentumState) -> torch.Tensor:
|
|
40
|
+
if state.sq_avg is None or state.sq_avg.shape != grad.shape:
|
|
41
|
+
state.sq_avg = torch.zeros_like(grad)
|
|
42
|
+
state.sq_avg.mul_(self.beta2).addcmul_(grad, grad, value=1 - self.beta2)
|
|
43
|
+
denom = state.sq_avg.sqrt().add_(self.eps)
|
|
44
|
+
return grad / denom
|
|
45
|
+
|
|
46
|
+
def _nl_precondition(
|
|
47
|
+
self,
|
|
48
|
+
grad: torch.Tensor,
|
|
49
|
+
context: torch.Tensor | None,
|
|
50
|
+
) -> tuple[torch.Tensor, dict[str, float]]:
|
|
51
|
+
metrics: dict[str, float] = {
|
|
52
|
+
"ctx_norm": 0.0,
|
|
53
|
+
"proj_norm": 0.0,
|
|
54
|
+
"proj_skipped": 0.0,
|
|
55
|
+
}
|
|
56
|
+
if context is None:
|
|
57
|
+
return grad, metrics
|
|
58
|
+
ctx = context
|
|
59
|
+
if ctx.ndim > 1:
|
|
60
|
+
ctx = ctx.reshape(-1, ctx.shape[-1]).mean(dim=0)
|
|
61
|
+
ctx_norm = torch.norm(ctx)
|
|
62
|
+
metrics["ctx_norm"] = ctx_norm.item()
|
|
63
|
+
|
|
64
|
+
if ctx_norm > 0:
|
|
65
|
+
if grad.ndim == 0 or grad.shape[-1] != ctx.shape[-1]:
|
|
66
|
+
metrics["proj_skipped"] = 1.0
|
|
67
|
+
return grad, metrics
|
|
68
|
+
unit = ctx / (ctx_norm + self.eps)
|
|
69
|
+
# Project grad orthogonal to context (rank-1 projector).
|
|
70
|
+
projection = (grad * unit).sum(dim=-1, keepdim=True) * unit
|
|
71
|
+
update = grad - projection
|
|
72
|
+
metrics["proj_norm"] = torch.norm(update).item()
|
|
73
|
+
return update, metrics
|
|
74
|
+
return grad, metrics
|
|
75
|
+
|
|
76
|
+
def forward( # type: ignore[override]
|
|
77
|
+
self,
|
|
78
|
+
grad: torch.Tensor,
|
|
79
|
+
*,
|
|
80
|
+
context: torch.Tensor | None = None,
|
|
81
|
+
param_key: str | None = None,
|
|
82
|
+
) -> torch.Tensor:
|
|
83
|
+
key = param_key or "__default__"
|
|
84
|
+
state = self.state.get(key)
|
|
85
|
+
if state is None:
|
|
86
|
+
state = DeepMomentumState()
|
|
87
|
+
self.state[key] = state
|
|
88
|
+
if state.grad_avg is None or state.grad_avg.shape != grad.shape:
|
|
89
|
+
state.grad_avg = torch.zeros_like(grad)
|
|
90
|
+
self.last_metrics = {}
|
|
91
|
+
update = grad
|
|
92
|
+
if self.variant in {"preconditioned", "muon"}:
|
|
93
|
+
update = self._precondition(grad, state)
|
|
94
|
+
if self.variant == "l2_objective":
|
|
95
|
+
update = grad + 0.1 * torch.mean(grad, dim=-1, keepdim=True)
|
|
96
|
+
if self.variant == "nl_l2_precond":
|
|
97
|
+
update, metrics = self._nl_precondition(grad, context)
|
|
98
|
+
self.last_metrics.update(metrics)
|
|
99
|
+
if self.variant in {"dmgd", "muon"}:
|
|
100
|
+
update = self.nonlinearity(update)
|
|
101
|
+
state.grad_avg.mul_(self.beta).add_(update, alpha=1 - self.beta)
|
|
102
|
+
return state.grad_avg
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict
|
|
4
|
+
|
|
5
|
+
from .deep import DeepMomentum
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def build_optimizer(config: Dict[str, Any]) -> DeepMomentum:
|
|
9
|
+
opt_type = config.get("type", "deep_momentum").lower()
|
|
10
|
+
if opt_type != "deep_momentum":
|
|
11
|
+
raise ValueError(f"Unsupported optimizer type {opt_type}")
|
|
12
|
+
params = config.get("params", {})
|
|
13
|
+
return DeepMomentum(**params)
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Iterable
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _newton_schulz(matrix: torch.Tensor, steps: int, eps: float = 1e-6) -> torch.Tensor:
|
|
9
|
+
if matrix.ndim != 2:
|
|
10
|
+
raise ValueError("Newton-Schulz expects a 2D matrix")
|
|
11
|
+
dtype = matrix.dtype
|
|
12
|
+
device = matrix.device
|
|
13
|
+
m, n = matrix.shape
|
|
14
|
+
x = matrix
|
|
15
|
+
norm = torch.linalg.norm(x)
|
|
16
|
+
x = x / (norm + eps)
|
|
17
|
+
eye = torch.eye(n, device=device, dtype=dtype)
|
|
18
|
+
for _ in range(steps):
|
|
19
|
+
x = 0.5 * x @ (3.0 * eye - x.T @ x)
|
|
20
|
+
return x
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _orthogonalize(tensor: torch.Tensor, steps: int, eps: float) -> torch.Tensor:
|
|
24
|
+
if tensor.ndim < 2:
|
|
25
|
+
return tensor
|
|
26
|
+
mat = tensor.reshape(tensor.shape[0], -1)
|
|
27
|
+
ortho = _newton_schulz(mat, steps=steps, eps=eps)
|
|
28
|
+
return ortho.reshape_as(tensor)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class M3(torch.optim.Optimizer):
|
|
32
|
+
"""
|
|
33
|
+
Multi-scale Momentum Muon (M3) optimizer (Nested Learning paper, Algorithm 1).
|
|
34
|
+
|
|
35
|
+
This is a paper-faithful implementation for 2D weight tensors:
|
|
36
|
+
- M1: fast momentum
|
|
37
|
+
- M2: slow momentum (updated every `slow_chunk` steps)
|
|
38
|
+
- V: second moment
|
|
39
|
+
- O1/O2: Newton-Schulz orthogonalized momenta
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
params: Iterable[torch.nn.Parameter],
|
|
45
|
+
*,
|
|
46
|
+
lr: float = 1e-3,
|
|
47
|
+
beta1: float = 0.9,
|
|
48
|
+
beta2: float = 0.999,
|
|
49
|
+
beta3: float = 0.9,
|
|
50
|
+
alpha: float = 1.0,
|
|
51
|
+
eps: float = 1e-8,
|
|
52
|
+
ns_steps: int = 3,
|
|
53
|
+
slow_chunk: int = 100,
|
|
54
|
+
weight_decay: float = 0.0,
|
|
55
|
+
) -> None:
|
|
56
|
+
defaults = dict(
|
|
57
|
+
lr=lr,
|
|
58
|
+
beta1=beta1,
|
|
59
|
+
beta2=beta2,
|
|
60
|
+
beta3=beta3,
|
|
61
|
+
alpha=alpha,
|
|
62
|
+
eps=eps,
|
|
63
|
+
ns_steps=ns_steps,
|
|
64
|
+
slow_chunk=slow_chunk,
|
|
65
|
+
weight_decay=weight_decay,
|
|
66
|
+
)
|
|
67
|
+
super().__init__(params, defaults)
|
|
68
|
+
|
|
69
|
+
@torch.no_grad()
|
|
70
|
+
def step(self, closure=None): # type: ignore[override]
|
|
71
|
+
loss = None
|
|
72
|
+
if closure is not None:
|
|
73
|
+
with torch.enable_grad():
|
|
74
|
+
loss = closure()
|
|
75
|
+
for group in self.param_groups:
|
|
76
|
+
lr = group["lr"]
|
|
77
|
+
beta1 = group["beta1"]
|
|
78
|
+
beta2 = group["beta2"]
|
|
79
|
+
beta3 = group["beta3"]
|
|
80
|
+
alpha = group["alpha"]
|
|
81
|
+
eps = group["eps"]
|
|
82
|
+
ns_steps = group["ns_steps"]
|
|
83
|
+
slow_chunk = group["slow_chunk"]
|
|
84
|
+
weight_decay = group["weight_decay"]
|
|
85
|
+
for p in group["params"]:
|
|
86
|
+
if p.grad is None:
|
|
87
|
+
continue
|
|
88
|
+
grad = p.grad
|
|
89
|
+
if weight_decay != 0.0:
|
|
90
|
+
grad = grad.add(p, alpha=weight_decay)
|
|
91
|
+
state = self.state[p]
|
|
92
|
+
if not state:
|
|
93
|
+
state["step"] = 0
|
|
94
|
+
state["m1"] = torch.zeros_like(p)
|
|
95
|
+
state["m2"] = torch.zeros_like(p)
|
|
96
|
+
state["v"] = torch.zeros_like(p)
|
|
97
|
+
state["slow_buffer"] = torch.zeros_like(p)
|
|
98
|
+
state["o2"] = torch.zeros_like(p)
|
|
99
|
+
state["step"] += 1
|
|
100
|
+
m1 = state["m1"]
|
|
101
|
+
m2 = state["m2"]
|
|
102
|
+
v = state["v"]
|
|
103
|
+
slow_buffer = state["slow_buffer"]
|
|
104
|
+
|
|
105
|
+
m1.add_(grad, alpha=beta1)
|
|
106
|
+
v.addcmul_(grad, grad, value=beta2)
|
|
107
|
+
slow_buffer.add_(grad)
|
|
108
|
+
|
|
109
|
+
o1 = _orthogonalize(m1, steps=ns_steps, eps=eps)
|
|
110
|
+
o2 = state["o2"]
|
|
111
|
+
denom = v.sqrt().add_(eps)
|
|
112
|
+
update = (o1 + alpha * o2) / denom
|
|
113
|
+
p.add_(update, alpha=-lr)
|
|
114
|
+
|
|
115
|
+
if slow_chunk > 0 and state["step"] % slow_chunk == 0:
|
|
116
|
+
# Paper Algorithm 1 uses the updated slow momentum term in the *next* chunk.
|
|
117
|
+
# Compute it after applying the current step update to avoid off-by-one usage.
|
|
118
|
+
m2.add_(slow_buffer, alpha=beta3)
|
|
119
|
+
slow_buffer.zero_()
|
|
120
|
+
state["o2"] = _orthogonalize(m2, steps=ns_steps, eps=eps)
|
|
121
|
+
return loss
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Dict, Sequence, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import nn
|
|
8
|
+
|
|
9
|
+
from ..levels import LevelClock, LevelSpec
|
|
10
|
+
from .factory import build_optimizer
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class LevelConfig:
|
|
15
|
+
specs: Sequence[LevelSpec]
|
|
16
|
+
optimizer_configs: Dict[str, dict]
|
|
17
|
+
default_lr: float
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class LevelOptimizerManager:
|
|
21
|
+
def __init__(self, config: LevelConfig):
|
|
22
|
+
self.clock = LevelClock(config.specs)
|
|
23
|
+
self.learning_rates: Dict[str, float] = {}
|
|
24
|
+
self.optimizers = {}
|
|
25
|
+
self._last_metrics: Dict[str, Dict[str, float]] = {}
|
|
26
|
+
for spec in config.specs:
|
|
27
|
+
key = spec.optimizer_key or "default"
|
|
28
|
+
optim_cfg = config.optimizer_configs.get(key, {"type": "deep_momentum", "params": {}})
|
|
29
|
+
lr = optim_cfg.get("lr", config.default_lr)
|
|
30
|
+
params_cfg = optim_cfg.get("params", {})
|
|
31
|
+
optimizer = build_optimizer(
|
|
32
|
+
{"type": optim_cfg.get("type", "deep_momentum"), "params": params_cfg}
|
|
33
|
+
)
|
|
34
|
+
self.optimizers[spec.name] = optimizer
|
|
35
|
+
self.learning_rates[spec.name] = lr
|
|
36
|
+
|
|
37
|
+
def should_update(self, level: str) -> bool:
|
|
38
|
+
return self.clock.should_update(level)
|
|
39
|
+
|
|
40
|
+
def optimize(
|
|
41
|
+
self,
|
|
42
|
+
level: str,
|
|
43
|
+
module: nn.Module,
|
|
44
|
+
loss: torch.Tensor,
|
|
45
|
+
*,
|
|
46
|
+
context: torch.Tensor | None = None,
|
|
47
|
+
force: bool = False,
|
|
48
|
+
) -> float:
|
|
49
|
+
if (not force) and (not self.should_update(level)):
|
|
50
|
+
return 0.0
|
|
51
|
+
named_params: Tuple[Tuple[str, torch.nn.Parameter], ...] = tuple(
|
|
52
|
+
(name, param) for name, param in module.named_parameters() if param.requires_grad
|
|
53
|
+
)
|
|
54
|
+
if not named_params:
|
|
55
|
+
return 0.0
|
|
56
|
+
params = tuple(param for _, param in named_params)
|
|
57
|
+
grads = torch.autograd.grad(loss, params, retain_graph=False, allow_unused=True)
|
|
58
|
+
grads_dict: Dict[str, torch.Tensor] = {}
|
|
59
|
+
for (name, _), grad in zip(named_params, grads, strict=True):
|
|
60
|
+
if grad is None:
|
|
61
|
+
continue
|
|
62
|
+
grads_dict[name] = grad
|
|
63
|
+
return self.apply_module_grads(
|
|
64
|
+
level,
|
|
65
|
+
module,
|
|
66
|
+
grads_dict,
|
|
67
|
+
context=context,
|
|
68
|
+
force=True,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def apply_module_grads(
|
|
72
|
+
self,
|
|
73
|
+
level: str,
|
|
74
|
+
module: nn.Module,
|
|
75
|
+
grads: Dict[str, torch.Tensor],
|
|
76
|
+
*,
|
|
77
|
+
context: torch.Tensor | None = None,
|
|
78
|
+
force: bool = False,
|
|
79
|
+
) -> float:
|
|
80
|
+
if (not force) and (not self.should_update(level)):
|
|
81
|
+
return 0.0
|
|
82
|
+
optimizer = self.optimizers[level]
|
|
83
|
+
lr = self.learning_rates[level]
|
|
84
|
+
total_norm = 0.0
|
|
85
|
+
with torch.no_grad():
|
|
86
|
+
for name, param in module.named_parameters():
|
|
87
|
+
if not param.requires_grad:
|
|
88
|
+
continue
|
|
89
|
+
grad = grads.get(name)
|
|
90
|
+
if grad is None:
|
|
91
|
+
continue
|
|
92
|
+
update = optimizer(grad, context=context, param_key=name)
|
|
93
|
+
param.add_(update, alpha=-lr)
|
|
94
|
+
total_norm += grad.norm().item()
|
|
95
|
+
self.clock.record_update(level)
|
|
96
|
+
metrics = getattr(optimizer, "last_metrics", None)
|
|
97
|
+
if metrics:
|
|
98
|
+
self._last_metrics[level] = dict(metrics)
|
|
99
|
+
else:
|
|
100
|
+
self._last_metrics[level] = {}
|
|
101
|
+
return total_norm
|
|
102
|
+
|
|
103
|
+
def tick(self) -> None:
|
|
104
|
+
self.clock.tick()
|
|
105
|
+
|
|
106
|
+
def pop_last_metrics(self, level: str) -> Dict[str, float]:
|
|
107
|
+
return self._last_metrics.pop(level, {})
|
|
108
|
+
|
|
109
|
+
def apply_grads(
|
|
110
|
+
self,
|
|
111
|
+
level: str,
|
|
112
|
+
params: Dict[str, torch.Tensor],
|
|
113
|
+
grads: Dict[str, torch.Tensor],
|
|
114
|
+
*,
|
|
115
|
+
context: torch.Tensor | None = None,
|
|
116
|
+
force: bool = False,
|
|
117
|
+
differentiable: bool = False,
|
|
118
|
+
) -> tuple[Dict[str, torch.Tensor], float]:
|
|
119
|
+
if (not force) and (not self.should_update(level)):
|
|
120
|
+
return params, 0.0
|
|
121
|
+
optimizer = self.optimizers[level]
|
|
122
|
+
lr = self.learning_rates[level]
|
|
123
|
+
updated: Dict[str, torch.Tensor] = {}
|
|
124
|
+
total_norm = 0.0
|
|
125
|
+
if differentiable:
|
|
126
|
+
for name, param in params.items():
|
|
127
|
+
grad = grads.get(name)
|
|
128
|
+
if grad is None:
|
|
129
|
+
updated[name] = param
|
|
130
|
+
continue
|
|
131
|
+
updated[name] = param - lr * grad
|
|
132
|
+
total_norm += float(grad.detach().norm().item())
|
|
133
|
+
self.clock.record_update(level)
|
|
134
|
+
self._last_metrics[level] = {"differentiable_updates": 1.0}
|
|
135
|
+
return updated, total_norm
|
|
136
|
+
with torch.no_grad():
|
|
137
|
+
for name, param in params.items():
|
|
138
|
+
grad = grads.get(name)
|
|
139
|
+
if grad is None:
|
|
140
|
+
updated[name] = param
|
|
141
|
+
continue
|
|
142
|
+
update = optimizer(grad, context=context, param_key=name)
|
|
143
|
+
updated[name] = (param - lr * update).detach()
|
|
144
|
+
total_norm += grad.norm().item()
|
|
145
|
+
self.clock.record_update(level)
|
|
146
|
+
metrics = getattr(optimizer, "last_metrics", None)
|
|
147
|
+
if metrics:
|
|
148
|
+
self._last_metrics[level] = dict(metrics)
|
|
149
|
+
else:
|
|
150
|
+
self._last_metrics[level] = {}
|
|
151
|
+
return updated, total_norm
|
|
File without changes
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Dict
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
|
|
9
|
+
from ..assoc_memory import AssocMemory
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class TitanMemoryConfig:
|
|
14
|
+
dim: int
|
|
15
|
+
hidden_multiplier: int = 4
|
|
16
|
+
layers: int = 2
|
|
17
|
+
activation: str = "gelu"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _activation(name: str) -> nn.Module:
|
|
21
|
+
if name.lower() == "relu":
|
|
22
|
+
return nn.ReLU()
|
|
23
|
+
if name.lower() == "gelu":
|
|
24
|
+
return nn.GELU()
|
|
25
|
+
if name.lower() == "silu":
|
|
26
|
+
return nn.SiLU()
|
|
27
|
+
msg = f"Unsupported activation {name}"
|
|
28
|
+
raise ValueError(msg)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class TitanMemory(AssocMemory):
|
|
32
|
+
"""Simplified TITAN-style associative memory."""
|
|
33
|
+
|
|
34
|
+
def __init__(self, config: TitanMemoryConfig):
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.config = config
|
|
37
|
+
hidden = config.dim * config.hidden_multiplier
|
|
38
|
+
blocks = []
|
|
39
|
+
activation = _activation(config.activation)
|
|
40
|
+
for layer_idx in range(config.layers - 1):
|
|
41
|
+
blocks.extend([nn.Linear(config.dim if layer_idx == 0 else hidden, hidden), activation])
|
|
42
|
+
blocks.append(nn.Linear(hidden if config.layers > 1 else config.dim, config.dim))
|
|
43
|
+
self.net = nn.Sequential(*blocks)
|
|
44
|
+
self.norm = nn.LayerNorm(config.dim)
|
|
45
|
+
self.grad_clip = 1.0
|
|
46
|
+
|
|
47
|
+
def forward(self, query: torch.Tensor) -> torch.Tensor: # type: ignore[override]
|
|
48
|
+
attn = self.net(query)
|
|
49
|
+
if self.training and self.grad_clip > 0:
|
|
50
|
+
with torch.no_grad():
|
|
51
|
+
norm = attn.norm(dim=-1, keepdim=True)
|
|
52
|
+
scale = torch.clamp(norm / self.grad_clip, min=1.0)
|
|
53
|
+
attn = attn / scale
|
|
54
|
+
return self.norm(attn)
|
|
55
|
+
|
|
56
|
+
def surprise(self, residual: torch.Tensor) -> torch.Tensor:
|
|
57
|
+
return residual.norm(dim=-1, keepdim=True)
|
|
58
|
+
|
|
59
|
+
@torch.no_grad()
|
|
60
|
+
def update(
|
|
61
|
+
self,
|
|
62
|
+
*,
|
|
63
|
+
key: torch.Tensor,
|
|
64
|
+
value: torch.Tensor,
|
|
65
|
+
error_signal: torch.Tensor | None = None,
|
|
66
|
+
lr: float = 1e-3,
|
|
67
|
+
) -> None:
|
|
68
|
+
with torch.enable_grad():
|
|
69
|
+
key_detached = key.detach().requires_grad_(True)
|
|
70
|
+
prediction = self.forward(key_detached)
|
|
71
|
+
target = value.detach()
|
|
72
|
+
if error_signal is None:
|
|
73
|
+
loss = torch.mean((prediction - target) ** 2)
|
|
74
|
+
else:
|
|
75
|
+
loss = torch.mean(error_signal * prediction)
|
|
76
|
+
grads = torch.autograd.grad(loss, list(self.net.parameters()), retain_graph=False)
|
|
77
|
+
for param, grad in zip(self.net.parameters(), grads, strict=False):
|
|
78
|
+
if grad is None:
|
|
79
|
+
continue
|
|
80
|
+
param.add_(grad, alpha=-lr)
|
|
81
|
+
|
|
82
|
+
@torch.no_grad()
|
|
83
|
+
def apply_deltas(self, deltas: Dict[str, torch.Tensor], scale: float = 1.0) -> None:
|
|
84
|
+
for name, tensor in deltas.items():
|
|
85
|
+
target = dict(self.named_parameters()).get(name)
|
|
86
|
+
if target is None:
|
|
87
|
+
continue
|
|
88
|
+
target.add_(tensor, alpha=scale)
|