landmarkdiff 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- landmarkdiff/__init__.py +40 -0
- landmarkdiff/__main__.py +207 -0
- landmarkdiff/api_client.py +316 -0
- landmarkdiff/arcface_torch.py +583 -0
- landmarkdiff/audit.py +338 -0
- landmarkdiff/augmentation.py +293 -0
- landmarkdiff/benchmark.py +213 -0
- landmarkdiff/checkpoint_manager.py +361 -0
- landmarkdiff/cli.py +252 -0
- landmarkdiff/clinical.py +223 -0
- landmarkdiff/conditioning.py +278 -0
- landmarkdiff/config.py +358 -0
- landmarkdiff/curriculum.py +191 -0
- landmarkdiff/data.py +405 -0
- landmarkdiff/data_version.py +301 -0
- landmarkdiff/displacement_model.py +745 -0
- landmarkdiff/ensemble.py +330 -0
- landmarkdiff/evaluation.py +415 -0
- landmarkdiff/experiment_tracker.py +231 -0
- landmarkdiff/face_verifier.py +947 -0
- landmarkdiff/fid.py +244 -0
- landmarkdiff/hyperparam.py +347 -0
- landmarkdiff/inference.py +754 -0
- landmarkdiff/landmarks.py +432 -0
- landmarkdiff/log.py +90 -0
- landmarkdiff/losses.py +348 -0
- landmarkdiff/manipulation.py +651 -0
- landmarkdiff/masking.py +316 -0
- landmarkdiff/metrics_agg.py +313 -0
- landmarkdiff/metrics_viz.py +464 -0
- landmarkdiff/model_registry.py +362 -0
- landmarkdiff/morphometry.py +342 -0
- landmarkdiff/postprocess.py +600 -0
- landmarkdiff/py.typed +0 -0
- landmarkdiff/safety.py +395 -0
- landmarkdiff/synthetic/__init__.py +23 -0
- landmarkdiff/synthetic/augmentation.py +188 -0
- landmarkdiff/synthetic/pair_generator.py +208 -0
- landmarkdiff/synthetic/tps_warp.py +273 -0
- landmarkdiff/validation.py +324 -0
- landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
- landmarkdiff-0.2.3.dist-info/RECORD +46 -0
- landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
- landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
- landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
- landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
landmarkdiff/__init__.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""LandmarkDiff: Anatomically-conditioned latent diffusion for facial surgery simulation."""
|
|
2
|
+
|
|
3
|
+
__version__ = "0.2.3"
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"api_client",
|
|
7
|
+
"arcface_torch",
|
|
8
|
+
"audit",
|
|
9
|
+
"augmentation",
|
|
10
|
+
"benchmark",
|
|
11
|
+
"checkpoint_manager",
|
|
12
|
+
"cli",
|
|
13
|
+
"clinical",
|
|
14
|
+
"conditioning",
|
|
15
|
+
"config",
|
|
16
|
+
"curriculum",
|
|
17
|
+
"data",
|
|
18
|
+
"data_version",
|
|
19
|
+
"displacement_model",
|
|
20
|
+
"ensemble",
|
|
21
|
+
"evaluation",
|
|
22
|
+
"experiment_tracker",
|
|
23
|
+
"face_verifier",
|
|
24
|
+
"fid",
|
|
25
|
+
"hyperparam",
|
|
26
|
+
"inference",
|
|
27
|
+
"landmarks",
|
|
28
|
+
"log",
|
|
29
|
+
"losses",
|
|
30
|
+
"manipulation",
|
|
31
|
+
"masking",
|
|
32
|
+
"metrics_agg",
|
|
33
|
+
"metrics_viz",
|
|
34
|
+
"model_registry",
|
|
35
|
+
"morphometry",
|
|
36
|
+
"postprocess",
|
|
37
|
+
"safety",
|
|
38
|
+
"synthetic",
|
|
39
|
+
"validation",
|
|
40
|
+
]
|
landmarkdiff/__main__.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
"""CLI entry point for python -m landmarkdiff."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import argparse
|
|
6
|
+
import sys
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import NoReturn
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _error(msg: str) -> NoReturn:
|
|
12
|
+
"""Print error to stderr and exit."""
|
|
13
|
+
print(f"error: {msg}", file=sys.stderr)
|
|
14
|
+
sys.exit(1)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _validate_image_path(path_str: str) -> Path:
|
|
18
|
+
"""Validate that the image path exists and looks like an image file."""
|
|
19
|
+
p = Path(path_str)
|
|
20
|
+
if not p.exists():
|
|
21
|
+
_error(f"file not found: {path_str}")
|
|
22
|
+
if not p.is_file():
|
|
23
|
+
_error(f"not a file: {path_str}")
|
|
24
|
+
return p
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def main() -> None:
|
|
28
|
+
from landmarkdiff import __version__
|
|
29
|
+
|
|
30
|
+
parser = argparse.ArgumentParser(
|
|
31
|
+
prog="landmarkdiff",
|
|
32
|
+
description="Facial surgery outcome prediction from clinical photography",
|
|
33
|
+
)
|
|
34
|
+
parser.add_argument("--version", action="version", version=f"landmarkdiff {__version__}")
|
|
35
|
+
|
|
36
|
+
subparsers = parser.add_subparsers(dest="command")
|
|
37
|
+
|
|
38
|
+
# inference
|
|
39
|
+
infer = subparsers.add_parser("infer", help="Run inference on an image")
|
|
40
|
+
infer.add_argument("image", type=str, help="Path to input face image")
|
|
41
|
+
infer.add_argument(
|
|
42
|
+
"--procedure",
|
|
43
|
+
type=str,
|
|
44
|
+
default="rhinoplasty",
|
|
45
|
+
choices=[
|
|
46
|
+
"rhinoplasty",
|
|
47
|
+
"blepharoplasty",
|
|
48
|
+
"rhytidectomy",
|
|
49
|
+
"orthognathic",
|
|
50
|
+
"brow_lift",
|
|
51
|
+
"mentoplasty",
|
|
52
|
+
],
|
|
53
|
+
help="Surgical procedure to simulate (default: rhinoplasty)",
|
|
54
|
+
)
|
|
55
|
+
infer.add_argument(
|
|
56
|
+
"--intensity",
|
|
57
|
+
type=float,
|
|
58
|
+
default=60.0,
|
|
59
|
+
help="Deformation intensity, 0-100 (default: 60)",
|
|
60
|
+
)
|
|
61
|
+
infer.add_argument(
|
|
62
|
+
"--mode",
|
|
63
|
+
type=str,
|
|
64
|
+
default="tps",
|
|
65
|
+
choices=["tps", "controlnet", "img2img", "controlnet_ip"],
|
|
66
|
+
help="Inference mode (default: tps, others require GPU)",
|
|
67
|
+
)
|
|
68
|
+
infer.add_argument(
|
|
69
|
+
"--output",
|
|
70
|
+
type=str,
|
|
71
|
+
default="output/",
|
|
72
|
+
help="Output directory (default: output/)",
|
|
73
|
+
)
|
|
74
|
+
infer.add_argument(
|
|
75
|
+
"--steps",
|
|
76
|
+
type=int,
|
|
77
|
+
default=30,
|
|
78
|
+
help="Number of diffusion steps (default: 30)",
|
|
79
|
+
)
|
|
80
|
+
infer.add_argument(
|
|
81
|
+
"--seed",
|
|
82
|
+
type=int,
|
|
83
|
+
default=None,
|
|
84
|
+
help="Random seed for reproducibility",
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# landmarks
|
|
88
|
+
lm = subparsers.add_parser("landmarks", help="Extract and visualize landmarks")
|
|
89
|
+
lm.add_argument("image", type=str, help="Path to input face image")
|
|
90
|
+
lm.add_argument(
|
|
91
|
+
"--output",
|
|
92
|
+
type=str,
|
|
93
|
+
default="output/landmarks.png",
|
|
94
|
+
help="Output path for landmark visualization (default: output/landmarks.png)",
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# demo
|
|
98
|
+
subparsers.add_parser("demo", help="Launch Gradio web demo")
|
|
99
|
+
|
|
100
|
+
args = parser.parse_args()
|
|
101
|
+
|
|
102
|
+
if args.command is None:
|
|
103
|
+
parser.print_help()
|
|
104
|
+
return
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
if args.command == "infer":
|
|
108
|
+
_run_inference(args)
|
|
109
|
+
elif args.command == "landmarks":
|
|
110
|
+
_run_landmarks(args)
|
|
111
|
+
elif args.command == "demo":
|
|
112
|
+
_run_demo()
|
|
113
|
+
except KeyboardInterrupt:
|
|
114
|
+
sys.exit(130)
|
|
115
|
+
except Exception as exc:
|
|
116
|
+
_error(str(exc))
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _run_inference(args: argparse.Namespace) -> None:
|
|
120
|
+
import numpy as np
|
|
121
|
+
from PIL import Image
|
|
122
|
+
|
|
123
|
+
from landmarkdiff.landmarks import extract_landmarks
|
|
124
|
+
from landmarkdiff.manipulation import apply_procedure_preset
|
|
125
|
+
|
|
126
|
+
if not (0 <= args.intensity <= 100):
|
|
127
|
+
_error(f"intensity must be between 0 and 100, got {args.intensity}")
|
|
128
|
+
|
|
129
|
+
image_path = _validate_image_path(args.image)
|
|
130
|
+
|
|
131
|
+
output_dir = Path(args.output)
|
|
132
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
133
|
+
|
|
134
|
+
img = Image.open(image_path).convert("RGB").resize((512, 512))
|
|
135
|
+
img_array = np.array(img)
|
|
136
|
+
|
|
137
|
+
landmarks = extract_landmarks(img_array)
|
|
138
|
+
if landmarks is None:
|
|
139
|
+
_error("no face detected in image")
|
|
140
|
+
|
|
141
|
+
deformed = apply_procedure_preset(landmarks, args.procedure, intensity=args.intensity)
|
|
142
|
+
|
|
143
|
+
if args.mode == "tps":
|
|
144
|
+
from landmarkdiff.synthetic.tps_warp import warp_image_tps
|
|
145
|
+
|
|
146
|
+
src = landmarks.pixel_coords[:, :2].copy()
|
|
147
|
+
dst = deformed.pixel_coords[:, :2].copy()
|
|
148
|
+
src[:, 0] *= 512 / landmarks.image_width
|
|
149
|
+
src[:, 1] *= 512 / landmarks.image_height
|
|
150
|
+
dst[:, 0] *= 512 / deformed.image_width
|
|
151
|
+
dst[:, 1] *= 512 / deformed.image_height
|
|
152
|
+
warped = warp_image_tps(img_array, src, dst)
|
|
153
|
+
Image.fromarray(warped).save(str(output_dir / "prediction.png"))
|
|
154
|
+
print(f"saved tps result to {output_dir / 'prediction.png'}")
|
|
155
|
+
else:
|
|
156
|
+
import torch
|
|
157
|
+
|
|
158
|
+
from landmarkdiff.inference import LandmarkDiffPipeline
|
|
159
|
+
|
|
160
|
+
pipeline = LandmarkDiffPipeline(mode=args.mode, device=torch.device("cuda"))
|
|
161
|
+
pipeline.load()
|
|
162
|
+
result = pipeline.generate(
|
|
163
|
+
img_array,
|
|
164
|
+
procedure=args.procedure,
|
|
165
|
+
intensity=args.intensity,
|
|
166
|
+
num_inference_steps=args.steps,
|
|
167
|
+
seed=args.seed,
|
|
168
|
+
)
|
|
169
|
+
result["output"].save(str(output_dir / "prediction.png"))
|
|
170
|
+
print(f"saved result to {output_dir / 'prediction.png'}")
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _run_landmarks(args: argparse.Namespace) -> None:
|
|
174
|
+
import numpy as np
|
|
175
|
+
from PIL import Image
|
|
176
|
+
|
|
177
|
+
from landmarkdiff.landmarks import extract_landmarks, render_landmark_image
|
|
178
|
+
|
|
179
|
+
image_path = _validate_image_path(args.image)
|
|
180
|
+
|
|
181
|
+
img = np.array(Image.open(image_path).convert("RGB").resize((512, 512)))
|
|
182
|
+
landmarks = extract_landmarks(img)
|
|
183
|
+
if landmarks is None:
|
|
184
|
+
_error("no face detected in image")
|
|
185
|
+
|
|
186
|
+
mesh = render_landmark_image(landmarks, 512, 512)
|
|
187
|
+
|
|
188
|
+
output_path = Path(args.output)
|
|
189
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
190
|
+
|
|
191
|
+
Image.fromarray(mesh).save(str(output_path))
|
|
192
|
+
print(f"saved landmark mesh to {output_path}")
|
|
193
|
+
print(f"detected {len(landmarks.landmarks)} landmarks, confidence {landmarks.confidence:.2f}")
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def _run_demo() -> None:
|
|
197
|
+
try:
|
|
198
|
+
from scripts.app import build_app
|
|
199
|
+
|
|
200
|
+
demo = build_app()
|
|
201
|
+
demo.launch()
|
|
202
|
+
except ImportError:
|
|
203
|
+
_error("gradio not installed - run: pip install landmarkdiff[app]")
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
if __name__ == "__main__":
|
|
207
|
+
main()
|
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
"""Python client for the LandmarkDiff REST API.
|
|
2
|
+
|
|
3
|
+
Provides a clean interface for interacting with the FastAPI server,
|
|
4
|
+
handling image encoding/decoding, error handling, and session management.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
from landmarkdiff.api_client import LandmarkDiffClient
|
|
8
|
+
|
|
9
|
+
client = LandmarkDiffClient("http://localhost:8000")
|
|
10
|
+
|
|
11
|
+
# Single prediction
|
|
12
|
+
result = client.predict("patient.png", procedure="rhinoplasty", intensity=65)
|
|
13
|
+
result.save("output.png")
|
|
14
|
+
|
|
15
|
+
# Face analysis
|
|
16
|
+
analysis = client.analyze("patient.png")
|
|
17
|
+
print(f"Fitzpatrick type: {analysis['fitzpatrick_type']}")
|
|
18
|
+
|
|
19
|
+
# Batch processing
|
|
20
|
+
results = client.batch_predict(
|
|
21
|
+
["patient1.png", "patient2.png"],
|
|
22
|
+
procedure="blepharoplasty",
|
|
23
|
+
)
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from __future__ import annotations
|
|
27
|
+
|
|
28
|
+
import base64
|
|
29
|
+
from dataclasses import dataclass, field
|
|
30
|
+
from pathlib import Path
|
|
31
|
+
from typing import Any
|
|
32
|
+
|
|
33
|
+
import cv2
|
|
34
|
+
import numpy as np
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class LandmarkDiffAPIError(Exception):
|
|
38
|
+
"""Base exception for LandmarkDiff API errors."""
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class PredictionResult:
|
|
43
|
+
"""Result from a single prediction."""
|
|
44
|
+
|
|
45
|
+
output_image: np.ndarray
|
|
46
|
+
procedure: str
|
|
47
|
+
intensity: float
|
|
48
|
+
confidence: float = 0.0
|
|
49
|
+
landmarks_before: list[Any] | None = None
|
|
50
|
+
landmarks_after: list[Any] | None = None
|
|
51
|
+
metrics: dict[str, float] = field(default_factory=dict)
|
|
52
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
53
|
+
|
|
54
|
+
def save(self, path: str | Path, fmt: str = ".png") -> None:
|
|
55
|
+
"""Save the output image to a file."""
|
|
56
|
+
cv2.imwrite(str(path), self.output_image)
|
|
57
|
+
|
|
58
|
+
def show(self) -> None:
|
|
59
|
+
"""Display the output image (requires GUI)."""
|
|
60
|
+
cv2.imshow("LandmarkDiff Prediction", self.output_image)
|
|
61
|
+
cv2.waitKey(0)
|
|
62
|
+
cv2.destroyAllWindows()
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class LandmarkDiffClient:
|
|
66
|
+
"""Client for the LandmarkDiff REST API.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
base_url: Server URL (e.g. "http://localhost:8000").
|
|
70
|
+
timeout: Request timeout in seconds.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(self, base_url: str = "http://localhost:8000", timeout: float = 60.0) -> None:
|
|
74
|
+
self.base_url = base_url.rstrip("/")
|
|
75
|
+
self.timeout = timeout
|
|
76
|
+
self._session: Any = None
|
|
77
|
+
|
|
78
|
+
def _get_session(self) -> Any:
|
|
79
|
+
"""Lazy-initialize requests session."""
|
|
80
|
+
if self._session is None:
|
|
81
|
+
try:
|
|
82
|
+
import requests
|
|
83
|
+
except ImportError:
|
|
84
|
+
raise ImportError("requests required. Install with: pip install requests") from None
|
|
85
|
+
self._session = requests.Session()
|
|
86
|
+
self._session.timeout = self.timeout
|
|
87
|
+
return self._session
|
|
88
|
+
|
|
89
|
+
def _read_image(self, image_path: str | Path) -> bytes:
|
|
90
|
+
"""Read image file as bytes."""
|
|
91
|
+
path = Path(image_path)
|
|
92
|
+
if not path.exists():
|
|
93
|
+
raise FileNotFoundError(f"Image not found: {path}")
|
|
94
|
+
return path.read_bytes()
|
|
95
|
+
|
|
96
|
+
def _decode_base64_image(self, b64_string: str) -> np.ndarray:
|
|
97
|
+
"""Decode a base64-encoded image to numpy array."""
|
|
98
|
+
img_bytes = base64.b64decode(b64_string)
|
|
99
|
+
arr = np.frombuffer(img_bytes, np.uint8)
|
|
100
|
+
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
|
101
|
+
if img is None:
|
|
102
|
+
raise ValueError("Failed to decode base64 image")
|
|
103
|
+
return img
|
|
104
|
+
|
|
105
|
+
# ------------------------------------------------------------------
|
|
106
|
+
# API methods
|
|
107
|
+
# ------------------------------------------------------------------
|
|
108
|
+
|
|
109
|
+
def health(self) -> dict[str, Any]:
|
|
110
|
+
"""Check server health.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Dict with status and version info.
|
|
114
|
+
|
|
115
|
+
Raises:
|
|
116
|
+
LandmarkDiffAPIError: If server is unreachable or returns an error.
|
|
117
|
+
"""
|
|
118
|
+
session = self._get_session()
|
|
119
|
+
try:
|
|
120
|
+
resp = session.get(f"{self.base_url}/health")
|
|
121
|
+
resp.raise_for_status()
|
|
122
|
+
return resp.json()
|
|
123
|
+
except Exception as e:
|
|
124
|
+
import requests
|
|
125
|
+
|
|
126
|
+
if isinstance(e, requests.ConnectionError):
|
|
127
|
+
raise LandmarkDiffAPIError(
|
|
128
|
+
f"Cannot connect to LandmarkDiff server at {self.base_url}. "
|
|
129
|
+
f"Make sure the server is running (python -m landmarkdiff serve)."
|
|
130
|
+
) from None
|
|
131
|
+
elif isinstance(e, requests.HTTPError):
|
|
132
|
+
raise LandmarkDiffAPIError(
|
|
133
|
+
f"Server returned error {e.response.status_code}: {e.response.text[:200]}"
|
|
134
|
+
) from None
|
|
135
|
+
else:
|
|
136
|
+
raise
|
|
137
|
+
|
|
138
|
+
def procedures(self) -> list[str]:
|
|
139
|
+
"""List available surgical procedures.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
List of procedure names.
|
|
143
|
+
|
|
144
|
+
Raises:
|
|
145
|
+
LandmarkDiffAPIError: If server is unreachable or returns an error.
|
|
146
|
+
"""
|
|
147
|
+
session = self._get_session()
|
|
148
|
+
try:
|
|
149
|
+
resp = session.get(f"{self.base_url}/procedures")
|
|
150
|
+
resp.raise_for_status()
|
|
151
|
+
return resp.json().get("procedures", [])
|
|
152
|
+
except Exception as e:
|
|
153
|
+
import requests
|
|
154
|
+
|
|
155
|
+
if isinstance(e, requests.ConnectionError):
|
|
156
|
+
raise LandmarkDiffAPIError(
|
|
157
|
+
f"Cannot connect to LandmarkDiff server at {self.base_url}. "
|
|
158
|
+
f"Make sure the server is running (python -m landmarkdiff serve)."
|
|
159
|
+
) from None
|
|
160
|
+
elif isinstance(e, requests.HTTPError):
|
|
161
|
+
raise LandmarkDiffAPIError(
|
|
162
|
+
f"Server returned error {e.response.status_code}: {e.response.text[:200]}"
|
|
163
|
+
) from None
|
|
164
|
+
else:
|
|
165
|
+
raise
|
|
166
|
+
|
|
167
|
+
def predict(
|
|
168
|
+
self,
|
|
169
|
+
image_path: str | Path,
|
|
170
|
+
procedure: str = "rhinoplasty",
|
|
171
|
+
intensity: float = 65.0,
|
|
172
|
+
seed: int = 42,
|
|
173
|
+
) -> PredictionResult:
|
|
174
|
+
"""Run surgical outcome prediction.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
image_path: Path to input face image.
|
|
178
|
+
procedure: Surgical procedure type.
|
|
179
|
+
intensity: Intensity of the modification (0-100).
|
|
180
|
+
seed: Random seed for reproducibility.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
PredictionResult with output image and metadata.
|
|
184
|
+
"""
|
|
185
|
+
session = self._get_session()
|
|
186
|
+
image_bytes = self._read_image(image_path)
|
|
187
|
+
|
|
188
|
+
files = {"image": ("image.png", image_bytes, "image/png")}
|
|
189
|
+
data = {
|
|
190
|
+
"procedure": procedure,
|
|
191
|
+
"intensity": str(intensity),
|
|
192
|
+
"seed": str(seed),
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
resp = session.post(f"{self.base_url}/predict", files=files, data=data)
|
|
196
|
+
try:
|
|
197
|
+
resp.raise_for_status()
|
|
198
|
+
result = resp.json()
|
|
199
|
+
|
|
200
|
+
# Decode output image
|
|
201
|
+
output_img = self._decode_base64_image(result["output_image"])
|
|
202
|
+
|
|
203
|
+
return PredictionResult(
|
|
204
|
+
output_image=output_img,
|
|
205
|
+
procedure=procedure,
|
|
206
|
+
intensity=intensity,
|
|
207
|
+
confidence=result.get("confidence", 0.0),
|
|
208
|
+
metrics=result.get("metrics", {}),
|
|
209
|
+
metadata=result.get("metadata", {}),
|
|
210
|
+
)
|
|
211
|
+
except Exception as e:
|
|
212
|
+
import requests
|
|
213
|
+
|
|
214
|
+
if isinstance(e, requests.ConnectionError):
|
|
215
|
+
raise LandmarkDiffAPIError(
|
|
216
|
+
f"Cannot connect to LandmarkDiff server at {self.base_url}. "
|
|
217
|
+
f"Make sure the server is running (python -m landmarkdiff serve)."
|
|
218
|
+
) from None
|
|
219
|
+
elif isinstance(e, requests.HTTPError):
|
|
220
|
+
raise LandmarkDiffAPIError(
|
|
221
|
+
f"Server returned error {e.response.status_code}: {e.response.text[:200]}"
|
|
222
|
+
) from None
|
|
223
|
+
else:
|
|
224
|
+
raise
|
|
225
|
+
|
|
226
|
+
def analyze(self, image_path: str | Path) -> dict[str, Any]:
|
|
227
|
+
"""Analyze a face image without generating a prediction.
|
|
228
|
+
|
|
229
|
+
Returns face landmarks, Fitzpatrick type, pose estimation, etc.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
image_path: Path to input face image.
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
Dict with analysis results.
|
|
236
|
+
|
|
237
|
+
Raises:
|
|
238
|
+
LandmarkDiffAPIError: If server is unreachable or returns an error.
|
|
239
|
+
"""
|
|
240
|
+
session = self._get_session()
|
|
241
|
+
image_bytes = self._read_image(image_path)
|
|
242
|
+
|
|
243
|
+
files = {"image": ("image.png", image_bytes, "image/png")}
|
|
244
|
+
try:
|
|
245
|
+
resp = session.post(f"{self.base_url}/analyze", files=files)
|
|
246
|
+
resp.raise_for_status()
|
|
247
|
+
return resp.json()
|
|
248
|
+
except Exception as e:
|
|
249
|
+
import requests
|
|
250
|
+
|
|
251
|
+
if isinstance(e, requests.ConnectionError):
|
|
252
|
+
raise LandmarkDiffAPIError(
|
|
253
|
+
f"Cannot connect to LandmarkDiff server at {self.base_url}. "
|
|
254
|
+
f"Make sure the server is running (python -m landmarkdiff serve)."
|
|
255
|
+
) from None
|
|
256
|
+
elif isinstance(e, requests.HTTPError):
|
|
257
|
+
raise LandmarkDiffAPIError(
|
|
258
|
+
f"Server returned error {e.response.status_code}: {e.response.text[:200]}"
|
|
259
|
+
) from None
|
|
260
|
+
else:
|
|
261
|
+
raise
|
|
262
|
+
|
|
263
|
+
def batch_predict(
|
|
264
|
+
self,
|
|
265
|
+
image_paths: list[str | Path],
|
|
266
|
+
procedure: str = "rhinoplasty",
|
|
267
|
+
intensity: float = 65.0,
|
|
268
|
+
seed: int = 42,
|
|
269
|
+
) -> list[PredictionResult]:
|
|
270
|
+
"""Run batch prediction on multiple images.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
image_paths: List of image file paths.
|
|
274
|
+
procedure: Procedure to apply to all images.
|
|
275
|
+
intensity: Intensity for all images.
|
|
276
|
+
seed: Base random seed.
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
List of PredictionResult objects.
|
|
280
|
+
"""
|
|
281
|
+
results = []
|
|
282
|
+
for i, path in enumerate(image_paths):
|
|
283
|
+
try:
|
|
284
|
+
result = self.predict(
|
|
285
|
+
path,
|
|
286
|
+
procedure=procedure,
|
|
287
|
+
intensity=intensity,
|
|
288
|
+
seed=seed + i,
|
|
289
|
+
)
|
|
290
|
+
results.append(result)
|
|
291
|
+
except Exception as e:
|
|
292
|
+
# Create a failed result
|
|
293
|
+
results.append(
|
|
294
|
+
PredictionResult(
|
|
295
|
+
output_image=np.zeros((512, 512, 3), dtype=np.uint8),
|
|
296
|
+
procedure=procedure,
|
|
297
|
+
intensity=intensity,
|
|
298
|
+
metadata={"error": str(e), "path": str(path)},
|
|
299
|
+
)
|
|
300
|
+
)
|
|
301
|
+
return results
|
|
302
|
+
|
|
303
|
+
def close(self) -> None:
|
|
304
|
+
"""Close the HTTP session."""
|
|
305
|
+
if self._session is not None:
|
|
306
|
+
self._session.close()
|
|
307
|
+
self._session = None
|
|
308
|
+
|
|
309
|
+
def __enter__(self) -> LandmarkDiffClient:
|
|
310
|
+
return self
|
|
311
|
+
|
|
312
|
+
def __exit__(self, *args: Any) -> None:
|
|
313
|
+
self.close()
|
|
314
|
+
|
|
315
|
+
def __repr__(self) -> str:
|
|
316
|
+
return f"LandmarkDiffClient(base_url='{self.base_url}')"
|