landmarkdiff 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- landmarkdiff/__init__.py +40 -0
- landmarkdiff/__main__.py +207 -0
- landmarkdiff/api_client.py +316 -0
- landmarkdiff/arcface_torch.py +583 -0
- landmarkdiff/audit.py +338 -0
- landmarkdiff/augmentation.py +293 -0
- landmarkdiff/benchmark.py +213 -0
- landmarkdiff/checkpoint_manager.py +361 -0
- landmarkdiff/cli.py +252 -0
- landmarkdiff/clinical.py +223 -0
- landmarkdiff/conditioning.py +278 -0
- landmarkdiff/config.py +358 -0
- landmarkdiff/curriculum.py +191 -0
- landmarkdiff/data.py +405 -0
- landmarkdiff/data_version.py +301 -0
- landmarkdiff/displacement_model.py +745 -0
- landmarkdiff/ensemble.py +330 -0
- landmarkdiff/evaluation.py +415 -0
- landmarkdiff/experiment_tracker.py +231 -0
- landmarkdiff/face_verifier.py +947 -0
- landmarkdiff/fid.py +244 -0
- landmarkdiff/hyperparam.py +347 -0
- landmarkdiff/inference.py +754 -0
- landmarkdiff/landmarks.py +432 -0
- landmarkdiff/log.py +90 -0
- landmarkdiff/losses.py +348 -0
- landmarkdiff/manipulation.py +651 -0
- landmarkdiff/masking.py +316 -0
- landmarkdiff/metrics_agg.py +313 -0
- landmarkdiff/metrics_viz.py +464 -0
- landmarkdiff/model_registry.py +362 -0
- landmarkdiff/morphometry.py +342 -0
- landmarkdiff/postprocess.py +600 -0
- landmarkdiff/py.typed +0 -0
- landmarkdiff/safety.py +395 -0
- landmarkdiff/synthetic/__init__.py +23 -0
- landmarkdiff/synthetic/augmentation.py +188 -0
- landmarkdiff/synthetic/pair_generator.py +208 -0
- landmarkdiff/synthetic/tps_warp.py +273 -0
- landmarkdiff/validation.py +324 -0
- landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
- landmarkdiff-0.2.3.dist-info/RECORD +46 -0
- landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
- landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
- landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
- landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
"""Conditioning signal generation: static adjacency wireframe + auto-Canny.
|
|
2
|
+
|
|
3
|
+
Uses a pre-defined anatomical adjacency matrix (NOT dynamic Delaunay) to prevent
|
|
4
|
+
triangle inversion on drastic landmark displacements. Auto-Canny adapts thresholds
|
|
5
|
+
to skin tone (Fitzpatrick I-VI safe).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import cv2
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from landmarkdiff.landmarks import FaceLandmarks
|
|
14
|
+
|
|
15
|
+
# Static anatomical adjacency for MediaPipe 478 landmarks.
|
|
16
|
+
# Connects landmarks along anatomically meaningful contours:
|
|
17
|
+
# jawline, nasal dorsum, orbital rim, lip vermilion, eyebrow arch.
|
|
18
|
+
# This is invariant to landmark displacement (unlike Delaunay).
|
|
19
|
+
|
|
20
|
+
JAWLINE_CONTOUR = [
|
|
21
|
+
10,
|
|
22
|
+
338,
|
|
23
|
+
297,
|
|
24
|
+
332,
|
|
25
|
+
284,
|
|
26
|
+
251,
|
|
27
|
+
389,
|
|
28
|
+
356,
|
|
29
|
+
454,
|
|
30
|
+
323,
|
|
31
|
+
361,
|
|
32
|
+
288,
|
|
33
|
+
397,
|
|
34
|
+
365,
|
|
35
|
+
379,
|
|
36
|
+
378,
|
|
37
|
+
400,
|
|
38
|
+
377,
|
|
39
|
+
152,
|
|
40
|
+
148,
|
|
41
|
+
176,
|
|
42
|
+
149,
|
|
43
|
+
150,
|
|
44
|
+
136,
|
|
45
|
+
172,
|
|
46
|
+
58,
|
|
47
|
+
132,
|
|
48
|
+
93,
|
|
49
|
+
234,
|
|
50
|
+
127,
|
|
51
|
+
162,
|
|
52
|
+
21,
|
|
53
|
+
54,
|
|
54
|
+
103,
|
|
55
|
+
67,
|
|
56
|
+
109,
|
|
57
|
+
10,
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
LEFT_EYE_CONTOUR = [
|
|
61
|
+
33,
|
|
62
|
+
7,
|
|
63
|
+
163,
|
|
64
|
+
144,
|
|
65
|
+
145,
|
|
66
|
+
153,
|
|
67
|
+
154,
|
|
68
|
+
155,
|
|
69
|
+
133,
|
|
70
|
+
173,
|
|
71
|
+
157,
|
|
72
|
+
158,
|
|
73
|
+
159,
|
|
74
|
+
160,
|
|
75
|
+
161,
|
|
76
|
+
246,
|
|
77
|
+
33,
|
|
78
|
+
]
|
|
79
|
+
|
|
80
|
+
RIGHT_EYE_CONTOUR = [
|
|
81
|
+
362,
|
|
82
|
+
382,
|
|
83
|
+
381,
|
|
84
|
+
380,
|
|
85
|
+
374,
|
|
86
|
+
373,
|
|
87
|
+
390,
|
|
88
|
+
249,
|
|
89
|
+
263,
|
|
90
|
+
466,
|
|
91
|
+
388,
|
|
92
|
+
387,
|
|
93
|
+
386,
|
|
94
|
+
385,
|
|
95
|
+
384,
|
|
96
|
+
398,
|
|
97
|
+
362,
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
LEFT_EYEBROW = [70, 63, 105, 66, 107, 55, 65, 52, 53, 46]
|
|
101
|
+
RIGHT_EYEBROW = [300, 293, 334, 296, 336, 285, 295, 282, 283, 276]
|
|
102
|
+
|
|
103
|
+
NOSE_BRIDGE = [168, 6, 197, 195, 5, 4, 1]
|
|
104
|
+
NOSE_TIP = [94, 2, 326, 327, 294, 278, 279, 275, 274, 460, 456, 363, 370]
|
|
105
|
+
NOSE_BOTTOM = [19, 1, 274, 275, 440, 344, 278, 294, 460, 305, 289, 392]
|
|
106
|
+
|
|
107
|
+
OUTER_LIPS = [
|
|
108
|
+
61,
|
|
109
|
+
146,
|
|
110
|
+
91,
|
|
111
|
+
181,
|
|
112
|
+
84,
|
|
113
|
+
17,
|
|
114
|
+
314,
|
|
115
|
+
405,
|
|
116
|
+
321,
|
|
117
|
+
375,
|
|
118
|
+
291,
|
|
119
|
+
308,
|
|
120
|
+
324,
|
|
121
|
+
318,
|
|
122
|
+
402,
|
|
123
|
+
317,
|
|
124
|
+
14,
|
|
125
|
+
87,
|
|
126
|
+
178,
|
|
127
|
+
88,
|
|
128
|
+
95,
|
|
129
|
+
78,
|
|
130
|
+
61,
|
|
131
|
+
]
|
|
132
|
+
|
|
133
|
+
INNER_LIPS = [
|
|
134
|
+
78,
|
|
135
|
+
191,
|
|
136
|
+
80,
|
|
137
|
+
81,
|
|
138
|
+
82,
|
|
139
|
+
13,
|
|
140
|
+
312,
|
|
141
|
+
311,
|
|
142
|
+
310,
|
|
143
|
+
415,
|
|
144
|
+
308,
|
|
145
|
+
324,
|
|
146
|
+
318,
|
|
147
|
+
402,
|
|
148
|
+
317,
|
|
149
|
+
14,
|
|
150
|
+
87,
|
|
151
|
+
178,
|
|
152
|
+
88,
|
|
153
|
+
95,
|
|
154
|
+
78,
|
|
155
|
+
]
|
|
156
|
+
|
|
157
|
+
# Auto-Canny threshold factors (median-relative)
|
|
158
|
+
_CANNY_LOW_FACTOR = 0.66
|
|
159
|
+
_CANNY_HIGH_FACTOR = 1.33
|
|
160
|
+
_CANNY_DEFAULT_MEDIAN = 128.0 # fallback when no non-zero pixels exist
|
|
161
|
+
|
|
162
|
+
ALL_CONTOURS = [
|
|
163
|
+
JAWLINE_CONTOUR,
|
|
164
|
+
LEFT_EYE_CONTOUR,
|
|
165
|
+
RIGHT_EYE_CONTOUR,
|
|
166
|
+
LEFT_EYEBROW,
|
|
167
|
+
RIGHT_EYEBROW,
|
|
168
|
+
NOSE_BRIDGE,
|
|
169
|
+
NOSE_TIP,
|
|
170
|
+
NOSE_BOTTOM,
|
|
171
|
+
OUTER_LIPS,
|
|
172
|
+
INNER_LIPS,
|
|
173
|
+
]
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def render_wireframe(
|
|
177
|
+
face: FaceLandmarks,
|
|
178
|
+
width: int | None = None,
|
|
179
|
+
height: int | None = None,
|
|
180
|
+
thickness: int = 1,
|
|
181
|
+
) -> np.ndarray:
|
|
182
|
+
"""Render static anatomical adjacency wireframe on black canvas.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
face: Facial landmarks (normalized coordinates).
|
|
186
|
+
width: Canvas width.
|
|
187
|
+
height: Canvas height.
|
|
188
|
+
thickness: Line thickness in pixels.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
Grayscale wireframe image.
|
|
192
|
+
"""
|
|
193
|
+
w = width or face.image_width
|
|
194
|
+
h = height or face.image_height
|
|
195
|
+
canvas = np.zeros((h, w), dtype=np.uint8)
|
|
196
|
+
|
|
197
|
+
coords = face.landmarks[:, :2].copy()
|
|
198
|
+
coords[:, 0] *= w
|
|
199
|
+
coords[:, 1] *= h
|
|
200
|
+
pts = coords.astype(np.int32)
|
|
201
|
+
|
|
202
|
+
for contour in ALL_CONTOURS:
|
|
203
|
+
for i in range(len(contour) - 1):
|
|
204
|
+
p1 = tuple(pts[contour[i]])
|
|
205
|
+
p2 = tuple(pts[contour[i + 1]])
|
|
206
|
+
cv2.line(canvas, p1, p2, 255, thickness)
|
|
207
|
+
|
|
208
|
+
return canvas
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def auto_canny(image: np.ndarray) -> np.ndarray:
|
|
212
|
+
"""Auto-Canny edge detection with adaptive thresholds.
|
|
213
|
+
|
|
214
|
+
Uses median-based thresholds (0.66*median, 1.33*median) instead of
|
|
215
|
+
hardcoded 50/150 to handle all Fitzpatrick skin types.
|
|
216
|
+
Post-processes with morphological skeletonization for 1-pixel edges.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
image: Grayscale input image.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
Binary edge map (uint8, 0 or 255).
|
|
223
|
+
"""
|
|
224
|
+
median = np.median(image[image > 0]) if np.any(image > 0) else _CANNY_DEFAULT_MEDIAN
|
|
225
|
+
low = int(max(0, _CANNY_LOW_FACTOR * median))
|
|
226
|
+
high = int(min(255, _CANNY_HIGH_FACTOR * median))
|
|
227
|
+
|
|
228
|
+
edges = cv2.Canny(image, low, high)
|
|
229
|
+
|
|
230
|
+
# Morphological skeletonization for guaranteed 1-pixel thickness
|
|
231
|
+
# ControlNet blurs on 2+ pixel edges
|
|
232
|
+
skeleton = np.zeros_like(edges)
|
|
233
|
+
element = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))
|
|
234
|
+
temp = edges.copy()
|
|
235
|
+
|
|
236
|
+
max_iterations = max(edges.shape[0], edges.shape[1])
|
|
237
|
+
for _ in range(max_iterations):
|
|
238
|
+
eroded = cv2.erode(temp, element)
|
|
239
|
+
dilated = cv2.dilate(eroded, element)
|
|
240
|
+
diff = cv2.subtract(temp, dilated)
|
|
241
|
+
skeleton = cv2.bitwise_or(skeleton, diff)
|
|
242
|
+
temp = eroded.copy()
|
|
243
|
+
if cv2.countNonZero(temp) == 0:
|
|
244
|
+
break
|
|
245
|
+
|
|
246
|
+
return skeleton
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def generate_conditioning(
|
|
250
|
+
face: FaceLandmarks,
|
|
251
|
+
width: int | None = None,
|
|
252
|
+
height: int | None = None,
|
|
253
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
254
|
+
"""Generate full conditioning signal for ControlNet.
|
|
255
|
+
|
|
256
|
+
Returns three channels per the spec:
|
|
257
|
+
1. Rendered landmark dots (colored, BGR)
|
|
258
|
+
2. Canny edge map from static wireframe (grayscale)
|
|
259
|
+
3. Wireframe rendering (grayscale)
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
face: Extracted facial landmarks.
|
|
263
|
+
width: Output width.
|
|
264
|
+
height: Output height.
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
Tuple of (landmark_image, canny_edges, wireframe).
|
|
268
|
+
"""
|
|
269
|
+
from landmarkdiff.landmarks import render_landmark_image
|
|
270
|
+
|
|
271
|
+
w = width or face.image_width
|
|
272
|
+
h = height or face.image_height
|
|
273
|
+
|
|
274
|
+
landmark_img = render_landmark_image(face, w, h)
|
|
275
|
+
wireframe = render_wireframe(face, w, h)
|
|
276
|
+
canny = auto_canny(wireframe)
|
|
277
|
+
|
|
278
|
+
return landmark_img, canny, wireframe
|
landmarkdiff/config.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
1
|
+
"""YAML-based experiment configuration for reproducible training and evaluation.
|
|
2
|
+
|
|
3
|
+
Provides typed dataclasses that can be loaded from YAML files, enabling
|
|
4
|
+
reproducible experiments with version-tracked configs.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
from landmarkdiff.config import ExperimentConfig
|
|
8
|
+
config = ExperimentConfig.from_yaml("configs/rhinoplasty_phaseA.yaml")
|
|
9
|
+
print(config.training.learning_rate)
|
|
10
|
+
|
|
11
|
+
# Or create programmatically
|
|
12
|
+
config = ExperimentConfig(
|
|
13
|
+
experiment_name="rhino_v1",
|
|
14
|
+
training=TrainingConfig(phase="A", learning_rate=1e-5),
|
|
15
|
+
)
|
|
16
|
+
config.to_yaml("configs/rhino_v1.yaml")
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
from dataclasses import asdict, dataclass, field
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from typing import Any
|
|
24
|
+
|
|
25
|
+
import yaml
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class ModelConfig:
|
|
30
|
+
"""ControlNet and base model configuration."""
|
|
31
|
+
|
|
32
|
+
base_model: str = "runwayml/stable-diffusion-v1-5"
|
|
33
|
+
controlnet_conditioning_channels: int = 3
|
|
34
|
+
controlnet_conditioning_scale: float = 1.0
|
|
35
|
+
use_ema: bool = True
|
|
36
|
+
ema_decay: float = 0.9999
|
|
37
|
+
gradient_checkpointing: bool = True
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class TrainingConfig:
|
|
42
|
+
"""Training hyperparameters."""
|
|
43
|
+
|
|
44
|
+
phase: str = "A" # "A" or "B"
|
|
45
|
+
learning_rate: float = 1e-5
|
|
46
|
+
batch_size: int = 4
|
|
47
|
+
gradient_accumulation_steps: int = 4
|
|
48
|
+
max_train_steps: int = 50000
|
|
49
|
+
warmup_steps: int = 500
|
|
50
|
+
mixed_precision: str = "bf16"
|
|
51
|
+
seed: int = 42
|
|
52
|
+
ema_decay: float = 0.9999
|
|
53
|
+
|
|
54
|
+
# Optimizer
|
|
55
|
+
optimizer: str = "adamw" # "adamw", "adam8bit", "prodigy"
|
|
56
|
+
adam_beta1: float = 0.9
|
|
57
|
+
adam_beta2: float = 0.999
|
|
58
|
+
weight_decay: float = 1e-2
|
|
59
|
+
max_grad_norm: float = 1.0
|
|
60
|
+
|
|
61
|
+
# LR scheduler
|
|
62
|
+
lr_scheduler: str = "cosine"
|
|
63
|
+
lr_scheduler_kwargs: dict[str, Any] = field(default_factory=dict)
|
|
64
|
+
|
|
65
|
+
# Logging intervals
|
|
66
|
+
log_every: int = 100
|
|
67
|
+
sample_every: int = 1000
|
|
68
|
+
|
|
69
|
+
# Phase B specific
|
|
70
|
+
identity_loss_weight: float = 0.1
|
|
71
|
+
perceptual_loss_weight: float = 0.05
|
|
72
|
+
use_differentiable_arcface: bool = False
|
|
73
|
+
arcface_weights_path: str | None = None
|
|
74
|
+
|
|
75
|
+
# Loss weights (alternative to individual weights)
|
|
76
|
+
loss_weights: dict[str, float] = field(default_factory=dict)
|
|
77
|
+
|
|
78
|
+
# Checkpointing
|
|
79
|
+
save_every_n_steps: int = 5000
|
|
80
|
+
resume_from_checkpoint: str | None = None
|
|
81
|
+
resume_phase_a: str | None = None
|
|
82
|
+
|
|
83
|
+
# Validation
|
|
84
|
+
validate_every_n_steps: int = 2500
|
|
85
|
+
num_validation_samples: int = 4
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@dataclass
|
|
89
|
+
class DataConfig:
|
|
90
|
+
"""Dataset configuration."""
|
|
91
|
+
|
|
92
|
+
train_dir: str = "data/training_combined"
|
|
93
|
+
val_dir: str = "data/splits/val"
|
|
94
|
+
test_dir: str = "data/splits/test"
|
|
95
|
+
image_size: int = 512
|
|
96
|
+
num_workers: int = 4
|
|
97
|
+
pin_memory: bool = True
|
|
98
|
+
|
|
99
|
+
# Augmentation
|
|
100
|
+
random_flip: bool = True
|
|
101
|
+
random_rotation: float = 5.0 # degrees
|
|
102
|
+
color_jitter: float = 0.1
|
|
103
|
+
clinical_augment: bool = False
|
|
104
|
+
geometric_augment: bool = True
|
|
105
|
+
|
|
106
|
+
# Procedure filtering
|
|
107
|
+
procedures: list[str] = field(
|
|
108
|
+
default_factory=lambda: [
|
|
109
|
+
"rhinoplasty",
|
|
110
|
+
"blepharoplasty",
|
|
111
|
+
"rhytidectomy",
|
|
112
|
+
"orthognathic",
|
|
113
|
+
"brow_lift",
|
|
114
|
+
"mentoplasty",
|
|
115
|
+
]
|
|
116
|
+
)
|
|
117
|
+
intensity_range: tuple[float, float] = (30.0, 100.0)
|
|
118
|
+
|
|
119
|
+
# Data-driven displacement
|
|
120
|
+
displacement_model_path: str | None = None
|
|
121
|
+
noise_scale: float = 0.1
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@dataclass
|
|
125
|
+
class InferenceConfig:
|
|
126
|
+
"""Inference / generation configuration."""
|
|
127
|
+
|
|
128
|
+
num_inference_steps: int = 30
|
|
129
|
+
guidance_scale: float = 7.5
|
|
130
|
+
scheduler: str = "dpmsolver++" # "ddpm", "ddim", "dpmsolver++"
|
|
131
|
+
controlnet_conditioning_scale: float = 1.0
|
|
132
|
+
|
|
133
|
+
# Post-processing
|
|
134
|
+
use_neural_postprocess: bool = False
|
|
135
|
+
restore_mode: str = "codeformer"
|
|
136
|
+
codeformer_fidelity: float = 0.7
|
|
137
|
+
use_realesrgan: bool = True
|
|
138
|
+
use_laplacian_blend: bool = True
|
|
139
|
+
sharpen_strength: float = 0.25
|
|
140
|
+
|
|
141
|
+
# Identity verification
|
|
142
|
+
verify_identity: bool = True
|
|
143
|
+
identity_threshold: float = 0.6
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@dataclass
|
|
147
|
+
class EvaluationConfig:
|
|
148
|
+
"""Evaluation configuration."""
|
|
149
|
+
|
|
150
|
+
compute_fid: bool = True
|
|
151
|
+
compute_lpips: bool = True
|
|
152
|
+
compute_nme: bool = True
|
|
153
|
+
compute_identity: bool = True
|
|
154
|
+
compute_ssim: bool = True
|
|
155
|
+
stratify_fitzpatrick: bool = True
|
|
156
|
+
stratify_procedure: bool = True
|
|
157
|
+
max_eval_samples: int = 0 # 0 = all
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
@dataclass
|
|
161
|
+
class WandbConfig:
|
|
162
|
+
"""Weights & Biases logging configuration."""
|
|
163
|
+
|
|
164
|
+
enabled: bool = True
|
|
165
|
+
project: str = "landmarkdiff"
|
|
166
|
+
entity: str | None = None
|
|
167
|
+
run_name: str | None = None
|
|
168
|
+
tags: list[str] = field(default_factory=list)
|
|
169
|
+
mode: str = "online" # "online", "offline", "disabled"
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
@dataclass
|
|
173
|
+
class SlurmConfig:
|
|
174
|
+
"""SLURM job submission parameters."""
|
|
175
|
+
|
|
176
|
+
partition: str = "batch_gpu"
|
|
177
|
+
account: str = "" # Set via YAML or SLURM_ACCOUNT env var
|
|
178
|
+
gpu_type: str = "nvidia_rtx_a6000"
|
|
179
|
+
num_gpus: int = 1
|
|
180
|
+
mem: str = "48G"
|
|
181
|
+
cpus_per_task: int = 8
|
|
182
|
+
time_limit: str = "48:00:00"
|
|
183
|
+
job_prefix: str = "surgery_"
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@dataclass
|
|
187
|
+
class SafetyConfig:
|
|
188
|
+
"""Clinical safety and responsible AI parameters."""
|
|
189
|
+
|
|
190
|
+
identity_threshold: float = 0.6
|
|
191
|
+
max_displacement_fraction: float = 0.05
|
|
192
|
+
watermark_enabled: bool = True
|
|
193
|
+
watermark_text: str = "AI-GENERATED PREDICTION"
|
|
194
|
+
ood_detection_enabled: bool = True
|
|
195
|
+
ood_confidence_threshold: float = 0.3
|
|
196
|
+
min_face_confidence: float = 0.5
|
|
197
|
+
max_yaw_degrees: float = 45.0
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
@dataclass
|
|
201
|
+
class ExperimentConfig:
|
|
202
|
+
"""Top-level experiment configuration."""
|
|
203
|
+
|
|
204
|
+
experiment_name: str = "default"
|
|
205
|
+
description: str = ""
|
|
206
|
+
version: str = "0.3.2"
|
|
207
|
+
|
|
208
|
+
model: ModelConfig = field(default_factory=ModelConfig)
|
|
209
|
+
training: TrainingConfig = field(default_factory=TrainingConfig)
|
|
210
|
+
data: DataConfig = field(default_factory=DataConfig)
|
|
211
|
+
inference: InferenceConfig = field(default_factory=InferenceConfig)
|
|
212
|
+
evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)
|
|
213
|
+
wandb: WandbConfig = field(default_factory=WandbConfig)
|
|
214
|
+
slurm: SlurmConfig = field(default_factory=SlurmConfig)
|
|
215
|
+
safety: SafetyConfig = field(default_factory=SafetyConfig)
|
|
216
|
+
|
|
217
|
+
# Output
|
|
218
|
+
output_dir: str = "outputs"
|
|
219
|
+
|
|
220
|
+
@classmethod
|
|
221
|
+
def from_yaml(cls, path: str | Path) -> ExperimentConfig:
|
|
222
|
+
"""Load config from a YAML file."""
|
|
223
|
+
path = Path(path)
|
|
224
|
+
with open(path) as f:
|
|
225
|
+
raw = yaml.safe_load(f)
|
|
226
|
+
|
|
227
|
+
if raw is None:
|
|
228
|
+
return cls()
|
|
229
|
+
|
|
230
|
+
return cls(
|
|
231
|
+
experiment_name=raw.get("experiment_name", "default"),
|
|
232
|
+
description=raw.get("description", ""),
|
|
233
|
+
version=raw.get("version", "0.3.2"),
|
|
234
|
+
model=_from_dict(ModelConfig, raw.get("model", {})),
|
|
235
|
+
training=_from_dict(TrainingConfig, raw.get("training", {})),
|
|
236
|
+
data=_from_dict(DataConfig, raw.get("data", {})),
|
|
237
|
+
inference=_from_dict(InferenceConfig, raw.get("inference", {})),
|
|
238
|
+
evaluation=_from_dict(EvaluationConfig, raw.get("evaluation", {})),
|
|
239
|
+
wandb=_from_dict(WandbConfig, raw.get("wandb", {})),
|
|
240
|
+
slurm=_from_dict(SlurmConfig, raw.get("slurm", {})),
|
|
241
|
+
safety=_from_dict(SafetyConfig, raw.get("safety", {})),
|
|
242
|
+
output_dir=raw.get("output_dir", "outputs"),
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
def to_yaml(self, path: str | Path) -> None:
|
|
246
|
+
"""Save config to a YAML file."""
|
|
247
|
+
path = Path(path)
|
|
248
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
249
|
+
d = _convert_tuples(asdict(self))
|
|
250
|
+
with open(path, "w") as f:
|
|
251
|
+
yaml.dump(d, f, default_flow_style=False, sort_keys=False)
|
|
252
|
+
|
|
253
|
+
def to_dict(self) -> dict:
|
|
254
|
+
"""Convert to dictionary."""
|
|
255
|
+
return asdict(self)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
_FIELD_ALIASES: dict[str, str] = {
|
|
259
|
+
# YAML name -> dataclass field name
|
|
260
|
+
"max_steps": "max_train_steps",
|
|
261
|
+
"save_interval": "save_every_n_steps",
|
|
262
|
+
"sample_interval": "sample_every",
|
|
263
|
+
"log_interval": "log_every",
|
|
264
|
+
"adam_weight_decay": "weight_decay",
|
|
265
|
+
"lr_warmup_steps": "warmup_steps",
|
|
266
|
+
"resume_from": "resume_from_checkpoint",
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def _from_dict(cls: type, d: dict) -> Any:
|
|
271
|
+
"""Create a dataclass from a dict, ignoring unknown keys.
|
|
272
|
+
|
|
273
|
+
Supports field aliases so YAML configs using train_controlnet.py-style
|
|
274
|
+
names (e.g. max_steps) map to dataclass fields (max_train_steps).
|
|
275
|
+
"""
|
|
276
|
+
import dataclasses
|
|
277
|
+
|
|
278
|
+
field_map = {f.name: f for f in dataclasses.fields(cls)}
|
|
279
|
+
filtered = {}
|
|
280
|
+
for k, v in d.items():
|
|
281
|
+
# Resolve aliases
|
|
282
|
+
canonical = _FIELD_ALIASES.get(k, k)
|
|
283
|
+
if canonical not in field_map:
|
|
284
|
+
continue
|
|
285
|
+
# Don't overwrite if the canonical name was already set explicitly
|
|
286
|
+
if canonical in filtered:
|
|
287
|
+
continue
|
|
288
|
+
# Convert lists back to tuples where the field type is tuple
|
|
289
|
+
f = field_map[canonical]
|
|
290
|
+
if isinstance(v, list) and "tuple" in str(f.type):
|
|
291
|
+
v = tuple(v)
|
|
292
|
+
filtered[canonical] = v
|
|
293
|
+
return cls(**filtered)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def _convert_tuples(obj: Any) -> Any:
|
|
297
|
+
"""Recursively convert tuples to lists for YAML serialization."""
|
|
298
|
+
if isinstance(obj, dict):
|
|
299
|
+
return {k: _convert_tuples(v) for k, v in obj.items()}
|
|
300
|
+
if isinstance(obj, (list, tuple)):
|
|
301
|
+
return [_convert_tuples(item) for item in obj]
|
|
302
|
+
return obj
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def load_config(
|
|
306
|
+
config_path: str | Path | None = None,
|
|
307
|
+
overrides: dict[str, object] | None = None,
|
|
308
|
+
) -> ExperimentConfig:
|
|
309
|
+
"""Load config with optional dot-notation overrides.
|
|
310
|
+
|
|
311
|
+
Args:
|
|
312
|
+
config_path: Path to YAML config. None returns defaults.
|
|
313
|
+
overrides: Dict of "section.key" -> value overrides.
|
|
314
|
+
E.g., {"training.learning_rate": 5e-6}
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
ExperimentConfig with overrides applied.
|
|
318
|
+
"""
|
|
319
|
+
config = ExperimentConfig.from_yaml(config_path) if config_path else ExperimentConfig()
|
|
320
|
+
|
|
321
|
+
if overrides:
|
|
322
|
+
for key, value in overrides.items():
|
|
323
|
+
parts = key.split(".")
|
|
324
|
+
obj = config
|
|
325
|
+
resolved = True
|
|
326
|
+
for part in parts[:-1]:
|
|
327
|
+
if hasattr(obj, part):
|
|
328
|
+
obj = getattr(obj, part)
|
|
329
|
+
else:
|
|
330
|
+
resolved = False
|
|
331
|
+
break
|
|
332
|
+
if resolved and hasattr(obj, parts[-1]):
|
|
333
|
+
setattr(obj, parts[-1], value)
|
|
334
|
+
|
|
335
|
+
return config
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def validate_config(config: ExperimentConfig) -> list[str]:
|
|
339
|
+
"""Validate config and return list of warnings."""
|
|
340
|
+
warnings = []
|
|
341
|
+
|
|
342
|
+
if config.training.phase == "B" and not config.training.resume_from_checkpoint:
|
|
343
|
+
warnings.append("Phase B should resume from a Phase A checkpoint")
|
|
344
|
+
|
|
345
|
+
eff_batch = config.training.batch_size * config.training.gradient_accumulation_steps
|
|
346
|
+
if eff_batch < 8:
|
|
347
|
+
warnings.append(f"Effective batch size {eff_batch} < 8 may cause instability")
|
|
348
|
+
|
|
349
|
+
if config.training.learning_rate > 1e-4:
|
|
350
|
+
warnings.append("Learning rate > 1e-4 is unusually high for fine-tuning")
|
|
351
|
+
|
|
352
|
+
if config.data.image_size != 512:
|
|
353
|
+
warnings.append(f"Image size {config.data.image_size} != 512; SD1.5 expects 512")
|
|
354
|
+
|
|
355
|
+
if config.safety.identity_threshold < 0.3:
|
|
356
|
+
warnings.append("Identity threshold < 0.3 may pass poor quality outputs")
|
|
357
|
+
|
|
358
|
+
return warnings
|