nested-learning 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nested_learning/__init__.py +12 -0
  2. nested_learning/__main__.py +12 -0
  3. nested_learning/assoc_memory.py +23 -0
  4. nested_learning/backbones.py +147 -0
  5. nested_learning/capabilities.py +104 -0
  6. nested_learning/cli.py +253 -0
  7. nested_learning/cms.py +92 -0
  8. nested_learning/config_utils.py +50 -0
  9. nested_learning/configs/ablations/cms_sparse.yaml +46 -0
  10. nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
  11. nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
  12. nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
  13. nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
  14. nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
  15. nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
  16. nested_learning/configs/data/continual_segments_sample.yaml +9 -0
  17. nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
  18. nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
  19. nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
  20. nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
  21. nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
  22. nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
  23. nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
  24. nested_learning/configs/deepspeed/zero3.json +25 -0
  25. nested_learning/configs/hope/mid.yaml +118 -0
  26. nested_learning/configs/hope/mid_fsdp.yaml +47 -0
  27. nested_learning/configs/hope/pilot.yaml +2 -0
  28. nested_learning/configs/hope/pilot_attention.yaml +9 -0
  29. nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
  30. nested_learning/configs/hope/pilot_transformer.yaml +9 -0
  31. nested_learning/configs/hope/target.yaml +145 -0
  32. nested_learning/configs/hope/target_fsdp.yaml +47 -0
  33. nested_learning/configs/mid_smoke.yaml +99 -0
  34. nested_learning/configs/mid_stage2.yaml +110 -0
  35. nested_learning/configs/mid_stage2_smoke.yaml +102 -0
  36. nested_learning/configs/mid_titan_baseline.yaml +92 -0
  37. nested_learning/configs/pilot.yaml +127 -0
  38. nested_learning/configs/pilot_paper_faithful.yaml +42 -0
  39. nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
  40. nested_learning/configs/pilot_smoke.yaml +80 -0
  41. nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
  42. nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
  43. nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
  44. nested_learning/continual_classification.py +136 -0
  45. nested_learning/continual_streaming.py +283 -0
  46. nested_learning/data.py +153 -0
  47. nested_learning/device.py +21 -0
  48. nested_learning/eval_state.py +72 -0
  49. nested_learning/fast_state.py +108 -0
  50. nested_learning/functional.py +69 -0
  51. nested_learning/hope/__init__.py +0 -0
  52. nested_learning/hope/block.py +1973 -0
  53. nested_learning/hope/self_mod.py +40 -0
  54. nested_learning/instrumentation.py +38 -0
  55. nested_learning/levels.py +94 -0
  56. nested_learning/logging_utils.py +64 -0
  57. nested_learning/memorize.py +382 -0
  58. nested_learning/model.py +604 -0
  59. nested_learning/optim/__init__.py +0 -0
  60. nested_learning/optim/deep.py +102 -0
  61. nested_learning/optim/factory.py +13 -0
  62. nested_learning/optim/m3.py +121 -0
  63. nested_learning/optim/manager.py +151 -0
  64. nested_learning/titan/__init__.py +0 -0
  65. nested_learning/titan/memory.py +88 -0
  66. nested_learning/titan/model.py +412 -0
  67. nested_learning/titan/self_modifying.py +724 -0
  68. nested_learning/tokenizer.py +28 -0
  69. nested_learning/tokenizer_coverage.py +77 -0
  70. nested_learning/training.py +1600 -0
  71. nested_learning/transformer.py +104 -0
  72. nested_learning-0.2.0.dist-info/METADATA +390 -0
  73. nested_learning-0.2.0.dist-info/RECORD +76 -0
  74. nested_learning-0.2.0.dist-info/WHEEL +4 -0
  75. nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
  76. nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,283 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, List, Sequence
6
+
7
+ import torch
8
+
9
+ from .continual_classification import ClassificationExample, unique_labels
10
+ from .memorize import MemorizeConfig, memorize_sequence
11
+ from .tokenizer import SentencePieceTokenizer
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class StreamingTask:
16
+ task_id: int
17
+ labels: List[str]
18
+ train: List[ClassificationExample]
19
+ eval: List[ClassificationExample]
20
+
21
+
22
+ @dataclass(frozen=True)
23
+ class ContinualEvalConfig:
24
+ task_size: int = 10
25
+ seed: int = 0
26
+ train_per_label: int = 50
27
+ eval_per_label: int = 50
28
+ prompt_template: str = "Text: {text}\nLabel:"
29
+ label_template: str = "{label}"
30
+ task_aware: bool = True
31
+
32
+
33
+ def _logprob_completion(
34
+ model,
35
+ tokenizer: SentencePieceTokenizer,
36
+ prompt: str,
37
+ completion: str,
38
+ device: torch.device,
39
+ *,
40
+ fast_state=None,
41
+ ) -> float:
42
+ prompt_ids = tokenizer.encode(prompt, add_bos=True)
43
+ completion_ids = tokenizer.encode(" " + completion, add_bos=False)
44
+ tokens = torch.cat([prompt_ids, completion_ids], dim=0).unsqueeze(0).to(device)
45
+ with torch.no_grad():
46
+ logits = model(tokens, fast_state=fast_state) if fast_state is not None else model(tokens)
47
+ log_probs = torch.log_softmax(logits[:, :-1, :], dim=-1)
48
+ target = tokens[:, 1:]
49
+ gathered = log_probs.gather(-1, target.unsqueeze(-1)).squeeze(-1)
50
+ prompt_len = prompt_ids.numel()
51
+ return float(gathered[0, prompt_len - 1 :].sum().item())
52
+
53
+
54
+ def predict_label(
55
+ model,
56
+ tokenizer: SentencePieceTokenizer,
57
+ text: str,
58
+ candidates: Sequence[str],
59
+ device: torch.device,
60
+ *,
61
+ prompt_template: str,
62
+ label_template: str,
63
+ fast_state=None,
64
+ ) -> str:
65
+ if not candidates:
66
+ raise ValueError("predict_label requires at least one candidate label")
67
+ prompt = prompt_template.format(text=text)
68
+ best_label = candidates[0]
69
+ best_score = -math.inf
70
+ for label in candidates:
71
+ label_str = label_template.format(label=label)
72
+ score = _logprob_completion(
73
+ model, tokenizer, prompt, label_str, device, fast_state=fast_state
74
+ )
75
+ if score > best_score:
76
+ best_score = score
77
+ best_label = label
78
+ return best_label
79
+
80
+
81
+ def _balanced_split(
82
+ examples: Sequence[ClassificationExample],
83
+ *,
84
+ labels: Sequence[str],
85
+ train_per_label: int,
86
+ eval_per_label: int,
87
+ ) -> tuple[List[ClassificationExample], List[ClassificationExample]]:
88
+ train: List[ClassificationExample] = []
89
+ eval_: List[ClassificationExample] = []
90
+ counts_train: Dict[str, int] = {lbl: 0 for lbl in labels}
91
+ counts_eval: Dict[str, int] = {lbl: 0 for lbl in labels}
92
+ for ex in examples:
93
+ lbl = ex.label
94
+ if lbl not in counts_train:
95
+ continue
96
+ if counts_train[lbl] < train_per_label:
97
+ train.append(ex)
98
+ counts_train[lbl] += 1
99
+ elif counts_eval[lbl] < eval_per_label:
100
+ eval_.append(ex)
101
+ counts_eval[lbl] += 1
102
+ if all(v >= train_per_label for v in counts_train.values()) and all(
103
+ v >= eval_per_label for v in counts_eval.values()
104
+ ):
105
+ break
106
+ return train, eval_
107
+
108
+
109
+ def build_streaming_tasks(
110
+ examples: Sequence[ClassificationExample],
111
+ *,
112
+ cfg: ContinualEvalConfig,
113
+ label_order: Sequence[str] | None = None,
114
+ ) -> List[StreamingTask]:
115
+ labels = list(label_order) if label_order is not None else unique_labels(examples)
116
+ if label_order is None:
117
+ import random
118
+
119
+ rng = random.Random(cfg.seed)
120
+ rng.shuffle(labels)
121
+ if cfg.task_size <= 0:
122
+ raise ValueError("task_size must be positive")
123
+ tasks: List[StreamingTask] = []
124
+ for task_id, start in enumerate(range(0, len(labels), cfg.task_size)):
125
+ task_labels = labels[start : start + cfg.task_size]
126
+ if not task_labels:
127
+ break
128
+ task_examples = [ex for ex in examples if ex.label in set(task_labels)]
129
+ train, eval_ = _balanced_split(
130
+ task_examples,
131
+ labels=task_labels,
132
+ train_per_label=cfg.train_per_label,
133
+ eval_per_label=cfg.eval_per_label,
134
+ )
135
+ tasks.append(
136
+ StreamingTask(task_id=task_id, labels=list(task_labels), train=train, eval=eval_)
137
+ )
138
+ return tasks
139
+
140
+
141
+ @dataclass(frozen=True)
142
+ class ContinualEvalResult:
143
+ task_accuracy_matrix: List[List[float]]
144
+ per_task_forgetting: List[float]
145
+ avg_accuracy_final: float
146
+ avg_forgetting: float
147
+
148
+
149
+ def evaluate_continual_classification(
150
+ model,
151
+ tokenizer: SentencePieceTokenizer,
152
+ tasks: Sequence[StreamingTask],
153
+ device: torch.device,
154
+ *,
155
+ cfg: ContinualEvalConfig,
156
+ memorize_cfg: MemorizeConfig,
157
+ ) -> tuple[ContinualEvalResult, Dict[str, Any]]:
158
+ """
159
+ Streaming class-incremental evaluation using generative classification + optional
160
+ test-time memorization.
161
+
162
+ - If `memorize_cfg.enabled`, each training example is memorized by appending the correct
163
+ label string.
164
+ - Accuracy is computed after each task on each task's eval set, producing a task-accuracy
165
+ matrix.
166
+ """
167
+ meta_snapshot: Dict[str, torch.Tensor] | None = None
168
+ if memorize_cfg.enabled and (not memorize_cfg.use_fast_state) and memorize_cfg.reset:
169
+ from .memorize import snapshot_state_dict # local import to avoid cycles
170
+
171
+ meta_snapshot = snapshot_state_dict(model)
172
+
173
+ fast_state = None
174
+ if memorize_cfg.enabled and memorize_cfg.use_fast_state:
175
+ if not hasattr(model, "init_fast_state"):
176
+ raise RuntimeError("Model does not support fast state memorization")
177
+ fast_state = model.init_fast_state()
178
+
179
+ task_acc: List[List[float]] = [[float("nan") for _ in tasks] for _ in tasks]
180
+ best_acc: List[float] = [0.0 for _ in tasks]
181
+
182
+ memorize_stats_total: Dict[str, float] = {}
183
+
184
+ def _eval_task(task_idx: int) -> float:
185
+ task = tasks[task_idx]
186
+ candidates = (
187
+ task.labels
188
+ if cfg.task_aware
189
+ else [lbl for t in tasks[: current_task + 1] for lbl in t.labels]
190
+ )
191
+ if not task.eval:
192
+ return float("nan")
193
+ correct = 0
194
+ for ex in task.eval:
195
+ pred = predict_label(
196
+ model,
197
+ tokenizer,
198
+ ex.text,
199
+ candidates,
200
+ device,
201
+ prompt_template=cfg.prompt_template,
202
+ label_template=cfg.label_template,
203
+ fast_state=fast_state,
204
+ )
205
+ correct += int(pred == ex.label)
206
+ return correct / len(task.eval) if task.eval else float("nan")
207
+
208
+ for current_task, task in enumerate(tasks):
209
+ # Online "training" on this task's examples via optional memorization.
210
+ for ex in task.train:
211
+ candidates = (
212
+ task.labels
213
+ if cfg.task_aware
214
+ else [lbl for t in tasks[: current_task + 1] for lbl in t.labels]
215
+ )
216
+ _ = predict_label(
217
+ model,
218
+ tokenizer,
219
+ ex.text,
220
+ candidates,
221
+ device,
222
+ prompt_template=cfg.prompt_template,
223
+ label_template=cfg.label_template,
224
+ fast_state=fast_state,
225
+ )
226
+ if memorize_cfg.enabled:
227
+ prompt = cfg.prompt_template.format(text=ex.text)
228
+ target = cfg.label_template.format(label=ex.label)
229
+ memorize_text = f"{prompt} {target}"
230
+ if memorize_cfg.use_fast_state and memorize_cfg.reset:
231
+ fast_state = model.init_fast_state()
232
+ stats = memorize_sequence(
233
+ model, tokenizer, memorize_text, device, memorize_cfg, fast_state=fast_state
234
+ )
235
+ for k, v in stats.items():
236
+ memorize_stats_total[k] = memorize_stats_total.get(k, 0.0) + v
237
+ if (
238
+ (not memorize_cfg.use_fast_state)
239
+ and memorize_cfg.reset
240
+ and meta_snapshot is not None
241
+ ):
242
+ from .memorize import restore_state_dict # local import to avoid cycles
243
+
244
+ restore_state_dict(model, meta_snapshot)
245
+
246
+ # Evaluate on all tasks seen so far.
247
+ for task_idx in range(current_task + 1):
248
+ acc = _eval_task(task_idx)
249
+ task_acc[task_idx][current_task] = acc
250
+ if not math.isnan(acc):
251
+ best_acc[task_idx] = max(best_acc[task_idx], acc)
252
+
253
+ final_accs = [task_acc[i][-1] for i in range(len(tasks)) if not math.isnan(task_acc[i][-1])]
254
+ avg_accuracy_final = sum(final_accs) / len(final_accs) if final_accs else float("nan")
255
+
256
+ per_task_forgetting: List[float] = []
257
+ for i in range(len(tasks)):
258
+ last = task_acc[i][-1]
259
+ if math.isnan(last):
260
+ per_task_forgetting.append(float("nan"))
261
+ continue
262
+ per_task_forgetting.append(best_acc[i] - last)
263
+ valid_forgetting = [f for f in per_task_forgetting if not math.isnan(f)]
264
+ avg_forgetting = (
265
+ sum(valid_forgetting) / len(valid_forgetting) if valid_forgetting else float("nan")
266
+ )
267
+
268
+ result = ContinualEvalResult(
269
+ task_accuracy_matrix=task_acc,
270
+ per_task_forgetting=per_task_forgetting,
271
+ avg_accuracy_final=avg_accuracy_final,
272
+ avg_forgetting=avg_forgetting,
273
+ )
274
+ meta = {
275
+ "task_size": cfg.task_size,
276
+ "train_per_label": cfg.train_per_label,
277
+ "eval_per_label": cfg.eval_per_label,
278
+ "task_aware": cfg.task_aware,
279
+ "prompt_template": cfg.prompt_template,
280
+ "label_template": cfg.label_template,
281
+ "memorize_stats": memorize_stats_total,
282
+ }
283
+ return result, meta
@@ -0,0 +1,153 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Iterator, List, Sequence
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import Dataset, IterableDataset, get_worker_info
10
+
11
+
12
+ @dataclass
13
+ class SyntheticTextConfig:
14
+ vocab_size: int
15
+ seq_len: int
16
+ dataset_size: int
17
+
18
+
19
+ class SyntheticTextDataset(Dataset[torch.Tensor]):
20
+ def __init__(self, config: SyntheticTextConfig):
21
+ self.config = config
22
+
23
+ def __len__(self) -> int:
24
+ return self.config.dataset_size
25
+
26
+ def __getitem__(self, idx: int) -> torch.Tensor:
27
+ g = torch.Generator().manual_seed(idx)
28
+ return torch.randint(0, self.config.vocab_size, (self.config.seq_len,), generator=g)
29
+
30
+
31
+ class TokenShardDataset(Dataset[torch.Tensor]):
32
+ """Memory-mapped dataset over NumPy shards produced by shard_corpus.py."""
33
+
34
+ def __init__(self, shard_dir: str | Path):
35
+ self.shard_dir = Path(shard_dir)
36
+ if not self.shard_dir.exists():
37
+ msg = f"Shard directory {self.shard_dir} does not exist"
38
+ raise FileNotFoundError(msg)
39
+ self.paths = sorted(self.shard_dir.glob("*.npy"))
40
+ if not self.paths:
41
+ msg = f"No shard files found in {self.shard_dir}"
42
+ raise ValueError(msg)
43
+ self.metadata: List[tuple[int, int]] = []
44
+ self._cache: dict[int, np.memmap] = {}
45
+ total = 0
46
+ for idx, path in enumerate(self.paths):
47
+ arr = np.load(path, mmap_mode="r")
48
+ length = arr.shape[0]
49
+ self.metadata.append((total, length))
50
+ total += length
51
+ self.total_sequences = total
52
+
53
+ def __len__(self) -> int:
54
+ return self.total_sequences
55
+
56
+ def _load_array(self, shard_idx: int) -> np.memmap:
57
+ if shard_idx not in self._cache:
58
+ self._cache[shard_idx] = np.load(self.paths[shard_idx], mmap_mode="r")
59
+ return self._cache[shard_idx]
60
+
61
+ def __getitem__(self, idx: int) -> torch.Tensor:
62
+ if idx < 0 or idx >= self.total_sequences:
63
+ raise IndexError(idx)
64
+ shard_idx = self._find_shard(idx)
65
+ start_offset = self.metadata[shard_idx][0]
66
+ arr = self._load_array(shard_idx)
67
+ local_idx = idx - start_offset
68
+ tokens = torch.from_numpy(arr[local_idx])
69
+ return tokens.long()
70
+
71
+ def _find_shard(self, idx: int) -> int:
72
+ lo, hi = 0, len(self.metadata) - 1
73
+ while lo <= hi:
74
+ mid = (lo + hi) // 2
75
+ start, length = self.metadata[mid]
76
+ if idx < start:
77
+ hi = mid - 1
78
+ elif idx >= start + length:
79
+ lo = mid + 1
80
+ else:
81
+ return mid
82
+ return len(self.metadata) - 1
83
+
84
+
85
+ @dataclass
86
+ class ShardSourceConfig:
87
+ name: str
88
+ shards_dir: str
89
+ weight: float
90
+
91
+
92
+ class ShardSource:
93
+ def __init__(self, config: ShardSourceConfig):
94
+ self.name = config.name
95
+ self.weight = config.weight
96
+ self.dir = Path(config.shards_dir)
97
+ if not self.dir.exists():
98
+ msg = f"Shard directory {self.dir} missing for source {self.name}"
99
+ raise FileNotFoundError(msg)
100
+ self.paths = sorted(self.dir.glob("*.npy"))
101
+ if not self.paths:
102
+ raise ValueError(f"No shard files in {self.dir}")
103
+ self._cache: dict[Path, np.memmap] = {}
104
+
105
+ def sample(self, rng: np.random.Generator) -> np.ndarray:
106
+ shard_path = self.paths[rng.integers(0, len(self.paths))]
107
+ if shard_path not in self._cache:
108
+ self._cache[shard_path] = np.load(shard_path, mmap_mode="r")
109
+ shard = self._cache[shard_path]
110
+ idx = rng.integers(0, shard.shape[0])
111
+ return shard[idx]
112
+
113
+
114
+ class MixtureShardDataset(IterableDataset[torch.Tensor]):
115
+ def __init__(
116
+ self,
117
+ sources: Sequence[ShardSourceConfig],
118
+ *,
119
+ samples_per_epoch: int,
120
+ seed: int = 0,
121
+ ):
122
+ super().__init__()
123
+ self.sources = [ShardSource(cfg) for cfg in sources]
124
+ total_weight = sum(max(src.weight, 0.0) for src in self.sources)
125
+ if total_weight <= 0:
126
+ raise ValueError("Mixture weights must sum to > 0")
127
+ self.weights = np.array([max(src.weight, 0.0) / total_weight for src in self.sources])
128
+ self.samples_per_epoch = samples_per_epoch
129
+ self.seed = seed
130
+
131
+ def __len__(self) -> int:
132
+ return self.samples_per_epoch
133
+
134
+ def __iter__(self) -> Iterator[torch.Tensor]:
135
+ worker = get_worker_info()
136
+ if worker is None:
137
+ start = 0
138
+ end = self.samples_per_epoch
139
+ worker_seed = self.seed
140
+ else:
141
+ per_worker = (self.samples_per_epoch + worker.num_workers - 1) // worker.num_workers
142
+ start = worker.id * per_worker
143
+ end = min(start + per_worker, self.samples_per_epoch)
144
+ worker_seed = self.seed + worker.id
145
+ rng = np.random.default_rng(worker_seed)
146
+ for _ in range(start, end):
147
+ idx = rng.choice(len(self.sources), p=self.weights)
148
+ sample = np.array(self.sources[idx].sample(rng), copy=True)
149
+ yield torch.from_numpy(sample).long()
150
+
151
+
152
+ def collate_batch(batch: list[torch.Tensor]) -> torch.Tensor:
153
+ return torch.stack(batch, dim=0)
@@ -0,0 +1,21 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+
5
+
6
+ def resolve_device(device_str: str) -> torch.device:
7
+ normalized = str(device_str).strip().lower()
8
+ if normalized.startswith("cuda"):
9
+ if not torch.cuda.is_available():
10
+ return torch.device("cpu")
11
+ parts = normalized.split(":")
12
+ idx = int(parts[1]) if len(parts) > 1 else 0
13
+ if idx >= torch.cuda.device_count():
14
+ idx = max(torch.cuda.device_count() - 1, 0)
15
+ return torch.device(f"cuda:{idx}")
16
+ if normalized.startswith("mps"):
17
+ if not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()):
18
+ return torch.device("cpu")
19
+ return torch.device("mps")
20
+ return torch.device(device_str)
21
+
@@ -0,0 +1,72 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+
7
+
8
+ @dataclass
9
+ class EvalStreamingState:
10
+ fast_state: object | None = None
11
+ attention_cache: object | None = None
12
+
13
+
14
+ def parse_eval_state_mode(mode: str) -> bool:
15
+ """
16
+ Returns True when eval state should be carried across samples.
17
+ """
18
+ normalized = str(mode).strip().lower()
19
+ if normalized in {"reset", "reset_per_sample", "isolated"}:
20
+ return False
21
+ if normalized in {"carry", "carry_across_samples", "stream"}:
22
+ return True
23
+ raise ValueError(
24
+ "Unsupported eval_state_mode={!r}; expected one of "
25
+ "['reset_per_sample', 'carry_across_samples']".format(mode)
26
+ )
27
+
28
+
29
+ def init_eval_streaming_state(
30
+ model,
31
+ *,
32
+ use_fast_state: bool,
33
+ use_attention_cache: bool,
34
+ ) -> EvalStreamingState:
35
+ state = EvalStreamingState()
36
+ if use_fast_state:
37
+ init_fast_state = getattr(model, "init_fast_state", None)
38
+ if not callable(init_fast_state):
39
+ raise RuntimeError(
40
+ "Requested fast-state eval mode, but model.init_fast_state() is missing"
41
+ )
42
+ state.fast_state = init_fast_state()
43
+ if use_attention_cache:
44
+ init_attention_cache = getattr(model, "init_attention_cache", None)
45
+ if not callable(init_attention_cache):
46
+ raise RuntimeError(
47
+ "Requested attention-cache eval mode, but model.init_attention_cache() is missing"
48
+ )
49
+ state.attention_cache = init_attention_cache()
50
+ return state
51
+
52
+
53
+ def forward_with_eval_state(
54
+ model,
55
+ tokens: torch.Tensor,
56
+ *,
57
+ state: EvalStreamingState | None = None,
58
+ ) -> tuple[torch.Tensor, EvalStreamingState | None]:
59
+ if state is None:
60
+ return model(tokens), None
61
+ if state.attention_cache is not None:
62
+ logits, next_cache = model(
63
+ tokens,
64
+ fast_state=state.fast_state,
65
+ attention_cache=state.attention_cache,
66
+ return_attention_cache=True,
67
+ )
68
+ state.attention_cache = next_cache
69
+ return logits, state
70
+ if state.fast_state is not None:
71
+ return model(tokens, fast_state=state.fast_state), state
72
+ return model(tokens), state
@@ -0,0 +1,108 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Dict, cast
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ from .optim.manager import LevelConfig, LevelOptimizerManager
10
+ from .titan.self_modifying import SelfModifyingTitansState
11
+
12
+ ParamDict = Dict[str, torch.Tensor]
13
+
14
+
15
+ @dataclass
16
+ class CMSChunkBuffer:
17
+ """
18
+ Streaming CMS chunk buffer persisted across multiple model calls.
19
+
20
+ This is required to preserve update-period cadence when a logical sequence is
21
+ processed in several chunked forward/update calls.
22
+ """
23
+
24
+ inputs: list[torch.Tensor] = field(default_factory=list)
25
+ teach: list[torch.Tensor] = field(default_factory=list)
26
+ active: list[torch.Tensor] = field(default_factory=list)
27
+ count: int = 0
28
+
29
+
30
+ def init_module_deltas(module: nn.Module) -> ParamDict:
31
+ """
32
+ Initialize a per-parameter "fast state" delta dict for meta+delta fast state.
33
+
34
+ The fast state stores *deltas* (initialized to 0) rather than detached parameter clones so that
35
+ forward passes can use `meta_param + delta`, allowing outer gradients to flow to meta params
36
+ while keeping online updates as stop-grad writes into the delta tensors.
37
+ """
38
+
39
+ return {name: torch.zeros_like(param).detach() for name, param in module.named_parameters()}
40
+
41
+
42
+ @dataclass
43
+ class BlockFastState:
44
+ titan_params: ParamDict | None
45
+ cms_params: Dict[str, ParamDict]
46
+ cms_online_buffers: Dict[str, CMSChunkBuffer]
47
+ level_manager: LevelOptimizerManager
48
+ selfmod_state: SelfModifyingTitansState | None = None
49
+
50
+
51
+ def build_block_fast_state(
52
+ *,
53
+ titan_module: nn.Module | None,
54
+ cms_blocks: Dict[str, nn.Module],
55
+ selfmod_module: nn.Module | None = None,
56
+ specs,
57
+ optimizer_configs: Dict[str, dict],
58
+ default_lr: float,
59
+ ) -> BlockFastState:
60
+ titan_params = None
61
+ if titan_module is not None:
62
+ titan_params = init_module_deltas(titan_module)
63
+ cms_params = {name: init_module_deltas(block) for name, block in cms_blocks.items()}
64
+ cms_online_buffers = {name: CMSChunkBuffer() for name in cms_blocks}
65
+ level_cfg = LevelConfig(specs=specs, optimizer_configs=optimizer_configs, default_lr=default_lr)
66
+ level_manager = LevelOptimizerManager(level_cfg)
67
+ selfmod_state = None
68
+ if selfmod_module is not None:
69
+ init_fn = getattr(selfmod_module, "init_fast_state", None)
70
+ if callable(init_fn):
71
+ selfmod_state = cast(SelfModifyingTitansState, init_fn())
72
+ return BlockFastState(
73
+ titan_params=titan_params,
74
+ cms_params=cms_params,
75
+ cms_online_buffers=cms_online_buffers,
76
+ level_manager=level_manager,
77
+ selfmod_state=selfmod_state,
78
+ )
79
+
80
+
81
+ @dataclass
82
+ class ModelFastState:
83
+ blocks: list[BlockFastState]
84
+
85
+
86
+ @dataclass
87
+ class AttentionKVCache:
88
+ """
89
+ Per-layer autoregressive attention cache.
90
+
91
+ Shapes:
92
+ - key: [batch, heads, cached_tokens, head_dim]
93
+ - value: [batch, heads, cached_tokens, head_dim]
94
+ """
95
+
96
+ key: torch.Tensor
97
+ value: torch.Tensor
98
+
99
+
100
+ @dataclass
101
+ class ModelAttentionCache:
102
+ """
103
+ Model-level container for per-block attention caches.
104
+
105
+ Blocks without attention store `None` entries.
106
+ """
107
+
108
+ blocks: list[AttentionKVCache | None]
@@ -0,0 +1,69 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, Mapping, Tuple
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.func import functional_call
8
+
9
+ ParamDict = Dict[str, torch.Tensor]
10
+
11
+
12
+ def params_with_deltas(module: nn.Module, deltas: ParamDict) -> ParamDict:
13
+ params: ParamDict = {}
14
+ missing: list[str] = []
15
+ for name, param in module.named_parameters():
16
+ delta = deltas.get(name)
17
+ if delta is None:
18
+ missing.append(name)
19
+ continue
20
+ params[name] = param + delta
21
+ if missing:
22
+ raise KeyError(
23
+ f"Missing fast-state delta(s) for {module.__class__.__name__}: {sorted(missing)[:10]}"
24
+ )
25
+ return params
26
+
27
+
28
+ def module_buffers(module: nn.Module) -> ParamDict:
29
+ return {name: buf for name, buf in module.named_buffers()}
30
+
31
+
32
+ def call_with_params(
33
+ module: nn.Module,
34
+ params: ParamDict,
35
+ *args: Any,
36
+ **kwargs: Any,
37
+ ) -> Any:
38
+ buffers = module_buffers(module)
39
+ return functional_call(module, (params, buffers), args, kwargs, strict=True)
40
+
41
+
42
+ def call_with_deltas(
43
+ module: nn.Module,
44
+ deltas: ParamDict,
45
+ *args: Any,
46
+ **kwargs: Any,
47
+ ) -> Any:
48
+ return call_with_params(module, params_with_deltas(module, deltas), *args, **kwargs)
49
+
50
+
51
+ def require_grad_params(
52
+ params: Mapping[str, torch.Tensor], *, detach: bool = True
53
+ ) -> ParamDict:
54
+ out: ParamDict = {}
55
+ for name, value in params.items():
56
+ if detach:
57
+ out[name] = value.detach().requires_grad_(True)
58
+ else:
59
+ out[name] = value.requires_grad_(True)
60
+ return out
61
+
62
+
63
+ def grads_to_dict(params: ParamDict, grads: Tuple[torch.Tensor | None, ...]) -> ParamDict:
64
+ out: ParamDict = {}
65
+ for (name, _), grad in zip(params.items(), grads, strict=True):
66
+ if grad is None:
67
+ continue
68
+ out[name] = grad
69
+ return out