landmarkdiff 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- landmarkdiff/__init__.py +40 -0
- landmarkdiff/__main__.py +207 -0
- landmarkdiff/api_client.py +316 -0
- landmarkdiff/arcface_torch.py +583 -0
- landmarkdiff/audit.py +338 -0
- landmarkdiff/augmentation.py +293 -0
- landmarkdiff/benchmark.py +213 -0
- landmarkdiff/checkpoint_manager.py +361 -0
- landmarkdiff/cli.py +252 -0
- landmarkdiff/clinical.py +223 -0
- landmarkdiff/conditioning.py +278 -0
- landmarkdiff/config.py +358 -0
- landmarkdiff/curriculum.py +191 -0
- landmarkdiff/data.py +405 -0
- landmarkdiff/data_version.py +301 -0
- landmarkdiff/displacement_model.py +745 -0
- landmarkdiff/ensemble.py +330 -0
- landmarkdiff/evaluation.py +415 -0
- landmarkdiff/experiment_tracker.py +231 -0
- landmarkdiff/face_verifier.py +947 -0
- landmarkdiff/fid.py +244 -0
- landmarkdiff/hyperparam.py +347 -0
- landmarkdiff/inference.py +754 -0
- landmarkdiff/landmarks.py +432 -0
- landmarkdiff/log.py +90 -0
- landmarkdiff/losses.py +348 -0
- landmarkdiff/manipulation.py +651 -0
- landmarkdiff/masking.py +316 -0
- landmarkdiff/metrics_agg.py +313 -0
- landmarkdiff/metrics_viz.py +464 -0
- landmarkdiff/model_registry.py +362 -0
- landmarkdiff/morphometry.py +342 -0
- landmarkdiff/postprocess.py +600 -0
- landmarkdiff/py.typed +0 -0
- landmarkdiff/safety.py +395 -0
- landmarkdiff/synthetic/__init__.py +23 -0
- landmarkdiff/synthetic/augmentation.py +188 -0
- landmarkdiff/synthetic/pair_generator.py +208 -0
- landmarkdiff/synthetic/tps_warp.py +273 -0
- landmarkdiff/validation.py +324 -0
- landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
- landmarkdiff-0.2.3.dist-info/RECORD +46 -0
- landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
- landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
- landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
- landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
landmarkdiff/fid.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
"""Self-contained FID computation using InceptionV3 feature extraction.
|
|
2
|
+
|
|
3
|
+
Avoids dependency on torch-fidelity by implementing FID directly.
|
|
4
|
+
Supports GPU acceleration, batched processing, and caching.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
from landmarkdiff.fid import compute_fid_from_dirs, compute_fid_from_arrays
|
|
8
|
+
|
|
9
|
+
# From directories
|
|
10
|
+
fid = compute_fid_from_dirs("path/to/real", "path/to/generated")
|
|
11
|
+
|
|
12
|
+
# From numpy arrays
|
|
13
|
+
fid = compute_fid_from_arrays(real_images, generated_images)
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
import torch
|
|
25
|
+
import torch.nn as nn
|
|
26
|
+
from torch.utils.data import DataLoader, Dataset
|
|
27
|
+
|
|
28
|
+
HAS_TORCH = True
|
|
29
|
+
except ImportError:
|
|
30
|
+
HAS_TORCH = False
|
|
31
|
+
Dataset = object # type: ignore[misc,assignment]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _load_inception_v3() -> Any:
|
|
35
|
+
"""Load InceptionV3 with pool3 features (2048-dim)."""
|
|
36
|
+
from torchvision.models import Inception_V3_Weights, inception_v3
|
|
37
|
+
|
|
38
|
+
model = inception_v3(weights=Inception_V3_Weights.IMAGENET1K_V1)
|
|
39
|
+
# We want features from the avg pool layer (2048-dim)
|
|
40
|
+
# Remove the final FC layer
|
|
41
|
+
model.fc = nn.Identity()
|
|
42
|
+
model.eval()
|
|
43
|
+
return model
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ImageFolderDataset(Dataset):
|
|
47
|
+
"""Simple dataset that loads images from a directory."""
|
|
48
|
+
|
|
49
|
+
def __init__(self, directory: str | Path, image_size: int = 299):
|
|
50
|
+
self.directory = Path(directory)
|
|
51
|
+
exts = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
|
|
52
|
+
self.files = sorted(
|
|
53
|
+
f for f in self.directory.iterdir() if f.suffix.lower() in exts and f.is_file()
|
|
54
|
+
)
|
|
55
|
+
self.image_size = image_size
|
|
56
|
+
|
|
57
|
+
def __len__(self) -> int:
|
|
58
|
+
return len(self.files)
|
|
59
|
+
|
|
60
|
+
def __getitem__(self, idx: int) -> Any:
|
|
61
|
+
import cv2
|
|
62
|
+
|
|
63
|
+
img = cv2.imread(str(self.files[idx]))
|
|
64
|
+
if img is None:
|
|
65
|
+
# Return zeros if image can't be loaded
|
|
66
|
+
return torch.zeros(3, self.image_size, self.image_size)
|
|
67
|
+
img = cv2.resize(img, (self.image_size, self.image_size))
|
|
68
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
69
|
+
# Normalize to [0, 1] then ImageNet normalize
|
|
70
|
+
t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
|
|
71
|
+
t = _imagenet_normalize(t)
|
|
72
|
+
return t
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class NumpyArrayDataset(Dataset):
|
|
76
|
+
"""Dataset wrapping a list of numpy arrays."""
|
|
77
|
+
|
|
78
|
+
def __init__(self, images: list[np.ndarray], image_size: int = 299):
|
|
79
|
+
self.images = images
|
|
80
|
+
self.image_size = image_size
|
|
81
|
+
|
|
82
|
+
def __len__(self) -> int:
|
|
83
|
+
return len(self.images)
|
|
84
|
+
|
|
85
|
+
def __getitem__(self, idx: int) -> Any:
|
|
86
|
+
import cv2
|
|
87
|
+
|
|
88
|
+
img = self.images[idx]
|
|
89
|
+
if img.shape[:2] != (self.image_size, self.image_size):
|
|
90
|
+
img = cv2.resize(img, (self.image_size, self.image_size))
|
|
91
|
+
if img.ndim == 2:
|
|
92
|
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
|
93
|
+
elif img.shape[2] == 4:
|
|
94
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
|
|
95
|
+
elif img.shape[2] == 3:
|
|
96
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
97
|
+
t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1)
|
|
98
|
+
t = _imagenet_normalize(t)
|
|
99
|
+
return t
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _imagenet_normalize(t: torch.Tensor) -> torch.Tensor:
|
|
103
|
+
"""Apply ImageNet normalization."""
|
|
104
|
+
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
|
105
|
+
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
|
106
|
+
return (t - mean) / std
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _extract_features(
|
|
110
|
+
model: nn.Module,
|
|
111
|
+
dataloader: DataLoader,
|
|
112
|
+
device: torch.device,
|
|
113
|
+
) -> np.ndarray:
|
|
114
|
+
"""Extract InceptionV3 pool3 features from a dataloader."""
|
|
115
|
+
features = []
|
|
116
|
+
with torch.no_grad():
|
|
117
|
+
for batch in dataloader:
|
|
118
|
+
batch = batch.to(device)
|
|
119
|
+
feat = model(batch)
|
|
120
|
+
if isinstance(feat, tuple):
|
|
121
|
+
feat = feat[0]
|
|
122
|
+
features.append(feat.cpu().numpy())
|
|
123
|
+
return np.concatenate(features, axis=0)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _compute_statistics(features: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
|
127
|
+
"""Compute mean and covariance of feature vectors."""
|
|
128
|
+
if features.shape[0] < 2:
|
|
129
|
+
raise ValueError(f"FID requires at least 2 images, got {features.shape[0]}")
|
|
130
|
+
mu = np.mean(features, axis=0)
|
|
131
|
+
sigma = np.cov(features, rowvar=False)
|
|
132
|
+
return mu, sigma
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _calculate_fid(
|
|
136
|
+
mu1: np.ndarray,
|
|
137
|
+
sigma1: np.ndarray,
|
|
138
|
+
mu2: np.ndarray,
|
|
139
|
+
sigma2: np.ndarray,
|
|
140
|
+
) -> float:
|
|
141
|
+
"""Calculate FID given two sets of statistics.
|
|
142
|
+
|
|
143
|
+
FID = ||mu1 - mu2||^2 + Tr(sigma1 + sigma2 - 2*sqrt(sigma1*sigma2))
|
|
144
|
+
"""
|
|
145
|
+
from scipy.linalg import sqrtm
|
|
146
|
+
|
|
147
|
+
diff = mu1 - mu2
|
|
148
|
+
covmean = sqrtm(sigma1 @ sigma2)
|
|
149
|
+
|
|
150
|
+
# Handle numerical instability
|
|
151
|
+
if np.iscomplexobj(covmean):
|
|
152
|
+
covmean = covmean.real
|
|
153
|
+
|
|
154
|
+
fid = diff @ diff + np.trace(sigma1 + sigma2 - 2 * covmean)
|
|
155
|
+
return float(max(fid, 0.0))
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def compute_fid_from_dirs(
|
|
159
|
+
real_dir: str | Path,
|
|
160
|
+
generated_dir: str | Path,
|
|
161
|
+
batch_size: int = 32,
|
|
162
|
+
num_workers: int = 4,
|
|
163
|
+
device: str | None = None,
|
|
164
|
+
) -> float:
|
|
165
|
+
"""Compute FID between two directories of images.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
real_dir: Path to real images.
|
|
169
|
+
generated_dir: Path to generated images.
|
|
170
|
+
batch_size: Batch size for feature extraction.
|
|
171
|
+
num_workers: DataLoader workers.
|
|
172
|
+
device: "cuda" or "cpu". Auto-detects if None.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
FID score (lower = better).
|
|
176
|
+
"""
|
|
177
|
+
if not HAS_TORCH:
|
|
178
|
+
raise ImportError("PyTorch required for FID computation")
|
|
179
|
+
|
|
180
|
+
if device is None:
|
|
181
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
182
|
+
dev = torch.device(device)
|
|
183
|
+
|
|
184
|
+
model = _load_inception_v3().to(dev)
|
|
185
|
+
|
|
186
|
+
real_ds = ImageFolderDataset(real_dir)
|
|
187
|
+
gen_ds = ImageFolderDataset(generated_dir)
|
|
188
|
+
|
|
189
|
+
if len(real_ds) == 0 or len(gen_ds) == 0:
|
|
190
|
+
raise ValueError("Need at least 1 image in each directory")
|
|
191
|
+
|
|
192
|
+
real_loader = DataLoader(
|
|
193
|
+
real_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True
|
|
194
|
+
)
|
|
195
|
+
gen_loader = DataLoader(gen_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True)
|
|
196
|
+
|
|
197
|
+
real_features = _extract_features(model, real_loader, dev)
|
|
198
|
+
gen_features = _extract_features(model, gen_loader, dev)
|
|
199
|
+
|
|
200
|
+
mu_real, sigma_real = _compute_statistics(real_features)
|
|
201
|
+
mu_gen, sigma_gen = _compute_statistics(gen_features)
|
|
202
|
+
|
|
203
|
+
return _calculate_fid(mu_real, sigma_real, mu_gen, sigma_gen)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def compute_fid_from_arrays(
|
|
207
|
+
real_images: list[np.ndarray],
|
|
208
|
+
generated_images: list[np.ndarray],
|
|
209
|
+
batch_size: int = 32,
|
|
210
|
+
device: str | None = None,
|
|
211
|
+
) -> float:
|
|
212
|
+
"""Compute FID from lists of numpy arrays.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
real_images: List of (H, W, 3) BGR uint8 images.
|
|
216
|
+
generated_images: List of (H, W, 3) BGR uint8 images.
|
|
217
|
+
batch_size: Batch size for feature extraction.
|
|
218
|
+
device: "cuda" or "cpu".
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
FID score (lower = better).
|
|
222
|
+
"""
|
|
223
|
+
if not HAS_TORCH:
|
|
224
|
+
raise ImportError("PyTorch required for FID computation")
|
|
225
|
+
|
|
226
|
+
if device is None:
|
|
227
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
228
|
+
dev = torch.device(device)
|
|
229
|
+
|
|
230
|
+
model = _load_inception_v3().to(dev)
|
|
231
|
+
|
|
232
|
+
real_ds = NumpyArrayDataset(real_images)
|
|
233
|
+
gen_ds = NumpyArrayDataset(generated_images)
|
|
234
|
+
|
|
235
|
+
real_loader = DataLoader(real_ds, batch_size=batch_size, num_workers=0)
|
|
236
|
+
gen_loader = DataLoader(gen_ds, batch_size=batch_size, num_workers=0)
|
|
237
|
+
|
|
238
|
+
real_features = _extract_features(model, real_loader, dev)
|
|
239
|
+
gen_features = _extract_features(model, gen_loader, dev)
|
|
240
|
+
|
|
241
|
+
mu_real, sigma_real = _compute_statistics(real_features)
|
|
242
|
+
mu_gen, sigma_gen = _compute_statistics(gen_features)
|
|
243
|
+
|
|
244
|
+
return _calculate_fid(mu_real, sigma_real, mu_gen, sigma_gen)
|
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
"""Hyperparameter search utilities for systematic ControlNet tuning.
|
|
2
|
+
|
|
3
|
+
Supports grid search, random search, and Bayesian-inspired adaptive search
|
|
4
|
+
over training hyperparameters. Generates YAML configs for each trial and
|
|
5
|
+
tracks results for comparison.
|
|
6
|
+
|
|
7
|
+
Usage:
|
|
8
|
+
from landmarkdiff.hyperparam import HyperparamSearch, SearchSpace
|
|
9
|
+
|
|
10
|
+
space = SearchSpace()
|
|
11
|
+
space.add_float("learning_rate", 1e-6, 1e-4, log_scale=True)
|
|
12
|
+
space.add_choice("optimizer", ["adamw", "adam8bit"])
|
|
13
|
+
space.add_int("batch_size", 2, 8, step=2)
|
|
14
|
+
|
|
15
|
+
search = HyperparamSearch(space, output_dir="hp_search")
|
|
16
|
+
for trial in search.generate_trials(strategy="random", n_trials=20):
|
|
17
|
+
print(trial.config)
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import hashlib
|
|
23
|
+
import json
|
|
24
|
+
import math
|
|
25
|
+
from dataclasses import dataclass, field
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
from typing import Any
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _to_native(val: Any) -> Any:
|
|
31
|
+
"""Convert numpy/non-standard types to native Python for YAML serialization."""
|
|
32
|
+
if hasattr(val, "item"): # numpy scalar
|
|
33
|
+
return val.item()
|
|
34
|
+
return val
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class ParamSpec:
|
|
39
|
+
"""Specification for a single hyperparameter."""
|
|
40
|
+
|
|
41
|
+
name: str
|
|
42
|
+
param_type: str # "float", "int", "choice"
|
|
43
|
+
low: float | None = None
|
|
44
|
+
high: float | None = None
|
|
45
|
+
step: float | None = None
|
|
46
|
+
log_scale: bool = False
|
|
47
|
+
choices: list[Any] | None = None
|
|
48
|
+
|
|
49
|
+
def sample(self, rng) -> Any:
|
|
50
|
+
"""Sample a value from this parameter spec."""
|
|
51
|
+
if self.param_type == "choice":
|
|
52
|
+
return rng.choice(self.choices)
|
|
53
|
+
elif self.param_type == "float":
|
|
54
|
+
if self.log_scale:
|
|
55
|
+
log_low = math.log(self.low)
|
|
56
|
+
log_high = math.log(self.high)
|
|
57
|
+
return float(math.exp(rng.uniform(log_low, log_high)))
|
|
58
|
+
return float(rng.uniform(self.low, self.high))
|
|
59
|
+
elif self.param_type == "int":
|
|
60
|
+
if self.step and self.step > 1:
|
|
61
|
+
n_steps = int((self.high - self.low) / self.step) + 1
|
|
62
|
+
idx = rng.integers(0, n_steps)
|
|
63
|
+
return int(self.low + idx * self.step)
|
|
64
|
+
return int(rng.integers(int(self.low), int(self.high) + 1))
|
|
65
|
+
raise ValueError(f"Unknown param type: {self.param_type}")
|
|
66
|
+
|
|
67
|
+
def grid_values(self, n_points: int = 5) -> list[Any]:
|
|
68
|
+
"""Generate grid values for this parameter."""
|
|
69
|
+
if self.param_type == "choice":
|
|
70
|
+
return list(self.choices)
|
|
71
|
+
elif self.param_type == "int":
|
|
72
|
+
if self.step and self.step > 1:
|
|
73
|
+
vals = []
|
|
74
|
+
v = self.low
|
|
75
|
+
while v <= self.high:
|
|
76
|
+
vals.append(int(v))
|
|
77
|
+
v += self.step
|
|
78
|
+
return vals
|
|
79
|
+
return list(range(int(self.low), int(self.high) + 1))
|
|
80
|
+
elif self.param_type == "float":
|
|
81
|
+
if self.log_scale:
|
|
82
|
+
log_low = math.log(self.low)
|
|
83
|
+
log_high = math.log(self.high)
|
|
84
|
+
return [
|
|
85
|
+
float(math.exp(log_low + i * (log_high - log_low) / (n_points - 1)))
|
|
86
|
+
for i in range(n_points)
|
|
87
|
+
]
|
|
88
|
+
return [
|
|
89
|
+
float(self.low + i * (self.high - self.low) / (n_points - 1))
|
|
90
|
+
for i in range(n_points)
|
|
91
|
+
]
|
|
92
|
+
return []
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class SearchSpace:
|
|
96
|
+
"""Define the hyperparameter search space."""
|
|
97
|
+
|
|
98
|
+
def __init__(self) -> None:
|
|
99
|
+
self.params: dict[str, ParamSpec] = {}
|
|
100
|
+
|
|
101
|
+
def add_float(
|
|
102
|
+
self,
|
|
103
|
+
name: str,
|
|
104
|
+
low: float,
|
|
105
|
+
high: float,
|
|
106
|
+
log_scale: bool = False,
|
|
107
|
+
) -> SearchSpace:
|
|
108
|
+
"""Add a continuous float parameter."""
|
|
109
|
+
self.params[name] = ParamSpec(
|
|
110
|
+
name=name,
|
|
111
|
+
param_type="float",
|
|
112
|
+
low=low,
|
|
113
|
+
high=high,
|
|
114
|
+
log_scale=log_scale,
|
|
115
|
+
)
|
|
116
|
+
return self
|
|
117
|
+
|
|
118
|
+
def add_int(
|
|
119
|
+
self,
|
|
120
|
+
name: str,
|
|
121
|
+
low: int,
|
|
122
|
+
high: int,
|
|
123
|
+
step: int = 1,
|
|
124
|
+
) -> SearchSpace:
|
|
125
|
+
"""Add an integer parameter."""
|
|
126
|
+
self.params[name] = ParamSpec(
|
|
127
|
+
name=name,
|
|
128
|
+
param_type="int",
|
|
129
|
+
low=low,
|
|
130
|
+
high=high,
|
|
131
|
+
step=step,
|
|
132
|
+
)
|
|
133
|
+
return self
|
|
134
|
+
|
|
135
|
+
def add_choice(self, name: str, choices: list[Any]) -> SearchSpace:
|
|
136
|
+
"""Add a categorical parameter."""
|
|
137
|
+
self.params[name] = ParamSpec(
|
|
138
|
+
name=name,
|
|
139
|
+
param_type="choice",
|
|
140
|
+
choices=choices,
|
|
141
|
+
)
|
|
142
|
+
return self
|
|
143
|
+
|
|
144
|
+
def __len__(self) -> int:
|
|
145
|
+
return len(self.params)
|
|
146
|
+
|
|
147
|
+
def __contains__(self, name: str) -> bool:
|
|
148
|
+
return name in self.params
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@dataclass
|
|
152
|
+
class Trial:
|
|
153
|
+
"""A single hyperparameter trial."""
|
|
154
|
+
|
|
155
|
+
trial_id: str
|
|
156
|
+
config: dict[str, Any]
|
|
157
|
+
result: dict[str, float] = field(default_factory=dict)
|
|
158
|
+
status: str = "pending" # pending, running, completed, failed
|
|
159
|
+
|
|
160
|
+
@property
|
|
161
|
+
def config_hash(self) -> str:
|
|
162
|
+
"""Short hash of the config for deduplication."""
|
|
163
|
+
s = json.dumps(self.config, sort_keys=True, default=str)
|
|
164
|
+
return hashlib.md5(s.encode()).hexdigest()[:8]
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class HyperparamSearch:
|
|
168
|
+
"""Hyperparameter search engine.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
space: Search space definition.
|
|
172
|
+
output_dir: Directory to save trial configs and results.
|
|
173
|
+
seed: Random seed for reproducibility.
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
def __init__(
|
|
177
|
+
self,
|
|
178
|
+
space: SearchSpace,
|
|
179
|
+
output_dir: str | Path = "hp_search",
|
|
180
|
+
seed: int = 42,
|
|
181
|
+
) -> None:
|
|
182
|
+
self.space = space
|
|
183
|
+
self.output_dir = Path(output_dir)
|
|
184
|
+
self.seed = seed
|
|
185
|
+
self.trials: list[Trial] = []
|
|
186
|
+
|
|
187
|
+
def generate_trials(
|
|
188
|
+
self,
|
|
189
|
+
strategy: str = "random",
|
|
190
|
+
n_trials: int = 20,
|
|
191
|
+
grid_points: int = 5,
|
|
192
|
+
) -> list[Trial]:
|
|
193
|
+
"""Generate trial configurations.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
strategy: "random" or "grid".
|
|
197
|
+
n_trials: Number of trials for random search.
|
|
198
|
+
grid_points: Points per continuous dimension for grid search.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
List of Trial objects with configs.
|
|
202
|
+
"""
|
|
203
|
+
if strategy == "grid":
|
|
204
|
+
trials = self._grid_search(grid_points)
|
|
205
|
+
elif strategy == "random":
|
|
206
|
+
trials = self._random_search(n_trials)
|
|
207
|
+
else:
|
|
208
|
+
raise ValueError(f"Unknown strategy: {strategy}. Use 'random' or 'grid'.")
|
|
209
|
+
|
|
210
|
+
self.trials.extend(trials)
|
|
211
|
+
return trials
|
|
212
|
+
|
|
213
|
+
def _random_search(self, n_trials: int) -> list[Trial]:
|
|
214
|
+
"""Generate random trial configs."""
|
|
215
|
+
import numpy as np
|
|
216
|
+
|
|
217
|
+
rng = np.random.default_rng(self.seed)
|
|
218
|
+
seen_hashes: set[str] = set()
|
|
219
|
+
trials: list[Trial] = []
|
|
220
|
+
|
|
221
|
+
max_attempts = n_trials * 10
|
|
222
|
+
attempts = 0
|
|
223
|
+
while len(trials) < n_trials and attempts < max_attempts:
|
|
224
|
+
attempts += 1
|
|
225
|
+
config = {name: spec.sample(rng) for name, spec in self.space.params.items()}
|
|
226
|
+
trial = Trial(
|
|
227
|
+
trial_id=f"trial_{len(trials):04d}",
|
|
228
|
+
config=config,
|
|
229
|
+
)
|
|
230
|
+
if trial.config_hash not in seen_hashes:
|
|
231
|
+
seen_hashes.add(trial.config_hash)
|
|
232
|
+
trials.append(trial)
|
|
233
|
+
|
|
234
|
+
return trials
|
|
235
|
+
|
|
236
|
+
def _grid_search(self, grid_points: int) -> list[Trial]:
|
|
237
|
+
"""Generate grid search configs."""
|
|
238
|
+
import itertools
|
|
239
|
+
|
|
240
|
+
param_names = list(self.space.params.keys())
|
|
241
|
+
param_values = [self.space.params[name].grid_values(grid_points) for name in param_names]
|
|
242
|
+
|
|
243
|
+
trials = []
|
|
244
|
+
for combo in itertools.product(*param_values):
|
|
245
|
+
config = dict(zip(param_names, combo))
|
|
246
|
+
trial = Trial(
|
|
247
|
+
trial_id=f"trial_{len(trials):04d}",
|
|
248
|
+
config=config,
|
|
249
|
+
)
|
|
250
|
+
trials.append(trial)
|
|
251
|
+
|
|
252
|
+
return trials
|
|
253
|
+
|
|
254
|
+
def record_result(
|
|
255
|
+
self,
|
|
256
|
+
trial_id: str,
|
|
257
|
+
metrics: dict[str, float],
|
|
258
|
+
) -> None:
|
|
259
|
+
"""Record results for a trial."""
|
|
260
|
+
for trial in self.trials:
|
|
261
|
+
if trial.trial_id == trial_id:
|
|
262
|
+
trial.result = metrics
|
|
263
|
+
trial.status = "completed"
|
|
264
|
+
return
|
|
265
|
+
raise KeyError(f"Trial {trial_id} not found")
|
|
266
|
+
|
|
267
|
+
def best_trial(
|
|
268
|
+
self,
|
|
269
|
+
metric: str = "loss",
|
|
270
|
+
lower_is_better: bool = True,
|
|
271
|
+
) -> Trial | None:
|
|
272
|
+
"""Get the best completed trial by a metric."""
|
|
273
|
+
completed = [t for t in self.trials if t.status == "completed" and metric in t.result]
|
|
274
|
+
if not completed:
|
|
275
|
+
return None
|
|
276
|
+
return (min if lower_is_better else max)(completed, key=lambda t: t.result[metric])
|
|
277
|
+
|
|
278
|
+
def save_configs(self) -> Path:
|
|
279
|
+
"""Save all trial configs as YAML files.
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
Output directory path.
|
|
283
|
+
"""
|
|
284
|
+
import yaml
|
|
285
|
+
|
|
286
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
287
|
+
for trial in self.trials:
|
|
288
|
+
cfg_path = self.output_dir / f"{trial.trial_id}.yaml"
|
|
289
|
+
# Convert numpy types to native Python for YAML serialization
|
|
290
|
+
native_config = {k: _to_native(v) for k, v in trial.config.items()}
|
|
291
|
+
with open(cfg_path, "w") as f:
|
|
292
|
+
yaml.safe_dump(
|
|
293
|
+
{"trial_id": trial.trial_id, **native_config},
|
|
294
|
+
f,
|
|
295
|
+
default_flow_style=False,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
# Save summary index
|
|
299
|
+
index = {
|
|
300
|
+
"seed": self.seed,
|
|
301
|
+
"n_trials": len(self.trials),
|
|
302
|
+
"params": {
|
|
303
|
+
name: {
|
|
304
|
+
"type": spec.param_type,
|
|
305
|
+
"low": spec.low,
|
|
306
|
+
"high": spec.high,
|
|
307
|
+
"choices": spec.choices,
|
|
308
|
+
"log_scale": spec.log_scale,
|
|
309
|
+
}
|
|
310
|
+
for name, spec in self.space.params.items()
|
|
311
|
+
},
|
|
312
|
+
}
|
|
313
|
+
with open(self.output_dir / "search_index.json", "w") as f:
|
|
314
|
+
json.dump(index, f, indent=2, default=str)
|
|
315
|
+
|
|
316
|
+
return self.output_dir
|
|
317
|
+
|
|
318
|
+
def results_table(self) -> str:
|
|
319
|
+
"""Format results as a text table."""
|
|
320
|
+
completed = [t for t in self.trials if t.status == "completed"]
|
|
321
|
+
if not completed:
|
|
322
|
+
return "No completed trials."
|
|
323
|
+
|
|
324
|
+
# Collect all metric names
|
|
325
|
+
metric_names = sorted(set().union(*(t.result.keys() for t in completed)))
|
|
326
|
+
param_names = sorted(self.space.params.keys())
|
|
327
|
+
|
|
328
|
+
# Header
|
|
329
|
+
cols = ["Trial"] + param_names + metric_names
|
|
330
|
+
lines = [" | ".join(f"{c:>12s}" for c in cols)]
|
|
331
|
+
lines.append("-" * len(lines[0]))
|
|
332
|
+
|
|
333
|
+
# Rows
|
|
334
|
+
for trial in completed:
|
|
335
|
+
parts = [f"{trial.trial_id:>12s}"]
|
|
336
|
+
for p in param_names:
|
|
337
|
+
val = trial.config.get(p, "")
|
|
338
|
+
if isinstance(val, float):
|
|
339
|
+
parts.append(f"{val:>12.6f}")
|
|
340
|
+
else:
|
|
341
|
+
parts.append(f"{val!s:>12s}")
|
|
342
|
+
for m in metric_names:
|
|
343
|
+
val = trial.result.get(m, float("nan"))
|
|
344
|
+
parts.append(f"{val:>12.4f}")
|
|
345
|
+
lines.append(" | ".join(parts))
|
|
346
|
+
|
|
347
|
+
return "\n".join(lines)
|