landmarkdiff 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. landmarkdiff/__init__.py +40 -0
  2. landmarkdiff/__main__.py +207 -0
  3. landmarkdiff/api_client.py +316 -0
  4. landmarkdiff/arcface_torch.py +583 -0
  5. landmarkdiff/audit.py +338 -0
  6. landmarkdiff/augmentation.py +293 -0
  7. landmarkdiff/benchmark.py +213 -0
  8. landmarkdiff/checkpoint_manager.py +361 -0
  9. landmarkdiff/cli.py +252 -0
  10. landmarkdiff/clinical.py +223 -0
  11. landmarkdiff/conditioning.py +278 -0
  12. landmarkdiff/config.py +358 -0
  13. landmarkdiff/curriculum.py +191 -0
  14. landmarkdiff/data.py +405 -0
  15. landmarkdiff/data_version.py +301 -0
  16. landmarkdiff/displacement_model.py +745 -0
  17. landmarkdiff/ensemble.py +330 -0
  18. landmarkdiff/evaluation.py +415 -0
  19. landmarkdiff/experiment_tracker.py +231 -0
  20. landmarkdiff/face_verifier.py +947 -0
  21. landmarkdiff/fid.py +244 -0
  22. landmarkdiff/hyperparam.py +347 -0
  23. landmarkdiff/inference.py +754 -0
  24. landmarkdiff/landmarks.py +432 -0
  25. landmarkdiff/log.py +90 -0
  26. landmarkdiff/losses.py +348 -0
  27. landmarkdiff/manipulation.py +651 -0
  28. landmarkdiff/masking.py +316 -0
  29. landmarkdiff/metrics_agg.py +313 -0
  30. landmarkdiff/metrics_viz.py +464 -0
  31. landmarkdiff/model_registry.py +362 -0
  32. landmarkdiff/morphometry.py +342 -0
  33. landmarkdiff/postprocess.py +600 -0
  34. landmarkdiff/py.typed +0 -0
  35. landmarkdiff/safety.py +395 -0
  36. landmarkdiff/synthetic/__init__.py +23 -0
  37. landmarkdiff/synthetic/augmentation.py +188 -0
  38. landmarkdiff/synthetic/pair_generator.py +208 -0
  39. landmarkdiff/synthetic/tps_warp.py +273 -0
  40. landmarkdiff/validation.py +324 -0
  41. landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
  42. landmarkdiff-0.2.3.dist-info/RECORD +46 -0
  43. landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
  44. landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
  45. landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
  46. landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,213 @@
1
+ """Inference benchmarking for deployment sizing.
2
+
3
+ Measures throughput, latency, and memory usage for ControlNet inference
4
+ under various configurations (resolution, batch size, denoising steps).
5
+
6
+ Usage:
7
+ from landmarkdiff.benchmark import InferenceBenchmark
8
+
9
+ bench = InferenceBenchmark()
10
+ bench.add_result("gpu_a6000", latency_ms=142.3, throughput_fps=7.0, vram_gb=4.2)
11
+ bench.add_result("gpu_a6000", latency_ms=138.1, throughput_fps=7.2, vram_gb=4.2)
12
+ print(bench.summary())
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ import time
19
+ from dataclasses import dataclass, field
20
+ from pathlib import Path
21
+ from typing import Any
22
+
23
+
24
+ @dataclass
25
+ class BenchmarkResult:
26
+ """A single benchmark measurement."""
27
+
28
+ config_name: str
29
+ latency_ms: float
30
+ throughput_fps: float = 0.0
31
+ vram_gb: float = 0.0
32
+ batch_size: int = 1
33
+ resolution: int = 512
34
+ num_inference_steps: int = 20
35
+ device: str = ""
36
+ metadata: dict[str, Any] = field(default_factory=dict)
37
+
38
+
39
+ class InferenceBenchmark:
40
+ """Collect and analyze inference benchmarks.
41
+
42
+ Args:
43
+ model_name: Name of the model being benchmarked.
44
+ """
45
+
46
+ def __init__(self, model_name: str = "LandmarkDiff-ControlNet") -> None:
47
+ self.model_name = model_name
48
+ self.results: list[BenchmarkResult] = []
49
+
50
+ def add_result(
51
+ self,
52
+ config_name: str,
53
+ latency_ms: float,
54
+ throughput_fps: float = 0.0,
55
+ vram_gb: float = 0.0,
56
+ batch_size: int = 1,
57
+ resolution: int = 512,
58
+ num_inference_steps: int = 20,
59
+ device: str = "",
60
+ **metadata: Any,
61
+ ) -> None:
62
+ """Add a benchmark result."""
63
+ if throughput_fps == 0.0 and latency_ms > 0:
64
+ throughput_fps = 1000.0 / latency_ms * batch_size
65
+
66
+ self.results.append(
67
+ BenchmarkResult(
68
+ config_name=config_name,
69
+ latency_ms=latency_ms,
70
+ throughput_fps=throughput_fps,
71
+ vram_gb=vram_gb,
72
+ batch_size=batch_size,
73
+ resolution=resolution,
74
+ num_inference_steps=num_inference_steps,
75
+ device=device,
76
+ metadata=metadata,
77
+ )
78
+ )
79
+
80
+ def mean_latency(self, config_name: str | None = None) -> float:
81
+ """Mean latency in ms, optionally filtered by config."""
82
+ results = self._filter(config_name)
83
+ if not results:
84
+ return float("nan")
85
+ return sum(r.latency_ms for r in results) / len(results)
86
+
87
+ def p99_latency(self, config_name: str | None = None) -> float:
88
+ """P99 latency in ms."""
89
+ results = self._filter(config_name)
90
+ if not results:
91
+ return float("nan")
92
+ sorted_latencies = sorted(r.latency_ms for r in results)
93
+ idx = max(0, int(len(sorted_latencies) * 0.99) - 1)
94
+ return sorted_latencies[idx]
95
+
96
+ def mean_throughput(self, config_name: str | None = None) -> float:
97
+ """Mean throughput in FPS."""
98
+ results = self._filter(config_name)
99
+ if not results:
100
+ return float("nan")
101
+ return sum(r.throughput_fps for r in results) / len(results)
102
+
103
+ def max_vram(self, config_name: str | None = None) -> float:
104
+ """Maximum VRAM usage in GB."""
105
+ results = self._filter(config_name)
106
+ if not results:
107
+ return 0.0
108
+ return max(r.vram_gb for r in results)
109
+
110
+ def _filter(self, config_name: str | None) -> list[BenchmarkResult]:
111
+ if config_name is None:
112
+ return self.results
113
+ return [r for r in self.results if r.config_name == config_name]
114
+
115
+ @property
116
+ def config_names(self) -> list[str]:
117
+ """Unique config names in order."""
118
+ seen: dict[str, None] = {}
119
+ for r in self.results:
120
+ seen.setdefault(r.config_name, None)
121
+ return list(seen.keys())
122
+
123
+ def summary(self) -> str:
124
+ """Generate text summary table."""
125
+ configs = self.config_names
126
+ if not configs:
127
+ return "No benchmark results."
128
+
129
+ header = (
130
+ f"{'Config':>20s} | {'Mean(ms)':>10s} | {'P99(ms)':>10s}"
131
+ f" | {'FPS':>8s} | {'VRAM(GB)':>8s} | {'N':>4s}"
132
+ )
133
+ lines = [
134
+ f"Inference Benchmark: {self.model_name}",
135
+ header,
136
+ "-" * len(header),
137
+ ]
138
+
139
+ for cfg in configs:
140
+ results = self._filter(cfg)
141
+ lines.append(
142
+ f"{cfg:>20s} | "
143
+ f"{self.mean_latency(cfg):>10.1f} | "
144
+ f"{self.p99_latency(cfg):>10.1f} | "
145
+ f"{self.mean_throughput(cfg):>8.2f} | "
146
+ f"{self.max_vram(cfg):>8.1f} | "
147
+ f"{len(results):>4d}"
148
+ )
149
+
150
+ return "\n".join(lines)
151
+
152
+ def to_json(self, path: str | Path | None = None) -> str:
153
+ """Export results as JSON."""
154
+ data = {
155
+ "model_name": self.model_name,
156
+ "results": [
157
+ {
158
+ "config_name": r.config_name,
159
+ "latency_ms": r.latency_ms,
160
+ "throughput_fps": round(r.throughput_fps, 2),
161
+ "vram_gb": r.vram_gb,
162
+ "batch_size": r.batch_size,
163
+ "resolution": r.resolution,
164
+ "num_inference_steps": r.num_inference_steps,
165
+ "device": r.device,
166
+ }
167
+ for r in self.results
168
+ ],
169
+ "summary": {
170
+ cfg: {
171
+ "mean_latency_ms": round(self.mean_latency(cfg), 1),
172
+ "p99_latency_ms": round(self.p99_latency(cfg), 1),
173
+ "mean_fps": round(self.mean_throughput(cfg), 2),
174
+ "max_vram_gb": round(self.max_vram(cfg), 1),
175
+ "n_samples": len(self._filter(cfg)),
176
+ }
177
+ for cfg in self.config_names
178
+ },
179
+ }
180
+ j = json.dumps(data, indent=2)
181
+ if path:
182
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
183
+ Path(path).write_text(j)
184
+ return j
185
+
186
+
187
+ class Timer:
188
+ """Simple context manager for timing code blocks.
189
+
190
+ Usage:
191
+ with Timer() as t:
192
+ run_inference()
193
+ print(f"Took {t.elapsed_ms:.1f} ms")
194
+ """
195
+
196
+ def __init__(self) -> None:
197
+ self.start_time: float = 0.0
198
+ self.end_time: float = 0.0
199
+
200
+ @property
201
+ def elapsed_ms(self) -> float:
202
+ return (self.end_time - self.start_time) * 1000
203
+
204
+ @property
205
+ def elapsed_s(self) -> float:
206
+ return self.end_time - self.start_time
207
+
208
+ def __enter__(self) -> Timer:
209
+ self.start_time = time.perf_counter()
210
+ return self
211
+
212
+ def __exit__(self, *args: Any) -> None:
213
+ self.end_time = time.perf_counter()
@@ -0,0 +1,361 @@
1
+ """Checkpoint management with metadata tracking, best-model selection, and pruning.
2
+
3
+ Provides a central manager for training checkpoints that:
4
+ - Tracks per-checkpoint metadata (step, metrics, timestamps)
5
+ - Maintains symlinks to best/latest checkpoints
6
+ - Prunes old checkpoints to save disk space
7
+ - Supports multiple ranking metrics (loss, FID, SSIM, etc.)
8
+
9
+ Usage:
10
+ manager = CheckpointManager(
11
+ output_dir="checkpoints/phaseA",
12
+ keep_best=3,
13
+ keep_latest=5,
14
+ metric="loss",
15
+ lower_is_better=True,
16
+ )
17
+
18
+ # During training loop:
19
+ manager.save(
20
+ step=1000,
21
+ controlnet=controlnet,
22
+ ema_controlnet=ema_controlnet,
23
+ optimizer=optimizer,
24
+ scheduler=scheduler,
25
+ metrics={"loss": 0.0123, "val_ssim": 0.87},
26
+ )
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import json
32
+ import shutil
33
+ import time
34
+ from dataclasses import asdict, dataclass, field
35
+ from pathlib import Path
36
+ from typing import Any
37
+
38
+ import torch
39
+
40
+
41
+ @dataclass
42
+ class CheckpointMetadata:
43
+ """Metadata for a single checkpoint."""
44
+
45
+ step: int
46
+ timestamp: float
47
+ metrics: dict[str, float] = field(default_factory=dict)
48
+ epoch: int | None = None
49
+ phase: str = ""
50
+ is_best: bool = False
51
+ size_mb: float = 0.0
52
+
53
+ def to_dict(self) -> dict[str, Any]:
54
+ return asdict(self)
55
+
56
+ @classmethod
57
+ def from_dict(cls, d: dict[str, Any]) -> CheckpointMetadata:
58
+ return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__})
59
+
60
+
61
+ class CheckpointManager:
62
+ """Manages training checkpoints with pruning and best-model tracking.
63
+
64
+ Args:
65
+ output_dir: Base directory for checkpoints.
66
+ keep_best: Number of best checkpoints to retain.
67
+ keep_latest: Number of most recent checkpoints to retain.
68
+ metric: Metric name used to determine "best" checkpoint.
69
+ lower_is_better: If True, lower metric values are better (e.g. loss, FID).
70
+ prefix: Checkpoint directory prefix (default: "checkpoint").
71
+ """
72
+
73
+ INDEX_FILE = "checkpoint_index.json"
74
+
75
+ def __init__(
76
+ self,
77
+ output_dir: str | Path,
78
+ keep_best: int = 3,
79
+ keep_latest: int = 5,
80
+ metric: str = "loss",
81
+ lower_is_better: bool = True,
82
+ prefix: str = "checkpoint",
83
+ ) -> None:
84
+ self.output_dir = Path(output_dir)
85
+ self.output_dir.mkdir(parents=True, exist_ok=True)
86
+ self.keep_best = keep_best
87
+ self.keep_latest = keep_latest
88
+ self.metric = metric
89
+ self.lower_is_better = lower_is_better
90
+ self.prefix = prefix
91
+
92
+ self._index: dict[str, Any] = {"checkpoints": {}}
93
+ self._load_index()
94
+
95
+ # ------------------------------------------------------------------
96
+ # Index persistence
97
+ # ------------------------------------------------------------------
98
+
99
+ def _index_path(self) -> Path:
100
+ return self.output_dir / self.INDEX_FILE
101
+
102
+ def _load_index(self) -> None:
103
+ path = self._index_path()
104
+ if path.exists():
105
+ with open(path) as f:
106
+ self._index = json.load(f)
107
+ if "checkpoints" not in self._index:
108
+ self._index["checkpoints"] = {}
109
+
110
+ def _save_index(self) -> None:
111
+ with open(self._index_path(), "w") as f:
112
+ json.dump(self._index, f, indent=2)
113
+
114
+ # ------------------------------------------------------------------
115
+ # Save checkpoint
116
+ # ------------------------------------------------------------------
117
+
118
+ def save(
119
+ self,
120
+ step: int,
121
+ controlnet: torch.nn.Module,
122
+ ema_controlnet: torch.nn.Module,
123
+ optimizer: torch.optim.Optimizer,
124
+ scheduler: Any = None,
125
+ metrics: dict[str, float] | None = None,
126
+ epoch: int | None = None,
127
+ phase: str = "",
128
+ extra_state: dict[str, Any] | None = None,
129
+ ) -> Path:
130
+ """Save a checkpoint with metadata.
131
+
132
+ Args:
133
+ step: Current training step.
134
+ controlnet: ControlNet model (or any nn.Module).
135
+ ema_controlnet: EMA copy of the model.
136
+ optimizer: Optimizer state.
137
+ scheduler: Optional LR scheduler.
138
+ metrics: Dict of metric values at this step.
139
+ epoch: Optional epoch number.
140
+ phase: Training phase label (e.g. "A", "B").
141
+ extra_state: Any additional state to save.
142
+
143
+ Returns:
144
+ Path to the saved checkpoint directory.
145
+ """
146
+ ckpt_name = f"{self.prefix}-{step}"
147
+ ckpt_dir = self.output_dir / ckpt_name
148
+ ckpt_dir.mkdir(exist_ok=True)
149
+
150
+ # Save EMA weights (used for inference)
151
+ if hasattr(ema_controlnet, "save_pretrained"):
152
+ ema_controlnet.save_pretrained(ckpt_dir / "controlnet_ema")
153
+
154
+ # Save training state for resume
155
+ state = {
156
+ "controlnet": _get_state_dict(controlnet),
157
+ "ema_controlnet": _get_state_dict(ema_controlnet),
158
+ "optimizer": optimizer.state_dict(),
159
+ "global_step": step,
160
+ }
161
+ if scheduler is not None:
162
+ state["scheduler"] = scheduler.state_dict()
163
+ if extra_state:
164
+ state.update(extra_state)
165
+
166
+ torch.save(state, ckpt_dir / "training_state.pt")
167
+
168
+ # Compute checkpoint size
169
+ size_mb = sum(f.stat().st_size for f in ckpt_dir.rglob("*") if f.is_file()) / (1024 * 1024)
170
+
171
+ # Create metadata
172
+ meta = CheckpointMetadata(
173
+ step=step,
174
+ timestamp=time.time(),
175
+ metrics=metrics or {},
176
+ epoch=epoch,
177
+ phase=phase,
178
+ size_mb=round(size_mb, 1),
179
+ )
180
+
181
+ # Save metadata alongside checkpoint
182
+ with open(ckpt_dir / "metadata.json", "w") as f:
183
+ json.dump(meta.to_dict(), f, indent=2)
184
+
185
+ # Update index
186
+ self._index["checkpoints"][ckpt_name] = meta.to_dict()
187
+ self._update_best()
188
+ self._save_index()
189
+
190
+ # Update symlinks
191
+ self._update_symlinks()
192
+
193
+ # Prune old checkpoints
194
+ self._prune()
195
+
196
+ return ckpt_dir
197
+
198
+ # ------------------------------------------------------------------
199
+ # Best / latest tracking
200
+ # ------------------------------------------------------------------
201
+
202
+ def _update_best(self) -> None:
203
+ """Recompute which checkpoints are 'best'."""
204
+ entries = []
205
+ for name, meta in self._index["checkpoints"].items():
206
+ val = meta.get("metrics", {}).get(self.metric)
207
+ if val is not None:
208
+ entries.append((name, val, meta))
209
+
210
+ if not entries:
211
+ return
212
+
213
+ # Sort by metric
214
+ entries.sort(key=lambda x: x[1], reverse=not self.lower_is_better)
215
+
216
+ # Mark best
217
+ best_names = {e[0] for e in entries[: self.keep_best]}
218
+ for name, meta in self._index["checkpoints"].items():
219
+ meta["is_best"] = name in best_names
220
+
221
+ def _update_symlinks(self) -> None:
222
+ """Update 'latest' and 'best' symlinks."""
223
+ checkpoints = self._sorted_by_step()
224
+ if not checkpoints:
225
+ return
226
+
227
+ # Latest symlink
228
+ latest_name = checkpoints[-1]
229
+ latest_link = self.output_dir / "latest"
230
+ _force_symlink(self.output_dir / latest_name, latest_link)
231
+
232
+ # Best symlink
233
+ best_name = self.get_best_checkpoint_name()
234
+ if best_name:
235
+ best_link = self.output_dir / "best"
236
+ _force_symlink(self.output_dir / best_name, best_link)
237
+
238
+ def get_best_checkpoint_name(self) -> str | None:
239
+ """Return the name of the best checkpoint by tracked metric."""
240
+ best = None
241
+ best_val = None
242
+ for name, meta in self._index["checkpoints"].items():
243
+ val = meta.get("metrics", {}).get(self.metric)
244
+ if val is None:
245
+ continue
246
+ if (
247
+ best_val is None
248
+ or (self.lower_is_better and val < best_val)
249
+ or (not self.lower_is_better and val > best_val)
250
+ ):
251
+ best, best_val = name, val
252
+ return best
253
+
254
+ def get_best_metric_value(self) -> float | None:
255
+ """Return the best value of the tracked metric."""
256
+ name = self.get_best_checkpoint_name()
257
+ if name is None:
258
+ return None
259
+ return self._index["checkpoints"][name]["metrics"].get(self.metric)
260
+
261
+ # ------------------------------------------------------------------
262
+ # Pruning
263
+ # ------------------------------------------------------------------
264
+
265
+ def _sorted_by_step(self) -> list[str]:
266
+ """Return checkpoint names sorted by step (ascending)."""
267
+ items = list(self._index["checkpoints"].items())
268
+ items.sort(key=lambda x: x[1].get("step", 0))
269
+ return [name for name, _ in items]
270
+
271
+ def _prune(self) -> None:
272
+ """Remove old checkpoints, keeping best N and latest M."""
273
+ all_names = self._sorted_by_step()
274
+ if len(all_names) <= self.keep_latest:
275
+ return
276
+
277
+ # Determine which to keep
278
+ keep = set()
279
+
280
+ # Keep latest
281
+ for name in all_names[-self.keep_latest :]:
282
+ keep.add(name)
283
+
284
+ # Keep best
285
+ for name, meta in self._index["checkpoints"].items():
286
+ if meta.get("is_best", False):
287
+ keep.add(name)
288
+
289
+ # Delete the rest
290
+ for name in all_names:
291
+ if name not in keep:
292
+ ckpt_dir = self.output_dir / name
293
+ if ckpt_dir.exists():
294
+ shutil.rmtree(ckpt_dir)
295
+ self._index["checkpoints"].pop(name, None)
296
+
297
+ self._save_index()
298
+
299
+ # ------------------------------------------------------------------
300
+ # Queries
301
+ # ------------------------------------------------------------------
302
+
303
+ def list_checkpoints(self) -> list[dict[str, Any]]:
304
+ """Return metadata for all tracked checkpoints, sorted by step."""
305
+ result = []
306
+ for name in self._sorted_by_step():
307
+ meta = self._index["checkpoints"][name]
308
+ result.append({"name": name, **meta})
309
+ return result
310
+
311
+ def get_checkpoint_path(self, name: str) -> Path:
312
+ """Return the filesystem path for a checkpoint by name."""
313
+ return self.output_dir / name
314
+
315
+ def get_latest_step(self) -> int:
316
+ """Return the step of the most recent checkpoint, or 0."""
317
+ names = self._sorted_by_step()
318
+ if not names:
319
+ return 0
320
+ return self._index["checkpoints"][names[-1]].get("step", 0)
321
+
322
+ def total_size_mb(self) -> float:
323
+ """Return total disk size of all tracked checkpoints."""
324
+ return sum(meta.get("size_mb", 0.0) for meta in self._index["checkpoints"].values())
325
+
326
+ def summary(self) -> str:
327
+ """Return a human-readable summary of checkpoint state."""
328
+ ckpts = self.list_checkpoints()
329
+ if not ckpts:
330
+ return "No checkpoints saved."
331
+
332
+ lines = [
333
+ f"Checkpoints: {len(ckpts)} saved ({self.total_size_mb():.0f} MB total)",
334
+ f"Latest: step {self.get_latest_step()}",
335
+ ]
336
+
337
+ best_name = self.get_best_checkpoint_name()
338
+ best_val = self.get_best_metric_value()
339
+ if best_name and best_val is not None:
340
+ lines.append(f"Best ({self.metric}): {best_val:.6f} @ {best_name}")
341
+
342
+ return "\n".join(lines)
343
+
344
+
345
+ # ------------------------------------------------------------------
346
+ # Helpers
347
+ # ------------------------------------------------------------------
348
+
349
+
350
+ def _get_state_dict(module: torch.nn.Module) -> dict:
351
+ """Extract state dict, handling DDP wrapper."""
352
+ if hasattr(module, "module"):
353
+ return module.module.state_dict()
354
+ return module.state_dict()
355
+
356
+
357
+ def _force_symlink(target: Path, link: Path) -> None:
358
+ """Create or replace a symlink."""
359
+ if link.is_symlink() or link.exists():
360
+ link.unlink()
361
+ link.symlink_to(target.name)