nested-learning 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nested_learning/__init__.py +12 -0
  2. nested_learning/__main__.py +12 -0
  3. nested_learning/assoc_memory.py +23 -0
  4. nested_learning/backbones.py +147 -0
  5. nested_learning/capabilities.py +104 -0
  6. nested_learning/cli.py +253 -0
  7. nested_learning/cms.py +92 -0
  8. nested_learning/config_utils.py +50 -0
  9. nested_learning/configs/ablations/cms_sparse.yaml +46 -0
  10. nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
  11. nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
  12. nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
  13. nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
  14. nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
  15. nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
  16. nested_learning/configs/data/continual_segments_sample.yaml +9 -0
  17. nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
  18. nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
  19. nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
  20. nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
  21. nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
  22. nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
  23. nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
  24. nested_learning/configs/deepspeed/zero3.json +25 -0
  25. nested_learning/configs/hope/mid.yaml +118 -0
  26. nested_learning/configs/hope/mid_fsdp.yaml +47 -0
  27. nested_learning/configs/hope/pilot.yaml +2 -0
  28. nested_learning/configs/hope/pilot_attention.yaml +9 -0
  29. nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
  30. nested_learning/configs/hope/pilot_transformer.yaml +9 -0
  31. nested_learning/configs/hope/target.yaml +145 -0
  32. nested_learning/configs/hope/target_fsdp.yaml +47 -0
  33. nested_learning/configs/mid_smoke.yaml +99 -0
  34. nested_learning/configs/mid_stage2.yaml +110 -0
  35. nested_learning/configs/mid_stage2_smoke.yaml +102 -0
  36. nested_learning/configs/mid_titan_baseline.yaml +92 -0
  37. nested_learning/configs/pilot.yaml +127 -0
  38. nested_learning/configs/pilot_paper_faithful.yaml +42 -0
  39. nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
  40. nested_learning/configs/pilot_smoke.yaml +80 -0
  41. nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
  42. nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
  43. nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
  44. nested_learning/continual_classification.py +136 -0
  45. nested_learning/continual_streaming.py +283 -0
  46. nested_learning/data.py +153 -0
  47. nested_learning/device.py +21 -0
  48. nested_learning/eval_state.py +72 -0
  49. nested_learning/fast_state.py +108 -0
  50. nested_learning/functional.py +69 -0
  51. nested_learning/hope/__init__.py +0 -0
  52. nested_learning/hope/block.py +1973 -0
  53. nested_learning/hope/self_mod.py +40 -0
  54. nested_learning/instrumentation.py +38 -0
  55. nested_learning/levels.py +94 -0
  56. nested_learning/logging_utils.py +64 -0
  57. nested_learning/memorize.py +382 -0
  58. nested_learning/model.py +604 -0
  59. nested_learning/optim/__init__.py +0 -0
  60. nested_learning/optim/deep.py +102 -0
  61. nested_learning/optim/factory.py +13 -0
  62. nested_learning/optim/m3.py +121 -0
  63. nested_learning/optim/manager.py +151 -0
  64. nested_learning/titan/__init__.py +0 -0
  65. nested_learning/titan/memory.py +88 -0
  66. nested_learning/titan/model.py +412 -0
  67. nested_learning/titan/self_modifying.py +724 -0
  68. nested_learning/tokenizer.py +28 -0
  69. nested_learning/tokenizer_coverage.py +77 -0
  70. nested_learning/training.py +1600 -0
  71. nested_learning/transformer.py +104 -0
  72. nested_learning-0.2.0.dist-info/METADATA +390 -0
  73. nested_learning-0.2.0.dist-info/RECORD +76 -0
  74. nested_learning-0.2.0.dist-info/WHEEL +4 -0
  75. nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
  76. nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,40 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class SelfModifier(nn.Module):
8
+ """
9
+ Learns parameter updates conditioned on key/value/error signals.
10
+
11
+ Note: In this implementation, we predict a 'target modification' (delta to the error signal)
12
+ rather than directly predicting weight deltas (Delta W). Mathematically, modifying the
13
+ target y to (y + delta) in the inner optimization step:
14
+ L = || f(x) - (y + delta) ||^2
15
+ results in a gradient update that is shifted by the gradient of delta.
16
+ This is functionally equivalent to a 'Learned Optimization Step' or 'Hypernetwork'
17
+ that modulates the update direction, but is more efficient to implement for
18
+ large memory modules than generating O(d^2) weight parameters directly.
19
+ """
20
+
21
+ def __init__(self, dim: int, hidden_multiplier: int = 4):
22
+ super().__init__()
23
+ hidden = dim * hidden_multiplier
24
+ self.net = nn.Sequential(
25
+ nn.Linear(dim * 3, hidden),
26
+ nn.GELU(),
27
+ nn.Linear(hidden, hidden),
28
+ nn.GELU(),
29
+ nn.Linear(hidden, dim),
30
+ )
31
+
32
+ def forward(
33
+ self,
34
+ *,
35
+ key: torch.Tensor,
36
+ value: torch.Tensor,
37
+ error_signal: torch.Tensor,
38
+ ) -> torch.Tensor:
39
+ concat = torch.cat([key, value, error_signal], dim=-1)
40
+ return self.net(concat)
@@ -0,0 +1,38 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Dict, List
5
+
6
+
7
+ @dataclass
8
+ class UpdateEvent:
9
+ step: int
10
+ level: str
11
+ magnitude: float | None = None
12
+
13
+
14
+ @dataclass
15
+ class UpdateLog:
16
+ """Lightweight container for tracking update magnitudes per level."""
17
+
18
+ events: List[UpdateEvent] = field(default_factory=list)
19
+
20
+ def record(self, *, step: int, level: str, magnitude: float | None = None) -> None:
21
+ self.events.append(UpdateEvent(step=step, level=level, magnitude=magnitude))
22
+
23
+ def summary(self) -> Dict[str, Dict[str, float]]:
24
+ counts: Dict[str, int] = {}
25
+ totals: Dict[str, float] = {}
26
+ for event in self.events:
27
+ counts[event.level] = counts.get(event.level, 0) + 1
28
+ if event.magnitude is not None:
29
+ totals[event.level] = totals.get(event.level, 0.0) + event.magnitude
30
+ return {
31
+ level: {
32
+ "updates": counts[level],
33
+ "avg_magnitude": (
34
+ totals[level] / counts[level] if level in totals else float("nan")
35
+ ),
36
+ }
37
+ for level in counts
38
+ }
@@ -0,0 +1,94 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, Iterable, List, MutableMapping, Sequence
5
+
6
+
7
+ @dataclass(frozen=True)
8
+ class LevelSpec:
9
+ """Configuration for a nested-learning level."""
10
+
11
+ name: str
12
+ update_period: int
13
+ warmup_steps: int = 0
14
+ jitter: int = 0
15
+ optimizer_key: str | None = None
16
+
17
+ def __post_init__(self) -> None:
18
+ if self.update_period <= 0:
19
+ msg = f"update_period for level {self.name} must be positive"
20
+ raise ValueError(msg)
21
+ if self.warmup_steps < 0:
22
+ msg = f"warmup_steps for level {self.name} must be non-negative"
23
+ raise ValueError(msg)
24
+ if self.jitter < 0:
25
+ msg = f"jitter for level {self.name} must be non-negative"
26
+ raise ValueError(msg)
27
+
28
+
29
+ @dataclass
30
+ class LevelState:
31
+ last_step: int = -1
32
+ updates: int = 0
33
+
34
+
35
+ class LevelClock:
36
+ """Deterministic scheduler for Nested Learning level updates."""
37
+
38
+ def __init__(self, specs: Sequence[LevelSpec]):
39
+ self._specs: Dict[str, LevelSpec] = {spec.name: spec for spec in specs}
40
+ if len(self._specs) != len(specs):
41
+ raise ValueError("Duplicate level names provided to LevelClock")
42
+ self._state: MutableMapping[str, LevelState] = {name: LevelState() for name in self._specs}
43
+ self._step: int = 0
44
+ self._timeline: List[dict] = []
45
+
46
+ @property
47
+ def step(self) -> int:
48
+ return self._step
49
+
50
+ def tick(self) -> None:
51
+ self._step += 1
52
+
53
+ def should_update(self, name: str) -> bool:
54
+ spec = self._specs[name]
55
+ state = self._state[name]
56
+ if self._step < spec.warmup_steps:
57
+ return False
58
+ delta = self._step - state.last_step
59
+ period = spec.update_period
60
+ if spec.jitter:
61
+ period = period + (self._step % (spec.jitter + 1))
62
+ return state.last_step < 0 or delta >= period
63
+
64
+ def record_update(self, name: str) -> None:
65
+ state = self._state[name]
66
+ state.last_step = self._step
67
+ state.updates += 1
68
+ self._timeline.append({"step": self._step, "level": name})
69
+
70
+ def levels_in_frequency_order(self) -> List[LevelSpec]:
71
+ return sorted(self._specs.values(), key=lambda spec: spec.update_period)
72
+
73
+ def stats(self) -> Dict[str, LevelState]:
74
+ return {
75
+ name: LevelState(state.last_step, state.updates) for name, state in self._state.items()
76
+ }
77
+
78
+ def timeline(self) -> List[dict]:
79
+ return list(self._timeline)
80
+
81
+
82
+ def ensure_level_specs(entries: Iterable[LevelSpec]) -> List[LevelSpec]:
83
+ """Ensure deterministic ordering and validate duplicates."""
84
+
85
+ specs = list(entries)
86
+ seen = set()
87
+ ordered: List[LevelSpec] = []
88
+ for spec in specs:
89
+ if spec.name in seen:
90
+ msg = f"Duplicate level spec {spec.name}"
91
+ raise ValueError(msg)
92
+ seen.add(spec.name)
93
+ ordered.append(spec)
94
+ return ordered
@@ -0,0 +1,64 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Any, Dict, cast
6
+
7
+ from omegaconf import DictConfig, OmegaConf
8
+
9
+
10
+ class BaseLogger:
11
+ def log(self, metrics: Dict[str, Any], step: int) -> None:
12
+ raise NotImplementedError
13
+
14
+ def finish(self) -> None:
15
+ pass
16
+
17
+
18
+ class NullLogger(BaseLogger):
19
+ def log(self, metrics: Dict[str, Any], step: int) -> None:
20
+ return
21
+
22
+
23
+ class JSONLogger(BaseLogger):
24
+ def __init__(self, path: Path):
25
+ self.path = path
26
+ self.records: list[Dict[str, Any]] = []
27
+
28
+ def log(self, metrics: Dict[str, Any], step: int) -> None:
29
+ payload = {"step": step, **metrics}
30
+ self.records.append(payload)
31
+
32
+ def finish(self) -> None:
33
+ self.path.parent.mkdir(parents=True, exist_ok=True)
34
+ self.path.write_text(json.dumps(self.records, indent=2))
35
+
36
+
37
+ class WandbLogger(BaseLogger):
38
+ def __init__(self, cfg: DictConfig, full_cfg: DictConfig):
39
+ import wandb
40
+
41
+ project = cfg.get("project", "nested-learning")
42
+ run_name = cfg.get("run_name")
43
+ config_dict = cast(dict[str, Any], OmegaConf.to_container(full_cfg, resolve=True))
44
+ self.run = wandb.init(project=project, name=run_name, config=config_dict)
45
+
46
+ def log(self, metrics: Dict[str, Any], step: int) -> None:
47
+ if self.run is not None:
48
+ self.run.log(metrics, step=step)
49
+
50
+ def finish(self) -> None:
51
+ if self.run is not None:
52
+ self.run.finish()
53
+
54
+
55
+ def init_logger(logging_cfg: DictConfig | None, full_cfg: DictConfig) -> BaseLogger:
56
+ if logging_cfg is None or not logging_cfg.get("enabled", False):
57
+ return NullLogger()
58
+ backend = logging_cfg.get("backend", "wandb").lower()
59
+ if backend == "wandb":
60
+ return WandbLogger(logging_cfg, full_cfg)
61
+ if backend == "json":
62
+ path = Path(logging_cfg.get("path", "logs/train_metrics.json"))
63
+ return JSONLogger(path)
64
+ return NullLogger()
@@ -0,0 +1,382 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from .tokenizer import SentencePieceTokenizer
10
+ from .training import compute_teach_signal
11
+
12
+
13
+ @dataclass
14
+ class MemorizeConfig:
15
+ enabled: bool = False
16
+ steps: int = 1
17
+ reset: bool = True
18
+ use_correct_answer: bool = False
19
+ use_fast_state: bool = True
20
+ surprise_threshold: float | None = None
21
+ paths: tuple[str, ...] | None = None
22
+ layers: tuple[int, ...] | None = None
23
+ online_chunk_size: int | None = None # If set, use online chunked updates
24
+
25
+
26
+ def snapshot_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
27
+ return {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
28
+
29
+
30
+ def restore_state_dict(model: torch.nn.Module, state: Dict[str, torch.Tensor]) -> None:
31
+ model.load_state_dict(state, strict=False)
32
+
33
+
34
+ def _setup_memorization_context(model, cfg: MemorizeConfig):
35
+ """Helper to setup model state for memorization."""
36
+ prev_allowed = getattr(model, "get_allowed_update_levels", lambda: None)()
37
+ prev_threshold = getattr(model, "get_surprise_threshold", lambda: None)()
38
+ prev_layers = getattr(model, "get_allowed_update_layers", lambda: None)()
39
+
40
+ if hasattr(model, "set_allowed_update_levels"):
41
+ allowed = None
42
+ if cfg.paths is not None:
43
+ allowed = {path.strip() for path in cfg.paths if path.strip()}
44
+ getattr(model, "set_allowed_update_levels")(allowed)
45
+
46
+ if cfg.surprise_threshold is not None and hasattr(model, "set_surprise_threshold"):
47
+ getattr(model, "set_surprise_threshold")(cfg.surprise_threshold)
48
+
49
+ if hasattr(model, "set_allowed_update_layers"):
50
+ layers = None
51
+ if cfg.layers is not None:
52
+ layers = {int(idx) for idx in cfg.layers}
53
+ getattr(model, "set_allowed_update_layers")(layers)
54
+
55
+ return prev_allowed, prev_threshold, prev_layers
56
+
57
+
58
+ def _teardown_memorization_context(model, prev_allowed, prev_threshold, prev_layers):
59
+ """Helper to restore model state after memorization."""
60
+ if hasattr(model, "set_allowed_update_levels"):
61
+ getattr(model, "set_allowed_update_levels")(
62
+ prev_allowed if prev_allowed is None else set(prev_allowed)
63
+ )
64
+ if hasattr(model, "set_surprise_threshold"):
65
+ getattr(model, "set_surprise_threshold")(prev_threshold)
66
+ if hasattr(model, "set_allowed_update_layers"):
67
+ getattr(model, "set_allowed_update_layers")(
68
+ None if prev_layers is None else {int(idx) for idx in prev_layers}
69
+ )
70
+
71
+
72
+ def _collect_metrics(model, stats: dict[str, float]):
73
+ """Helper to collect and aggregate update metrics."""
74
+ if hasattr(model, "pop_update_metrics"):
75
+ metrics = model.pop_update_metrics()
76
+ titan_updates = sum(
77
+ value for key, value in metrics.items() if key.endswith("titan.titan.grad_norm")
78
+ )
79
+ titan_hits = sum(
80
+ value for key, value in metrics.items() if key.endswith("titan.titan.gate_hit")
81
+ )
82
+ stats["titan_mem_updates"] += titan_updates
83
+ stats["titan_update_events"] += titan_hits
84
+
85
+ # Aggregate CMS updates per level: keys look like "layer{idx}.cms.<level>.<metric>".
86
+ for key, value in metrics.items():
87
+ parts = key.split(".")
88
+ if len(parts) < 4:
89
+ continue
90
+ if parts[-3] != "cms":
91
+ continue
92
+ level = parts[-2]
93
+ metric = parts[-1]
94
+ if metric == "grad_norm":
95
+ stats_key = f"{level}_updates"
96
+ stats[stats_key] = stats.get(stats_key, 0.0) + float(value)
97
+ elif metric == "gate_hit":
98
+ stats_key = f"{level}_update_events"
99
+ stats[stats_key] = stats.get(stats_key, 0.0) + float(value)
100
+
101
+
102
+ def _layernorm_backward(
103
+ grad_out: torch.Tensor,
104
+ pre_norm: torch.Tensor,
105
+ norm: nn.LayerNorm,
106
+ ) -> torch.Tensor:
107
+ """
108
+ Convert gradient w.r.t. LayerNorm output into gradient w.r.t. LayerNorm input.
109
+
110
+ This aligns the teach signal with the pre-norm hidden state that the blocks actually update.
111
+ """
112
+ if grad_out.shape != pre_norm.shape:
113
+ raise ValueError("grad_out and pre_norm must have identical shapes")
114
+ weight = norm.weight
115
+ if weight is None:
116
+ weight = torch.ones(pre_norm.shape[-1], device=pre_norm.device, dtype=pre_norm.dtype)
117
+ grad_hat = grad_out * weight.to(grad_out.dtype).view(1, 1, -1)
118
+ mean = pre_norm.mean(dim=-1, keepdim=True)
119
+ var = pre_norm.var(dim=-1, unbiased=False, keepdim=True)
120
+ inv_std = torch.rsqrt(var + norm.eps)
121
+ x_hat = (pre_norm - mean) * inv_std
122
+ grad_mean = grad_hat.mean(dim=-1, keepdim=True)
123
+ grad_proj = (grad_hat * x_hat).mean(dim=-1, keepdim=True)
124
+ return (grad_hat - grad_mean - x_hat * grad_proj) * inv_std
125
+
126
+
127
+ def _get_model_surprise_metric(model) -> str:
128
+ getter = getattr(model, "get_surprise_metric", None)
129
+ if callable(getter):
130
+ return str(getter()).strip().lower()
131
+ return "l2"
132
+
133
+
134
+ def _compute_surprise_value(
135
+ *,
136
+ model,
137
+ metric: str,
138
+ logits: torch.Tensor,
139
+ tokens: torch.Tensor,
140
+ teach_signal: torch.Tensor,
141
+ ) -> tuple[float, float | None]:
142
+ normalized = str(metric).strip().lower()
143
+ if normalized == "l2":
144
+ runtime_scale = float(getattr(model, "_runtime_teach_scale", 1.0))
145
+ runtime_clip = float(getattr(model, "_runtime_teach_clip", 0.0))
146
+ scaled = teach_signal * runtime_scale
147
+ if runtime_clip > 0:
148
+ norm = scaled.norm(dim=-1, keepdim=True)
149
+ scale = torch.clamp(norm / runtime_clip, min=1.0)
150
+ scaled = scaled / scale
151
+ value = float(scaled.norm(dim=-1).mean().item())
152
+ return value, None
153
+ if normalized == "loss":
154
+ loss = torch.nn.functional.cross_entropy(
155
+ logits[:, :-1].reshape(-1, logits.size(-1)),
156
+ tokens[:, 1:].reshape(-1),
157
+ )
158
+ value = float(loss.detach().item())
159
+ return value, value
160
+ if normalized == "logit_entropy":
161
+ logits_detached = logits[:, :-1].detach().float()
162
+ probs = torch.softmax(logits_detached, dim=-1)
163
+ entropy = -(probs * torch.log(probs.clamp(min=1e-9))).sum(dim=-1).mean()
164
+ value = float(entropy.item())
165
+ return value, value
166
+ raise ValueError(f"Unsupported surprise_metric={metric!r}")
167
+
168
+
169
+ def memorize_tokens(
170
+ model,
171
+ token_batch: torch.Tensor,
172
+ cfg: MemorizeConfig,
173
+ *,
174
+ fast_state=None,
175
+ teach_mask: torch.Tensor | None = None,
176
+ ) -> dict[str, float]:
177
+ if token_batch.size(1) < 2:
178
+ return {}
179
+
180
+ if cfg.use_fast_state and fast_state is None:
181
+ raise ValueError("cfg.use_fast_state=True requires passing fast_state")
182
+
183
+ with torch.no_grad():
184
+ stats: dict[str, float] = {
185
+ "titan_mem_updates": 0.0,
186
+ "titan_update_events": 0.0,
187
+ "cms_fast_updates": 0.0,
188
+ "cms_fast_update_events": 0.0,
189
+ "cms_mid_updates": 0.0,
190
+ "cms_mid_update_events": 0.0,
191
+ "cms_slow_updates": 0.0,
192
+ "cms_slow_update_events": 0.0,
193
+ "cms_ultra_updates": 0.0,
194
+ "cms_ultra_update_events": 0.0,
195
+ }
196
+ prev_allowed, prev_threshold, prev_layers = _setup_memorization_context(model, cfg)
197
+
198
+ if cfg.online_chunk_size and cfg.online_chunk_size > 0:
199
+ # Online / Chunked Learning Mode
200
+ seq_len = token_batch.size(1)
201
+ chunk_size = cfg.online_chunk_size
202
+
203
+ # We process the sequence in increasing windows
204
+ # But to avoid O(N^2) cost for very long sequences, this is an approximation
205
+ # where we re-process the history. For faithful online learning, this is necessary
206
+ # without external KV cache management.
207
+
208
+ # Note: compute_teach_signal computes gradients for predicting tokens[1:]
209
+ # token_batch: [t0, t1, t2, t3]
210
+ # logits: [p1, p2, p3, p4] (aligned with t0..t3 input)
211
+ # teach_signal index i corresponds to error on token[i+1]
212
+
213
+ # We iterate over target token indices (1..seq_len-1) in chunks.
214
+ # For targets up to index K (exclusive end), feed tokens[:, :K] as context.
215
+ target_start = 1
216
+ while target_start < seq_len:
217
+ target_end = min(target_start + chunk_size, seq_len)
218
+ # We want to learn targets [target_start ... target_end]
219
+ # (python slice style end index).
220
+ # Range: target_start until target_end.
221
+
222
+ # To compute error for target at index K, we need input 0..K.
223
+ # So we need input up to target_end-1? No, up to target_end.
224
+ # Because compute_teach_signal aligns logits[:-1] with tokens[1:].
225
+ # If tokens is [A, B], logits[:-1] is preds for [B].
226
+ # So if we have input [A, B], we get error for B.
227
+ # If we have input [A, B, C], we get error for B, C.
228
+
229
+ # So to get error for targets up to target_end-1 (python slice),
230
+ # we need input tokens[:, :target_end].
231
+
232
+ context_tokens = token_batch[:, :target_end]
233
+
234
+ pre_norm = None
235
+ if hasattr(model, "forward_with_pre_norm"):
236
+ forward_fn = getattr(model, "forward_with_pre_norm")
237
+ logits, pre_norm = (
238
+ forward_fn(context_tokens, fast_state=fast_state)
239
+ if cfg.use_fast_state
240
+ else forward_fn(context_tokens)
241
+ )
242
+ else:
243
+ logits = (
244
+ model(context_tokens, fast_state=fast_state)
245
+ if cfg.use_fast_state
246
+ else model(context_tokens)
247
+ )
248
+ full_signal = compute_teach_signal(model, logits, context_tokens)
249
+ if pre_norm is not None:
250
+ norm = getattr(model, "norm", None)
251
+ if isinstance(norm, nn.LayerNorm):
252
+ full_signal = _layernorm_backward(full_signal, pre_norm, norm)
253
+
254
+ # full_signal length is target_end.
255
+ # indices correspond to errors for targets at 1 ... target_end.
256
+ # idx 0 -> target 1.
257
+ # idx k -> target k+1.
258
+
259
+ # We want to keep errors for targets [target_start ... target_end-1].
260
+ # These correspond to signal indices [target_start-1 ... target_end-2].
261
+
262
+ # Example: [A, B, C]. target_start=1 (B). target_end=2 (up to B).
263
+ # chunk=1.
264
+ # context [A, B].
265
+ # signal len 2. idx 0->B. idx 1->pad.
266
+ # We want B. idx 0.
267
+ # signal indices: target_start-1 (0) to target_end-1 (1)?
268
+ # Wait, if target_end is 2 (slice), we processed B.
269
+ # signal indices: 1-1=0. 2-2=0. Range 0:1.
270
+
271
+ mask = torch.zeros_like(full_signal)
272
+ mask_start = target_start - 1
273
+ mask_end = target_end - 1
274
+ mask[:, mask_start:mask_end, :] = 1.0
275
+
276
+ masked_signal = full_signal * mask
277
+ if teach_mask is not None:
278
+ if teach_mask.ndim != 2:
279
+ raise ValueError("teach_mask must have shape (B, T)")
280
+ if teach_mask.shape[0] != token_batch.shape[0]:
281
+ raise ValueError("teach_mask batch size mismatch")
282
+ mask_slice = teach_mask[:, :target_end].to(masked_signal.device).float()
283
+ masked_signal = masked_signal * mask_slice.unsqueeze(-1)
284
+ surprise_metric = _get_model_surprise_metric(model)
285
+ surprise_value, surprise_override = _compute_surprise_value(
286
+ model=model,
287
+ metric=surprise_metric,
288
+ logits=logits,
289
+ tokens=context_tokens,
290
+ teach_signal=masked_signal,
291
+ )
292
+ if cfg.surprise_threshold is not None and surprise_value < cfg.surprise_threshold:
293
+ target_start = target_end
294
+ continue
295
+ if cfg.use_fast_state:
296
+ model(
297
+ context_tokens,
298
+ teach_signal=masked_signal,
299
+ surprise_value=surprise_override,
300
+ fast_state=fast_state,
301
+ )
302
+ else:
303
+ model(
304
+ context_tokens,
305
+ teach_signal=masked_signal,
306
+ surprise_value=surprise_override,
307
+ )
308
+ _collect_metrics(model, stats)
309
+
310
+ target_start = target_end
311
+
312
+ else:
313
+ # Batch Mode (Default)
314
+ for _ in range(cfg.steps):
315
+ pre_norm = None
316
+ if hasattr(model, "forward_with_pre_norm"):
317
+ forward_fn = getattr(model, "forward_with_pre_norm")
318
+ logits, pre_norm = (
319
+ forward_fn(token_batch, fast_state=fast_state)
320
+ if cfg.use_fast_state
321
+ else forward_fn(token_batch)
322
+ )
323
+ else:
324
+ logits = (
325
+ model(token_batch, fast_state=fast_state)
326
+ if cfg.use_fast_state
327
+ else model(token_batch)
328
+ )
329
+ teach_signal = compute_teach_signal(model, logits, token_batch)
330
+ if pre_norm is not None:
331
+ norm = getattr(model, "norm", None)
332
+ if isinstance(norm, nn.LayerNorm):
333
+ teach_signal = _layernorm_backward(teach_signal, pre_norm, norm)
334
+ if teach_mask is not None:
335
+ if teach_mask.ndim != 2:
336
+ raise ValueError("teach_mask must have shape (B, T)")
337
+ if teach_mask.shape[:2] != teach_signal.shape[:2]:
338
+ raise ValueError("teach_mask shape mismatch")
339
+ mask_f = teach_mask.to(teach_signal.device).float().unsqueeze(-1)
340
+ teach_signal = teach_signal * mask_f
341
+ surprise_metric = _get_model_surprise_metric(model)
342
+ surprise_value, surprise_override = _compute_surprise_value(
343
+ model=model,
344
+ metric=surprise_metric,
345
+ logits=logits,
346
+ tokens=token_batch,
347
+ teach_signal=teach_signal,
348
+ )
349
+ if cfg.surprise_threshold is not None and surprise_value < cfg.surprise_threshold:
350
+ continue
351
+ if cfg.use_fast_state:
352
+ model(
353
+ token_batch,
354
+ teach_signal=teach_signal,
355
+ surprise_value=surprise_override,
356
+ fast_state=fast_state,
357
+ )
358
+ else:
359
+ model(token_batch, teach_signal=teach_signal, surprise_value=surprise_override)
360
+ _collect_metrics(model, stats)
361
+
362
+ _teardown_memorization_context(model, prev_allowed, prev_threshold, prev_layers)
363
+ return stats
364
+
365
+
366
+ def memorize_sequence(
367
+ model,
368
+ tokenizer: SentencePieceTokenizer,
369
+ text: str,
370
+ device: torch.device,
371
+ cfg: MemorizeConfig,
372
+ *,
373
+ fast_state=None,
374
+ teach_mask: torch.Tensor | None = None,
375
+ ) -> dict[str, float]:
376
+ if not text:
377
+ return {}
378
+ tokens = tokenizer.encode(text)
379
+ if tokens.size(0) < 2:
380
+ return {}
381
+ batch = tokens.to(device).unsqueeze(0)
382
+ return memorize_tokens(model, batch, cfg, fast_state=fast_state, teach_mask=teach_mask)