landmarkdiff 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- landmarkdiff/__init__.py +40 -0
- landmarkdiff/__main__.py +207 -0
- landmarkdiff/api_client.py +316 -0
- landmarkdiff/arcface_torch.py +583 -0
- landmarkdiff/audit.py +338 -0
- landmarkdiff/augmentation.py +293 -0
- landmarkdiff/benchmark.py +213 -0
- landmarkdiff/checkpoint_manager.py +361 -0
- landmarkdiff/cli.py +252 -0
- landmarkdiff/clinical.py +223 -0
- landmarkdiff/conditioning.py +278 -0
- landmarkdiff/config.py +358 -0
- landmarkdiff/curriculum.py +191 -0
- landmarkdiff/data.py +405 -0
- landmarkdiff/data_version.py +301 -0
- landmarkdiff/displacement_model.py +745 -0
- landmarkdiff/ensemble.py +330 -0
- landmarkdiff/evaluation.py +415 -0
- landmarkdiff/experiment_tracker.py +231 -0
- landmarkdiff/face_verifier.py +947 -0
- landmarkdiff/fid.py +244 -0
- landmarkdiff/hyperparam.py +347 -0
- landmarkdiff/inference.py +754 -0
- landmarkdiff/landmarks.py +432 -0
- landmarkdiff/log.py +90 -0
- landmarkdiff/losses.py +348 -0
- landmarkdiff/manipulation.py +651 -0
- landmarkdiff/masking.py +316 -0
- landmarkdiff/metrics_agg.py +313 -0
- landmarkdiff/metrics_viz.py +464 -0
- landmarkdiff/model_registry.py +362 -0
- landmarkdiff/morphometry.py +342 -0
- landmarkdiff/postprocess.py +600 -0
- landmarkdiff/py.typed +0 -0
- landmarkdiff/safety.py +395 -0
- landmarkdiff/synthetic/__init__.py +23 -0
- landmarkdiff/synthetic/augmentation.py +188 -0
- landmarkdiff/synthetic/pair_generator.py +208 -0
- landmarkdiff/synthetic/tps_warp.py +273 -0
- landmarkdiff/validation.py +324 -0
- landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
- landmarkdiff-0.2.3.dist-info/RECORD +46 -0
- landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
- landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
- landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
- landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
"""Inference benchmarking for deployment sizing.
|
|
2
|
+
|
|
3
|
+
Measures throughput, latency, and memory usage for ControlNet inference
|
|
4
|
+
under various configurations (resolution, batch size, denoising steps).
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
from landmarkdiff.benchmark import InferenceBenchmark
|
|
8
|
+
|
|
9
|
+
bench = InferenceBenchmark()
|
|
10
|
+
bench.add_result("gpu_a6000", latency_ms=142.3, throughput_fps=7.0, vram_gb=4.2)
|
|
11
|
+
bench.add_result("gpu_a6000", latency_ms=138.1, throughput_fps=7.2, vram_gb=4.2)
|
|
12
|
+
print(bench.summary())
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
import time
|
|
19
|
+
from dataclasses import dataclass, field
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class BenchmarkResult:
|
|
26
|
+
"""A single benchmark measurement."""
|
|
27
|
+
|
|
28
|
+
config_name: str
|
|
29
|
+
latency_ms: float
|
|
30
|
+
throughput_fps: float = 0.0
|
|
31
|
+
vram_gb: float = 0.0
|
|
32
|
+
batch_size: int = 1
|
|
33
|
+
resolution: int = 512
|
|
34
|
+
num_inference_steps: int = 20
|
|
35
|
+
device: str = ""
|
|
36
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class InferenceBenchmark:
|
|
40
|
+
"""Collect and analyze inference benchmarks.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
model_name: Name of the model being benchmarked.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, model_name: str = "LandmarkDiff-ControlNet") -> None:
|
|
47
|
+
self.model_name = model_name
|
|
48
|
+
self.results: list[BenchmarkResult] = []
|
|
49
|
+
|
|
50
|
+
def add_result(
|
|
51
|
+
self,
|
|
52
|
+
config_name: str,
|
|
53
|
+
latency_ms: float,
|
|
54
|
+
throughput_fps: float = 0.0,
|
|
55
|
+
vram_gb: float = 0.0,
|
|
56
|
+
batch_size: int = 1,
|
|
57
|
+
resolution: int = 512,
|
|
58
|
+
num_inference_steps: int = 20,
|
|
59
|
+
device: str = "",
|
|
60
|
+
**metadata: Any,
|
|
61
|
+
) -> None:
|
|
62
|
+
"""Add a benchmark result."""
|
|
63
|
+
if throughput_fps == 0.0 and latency_ms > 0:
|
|
64
|
+
throughput_fps = 1000.0 / latency_ms * batch_size
|
|
65
|
+
|
|
66
|
+
self.results.append(
|
|
67
|
+
BenchmarkResult(
|
|
68
|
+
config_name=config_name,
|
|
69
|
+
latency_ms=latency_ms,
|
|
70
|
+
throughput_fps=throughput_fps,
|
|
71
|
+
vram_gb=vram_gb,
|
|
72
|
+
batch_size=batch_size,
|
|
73
|
+
resolution=resolution,
|
|
74
|
+
num_inference_steps=num_inference_steps,
|
|
75
|
+
device=device,
|
|
76
|
+
metadata=metadata,
|
|
77
|
+
)
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
def mean_latency(self, config_name: str | None = None) -> float:
|
|
81
|
+
"""Mean latency in ms, optionally filtered by config."""
|
|
82
|
+
results = self._filter(config_name)
|
|
83
|
+
if not results:
|
|
84
|
+
return float("nan")
|
|
85
|
+
return sum(r.latency_ms for r in results) / len(results)
|
|
86
|
+
|
|
87
|
+
def p99_latency(self, config_name: str | None = None) -> float:
|
|
88
|
+
"""P99 latency in ms."""
|
|
89
|
+
results = self._filter(config_name)
|
|
90
|
+
if not results:
|
|
91
|
+
return float("nan")
|
|
92
|
+
sorted_latencies = sorted(r.latency_ms for r in results)
|
|
93
|
+
idx = max(0, int(len(sorted_latencies) * 0.99) - 1)
|
|
94
|
+
return sorted_latencies[idx]
|
|
95
|
+
|
|
96
|
+
def mean_throughput(self, config_name: str | None = None) -> float:
|
|
97
|
+
"""Mean throughput in FPS."""
|
|
98
|
+
results = self._filter(config_name)
|
|
99
|
+
if not results:
|
|
100
|
+
return float("nan")
|
|
101
|
+
return sum(r.throughput_fps for r in results) / len(results)
|
|
102
|
+
|
|
103
|
+
def max_vram(self, config_name: str | None = None) -> float:
|
|
104
|
+
"""Maximum VRAM usage in GB."""
|
|
105
|
+
results = self._filter(config_name)
|
|
106
|
+
if not results:
|
|
107
|
+
return 0.0
|
|
108
|
+
return max(r.vram_gb for r in results)
|
|
109
|
+
|
|
110
|
+
def _filter(self, config_name: str | None) -> list[BenchmarkResult]:
|
|
111
|
+
if config_name is None:
|
|
112
|
+
return self.results
|
|
113
|
+
return [r for r in self.results if r.config_name == config_name]
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def config_names(self) -> list[str]:
|
|
117
|
+
"""Unique config names in order."""
|
|
118
|
+
seen: dict[str, None] = {}
|
|
119
|
+
for r in self.results:
|
|
120
|
+
seen.setdefault(r.config_name, None)
|
|
121
|
+
return list(seen.keys())
|
|
122
|
+
|
|
123
|
+
def summary(self) -> str:
|
|
124
|
+
"""Generate text summary table."""
|
|
125
|
+
configs = self.config_names
|
|
126
|
+
if not configs:
|
|
127
|
+
return "No benchmark results."
|
|
128
|
+
|
|
129
|
+
header = (
|
|
130
|
+
f"{'Config':>20s} | {'Mean(ms)':>10s} | {'P99(ms)':>10s}"
|
|
131
|
+
f" | {'FPS':>8s} | {'VRAM(GB)':>8s} | {'N':>4s}"
|
|
132
|
+
)
|
|
133
|
+
lines = [
|
|
134
|
+
f"Inference Benchmark: {self.model_name}",
|
|
135
|
+
header,
|
|
136
|
+
"-" * len(header),
|
|
137
|
+
]
|
|
138
|
+
|
|
139
|
+
for cfg in configs:
|
|
140
|
+
results = self._filter(cfg)
|
|
141
|
+
lines.append(
|
|
142
|
+
f"{cfg:>20s} | "
|
|
143
|
+
f"{self.mean_latency(cfg):>10.1f} | "
|
|
144
|
+
f"{self.p99_latency(cfg):>10.1f} | "
|
|
145
|
+
f"{self.mean_throughput(cfg):>8.2f} | "
|
|
146
|
+
f"{self.max_vram(cfg):>8.1f} | "
|
|
147
|
+
f"{len(results):>4d}"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
return "\n".join(lines)
|
|
151
|
+
|
|
152
|
+
def to_json(self, path: str | Path | None = None) -> str:
|
|
153
|
+
"""Export results as JSON."""
|
|
154
|
+
data = {
|
|
155
|
+
"model_name": self.model_name,
|
|
156
|
+
"results": [
|
|
157
|
+
{
|
|
158
|
+
"config_name": r.config_name,
|
|
159
|
+
"latency_ms": r.latency_ms,
|
|
160
|
+
"throughput_fps": round(r.throughput_fps, 2),
|
|
161
|
+
"vram_gb": r.vram_gb,
|
|
162
|
+
"batch_size": r.batch_size,
|
|
163
|
+
"resolution": r.resolution,
|
|
164
|
+
"num_inference_steps": r.num_inference_steps,
|
|
165
|
+
"device": r.device,
|
|
166
|
+
}
|
|
167
|
+
for r in self.results
|
|
168
|
+
],
|
|
169
|
+
"summary": {
|
|
170
|
+
cfg: {
|
|
171
|
+
"mean_latency_ms": round(self.mean_latency(cfg), 1),
|
|
172
|
+
"p99_latency_ms": round(self.p99_latency(cfg), 1),
|
|
173
|
+
"mean_fps": round(self.mean_throughput(cfg), 2),
|
|
174
|
+
"max_vram_gb": round(self.max_vram(cfg), 1),
|
|
175
|
+
"n_samples": len(self._filter(cfg)),
|
|
176
|
+
}
|
|
177
|
+
for cfg in self.config_names
|
|
178
|
+
},
|
|
179
|
+
}
|
|
180
|
+
j = json.dumps(data, indent=2)
|
|
181
|
+
if path:
|
|
182
|
+
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
|
183
|
+
Path(path).write_text(j)
|
|
184
|
+
return j
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class Timer:
|
|
188
|
+
"""Simple context manager for timing code blocks.
|
|
189
|
+
|
|
190
|
+
Usage:
|
|
191
|
+
with Timer() as t:
|
|
192
|
+
run_inference()
|
|
193
|
+
print(f"Took {t.elapsed_ms:.1f} ms")
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
def __init__(self) -> None:
|
|
197
|
+
self.start_time: float = 0.0
|
|
198
|
+
self.end_time: float = 0.0
|
|
199
|
+
|
|
200
|
+
@property
|
|
201
|
+
def elapsed_ms(self) -> float:
|
|
202
|
+
return (self.end_time - self.start_time) * 1000
|
|
203
|
+
|
|
204
|
+
@property
|
|
205
|
+
def elapsed_s(self) -> float:
|
|
206
|
+
return self.end_time - self.start_time
|
|
207
|
+
|
|
208
|
+
def __enter__(self) -> Timer:
|
|
209
|
+
self.start_time = time.perf_counter()
|
|
210
|
+
return self
|
|
211
|
+
|
|
212
|
+
def __exit__(self, *args: Any) -> None:
|
|
213
|
+
self.end_time = time.perf_counter()
|
|
@@ -0,0 +1,361 @@
|
|
|
1
|
+
"""Checkpoint management with metadata tracking, best-model selection, and pruning.
|
|
2
|
+
|
|
3
|
+
Provides a central manager for training checkpoints that:
|
|
4
|
+
- Tracks per-checkpoint metadata (step, metrics, timestamps)
|
|
5
|
+
- Maintains symlinks to best/latest checkpoints
|
|
6
|
+
- Prunes old checkpoints to save disk space
|
|
7
|
+
- Supports multiple ranking metrics (loss, FID, SSIM, etc.)
|
|
8
|
+
|
|
9
|
+
Usage:
|
|
10
|
+
manager = CheckpointManager(
|
|
11
|
+
output_dir="checkpoints/phaseA",
|
|
12
|
+
keep_best=3,
|
|
13
|
+
keep_latest=5,
|
|
14
|
+
metric="loss",
|
|
15
|
+
lower_is_better=True,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
# During training loop:
|
|
19
|
+
manager.save(
|
|
20
|
+
step=1000,
|
|
21
|
+
controlnet=controlnet,
|
|
22
|
+
ema_controlnet=ema_controlnet,
|
|
23
|
+
optimizer=optimizer,
|
|
24
|
+
scheduler=scheduler,
|
|
25
|
+
metrics={"loss": 0.0123, "val_ssim": 0.87},
|
|
26
|
+
)
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from __future__ import annotations
|
|
30
|
+
|
|
31
|
+
import json
|
|
32
|
+
import shutil
|
|
33
|
+
import time
|
|
34
|
+
from dataclasses import asdict, dataclass, field
|
|
35
|
+
from pathlib import Path
|
|
36
|
+
from typing import Any
|
|
37
|
+
|
|
38
|
+
import torch
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class CheckpointMetadata:
|
|
43
|
+
"""Metadata for a single checkpoint."""
|
|
44
|
+
|
|
45
|
+
step: int
|
|
46
|
+
timestamp: float
|
|
47
|
+
metrics: dict[str, float] = field(default_factory=dict)
|
|
48
|
+
epoch: int | None = None
|
|
49
|
+
phase: str = ""
|
|
50
|
+
is_best: bool = False
|
|
51
|
+
size_mb: float = 0.0
|
|
52
|
+
|
|
53
|
+
def to_dict(self) -> dict[str, Any]:
|
|
54
|
+
return asdict(self)
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def from_dict(cls, d: dict[str, Any]) -> CheckpointMetadata:
|
|
58
|
+
return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__})
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class CheckpointManager:
|
|
62
|
+
"""Manages training checkpoints with pruning and best-model tracking.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
output_dir: Base directory for checkpoints.
|
|
66
|
+
keep_best: Number of best checkpoints to retain.
|
|
67
|
+
keep_latest: Number of most recent checkpoints to retain.
|
|
68
|
+
metric: Metric name used to determine "best" checkpoint.
|
|
69
|
+
lower_is_better: If True, lower metric values are better (e.g. loss, FID).
|
|
70
|
+
prefix: Checkpoint directory prefix (default: "checkpoint").
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
INDEX_FILE = "checkpoint_index.json"
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
output_dir: str | Path,
|
|
78
|
+
keep_best: int = 3,
|
|
79
|
+
keep_latest: int = 5,
|
|
80
|
+
metric: str = "loss",
|
|
81
|
+
lower_is_better: bool = True,
|
|
82
|
+
prefix: str = "checkpoint",
|
|
83
|
+
) -> None:
|
|
84
|
+
self.output_dir = Path(output_dir)
|
|
85
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
86
|
+
self.keep_best = keep_best
|
|
87
|
+
self.keep_latest = keep_latest
|
|
88
|
+
self.metric = metric
|
|
89
|
+
self.lower_is_better = lower_is_better
|
|
90
|
+
self.prefix = prefix
|
|
91
|
+
|
|
92
|
+
self._index: dict[str, Any] = {"checkpoints": {}}
|
|
93
|
+
self._load_index()
|
|
94
|
+
|
|
95
|
+
# ------------------------------------------------------------------
|
|
96
|
+
# Index persistence
|
|
97
|
+
# ------------------------------------------------------------------
|
|
98
|
+
|
|
99
|
+
def _index_path(self) -> Path:
|
|
100
|
+
return self.output_dir / self.INDEX_FILE
|
|
101
|
+
|
|
102
|
+
def _load_index(self) -> None:
|
|
103
|
+
path = self._index_path()
|
|
104
|
+
if path.exists():
|
|
105
|
+
with open(path) as f:
|
|
106
|
+
self._index = json.load(f)
|
|
107
|
+
if "checkpoints" not in self._index:
|
|
108
|
+
self._index["checkpoints"] = {}
|
|
109
|
+
|
|
110
|
+
def _save_index(self) -> None:
|
|
111
|
+
with open(self._index_path(), "w") as f:
|
|
112
|
+
json.dump(self._index, f, indent=2)
|
|
113
|
+
|
|
114
|
+
# ------------------------------------------------------------------
|
|
115
|
+
# Save checkpoint
|
|
116
|
+
# ------------------------------------------------------------------
|
|
117
|
+
|
|
118
|
+
def save(
|
|
119
|
+
self,
|
|
120
|
+
step: int,
|
|
121
|
+
controlnet: torch.nn.Module,
|
|
122
|
+
ema_controlnet: torch.nn.Module,
|
|
123
|
+
optimizer: torch.optim.Optimizer,
|
|
124
|
+
scheduler: Any = None,
|
|
125
|
+
metrics: dict[str, float] | None = None,
|
|
126
|
+
epoch: int | None = None,
|
|
127
|
+
phase: str = "",
|
|
128
|
+
extra_state: dict[str, Any] | None = None,
|
|
129
|
+
) -> Path:
|
|
130
|
+
"""Save a checkpoint with metadata.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
step: Current training step.
|
|
134
|
+
controlnet: ControlNet model (or any nn.Module).
|
|
135
|
+
ema_controlnet: EMA copy of the model.
|
|
136
|
+
optimizer: Optimizer state.
|
|
137
|
+
scheduler: Optional LR scheduler.
|
|
138
|
+
metrics: Dict of metric values at this step.
|
|
139
|
+
epoch: Optional epoch number.
|
|
140
|
+
phase: Training phase label (e.g. "A", "B").
|
|
141
|
+
extra_state: Any additional state to save.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Path to the saved checkpoint directory.
|
|
145
|
+
"""
|
|
146
|
+
ckpt_name = f"{self.prefix}-{step}"
|
|
147
|
+
ckpt_dir = self.output_dir / ckpt_name
|
|
148
|
+
ckpt_dir.mkdir(exist_ok=True)
|
|
149
|
+
|
|
150
|
+
# Save EMA weights (used for inference)
|
|
151
|
+
if hasattr(ema_controlnet, "save_pretrained"):
|
|
152
|
+
ema_controlnet.save_pretrained(ckpt_dir / "controlnet_ema")
|
|
153
|
+
|
|
154
|
+
# Save training state for resume
|
|
155
|
+
state = {
|
|
156
|
+
"controlnet": _get_state_dict(controlnet),
|
|
157
|
+
"ema_controlnet": _get_state_dict(ema_controlnet),
|
|
158
|
+
"optimizer": optimizer.state_dict(),
|
|
159
|
+
"global_step": step,
|
|
160
|
+
}
|
|
161
|
+
if scheduler is not None:
|
|
162
|
+
state["scheduler"] = scheduler.state_dict()
|
|
163
|
+
if extra_state:
|
|
164
|
+
state.update(extra_state)
|
|
165
|
+
|
|
166
|
+
torch.save(state, ckpt_dir / "training_state.pt")
|
|
167
|
+
|
|
168
|
+
# Compute checkpoint size
|
|
169
|
+
size_mb = sum(f.stat().st_size for f in ckpt_dir.rglob("*") if f.is_file()) / (1024 * 1024)
|
|
170
|
+
|
|
171
|
+
# Create metadata
|
|
172
|
+
meta = CheckpointMetadata(
|
|
173
|
+
step=step,
|
|
174
|
+
timestamp=time.time(),
|
|
175
|
+
metrics=metrics or {},
|
|
176
|
+
epoch=epoch,
|
|
177
|
+
phase=phase,
|
|
178
|
+
size_mb=round(size_mb, 1),
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# Save metadata alongside checkpoint
|
|
182
|
+
with open(ckpt_dir / "metadata.json", "w") as f:
|
|
183
|
+
json.dump(meta.to_dict(), f, indent=2)
|
|
184
|
+
|
|
185
|
+
# Update index
|
|
186
|
+
self._index["checkpoints"][ckpt_name] = meta.to_dict()
|
|
187
|
+
self._update_best()
|
|
188
|
+
self._save_index()
|
|
189
|
+
|
|
190
|
+
# Update symlinks
|
|
191
|
+
self._update_symlinks()
|
|
192
|
+
|
|
193
|
+
# Prune old checkpoints
|
|
194
|
+
self._prune()
|
|
195
|
+
|
|
196
|
+
return ckpt_dir
|
|
197
|
+
|
|
198
|
+
# ------------------------------------------------------------------
|
|
199
|
+
# Best / latest tracking
|
|
200
|
+
# ------------------------------------------------------------------
|
|
201
|
+
|
|
202
|
+
def _update_best(self) -> None:
|
|
203
|
+
"""Recompute which checkpoints are 'best'."""
|
|
204
|
+
entries = []
|
|
205
|
+
for name, meta in self._index["checkpoints"].items():
|
|
206
|
+
val = meta.get("metrics", {}).get(self.metric)
|
|
207
|
+
if val is not None:
|
|
208
|
+
entries.append((name, val, meta))
|
|
209
|
+
|
|
210
|
+
if not entries:
|
|
211
|
+
return
|
|
212
|
+
|
|
213
|
+
# Sort by metric
|
|
214
|
+
entries.sort(key=lambda x: x[1], reverse=not self.lower_is_better)
|
|
215
|
+
|
|
216
|
+
# Mark best
|
|
217
|
+
best_names = {e[0] for e in entries[: self.keep_best]}
|
|
218
|
+
for name, meta in self._index["checkpoints"].items():
|
|
219
|
+
meta["is_best"] = name in best_names
|
|
220
|
+
|
|
221
|
+
def _update_symlinks(self) -> None:
|
|
222
|
+
"""Update 'latest' and 'best' symlinks."""
|
|
223
|
+
checkpoints = self._sorted_by_step()
|
|
224
|
+
if not checkpoints:
|
|
225
|
+
return
|
|
226
|
+
|
|
227
|
+
# Latest symlink
|
|
228
|
+
latest_name = checkpoints[-1]
|
|
229
|
+
latest_link = self.output_dir / "latest"
|
|
230
|
+
_force_symlink(self.output_dir / latest_name, latest_link)
|
|
231
|
+
|
|
232
|
+
# Best symlink
|
|
233
|
+
best_name = self.get_best_checkpoint_name()
|
|
234
|
+
if best_name:
|
|
235
|
+
best_link = self.output_dir / "best"
|
|
236
|
+
_force_symlink(self.output_dir / best_name, best_link)
|
|
237
|
+
|
|
238
|
+
def get_best_checkpoint_name(self) -> str | None:
|
|
239
|
+
"""Return the name of the best checkpoint by tracked metric."""
|
|
240
|
+
best = None
|
|
241
|
+
best_val = None
|
|
242
|
+
for name, meta in self._index["checkpoints"].items():
|
|
243
|
+
val = meta.get("metrics", {}).get(self.metric)
|
|
244
|
+
if val is None:
|
|
245
|
+
continue
|
|
246
|
+
if (
|
|
247
|
+
best_val is None
|
|
248
|
+
or (self.lower_is_better and val < best_val)
|
|
249
|
+
or (not self.lower_is_better and val > best_val)
|
|
250
|
+
):
|
|
251
|
+
best, best_val = name, val
|
|
252
|
+
return best
|
|
253
|
+
|
|
254
|
+
def get_best_metric_value(self) -> float | None:
|
|
255
|
+
"""Return the best value of the tracked metric."""
|
|
256
|
+
name = self.get_best_checkpoint_name()
|
|
257
|
+
if name is None:
|
|
258
|
+
return None
|
|
259
|
+
return self._index["checkpoints"][name]["metrics"].get(self.metric)
|
|
260
|
+
|
|
261
|
+
# ------------------------------------------------------------------
|
|
262
|
+
# Pruning
|
|
263
|
+
# ------------------------------------------------------------------
|
|
264
|
+
|
|
265
|
+
def _sorted_by_step(self) -> list[str]:
|
|
266
|
+
"""Return checkpoint names sorted by step (ascending)."""
|
|
267
|
+
items = list(self._index["checkpoints"].items())
|
|
268
|
+
items.sort(key=lambda x: x[1].get("step", 0))
|
|
269
|
+
return [name for name, _ in items]
|
|
270
|
+
|
|
271
|
+
def _prune(self) -> None:
|
|
272
|
+
"""Remove old checkpoints, keeping best N and latest M."""
|
|
273
|
+
all_names = self._sorted_by_step()
|
|
274
|
+
if len(all_names) <= self.keep_latest:
|
|
275
|
+
return
|
|
276
|
+
|
|
277
|
+
# Determine which to keep
|
|
278
|
+
keep = set()
|
|
279
|
+
|
|
280
|
+
# Keep latest
|
|
281
|
+
for name in all_names[-self.keep_latest :]:
|
|
282
|
+
keep.add(name)
|
|
283
|
+
|
|
284
|
+
# Keep best
|
|
285
|
+
for name, meta in self._index["checkpoints"].items():
|
|
286
|
+
if meta.get("is_best", False):
|
|
287
|
+
keep.add(name)
|
|
288
|
+
|
|
289
|
+
# Delete the rest
|
|
290
|
+
for name in all_names:
|
|
291
|
+
if name not in keep:
|
|
292
|
+
ckpt_dir = self.output_dir / name
|
|
293
|
+
if ckpt_dir.exists():
|
|
294
|
+
shutil.rmtree(ckpt_dir)
|
|
295
|
+
self._index["checkpoints"].pop(name, None)
|
|
296
|
+
|
|
297
|
+
self._save_index()
|
|
298
|
+
|
|
299
|
+
# ------------------------------------------------------------------
|
|
300
|
+
# Queries
|
|
301
|
+
# ------------------------------------------------------------------
|
|
302
|
+
|
|
303
|
+
def list_checkpoints(self) -> list[dict[str, Any]]:
|
|
304
|
+
"""Return metadata for all tracked checkpoints, sorted by step."""
|
|
305
|
+
result = []
|
|
306
|
+
for name in self._sorted_by_step():
|
|
307
|
+
meta = self._index["checkpoints"][name]
|
|
308
|
+
result.append({"name": name, **meta})
|
|
309
|
+
return result
|
|
310
|
+
|
|
311
|
+
def get_checkpoint_path(self, name: str) -> Path:
|
|
312
|
+
"""Return the filesystem path for a checkpoint by name."""
|
|
313
|
+
return self.output_dir / name
|
|
314
|
+
|
|
315
|
+
def get_latest_step(self) -> int:
|
|
316
|
+
"""Return the step of the most recent checkpoint, or 0."""
|
|
317
|
+
names = self._sorted_by_step()
|
|
318
|
+
if not names:
|
|
319
|
+
return 0
|
|
320
|
+
return self._index["checkpoints"][names[-1]].get("step", 0)
|
|
321
|
+
|
|
322
|
+
def total_size_mb(self) -> float:
|
|
323
|
+
"""Return total disk size of all tracked checkpoints."""
|
|
324
|
+
return sum(meta.get("size_mb", 0.0) for meta in self._index["checkpoints"].values())
|
|
325
|
+
|
|
326
|
+
def summary(self) -> str:
|
|
327
|
+
"""Return a human-readable summary of checkpoint state."""
|
|
328
|
+
ckpts = self.list_checkpoints()
|
|
329
|
+
if not ckpts:
|
|
330
|
+
return "No checkpoints saved."
|
|
331
|
+
|
|
332
|
+
lines = [
|
|
333
|
+
f"Checkpoints: {len(ckpts)} saved ({self.total_size_mb():.0f} MB total)",
|
|
334
|
+
f"Latest: step {self.get_latest_step()}",
|
|
335
|
+
]
|
|
336
|
+
|
|
337
|
+
best_name = self.get_best_checkpoint_name()
|
|
338
|
+
best_val = self.get_best_metric_value()
|
|
339
|
+
if best_name and best_val is not None:
|
|
340
|
+
lines.append(f"Best ({self.metric}): {best_val:.6f} @ {best_name}")
|
|
341
|
+
|
|
342
|
+
return "\n".join(lines)
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
# ------------------------------------------------------------------
|
|
346
|
+
# Helpers
|
|
347
|
+
# ------------------------------------------------------------------
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def _get_state_dict(module: torch.nn.Module) -> dict:
|
|
351
|
+
"""Extract state dict, handling DDP wrapper."""
|
|
352
|
+
if hasattr(module, "module"):
|
|
353
|
+
return module.module.state_dict()
|
|
354
|
+
return module.state_dict()
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def _force_symlink(target: Path, link: Path) -> None:
|
|
358
|
+
"""Create or replace a symlink."""
|
|
359
|
+
if link.is_symlink() or link.exists():
|
|
360
|
+
link.unlink()
|
|
361
|
+
link.symlink_to(target.name)
|