landmarkdiff 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. landmarkdiff/__init__.py +40 -0
  2. landmarkdiff/__main__.py +207 -0
  3. landmarkdiff/api_client.py +316 -0
  4. landmarkdiff/arcface_torch.py +583 -0
  5. landmarkdiff/audit.py +338 -0
  6. landmarkdiff/augmentation.py +293 -0
  7. landmarkdiff/benchmark.py +213 -0
  8. landmarkdiff/checkpoint_manager.py +361 -0
  9. landmarkdiff/cli.py +252 -0
  10. landmarkdiff/clinical.py +223 -0
  11. landmarkdiff/conditioning.py +278 -0
  12. landmarkdiff/config.py +358 -0
  13. landmarkdiff/curriculum.py +191 -0
  14. landmarkdiff/data.py +405 -0
  15. landmarkdiff/data_version.py +301 -0
  16. landmarkdiff/displacement_model.py +745 -0
  17. landmarkdiff/ensemble.py +330 -0
  18. landmarkdiff/evaluation.py +415 -0
  19. landmarkdiff/experiment_tracker.py +231 -0
  20. landmarkdiff/face_verifier.py +947 -0
  21. landmarkdiff/fid.py +244 -0
  22. landmarkdiff/hyperparam.py +347 -0
  23. landmarkdiff/inference.py +754 -0
  24. landmarkdiff/landmarks.py +432 -0
  25. landmarkdiff/log.py +90 -0
  26. landmarkdiff/losses.py +348 -0
  27. landmarkdiff/manipulation.py +651 -0
  28. landmarkdiff/masking.py +316 -0
  29. landmarkdiff/metrics_agg.py +313 -0
  30. landmarkdiff/metrics_viz.py +464 -0
  31. landmarkdiff/model_registry.py +362 -0
  32. landmarkdiff/morphometry.py +342 -0
  33. landmarkdiff/postprocess.py +600 -0
  34. landmarkdiff/py.typed +0 -0
  35. landmarkdiff/safety.py +395 -0
  36. landmarkdiff/synthetic/__init__.py +23 -0
  37. landmarkdiff/synthetic/augmentation.py +188 -0
  38. landmarkdiff/synthetic/pair_generator.py +208 -0
  39. landmarkdiff/synthetic/tps_warp.py +273 -0
  40. landmarkdiff/validation.py +324 -0
  41. landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
  42. landmarkdiff-0.2.3.dist-info/RECORD +46 -0
  43. landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
  44. landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
  45. landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
  46. landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,40 @@
1
+ """LandmarkDiff: Anatomically-conditioned latent diffusion for facial surgery simulation."""
2
+
3
+ __version__ = "0.2.3"
4
+
5
+ __all__ = [
6
+ "api_client",
7
+ "arcface_torch",
8
+ "audit",
9
+ "augmentation",
10
+ "benchmark",
11
+ "checkpoint_manager",
12
+ "cli",
13
+ "clinical",
14
+ "conditioning",
15
+ "config",
16
+ "curriculum",
17
+ "data",
18
+ "data_version",
19
+ "displacement_model",
20
+ "ensemble",
21
+ "evaluation",
22
+ "experiment_tracker",
23
+ "face_verifier",
24
+ "fid",
25
+ "hyperparam",
26
+ "inference",
27
+ "landmarks",
28
+ "log",
29
+ "losses",
30
+ "manipulation",
31
+ "masking",
32
+ "metrics_agg",
33
+ "metrics_viz",
34
+ "model_registry",
35
+ "morphometry",
36
+ "postprocess",
37
+ "safety",
38
+ "synthetic",
39
+ "validation",
40
+ ]
@@ -0,0 +1,207 @@
1
+ """CLI entry point for python -m landmarkdiff."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import sys
7
+ from pathlib import Path
8
+ from typing import NoReturn
9
+
10
+
11
+ def _error(msg: str) -> NoReturn:
12
+ """Print error to stderr and exit."""
13
+ print(f"error: {msg}", file=sys.stderr)
14
+ sys.exit(1)
15
+
16
+
17
+ def _validate_image_path(path_str: str) -> Path:
18
+ """Validate that the image path exists and looks like an image file."""
19
+ p = Path(path_str)
20
+ if not p.exists():
21
+ _error(f"file not found: {path_str}")
22
+ if not p.is_file():
23
+ _error(f"not a file: {path_str}")
24
+ return p
25
+
26
+
27
+ def main() -> None:
28
+ from landmarkdiff import __version__
29
+
30
+ parser = argparse.ArgumentParser(
31
+ prog="landmarkdiff",
32
+ description="Facial surgery outcome prediction from clinical photography",
33
+ )
34
+ parser.add_argument("--version", action="version", version=f"landmarkdiff {__version__}")
35
+
36
+ subparsers = parser.add_subparsers(dest="command")
37
+
38
+ # inference
39
+ infer = subparsers.add_parser("infer", help="Run inference on an image")
40
+ infer.add_argument("image", type=str, help="Path to input face image")
41
+ infer.add_argument(
42
+ "--procedure",
43
+ type=str,
44
+ default="rhinoplasty",
45
+ choices=[
46
+ "rhinoplasty",
47
+ "blepharoplasty",
48
+ "rhytidectomy",
49
+ "orthognathic",
50
+ "brow_lift",
51
+ "mentoplasty",
52
+ ],
53
+ help="Surgical procedure to simulate (default: rhinoplasty)",
54
+ )
55
+ infer.add_argument(
56
+ "--intensity",
57
+ type=float,
58
+ default=60.0,
59
+ help="Deformation intensity, 0-100 (default: 60)",
60
+ )
61
+ infer.add_argument(
62
+ "--mode",
63
+ type=str,
64
+ default="tps",
65
+ choices=["tps", "controlnet", "img2img", "controlnet_ip"],
66
+ help="Inference mode (default: tps, others require GPU)",
67
+ )
68
+ infer.add_argument(
69
+ "--output",
70
+ type=str,
71
+ default="output/",
72
+ help="Output directory (default: output/)",
73
+ )
74
+ infer.add_argument(
75
+ "--steps",
76
+ type=int,
77
+ default=30,
78
+ help="Number of diffusion steps (default: 30)",
79
+ )
80
+ infer.add_argument(
81
+ "--seed",
82
+ type=int,
83
+ default=None,
84
+ help="Random seed for reproducibility",
85
+ )
86
+
87
+ # landmarks
88
+ lm = subparsers.add_parser("landmarks", help="Extract and visualize landmarks")
89
+ lm.add_argument("image", type=str, help="Path to input face image")
90
+ lm.add_argument(
91
+ "--output",
92
+ type=str,
93
+ default="output/landmarks.png",
94
+ help="Output path for landmark visualization (default: output/landmarks.png)",
95
+ )
96
+
97
+ # demo
98
+ subparsers.add_parser("demo", help="Launch Gradio web demo")
99
+
100
+ args = parser.parse_args()
101
+
102
+ if args.command is None:
103
+ parser.print_help()
104
+ return
105
+
106
+ try:
107
+ if args.command == "infer":
108
+ _run_inference(args)
109
+ elif args.command == "landmarks":
110
+ _run_landmarks(args)
111
+ elif args.command == "demo":
112
+ _run_demo()
113
+ except KeyboardInterrupt:
114
+ sys.exit(130)
115
+ except Exception as exc:
116
+ _error(str(exc))
117
+
118
+
119
+ def _run_inference(args: argparse.Namespace) -> None:
120
+ import numpy as np
121
+ from PIL import Image
122
+
123
+ from landmarkdiff.landmarks import extract_landmarks
124
+ from landmarkdiff.manipulation import apply_procedure_preset
125
+
126
+ if not (0 <= args.intensity <= 100):
127
+ _error(f"intensity must be between 0 and 100, got {args.intensity}")
128
+
129
+ image_path = _validate_image_path(args.image)
130
+
131
+ output_dir = Path(args.output)
132
+ output_dir.mkdir(parents=True, exist_ok=True)
133
+
134
+ img = Image.open(image_path).convert("RGB").resize((512, 512))
135
+ img_array = np.array(img)
136
+
137
+ landmarks = extract_landmarks(img_array)
138
+ if landmarks is None:
139
+ _error("no face detected in image")
140
+
141
+ deformed = apply_procedure_preset(landmarks, args.procedure, intensity=args.intensity)
142
+
143
+ if args.mode == "tps":
144
+ from landmarkdiff.synthetic.tps_warp import warp_image_tps
145
+
146
+ src = landmarks.pixel_coords[:, :2].copy()
147
+ dst = deformed.pixel_coords[:, :2].copy()
148
+ src[:, 0] *= 512 / landmarks.image_width
149
+ src[:, 1] *= 512 / landmarks.image_height
150
+ dst[:, 0] *= 512 / deformed.image_width
151
+ dst[:, 1] *= 512 / deformed.image_height
152
+ warped = warp_image_tps(img_array, src, dst)
153
+ Image.fromarray(warped).save(str(output_dir / "prediction.png"))
154
+ print(f"saved tps result to {output_dir / 'prediction.png'}")
155
+ else:
156
+ import torch
157
+
158
+ from landmarkdiff.inference import LandmarkDiffPipeline
159
+
160
+ pipeline = LandmarkDiffPipeline(mode=args.mode, device=torch.device("cuda"))
161
+ pipeline.load()
162
+ result = pipeline.generate(
163
+ img_array,
164
+ procedure=args.procedure,
165
+ intensity=args.intensity,
166
+ num_inference_steps=args.steps,
167
+ seed=args.seed,
168
+ )
169
+ result["output"].save(str(output_dir / "prediction.png"))
170
+ print(f"saved result to {output_dir / 'prediction.png'}")
171
+
172
+
173
+ def _run_landmarks(args: argparse.Namespace) -> None:
174
+ import numpy as np
175
+ from PIL import Image
176
+
177
+ from landmarkdiff.landmarks import extract_landmarks, render_landmark_image
178
+
179
+ image_path = _validate_image_path(args.image)
180
+
181
+ img = np.array(Image.open(image_path).convert("RGB").resize((512, 512)))
182
+ landmarks = extract_landmarks(img)
183
+ if landmarks is None:
184
+ _error("no face detected in image")
185
+
186
+ mesh = render_landmark_image(landmarks, 512, 512)
187
+
188
+ output_path = Path(args.output)
189
+ output_path.parent.mkdir(parents=True, exist_ok=True)
190
+
191
+ Image.fromarray(mesh).save(str(output_path))
192
+ print(f"saved landmark mesh to {output_path}")
193
+ print(f"detected {len(landmarks.landmarks)} landmarks, confidence {landmarks.confidence:.2f}")
194
+
195
+
196
+ def _run_demo() -> None:
197
+ try:
198
+ from scripts.app import build_app
199
+
200
+ demo = build_app()
201
+ demo.launch()
202
+ except ImportError:
203
+ _error("gradio not installed - run: pip install landmarkdiff[app]")
204
+
205
+
206
+ if __name__ == "__main__":
207
+ main()
@@ -0,0 +1,316 @@
1
+ """Python client for the LandmarkDiff REST API.
2
+
3
+ Provides a clean interface for interacting with the FastAPI server,
4
+ handling image encoding/decoding, error handling, and session management.
5
+
6
+ Usage:
7
+ from landmarkdiff.api_client import LandmarkDiffClient
8
+
9
+ client = LandmarkDiffClient("http://localhost:8000")
10
+
11
+ # Single prediction
12
+ result = client.predict("patient.png", procedure="rhinoplasty", intensity=65)
13
+ result.save("output.png")
14
+
15
+ # Face analysis
16
+ analysis = client.analyze("patient.png")
17
+ print(f"Fitzpatrick type: {analysis['fitzpatrick_type']}")
18
+
19
+ # Batch processing
20
+ results = client.batch_predict(
21
+ ["patient1.png", "patient2.png"],
22
+ procedure="blepharoplasty",
23
+ )
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import base64
29
+ from dataclasses import dataclass, field
30
+ from pathlib import Path
31
+ from typing import Any
32
+
33
+ import cv2
34
+ import numpy as np
35
+
36
+
37
+ class LandmarkDiffAPIError(Exception):
38
+ """Base exception for LandmarkDiff API errors."""
39
+
40
+
41
+ @dataclass
42
+ class PredictionResult:
43
+ """Result from a single prediction."""
44
+
45
+ output_image: np.ndarray
46
+ procedure: str
47
+ intensity: float
48
+ confidence: float = 0.0
49
+ landmarks_before: list[Any] | None = None
50
+ landmarks_after: list[Any] | None = None
51
+ metrics: dict[str, float] = field(default_factory=dict)
52
+ metadata: dict[str, Any] = field(default_factory=dict)
53
+
54
+ def save(self, path: str | Path, fmt: str = ".png") -> None:
55
+ """Save the output image to a file."""
56
+ cv2.imwrite(str(path), self.output_image)
57
+
58
+ def show(self) -> None:
59
+ """Display the output image (requires GUI)."""
60
+ cv2.imshow("LandmarkDiff Prediction", self.output_image)
61
+ cv2.waitKey(0)
62
+ cv2.destroyAllWindows()
63
+
64
+
65
+ class LandmarkDiffClient:
66
+ """Client for the LandmarkDiff REST API.
67
+
68
+ Args:
69
+ base_url: Server URL (e.g. "http://localhost:8000").
70
+ timeout: Request timeout in seconds.
71
+ """
72
+
73
+ def __init__(self, base_url: str = "http://localhost:8000", timeout: float = 60.0) -> None:
74
+ self.base_url = base_url.rstrip("/")
75
+ self.timeout = timeout
76
+ self._session: Any = None
77
+
78
+ def _get_session(self) -> Any:
79
+ """Lazy-initialize requests session."""
80
+ if self._session is None:
81
+ try:
82
+ import requests
83
+ except ImportError:
84
+ raise ImportError("requests required. Install with: pip install requests") from None
85
+ self._session = requests.Session()
86
+ self._session.timeout = self.timeout
87
+ return self._session
88
+
89
+ def _read_image(self, image_path: str | Path) -> bytes:
90
+ """Read image file as bytes."""
91
+ path = Path(image_path)
92
+ if not path.exists():
93
+ raise FileNotFoundError(f"Image not found: {path}")
94
+ return path.read_bytes()
95
+
96
+ def _decode_base64_image(self, b64_string: str) -> np.ndarray:
97
+ """Decode a base64-encoded image to numpy array."""
98
+ img_bytes = base64.b64decode(b64_string)
99
+ arr = np.frombuffer(img_bytes, np.uint8)
100
+ img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
101
+ if img is None:
102
+ raise ValueError("Failed to decode base64 image")
103
+ return img
104
+
105
+ # ------------------------------------------------------------------
106
+ # API methods
107
+ # ------------------------------------------------------------------
108
+
109
+ def health(self) -> dict[str, Any]:
110
+ """Check server health.
111
+
112
+ Returns:
113
+ Dict with status and version info.
114
+
115
+ Raises:
116
+ LandmarkDiffAPIError: If server is unreachable or returns an error.
117
+ """
118
+ session = self._get_session()
119
+ try:
120
+ resp = session.get(f"{self.base_url}/health")
121
+ resp.raise_for_status()
122
+ return resp.json()
123
+ except Exception as e:
124
+ import requests
125
+
126
+ if isinstance(e, requests.ConnectionError):
127
+ raise LandmarkDiffAPIError(
128
+ f"Cannot connect to LandmarkDiff server at {self.base_url}. "
129
+ f"Make sure the server is running (python -m landmarkdiff serve)."
130
+ ) from None
131
+ elif isinstance(e, requests.HTTPError):
132
+ raise LandmarkDiffAPIError(
133
+ f"Server returned error {e.response.status_code}: {e.response.text[:200]}"
134
+ ) from None
135
+ else:
136
+ raise
137
+
138
+ def procedures(self) -> list[str]:
139
+ """List available surgical procedures.
140
+
141
+ Returns:
142
+ List of procedure names.
143
+
144
+ Raises:
145
+ LandmarkDiffAPIError: If server is unreachable or returns an error.
146
+ """
147
+ session = self._get_session()
148
+ try:
149
+ resp = session.get(f"{self.base_url}/procedures")
150
+ resp.raise_for_status()
151
+ return resp.json().get("procedures", [])
152
+ except Exception as e:
153
+ import requests
154
+
155
+ if isinstance(e, requests.ConnectionError):
156
+ raise LandmarkDiffAPIError(
157
+ f"Cannot connect to LandmarkDiff server at {self.base_url}. "
158
+ f"Make sure the server is running (python -m landmarkdiff serve)."
159
+ ) from None
160
+ elif isinstance(e, requests.HTTPError):
161
+ raise LandmarkDiffAPIError(
162
+ f"Server returned error {e.response.status_code}: {e.response.text[:200]}"
163
+ ) from None
164
+ else:
165
+ raise
166
+
167
+ def predict(
168
+ self,
169
+ image_path: str | Path,
170
+ procedure: str = "rhinoplasty",
171
+ intensity: float = 65.0,
172
+ seed: int = 42,
173
+ ) -> PredictionResult:
174
+ """Run surgical outcome prediction.
175
+
176
+ Args:
177
+ image_path: Path to input face image.
178
+ procedure: Surgical procedure type.
179
+ intensity: Intensity of the modification (0-100).
180
+ seed: Random seed for reproducibility.
181
+
182
+ Returns:
183
+ PredictionResult with output image and metadata.
184
+ """
185
+ session = self._get_session()
186
+ image_bytes = self._read_image(image_path)
187
+
188
+ files = {"image": ("image.png", image_bytes, "image/png")}
189
+ data = {
190
+ "procedure": procedure,
191
+ "intensity": str(intensity),
192
+ "seed": str(seed),
193
+ }
194
+
195
+ resp = session.post(f"{self.base_url}/predict", files=files, data=data)
196
+ try:
197
+ resp.raise_for_status()
198
+ result = resp.json()
199
+
200
+ # Decode output image
201
+ output_img = self._decode_base64_image(result["output_image"])
202
+
203
+ return PredictionResult(
204
+ output_image=output_img,
205
+ procedure=procedure,
206
+ intensity=intensity,
207
+ confidence=result.get("confidence", 0.0),
208
+ metrics=result.get("metrics", {}),
209
+ metadata=result.get("metadata", {}),
210
+ )
211
+ except Exception as e:
212
+ import requests
213
+
214
+ if isinstance(e, requests.ConnectionError):
215
+ raise LandmarkDiffAPIError(
216
+ f"Cannot connect to LandmarkDiff server at {self.base_url}. "
217
+ f"Make sure the server is running (python -m landmarkdiff serve)."
218
+ ) from None
219
+ elif isinstance(e, requests.HTTPError):
220
+ raise LandmarkDiffAPIError(
221
+ f"Server returned error {e.response.status_code}: {e.response.text[:200]}"
222
+ ) from None
223
+ else:
224
+ raise
225
+
226
+ def analyze(self, image_path: str | Path) -> dict[str, Any]:
227
+ """Analyze a face image without generating a prediction.
228
+
229
+ Returns face landmarks, Fitzpatrick type, pose estimation, etc.
230
+
231
+ Args:
232
+ image_path: Path to input face image.
233
+
234
+ Returns:
235
+ Dict with analysis results.
236
+
237
+ Raises:
238
+ LandmarkDiffAPIError: If server is unreachable or returns an error.
239
+ """
240
+ session = self._get_session()
241
+ image_bytes = self._read_image(image_path)
242
+
243
+ files = {"image": ("image.png", image_bytes, "image/png")}
244
+ try:
245
+ resp = session.post(f"{self.base_url}/analyze", files=files)
246
+ resp.raise_for_status()
247
+ return resp.json()
248
+ except Exception as e:
249
+ import requests
250
+
251
+ if isinstance(e, requests.ConnectionError):
252
+ raise LandmarkDiffAPIError(
253
+ f"Cannot connect to LandmarkDiff server at {self.base_url}. "
254
+ f"Make sure the server is running (python -m landmarkdiff serve)."
255
+ ) from None
256
+ elif isinstance(e, requests.HTTPError):
257
+ raise LandmarkDiffAPIError(
258
+ f"Server returned error {e.response.status_code}: {e.response.text[:200]}"
259
+ ) from None
260
+ else:
261
+ raise
262
+
263
+ def batch_predict(
264
+ self,
265
+ image_paths: list[str | Path],
266
+ procedure: str = "rhinoplasty",
267
+ intensity: float = 65.0,
268
+ seed: int = 42,
269
+ ) -> list[PredictionResult]:
270
+ """Run batch prediction on multiple images.
271
+
272
+ Args:
273
+ image_paths: List of image file paths.
274
+ procedure: Procedure to apply to all images.
275
+ intensity: Intensity for all images.
276
+ seed: Base random seed.
277
+
278
+ Returns:
279
+ List of PredictionResult objects.
280
+ """
281
+ results = []
282
+ for i, path in enumerate(image_paths):
283
+ try:
284
+ result = self.predict(
285
+ path,
286
+ procedure=procedure,
287
+ intensity=intensity,
288
+ seed=seed + i,
289
+ )
290
+ results.append(result)
291
+ except Exception as e:
292
+ # Create a failed result
293
+ results.append(
294
+ PredictionResult(
295
+ output_image=np.zeros((512, 512, 3), dtype=np.uint8),
296
+ procedure=procedure,
297
+ intensity=intensity,
298
+ metadata={"error": str(e), "path": str(path)},
299
+ )
300
+ )
301
+ return results
302
+
303
+ def close(self) -> None:
304
+ """Close the HTTP session."""
305
+ if self._session is not None:
306
+ self._session.close()
307
+ self._session = None
308
+
309
+ def __enter__(self) -> LandmarkDiffClient:
310
+ return self
311
+
312
+ def __exit__(self, *args: Any) -> None:
313
+ self.close()
314
+
315
+ def __repr__(self) -> str:
316
+ return f"LandmarkDiffClient(base_url='{self.base_url}')"