landmarkdiff 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. landmarkdiff/__init__.py +40 -0
  2. landmarkdiff/__main__.py +207 -0
  3. landmarkdiff/api_client.py +316 -0
  4. landmarkdiff/arcface_torch.py +583 -0
  5. landmarkdiff/audit.py +338 -0
  6. landmarkdiff/augmentation.py +293 -0
  7. landmarkdiff/benchmark.py +213 -0
  8. landmarkdiff/checkpoint_manager.py +361 -0
  9. landmarkdiff/cli.py +252 -0
  10. landmarkdiff/clinical.py +223 -0
  11. landmarkdiff/conditioning.py +278 -0
  12. landmarkdiff/config.py +358 -0
  13. landmarkdiff/curriculum.py +191 -0
  14. landmarkdiff/data.py +405 -0
  15. landmarkdiff/data_version.py +301 -0
  16. landmarkdiff/displacement_model.py +745 -0
  17. landmarkdiff/ensemble.py +330 -0
  18. landmarkdiff/evaluation.py +415 -0
  19. landmarkdiff/experiment_tracker.py +231 -0
  20. landmarkdiff/face_verifier.py +947 -0
  21. landmarkdiff/fid.py +244 -0
  22. landmarkdiff/hyperparam.py +347 -0
  23. landmarkdiff/inference.py +754 -0
  24. landmarkdiff/landmarks.py +432 -0
  25. landmarkdiff/log.py +90 -0
  26. landmarkdiff/losses.py +348 -0
  27. landmarkdiff/manipulation.py +651 -0
  28. landmarkdiff/masking.py +316 -0
  29. landmarkdiff/metrics_agg.py +313 -0
  30. landmarkdiff/metrics_viz.py +464 -0
  31. landmarkdiff/model_registry.py +362 -0
  32. landmarkdiff/morphometry.py +342 -0
  33. landmarkdiff/postprocess.py +600 -0
  34. landmarkdiff/py.typed +0 -0
  35. landmarkdiff/safety.py +395 -0
  36. landmarkdiff/synthetic/__init__.py +23 -0
  37. landmarkdiff/synthetic/augmentation.py +188 -0
  38. landmarkdiff/synthetic/pair_generator.py +208 -0
  39. landmarkdiff/synthetic/tps_warp.py +273 -0
  40. landmarkdiff/validation.py +324 -0
  41. landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
  42. landmarkdiff-0.2.3.dist-info/RECORD +46 -0
  43. landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
  44. landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
  45. landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
  46. landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
landmarkdiff/fid.py ADDED
@@ -0,0 +1,244 @@
1
+ """Self-contained FID computation using InceptionV3 feature extraction.
2
+
3
+ Avoids dependency on torch-fidelity by implementing FID directly.
4
+ Supports GPU acceleration, batched processing, and caching.
5
+
6
+ Usage:
7
+ from landmarkdiff.fid import compute_fid_from_dirs, compute_fid_from_arrays
8
+
9
+ # From directories
10
+ fid = compute_fid_from_dirs("path/to/real", "path/to/generated")
11
+
12
+ # From numpy arrays
13
+ fid = compute_fid_from_arrays(real_images, generated_images)
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from pathlib import Path
19
+ from typing import Any
20
+
21
+ import numpy as np
22
+
23
+ try:
24
+ import torch
25
+ import torch.nn as nn
26
+ from torch.utils.data import DataLoader, Dataset
27
+
28
+ HAS_TORCH = True
29
+ except ImportError:
30
+ HAS_TORCH = False
31
+ Dataset = object # type: ignore[misc,assignment]
32
+
33
+
34
+ def _load_inception_v3() -> Any:
35
+ """Load InceptionV3 with pool3 features (2048-dim)."""
36
+ from torchvision.models import Inception_V3_Weights, inception_v3
37
+
38
+ model = inception_v3(weights=Inception_V3_Weights.IMAGENET1K_V1)
39
+ # We want features from the avg pool layer (2048-dim)
40
+ # Remove the final FC layer
41
+ model.fc = nn.Identity()
42
+ model.eval()
43
+ return model
44
+
45
+
46
+ class ImageFolderDataset(Dataset):
47
+ """Simple dataset that loads images from a directory."""
48
+
49
+ def __init__(self, directory: str | Path, image_size: int = 299):
50
+ self.directory = Path(directory)
51
+ exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
52
+ self.files = sorted(
53
+ f for f in self.directory.iterdir() if f.suffix.lower() in exts and f.is_file()
54
+ )
55
+ self.image_size = image_size
56
+
57
+ def __len__(self) -> int:
58
+ return len(self.files)
59
+
60
+ def __getitem__(self, idx: int) -> Any:
61
+ import cv2
62
+
63
+ img = cv2.imread(str(self.files[idx]))
64
+ if img is None:
65
+ # Return zeros if image can't be loaded
66
+ return torch.zeros(3, self.image_size, self.image_size)
67
+ img = cv2.resize(img, (self.image_size, self.image_size))
68
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
69
+ # Normalize to [0, 1] then ImageNet normalize
70
+ t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
71
+ t = _imagenet_normalize(t)
72
+ return t
73
+
74
+
75
+ class NumpyArrayDataset(Dataset):
76
+ """Dataset wrapping a list of numpy arrays."""
77
+
78
+ def __init__(self, images: list[np.ndarray], image_size: int = 299):
79
+ self.images = images
80
+ self.image_size = image_size
81
+
82
+ def __len__(self) -> int:
83
+ return len(self.images)
84
+
85
+ def __getitem__(self, idx: int) -> Any:
86
+ import cv2
87
+
88
+ img = self.images[idx]
89
+ if img.shape[:2] != (self.image_size, self.image_size):
90
+ img = cv2.resize(img, (self.image_size, self.image_size))
91
+ if img.ndim == 2:
92
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
93
+ elif img.shape[2] == 4:
94
+ img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
95
+ elif img.shape[2] == 3:
96
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
97
+ t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
98
+ t = _imagenet_normalize(t)
99
+ return t
100
+
101
+
102
+ def _imagenet_normalize(t: torch.Tensor) -> torch.Tensor:
103
+ """Apply ImageNet normalization."""
104
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
105
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
106
+ return (t - mean) / std
107
+
108
+
109
+ def _extract_features(
110
+ model: nn.Module,
111
+ dataloader: DataLoader,
112
+ device: torch.device,
113
+ ) -> np.ndarray:
114
+ """Extract InceptionV3 pool3 features from a dataloader."""
115
+ features = []
116
+ with torch.no_grad():
117
+ for batch in dataloader:
118
+ batch = batch.to(device)
119
+ feat = model(batch)
120
+ if isinstance(feat, tuple):
121
+ feat = feat[0]
122
+ features.append(feat.cpu().numpy())
123
+ return np.concatenate(features, axis=0)
124
+
125
+
126
+ def _compute_statistics(features: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
127
+ """Compute mean and covariance of feature vectors."""
128
+ if features.shape[0] < 2:
129
+ raise ValueError(f"FID requires at least 2 images, got {features.shape[0]}")
130
+ mu = np.mean(features, axis=0)
131
+ sigma = np.cov(features, rowvar=False)
132
+ return mu, sigma
133
+
134
+
135
+ def _calculate_fid(
136
+ mu1: np.ndarray,
137
+ sigma1: np.ndarray,
138
+ mu2: np.ndarray,
139
+ sigma2: np.ndarray,
140
+ ) -> float:
141
+ """Calculate FID given two sets of statistics.
142
+
143
+ FID = ||mu1 - mu2||^2 + Tr(sigma1 + sigma2 - 2*sqrt(sigma1*sigma2))
144
+ """
145
+ from scipy.linalg import sqrtm
146
+
147
+ diff = mu1 - mu2
148
+ covmean = sqrtm(sigma1 @ sigma2)
149
+
150
+ # Handle numerical instability
151
+ if np.iscomplexobj(covmean):
152
+ covmean = covmean.real
153
+
154
+ fid = diff @ diff + np.trace(sigma1 + sigma2 - 2 * covmean)
155
+ return float(max(fid, 0.0))
156
+
157
+
158
+ def compute_fid_from_dirs(
159
+ real_dir: str | Path,
160
+ generated_dir: str | Path,
161
+ batch_size: int = 32,
162
+ num_workers: int = 4,
163
+ device: str | None = None,
164
+ ) -> float:
165
+ """Compute FID between two directories of images.
166
+
167
+ Args:
168
+ real_dir: Path to real images.
169
+ generated_dir: Path to generated images.
170
+ batch_size: Batch size for feature extraction.
171
+ num_workers: DataLoader workers.
172
+ device: "cuda" or "cpu". Auto-detects if None.
173
+
174
+ Returns:
175
+ FID score (lower = better).
176
+ """
177
+ if not HAS_TORCH:
178
+ raise ImportError("PyTorch required for FID computation")
179
+
180
+ if device is None:
181
+ device = "cuda" if torch.cuda.is_available() else "cpu"
182
+ dev = torch.device(device)
183
+
184
+ model = _load_inception_v3().to(dev)
185
+
186
+ real_ds = ImageFolderDataset(real_dir)
187
+ gen_ds = ImageFolderDataset(generated_dir)
188
+
189
+ if len(real_ds) == 0 or len(gen_ds) == 0:
190
+ raise ValueError("Need at least 1 image in each directory")
191
+
192
+ real_loader = DataLoader(
193
+ real_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True
194
+ )
195
+ gen_loader = DataLoader(gen_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True)
196
+
197
+ real_features = _extract_features(model, real_loader, dev)
198
+ gen_features = _extract_features(model, gen_loader, dev)
199
+
200
+ mu_real, sigma_real = _compute_statistics(real_features)
201
+ mu_gen, sigma_gen = _compute_statistics(gen_features)
202
+
203
+ return _calculate_fid(mu_real, sigma_real, mu_gen, sigma_gen)
204
+
205
+
206
+ def compute_fid_from_arrays(
207
+ real_images: list[np.ndarray],
208
+ generated_images: list[np.ndarray],
209
+ batch_size: int = 32,
210
+ device: str | None = None,
211
+ ) -> float:
212
+ """Compute FID from lists of numpy arrays.
213
+
214
+ Args:
215
+ real_images: List of (H, W, 3) BGR uint8 images.
216
+ generated_images: List of (H, W, 3) BGR uint8 images.
217
+ batch_size: Batch size for feature extraction.
218
+ device: "cuda" or "cpu".
219
+
220
+ Returns:
221
+ FID score (lower = better).
222
+ """
223
+ if not HAS_TORCH:
224
+ raise ImportError("PyTorch required for FID computation")
225
+
226
+ if device is None:
227
+ device = "cuda" if torch.cuda.is_available() else "cpu"
228
+ dev = torch.device(device)
229
+
230
+ model = _load_inception_v3().to(dev)
231
+
232
+ real_ds = NumpyArrayDataset(real_images)
233
+ gen_ds = NumpyArrayDataset(generated_images)
234
+
235
+ real_loader = DataLoader(real_ds, batch_size=batch_size, num_workers=0)
236
+ gen_loader = DataLoader(gen_ds, batch_size=batch_size, num_workers=0)
237
+
238
+ real_features = _extract_features(model, real_loader, dev)
239
+ gen_features = _extract_features(model, gen_loader, dev)
240
+
241
+ mu_real, sigma_real = _compute_statistics(real_features)
242
+ mu_gen, sigma_gen = _compute_statistics(gen_features)
243
+
244
+ return _calculate_fid(mu_real, sigma_real, mu_gen, sigma_gen)
@@ -0,0 +1,347 @@
1
+ """Hyperparameter search utilities for systematic ControlNet tuning.
2
+
3
+ Supports grid search, random search, and Bayesian-inspired adaptive search
4
+ over training hyperparameters. Generates YAML configs for each trial and
5
+ tracks results for comparison.
6
+
7
+ Usage:
8
+ from landmarkdiff.hyperparam import HyperparamSearch, SearchSpace
9
+
10
+ space = SearchSpace()
11
+ space.add_float("learning_rate", 1e-6, 1e-4, log_scale=True)
12
+ space.add_choice("optimizer", ["adamw", "adam8bit"])
13
+ space.add_int("batch_size", 2, 8, step=2)
14
+
15
+ search = HyperparamSearch(space, output_dir="hp_search")
16
+ for trial in search.generate_trials(strategy="random", n_trials=20):
17
+ print(trial.config)
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import hashlib
23
+ import json
24
+ import math
25
+ from dataclasses import dataclass, field
26
+ from pathlib import Path
27
+ from typing import Any
28
+
29
+
30
+ def _to_native(val: Any) -> Any:
31
+ """Convert numpy/non-standard types to native Python for YAML serialization."""
32
+ if hasattr(val, "item"): # numpy scalar
33
+ return val.item()
34
+ return val
35
+
36
+
37
+ @dataclass
38
+ class ParamSpec:
39
+ """Specification for a single hyperparameter."""
40
+
41
+ name: str
42
+ param_type: str # "float", "int", "choice"
43
+ low: float | None = None
44
+ high: float | None = None
45
+ step: float | None = None
46
+ log_scale: bool = False
47
+ choices: list[Any] | None = None
48
+
49
+ def sample(self, rng) -> Any:
50
+ """Sample a value from this parameter spec."""
51
+ if self.param_type == "choice":
52
+ return rng.choice(self.choices)
53
+ elif self.param_type == "float":
54
+ if self.log_scale:
55
+ log_low = math.log(self.low)
56
+ log_high = math.log(self.high)
57
+ return float(math.exp(rng.uniform(log_low, log_high)))
58
+ return float(rng.uniform(self.low, self.high))
59
+ elif self.param_type == "int":
60
+ if self.step and self.step > 1:
61
+ n_steps = int((self.high - self.low) / self.step) + 1
62
+ idx = rng.integers(0, n_steps)
63
+ return int(self.low + idx * self.step)
64
+ return int(rng.integers(int(self.low), int(self.high) + 1))
65
+ raise ValueError(f"Unknown param type: {self.param_type}")
66
+
67
+ def grid_values(self, n_points: int = 5) -> list[Any]:
68
+ """Generate grid values for this parameter."""
69
+ if self.param_type == "choice":
70
+ return list(self.choices)
71
+ elif self.param_type == "int":
72
+ if self.step and self.step > 1:
73
+ vals = []
74
+ v = self.low
75
+ while v <= self.high:
76
+ vals.append(int(v))
77
+ v += self.step
78
+ return vals
79
+ return list(range(int(self.low), int(self.high) + 1))
80
+ elif self.param_type == "float":
81
+ if self.log_scale:
82
+ log_low = math.log(self.low)
83
+ log_high = math.log(self.high)
84
+ return [
85
+ float(math.exp(log_low + i * (log_high - log_low) / (n_points - 1)))
86
+ for i in range(n_points)
87
+ ]
88
+ return [
89
+ float(self.low + i * (self.high - self.low) / (n_points - 1))
90
+ for i in range(n_points)
91
+ ]
92
+ return []
93
+
94
+
95
+ class SearchSpace:
96
+ """Define the hyperparameter search space."""
97
+
98
+ def __init__(self) -> None:
99
+ self.params: dict[str, ParamSpec] = {}
100
+
101
+ def add_float(
102
+ self,
103
+ name: str,
104
+ low: float,
105
+ high: float,
106
+ log_scale: bool = False,
107
+ ) -> SearchSpace:
108
+ """Add a continuous float parameter."""
109
+ self.params[name] = ParamSpec(
110
+ name=name,
111
+ param_type="float",
112
+ low=low,
113
+ high=high,
114
+ log_scale=log_scale,
115
+ )
116
+ return self
117
+
118
+ def add_int(
119
+ self,
120
+ name: str,
121
+ low: int,
122
+ high: int,
123
+ step: int = 1,
124
+ ) -> SearchSpace:
125
+ """Add an integer parameter."""
126
+ self.params[name] = ParamSpec(
127
+ name=name,
128
+ param_type="int",
129
+ low=low,
130
+ high=high,
131
+ step=step,
132
+ )
133
+ return self
134
+
135
+ def add_choice(self, name: str, choices: list[Any]) -> SearchSpace:
136
+ """Add a categorical parameter."""
137
+ self.params[name] = ParamSpec(
138
+ name=name,
139
+ param_type="choice",
140
+ choices=choices,
141
+ )
142
+ return self
143
+
144
+ def __len__(self) -> int:
145
+ return len(self.params)
146
+
147
+ def __contains__(self, name: str) -> bool:
148
+ return name in self.params
149
+
150
+
151
+ @dataclass
152
+ class Trial:
153
+ """A single hyperparameter trial."""
154
+
155
+ trial_id: str
156
+ config: dict[str, Any]
157
+ result: dict[str, float] = field(default_factory=dict)
158
+ status: str = "pending" # pending, running, completed, failed
159
+
160
+ @property
161
+ def config_hash(self) -> str:
162
+ """Short hash of the config for deduplication."""
163
+ s = json.dumps(self.config, sort_keys=True, default=str)
164
+ return hashlib.md5(s.encode()).hexdigest()[:8]
165
+
166
+
167
+ class HyperparamSearch:
168
+ """Hyperparameter search engine.
169
+
170
+ Args:
171
+ space: Search space definition.
172
+ output_dir: Directory to save trial configs and results.
173
+ seed: Random seed for reproducibility.
174
+ """
175
+
176
+ def __init__(
177
+ self,
178
+ space: SearchSpace,
179
+ output_dir: str | Path = "hp_search",
180
+ seed: int = 42,
181
+ ) -> None:
182
+ self.space = space
183
+ self.output_dir = Path(output_dir)
184
+ self.seed = seed
185
+ self.trials: list[Trial] = []
186
+
187
+ def generate_trials(
188
+ self,
189
+ strategy: str = "random",
190
+ n_trials: int = 20,
191
+ grid_points: int = 5,
192
+ ) -> list[Trial]:
193
+ """Generate trial configurations.
194
+
195
+ Args:
196
+ strategy: "random" or "grid".
197
+ n_trials: Number of trials for random search.
198
+ grid_points: Points per continuous dimension for grid search.
199
+
200
+ Returns:
201
+ List of Trial objects with configs.
202
+ """
203
+ if strategy == "grid":
204
+ trials = self._grid_search(grid_points)
205
+ elif strategy == "random":
206
+ trials = self._random_search(n_trials)
207
+ else:
208
+ raise ValueError(f"Unknown strategy: {strategy}. Use 'random' or 'grid'.")
209
+
210
+ self.trials.extend(trials)
211
+ return trials
212
+
213
+ def _random_search(self, n_trials: int) -> list[Trial]:
214
+ """Generate random trial configs."""
215
+ import numpy as np
216
+
217
+ rng = np.random.default_rng(self.seed)
218
+ seen_hashes: set[str] = set()
219
+ trials: list[Trial] = []
220
+
221
+ max_attempts = n_trials * 10
222
+ attempts = 0
223
+ while len(trials) < n_trials and attempts < max_attempts:
224
+ attempts += 1
225
+ config = {name: spec.sample(rng) for name, spec in self.space.params.items()}
226
+ trial = Trial(
227
+ trial_id=f"trial_{len(trials):04d}",
228
+ config=config,
229
+ )
230
+ if trial.config_hash not in seen_hashes:
231
+ seen_hashes.add(trial.config_hash)
232
+ trials.append(trial)
233
+
234
+ return trials
235
+
236
+ def _grid_search(self, grid_points: int) -> list[Trial]:
237
+ """Generate grid search configs."""
238
+ import itertools
239
+
240
+ param_names = list(self.space.params.keys())
241
+ param_values = [self.space.params[name].grid_values(grid_points) for name in param_names]
242
+
243
+ trials = []
244
+ for combo in itertools.product(*param_values):
245
+ config = dict(zip(param_names, combo))
246
+ trial = Trial(
247
+ trial_id=f"trial_{len(trials):04d}",
248
+ config=config,
249
+ )
250
+ trials.append(trial)
251
+
252
+ return trials
253
+
254
+ def record_result(
255
+ self,
256
+ trial_id: str,
257
+ metrics: dict[str, float],
258
+ ) -> None:
259
+ """Record results for a trial."""
260
+ for trial in self.trials:
261
+ if trial.trial_id == trial_id:
262
+ trial.result = metrics
263
+ trial.status = "completed"
264
+ return
265
+ raise KeyError(f"Trial {trial_id} not found")
266
+
267
+ def best_trial(
268
+ self,
269
+ metric: str = "loss",
270
+ lower_is_better: bool = True,
271
+ ) -> Trial | None:
272
+ """Get the best completed trial by a metric."""
273
+ completed = [t for t in self.trials if t.status == "completed" and metric in t.result]
274
+ if not completed:
275
+ return None
276
+ return (min if lower_is_better else max)(completed, key=lambda t: t.result[metric])
277
+
278
+ def save_configs(self) -> Path:
279
+ """Save all trial configs as YAML files.
280
+
281
+ Returns:
282
+ Output directory path.
283
+ """
284
+ import yaml
285
+
286
+ self.output_dir.mkdir(parents=True, exist_ok=True)
287
+ for trial in self.trials:
288
+ cfg_path = self.output_dir / f"{trial.trial_id}.yaml"
289
+ # Convert numpy types to native Python for YAML serialization
290
+ native_config = {k: _to_native(v) for k, v in trial.config.items()}
291
+ with open(cfg_path, "w") as f:
292
+ yaml.safe_dump(
293
+ {"trial_id": trial.trial_id, **native_config},
294
+ f,
295
+ default_flow_style=False,
296
+ )
297
+
298
+ # Save summary index
299
+ index = {
300
+ "seed": self.seed,
301
+ "n_trials": len(self.trials),
302
+ "params": {
303
+ name: {
304
+ "type": spec.param_type,
305
+ "low": spec.low,
306
+ "high": spec.high,
307
+ "choices": spec.choices,
308
+ "log_scale": spec.log_scale,
309
+ }
310
+ for name, spec in self.space.params.items()
311
+ },
312
+ }
313
+ with open(self.output_dir / "search_index.json", "w") as f:
314
+ json.dump(index, f, indent=2, default=str)
315
+
316
+ return self.output_dir
317
+
318
+ def results_table(self) -> str:
319
+ """Format results as a text table."""
320
+ completed = [t for t in self.trials if t.status == "completed"]
321
+ if not completed:
322
+ return "No completed trials."
323
+
324
+ # Collect all metric names
325
+ metric_names = sorted(set().union(*(t.result.keys() for t in completed)))
326
+ param_names = sorted(self.space.params.keys())
327
+
328
+ # Header
329
+ cols = ["Trial"] + param_names + metric_names
330
+ lines = [" | ".join(f"{c:>12s}" for c in cols)]
331
+ lines.append("-" * len(lines[0]))
332
+
333
+ # Rows
334
+ for trial in completed:
335
+ parts = [f"{trial.trial_id:>12s}"]
336
+ for p in param_names:
337
+ val = trial.config.get(p, "")
338
+ if isinstance(val, float):
339
+ parts.append(f"{val:>12.6f}")
340
+ else:
341
+ parts.append(f"{val!s:>12s}")
342
+ for m in metric_names:
343
+ val = trial.result.get(m, float("nan"))
344
+ parts.append(f"{val:>12.4f}")
345
+ lines.append(" | ".join(parts))
346
+
347
+ return "\n".join(lines)