nested-learning 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nested_learning/__init__.py +12 -0
- nested_learning/__main__.py +12 -0
- nested_learning/assoc_memory.py +23 -0
- nested_learning/backbones.py +147 -0
- nested_learning/capabilities.py +104 -0
- nested_learning/cli.py +253 -0
- nested_learning/cms.py +92 -0
- nested_learning/config_utils.py +50 -0
- nested_learning/configs/ablations/cms_sparse.yaml +46 -0
- nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
- nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
- nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
- nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
- nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
- nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
- nested_learning/configs/data/continual_segments_sample.yaml +9 -0
- nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
- nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
- nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
- nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
- nested_learning/configs/deepspeed/zero3.json +25 -0
- nested_learning/configs/hope/mid.yaml +118 -0
- nested_learning/configs/hope/mid_fsdp.yaml +47 -0
- nested_learning/configs/hope/pilot.yaml +2 -0
- nested_learning/configs/hope/pilot_attention.yaml +9 -0
- nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
- nested_learning/configs/hope/pilot_transformer.yaml +9 -0
- nested_learning/configs/hope/target.yaml +145 -0
- nested_learning/configs/hope/target_fsdp.yaml +47 -0
- nested_learning/configs/mid_smoke.yaml +99 -0
- nested_learning/configs/mid_stage2.yaml +110 -0
- nested_learning/configs/mid_stage2_smoke.yaml +102 -0
- nested_learning/configs/mid_titan_baseline.yaml +92 -0
- nested_learning/configs/pilot.yaml +127 -0
- nested_learning/configs/pilot_paper_faithful.yaml +42 -0
- nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
- nested_learning/configs/pilot_smoke.yaml +80 -0
- nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
- nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
- nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
- nested_learning/continual_classification.py +136 -0
- nested_learning/continual_streaming.py +283 -0
- nested_learning/data.py +153 -0
- nested_learning/device.py +21 -0
- nested_learning/eval_state.py +72 -0
- nested_learning/fast_state.py +108 -0
- nested_learning/functional.py +69 -0
- nested_learning/hope/__init__.py +0 -0
- nested_learning/hope/block.py +1973 -0
- nested_learning/hope/self_mod.py +40 -0
- nested_learning/instrumentation.py +38 -0
- nested_learning/levels.py +94 -0
- nested_learning/logging_utils.py +64 -0
- nested_learning/memorize.py +382 -0
- nested_learning/model.py +604 -0
- nested_learning/optim/__init__.py +0 -0
- nested_learning/optim/deep.py +102 -0
- nested_learning/optim/factory.py +13 -0
- nested_learning/optim/m3.py +121 -0
- nested_learning/optim/manager.py +151 -0
- nested_learning/titan/__init__.py +0 -0
- nested_learning/titan/memory.py +88 -0
- nested_learning/titan/model.py +412 -0
- nested_learning/titan/self_modifying.py +724 -0
- nested_learning/tokenizer.py +28 -0
- nested_learning/tokenizer_coverage.py +77 -0
- nested_learning/training.py +1600 -0
- nested_learning/transformer.py +104 -0
- nested_learning-0.2.0.dist-info/METADATA +390 -0
- nested_learning-0.2.0.dist-info/RECORD +76 -0
- nested_learning-0.2.0.dist-info/WHEEL +4 -0
- nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
- nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any, Dict, List, Sequence
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from .continual_classification import ClassificationExample, unique_labels
|
|
10
|
+
from .memorize import MemorizeConfig, memorize_sequence
|
|
11
|
+
from .tokenizer import SentencePieceTokenizer
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True)
|
|
15
|
+
class StreamingTask:
|
|
16
|
+
task_id: int
|
|
17
|
+
labels: List[str]
|
|
18
|
+
train: List[ClassificationExample]
|
|
19
|
+
eval: List[ClassificationExample]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass(frozen=True)
|
|
23
|
+
class ContinualEvalConfig:
|
|
24
|
+
task_size: int = 10
|
|
25
|
+
seed: int = 0
|
|
26
|
+
train_per_label: int = 50
|
|
27
|
+
eval_per_label: int = 50
|
|
28
|
+
prompt_template: str = "Text: {text}\nLabel:"
|
|
29
|
+
label_template: str = "{label}"
|
|
30
|
+
task_aware: bool = True
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _logprob_completion(
|
|
34
|
+
model,
|
|
35
|
+
tokenizer: SentencePieceTokenizer,
|
|
36
|
+
prompt: str,
|
|
37
|
+
completion: str,
|
|
38
|
+
device: torch.device,
|
|
39
|
+
*,
|
|
40
|
+
fast_state=None,
|
|
41
|
+
) -> float:
|
|
42
|
+
prompt_ids = tokenizer.encode(prompt, add_bos=True)
|
|
43
|
+
completion_ids = tokenizer.encode(" " + completion, add_bos=False)
|
|
44
|
+
tokens = torch.cat([prompt_ids, completion_ids], dim=0).unsqueeze(0).to(device)
|
|
45
|
+
with torch.no_grad():
|
|
46
|
+
logits = model(tokens, fast_state=fast_state) if fast_state is not None else model(tokens)
|
|
47
|
+
log_probs = torch.log_softmax(logits[:, :-1, :], dim=-1)
|
|
48
|
+
target = tokens[:, 1:]
|
|
49
|
+
gathered = log_probs.gather(-1, target.unsqueeze(-1)).squeeze(-1)
|
|
50
|
+
prompt_len = prompt_ids.numel()
|
|
51
|
+
return float(gathered[0, prompt_len - 1 :].sum().item())
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def predict_label(
|
|
55
|
+
model,
|
|
56
|
+
tokenizer: SentencePieceTokenizer,
|
|
57
|
+
text: str,
|
|
58
|
+
candidates: Sequence[str],
|
|
59
|
+
device: torch.device,
|
|
60
|
+
*,
|
|
61
|
+
prompt_template: str,
|
|
62
|
+
label_template: str,
|
|
63
|
+
fast_state=None,
|
|
64
|
+
) -> str:
|
|
65
|
+
if not candidates:
|
|
66
|
+
raise ValueError("predict_label requires at least one candidate label")
|
|
67
|
+
prompt = prompt_template.format(text=text)
|
|
68
|
+
best_label = candidates[0]
|
|
69
|
+
best_score = -math.inf
|
|
70
|
+
for label in candidates:
|
|
71
|
+
label_str = label_template.format(label=label)
|
|
72
|
+
score = _logprob_completion(
|
|
73
|
+
model, tokenizer, prompt, label_str, device, fast_state=fast_state
|
|
74
|
+
)
|
|
75
|
+
if score > best_score:
|
|
76
|
+
best_score = score
|
|
77
|
+
best_label = label
|
|
78
|
+
return best_label
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _balanced_split(
|
|
82
|
+
examples: Sequence[ClassificationExample],
|
|
83
|
+
*,
|
|
84
|
+
labels: Sequence[str],
|
|
85
|
+
train_per_label: int,
|
|
86
|
+
eval_per_label: int,
|
|
87
|
+
) -> tuple[List[ClassificationExample], List[ClassificationExample]]:
|
|
88
|
+
train: List[ClassificationExample] = []
|
|
89
|
+
eval_: List[ClassificationExample] = []
|
|
90
|
+
counts_train: Dict[str, int] = {lbl: 0 for lbl in labels}
|
|
91
|
+
counts_eval: Dict[str, int] = {lbl: 0 for lbl in labels}
|
|
92
|
+
for ex in examples:
|
|
93
|
+
lbl = ex.label
|
|
94
|
+
if lbl not in counts_train:
|
|
95
|
+
continue
|
|
96
|
+
if counts_train[lbl] < train_per_label:
|
|
97
|
+
train.append(ex)
|
|
98
|
+
counts_train[lbl] += 1
|
|
99
|
+
elif counts_eval[lbl] < eval_per_label:
|
|
100
|
+
eval_.append(ex)
|
|
101
|
+
counts_eval[lbl] += 1
|
|
102
|
+
if all(v >= train_per_label for v in counts_train.values()) and all(
|
|
103
|
+
v >= eval_per_label for v in counts_eval.values()
|
|
104
|
+
):
|
|
105
|
+
break
|
|
106
|
+
return train, eval_
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def build_streaming_tasks(
|
|
110
|
+
examples: Sequence[ClassificationExample],
|
|
111
|
+
*,
|
|
112
|
+
cfg: ContinualEvalConfig,
|
|
113
|
+
label_order: Sequence[str] | None = None,
|
|
114
|
+
) -> List[StreamingTask]:
|
|
115
|
+
labels = list(label_order) if label_order is not None else unique_labels(examples)
|
|
116
|
+
if label_order is None:
|
|
117
|
+
import random
|
|
118
|
+
|
|
119
|
+
rng = random.Random(cfg.seed)
|
|
120
|
+
rng.shuffle(labels)
|
|
121
|
+
if cfg.task_size <= 0:
|
|
122
|
+
raise ValueError("task_size must be positive")
|
|
123
|
+
tasks: List[StreamingTask] = []
|
|
124
|
+
for task_id, start in enumerate(range(0, len(labels), cfg.task_size)):
|
|
125
|
+
task_labels = labels[start : start + cfg.task_size]
|
|
126
|
+
if not task_labels:
|
|
127
|
+
break
|
|
128
|
+
task_examples = [ex for ex in examples if ex.label in set(task_labels)]
|
|
129
|
+
train, eval_ = _balanced_split(
|
|
130
|
+
task_examples,
|
|
131
|
+
labels=task_labels,
|
|
132
|
+
train_per_label=cfg.train_per_label,
|
|
133
|
+
eval_per_label=cfg.eval_per_label,
|
|
134
|
+
)
|
|
135
|
+
tasks.append(
|
|
136
|
+
StreamingTask(task_id=task_id, labels=list(task_labels), train=train, eval=eval_)
|
|
137
|
+
)
|
|
138
|
+
return tasks
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@dataclass(frozen=True)
|
|
142
|
+
class ContinualEvalResult:
|
|
143
|
+
task_accuracy_matrix: List[List[float]]
|
|
144
|
+
per_task_forgetting: List[float]
|
|
145
|
+
avg_accuracy_final: float
|
|
146
|
+
avg_forgetting: float
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def evaluate_continual_classification(
|
|
150
|
+
model,
|
|
151
|
+
tokenizer: SentencePieceTokenizer,
|
|
152
|
+
tasks: Sequence[StreamingTask],
|
|
153
|
+
device: torch.device,
|
|
154
|
+
*,
|
|
155
|
+
cfg: ContinualEvalConfig,
|
|
156
|
+
memorize_cfg: MemorizeConfig,
|
|
157
|
+
) -> tuple[ContinualEvalResult, Dict[str, Any]]:
|
|
158
|
+
"""
|
|
159
|
+
Streaming class-incremental evaluation using generative classification + optional
|
|
160
|
+
test-time memorization.
|
|
161
|
+
|
|
162
|
+
- If `memorize_cfg.enabled`, each training example is memorized by appending the correct
|
|
163
|
+
label string.
|
|
164
|
+
- Accuracy is computed after each task on each task's eval set, producing a task-accuracy
|
|
165
|
+
matrix.
|
|
166
|
+
"""
|
|
167
|
+
meta_snapshot: Dict[str, torch.Tensor] | None = None
|
|
168
|
+
if memorize_cfg.enabled and (not memorize_cfg.use_fast_state) and memorize_cfg.reset:
|
|
169
|
+
from .memorize import snapshot_state_dict # local import to avoid cycles
|
|
170
|
+
|
|
171
|
+
meta_snapshot = snapshot_state_dict(model)
|
|
172
|
+
|
|
173
|
+
fast_state = None
|
|
174
|
+
if memorize_cfg.enabled and memorize_cfg.use_fast_state:
|
|
175
|
+
if not hasattr(model, "init_fast_state"):
|
|
176
|
+
raise RuntimeError("Model does not support fast state memorization")
|
|
177
|
+
fast_state = model.init_fast_state()
|
|
178
|
+
|
|
179
|
+
task_acc: List[List[float]] = [[float("nan") for _ in tasks] for _ in tasks]
|
|
180
|
+
best_acc: List[float] = [0.0 for _ in tasks]
|
|
181
|
+
|
|
182
|
+
memorize_stats_total: Dict[str, float] = {}
|
|
183
|
+
|
|
184
|
+
def _eval_task(task_idx: int) -> float:
|
|
185
|
+
task = tasks[task_idx]
|
|
186
|
+
candidates = (
|
|
187
|
+
task.labels
|
|
188
|
+
if cfg.task_aware
|
|
189
|
+
else [lbl for t in tasks[: current_task + 1] for lbl in t.labels]
|
|
190
|
+
)
|
|
191
|
+
if not task.eval:
|
|
192
|
+
return float("nan")
|
|
193
|
+
correct = 0
|
|
194
|
+
for ex in task.eval:
|
|
195
|
+
pred = predict_label(
|
|
196
|
+
model,
|
|
197
|
+
tokenizer,
|
|
198
|
+
ex.text,
|
|
199
|
+
candidates,
|
|
200
|
+
device,
|
|
201
|
+
prompt_template=cfg.prompt_template,
|
|
202
|
+
label_template=cfg.label_template,
|
|
203
|
+
fast_state=fast_state,
|
|
204
|
+
)
|
|
205
|
+
correct += int(pred == ex.label)
|
|
206
|
+
return correct / len(task.eval) if task.eval else float("nan")
|
|
207
|
+
|
|
208
|
+
for current_task, task in enumerate(tasks):
|
|
209
|
+
# Online "training" on this task's examples via optional memorization.
|
|
210
|
+
for ex in task.train:
|
|
211
|
+
candidates = (
|
|
212
|
+
task.labels
|
|
213
|
+
if cfg.task_aware
|
|
214
|
+
else [lbl for t in tasks[: current_task + 1] for lbl in t.labels]
|
|
215
|
+
)
|
|
216
|
+
_ = predict_label(
|
|
217
|
+
model,
|
|
218
|
+
tokenizer,
|
|
219
|
+
ex.text,
|
|
220
|
+
candidates,
|
|
221
|
+
device,
|
|
222
|
+
prompt_template=cfg.prompt_template,
|
|
223
|
+
label_template=cfg.label_template,
|
|
224
|
+
fast_state=fast_state,
|
|
225
|
+
)
|
|
226
|
+
if memorize_cfg.enabled:
|
|
227
|
+
prompt = cfg.prompt_template.format(text=ex.text)
|
|
228
|
+
target = cfg.label_template.format(label=ex.label)
|
|
229
|
+
memorize_text = f"{prompt} {target}"
|
|
230
|
+
if memorize_cfg.use_fast_state and memorize_cfg.reset:
|
|
231
|
+
fast_state = model.init_fast_state()
|
|
232
|
+
stats = memorize_sequence(
|
|
233
|
+
model, tokenizer, memorize_text, device, memorize_cfg, fast_state=fast_state
|
|
234
|
+
)
|
|
235
|
+
for k, v in stats.items():
|
|
236
|
+
memorize_stats_total[k] = memorize_stats_total.get(k, 0.0) + v
|
|
237
|
+
if (
|
|
238
|
+
(not memorize_cfg.use_fast_state)
|
|
239
|
+
and memorize_cfg.reset
|
|
240
|
+
and meta_snapshot is not None
|
|
241
|
+
):
|
|
242
|
+
from .memorize import restore_state_dict # local import to avoid cycles
|
|
243
|
+
|
|
244
|
+
restore_state_dict(model, meta_snapshot)
|
|
245
|
+
|
|
246
|
+
# Evaluate on all tasks seen so far.
|
|
247
|
+
for task_idx in range(current_task + 1):
|
|
248
|
+
acc = _eval_task(task_idx)
|
|
249
|
+
task_acc[task_idx][current_task] = acc
|
|
250
|
+
if not math.isnan(acc):
|
|
251
|
+
best_acc[task_idx] = max(best_acc[task_idx], acc)
|
|
252
|
+
|
|
253
|
+
final_accs = [task_acc[i][-1] for i in range(len(tasks)) if not math.isnan(task_acc[i][-1])]
|
|
254
|
+
avg_accuracy_final = sum(final_accs) / len(final_accs) if final_accs else float("nan")
|
|
255
|
+
|
|
256
|
+
per_task_forgetting: List[float] = []
|
|
257
|
+
for i in range(len(tasks)):
|
|
258
|
+
last = task_acc[i][-1]
|
|
259
|
+
if math.isnan(last):
|
|
260
|
+
per_task_forgetting.append(float("nan"))
|
|
261
|
+
continue
|
|
262
|
+
per_task_forgetting.append(best_acc[i] - last)
|
|
263
|
+
valid_forgetting = [f for f in per_task_forgetting if not math.isnan(f)]
|
|
264
|
+
avg_forgetting = (
|
|
265
|
+
sum(valid_forgetting) / len(valid_forgetting) if valid_forgetting else float("nan")
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
result = ContinualEvalResult(
|
|
269
|
+
task_accuracy_matrix=task_acc,
|
|
270
|
+
per_task_forgetting=per_task_forgetting,
|
|
271
|
+
avg_accuracy_final=avg_accuracy_final,
|
|
272
|
+
avg_forgetting=avg_forgetting,
|
|
273
|
+
)
|
|
274
|
+
meta = {
|
|
275
|
+
"task_size": cfg.task_size,
|
|
276
|
+
"train_per_label": cfg.train_per_label,
|
|
277
|
+
"eval_per_label": cfg.eval_per_label,
|
|
278
|
+
"task_aware": cfg.task_aware,
|
|
279
|
+
"prompt_template": cfg.prompt_template,
|
|
280
|
+
"label_template": cfg.label_template,
|
|
281
|
+
"memorize_stats": memorize_stats_total,
|
|
282
|
+
}
|
|
283
|
+
return result, meta
|
nested_learning/data.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Iterator, List, Sequence
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
from torch.utils.data import Dataset, IterableDataset, get_worker_info
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class SyntheticTextConfig:
|
|
14
|
+
vocab_size: int
|
|
15
|
+
seq_len: int
|
|
16
|
+
dataset_size: int
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SyntheticTextDataset(Dataset[torch.Tensor]):
|
|
20
|
+
def __init__(self, config: SyntheticTextConfig):
|
|
21
|
+
self.config = config
|
|
22
|
+
|
|
23
|
+
def __len__(self) -> int:
|
|
24
|
+
return self.config.dataset_size
|
|
25
|
+
|
|
26
|
+
def __getitem__(self, idx: int) -> torch.Tensor:
|
|
27
|
+
g = torch.Generator().manual_seed(idx)
|
|
28
|
+
return torch.randint(0, self.config.vocab_size, (self.config.seq_len,), generator=g)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class TokenShardDataset(Dataset[torch.Tensor]):
|
|
32
|
+
"""Memory-mapped dataset over NumPy shards produced by shard_corpus.py."""
|
|
33
|
+
|
|
34
|
+
def __init__(self, shard_dir: str | Path):
|
|
35
|
+
self.shard_dir = Path(shard_dir)
|
|
36
|
+
if not self.shard_dir.exists():
|
|
37
|
+
msg = f"Shard directory {self.shard_dir} does not exist"
|
|
38
|
+
raise FileNotFoundError(msg)
|
|
39
|
+
self.paths = sorted(self.shard_dir.glob("*.npy"))
|
|
40
|
+
if not self.paths:
|
|
41
|
+
msg = f"No shard files found in {self.shard_dir}"
|
|
42
|
+
raise ValueError(msg)
|
|
43
|
+
self.metadata: List[tuple[int, int]] = []
|
|
44
|
+
self._cache: dict[int, np.memmap] = {}
|
|
45
|
+
total = 0
|
|
46
|
+
for idx, path in enumerate(self.paths):
|
|
47
|
+
arr = np.load(path, mmap_mode="r")
|
|
48
|
+
length = arr.shape[0]
|
|
49
|
+
self.metadata.append((total, length))
|
|
50
|
+
total += length
|
|
51
|
+
self.total_sequences = total
|
|
52
|
+
|
|
53
|
+
def __len__(self) -> int:
|
|
54
|
+
return self.total_sequences
|
|
55
|
+
|
|
56
|
+
def _load_array(self, shard_idx: int) -> np.memmap:
|
|
57
|
+
if shard_idx not in self._cache:
|
|
58
|
+
self._cache[shard_idx] = np.load(self.paths[shard_idx], mmap_mode="r")
|
|
59
|
+
return self._cache[shard_idx]
|
|
60
|
+
|
|
61
|
+
def __getitem__(self, idx: int) -> torch.Tensor:
|
|
62
|
+
if idx < 0 or idx >= self.total_sequences:
|
|
63
|
+
raise IndexError(idx)
|
|
64
|
+
shard_idx = self._find_shard(idx)
|
|
65
|
+
start_offset = self.metadata[shard_idx][0]
|
|
66
|
+
arr = self._load_array(shard_idx)
|
|
67
|
+
local_idx = idx - start_offset
|
|
68
|
+
tokens = torch.from_numpy(arr[local_idx])
|
|
69
|
+
return tokens.long()
|
|
70
|
+
|
|
71
|
+
def _find_shard(self, idx: int) -> int:
|
|
72
|
+
lo, hi = 0, len(self.metadata) - 1
|
|
73
|
+
while lo <= hi:
|
|
74
|
+
mid = (lo + hi) // 2
|
|
75
|
+
start, length = self.metadata[mid]
|
|
76
|
+
if idx < start:
|
|
77
|
+
hi = mid - 1
|
|
78
|
+
elif idx >= start + length:
|
|
79
|
+
lo = mid + 1
|
|
80
|
+
else:
|
|
81
|
+
return mid
|
|
82
|
+
return len(self.metadata) - 1
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclass
|
|
86
|
+
class ShardSourceConfig:
|
|
87
|
+
name: str
|
|
88
|
+
shards_dir: str
|
|
89
|
+
weight: float
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class ShardSource:
|
|
93
|
+
def __init__(self, config: ShardSourceConfig):
|
|
94
|
+
self.name = config.name
|
|
95
|
+
self.weight = config.weight
|
|
96
|
+
self.dir = Path(config.shards_dir)
|
|
97
|
+
if not self.dir.exists():
|
|
98
|
+
msg = f"Shard directory {self.dir} missing for source {self.name}"
|
|
99
|
+
raise FileNotFoundError(msg)
|
|
100
|
+
self.paths = sorted(self.dir.glob("*.npy"))
|
|
101
|
+
if not self.paths:
|
|
102
|
+
raise ValueError(f"No shard files in {self.dir}")
|
|
103
|
+
self._cache: dict[Path, np.memmap] = {}
|
|
104
|
+
|
|
105
|
+
def sample(self, rng: np.random.Generator) -> np.ndarray:
|
|
106
|
+
shard_path = self.paths[rng.integers(0, len(self.paths))]
|
|
107
|
+
if shard_path not in self._cache:
|
|
108
|
+
self._cache[shard_path] = np.load(shard_path, mmap_mode="r")
|
|
109
|
+
shard = self._cache[shard_path]
|
|
110
|
+
idx = rng.integers(0, shard.shape[0])
|
|
111
|
+
return shard[idx]
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class MixtureShardDataset(IterableDataset[torch.Tensor]):
|
|
115
|
+
def __init__(
|
|
116
|
+
self,
|
|
117
|
+
sources: Sequence[ShardSourceConfig],
|
|
118
|
+
*,
|
|
119
|
+
samples_per_epoch: int,
|
|
120
|
+
seed: int = 0,
|
|
121
|
+
):
|
|
122
|
+
super().__init__()
|
|
123
|
+
self.sources = [ShardSource(cfg) for cfg in sources]
|
|
124
|
+
total_weight = sum(max(src.weight, 0.0) for src in self.sources)
|
|
125
|
+
if total_weight <= 0:
|
|
126
|
+
raise ValueError("Mixture weights must sum to > 0")
|
|
127
|
+
self.weights = np.array([max(src.weight, 0.0) / total_weight for src in self.sources])
|
|
128
|
+
self.samples_per_epoch = samples_per_epoch
|
|
129
|
+
self.seed = seed
|
|
130
|
+
|
|
131
|
+
def __len__(self) -> int:
|
|
132
|
+
return self.samples_per_epoch
|
|
133
|
+
|
|
134
|
+
def __iter__(self) -> Iterator[torch.Tensor]:
|
|
135
|
+
worker = get_worker_info()
|
|
136
|
+
if worker is None:
|
|
137
|
+
start = 0
|
|
138
|
+
end = self.samples_per_epoch
|
|
139
|
+
worker_seed = self.seed
|
|
140
|
+
else:
|
|
141
|
+
per_worker = (self.samples_per_epoch + worker.num_workers - 1) // worker.num_workers
|
|
142
|
+
start = worker.id * per_worker
|
|
143
|
+
end = min(start + per_worker, self.samples_per_epoch)
|
|
144
|
+
worker_seed = self.seed + worker.id
|
|
145
|
+
rng = np.random.default_rng(worker_seed)
|
|
146
|
+
for _ in range(start, end):
|
|
147
|
+
idx = rng.choice(len(self.sources), p=self.weights)
|
|
148
|
+
sample = np.array(self.sources[idx].sample(rng), copy=True)
|
|
149
|
+
yield torch.from_numpy(sample).long()
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def collate_batch(batch: list[torch.Tensor]) -> torch.Tensor:
|
|
153
|
+
return torch.stack(batch, dim=0)
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def resolve_device(device_str: str) -> torch.device:
|
|
7
|
+
normalized = str(device_str).strip().lower()
|
|
8
|
+
if normalized.startswith("cuda"):
|
|
9
|
+
if not torch.cuda.is_available():
|
|
10
|
+
return torch.device("cpu")
|
|
11
|
+
parts = normalized.split(":")
|
|
12
|
+
idx = int(parts[1]) if len(parts) > 1 else 0
|
|
13
|
+
if idx >= torch.cuda.device_count():
|
|
14
|
+
idx = max(torch.cuda.device_count() - 1, 0)
|
|
15
|
+
return torch.device(f"cuda:{idx}")
|
|
16
|
+
if normalized.startswith("mps"):
|
|
17
|
+
if not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()):
|
|
18
|
+
return torch.device("cpu")
|
|
19
|
+
return torch.device("mps")
|
|
20
|
+
return torch.device(device_str)
|
|
21
|
+
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class EvalStreamingState:
|
|
10
|
+
fast_state: object | None = None
|
|
11
|
+
attention_cache: object | None = None
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def parse_eval_state_mode(mode: str) -> bool:
|
|
15
|
+
"""
|
|
16
|
+
Returns True when eval state should be carried across samples.
|
|
17
|
+
"""
|
|
18
|
+
normalized = str(mode).strip().lower()
|
|
19
|
+
if normalized in {"reset", "reset_per_sample", "isolated"}:
|
|
20
|
+
return False
|
|
21
|
+
if normalized in {"carry", "carry_across_samples", "stream"}:
|
|
22
|
+
return True
|
|
23
|
+
raise ValueError(
|
|
24
|
+
"Unsupported eval_state_mode={!r}; expected one of "
|
|
25
|
+
"['reset_per_sample', 'carry_across_samples']".format(mode)
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def init_eval_streaming_state(
|
|
30
|
+
model,
|
|
31
|
+
*,
|
|
32
|
+
use_fast_state: bool,
|
|
33
|
+
use_attention_cache: bool,
|
|
34
|
+
) -> EvalStreamingState:
|
|
35
|
+
state = EvalStreamingState()
|
|
36
|
+
if use_fast_state:
|
|
37
|
+
init_fast_state = getattr(model, "init_fast_state", None)
|
|
38
|
+
if not callable(init_fast_state):
|
|
39
|
+
raise RuntimeError(
|
|
40
|
+
"Requested fast-state eval mode, but model.init_fast_state() is missing"
|
|
41
|
+
)
|
|
42
|
+
state.fast_state = init_fast_state()
|
|
43
|
+
if use_attention_cache:
|
|
44
|
+
init_attention_cache = getattr(model, "init_attention_cache", None)
|
|
45
|
+
if not callable(init_attention_cache):
|
|
46
|
+
raise RuntimeError(
|
|
47
|
+
"Requested attention-cache eval mode, but model.init_attention_cache() is missing"
|
|
48
|
+
)
|
|
49
|
+
state.attention_cache = init_attention_cache()
|
|
50
|
+
return state
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def forward_with_eval_state(
|
|
54
|
+
model,
|
|
55
|
+
tokens: torch.Tensor,
|
|
56
|
+
*,
|
|
57
|
+
state: EvalStreamingState | None = None,
|
|
58
|
+
) -> tuple[torch.Tensor, EvalStreamingState | None]:
|
|
59
|
+
if state is None:
|
|
60
|
+
return model(tokens), None
|
|
61
|
+
if state.attention_cache is not None:
|
|
62
|
+
logits, next_cache = model(
|
|
63
|
+
tokens,
|
|
64
|
+
fast_state=state.fast_state,
|
|
65
|
+
attention_cache=state.attention_cache,
|
|
66
|
+
return_attention_cache=True,
|
|
67
|
+
)
|
|
68
|
+
state.attention_cache = next_cache
|
|
69
|
+
return logits, state
|
|
70
|
+
if state.fast_state is not None:
|
|
71
|
+
return model(tokens, fast_state=state.fast_state), state
|
|
72
|
+
return model(tokens), state
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Dict, cast
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import nn
|
|
8
|
+
|
|
9
|
+
from .optim.manager import LevelConfig, LevelOptimizerManager
|
|
10
|
+
from .titan.self_modifying import SelfModifyingTitansState
|
|
11
|
+
|
|
12
|
+
ParamDict = Dict[str, torch.Tensor]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class CMSChunkBuffer:
|
|
17
|
+
"""
|
|
18
|
+
Streaming CMS chunk buffer persisted across multiple model calls.
|
|
19
|
+
|
|
20
|
+
This is required to preserve update-period cadence when a logical sequence is
|
|
21
|
+
processed in several chunked forward/update calls.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
inputs: list[torch.Tensor] = field(default_factory=list)
|
|
25
|
+
teach: list[torch.Tensor] = field(default_factory=list)
|
|
26
|
+
active: list[torch.Tensor] = field(default_factory=list)
|
|
27
|
+
count: int = 0
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def init_module_deltas(module: nn.Module) -> ParamDict:
|
|
31
|
+
"""
|
|
32
|
+
Initialize a per-parameter "fast state" delta dict for meta+delta fast state.
|
|
33
|
+
|
|
34
|
+
The fast state stores *deltas* (initialized to 0) rather than detached parameter clones so that
|
|
35
|
+
forward passes can use `meta_param + delta`, allowing outer gradients to flow to meta params
|
|
36
|
+
while keeping online updates as stop-grad writes into the delta tensors.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
return {name: torch.zeros_like(param).detach() for name, param in module.named_parameters()}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class BlockFastState:
|
|
44
|
+
titan_params: ParamDict | None
|
|
45
|
+
cms_params: Dict[str, ParamDict]
|
|
46
|
+
cms_online_buffers: Dict[str, CMSChunkBuffer]
|
|
47
|
+
level_manager: LevelOptimizerManager
|
|
48
|
+
selfmod_state: SelfModifyingTitansState | None = None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def build_block_fast_state(
|
|
52
|
+
*,
|
|
53
|
+
titan_module: nn.Module | None,
|
|
54
|
+
cms_blocks: Dict[str, nn.Module],
|
|
55
|
+
selfmod_module: nn.Module | None = None,
|
|
56
|
+
specs,
|
|
57
|
+
optimizer_configs: Dict[str, dict],
|
|
58
|
+
default_lr: float,
|
|
59
|
+
) -> BlockFastState:
|
|
60
|
+
titan_params = None
|
|
61
|
+
if titan_module is not None:
|
|
62
|
+
titan_params = init_module_deltas(titan_module)
|
|
63
|
+
cms_params = {name: init_module_deltas(block) for name, block in cms_blocks.items()}
|
|
64
|
+
cms_online_buffers = {name: CMSChunkBuffer() for name in cms_blocks}
|
|
65
|
+
level_cfg = LevelConfig(specs=specs, optimizer_configs=optimizer_configs, default_lr=default_lr)
|
|
66
|
+
level_manager = LevelOptimizerManager(level_cfg)
|
|
67
|
+
selfmod_state = None
|
|
68
|
+
if selfmod_module is not None:
|
|
69
|
+
init_fn = getattr(selfmod_module, "init_fast_state", None)
|
|
70
|
+
if callable(init_fn):
|
|
71
|
+
selfmod_state = cast(SelfModifyingTitansState, init_fn())
|
|
72
|
+
return BlockFastState(
|
|
73
|
+
titan_params=titan_params,
|
|
74
|
+
cms_params=cms_params,
|
|
75
|
+
cms_online_buffers=cms_online_buffers,
|
|
76
|
+
level_manager=level_manager,
|
|
77
|
+
selfmod_state=selfmod_state,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass
|
|
82
|
+
class ModelFastState:
|
|
83
|
+
blocks: list[BlockFastState]
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@dataclass
|
|
87
|
+
class AttentionKVCache:
|
|
88
|
+
"""
|
|
89
|
+
Per-layer autoregressive attention cache.
|
|
90
|
+
|
|
91
|
+
Shapes:
|
|
92
|
+
- key: [batch, heads, cached_tokens, head_dim]
|
|
93
|
+
- value: [batch, heads, cached_tokens, head_dim]
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
key: torch.Tensor
|
|
97
|
+
value: torch.Tensor
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@dataclass
|
|
101
|
+
class ModelAttentionCache:
|
|
102
|
+
"""
|
|
103
|
+
Model-level container for per-block attention caches.
|
|
104
|
+
|
|
105
|
+
Blocks without attention store `None` entries.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
blocks: list[AttentionKVCache | None]
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Mapping, Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
from torch.func import functional_call
|
|
8
|
+
|
|
9
|
+
ParamDict = Dict[str, torch.Tensor]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def params_with_deltas(module: nn.Module, deltas: ParamDict) -> ParamDict:
|
|
13
|
+
params: ParamDict = {}
|
|
14
|
+
missing: list[str] = []
|
|
15
|
+
for name, param in module.named_parameters():
|
|
16
|
+
delta = deltas.get(name)
|
|
17
|
+
if delta is None:
|
|
18
|
+
missing.append(name)
|
|
19
|
+
continue
|
|
20
|
+
params[name] = param + delta
|
|
21
|
+
if missing:
|
|
22
|
+
raise KeyError(
|
|
23
|
+
f"Missing fast-state delta(s) for {module.__class__.__name__}: {sorted(missing)[:10]}"
|
|
24
|
+
)
|
|
25
|
+
return params
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def module_buffers(module: nn.Module) -> ParamDict:
|
|
29
|
+
return {name: buf for name, buf in module.named_buffers()}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def call_with_params(
|
|
33
|
+
module: nn.Module,
|
|
34
|
+
params: ParamDict,
|
|
35
|
+
*args: Any,
|
|
36
|
+
**kwargs: Any,
|
|
37
|
+
) -> Any:
|
|
38
|
+
buffers = module_buffers(module)
|
|
39
|
+
return functional_call(module, (params, buffers), args, kwargs, strict=True)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def call_with_deltas(
|
|
43
|
+
module: nn.Module,
|
|
44
|
+
deltas: ParamDict,
|
|
45
|
+
*args: Any,
|
|
46
|
+
**kwargs: Any,
|
|
47
|
+
) -> Any:
|
|
48
|
+
return call_with_params(module, params_with_deltas(module, deltas), *args, **kwargs)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def require_grad_params(
|
|
52
|
+
params: Mapping[str, torch.Tensor], *, detach: bool = True
|
|
53
|
+
) -> ParamDict:
|
|
54
|
+
out: ParamDict = {}
|
|
55
|
+
for name, value in params.items():
|
|
56
|
+
if detach:
|
|
57
|
+
out[name] = value.detach().requires_grad_(True)
|
|
58
|
+
else:
|
|
59
|
+
out[name] = value.requires_grad_(True)
|
|
60
|
+
return out
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def grads_to_dict(params: ParamDict, grads: Tuple[torch.Tensor | None, ...]) -> ParamDict:
|
|
64
|
+
out: ParamDict = {}
|
|
65
|
+
for (name, _), grad in zip(params.items(), grads, strict=True):
|
|
66
|
+
if grad is None:
|
|
67
|
+
continue
|
|
68
|
+
out[name] = grad
|
|
69
|
+
return out
|