landmarkdiff 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. landmarkdiff/__init__.py +40 -0
  2. landmarkdiff/__main__.py +207 -0
  3. landmarkdiff/api_client.py +316 -0
  4. landmarkdiff/arcface_torch.py +583 -0
  5. landmarkdiff/audit.py +338 -0
  6. landmarkdiff/augmentation.py +293 -0
  7. landmarkdiff/benchmark.py +213 -0
  8. landmarkdiff/checkpoint_manager.py +361 -0
  9. landmarkdiff/cli.py +252 -0
  10. landmarkdiff/clinical.py +223 -0
  11. landmarkdiff/conditioning.py +278 -0
  12. landmarkdiff/config.py +358 -0
  13. landmarkdiff/curriculum.py +191 -0
  14. landmarkdiff/data.py +405 -0
  15. landmarkdiff/data_version.py +301 -0
  16. landmarkdiff/displacement_model.py +745 -0
  17. landmarkdiff/ensemble.py +330 -0
  18. landmarkdiff/evaluation.py +415 -0
  19. landmarkdiff/experiment_tracker.py +231 -0
  20. landmarkdiff/face_verifier.py +947 -0
  21. landmarkdiff/fid.py +244 -0
  22. landmarkdiff/hyperparam.py +347 -0
  23. landmarkdiff/inference.py +754 -0
  24. landmarkdiff/landmarks.py +432 -0
  25. landmarkdiff/log.py +90 -0
  26. landmarkdiff/losses.py +348 -0
  27. landmarkdiff/manipulation.py +651 -0
  28. landmarkdiff/masking.py +316 -0
  29. landmarkdiff/metrics_agg.py +313 -0
  30. landmarkdiff/metrics_viz.py +464 -0
  31. landmarkdiff/model_registry.py +362 -0
  32. landmarkdiff/morphometry.py +342 -0
  33. landmarkdiff/postprocess.py +600 -0
  34. landmarkdiff/py.typed +0 -0
  35. landmarkdiff/safety.py +395 -0
  36. landmarkdiff/synthetic/__init__.py +23 -0
  37. landmarkdiff/synthetic/augmentation.py +188 -0
  38. landmarkdiff/synthetic/pair_generator.py +208 -0
  39. landmarkdiff/synthetic/tps_warp.py +273 -0
  40. landmarkdiff/validation.py +324 -0
  41. landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
  42. landmarkdiff-0.2.3.dist-info/RECORD +46 -0
  43. landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
  44. landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
  45. landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
  46. landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
landmarkdiff/cli.py ADDED
@@ -0,0 +1,252 @@
1
+ """Unified CLI for LandmarkDiff.
2
+
3
+ Usage:
4
+ landmarkdiff infer IMAGE --procedure rhinoplasty --intensity 65
5
+ landmarkdiff evaluate --test-dir data/test --checkpoint checkpoints/latest
6
+ landmarkdiff train --config configs/phaseA.yaml
7
+ landmarkdiff demo IMAGE --output demo_report.png
8
+ landmarkdiff config --show
9
+ landmarkdiff validate IMAGE --output validated.png
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import logging
16
+ import sys
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def cmd_infer(args: argparse.Namespace) -> None:
22
+ """Run single-image inference."""
23
+ from pathlib import Path
24
+
25
+ import cv2
26
+
27
+ from landmarkdiff.inference import LandmarkDiffPipeline
28
+
29
+ image = cv2.imread(args.image)
30
+ if image is None:
31
+ logger.error("Cannot read image: %s", args.image)
32
+ sys.exit(1)
33
+
34
+ image = cv2.resize(image, (512, 512))
35
+
36
+ pipeline = LandmarkDiffPipeline(
37
+ mode=args.mode,
38
+ controlnet_checkpoint=args.checkpoint,
39
+ displacement_model_path=args.displacement_model,
40
+ )
41
+ pipeline.load()
42
+
43
+ result = pipeline.generate(
44
+ image,
45
+ procedure=args.procedure,
46
+ intensity=args.intensity,
47
+ seed=args.seed,
48
+ )
49
+
50
+ out_path = Path(args.output)
51
+ out_path.parent.mkdir(parents=True, exist_ok=True)
52
+ cv2.imwrite(str(out_path), result["output"])
53
+ print(f"Output saved: {out_path}")
54
+
55
+ if args.watermark:
56
+ from landmarkdiff.safety import SafetyValidator
57
+
58
+ validator = SafetyValidator()
59
+ watermarked = validator.apply_watermark(result["output"])
60
+ wm_path = out_path.with_stem(out_path.stem + "_watermarked")
61
+ cv2.imwrite(str(wm_path), watermarked)
62
+ print(f"Watermarked: {wm_path}")
63
+
64
+
65
+ def cmd_ensemble(args: argparse.Namespace) -> None:
66
+ """Run ensemble inference."""
67
+ from landmarkdiff.ensemble import ensemble_inference
68
+
69
+ ensemble_inference(
70
+ image_path=args.image,
71
+ procedure=args.procedure,
72
+ intensity=args.intensity,
73
+ output_dir=args.output,
74
+ n_samples=args.n_samples,
75
+ strategy=args.strategy,
76
+ mode=args.mode,
77
+ controlnet_checkpoint=args.checkpoint,
78
+ displacement_model_path=args.displacement_model,
79
+ seed=args.seed,
80
+ )
81
+
82
+
83
+ def cmd_evaluate(args: argparse.Namespace) -> None:
84
+ """Run evaluation on test set."""
85
+ from pathlib import Path
86
+
87
+ # Import evaluation functions
88
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
89
+ from scripts.run_evaluation import run_evaluation
90
+
91
+ run_evaluation(
92
+ test_dir=args.test_dir,
93
+ output_dir=args.output,
94
+ checkpoint=args.checkpoint,
95
+ max_samples=args.max_samples,
96
+ )
97
+
98
+
99
+ def cmd_config(args: argparse.Namespace) -> None:
100
+ """Show or validate configuration."""
101
+ from landmarkdiff.config import ExperimentConfig, load_config, validate_config
102
+
103
+ config = load_config(args.file) if args.file else ExperimentConfig()
104
+
105
+ if args.validate:
106
+ warnings = validate_config(config)
107
+ if warnings:
108
+ print("Validation warnings:")
109
+ for w in warnings:
110
+ print(f" - {w}")
111
+ else:
112
+ print("Configuration valid (no warnings).")
113
+ else:
114
+ from dataclasses import asdict
115
+
116
+ import yaml
117
+
118
+ print(yaml.dump(asdict(config), default_flow_style=False, sort_keys=False))
119
+
120
+
121
+ def cmd_validate(args: argparse.Namespace) -> None:
122
+ """Run safety validation on an output image."""
123
+ import cv2
124
+
125
+ from landmarkdiff.safety import SafetyValidator
126
+
127
+ input_img = cv2.imread(args.input)
128
+ output_img = cv2.imread(args.output_image)
129
+
130
+ if input_img is None or output_img is None:
131
+ logger.error("Cannot read input or output image.")
132
+ sys.exit(1)
133
+
134
+ validator = SafetyValidator(
135
+ watermark_enabled=args.watermark,
136
+ )
137
+
138
+ result = validator.validate(
139
+ input_image=input_img,
140
+ output_image=output_img,
141
+ face_confidence=args.face_confidence,
142
+ )
143
+
144
+ print(result.summary())
145
+
146
+ if not result.passed:
147
+ sys.exit(1)
148
+
149
+
150
+ def cmd_version(args: argparse.Namespace) -> None:
151
+ """Print version info."""
152
+ from landmarkdiff import __version__
153
+
154
+ print(f"LandmarkDiff v{__version__}")
155
+
156
+
157
+ def main(argv: list[str] | None = None) -> None:
158
+ """Main CLI entry point."""
159
+ parser = argparse.ArgumentParser(
160
+ prog="landmarkdiff",
161
+ description="LandmarkDiff: Facial surgery outcome prediction via latent diffusion",
162
+ )
163
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
164
+
165
+ # --- infer ---
166
+ p_infer = subparsers.add_parser("infer", help="Run single-image inference")
167
+ p_infer.add_argument("image", help="Input face image path")
168
+ p_infer.add_argument(
169
+ "--procedure",
170
+ default="rhinoplasty",
171
+ choices=[
172
+ "rhinoplasty",
173
+ "blepharoplasty",
174
+ "rhytidectomy",
175
+ "orthognathic",
176
+ "brow_lift",
177
+ "mentoplasty",
178
+ ],
179
+ )
180
+ p_infer.add_argument("--intensity", type=float, default=65.0)
181
+ p_infer.add_argument("--output", default="output.png")
182
+ p_infer.add_argument(
183
+ "--mode",
184
+ default="tps",
185
+ choices=["controlnet", "controlnet_ip", "controlnet_fast", "img2img", "tps"],
186
+ )
187
+ p_infer.add_argument("--checkpoint", default=None)
188
+ p_infer.add_argument("--displacement-model", default=None)
189
+ p_infer.add_argument("--seed", type=int, default=42)
190
+ p_infer.add_argument("--watermark", action="store_true")
191
+ p_infer.set_defaults(func=cmd_infer)
192
+
193
+ # --- ensemble ---
194
+ p_ensemble = subparsers.add_parser("ensemble", help="Run ensemble inference")
195
+ p_ensemble.add_argument("image", help="Input face image path")
196
+ p_ensemble.add_argument("--procedure", default="rhinoplasty")
197
+ p_ensemble.add_argument("--intensity", type=float, default=65.0)
198
+ p_ensemble.add_argument("--output", default="ensemble_output")
199
+ p_ensemble.add_argument("--n-samples", type=int, default=5)
200
+ p_ensemble.add_argument(
201
+ "--strategy",
202
+ default="best_of_n",
203
+ choices=["pixel_average", "weighted_average", "best_of_n", "median"],
204
+ )
205
+ p_ensemble.add_argument(
206
+ "--mode",
207
+ default="tps",
208
+ choices=["controlnet", "controlnet_ip", "controlnet_fast", "img2img", "tps"],
209
+ )
210
+ p_ensemble.add_argument("--checkpoint", default=None)
211
+ p_ensemble.add_argument("--displacement-model", default=None)
212
+ p_ensemble.add_argument("--seed", type=int, default=42)
213
+ p_ensemble.set_defaults(func=cmd_ensemble)
214
+
215
+ # --- evaluate ---
216
+ p_eval = subparsers.add_parser("evaluate", help="Evaluate on test set")
217
+ p_eval.add_argument("--test-dir", required=True)
218
+ p_eval.add_argument("--output", default="eval_results")
219
+ p_eval.add_argument("--mode", default="tps")
220
+ p_eval.add_argument("--checkpoint", default=None)
221
+ p_eval.add_argument("--displacement-model", default=None)
222
+ p_eval.add_argument("--max-samples", type=int, default=0)
223
+ p_eval.set_defaults(func=cmd_evaluate)
224
+
225
+ # --- config ---
226
+ p_config = subparsers.add_parser("config", help="Show or validate configuration")
227
+ p_config.add_argument("--file", default=None, help="YAML config file")
228
+ p_config.add_argument("--validate", action="store_true")
229
+ p_config.set_defaults(func=cmd_config)
230
+
231
+ # --- validate ---
232
+ p_validate = subparsers.add_parser("validate", help="Run safety validation")
233
+ p_validate.add_argument("input", help="Original input image")
234
+ p_validate.add_argument("output_image", help="Generated output image")
235
+ p_validate.add_argument("--watermark", action="store_true")
236
+ p_validate.add_argument("--face-confidence", type=float, default=1.0)
237
+ p_validate.set_defaults(func=cmd_validate)
238
+
239
+ # --- version ---
240
+ p_version = subparsers.add_parser("version", help="Print version")
241
+ p_version.set_defaults(func=cmd_version)
242
+
243
+ args = parser.parse_args(argv)
244
+ if not hasattr(args, "func"):
245
+ parser.print_help()
246
+ sys.exit(1)
247
+
248
+ args.func(args)
249
+
250
+
251
+ if __name__ == "__main__":
252
+ main()
@@ -0,0 +1,223 @@
1
+ """Clinical edge case handling for pathological conditions.
2
+
3
+ Implements special-case logic for:
4
+ - Vitiligo: preserve depigmented patches (don't blend over them)
5
+ - Bell's palsy: disable bilateral symmetry in deformation vectors
6
+ - Keloid: flag keloid-prone areas to reduce aggressive compositing
7
+ - Ehlers-Danlos: wider influence radii for hypermobile tissue
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from dataclasses import dataclass, field
13
+
14
+ import cv2
15
+ import numpy as np
16
+
17
+ from landmarkdiff.landmarks import FaceLandmarks
18
+
19
+
20
+ @dataclass
21
+ class ClinicalFlags:
22
+ """Clinical condition flags that modify pipeline behavior.
23
+
24
+ Set flags to True to enable condition-specific handling.
25
+ """
26
+
27
+ vitiligo: bool = False
28
+ bells_palsy: bool = False
29
+ bells_palsy_side: str = "left" # affected side: "left" or "right"
30
+ keloid_prone: bool = False
31
+ keloid_regions: list[str] = field(default_factory=list) # e.g. ["jawline", "nose"]
32
+ ehlers_danlos: bool = False
33
+
34
+ def has_any(self) -> bool:
35
+ return self.vitiligo or self.bells_palsy or self.keloid_prone or self.ehlers_danlos
36
+
37
+
38
+ def detect_vitiligo_patches(
39
+ image: np.ndarray,
40
+ face: FaceLandmarks,
41
+ l_threshold: float = 85.0,
42
+ min_patch_area: int = 200,
43
+ ) -> np.ndarray:
44
+ """Detect depigmented (vitiligo) patches on face using LAB luminance.
45
+
46
+ Vitiligo patches appear as high-L, low-saturation regions that deviate
47
+ significantly from surrounding skin tone.
48
+
49
+ Args:
50
+ image: BGR face image.
51
+ face: Extracted landmarks for face ROI.
52
+ l_threshold: Luminance threshold (patches brighter than surrounding skin).
53
+ min_patch_area: Minimum contour area in pixels to count as a patch.
54
+
55
+ Returns:
56
+ Binary mask (uint8, 0/255) of detected vitiligo patches.
57
+ """
58
+ h, w = image.shape[:2]
59
+ lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB).astype(np.float32)
60
+
61
+ # Create face ROI mask from landmarks
62
+ coords = face.pixel_coords.astype(np.int32)
63
+ hull = cv2.convexHull(coords)
64
+ face_mask = np.zeros((h, w), dtype=np.uint8)
65
+ cv2.fillConvexPoly(face_mask, hull, 255)
66
+
67
+ # Get face-region luminance statistics
68
+ l_channel = lab[:, :, 0]
69
+ face_pixels = l_channel[face_mask > 0]
70
+ if len(face_pixels) == 0:
71
+ return np.zeros((h, w), dtype=np.uint8)
72
+
73
+ l_mean = np.mean(face_pixels)
74
+ l_std = np.std(face_pixels)
75
+
76
+ # Vitiligo patches: significantly brighter than mean skin
77
+ threshold = min(l_threshold, l_mean + 2.0 * l_std)
78
+ bright_mask = ((l_channel > threshold) & (face_mask > 0)).astype(np.uint8) * 255
79
+
80
+ # Also check for low saturation (a,b channels close to 128)
81
+ a_channel = lab[:, :, 1]
82
+ b_channel = lab[:, :, 2]
83
+ low_sat = ((np.abs(a_channel - 128) < 15) & (np.abs(b_channel - 128) < 15)).astype(
84
+ np.uint8
85
+ ) * 255
86
+
87
+ # Combined: bright AND low-saturation within face
88
+ vitiligo_raw = cv2.bitwise_and(bright_mask, low_sat)
89
+
90
+ # Filter small noise patches
91
+ contours, _ = cv2.findContours(vitiligo_raw, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
92
+ result = np.zeros((h, w), dtype=np.uint8)
93
+ for cnt in contours:
94
+ if cv2.contourArea(cnt) >= min_patch_area:
95
+ cv2.fillPoly(result, [cnt], 255)
96
+
97
+ return result
98
+
99
+
100
+ def adjust_mask_for_vitiligo(
101
+ mask: np.ndarray,
102
+ vitiligo_patches: np.ndarray,
103
+ preservation_factor: float = 0.3,
104
+ ) -> np.ndarray:
105
+ """Reduce mask intensity over vitiligo patches to preserve them.
106
+
107
+ Instead of full blending over depigmented patches, we reduce the
108
+ mask weight so the original vitiligo pattern shows through.
109
+
110
+ Args:
111
+ mask: Float32 surgical mask [0-1].
112
+ vitiligo_patches: Binary mask of vitiligo regions (0/255 uint8).
113
+ preservation_factor: How much to reduce blending (0=full blend, 1=fully preserve).
114
+
115
+ Returns:
116
+ Modified mask with reduced intensity over vitiligo patches.
117
+ """
118
+ patches_f = vitiligo_patches.astype(np.float32) / 255.0
119
+ reduction = patches_f * preservation_factor
120
+ return np.clip(mask - reduction, 0.0, 1.0)
121
+
122
+
123
+ def get_bells_palsy_side_indices(
124
+ side: str,
125
+ ) -> dict[str, list[int]]:
126
+ """Get landmark indices for the affected side in Bell's palsy.
127
+
128
+ In Bell's palsy, one side of the face is paralyzed. We should NOT
129
+ apply bilateral symmetric deformations — only deform the healthy side.
130
+
131
+ Returns:
132
+ Dict mapping region names to landmark indices on the affected side.
133
+ """
134
+ if side == "left":
135
+ return {
136
+ "eye": [33, 7, 163, 144, 145, 153, 154, 155, 133, 173, 157, 158, 159, 160, 161, 246],
137
+ "eyebrow": [70, 63, 105, 66, 107, 55, 65, 52, 53, 46],
138
+ "mouth_corner": [61, 146, 91, 181, 84],
139
+ "jawline": [132, 136, 172, 58, 150, 176, 148, 149],
140
+ }
141
+ else:
142
+ return {
143
+ "eye": [362, 382, 381, 380, 374, 373, 390, 249, 263, 466, 388, 387, 386, 385, 384, 398],
144
+ "eyebrow": [300, 293, 334, 296, 336, 285, 295, 282, 283, 276],
145
+ "mouth_corner": [291, 308, 324, 318, 402],
146
+ "jawline": [361, 365, 397, 288, 379, 400, 377, 378],
147
+ }
148
+
149
+
150
+ def get_keloid_exclusion_mask(
151
+ face: FaceLandmarks,
152
+ regions: list[str],
153
+ width: int,
154
+ height: int,
155
+ margin_px: int = 10,
156
+ ) -> np.ndarray:
157
+ """Generate mask of keloid-prone regions to exclude from aggressive compositing.
158
+
159
+ Keloid patients should have reduced blending intensity and no sharp
160
+ boundary transitions in prone areas (typically jawline, ears, chest).
161
+
162
+ Args:
163
+ face: Extracted landmarks.
164
+ regions: List of region names prone to keloids.
165
+ width: Image width.
166
+ height: Image height.
167
+ margin_px: Extra margin around keloid regions.
168
+
169
+ Returns:
170
+ Float32 mask [0-1] where 1 = keloid-prone area.
171
+ """
172
+ from landmarkdiff.landmarks import LANDMARK_REGIONS
173
+
174
+ mask = np.zeros((height, width), dtype=np.float32)
175
+ coords = face.pixel_coords.astype(np.int32)
176
+
177
+ for region in regions:
178
+ indices = LANDMARK_REGIONS.get(region, [])
179
+ if not indices:
180
+ continue
181
+ pts = coords[indices]
182
+ hull = cv2.convexHull(pts)
183
+ cv2.fillConvexPoly(mask, hull, 1.0)
184
+
185
+ # Dilate by margin
186
+ if margin_px > 0:
187
+ kernel = cv2.getStructuringElement(
188
+ cv2.MORPH_ELLIPSE, (2 * margin_px + 1, 2 * margin_px + 1)
189
+ )
190
+ mask = cv2.dilate(mask, kernel)
191
+
192
+ return np.clip(mask, 0.0, 1.0)
193
+
194
+
195
+ def adjust_mask_for_keloid(
196
+ mask: np.ndarray,
197
+ keloid_mask: np.ndarray,
198
+ reduction_factor: float = 0.5,
199
+ ) -> np.ndarray:
200
+ """Soften mask transitions in keloid-prone areas.
201
+
202
+ Reduces the mask gradient steepness to prevent hard boundaries
203
+ that could trigger keloid formation in real surgical planning.
204
+
205
+ Args:
206
+ mask: Float32 surgical mask [0-1].
207
+ keloid_mask: Float32 keloid region mask [0-1].
208
+ reduction_factor: How much to reduce mask intensity in keloid areas.
209
+
210
+ Returns:
211
+ Modified mask with gentler transitions in keloid regions.
212
+ """
213
+ # Reduce mask intensity in keloid-prone areas
214
+ keloid_reduction = keloid_mask * reduction_factor
215
+ modified = mask * (1.0 - keloid_reduction)
216
+
217
+ # Extra Gaussian blur in keloid regions for softer transitions
218
+ blur_kernel = 31
219
+ blurred = cv2.GaussianBlur(modified, (blur_kernel, blur_kernel), 10.0)
220
+
221
+ # Use blurred version only in keloid regions
222
+ result = modified * (1.0 - keloid_mask) + blurred * keloid_mask
223
+ return np.clip(result, 0.0, 1.0)