landmarkdiff 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. landmarkdiff/__init__.py +40 -0
  2. landmarkdiff/__main__.py +207 -0
  3. landmarkdiff/api_client.py +316 -0
  4. landmarkdiff/arcface_torch.py +583 -0
  5. landmarkdiff/audit.py +338 -0
  6. landmarkdiff/augmentation.py +293 -0
  7. landmarkdiff/benchmark.py +213 -0
  8. landmarkdiff/checkpoint_manager.py +361 -0
  9. landmarkdiff/cli.py +252 -0
  10. landmarkdiff/clinical.py +223 -0
  11. landmarkdiff/conditioning.py +278 -0
  12. landmarkdiff/config.py +358 -0
  13. landmarkdiff/curriculum.py +191 -0
  14. landmarkdiff/data.py +405 -0
  15. landmarkdiff/data_version.py +301 -0
  16. landmarkdiff/displacement_model.py +745 -0
  17. landmarkdiff/ensemble.py +330 -0
  18. landmarkdiff/evaluation.py +415 -0
  19. landmarkdiff/experiment_tracker.py +231 -0
  20. landmarkdiff/face_verifier.py +947 -0
  21. landmarkdiff/fid.py +244 -0
  22. landmarkdiff/hyperparam.py +347 -0
  23. landmarkdiff/inference.py +754 -0
  24. landmarkdiff/landmarks.py +432 -0
  25. landmarkdiff/log.py +90 -0
  26. landmarkdiff/losses.py +348 -0
  27. landmarkdiff/manipulation.py +651 -0
  28. landmarkdiff/masking.py +316 -0
  29. landmarkdiff/metrics_agg.py +313 -0
  30. landmarkdiff/metrics_viz.py +464 -0
  31. landmarkdiff/model_registry.py +362 -0
  32. landmarkdiff/morphometry.py +342 -0
  33. landmarkdiff/postprocess.py +600 -0
  34. landmarkdiff/py.typed +0 -0
  35. landmarkdiff/safety.py +395 -0
  36. landmarkdiff/synthetic/__init__.py +23 -0
  37. landmarkdiff/synthetic/augmentation.py +188 -0
  38. landmarkdiff/synthetic/pair_generator.py +208 -0
  39. landmarkdiff/synthetic/tps_warp.py +273 -0
  40. landmarkdiff/validation.py +324 -0
  41. landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
  42. landmarkdiff-0.2.3.dist-info/RECORD +46 -0
  43. landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
  44. landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
  45. landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
  46. landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,208 @@
1
+ """Synthetic training pair generator.
2
+
3
+ Creates (input, conditioning, mask, target) tuples for ControlNet fine-tuning.
4
+ Pipeline: FFHQ image -> extract landmarks -> random FFD manipulation ->
5
+ generate conditioning + mask -> apply clinical augmentation to input.
6
+
7
+ Augmentations are applied to INPUT only, never to target (ground truth).
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import logging
13
+ from collections.abc import Iterator
14
+ from dataclasses import dataclass
15
+ from pathlib import Path
16
+
17
+ import cv2
18
+ import numpy as np
19
+
20
+ from landmarkdiff.conditioning import generate_conditioning
21
+ from landmarkdiff.landmarks import extract_landmarks, render_landmark_image
22
+ from landmarkdiff.manipulation import (
23
+ PROCEDURE_LANDMARKS,
24
+ apply_procedure_preset,
25
+ )
26
+ from landmarkdiff.masking import generate_surgical_mask
27
+ from landmarkdiff.synthetic.augmentation import apply_clinical_augmentation
28
+ from landmarkdiff.synthetic.tps_warp import warp_image_tps
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ @dataclass(frozen=True)
34
+ class TrainingPair:
35
+ """A single training sample for ControlNet fine-tuning."""
36
+
37
+ input_image: np.ndarray # augmented input (512x512 BGR)
38
+ target_image: np.ndarray # clean target (512x512 BGR) -- TPS-warped original
39
+ conditioning: np.ndarray # landmark rendering (512x512 BGR)
40
+ canny: np.ndarray # canny edge map (512x512 grayscale)
41
+ mask: np.ndarray # feathered surgical mask (512x512 float32)
42
+ procedure: str
43
+ intensity: float
44
+
45
+
46
+ PROCEDURES = list(PROCEDURE_LANDMARKS.keys())
47
+
48
+
49
+ def generate_pair(
50
+ image: np.ndarray,
51
+ procedure: str | None = None,
52
+ intensity: float | None = None,
53
+ target_size: int = 512,
54
+ rng: np.random.Generator | None = None,
55
+ ) -> TrainingPair | None:
56
+ """Generate a single training pair from a face image.
57
+
58
+ Args:
59
+ image: BGR input image (any size).
60
+ procedure: Procedure type (random if None).
61
+ intensity: Manipulation intensity 0-100 (random 30-90 if None).
62
+ target_size: Output resolution.
63
+ rng: Random number generator.
64
+
65
+ Returns:
66
+ TrainingPair or None if face detection fails.
67
+ """
68
+ rng = rng or np.random.default_rng()
69
+
70
+ # Resize to target
71
+ resized = cv2.resize(image, (target_size, target_size))
72
+
73
+ # Extract landmarks
74
+ face = extract_landmarks(resized)
75
+ if face is None:
76
+ return None
77
+
78
+ # Random procedure and intensity if not specified
79
+ if procedure is None:
80
+ procedure = rng.choice(PROCEDURES)
81
+ if intensity is None:
82
+ intensity = float(rng.uniform(30, 90))
83
+
84
+ # Manipulate landmarks
85
+ manipulated = apply_procedure_preset(face, procedure, intensity, target_size)
86
+
87
+ # Generate conditioning from manipulated landmarks
88
+ landmark_img = render_landmark_image(manipulated, target_size, target_size)
89
+ _, canny, _ = generate_conditioning(manipulated, target_size, target_size)
90
+
91
+ # Generate mask
92
+ mask = generate_surgical_mask(face, procedure, target_size, target_size)
93
+
94
+ # Generate target: TPS warp the original image to match manipulated landmarks
95
+ src_px = face.pixel_coords
96
+ dst_px = manipulated.pixel_coords
97
+ target = warp_image_tps(resized, src_px, dst_px)
98
+
99
+ # Apply clinical augmentation to INPUT only (never target)
100
+ augmented_input = apply_clinical_augmentation(resized, rng=rng)
101
+
102
+ return TrainingPair(
103
+ input_image=augmented_input,
104
+ target_image=target,
105
+ conditioning=landmark_img,
106
+ canny=canny,
107
+ mask=mask,
108
+ procedure=procedure,
109
+ intensity=intensity,
110
+ )
111
+
112
+
113
+ def generate_pairs_from_directory(
114
+ image_dir: str | Path,
115
+ num_pairs: int = 1000,
116
+ target_size: int = 512,
117
+ seed: int = 42,
118
+ quality_check: bool = True,
119
+ min_quality: float = 45.0,
120
+ ) -> Iterator[TrainingPair]:
121
+ """Generate training pairs from a directory of face images.
122
+
123
+ Args:
124
+ image_dir: Directory containing face images.
125
+ num_pairs: Total number of pairs to generate.
126
+ target_size: Output resolution.
127
+ seed: Random seed.
128
+ quality_check: Run face verifier quality check on source images.
129
+ min_quality: Minimum quality score to use image (0-100).
130
+
131
+ Yields:
132
+ TrainingPair instances.
133
+ """
134
+ rng = np.random.default_rng(seed)
135
+ image_dir = Path(image_dir)
136
+
137
+ extensions = {".jpg", ".jpeg", ".png", ".webp"}
138
+ image_files = sorted(f for f in image_dir.iterdir() if f.suffix.lower() in extensions)
139
+
140
+ if not image_files:
141
+ raise FileNotFoundError(f"No images found in {image_dir}")
142
+
143
+ # Optional quality pre-filter
144
+ _quality_cache: dict[str, float] = {}
145
+ quality_rejects = 0
146
+
147
+ generated = 0
148
+ consecutive_failures = 0
149
+ idx = 0
150
+ while generated < num_pairs:
151
+ # Cycle through images
152
+ img_path = image_files[idx % len(image_files)]
153
+ idx += 1
154
+ image = cv2.imread(str(img_path))
155
+ if image is None:
156
+ consecutive_failures += 1
157
+ if consecutive_failures > len(image_files):
158
+ logger.warning("%d consecutive failures, stopping early", consecutive_failures)
159
+ break
160
+ continue
161
+
162
+ # Quality gate: reject low-quality source images before pair generation
163
+ if quality_check:
164
+ cache_key = str(img_path)
165
+ if cache_key not in _quality_cache:
166
+ try:
167
+ from landmarkdiff.face_verifier import analyze_distortions
168
+
169
+ resized = cv2.resize(image, (target_size, target_size))
170
+ report = analyze_distortions(resized)
171
+ _quality_cache[cache_key] = report.quality_score
172
+ except Exception:
173
+ _quality_cache[cache_key] = 100.0 # Can't check -- allow through
174
+
175
+ if _quality_cache[cache_key] < min_quality:
176
+ quality_rejects += 1
177
+ if quality_rejects % 100 == 0:
178
+ logger.info(" Quality filter: %d images rejected so far", quality_rejects)
179
+ consecutive_failures += 1
180
+ if consecutive_failures > len(image_files):
181
+ break
182
+ continue
183
+
184
+ pair = generate_pair(image, target_size=target_size, rng=rng)
185
+ if pair is not None:
186
+ yield pair
187
+ generated += 1
188
+ consecutive_failures = 0
189
+ else:
190
+ consecutive_failures += 1
191
+ if consecutive_failures > len(image_files):
192
+ logger.warning("%d consecutive failures, stopping early", consecutive_failures)
193
+ break
194
+
195
+ if quality_rejects > 0:
196
+ logger.info("Quality filter: rejected %d low-quality source images", quality_rejects)
197
+
198
+
199
+ def save_pair(pair: TrainingPair, output_dir: Path, index: int) -> None:
200
+ """Save a training pair to disk."""
201
+ output_dir.mkdir(parents=True, exist_ok=True)
202
+ prefix = f"{index:06d}"
203
+
204
+ cv2.imwrite(str(output_dir / f"{prefix}_input.png"), pair.input_image)
205
+ cv2.imwrite(str(output_dir / f"{prefix}_target.png"), pair.target_image)
206
+ cv2.imwrite(str(output_dir / f"{prefix}_conditioning.png"), pair.conditioning)
207
+ cv2.imwrite(str(output_dir / f"{prefix}_canny.png"), pair.canny)
208
+ cv2.imwrite(str(output_dir / f"{prefix}_mask.png"), (pair.mask * 255).astype(np.uint8))
@@ -0,0 +1,273 @@
1
+ """TPS warping for synthetic pair generation.
2
+
3
+ Only warps deformable tissue - rigid structures (teeth, sclera) get
4
+ rigid translation instead. Prevents "rubber teeth" from naive TPS.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import cv2
10
+ import numpy as np
11
+
12
+
13
+ def compute_tps_transform(
14
+ src_pts: np.ndarray,
15
+ dst_pts: np.ndarray,
16
+ ) -> cv2.ThinPlateSplineShapeTransformer:
17
+ """Fit a TPS transform from src to dst points."""
18
+ src = src_pts.reshape(1, -1, 2).astype(np.float32)
19
+ dst = dst_pts.reshape(1, -1, 2).astype(np.float32)
20
+ matches = [cv2.DMatch(i, i, 0) for i in range(len(src_pts))]
21
+
22
+ tps = cv2.createThinPlateSplineShapeTransformer()
23
+ tps.estimateTransformation(dst, src, matches)
24
+ return tps
25
+
26
+
27
+ def _subsample_control_points(
28
+ src: np.ndarray,
29
+ dst: np.ndarray,
30
+ max_points: int = 80,
31
+ anchor_stride: int = 8,
32
+ ) -> tuple[np.ndarray, np.ndarray]:
33
+ """Keep all displaced points + sparse anchors. ~80 pts instead of 478, ~30x faster."""
34
+ displacements = np.linalg.norm(dst - src, axis=1)
35
+ displaced_mask = displacements > 0.5 # moved by > 0.5px
36
+ displaced_idx = np.where(displaced_mask)[0]
37
+
38
+ # Add sparse anchors from non-displaced landmarks
39
+ non_displaced_idx = np.where(~displaced_mask)[0]
40
+ anchor_idx = non_displaced_idx[::anchor_stride]
41
+
42
+ selected = np.concatenate([displaced_idx, anchor_idx])
43
+
44
+ # If still too many, subsample anchors more aggressively
45
+ if len(selected) > max_points:
46
+ n_anchors = max_points - len(displaced_idx)
47
+ if n_anchors > 0:
48
+ step = max(1, len(non_displaced_idx) // n_anchors)
49
+ anchor_idx = non_displaced_idx[::step][:n_anchors]
50
+ selected = np.concatenate([displaced_idx, anchor_idx])
51
+ else:
52
+ selected = displaced_idx[:max_points]
53
+
54
+ selected = np.unique(selected)
55
+ return src[selected], dst[selected]
56
+
57
+
58
+ def warp_image_tps(
59
+ image: np.ndarray,
60
+ src_landmarks: np.ndarray,
61
+ dst_landmarks: np.ndarray,
62
+ rigid_mask: np.ndarray | None = None,
63
+ ) -> np.ndarray:
64
+ """Apply TPS warp to an image with optional rigid region preservation."""
65
+ h, w = image.shape[:2]
66
+
67
+ src_pts = src_landmarks.astype(np.float32)
68
+ dst_pts = dst_landmarks.astype(np.float32)
69
+
70
+ # Subsample control points for speed (478 -> ~80)
71
+ src_sub, dst_sub = _subsample_control_points(src_pts, dst_pts)
72
+
73
+ # Compute TPS coefficients on subsampled points
74
+ map_x, map_y = _compute_tps_map(src_sub, dst_sub, w, h)
75
+
76
+ # Warp the image
77
+ warped = cv2.remap(
78
+ image,
79
+ map_x.astype(np.float32),
80
+ map_y.astype(np.float32),
81
+ interpolation=cv2.INTER_LINEAR,
82
+ borderMode=cv2.BORDER_REFLECT_101,
83
+ )
84
+
85
+ if rigid_mask is not None:
86
+ # For rigid regions, compute mean translation and apply rigidly
87
+ rigid_translation = _compute_rigid_translation(src_pts, dst_pts, rigid_mask, w, h)
88
+ rigid_warped = _apply_rigid_translation(image, rigid_translation)
89
+
90
+ # Translate the mask to match the rigidly-shifted content
91
+ translated_mask = _apply_rigid_translation(rigid_mask, rigid_translation)
92
+ # Composite: use rigid warp in rigid regions, TPS elsewhere
93
+ mask_f = translated_mask.astype(np.float32)
94
+ if len(mask_f.shape) == 2:
95
+ mask_f = np.stack([mask_f] * 3, axis=-1)
96
+ mask_f = mask_f / 255.0 if mask_f.max() > 1 else mask_f
97
+ warped = (rigid_warped * mask_f + warped * (1 - mask_f)).astype(np.uint8)
98
+
99
+ return warped
100
+
101
+
102
+ def _compute_tps_map(
103
+ src: np.ndarray,
104
+ dst: np.ndarray,
105
+ width: int,
106
+ height: int,
107
+ ) -> tuple[np.ndarray, np.ndarray]:
108
+ """Build remap arrays from TPS control points via RBF interpolation."""
109
+ # Displacement at control points
110
+ dx = dst[:, 0] - src[:, 0]
111
+ dy = dst[:, 1] - src[:, 1]
112
+
113
+ # Create grid
114
+ grid_x, grid_y = np.meshgrid(np.arange(width), np.arange(height))
115
+ grid_x = grid_x.astype(np.float64)
116
+ grid_y = grid_y.astype(np.float64)
117
+
118
+ # RBF interpolation using TPS kernel: r^2 * log(r)
119
+ map_x = grid_x.copy()
120
+ map_y = grid_y.copy()
121
+
122
+ n = len(src)
123
+ if n == 0:
124
+ return map_x, map_y
125
+
126
+ # Solve TPS system for x and y displacements
127
+ weights_x = _solve_tps_weights(src, dx)
128
+ weights_y = _solve_tps_weights(src, dy)
129
+
130
+ # Evaluate on grid (vectorized for speed)
131
+ flat_x = grid_x.ravel()
132
+ flat_y = grid_y.ravel()
133
+ pts = np.stack([flat_x, flat_y], axis=1)
134
+
135
+ disp_x = _evaluate_tps(pts, src, weights_x)
136
+ disp_y = _evaluate_tps(pts, src, weights_y)
137
+
138
+ map_x = (flat_x - disp_x).reshape(height, width)
139
+ map_y = (flat_y - disp_y).reshape(height, width)
140
+
141
+ return map_x, map_y
142
+
143
+
144
+ def _tps_kernel(r: np.ndarray) -> np.ndarray:
145
+ """TPS radial basis function: r^2 * log(r), with r=0 -> 0."""
146
+ result = np.zeros_like(r)
147
+ mask = r > 0
148
+ result[mask] = r[mask] ** 2 * np.log(r[mask])
149
+ return result
150
+
151
+
152
+ def _solve_tps_weights(
153
+ control_pts: np.ndarray,
154
+ values: np.ndarray,
155
+ ) -> np.ndarray:
156
+ """Solve TPS system -> weight vector [w1..wn, a0, a1, a2]."""
157
+ n = len(control_pts)
158
+
159
+ # Build kernel matrix K (vectorized)
160
+ diff = control_pts[:, np.newaxis, :] - control_pts[np.newaxis, :, :] # (n, n, 2)
161
+ r_mat = np.sqrt((diff**2).sum(axis=2)) # (n, n)
162
+ K = np.zeros((n, n))
163
+ nz = r_mat > 0
164
+ K[nz] = r_mat[nz] ** 2 * np.log(r_mat[nz])
165
+
166
+ # Build system matrix [K P; P^T 0]
167
+ P = np.hstack([np.ones((n, 1)), control_pts]) # (n, 3)
168
+
169
+ L = np.zeros((n + 3, n + 3))
170
+ L[:n, :n] = K
171
+ L[:n, n:] = P
172
+ L[n:, :n] = P.T
173
+
174
+ # Regularization for numerical stability
175
+ L[:n, :n] += np.eye(n) * 1e-6
176
+
177
+ rhs = np.zeros(n + 3)
178
+ rhs[:n] = values
179
+
180
+ try:
181
+ weights = np.linalg.solve(L, rhs)
182
+ except np.linalg.LinAlgError:
183
+ weights = np.linalg.lstsq(L, rhs, rcond=None)[0]
184
+
185
+ return weights
186
+
187
+
188
+ def _evaluate_tps(
189
+ points: np.ndarray,
190
+ control_pts: np.ndarray,
191
+ weights: np.ndarray,
192
+ ) -> np.ndarray:
193
+ """Evaluate TPS at arbitrary points (vectorized)."""
194
+ n = len(control_pts)
195
+ w = weights[:n]
196
+ a = weights[n:] # affine: a0 + a1*x + a2*y
197
+
198
+ # Affine component
199
+ result = a[0] + a[1] * points[:, 0] + a[2] * points[:, 1]
200
+
201
+ # Vectorized RBF evaluation in batches to limit memory
202
+ batch_size = 50000
203
+ for start in range(0, len(points), batch_size):
204
+ end = min(start + batch_size, len(points))
205
+ batch = points[start:end] # (M, 2)
206
+
207
+ # Compute all distances at once: (M, n)
208
+ dx = batch[:, 0:1] - control_pts[:, 0] # (M, n) via broadcasting
209
+ dy = batch[:, 1:2] - control_pts[:, 1] # (M, n)
210
+ r = np.sqrt(dx**2 + dy**2)
211
+
212
+ # TPS kernel: r^2 * log(r), with r=0 -> 0
213
+ kernel = np.zeros_like(r)
214
+ mask = r > 0
215
+ kernel[mask] = r[mask] ** 2 * np.log(r[mask])
216
+
217
+ # Weighted sum across all control points
218
+ result[start:end] += kernel @ w
219
+
220
+ return result
221
+
222
+
223
+ def _compute_rigid_translation(
224
+ src: np.ndarray,
225
+ dst: np.ndarray,
226
+ mask: np.ndarray,
227
+ width: int,
228
+ height: int,
229
+ ) -> np.ndarray:
230
+ """Compute mean translation for rigid regions."""
231
+ # Find control points inside rigid mask
232
+ inside = []
233
+ for i, (x, y) in enumerate(src):
234
+ ix, iy = int(x), int(y)
235
+ if 0 <= ix < width and 0 <= iy < height and mask[iy, ix] > 0:
236
+ inside.append(i)
237
+
238
+ if not inside:
239
+ return np.array([0.0, 0.0])
240
+
241
+ dx = np.mean(dst[inside, 0] - src[inside, 0])
242
+ dy = np.mean(dst[inside, 1] - src[inside, 1])
243
+ return np.array([dx, dy])
244
+
245
+
246
+ def _apply_rigid_translation(
247
+ image: np.ndarray,
248
+ translation: np.ndarray,
249
+ ) -> np.ndarray:
250
+ """Apply rigid translation to an image."""
251
+ h, w = image.shape[:2]
252
+ M = np.float32([[1, 0, translation[0]], [0, 1, translation[1]]])
253
+ return cv2.warpAffine(image, M, (w, h), borderMode=cv2.BORDER_REFLECT_101)
254
+
255
+
256
+ def generate_random_warp(
257
+ landmarks: np.ndarray,
258
+ procedure_indices: list[int],
259
+ max_displacement: float = 15.0,
260
+ rng: np.random.Generator | None = None,
261
+ ) -> np.ndarray:
262
+ """Generate randomly warped landmarks for synthetic data."""
263
+ rng = rng or np.random.default_rng()
264
+ result = landmarks.copy()
265
+
266
+ for idx in procedure_indices:
267
+ if idx < len(landmarks):
268
+ dx = rng.uniform(-max_displacement, max_displacement)
269
+ dy = rng.uniform(-max_displacement, max_displacement)
270
+ result[idx, 0] += dx
271
+ result[idx, 1] += dy
272
+
273
+ return result