@zigrivers/scaffold 3.14.0 → 3.16.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +50 -21
- package/content/knowledge/core/automated-review-tooling.md +21 -26
- package/content/knowledge/core/multi-model-review-dispatch.md +30 -55
- package/content/knowledge/research/research-architecture.md +385 -0
- package/content/knowledge/research/research-conventions.md +248 -0
- package/content/knowledge/research/research-dev-environment.md +303 -0
- package/content/knowledge/research/research-experiment-loop.md +429 -0
- package/content/knowledge/research/research-experiment-tracking.md +336 -0
- package/content/knowledge/research/research-ml-architecture-search.md +383 -0
- package/content/knowledge/research/research-ml-evaluation.md +407 -0
- package/content/knowledge/research/research-ml-experiment-tracking.md +466 -0
- package/content/knowledge/research/research-ml-training-patterns.md +413 -0
- package/content/knowledge/research/research-observability.md +395 -0
- package/content/knowledge/research/research-overfitting-prevention.md +306 -0
- package/content/knowledge/research/research-project-structure.md +264 -0
- package/content/knowledge/research/research-quant-backtesting.md +326 -0
- package/content/knowledge/research/research-quant-market-data.md +366 -0
- package/content/knowledge/research/research-quant-metrics.md +335 -0
- package/content/knowledge/research/research-quant-requirements.md +223 -0
- package/content/knowledge/research/research-quant-risk.md +469 -0
- package/content/knowledge/research/research-quant-strategy-patterns.md +412 -0
- package/content/knowledge/research/research-requirements.md +201 -0
- package/content/knowledge/research/research-security.md +374 -0
- package/content/knowledge/research/research-sim-compute-management.md +538 -0
- package/content/knowledge/research/research-sim-engine-patterns.md +448 -0
- package/content/knowledge/research/research-sim-parameter-spaces.md +425 -0
- package/content/knowledge/research/research-sim-validation.md +456 -0
- package/content/knowledge/research/research-testing.md +334 -0
- package/content/methodology/research-ml-research.yml +23 -0
- package/content/methodology/research-overlay.yml +65 -0
- package/content/methodology/research-quant-finance.yml +29 -0
- package/content/methodology/research-simulation.yml +23 -0
- package/content/tools/post-implementation-review.md +36 -7
- package/content/tools/review-code.md +33 -8
- package/content/tools/review-pr.md +79 -95
- package/dist/cli/commands/adopt.d.ts.map +1 -1
- package/dist/cli/commands/adopt.js +22 -1
- package/dist/cli/commands/adopt.js.map +1 -1
- package/dist/cli/commands/adopt.serialization.test.js +41 -0
- package/dist/cli/commands/adopt.serialization.test.js.map +1 -1
- package/dist/cli/commands/init.d.ts +4 -0
- package/dist/cli/commands/init.d.ts.map +1 -1
- package/dist/cli/commands/init.js +32 -2
- package/dist/cli/commands/init.js.map +1 -1
- package/dist/cli/init-flag-families.d.ts +6 -1
- package/dist/cli/init-flag-families.d.ts.map +1 -1
- package/dist/cli/init-flag-families.js +32 -1
- package/dist/cli/init-flag-families.js.map +1 -1
- package/dist/cli/init-flag-families.test.js +47 -0
- package/dist/cli/init-flag-families.test.js.map +1 -1
- package/dist/config/schema.d.ts +272 -16
- package/dist/config/schema.d.ts.map +1 -1
- package/dist/config/schema.js +25 -1
- package/dist/config/schema.js.map +1 -1
- package/dist/config/schema.test.js +103 -3
- package/dist/config/schema.test.js.map +1 -1
- package/dist/core/assembly/overlay-loader.d.ts +12 -0
- package/dist/core/assembly/overlay-loader.d.ts.map +1 -1
- package/dist/core/assembly/overlay-loader.js +30 -0
- package/dist/core/assembly/overlay-loader.js.map +1 -1
- package/dist/core/assembly/overlay-loader.test.js +66 -1
- package/dist/core/assembly/overlay-loader.test.js.map +1 -1
- package/dist/core/assembly/overlay-state-resolver.d.ts.map +1 -1
- package/dist/core/assembly/overlay-state-resolver.js +48 -19
- package/dist/core/assembly/overlay-state-resolver.js.map +1 -1
- package/dist/core/assembly/overlay-state-resolver.test.js +80 -0
- package/dist/core/assembly/overlay-state-resolver.test.js.map +1 -1
- package/dist/e2e/project-type-overlays.test.js +119 -0
- package/dist/e2e/project-type-overlays.test.js.map +1 -1
- package/dist/project/adopt.d.ts.map +1 -1
- package/dist/project/adopt.js +3 -1
- package/dist/project/adopt.js.map +1 -1
- package/dist/project/detectors/disambiguate.js +1 -1
- package/dist/project/detectors/disambiguate.js.map +1 -1
- package/dist/project/detectors/index.d.ts.map +1 -1
- package/dist/project/detectors/index.js +2 -1
- package/dist/project/detectors/index.js.map +1 -1
- package/dist/project/detectors/ml.d.ts.map +1 -1
- package/dist/project/detectors/ml.js +2 -6
- package/dist/project/detectors/ml.js.map +1 -1
- package/dist/project/detectors/research.d.ts +4 -0
- package/dist/project/detectors/research.d.ts.map +1 -0
- package/dist/project/detectors/research.js +141 -0
- package/dist/project/detectors/research.js.map +1 -0
- package/dist/project/detectors/research.test.d.ts +2 -0
- package/dist/project/detectors/research.test.d.ts.map +1 -0
- package/dist/project/detectors/research.test.js +235 -0
- package/dist/project/detectors/research.test.js.map +1 -0
- package/dist/project/detectors/shared-signals.d.ts +3 -0
- package/dist/project/detectors/shared-signals.d.ts.map +1 -0
- package/dist/project/detectors/shared-signals.js +9 -0
- package/dist/project/detectors/shared-signals.js.map +1 -0
- package/dist/project/detectors/types.d.ts +6 -2
- package/dist/project/detectors/types.d.ts.map +1 -1
- package/dist/project/detectors/types.js.map +1 -1
- package/dist/types/config.d.ts +7 -1
- package/dist/types/config.d.ts.map +1 -1
- package/dist/wizard/copy/core.d.ts.map +1 -1
- package/dist/wizard/copy/core.js +4 -0
- package/dist/wizard/copy/core.js.map +1 -1
- package/dist/wizard/copy/index.d.ts.map +1 -1
- package/dist/wizard/copy/index.js +2 -0
- package/dist/wizard/copy/index.js.map +1 -1
- package/dist/wizard/copy/research.d.ts +3 -0
- package/dist/wizard/copy/research.d.ts.map +1 -0
- package/dist/wizard/copy/research.js +27 -0
- package/dist/wizard/copy/research.js.map +1 -0
- package/dist/wizard/copy/types.d.ts +5 -1
- package/dist/wizard/copy/types.d.ts.map +1 -1
- package/dist/wizard/flags.d.ts +7 -1
- package/dist/wizard/flags.d.ts.map +1 -1
- package/dist/wizard/questions.d.ts +4 -2
- package/dist/wizard/questions.d.ts.map +1 -1
- package/dist/wizard/questions.js +27 -1
- package/dist/wizard/questions.js.map +1 -1
- package/dist/wizard/questions.test.js +51 -0
- package/dist/wizard/questions.test.js.map +1 -1
- package/dist/wizard/wizard.d.ts +3 -2
- package/dist/wizard/wizard.d.ts.map +1 -1
- package/dist/wizard/wizard.js +3 -1
- package/dist/wizard/wizard.js.map +1 -1
- package/package.json +1 -1
|
@@ -0,0 +1,413 @@
|
|
|
1
|
+
---
|
|
2
|
+
name: research-ml-training-patterns
|
|
3
|
+
description: Training loop patterns for research iteration including fast-fail detection, curriculum exploration, hyperparameter-conditioned training, reproducibility seeding, and checkpoint warm-starting
|
|
4
|
+
topics: [research, ml-research, training, fast-fail, curriculum, hyperparameter, reproducibility, checkpoint, warm-start]
|
|
5
|
+
---
|
|
6
|
+
|
|
7
|
+
Research training differs fundamentally from production training. In production, you train one model to convergence and ship it. In research, you train hundreds or thousands of configurations, most of which will fail -- the goal is to identify failures as early as possible and invest compute only in promising directions. A well-designed research training loop detects bad configurations within the first few percent of training, supports curriculum and schedule exploration without code changes, enables warm-starting from checkpoints to avoid redundant computation, and guarantees reproducibility so that any promising result can be verified.
|
|
8
|
+
|
|
9
|
+
## Summary
|
|
10
|
+
|
|
11
|
+
Design training loops for rapid iteration: implement fast-fail detection that aborts unpromising runs within 5-10% of the full budget, use hyperparameter-conditioned training that takes the full config as input (no hardcoded values), support curriculum schedules as first-class configuration objects, seed all randomness for exact reproducibility, and implement checkpoint-based warm-starting to resume experiments from any saved state. Separate the training loop from the evaluation loop so that evaluation strategies can evolve independently.
|
|
12
|
+
|
|
13
|
+
## Deep Guidance
|
|
14
|
+
|
|
15
|
+
### Fast-Fail Training
|
|
16
|
+
|
|
17
|
+
The most important research training pattern: detect bad configurations early and abort them. A run that will ultimately score poorly usually shows signals (diverging loss, NaN gradients, flat learning curves) within the first epoch:
|
|
18
|
+
|
|
19
|
+
```python
|
|
20
|
+
# src/training/fast_fail.py
|
|
21
|
+
from dataclasses import dataclass
|
|
22
|
+
import math
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class FastFailConfig:
|
|
26
|
+
"""Configuration for early termination of bad runs."""
|
|
27
|
+
# Abort if loss exceeds this multiple of initial loss
|
|
28
|
+
loss_explosion_factor: float = 10.0
|
|
29
|
+
# Abort if loss has not decreased after this many steps
|
|
30
|
+
patience_steps: int = 500
|
|
31
|
+
# Minimum improvement required within patience window
|
|
32
|
+
min_improvement_pct: float = 1.0
|
|
33
|
+
# Abort immediately on NaN/Inf
|
|
34
|
+
abort_on_nan: bool = True
|
|
35
|
+
# Check interval (don't check every step -- too expensive)
|
|
36
|
+
check_every_n_steps: int = 50
|
|
37
|
+
|
|
38
|
+
class FastFailDetector:
|
|
39
|
+
"""Detect and abort unpromising training runs early."""
|
|
40
|
+
|
|
41
|
+
def __init__(self, config: FastFailConfig):
|
|
42
|
+
self.config = config
|
|
43
|
+
self.initial_loss: float | None = None
|
|
44
|
+
self.best_loss: float = float("inf")
|
|
45
|
+
self.steps_since_improvement: int = 0
|
|
46
|
+
self.total_steps: int = 0
|
|
47
|
+
|
|
48
|
+
def check(self, loss: float) -> tuple[bool, str]:
|
|
49
|
+
"""Return (should_abort, reason) after observing a loss value."""
|
|
50
|
+
self.total_steps += 1
|
|
51
|
+
|
|
52
|
+
# NaN/Inf check (always, regardless of interval)
|
|
53
|
+
if self.config.abort_on_nan and (math.isnan(loss) or math.isinf(loss)):
|
|
54
|
+
return True, f"NaN/Inf loss at step {self.total_steps}"
|
|
55
|
+
|
|
56
|
+
# Skip interval checks
|
|
57
|
+
if self.total_steps % self.config.check_every_n_steps != 0:
|
|
58
|
+
return False, ""
|
|
59
|
+
|
|
60
|
+
# Record initial loss
|
|
61
|
+
if self.initial_loss is None:
|
|
62
|
+
self.initial_loss = loss
|
|
63
|
+
self.best_loss = loss
|
|
64
|
+
return False, ""
|
|
65
|
+
|
|
66
|
+
# Loss explosion check
|
|
67
|
+
if loss > self.initial_loss * self.config.loss_explosion_factor:
|
|
68
|
+
return True, (
|
|
69
|
+
f"Loss exploded: {loss:.4f} > "
|
|
70
|
+
f"{self.initial_loss * self.config.loss_explosion_factor:.4f}"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Improvement check
|
|
74
|
+
improvement = (self.best_loss - loss) / abs(self.best_loss) * 100
|
|
75
|
+
if improvement > self.config.min_improvement_pct:
|
|
76
|
+
self.best_loss = loss
|
|
77
|
+
self.steps_since_improvement = 0
|
|
78
|
+
else:
|
|
79
|
+
self.steps_since_improvement += self.config.check_every_n_steps
|
|
80
|
+
|
|
81
|
+
if self.steps_since_improvement >= self.config.patience_steps:
|
|
82
|
+
return True, (
|
|
83
|
+
f"No improvement for {self.steps_since_improvement} steps "
|
|
84
|
+
f"(best: {self.best_loss:.4f}, current: {loss:.4f})"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
return False, ""
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
### Hyperparameter-Conditioned Training
|
|
91
|
+
|
|
92
|
+
Research training loops must accept the full experiment config as input, with zero hardcoded values. This enables sweep tools to drive training externally:
|
|
93
|
+
|
|
94
|
+
```python
|
|
95
|
+
# src/training/configurable_trainer.py
|
|
96
|
+
from dataclasses import dataclass, field
|
|
97
|
+
from typing import Any
|
|
98
|
+
import torch
|
|
99
|
+
import torch.nn as nn
|
|
100
|
+
|
|
101
|
+
@dataclass
|
|
102
|
+
class TrainingConfig:
|
|
103
|
+
"""Complete training configuration -- no hardcoded values."""
|
|
104
|
+
# Optimization
|
|
105
|
+
learning_rate: float = 1e-3
|
|
106
|
+
weight_decay: float = 1e-4
|
|
107
|
+
optimizer: str = "adamw" # "adam", "adamw", "sgd", "lion"
|
|
108
|
+
scheduler: str = "cosine" # "cosine", "linear", "step", "none"
|
|
109
|
+
warmup_steps: int = 100
|
|
110
|
+
max_steps: int = 10000
|
|
111
|
+
|
|
112
|
+
# Architecture (passed through to model builder)
|
|
113
|
+
model_config: dict[str, Any] = field(default_factory=dict)
|
|
114
|
+
|
|
115
|
+
# Training behavior
|
|
116
|
+
batch_size: int = 32
|
|
117
|
+
gradient_clip_norm: float = 1.0
|
|
118
|
+
mixed_precision: bool = True
|
|
119
|
+
gradient_accumulation_steps: int = 1
|
|
120
|
+
|
|
121
|
+
# Fast-fail
|
|
122
|
+
fast_fail: bool = True
|
|
123
|
+
fast_fail_patience: int = 500
|
|
124
|
+
|
|
125
|
+
# Reproducibility
|
|
126
|
+
seed: int = 42
|
|
127
|
+
|
|
128
|
+
def build_optimizer(model: nn.Module, config: TrainingConfig) -> torch.optim.Optimizer:
|
|
129
|
+
"""Build optimizer from config -- never hardcode optimizer choice."""
|
|
130
|
+
optimizers = {
|
|
131
|
+
"adam": torch.optim.Adam,
|
|
132
|
+
"adamw": torch.optim.AdamW,
|
|
133
|
+
"sgd": torch.optim.SGD,
|
|
134
|
+
}
|
|
135
|
+
cls = optimizers[config.optimizer]
|
|
136
|
+
kwargs = {"lr": config.learning_rate, "weight_decay": config.weight_decay}
|
|
137
|
+
if config.optimizer == "sgd":
|
|
138
|
+
kwargs["momentum"] = 0.9
|
|
139
|
+
return cls(model.parameters(), **kwargs)
|
|
140
|
+
|
|
141
|
+
def build_scheduler(optimizer, config: TrainingConfig):
|
|
142
|
+
"""Build LR scheduler from config."""
|
|
143
|
+
if config.scheduler == "cosine":
|
|
144
|
+
return torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
145
|
+
optimizer, T_max=config.max_steps - config.warmup_steps
|
|
146
|
+
)
|
|
147
|
+
elif config.scheduler == "linear":
|
|
148
|
+
return torch.optim.lr_scheduler.LinearLR(
|
|
149
|
+
optimizer, start_factor=1.0, end_factor=0.0,
|
|
150
|
+
total_iters=config.max_steps - config.warmup_steps
|
|
151
|
+
)
|
|
152
|
+
elif config.scheduler == "step":
|
|
153
|
+
return torch.optim.lr_scheduler.StepLR(
|
|
154
|
+
optimizer, step_size=config.max_steps // 5, gamma=0.5
|
|
155
|
+
)
|
|
156
|
+
return None
|
|
157
|
+
```
|
|
158
|
+
|
|
159
|
+
### Curriculum and Schedule Exploration
|
|
160
|
+
|
|
161
|
+
Research often explores different training curricula (data ordering, task difficulty progression, loss weighting schedules). Define these as first-class objects:
|
|
162
|
+
|
|
163
|
+
```python
|
|
164
|
+
# src/training/curriculum.py
|
|
165
|
+
from dataclasses import dataclass
|
|
166
|
+
from typing import Callable
|
|
167
|
+
import math
|
|
168
|
+
|
|
169
|
+
@dataclass
|
|
170
|
+
class CurriculumSchedule:
|
|
171
|
+
"""A curriculum schedule that controls training progression."""
|
|
172
|
+
name: str
|
|
173
|
+
# Function mapping step -> difficulty level (0.0 to 1.0)
|
|
174
|
+
difficulty_fn: Callable[[int, int], float] # (step, max_steps) -> difficulty
|
|
175
|
+
|
|
176
|
+
def get_difficulty(self, step: int, max_steps: int) -> float:
|
|
177
|
+
return self.difficulty_fn(step, max_steps)
|
|
178
|
+
|
|
179
|
+
# Built-in schedules for experimentation
|
|
180
|
+
CURRICULUM_SCHEDULES = {
|
|
181
|
+
"linear": CurriculumSchedule(
|
|
182
|
+
name="linear",
|
|
183
|
+
difficulty_fn=lambda step, max_steps: step / max_steps,
|
|
184
|
+
),
|
|
185
|
+
"exponential": CurriculumSchedule(
|
|
186
|
+
name="exponential",
|
|
187
|
+
difficulty_fn=lambda step, max_steps: (
|
|
188
|
+
math.exp(3 * step / max_steps) - 1) / (math.e**3 - 1),
|
|
189
|
+
),
|
|
190
|
+
"step_3": CurriculumSchedule(
|
|
191
|
+
name="step_3",
|
|
192
|
+
difficulty_fn=lambda step, max_steps: min(1.0, (step // (max_steps // 3) + 1) / 3),
|
|
193
|
+
),
|
|
194
|
+
"constant_easy": CurriculumSchedule(
|
|
195
|
+
name="constant_easy",
|
|
196
|
+
difficulty_fn=lambda step, max_steps: 0.3,
|
|
197
|
+
),
|
|
198
|
+
"constant_hard": CurriculumSchedule(
|
|
199
|
+
name="constant_hard",
|
|
200
|
+
difficulty_fn=lambda step, max_steps: 1.0,
|
|
201
|
+
),
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
def filter_by_difficulty(
|
|
205
|
+
dataset,
|
|
206
|
+
difficulty_scores: list[float],
|
|
207
|
+
current_difficulty: float,
|
|
208
|
+
tolerance: float = 0.1,
|
|
209
|
+
) -> list[int]:
|
|
210
|
+
"""Return indices of samples at or below current difficulty level."""
|
|
211
|
+
return [
|
|
212
|
+
i for i, score in enumerate(difficulty_scores)
|
|
213
|
+
if score <= current_difficulty + tolerance
|
|
214
|
+
]
|
|
215
|
+
```
|
|
216
|
+
|
|
217
|
+
### Reproducibility Seeding
|
|
218
|
+
|
|
219
|
+
Exact reproducibility requires seeding every source of randomness. This is non-trivial with GPU operations:
|
|
220
|
+
|
|
221
|
+
```python
|
|
222
|
+
# src/training/reproducibility.py
|
|
223
|
+
import os
|
|
224
|
+
import random
|
|
225
|
+
import numpy as np
|
|
226
|
+
import torch
|
|
227
|
+
|
|
228
|
+
def seed_everything(seed: int) -> None:
|
|
229
|
+
"""Seed all random number generators for reproducibility."""
|
|
230
|
+
random.seed(seed)
|
|
231
|
+
np.random.seed(seed)
|
|
232
|
+
torch.manual_seed(seed)
|
|
233
|
+
torch.cuda.manual_seed_all(seed)
|
|
234
|
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
|
235
|
+
|
|
236
|
+
# Deterministic algorithms (slower but reproducible)
|
|
237
|
+
torch.backends.cudnn.deterministic = True
|
|
238
|
+
torch.backends.cudnn.benchmark = False
|
|
239
|
+
torch.use_deterministic_algorithms(True, warn_only=True)
|
|
240
|
+
|
|
241
|
+
def worker_init_fn(worker_id: int) -> None:
|
|
242
|
+
"""Seed dataloader workers for reproducibility."""
|
|
243
|
+
worker_seed = torch.initial_seed() % 2**32
|
|
244
|
+
np.random.seed(worker_seed)
|
|
245
|
+
random.seed(worker_seed)
|
|
246
|
+
|
|
247
|
+
def get_reproducibility_info(seed: int) -> dict:
|
|
248
|
+
"""Capture full reproducibility record for a training run."""
|
|
249
|
+
return {
|
|
250
|
+
"seed": seed,
|
|
251
|
+
"torch_version": torch.__version__,
|
|
252
|
+
"cuda_version": torch.version.cuda or "none",
|
|
253
|
+
"cudnn_version": torch.backends.cudnn.version() if torch.cuda.is_available() else None,
|
|
254
|
+
"deterministic": torch.backends.cudnn.deterministic,
|
|
255
|
+
"benchmark": torch.backends.cudnn.benchmark,
|
|
256
|
+
"gpu_name": (
|
|
257
|
+
torch.cuda.get_device_name(0) if torch.cuda.is_available() else "none"
|
|
258
|
+
),
|
|
259
|
+
}
|
|
260
|
+
```
|
|
261
|
+
|
|
262
|
+
### Checkpoint-Based Warm Starting
|
|
263
|
+
|
|
264
|
+
Warm-starting avoids re-training from scratch when exploring nearby configurations. Save and restore training state completely:
|
|
265
|
+
|
|
266
|
+
```python
|
|
267
|
+
# src/training/checkpointing.py
|
|
268
|
+
from dataclasses import dataclass
|
|
269
|
+
from pathlib import Path
|
|
270
|
+
import torch
|
|
271
|
+
import json
|
|
272
|
+
|
|
273
|
+
@dataclass
|
|
274
|
+
class CheckpointManager:
|
|
275
|
+
"""Manage training checkpoints for warm-starting experiments."""
|
|
276
|
+
checkpoint_dir: Path
|
|
277
|
+
keep_top_k: int = 5 # Keep only the best K checkpoints
|
|
278
|
+
|
|
279
|
+
def __post_init__(self):
|
|
280
|
+
self.checkpoint_dir = Path(self.checkpoint_dir)
|
|
281
|
+
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
282
|
+
|
|
283
|
+
def save(
|
|
284
|
+
self,
|
|
285
|
+
model,
|
|
286
|
+
optimizer,
|
|
287
|
+
scheduler,
|
|
288
|
+
step: int,
|
|
289
|
+
metrics: dict[str, float],
|
|
290
|
+
config: dict,
|
|
291
|
+
) -> Path:
|
|
292
|
+
"""Save complete training state for warm-starting."""
|
|
293
|
+
checkpoint = {
|
|
294
|
+
"model_state_dict": model.state_dict(),
|
|
295
|
+
"optimizer_state_dict": optimizer.state_dict(),
|
|
296
|
+
"scheduler_state_dict": scheduler.state_dict() if scheduler else None,
|
|
297
|
+
"step": step,
|
|
298
|
+
"metrics": metrics,
|
|
299
|
+
"config": config,
|
|
300
|
+
}
|
|
301
|
+
path = self.checkpoint_dir / f"checkpoint_step_{step}.pt"
|
|
302
|
+
torch.save(checkpoint, path)
|
|
303
|
+
|
|
304
|
+
# Save metadata for quick filtering
|
|
305
|
+
meta_path = self.checkpoint_dir / f"checkpoint_step_{step}.json"
|
|
306
|
+
with open(meta_path, "w") as f:
|
|
307
|
+
json.dump({"step": step, "metrics": metrics, "config": config}, f, indent=2)
|
|
308
|
+
|
|
309
|
+
self._enforce_top_k(metrics)
|
|
310
|
+
return path
|
|
311
|
+
|
|
312
|
+
def load(self, path: Path, model, optimizer=None, scheduler=None) -> dict:
|
|
313
|
+
"""Load checkpoint and restore training state."""
|
|
314
|
+
checkpoint = torch.load(path, map_location="cpu", weights_only=False)
|
|
315
|
+
model.load_state_dict(checkpoint["model_state_dict"])
|
|
316
|
+
if optimizer and checkpoint["optimizer_state_dict"]:
|
|
317
|
+
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
|
318
|
+
if scheduler and checkpoint["scheduler_state_dict"]:
|
|
319
|
+
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
|
320
|
+
return checkpoint
|
|
321
|
+
|
|
322
|
+
def find_warmstart_checkpoint(self, config: dict, metric: str = "val_loss") -> Path | None:
|
|
323
|
+
"""Find the best checkpoint from a similar config for warm-starting."""
|
|
324
|
+
meta_files = sorted(self.checkpoint_dir.glob("*.json"))
|
|
325
|
+
candidates = []
|
|
326
|
+
for meta_path in meta_files:
|
|
327
|
+
with open(meta_path) as f:
|
|
328
|
+
meta = json.load(f)
|
|
329
|
+
similarity = self._config_similarity(config, meta["config"])
|
|
330
|
+
if similarity > 0.7: # At least 70% similar
|
|
331
|
+
candidates.append((meta_path, meta, similarity))
|
|
332
|
+
|
|
333
|
+
if not candidates:
|
|
334
|
+
return None
|
|
335
|
+
|
|
336
|
+
# Pick highest similarity, break ties by best metric
|
|
337
|
+
candidates.sort(key=lambda x: (x[2], -x[1]["metrics"].get(metric, float("inf"))))
|
|
338
|
+
best_meta = candidates[-1][0]
|
|
339
|
+
ckpt_path = best_meta.with_suffix(".pt")
|
|
340
|
+
return ckpt_path if ckpt_path.exists() else None
|
|
341
|
+
|
|
342
|
+
def _config_similarity(self, config_a: dict, config_b: dict) -> float:
|
|
343
|
+
"""Compute fraction of matching config keys."""
|
|
344
|
+
all_keys = set(config_a) | set(config_b)
|
|
345
|
+
if not all_keys:
|
|
346
|
+
return 1.0
|
|
347
|
+
matching = sum(1 for k in all_keys if config_a.get(k) == config_b.get(k))
|
|
348
|
+
return matching / len(all_keys)
|
|
349
|
+
|
|
350
|
+
def _enforce_top_k(self, latest_metrics: dict) -> None:
|
|
351
|
+
"""Keep only top-K checkpoints by primary metric."""
|
|
352
|
+
meta_files = list(self.checkpoint_dir.glob("*.json"))
|
|
353
|
+
if len(meta_files) <= self.keep_top_k:
|
|
354
|
+
return
|
|
355
|
+
entries = []
|
|
356
|
+
for meta_path in meta_files:
|
|
357
|
+
with open(meta_path) as f:
|
|
358
|
+
meta = json.load(f)
|
|
359
|
+
entries.append((meta_path, meta))
|
|
360
|
+
|
|
361
|
+
# Sort by val_loss ascending (lower is better) -- remove worst
|
|
362
|
+
entries.sort(key=lambda x: x[1]["metrics"].get("val_loss", float("inf")))
|
|
363
|
+
for meta_path, _ in entries[self.keep_top_k:]:
|
|
364
|
+
meta_path.unlink(missing_ok=True)
|
|
365
|
+
meta_path.with_suffix(".pt").unlink(missing_ok=True)
|
|
366
|
+
```
|
|
367
|
+
|
|
368
|
+
### Research Training Loop Integration
|
|
369
|
+
|
|
370
|
+
Combine all patterns into a single research-oriented training loop:
|
|
371
|
+
|
|
372
|
+
```python
|
|
373
|
+
# src/training/research_trainer.py
|
|
374
|
+
from src.training.fast_fail import FastFailDetector, FastFailConfig
|
|
375
|
+
from src.training.reproducibility import seed_everything
|
|
376
|
+
from src.training.checkpointing import CheckpointManager
|
|
377
|
+
|
|
378
|
+
def research_train(config: TrainingConfig, model, train_loader, val_loader) -> dict:
|
|
379
|
+
"""Research training loop with fast-fail, seeding, and checkpointing."""
|
|
380
|
+
seed_everything(config.seed)
|
|
381
|
+
|
|
382
|
+
optimizer = build_optimizer(model, config)
|
|
383
|
+
scheduler = build_scheduler(optimizer, config)
|
|
384
|
+
checkpoint_mgr = CheckpointManager(Path("checkpoints"), keep_top_k=3)
|
|
385
|
+
|
|
386
|
+
# Attempt warm-start from similar config
|
|
387
|
+
warmstart_ckpt = checkpoint_mgr.find_warmstart_checkpoint(vars(config))
|
|
388
|
+
start_step = 0
|
|
389
|
+
if warmstart_ckpt:
|
|
390
|
+
ckpt_data = checkpoint_mgr.load(warmstart_ckpt, model, optimizer, scheduler)
|
|
391
|
+
start_step = ckpt_data["step"]
|
|
392
|
+
|
|
393
|
+
fast_fail = FastFailDetector(FastFailConfig(
|
|
394
|
+
patience_steps=config.fast_fail_patience
|
|
395
|
+
)) if config.fast_fail else None
|
|
396
|
+
|
|
397
|
+
for step in range(start_step, config.max_steps):
|
|
398
|
+
loss = train_step(model, optimizer, train_loader, config)
|
|
399
|
+
|
|
400
|
+
# Fast-fail check
|
|
401
|
+
if fast_fail:
|
|
402
|
+
should_abort, reason = fast_fail.check(loss.item())
|
|
403
|
+
if should_abort:
|
|
404
|
+
return {"status": "aborted", "reason": reason, "step": step}
|
|
405
|
+
|
|
406
|
+
# Periodic evaluation and checkpointing
|
|
407
|
+
if step % 500 == 0 and step > 0:
|
|
408
|
+
metrics = evaluate(model, val_loader)
|
|
409
|
+
checkpoint_mgr.save(model, optimizer, scheduler, step, metrics, vars(config))
|
|
410
|
+
|
|
411
|
+
final_metrics = evaluate(model, val_loader)
|
|
412
|
+
return {"status": "completed", "metrics": final_metrics, "steps": config.max_steps}
|
|
413
|
+
```
|