nested-learning 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nested_learning/__init__.py +12 -0
- nested_learning/__main__.py +12 -0
- nested_learning/assoc_memory.py +23 -0
- nested_learning/backbones.py +147 -0
- nested_learning/capabilities.py +104 -0
- nested_learning/cli.py +253 -0
- nested_learning/cms.py +92 -0
- nested_learning/config_utils.py +50 -0
- nested_learning/configs/ablations/cms_sparse.yaml +46 -0
- nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
- nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
- nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
- nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
- nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
- nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
- nested_learning/configs/data/continual_segments_sample.yaml +9 -0
- nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
- nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
- nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
- nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
- nested_learning/configs/deepspeed/zero3.json +25 -0
- nested_learning/configs/hope/mid.yaml +118 -0
- nested_learning/configs/hope/mid_fsdp.yaml +47 -0
- nested_learning/configs/hope/pilot.yaml +2 -0
- nested_learning/configs/hope/pilot_attention.yaml +9 -0
- nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
- nested_learning/configs/hope/pilot_transformer.yaml +9 -0
- nested_learning/configs/hope/target.yaml +145 -0
- nested_learning/configs/hope/target_fsdp.yaml +47 -0
- nested_learning/configs/mid_smoke.yaml +99 -0
- nested_learning/configs/mid_stage2.yaml +110 -0
- nested_learning/configs/mid_stage2_smoke.yaml +102 -0
- nested_learning/configs/mid_titan_baseline.yaml +92 -0
- nested_learning/configs/pilot.yaml +127 -0
- nested_learning/configs/pilot_paper_faithful.yaml +42 -0
- nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
- nested_learning/configs/pilot_smoke.yaml +80 -0
- nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
- nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
- nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
- nested_learning/continual_classification.py +136 -0
- nested_learning/continual_streaming.py +283 -0
- nested_learning/data.py +153 -0
- nested_learning/device.py +21 -0
- nested_learning/eval_state.py +72 -0
- nested_learning/fast_state.py +108 -0
- nested_learning/functional.py +69 -0
- nested_learning/hope/__init__.py +0 -0
- nested_learning/hope/block.py +1973 -0
- nested_learning/hope/self_mod.py +40 -0
- nested_learning/instrumentation.py +38 -0
- nested_learning/levels.py +94 -0
- nested_learning/logging_utils.py +64 -0
- nested_learning/memorize.py +382 -0
- nested_learning/model.py +604 -0
- nested_learning/optim/__init__.py +0 -0
- nested_learning/optim/deep.py +102 -0
- nested_learning/optim/factory.py +13 -0
- nested_learning/optim/m3.py +121 -0
- nested_learning/optim/manager.py +151 -0
- nested_learning/titan/__init__.py +0 -0
- nested_learning/titan/memory.py +88 -0
- nested_learning/titan/model.py +412 -0
- nested_learning/titan/self_modifying.py +724 -0
- nested_learning/tokenizer.py +28 -0
- nested_learning/tokenizer_coverage.py +77 -0
- nested_learning/training.py +1600 -0
- nested_learning/transformer.py +104 -0
- nested_learning-0.2.0.dist-info/METADATA +390 -0
- nested_learning-0.2.0.dist-info/RECORD +76 -0
- nested_learning-0.2.0.dist-info/WHEEL +4 -0
- nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
- nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SelfModifier(nn.Module):
|
|
8
|
+
"""
|
|
9
|
+
Learns parameter updates conditioned on key/value/error signals.
|
|
10
|
+
|
|
11
|
+
Note: In this implementation, we predict a 'target modification' (delta to the error signal)
|
|
12
|
+
rather than directly predicting weight deltas (Delta W). Mathematically, modifying the
|
|
13
|
+
target y to (y + delta) in the inner optimization step:
|
|
14
|
+
L = || f(x) - (y + delta) ||^2
|
|
15
|
+
results in a gradient update that is shifted by the gradient of delta.
|
|
16
|
+
This is functionally equivalent to a 'Learned Optimization Step' or 'Hypernetwork'
|
|
17
|
+
that modulates the update direction, but is more efficient to implement for
|
|
18
|
+
large memory modules than generating O(d^2) weight parameters directly.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, dim: int, hidden_multiplier: int = 4):
|
|
22
|
+
super().__init__()
|
|
23
|
+
hidden = dim * hidden_multiplier
|
|
24
|
+
self.net = nn.Sequential(
|
|
25
|
+
nn.Linear(dim * 3, hidden),
|
|
26
|
+
nn.GELU(),
|
|
27
|
+
nn.Linear(hidden, hidden),
|
|
28
|
+
nn.GELU(),
|
|
29
|
+
nn.Linear(hidden, dim),
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
def forward(
|
|
33
|
+
self,
|
|
34
|
+
*,
|
|
35
|
+
key: torch.Tensor,
|
|
36
|
+
value: torch.Tensor,
|
|
37
|
+
error_signal: torch.Tensor,
|
|
38
|
+
) -> torch.Tensor:
|
|
39
|
+
concat = torch.cat([key, value, error_signal], dim=-1)
|
|
40
|
+
return self.net(concat)
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Dict, List
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class UpdateEvent:
|
|
9
|
+
step: int
|
|
10
|
+
level: str
|
|
11
|
+
magnitude: float | None = None
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class UpdateLog:
|
|
16
|
+
"""Lightweight container for tracking update magnitudes per level."""
|
|
17
|
+
|
|
18
|
+
events: List[UpdateEvent] = field(default_factory=list)
|
|
19
|
+
|
|
20
|
+
def record(self, *, step: int, level: str, magnitude: float | None = None) -> None:
|
|
21
|
+
self.events.append(UpdateEvent(step=step, level=level, magnitude=magnitude))
|
|
22
|
+
|
|
23
|
+
def summary(self) -> Dict[str, Dict[str, float]]:
|
|
24
|
+
counts: Dict[str, int] = {}
|
|
25
|
+
totals: Dict[str, float] = {}
|
|
26
|
+
for event in self.events:
|
|
27
|
+
counts[event.level] = counts.get(event.level, 0) + 1
|
|
28
|
+
if event.magnitude is not None:
|
|
29
|
+
totals[event.level] = totals.get(event.level, 0.0) + event.magnitude
|
|
30
|
+
return {
|
|
31
|
+
level: {
|
|
32
|
+
"updates": counts[level],
|
|
33
|
+
"avg_magnitude": (
|
|
34
|
+
totals[level] / counts[level] if level in totals else float("nan")
|
|
35
|
+
),
|
|
36
|
+
}
|
|
37
|
+
for level in counts
|
|
38
|
+
}
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Dict, Iterable, List, MutableMapping, Sequence
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass(frozen=True)
|
|
8
|
+
class LevelSpec:
|
|
9
|
+
"""Configuration for a nested-learning level."""
|
|
10
|
+
|
|
11
|
+
name: str
|
|
12
|
+
update_period: int
|
|
13
|
+
warmup_steps: int = 0
|
|
14
|
+
jitter: int = 0
|
|
15
|
+
optimizer_key: str | None = None
|
|
16
|
+
|
|
17
|
+
def __post_init__(self) -> None:
|
|
18
|
+
if self.update_period <= 0:
|
|
19
|
+
msg = f"update_period for level {self.name} must be positive"
|
|
20
|
+
raise ValueError(msg)
|
|
21
|
+
if self.warmup_steps < 0:
|
|
22
|
+
msg = f"warmup_steps for level {self.name} must be non-negative"
|
|
23
|
+
raise ValueError(msg)
|
|
24
|
+
if self.jitter < 0:
|
|
25
|
+
msg = f"jitter for level {self.name} must be non-negative"
|
|
26
|
+
raise ValueError(msg)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class LevelState:
|
|
31
|
+
last_step: int = -1
|
|
32
|
+
updates: int = 0
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class LevelClock:
|
|
36
|
+
"""Deterministic scheduler for Nested Learning level updates."""
|
|
37
|
+
|
|
38
|
+
def __init__(self, specs: Sequence[LevelSpec]):
|
|
39
|
+
self._specs: Dict[str, LevelSpec] = {spec.name: spec for spec in specs}
|
|
40
|
+
if len(self._specs) != len(specs):
|
|
41
|
+
raise ValueError("Duplicate level names provided to LevelClock")
|
|
42
|
+
self._state: MutableMapping[str, LevelState] = {name: LevelState() for name in self._specs}
|
|
43
|
+
self._step: int = 0
|
|
44
|
+
self._timeline: List[dict] = []
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def step(self) -> int:
|
|
48
|
+
return self._step
|
|
49
|
+
|
|
50
|
+
def tick(self) -> None:
|
|
51
|
+
self._step += 1
|
|
52
|
+
|
|
53
|
+
def should_update(self, name: str) -> bool:
|
|
54
|
+
spec = self._specs[name]
|
|
55
|
+
state = self._state[name]
|
|
56
|
+
if self._step < spec.warmup_steps:
|
|
57
|
+
return False
|
|
58
|
+
delta = self._step - state.last_step
|
|
59
|
+
period = spec.update_period
|
|
60
|
+
if spec.jitter:
|
|
61
|
+
period = period + (self._step % (spec.jitter + 1))
|
|
62
|
+
return state.last_step < 0 or delta >= period
|
|
63
|
+
|
|
64
|
+
def record_update(self, name: str) -> None:
|
|
65
|
+
state = self._state[name]
|
|
66
|
+
state.last_step = self._step
|
|
67
|
+
state.updates += 1
|
|
68
|
+
self._timeline.append({"step": self._step, "level": name})
|
|
69
|
+
|
|
70
|
+
def levels_in_frequency_order(self) -> List[LevelSpec]:
|
|
71
|
+
return sorted(self._specs.values(), key=lambda spec: spec.update_period)
|
|
72
|
+
|
|
73
|
+
def stats(self) -> Dict[str, LevelState]:
|
|
74
|
+
return {
|
|
75
|
+
name: LevelState(state.last_step, state.updates) for name, state in self._state.items()
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
def timeline(self) -> List[dict]:
|
|
79
|
+
return list(self._timeline)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def ensure_level_specs(entries: Iterable[LevelSpec]) -> List[LevelSpec]:
|
|
83
|
+
"""Ensure deterministic ordering and validate duplicates."""
|
|
84
|
+
|
|
85
|
+
specs = list(entries)
|
|
86
|
+
seen = set()
|
|
87
|
+
ordered: List[LevelSpec] = []
|
|
88
|
+
for spec in specs:
|
|
89
|
+
if spec.name in seen:
|
|
90
|
+
msg = f"Duplicate level spec {spec.name}"
|
|
91
|
+
raise ValueError(msg)
|
|
92
|
+
seen.add(spec.name)
|
|
93
|
+
ordered.append(spec)
|
|
94
|
+
return ordered
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Dict, cast
|
|
6
|
+
|
|
7
|
+
from omegaconf import DictConfig, OmegaConf
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseLogger:
|
|
11
|
+
def log(self, metrics: Dict[str, Any], step: int) -> None:
|
|
12
|
+
raise NotImplementedError
|
|
13
|
+
|
|
14
|
+
def finish(self) -> None:
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class NullLogger(BaseLogger):
|
|
19
|
+
def log(self, metrics: Dict[str, Any], step: int) -> None:
|
|
20
|
+
return
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class JSONLogger(BaseLogger):
|
|
24
|
+
def __init__(self, path: Path):
|
|
25
|
+
self.path = path
|
|
26
|
+
self.records: list[Dict[str, Any]] = []
|
|
27
|
+
|
|
28
|
+
def log(self, metrics: Dict[str, Any], step: int) -> None:
|
|
29
|
+
payload = {"step": step, **metrics}
|
|
30
|
+
self.records.append(payload)
|
|
31
|
+
|
|
32
|
+
def finish(self) -> None:
|
|
33
|
+
self.path.parent.mkdir(parents=True, exist_ok=True)
|
|
34
|
+
self.path.write_text(json.dumps(self.records, indent=2))
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class WandbLogger(BaseLogger):
|
|
38
|
+
def __init__(self, cfg: DictConfig, full_cfg: DictConfig):
|
|
39
|
+
import wandb
|
|
40
|
+
|
|
41
|
+
project = cfg.get("project", "nested-learning")
|
|
42
|
+
run_name = cfg.get("run_name")
|
|
43
|
+
config_dict = cast(dict[str, Any], OmegaConf.to_container(full_cfg, resolve=True))
|
|
44
|
+
self.run = wandb.init(project=project, name=run_name, config=config_dict)
|
|
45
|
+
|
|
46
|
+
def log(self, metrics: Dict[str, Any], step: int) -> None:
|
|
47
|
+
if self.run is not None:
|
|
48
|
+
self.run.log(metrics, step=step)
|
|
49
|
+
|
|
50
|
+
def finish(self) -> None:
|
|
51
|
+
if self.run is not None:
|
|
52
|
+
self.run.finish()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def init_logger(logging_cfg: DictConfig | None, full_cfg: DictConfig) -> BaseLogger:
|
|
56
|
+
if logging_cfg is None or not logging_cfg.get("enabled", False):
|
|
57
|
+
return NullLogger()
|
|
58
|
+
backend = logging_cfg.get("backend", "wandb").lower()
|
|
59
|
+
if backend == "wandb":
|
|
60
|
+
return WandbLogger(logging_cfg, full_cfg)
|
|
61
|
+
if backend == "json":
|
|
62
|
+
path = Path(logging_cfg.get("path", "logs/train_metrics.json"))
|
|
63
|
+
return JSONLogger(path)
|
|
64
|
+
return NullLogger()
|
|
@@ -0,0 +1,382 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Dict
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
|
|
9
|
+
from .tokenizer import SentencePieceTokenizer
|
|
10
|
+
from .training import compute_teach_signal
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class MemorizeConfig:
|
|
15
|
+
enabled: bool = False
|
|
16
|
+
steps: int = 1
|
|
17
|
+
reset: bool = True
|
|
18
|
+
use_correct_answer: bool = False
|
|
19
|
+
use_fast_state: bool = True
|
|
20
|
+
surprise_threshold: float | None = None
|
|
21
|
+
paths: tuple[str, ...] | None = None
|
|
22
|
+
layers: tuple[int, ...] | None = None
|
|
23
|
+
online_chunk_size: int | None = None # If set, use online chunked updates
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def snapshot_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
|
27
|
+
return {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def restore_state_dict(model: torch.nn.Module, state: Dict[str, torch.Tensor]) -> None:
|
|
31
|
+
model.load_state_dict(state, strict=False)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _setup_memorization_context(model, cfg: MemorizeConfig):
|
|
35
|
+
"""Helper to setup model state for memorization."""
|
|
36
|
+
prev_allowed = getattr(model, "get_allowed_update_levels", lambda: None)()
|
|
37
|
+
prev_threshold = getattr(model, "get_surprise_threshold", lambda: None)()
|
|
38
|
+
prev_layers = getattr(model, "get_allowed_update_layers", lambda: None)()
|
|
39
|
+
|
|
40
|
+
if hasattr(model, "set_allowed_update_levels"):
|
|
41
|
+
allowed = None
|
|
42
|
+
if cfg.paths is not None:
|
|
43
|
+
allowed = {path.strip() for path in cfg.paths if path.strip()}
|
|
44
|
+
getattr(model, "set_allowed_update_levels")(allowed)
|
|
45
|
+
|
|
46
|
+
if cfg.surprise_threshold is not None and hasattr(model, "set_surprise_threshold"):
|
|
47
|
+
getattr(model, "set_surprise_threshold")(cfg.surprise_threshold)
|
|
48
|
+
|
|
49
|
+
if hasattr(model, "set_allowed_update_layers"):
|
|
50
|
+
layers = None
|
|
51
|
+
if cfg.layers is not None:
|
|
52
|
+
layers = {int(idx) for idx in cfg.layers}
|
|
53
|
+
getattr(model, "set_allowed_update_layers")(layers)
|
|
54
|
+
|
|
55
|
+
return prev_allowed, prev_threshold, prev_layers
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _teardown_memorization_context(model, prev_allowed, prev_threshold, prev_layers):
|
|
59
|
+
"""Helper to restore model state after memorization."""
|
|
60
|
+
if hasattr(model, "set_allowed_update_levels"):
|
|
61
|
+
getattr(model, "set_allowed_update_levels")(
|
|
62
|
+
prev_allowed if prev_allowed is None else set(prev_allowed)
|
|
63
|
+
)
|
|
64
|
+
if hasattr(model, "set_surprise_threshold"):
|
|
65
|
+
getattr(model, "set_surprise_threshold")(prev_threshold)
|
|
66
|
+
if hasattr(model, "set_allowed_update_layers"):
|
|
67
|
+
getattr(model, "set_allowed_update_layers")(
|
|
68
|
+
None if prev_layers is None else {int(idx) for idx in prev_layers}
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _collect_metrics(model, stats: dict[str, float]):
|
|
73
|
+
"""Helper to collect and aggregate update metrics."""
|
|
74
|
+
if hasattr(model, "pop_update_metrics"):
|
|
75
|
+
metrics = model.pop_update_metrics()
|
|
76
|
+
titan_updates = sum(
|
|
77
|
+
value for key, value in metrics.items() if key.endswith("titan.titan.grad_norm")
|
|
78
|
+
)
|
|
79
|
+
titan_hits = sum(
|
|
80
|
+
value for key, value in metrics.items() if key.endswith("titan.titan.gate_hit")
|
|
81
|
+
)
|
|
82
|
+
stats["titan_mem_updates"] += titan_updates
|
|
83
|
+
stats["titan_update_events"] += titan_hits
|
|
84
|
+
|
|
85
|
+
# Aggregate CMS updates per level: keys look like "layer{idx}.cms.<level>.<metric>".
|
|
86
|
+
for key, value in metrics.items():
|
|
87
|
+
parts = key.split(".")
|
|
88
|
+
if len(parts) < 4:
|
|
89
|
+
continue
|
|
90
|
+
if parts[-3] != "cms":
|
|
91
|
+
continue
|
|
92
|
+
level = parts[-2]
|
|
93
|
+
metric = parts[-1]
|
|
94
|
+
if metric == "grad_norm":
|
|
95
|
+
stats_key = f"{level}_updates"
|
|
96
|
+
stats[stats_key] = stats.get(stats_key, 0.0) + float(value)
|
|
97
|
+
elif metric == "gate_hit":
|
|
98
|
+
stats_key = f"{level}_update_events"
|
|
99
|
+
stats[stats_key] = stats.get(stats_key, 0.0) + float(value)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _layernorm_backward(
|
|
103
|
+
grad_out: torch.Tensor,
|
|
104
|
+
pre_norm: torch.Tensor,
|
|
105
|
+
norm: nn.LayerNorm,
|
|
106
|
+
) -> torch.Tensor:
|
|
107
|
+
"""
|
|
108
|
+
Convert gradient w.r.t. LayerNorm output into gradient w.r.t. LayerNorm input.
|
|
109
|
+
|
|
110
|
+
This aligns the teach signal with the pre-norm hidden state that the blocks actually update.
|
|
111
|
+
"""
|
|
112
|
+
if grad_out.shape != pre_norm.shape:
|
|
113
|
+
raise ValueError("grad_out and pre_norm must have identical shapes")
|
|
114
|
+
weight = norm.weight
|
|
115
|
+
if weight is None:
|
|
116
|
+
weight = torch.ones(pre_norm.shape[-1], device=pre_norm.device, dtype=pre_norm.dtype)
|
|
117
|
+
grad_hat = grad_out * weight.to(grad_out.dtype).view(1, 1, -1)
|
|
118
|
+
mean = pre_norm.mean(dim=-1, keepdim=True)
|
|
119
|
+
var = pre_norm.var(dim=-1, unbiased=False, keepdim=True)
|
|
120
|
+
inv_std = torch.rsqrt(var + norm.eps)
|
|
121
|
+
x_hat = (pre_norm - mean) * inv_std
|
|
122
|
+
grad_mean = grad_hat.mean(dim=-1, keepdim=True)
|
|
123
|
+
grad_proj = (grad_hat * x_hat).mean(dim=-1, keepdim=True)
|
|
124
|
+
return (grad_hat - grad_mean - x_hat * grad_proj) * inv_std
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _get_model_surprise_metric(model) -> str:
|
|
128
|
+
getter = getattr(model, "get_surprise_metric", None)
|
|
129
|
+
if callable(getter):
|
|
130
|
+
return str(getter()).strip().lower()
|
|
131
|
+
return "l2"
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _compute_surprise_value(
|
|
135
|
+
*,
|
|
136
|
+
model,
|
|
137
|
+
metric: str,
|
|
138
|
+
logits: torch.Tensor,
|
|
139
|
+
tokens: torch.Tensor,
|
|
140
|
+
teach_signal: torch.Tensor,
|
|
141
|
+
) -> tuple[float, float | None]:
|
|
142
|
+
normalized = str(metric).strip().lower()
|
|
143
|
+
if normalized == "l2":
|
|
144
|
+
runtime_scale = float(getattr(model, "_runtime_teach_scale", 1.0))
|
|
145
|
+
runtime_clip = float(getattr(model, "_runtime_teach_clip", 0.0))
|
|
146
|
+
scaled = teach_signal * runtime_scale
|
|
147
|
+
if runtime_clip > 0:
|
|
148
|
+
norm = scaled.norm(dim=-1, keepdim=True)
|
|
149
|
+
scale = torch.clamp(norm / runtime_clip, min=1.0)
|
|
150
|
+
scaled = scaled / scale
|
|
151
|
+
value = float(scaled.norm(dim=-1).mean().item())
|
|
152
|
+
return value, None
|
|
153
|
+
if normalized == "loss":
|
|
154
|
+
loss = torch.nn.functional.cross_entropy(
|
|
155
|
+
logits[:, :-1].reshape(-1, logits.size(-1)),
|
|
156
|
+
tokens[:, 1:].reshape(-1),
|
|
157
|
+
)
|
|
158
|
+
value = float(loss.detach().item())
|
|
159
|
+
return value, value
|
|
160
|
+
if normalized == "logit_entropy":
|
|
161
|
+
logits_detached = logits[:, :-1].detach().float()
|
|
162
|
+
probs = torch.softmax(logits_detached, dim=-1)
|
|
163
|
+
entropy = -(probs * torch.log(probs.clamp(min=1e-9))).sum(dim=-1).mean()
|
|
164
|
+
value = float(entropy.item())
|
|
165
|
+
return value, value
|
|
166
|
+
raise ValueError(f"Unsupported surprise_metric={metric!r}")
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def memorize_tokens(
|
|
170
|
+
model,
|
|
171
|
+
token_batch: torch.Tensor,
|
|
172
|
+
cfg: MemorizeConfig,
|
|
173
|
+
*,
|
|
174
|
+
fast_state=None,
|
|
175
|
+
teach_mask: torch.Tensor | None = None,
|
|
176
|
+
) -> dict[str, float]:
|
|
177
|
+
if token_batch.size(1) < 2:
|
|
178
|
+
return {}
|
|
179
|
+
|
|
180
|
+
if cfg.use_fast_state and fast_state is None:
|
|
181
|
+
raise ValueError("cfg.use_fast_state=True requires passing fast_state")
|
|
182
|
+
|
|
183
|
+
with torch.no_grad():
|
|
184
|
+
stats: dict[str, float] = {
|
|
185
|
+
"titan_mem_updates": 0.0,
|
|
186
|
+
"titan_update_events": 0.0,
|
|
187
|
+
"cms_fast_updates": 0.0,
|
|
188
|
+
"cms_fast_update_events": 0.0,
|
|
189
|
+
"cms_mid_updates": 0.0,
|
|
190
|
+
"cms_mid_update_events": 0.0,
|
|
191
|
+
"cms_slow_updates": 0.0,
|
|
192
|
+
"cms_slow_update_events": 0.0,
|
|
193
|
+
"cms_ultra_updates": 0.0,
|
|
194
|
+
"cms_ultra_update_events": 0.0,
|
|
195
|
+
}
|
|
196
|
+
prev_allowed, prev_threshold, prev_layers = _setup_memorization_context(model, cfg)
|
|
197
|
+
|
|
198
|
+
if cfg.online_chunk_size and cfg.online_chunk_size > 0:
|
|
199
|
+
# Online / Chunked Learning Mode
|
|
200
|
+
seq_len = token_batch.size(1)
|
|
201
|
+
chunk_size = cfg.online_chunk_size
|
|
202
|
+
|
|
203
|
+
# We process the sequence in increasing windows
|
|
204
|
+
# But to avoid O(N^2) cost for very long sequences, this is an approximation
|
|
205
|
+
# where we re-process the history. For faithful online learning, this is necessary
|
|
206
|
+
# without external KV cache management.
|
|
207
|
+
|
|
208
|
+
# Note: compute_teach_signal computes gradients for predicting tokens[1:]
|
|
209
|
+
# token_batch: [t0, t1, t2, t3]
|
|
210
|
+
# logits: [p1, p2, p3, p4] (aligned with t0..t3 input)
|
|
211
|
+
# teach_signal index i corresponds to error on token[i+1]
|
|
212
|
+
|
|
213
|
+
# We iterate over target token indices (1..seq_len-1) in chunks.
|
|
214
|
+
# For targets up to index K (exclusive end), feed tokens[:, :K] as context.
|
|
215
|
+
target_start = 1
|
|
216
|
+
while target_start < seq_len:
|
|
217
|
+
target_end = min(target_start + chunk_size, seq_len)
|
|
218
|
+
# We want to learn targets [target_start ... target_end]
|
|
219
|
+
# (python slice style end index).
|
|
220
|
+
# Range: target_start until target_end.
|
|
221
|
+
|
|
222
|
+
# To compute error for target at index K, we need input 0..K.
|
|
223
|
+
# So we need input up to target_end-1? No, up to target_end.
|
|
224
|
+
# Because compute_teach_signal aligns logits[:-1] with tokens[1:].
|
|
225
|
+
# If tokens is [A, B], logits[:-1] is preds for [B].
|
|
226
|
+
# So if we have input [A, B], we get error for B.
|
|
227
|
+
# If we have input [A, B, C], we get error for B, C.
|
|
228
|
+
|
|
229
|
+
# So to get error for targets up to target_end-1 (python slice),
|
|
230
|
+
# we need input tokens[:, :target_end].
|
|
231
|
+
|
|
232
|
+
context_tokens = token_batch[:, :target_end]
|
|
233
|
+
|
|
234
|
+
pre_norm = None
|
|
235
|
+
if hasattr(model, "forward_with_pre_norm"):
|
|
236
|
+
forward_fn = getattr(model, "forward_with_pre_norm")
|
|
237
|
+
logits, pre_norm = (
|
|
238
|
+
forward_fn(context_tokens, fast_state=fast_state)
|
|
239
|
+
if cfg.use_fast_state
|
|
240
|
+
else forward_fn(context_tokens)
|
|
241
|
+
)
|
|
242
|
+
else:
|
|
243
|
+
logits = (
|
|
244
|
+
model(context_tokens, fast_state=fast_state)
|
|
245
|
+
if cfg.use_fast_state
|
|
246
|
+
else model(context_tokens)
|
|
247
|
+
)
|
|
248
|
+
full_signal = compute_teach_signal(model, logits, context_tokens)
|
|
249
|
+
if pre_norm is not None:
|
|
250
|
+
norm = getattr(model, "norm", None)
|
|
251
|
+
if isinstance(norm, nn.LayerNorm):
|
|
252
|
+
full_signal = _layernorm_backward(full_signal, pre_norm, norm)
|
|
253
|
+
|
|
254
|
+
# full_signal length is target_end.
|
|
255
|
+
# indices correspond to errors for targets at 1 ... target_end.
|
|
256
|
+
# idx 0 -> target 1.
|
|
257
|
+
# idx k -> target k+1.
|
|
258
|
+
|
|
259
|
+
# We want to keep errors for targets [target_start ... target_end-1].
|
|
260
|
+
# These correspond to signal indices [target_start-1 ... target_end-2].
|
|
261
|
+
|
|
262
|
+
# Example: [A, B, C]. target_start=1 (B). target_end=2 (up to B).
|
|
263
|
+
# chunk=1.
|
|
264
|
+
# context [A, B].
|
|
265
|
+
# signal len 2. idx 0->B. idx 1->pad.
|
|
266
|
+
# We want B. idx 0.
|
|
267
|
+
# signal indices: target_start-1 (0) to target_end-1 (1)?
|
|
268
|
+
# Wait, if target_end is 2 (slice), we processed B.
|
|
269
|
+
# signal indices: 1-1=0. 2-2=0. Range 0:1.
|
|
270
|
+
|
|
271
|
+
mask = torch.zeros_like(full_signal)
|
|
272
|
+
mask_start = target_start - 1
|
|
273
|
+
mask_end = target_end - 1
|
|
274
|
+
mask[:, mask_start:mask_end, :] = 1.0
|
|
275
|
+
|
|
276
|
+
masked_signal = full_signal * mask
|
|
277
|
+
if teach_mask is not None:
|
|
278
|
+
if teach_mask.ndim != 2:
|
|
279
|
+
raise ValueError("teach_mask must have shape (B, T)")
|
|
280
|
+
if teach_mask.shape[0] != token_batch.shape[0]:
|
|
281
|
+
raise ValueError("teach_mask batch size mismatch")
|
|
282
|
+
mask_slice = teach_mask[:, :target_end].to(masked_signal.device).float()
|
|
283
|
+
masked_signal = masked_signal * mask_slice.unsqueeze(-1)
|
|
284
|
+
surprise_metric = _get_model_surprise_metric(model)
|
|
285
|
+
surprise_value, surprise_override = _compute_surprise_value(
|
|
286
|
+
model=model,
|
|
287
|
+
metric=surprise_metric,
|
|
288
|
+
logits=logits,
|
|
289
|
+
tokens=context_tokens,
|
|
290
|
+
teach_signal=masked_signal,
|
|
291
|
+
)
|
|
292
|
+
if cfg.surprise_threshold is not None and surprise_value < cfg.surprise_threshold:
|
|
293
|
+
target_start = target_end
|
|
294
|
+
continue
|
|
295
|
+
if cfg.use_fast_state:
|
|
296
|
+
model(
|
|
297
|
+
context_tokens,
|
|
298
|
+
teach_signal=masked_signal,
|
|
299
|
+
surprise_value=surprise_override,
|
|
300
|
+
fast_state=fast_state,
|
|
301
|
+
)
|
|
302
|
+
else:
|
|
303
|
+
model(
|
|
304
|
+
context_tokens,
|
|
305
|
+
teach_signal=masked_signal,
|
|
306
|
+
surprise_value=surprise_override,
|
|
307
|
+
)
|
|
308
|
+
_collect_metrics(model, stats)
|
|
309
|
+
|
|
310
|
+
target_start = target_end
|
|
311
|
+
|
|
312
|
+
else:
|
|
313
|
+
# Batch Mode (Default)
|
|
314
|
+
for _ in range(cfg.steps):
|
|
315
|
+
pre_norm = None
|
|
316
|
+
if hasattr(model, "forward_with_pre_norm"):
|
|
317
|
+
forward_fn = getattr(model, "forward_with_pre_norm")
|
|
318
|
+
logits, pre_norm = (
|
|
319
|
+
forward_fn(token_batch, fast_state=fast_state)
|
|
320
|
+
if cfg.use_fast_state
|
|
321
|
+
else forward_fn(token_batch)
|
|
322
|
+
)
|
|
323
|
+
else:
|
|
324
|
+
logits = (
|
|
325
|
+
model(token_batch, fast_state=fast_state)
|
|
326
|
+
if cfg.use_fast_state
|
|
327
|
+
else model(token_batch)
|
|
328
|
+
)
|
|
329
|
+
teach_signal = compute_teach_signal(model, logits, token_batch)
|
|
330
|
+
if pre_norm is not None:
|
|
331
|
+
norm = getattr(model, "norm", None)
|
|
332
|
+
if isinstance(norm, nn.LayerNorm):
|
|
333
|
+
teach_signal = _layernorm_backward(teach_signal, pre_norm, norm)
|
|
334
|
+
if teach_mask is not None:
|
|
335
|
+
if teach_mask.ndim != 2:
|
|
336
|
+
raise ValueError("teach_mask must have shape (B, T)")
|
|
337
|
+
if teach_mask.shape[:2] != teach_signal.shape[:2]:
|
|
338
|
+
raise ValueError("teach_mask shape mismatch")
|
|
339
|
+
mask_f = teach_mask.to(teach_signal.device).float().unsqueeze(-1)
|
|
340
|
+
teach_signal = teach_signal * mask_f
|
|
341
|
+
surprise_metric = _get_model_surprise_metric(model)
|
|
342
|
+
surprise_value, surprise_override = _compute_surprise_value(
|
|
343
|
+
model=model,
|
|
344
|
+
metric=surprise_metric,
|
|
345
|
+
logits=logits,
|
|
346
|
+
tokens=token_batch,
|
|
347
|
+
teach_signal=teach_signal,
|
|
348
|
+
)
|
|
349
|
+
if cfg.surprise_threshold is not None and surprise_value < cfg.surprise_threshold:
|
|
350
|
+
continue
|
|
351
|
+
if cfg.use_fast_state:
|
|
352
|
+
model(
|
|
353
|
+
token_batch,
|
|
354
|
+
teach_signal=teach_signal,
|
|
355
|
+
surprise_value=surprise_override,
|
|
356
|
+
fast_state=fast_state,
|
|
357
|
+
)
|
|
358
|
+
else:
|
|
359
|
+
model(token_batch, teach_signal=teach_signal, surprise_value=surprise_override)
|
|
360
|
+
_collect_metrics(model, stats)
|
|
361
|
+
|
|
362
|
+
_teardown_memorization_context(model, prev_allowed, prev_threshold, prev_layers)
|
|
363
|
+
return stats
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def memorize_sequence(
|
|
367
|
+
model,
|
|
368
|
+
tokenizer: SentencePieceTokenizer,
|
|
369
|
+
text: str,
|
|
370
|
+
device: torch.device,
|
|
371
|
+
cfg: MemorizeConfig,
|
|
372
|
+
*,
|
|
373
|
+
fast_state=None,
|
|
374
|
+
teach_mask: torch.Tensor | None = None,
|
|
375
|
+
) -> dict[str, float]:
|
|
376
|
+
if not text:
|
|
377
|
+
return {}
|
|
378
|
+
tokens = tokenizer.encode(text)
|
|
379
|
+
if tokens.size(0) < 2:
|
|
380
|
+
return {}
|
|
381
|
+
batch = tokens.to(device).unsqueeze(0)
|
|
382
|
+
return memorize_tokens(model, batch, cfg, fast_state=fast_state, teach_mask=teach_mask)
|