landmarkdiff 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. landmarkdiff/__init__.py +40 -0
  2. landmarkdiff/__main__.py +207 -0
  3. landmarkdiff/api_client.py +316 -0
  4. landmarkdiff/arcface_torch.py +583 -0
  5. landmarkdiff/audit.py +338 -0
  6. landmarkdiff/augmentation.py +293 -0
  7. landmarkdiff/benchmark.py +213 -0
  8. landmarkdiff/checkpoint_manager.py +361 -0
  9. landmarkdiff/cli.py +252 -0
  10. landmarkdiff/clinical.py +223 -0
  11. landmarkdiff/conditioning.py +278 -0
  12. landmarkdiff/config.py +358 -0
  13. landmarkdiff/curriculum.py +191 -0
  14. landmarkdiff/data.py +405 -0
  15. landmarkdiff/data_version.py +301 -0
  16. landmarkdiff/displacement_model.py +745 -0
  17. landmarkdiff/ensemble.py +330 -0
  18. landmarkdiff/evaluation.py +415 -0
  19. landmarkdiff/experiment_tracker.py +231 -0
  20. landmarkdiff/face_verifier.py +947 -0
  21. landmarkdiff/fid.py +244 -0
  22. landmarkdiff/hyperparam.py +347 -0
  23. landmarkdiff/inference.py +754 -0
  24. landmarkdiff/landmarks.py +432 -0
  25. landmarkdiff/log.py +90 -0
  26. landmarkdiff/losses.py +348 -0
  27. landmarkdiff/manipulation.py +651 -0
  28. landmarkdiff/masking.py +316 -0
  29. landmarkdiff/metrics_agg.py +313 -0
  30. landmarkdiff/metrics_viz.py +464 -0
  31. landmarkdiff/model_registry.py +362 -0
  32. landmarkdiff/morphometry.py +342 -0
  33. landmarkdiff/postprocess.py +600 -0
  34. landmarkdiff/py.typed +0 -0
  35. landmarkdiff/safety.py +395 -0
  36. landmarkdiff/synthetic/__init__.py +23 -0
  37. landmarkdiff/synthetic/augmentation.py +188 -0
  38. landmarkdiff/synthetic/pair_generator.py +208 -0
  39. landmarkdiff/synthetic/tps_warp.py +273 -0
  40. landmarkdiff/validation.py +324 -0
  41. landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
  42. landmarkdiff-0.2.3.dist-info/RECORD +46 -0
  43. landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
  44. landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
  45. landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
  46. landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,278 @@
1
+ """Conditioning signal generation: static adjacency wireframe + auto-Canny.
2
+
3
+ Uses a pre-defined anatomical adjacency matrix (NOT dynamic Delaunay) to prevent
4
+ triangle inversion on drastic landmark displacements. Auto-Canny adapts thresholds
5
+ to skin tone (Fitzpatrick I-VI safe).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import cv2
11
+ import numpy as np
12
+
13
+ from landmarkdiff.landmarks import FaceLandmarks
14
+
15
+ # Static anatomical adjacency for MediaPipe 478 landmarks.
16
+ # Connects landmarks along anatomically meaningful contours:
17
+ # jawline, nasal dorsum, orbital rim, lip vermilion, eyebrow arch.
18
+ # This is invariant to landmark displacement (unlike Delaunay).
19
+
20
+ JAWLINE_CONTOUR = [
21
+ 10,
22
+ 338,
23
+ 297,
24
+ 332,
25
+ 284,
26
+ 251,
27
+ 389,
28
+ 356,
29
+ 454,
30
+ 323,
31
+ 361,
32
+ 288,
33
+ 397,
34
+ 365,
35
+ 379,
36
+ 378,
37
+ 400,
38
+ 377,
39
+ 152,
40
+ 148,
41
+ 176,
42
+ 149,
43
+ 150,
44
+ 136,
45
+ 172,
46
+ 58,
47
+ 132,
48
+ 93,
49
+ 234,
50
+ 127,
51
+ 162,
52
+ 21,
53
+ 54,
54
+ 103,
55
+ 67,
56
+ 109,
57
+ 10,
58
+ ]
59
+
60
+ LEFT_EYE_CONTOUR = [
61
+ 33,
62
+ 7,
63
+ 163,
64
+ 144,
65
+ 145,
66
+ 153,
67
+ 154,
68
+ 155,
69
+ 133,
70
+ 173,
71
+ 157,
72
+ 158,
73
+ 159,
74
+ 160,
75
+ 161,
76
+ 246,
77
+ 33,
78
+ ]
79
+
80
+ RIGHT_EYE_CONTOUR = [
81
+ 362,
82
+ 382,
83
+ 381,
84
+ 380,
85
+ 374,
86
+ 373,
87
+ 390,
88
+ 249,
89
+ 263,
90
+ 466,
91
+ 388,
92
+ 387,
93
+ 386,
94
+ 385,
95
+ 384,
96
+ 398,
97
+ 362,
98
+ ]
99
+
100
+ LEFT_EYEBROW = [70, 63, 105, 66, 107, 55, 65, 52, 53, 46]
101
+ RIGHT_EYEBROW = [300, 293, 334, 296, 336, 285, 295, 282, 283, 276]
102
+
103
+ NOSE_BRIDGE = [168, 6, 197, 195, 5, 4, 1]
104
+ NOSE_TIP = [94, 2, 326, 327, 294, 278, 279, 275, 274, 460, 456, 363, 370]
105
+ NOSE_BOTTOM = [19, 1, 274, 275, 440, 344, 278, 294, 460, 305, 289, 392]
106
+
107
+ OUTER_LIPS = [
108
+ 61,
109
+ 146,
110
+ 91,
111
+ 181,
112
+ 84,
113
+ 17,
114
+ 314,
115
+ 405,
116
+ 321,
117
+ 375,
118
+ 291,
119
+ 308,
120
+ 324,
121
+ 318,
122
+ 402,
123
+ 317,
124
+ 14,
125
+ 87,
126
+ 178,
127
+ 88,
128
+ 95,
129
+ 78,
130
+ 61,
131
+ ]
132
+
133
+ INNER_LIPS = [
134
+ 78,
135
+ 191,
136
+ 80,
137
+ 81,
138
+ 82,
139
+ 13,
140
+ 312,
141
+ 311,
142
+ 310,
143
+ 415,
144
+ 308,
145
+ 324,
146
+ 318,
147
+ 402,
148
+ 317,
149
+ 14,
150
+ 87,
151
+ 178,
152
+ 88,
153
+ 95,
154
+ 78,
155
+ ]
156
+
157
+ # Auto-Canny threshold factors (median-relative)
158
+ _CANNY_LOW_FACTOR = 0.66
159
+ _CANNY_HIGH_FACTOR = 1.33
160
+ _CANNY_DEFAULT_MEDIAN = 128.0 # fallback when no non-zero pixels exist
161
+
162
+ ALL_CONTOURS = [
163
+ JAWLINE_CONTOUR,
164
+ LEFT_EYE_CONTOUR,
165
+ RIGHT_EYE_CONTOUR,
166
+ LEFT_EYEBROW,
167
+ RIGHT_EYEBROW,
168
+ NOSE_BRIDGE,
169
+ NOSE_TIP,
170
+ NOSE_BOTTOM,
171
+ OUTER_LIPS,
172
+ INNER_LIPS,
173
+ ]
174
+
175
+
176
+ def render_wireframe(
177
+ face: FaceLandmarks,
178
+ width: int | None = None,
179
+ height: int | None = None,
180
+ thickness: int = 1,
181
+ ) -> np.ndarray:
182
+ """Render static anatomical adjacency wireframe on black canvas.
183
+
184
+ Args:
185
+ face: Facial landmarks (normalized coordinates).
186
+ width: Canvas width.
187
+ height: Canvas height.
188
+ thickness: Line thickness in pixels.
189
+
190
+ Returns:
191
+ Grayscale wireframe image.
192
+ """
193
+ w = width or face.image_width
194
+ h = height or face.image_height
195
+ canvas = np.zeros((h, w), dtype=np.uint8)
196
+
197
+ coords = face.landmarks[:, :2].copy()
198
+ coords[:, 0] *= w
199
+ coords[:, 1] *= h
200
+ pts = coords.astype(np.int32)
201
+
202
+ for contour in ALL_CONTOURS:
203
+ for i in range(len(contour) - 1):
204
+ p1 = tuple(pts[contour[i]])
205
+ p2 = tuple(pts[contour[i + 1]])
206
+ cv2.line(canvas, p1, p2, 255, thickness)
207
+
208
+ return canvas
209
+
210
+
211
+ def auto_canny(image: np.ndarray) -> np.ndarray:
212
+ """Auto-Canny edge detection with adaptive thresholds.
213
+
214
+ Uses median-based thresholds (0.66*median, 1.33*median) instead of
215
+ hardcoded 50/150 to handle all Fitzpatrick skin types.
216
+ Post-processes with morphological skeletonization for 1-pixel edges.
217
+
218
+ Args:
219
+ image: Grayscale input image.
220
+
221
+ Returns:
222
+ Binary edge map (uint8, 0 or 255).
223
+ """
224
+ median = np.median(image[image > 0]) if np.any(image > 0) else _CANNY_DEFAULT_MEDIAN
225
+ low = int(max(0, _CANNY_LOW_FACTOR * median))
226
+ high = int(min(255, _CANNY_HIGH_FACTOR * median))
227
+
228
+ edges = cv2.Canny(image, low, high)
229
+
230
+ # Morphological skeletonization for guaranteed 1-pixel thickness
231
+ # ControlNet blurs on 2+ pixel edges
232
+ skeleton = np.zeros_like(edges)
233
+ element = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))
234
+ temp = edges.copy()
235
+
236
+ max_iterations = max(edges.shape[0], edges.shape[1])
237
+ for _ in range(max_iterations):
238
+ eroded = cv2.erode(temp, element)
239
+ dilated = cv2.dilate(eroded, element)
240
+ diff = cv2.subtract(temp, dilated)
241
+ skeleton = cv2.bitwise_or(skeleton, diff)
242
+ temp = eroded.copy()
243
+ if cv2.countNonZero(temp) == 0:
244
+ break
245
+
246
+ return skeleton
247
+
248
+
249
+ def generate_conditioning(
250
+ face: FaceLandmarks,
251
+ width: int | None = None,
252
+ height: int | None = None,
253
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
254
+ """Generate full conditioning signal for ControlNet.
255
+
256
+ Returns three channels per the spec:
257
+ 1. Rendered landmark dots (colored, BGR)
258
+ 2. Canny edge map from static wireframe (grayscale)
259
+ 3. Wireframe rendering (grayscale)
260
+
261
+ Args:
262
+ face: Extracted facial landmarks.
263
+ width: Output width.
264
+ height: Output height.
265
+
266
+ Returns:
267
+ Tuple of (landmark_image, canny_edges, wireframe).
268
+ """
269
+ from landmarkdiff.landmarks import render_landmark_image
270
+
271
+ w = width or face.image_width
272
+ h = height or face.image_height
273
+
274
+ landmark_img = render_landmark_image(face, w, h)
275
+ wireframe = render_wireframe(face, w, h)
276
+ canny = auto_canny(wireframe)
277
+
278
+ return landmark_img, canny, wireframe
landmarkdiff/config.py ADDED
@@ -0,0 +1,358 @@
1
+ """YAML-based experiment configuration for reproducible training and evaluation.
2
+
3
+ Provides typed dataclasses that can be loaded from YAML files, enabling
4
+ reproducible experiments with version-tracked configs.
5
+
6
+ Usage:
7
+ from landmarkdiff.config import ExperimentConfig
8
+ config = ExperimentConfig.from_yaml("configs/rhinoplasty_phaseA.yaml")
9
+ print(config.training.learning_rate)
10
+
11
+ # Or create programmatically
12
+ config = ExperimentConfig(
13
+ experiment_name="rhino_v1",
14
+ training=TrainingConfig(phase="A", learning_rate=1e-5),
15
+ )
16
+ config.to_yaml("configs/rhino_v1.yaml")
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ from dataclasses import asdict, dataclass, field
22
+ from pathlib import Path
23
+ from typing import Any
24
+
25
+ import yaml
26
+
27
+
28
+ @dataclass
29
+ class ModelConfig:
30
+ """ControlNet and base model configuration."""
31
+
32
+ base_model: str = "runwayml/stable-diffusion-v1-5"
33
+ controlnet_conditioning_channels: int = 3
34
+ controlnet_conditioning_scale: float = 1.0
35
+ use_ema: bool = True
36
+ ema_decay: float = 0.9999
37
+ gradient_checkpointing: bool = True
38
+
39
+
40
+ @dataclass
41
+ class TrainingConfig:
42
+ """Training hyperparameters."""
43
+
44
+ phase: str = "A" # "A" or "B"
45
+ learning_rate: float = 1e-5
46
+ batch_size: int = 4
47
+ gradient_accumulation_steps: int = 4
48
+ max_train_steps: int = 50000
49
+ warmup_steps: int = 500
50
+ mixed_precision: str = "bf16"
51
+ seed: int = 42
52
+ ema_decay: float = 0.9999
53
+
54
+ # Optimizer
55
+ optimizer: str = "adamw" # "adamw", "adam8bit", "prodigy"
56
+ adam_beta1: float = 0.9
57
+ adam_beta2: float = 0.999
58
+ weight_decay: float = 1e-2
59
+ max_grad_norm: float = 1.0
60
+
61
+ # LR scheduler
62
+ lr_scheduler: str = "cosine"
63
+ lr_scheduler_kwargs: dict[str, Any] = field(default_factory=dict)
64
+
65
+ # Logging intervals
66
+ log_every: int = 100
67
+ sample_every: int = 1000
68
+
69
+ # Phase B specific
70
+ identity_loss_weight: float = 0.1
71
+ perceptual_loss_weight: float = 0.05
72
+ use_differentiable_arcface: bool = False
73
+ arcface_weights_path: str | None = None
74
+
75
+ # Loss weights (alternative to individual weights)
76
+ loss_weights: dict[str, float] = field(default_factory=dict)
77
+
78
+ # Checkpointing
79
+ save_every_n_steps: int = 5000
80
+ resume_from_checkpoint: str | None = None
81
+ resume_phase_a: str | None = None
82
+
83
+ # Validation
84
+ validate_every_n_steps: int = 2500
85
+ num_validation_samples: int = 4
86
+
87
+
88
+ @dataclass
89
+ class DataConfig:
90
+ """Dataset configuration."""
91
+
92
+ train_dir: str = "data/training_combined"
93
+ val_dir: str = "data/splits/val"
94
+ test_dir: str = "data/splits/test"
95
+ image_size: int = 512
96
+ num_workers: int = 4
97
+ pin_memory: bool = True
98
+
99
+ # Augmentation
100
+ random_flip: bool = True
101
+ random_rotation: float = 5.0 # degrees
102
+ color_jitter: float = 0.1
103
+ clinical_augment: bool = False
104
+ geometric_augment: bool = True
105
+
106
+ # Procedure filtering
107
+ procedures: list[str] = field(
108
+ default_factory=lambda: [
109
+ "rhinoplasty",
110
+ "blepharoplasty",
111
+ "rhytidectomy",
112
+ "orthognathic",
113
+ "brow_lift",
114
+ "mentoplasty",
115
+ ]
116
+ )
117
+ intensity_range: tuple[float, float] = (30.0, 100.0)
118
+
119
+ # Data-driven displacement
120
+ displacement_model_path: str | None = None
121
+ noise_scale: float = 0.1
122
+
123
+
124
+ @dataclass
125
+ class InferenceConfig:
126
+ """Inference / generation configuration."""
127
+
128
+ num_inference_steps: int = 30
129
+ guidance_scale: float = 7.5
130
+ scheduler: str = "dpmsolver++" # "ddpm", "ddim", "dpmsolver++"
131
+ controlnet_conditioning_scale: float = 1.0
132
+
133
+ # Post-processing
134
+ use_neural_postprocess: bool = False
135
+ restore_mode: str = "codeformer"
136
+ codeformer_fidelity: float = 0.7
137
+ use_realesrgan: bool = True
138
+ use_laplacian_blend: bool = True
139
+ sharpen_strength: float = 0.25
140
+
141
+ # Identity verification
142
+ verify_identity: bool = True
143
+ identity_threshold: float = 0.6
144
+
145
+
146
+ @dataclass
147
+ class EvaluationConfig:
148
+ """Evaluation configuration."""
149
+
150
+ compute_fid: bool = True
151
+ compute_lpips: bool = True
152
+ compute_nme: bool = True
153
+ compute_identity: bool = True
154
+ compute_ssim: bool = True
155
+ stratify_fitzpatrick: bool = True
156
+ stratify_procedure: bool = True
157
+ max_eval_samples: int = 0 # 0 = all
158
+
159
+
160
+ @dataclass
161
+ class WandbConfig:
162
+ """Weights & Biases logging configuration."""
163
+
164
+ enabled: bool = True
165
+ project: str = "landmarkdiff"
166
+ entity: str | None = None
167
+ run_name: str | None = None
168
+ tags: list[str] = field(default_factory=list)
169
+ mode: str = "online" # "online", "offline", "disabled"
170
+
171
+
172
+ @dataclass
173
+ class SlurmConfig:
174
+ """SLURM job submission parameters."""
175
+
176
+ partition: str = "batch_gpu"
177
+ account: str = "" # Set via YAML or SLURM_ACCOUNT env var
178
+ gpu_type: str = "nvidia_rtx_a6000"
179
+ num_gpus: int = 1
180
+ mem: str = "48G"
181
+ cpus_per_task: int = 8
182
+ time_limit: str = "48:00:00"
183
+ job_prefix: str = "surgery_"
184
+
185
+
186
+ @dataclass
187
+ class SafetyConfig:
188
+ """Clinical safety and responsible AI parameters."""
189
+
190
+ identity_threshold: float = 0.6
191
+ max_displacement_fraction: float = 0.05
192
+ watermark_enabled: bool = True
193
+ watermark_text: str = "AI-GENERATED PREDICTION"
194
+ ood_detection_enabled: bool = True
195
+ ood_confidence_threshold: float = 0.3
196
+ min_face_confidence: float = 0.5
197
+ max_yaw_degrees: float = 45.0
198
+
199
+
200
+ @dataclass
201
+ class ExperimentConfig:
202
+ """Top-level experiment configuration."""
203
+
204
+ experiment_name: str = "default"
205
+ description: str = ""
206
+ version: str = "0.3.2"
207
+
208
+ model: ModelConfig = field(default_factory=ModelConfig)
209
+ training: TrainingConfig = field(default_factory=TrainingConfig)
210
+ data: DataConfig = field(default_factory=DataConfig)
211
+ inference: InferenceConfig = field(default_factory=InferenceConfig)
212
+ evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)
213
+ wandb: WandbConfig = field(default_factory=WandbConfig)
214
+ slurm: SlurmConfig = field(default_factory=SlurmConfig)
215
+ safety: SafetyConfig = field(default_factory=SafetyConfig)
216
+
217
+ # Output
218
+ output_dir: str = "outputs"
219
+
220
+ @classmethod
221
+ def from_yaml(cls, path: str | Path) -> ExperimentConfig:
222
+ """Load config from a YAML file."""
223
+ path = Path(path)
224
+ with open(path) as f:
225
+ raw = yaml.safe_load(f)
226
+
227
+ if raw is None:
228
+ return cls()
229
+
230
+ return cls(
231
+ experiment_name=raw.get("experiment_name", "default"),
232
+ description=raw.get("description", ""),
233
+ version=raw.get("version", "0.3.2"),
234
+ model=_from_dict(ModelConfig, raw.get("model", {})),
235
+ training=_from_dict(TrainingConfig, raw.get("training", {})),
236
+ data=_from_dict(DataConfig, raw.get("data", {})),
237
+ inference=_from_dict(InferenceConfig, raw.get("inference", {})),
238
+ evaluation=_from_dict(EvaluationConfig, raw.get("evaluation", {})),
239
+ wandb=_from_dict(WandbConfig, raw.get("wandb", {})),
240
+ slurm=_from_dict(SlurmConfig, raw.get("slurm", {})),
241
+ safety=_from_dict(SafetyConfig, raw.get("safety", {})),
242
+ output_dir=raw.get("output_dir", "outputs"),
243
+ )
244
+
245
+ def to_yaml(self, path: str | Path) -> None:
246
+ """Save config to a YAML file."""
247
+ path = Path(path)
248
+ path.parent.mkdir(parents=True, exist_ok=True)
249
+ d = _convert_tuples(asdict(self))
250
+ with open(path, "w") as f:
251
+ yaml.dump(d, f, default_flow_style=False, sort_keys=False)
252
+
253
+ def to_dict(self) -> dict:
254
+ """Convert to dictionary."""
255
+ return asdict(self)
256
+
257
+
258
+ _FIELD_ALIASES: dict[str, str] = {
259
+ # YAML name -> dataclass field name
260
+ "max_steps": "max_train_steps",
261
+ "save_interval": "save_every_n_steps",
262
+ "sample_interval": "sample_every",
263
+ "log_interval": "log_every",
264
+ "adam_weight_decay": "weight_decay",
265
+ "lr_warmup_steps": "warmup_steps",
266
+ "resume_from": "resume_from_checkpoint",
267
+ }
268
+
269
+
270
+ def _from_dict(cls: type, d: dict) -> Any:
271
+ """Create a dataclass from a dict, ignoring unknown keys.
272
+
273
+ Supports field aliases so YAML configs using train_controlnet.py-style
274
+ names (e.g. max_steps) map to dataclass fields (max_train_steps).
275
+ """
276
+ import dataclasses
277
+
278
+ field_map = {f.name: f for f in dataclasses.fields(cls)}
279
+ filtered = {}
280
+ for k, v in d.items():
281
+ # Resolve aliases
282
+ canonical = _FIELD_ALIASES.get(k, k)
283
+ if canonical not in field_map:
284
+ continue
285
+ # Don't overwrite if the canonical name was already set explicitly
286
+ if canonical in filtered:
287
+ continue
288
+ # Convert lists back to tuples where the field type is tuple
289
+ f = field_map[canonical]
290
+ if isinstance(v, list) and "tuple" in str(f.type):
291
+ v = tuple(v)
292
+ filtered[canonical] = v
293
+ return cls(**filtered)
294
+
295
+
296
+ def _convert_tuples(obj: Any) -> Any:
297
+ """Recursively convert tuples to lists for YAML serialization."""
298
+ if isinstance(obj, dict):
299
+ return {k: _convert_tuples(v) for k, v in obj.items()}
300
+ if isinstance(obj, (list, tuple)):
301
+ return [_convert_tuples(item) for item in obj]
302
+ return obj
303
+
304
+
305
+ def load_config(
306
+ config_path: str | Path | None = None,
307
+ overrides: dict[str, object] | None = None,
308
+ ) -> ExperimentConfig:
309
+ """Load config with optional dot-notation overrides.
310
+
311
+ Args:
312
+ config_path: Path to YAML config. None returns defaults.
313
+ overrides: Dict of "section.key" -> value overrides.
314
+ E.g., {"training.learning_rate": 5e-6}
315
+
316
+ Returns:
317
+ ExperimentConfig with overrides applied.
318
+ """
319
+ config = ExperimentConfig.from_yaml(config_path) if config_path else ExperimentConfig()
320
+
321
+ if overrides:
322
+ for key, value in overrides.items():
323
+ parts = key.split(".")
324
+ obj = config
325
+ resolved = True
326
+ for part in parts[:-1]:
327
+ if hasattr(obj, part):
328
+ obj = getattr(obj, part)
329
+ else:
330
+ resolved = False
331
+ break
332
+ if resolved and hasattr(obj, parts[-1]):
333
+ setattr(obj, parts[-1], value)
334
+
335
+ return config
336
+
337
+
338
+ def validate_config(config: ExperimentConfig) -> list[str]:
339
+ """Validate config and return list of warnings."""
340
+ warnings = []
341
+
342
+ if config.training.phase == "B" and not config.training.resume_from_checkpoint:
343
+ warnings.append("Phase B should resume from a Phase A checkpoint")
344
+
345
+ eff_batch = config.training.batch_size * config.training.gradient_accumulation_steps
346
+ if eff_batch < 8:
347
+ warnings.append(f"Effective batch size {eff_batch} < 8 may cause instability")
348
+
349
+ if config.training.learning_rate > 1e-4:
350
+ warnings.append("Learning rate > 1e-4 is unusually high for fine-tuning")
351
+
352
+ if config.data.image_size != 512:
353
+ warnings.append(f"Image size {config.data.image_size} != 512; SD1.5 expects 512")
354
+
355
+ if config.safety.identity_threshold < 0.3:
356
+ warnings.append("Identity threshold < 0.3 may pass poor quality outputs")
357
+
358
+ return warnings