landmarkdiff 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- landmarkdiff/__init__.py +40 -0
- landmarkdiff/__main__.py +207 -0
- landmarkdiff/api_client.py +316 -0
- landmarkdiff/arcface_torch.py +583 -0
- landmarkdiff/audit.py +338 -0
- landmarkdiff/augmentation.py +293 -0
- landmarkdiff/benchmark.py +213 -0
- landmarkdiff/checkpoint_manager.py +361 -0
- landmarkdiff/cli.py +252 -0
- landmarkdiff/clinical.py +223 -0
- landmarkdiff/conditioning.py +278 -0
- landmarkdiff/config.py +358 -0
- landmarkdiff/curriculum.py +191 -0
- landmarkdiff/data.py +405 -0
- landmarkdiff/data_version.py +301 -0
- landmarkdiff/displacement_model.py +745 -0
- landmarkdiff/ensemble.py +330 -0
- landmarkdiff/evaluation.py +415 -0
- landmarkdiff/experiment_tracker.py +231 -0
- landmarkdiff/face_verifier.py +947 -0
- landmarkdiff/fid.py +244 -0
- landmarkdiff/hyperparam.py +347 -0
- landmarkdiff/inference.py +754 -0
- landmarkdiff/landmarks.py +432 -0
- landmarkdiff/log.py +90 -0
- landmarkdiff/losses.py +348 -0
- landmarkdiff/manipulation.py +651 -0
- landmarkdiff/masking.py +316 -0
- landmarkdiff/metrics_agg.py +313 -0
- landmarkdiff/metrics_viz.py +464 -0
- landmarkdiff/model_registry.py +362 -0
- landmarkdiff/morphometry.py +342 -0
- landmarkdiff/postprocess.py +600 -0
- landmarkdiff/py.typed +0 -0
- landmarkdiff/safety.py +395 -0
- landmarkdiff/synthetic/__init__.py +23 -0
- landmarkdiff/synthetic/augmentation.py +188 -0
- landmarkdiff/synthetic/pair_generator.py +208 -0
- landmarkdiff/synthetic/tps_warp.py +273 -0
- landmarkdiff/validation.py +324 -0
- landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
- landmarkdiff-0.2.3.dist-info/RECORD +46 -0
- landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
- landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
- landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
- landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
"""Synthetic training pair generator.
|
|
2
|
+
|
|
3
|
+
Creates (input, conditioning, mask, target) tuples for ControlNet fine-tuning.
|
|
4
|
+
Pipeline: FFHQ image -> extract landmarks -> random FFD manipulation ->
|
|
5
|
+
generate conditioning + mask -> apply clinical augmentation to input.
|
|
6
|
+
|
|
7
|
+
Augmentations are applied to INPUT only, never to target (ground truth).
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import logging
|
|
13
|
+
from collections.abc import Iterator
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
|
|
17
|
+
import cv2
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
from landmarkdiff.conditioning import generate_conditioning
|
|
21
|
+
from landmarkdiff.landmarks import extract_landmarks, render_landmark_image
|
|
22
|
+
from landmarkdiff.manipulation import (
|
|
23
|
+
PROCEDURE_LANDMARKS,
|
|
24
|
+
apply_procedure_preset,
|
|
25
|
+
)
|
|
26
|
+
from landmarkdiff.masking import generate_surgical_mask
|
|
27
|
+
from landmarkdiff.synthetic.augmentation import apply_clinical_augmentation
|
|
28
|
+
from landmarkdiff.synthetic.tps_warp import warp_image_tps
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass(frozen=True)
|
|
34
|
+
class TrainingPair:
|
|
35
|
+
"""A single training sample for ControlNet fine-tuning."""
|
|
36
|
+
|
|
37
|
+
input_image: np.ndarray # augmented input (512x512 BGR)
|
|
38
|
+
target_image: np.ndarray # clean target (512x512 BGR) -- TPS-warped original
|
|
39
|
+
conditioning: np.ndarray # landmark rendering (512x512 BGR)
|
|
40
|
+
canny: np.ndarray # canny edge map (512x512 grayscale)
|
|
41
|
+
mask: np.ndarray # feathered surgical mask (512x512 float32)
|
|
42
|
+
procedure: str
|
|
43
|
+
intensity: float
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
PROCEDURES = list(PROCEDURE_LANDMARKS.keys())
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def generate_pair(
|
|
50
|
+
image: np.ndarray,
|
|
51
|
+
procedure: str | None = None,
|
|
52
|
+
intensity: float | None = None,
|
|
53
|
+
target_size: int = 512,
|
|
54
|
+
rng: np.random.Generator | None = None,
|
|
55
|
+
) -> TrainingPair | None:
|
|
56
|
+
"""Generate a single training pair from a face image.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
image: BGR input image (any size).
|
|
60
|
+
procedure: Procedure type (random if None).
|
|
61
|
+
intensity: Manipulation intensity 0-100 (random 30-90 if None).
|
|
62
|
+
target_size: Output resolution.
|
|
63
|
+
rng: Random number generator.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
TrainingPair or None if face detection fails.
|
|
67
|
+
"""
|
|
68
|
+
rng = rng or np.random.default_rng()
|
|
69
|
+
|
|
70
|
+
# Resize to target
|
|
71
|
+
resized = cv2.resize(image, (target_size, target_size))
|
|
72
|
+
|
|
73
|
+
# Extract landmarks
|
|
74
|
+
face = extract_landmarks(resized)
|
|
75
|
+
if face is None:
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
# Random procedure and intensity if not specified
|
|
79
|
+
if procedure is None:
|
|
80
|
+
procedure = rng.choice(PROCEDURES)
|
|
81
|
+
if intensity is None:
|
|
82
|
+
intensity = float(rng.uniform(30, 90))
|
|
83
|
+
|
|
84
|
+
# Manipulate landmarks
|
|
85
|
+
manipulated = apply_procedure_preset(face, procedure, intensity, target_size)
|
|
86
|
+
|
|
87
|
+
# Generate conditioning from manipulated landmarks
|
|
88
|
+
landmark_img = render_landmark_image(manipulated, target_size, target_size)
|
|
89
|
+
_, canny, _ = generate_conditioning(manipulated, target_size, target_size)
|
|
90
|
+
|
|
91
|
+
# Generate mask
|
|
92
|
+
mask = generate_surgical_mask(face, procedure, target_size, target_size)
|
|
93
|
+
|
|
94
|
+
# Generate target: TPS warp the original image to match manipulated landmarks
|
|
95
|
+
src_px = face.pixel_coords
|
|
96
|
+
dst_px = manipulated.pixel_coords
|
|
97
|
+
target = warp_image_tps(resized, src_px, dst_px)
|
|
98
|
+
|
|
99
|
+
# Apply clinical augmentation to INPUT only (never target)
|
|
100
|
+
augmented_input = apply_clinical_augmentation(resized, rng=rng)
|
|
101
|
+
|
|
102
|
+
return TrainingPair(
|
|
103
|
+
input_image=augmented_input,
|
|
104
|
+
target_image=target,
|
|
105
|
+
conditioning=landmark_img,
|
|
106
|
+
canny=canny,
|
|
107
|
+
mask=mask,
|
|
108
|
+
procedure=procedure,
|
|
109
|
+
intensity=intensity,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def generate_pairs_from_directory(
|
|
114
|
+
image_dir: str | Path,
|
|
115
|
+
num_pairs: int = 1000,
|
|
116
|
+
target_size: int = 512,
|
|
117
|
+
seed: int = 42,
|
|
118
|
+
quality_check: bool = True,
|
|
119
|
+
min_quality: float = 45.0,
|
|
120
|
+
) -> Iterator[TrainingPair]:
|
|
121
|
+
"""Generate training pairs from a directory of face images.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
image_dir: Directory containing face images.
|
|
125
|
+
num_pairs: Total number of pairs to generate.
|
|
126
|
+
target_size: Output resolution.
|
|
127
|
+
seed: Random seed.
|
|
128
|
+
quality_check: Run face verifier quality check on source images.
|
|
129
|
+
min_quality: Minimum quality score to use image (0-100).
|
|
130
|
+
|
|
131
|
+
Yields:
|
|
132
|
+
TrainingPair instances.
|
|
133
|
+
"""
|
|
134
|
+
rng = np.random.default_rng(seed)
|
|
135
|
+
image_dir = Path(image_dir)
|
|
136
|
+
|
|
137
|
+
extensions = {".jpg", ".jpeg", ".png", ".webp"}
|
|
138
|
+
image_files = sorted(f for f in image_dir.iterdir() if f.suffix.lower() in extensions)
|
|
139
|
+
|
|
140
|
+
if not image_files:
|
|
141
|
+
raise FileNotFoundError(f"No images found in {image_dir}")
|
|
142
|
+
|
|
143
|
+
# Optional quality pre-filter
|
|
144
|
+
_quality_cache: dict[str, float] = {}
|
|
145
|
+
quality_rejects = 0
|
|
146
|
+
|
|
147
|
+
generated = 0
|
|
148
|
+
consecutive_failures = 0
|
|
149
|
+
idx = 0
|
|
150
|
+
while generated < num_pairs:
|
|
151
|
+
# Cycle through images
|
|
152
|
+
img_path = image_files[idx % len(image_files)]
|
|
153
|
+
idx += 1
|
|
154
|
+
image = cv2.imread(str(img_path))
|
|
155
|
+
if image is None:
|
|
156
|
+
consecutive_failures += 1
|
|
157
|
+
if consecutive_failures > len(image_files):
|
|
158
|
+
logger.warning("%d consecutive failures, stopping early", consecutive_failures)
|
|
159
|
+
break
|
|
160
|
+
continue
|
|
161
|
+
|
|
162
|
+
# Quality gate: reject low-quality source images before pair generation
|
|
163
|
+
if quality_check:
|
|
164
|
+
cache_key = str(img_path)
|
|
165
|
+
if cache_key not in _quality_cache:
|
|
166
|
+
try:
|
|
167
|
+
from landmarkdiff.face_verifier import analyze_distortions
|
|
168
|
+
|
|
169
|
+
resized = cv2.resize(image, (target_size, target_size))
|
|
170
|
+
report = analyze_distortions(resized)
|
|
171
|
+
_quality_cache[cache_key] = report.quality_score
|
|
172
|
+
except Exception:
|
|
173
|
+
_quality_cache[cache_key] = 100.0 # Can't check -- allow through
|
|
174
|
+
|
|
175
|
+
if _quality_cache[cache_key] < min_quality:
|
|
176
|
+
quality_rejects += 1
|
|
177
|
+
if quality_rejects % 100 == 0:
|
|
178
|
+
logger.info(" Quality filter: %d images rejected so far", quality_rejects)
|
|
179
|
+
consecutive_failures += 1
|
|
180
|
+
if consecutive_failures > len(image_files):
|
|
181
|
+
break
|
|
182
|
+
continue
|
|
183
|
+
|
|
184
|
+
pair = generate_pair(image, target_size=target_size, rng=rng)
|
|
185
|
+
if pair is not None:
|
|
186
|
+
yield pair
|
|
187
|
+
generated += 1
|
|
188
|
+
consecutive_failures = 0
|
|
189
|
+
else:
|
|
190
|
+
consecutive_failures += 1
|
|
191
|
+
if consecutive_failures > len(image_files):
|
|
192
|
+
logger.warning("%d consecutive failures, stopping early", consecutive_failures)
|
|
193
|
+
break
|
|
194
|
+
|
|
195
|
+
if quality_rejects > 0:
|
|
196
|
+
logger.info("Quality filter: rejected %d low-quality source images", quality_rejects)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def save_pair(pair: TrainingPair, output_dir: Path, index: int) -> None:
|
|
200
|
+
"""Save a training pair to disk."""
|
|
201
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
202
|
+
prefix = f"{index:06d}"
|
|
203
|
+
|
|
204
|
+
cv2.imwrite(str(output_dir / f"{prefix}_input.png"), pair.input_image)
|
|
205
|
+
cv2.imwrite(str(output_dir / f"{prefix}_target.png"), pair.target_image)
|
|
206
|
+
cv2.imwrite(str(output_dir / f"{prefix}_conditioning.png"), pair.conditioning)
|
|
207
|
+
cv2.imwrite(str(output_dir / f"{prefix}_canny.png"), pair.canny)
|
|
208
|
+
cv2.imwrite(str(output_dir / f"{prefix}_mask.png"), (pair.mask * 255).astype(np.uint8))
|
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
"""TPS warping for synthetic pair generation.
|
|
2
|
+
|
|
3
|
+
Only warps deformable tissue - rigid structures (teeth, sclera) get
|
|
4
|
+
rigid translation instead. Prevents "rubber teeth" from naive TPS.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import cv2
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def compute_tps_transform(
|
|
14
|
+
src_pts: np.ndarray,
|
|
15
|
+
dst_pts: np.ndarray,
|
|
16
|
+
) -> cv2.ThinPlateSplineShapeTransformer:
|
|
17
|
+
"""Fit a TPS transform from src to dst points."""
|
|
18
|
+
src = src_pts.reshape(1, -1, 2).astype(np.float32)
|
|
19
|
+
dst = dst_pts.reshape(1, -1, 2).astype(np.float32)
|
|
20
|
+
matches = [cv2.DMatch(i, i, 0) for i in range(len(src_pts))]
|
|
21
|
+
|
|
22
|
+
tps = cv2.createThinPlateSplineShapeTransformer()
|
|
23
|
+
tps.estimateTransformation(dst, src, matches)
|
|
24
|
+
return tps
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _subsample_control_points(
|
|
28
|
+
src: np.ndarray,
|
|
29
|
+
dst: np.ndarray,
|
|
30
|
+
max_points: int = 80,
|
|
31
|
+
anchor_stride: int = 8,
|
|
32
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
33
|
+
"""Keep all displaced points + sparse anchors. ~80 pts instead of 478, ~30x faster."""
|
|
34
|
+
displacements = np.linalg.norm(dst - src, axis=1)
|
|
35
|
+
displaced_mask = displacements > 0.5 # moved by > 0.5px
|
|
36
|
+
displaced_idx = np.where(displaced_mask)[0]
|
|
37
|
+
|
|
38
|
+
# Add sparse anchors from non-displaced landmarks
|
|
39
|
+
non_displaced_idx = np.where(~displaced_mask)[0]
|
|
40
|
+
anchor_idx = non_displaced_idx[::anchor_stride]
|
|
41
|
+
|
|
42
|
+
selected = np.concatenate([displaced_idx, anchor_idx])
|
|
43
|
+
|
|
44
|
+
# If still too many, subsample anchors more aggressively
|
|
45
|
+
if len(selected) > max_points:
|
|
46
|
+
n_anchors = max_points - len(displaced_idx)
|
|
47
|
+
if n_anchors > 0:
|
|
48
|
+
step = max(1, len(non_displaced_idx) // n_anchors)
|
|
49
|
+
anchor_idx = non_displaced_idx[::step][:n_anchors]
|
|
50
|
+
selected = np.concatenate([displaced_idx, anchor_idx])
|
|
51
|
+
else:
|
|
52
|
+
selected = displaced_idx[:max_points]
|
|
53
|
+
|
|
54
|
+
selected = np.unique(selected)
|
|
55
|
+
return src[selected], dst[selected]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def warp_image_tps(
|
|
59
|
+
image: np.ndarray,
|
|
60
|
+
src_landmarks: np.ndarray,
|
|
61
|
+
dst_landmarks: np.ndarray,
|
|
62
|
+
rigid_mask: np.ndarray | None = None,
|
|
63
|
+
) -> np.ndarray:
|
|
64
|
+
"""Apply TPS warp to an image with optional rigid region preservation."""
|
|
65
|
+
h, w = image.shape[:2]
|
|
66
|
+
|
|
67
|
+
src_pts = src_landmarks.astype(np.float32)
|
|
68
|
+
dst_pts = dst_landmarks.astype(np.float32)
|
|
69
|
+
|
|
70
|
+
# Subsample control points for speed (478 -> ~80)
|
|
71
|
+
src_sub, dst_sub = _subsample_control_points(src_pts, dst_pts)
|
|
72
|
+
|
|
73
|
+
# Compute TPS coefficients on subsampled points
|
|
74
|
+
map_x, map_y = _compute_tps_map(src_sub, dst_sub, w, h)
|
|
75
|
+
|
|
76
|
+
# Warp the image
|
|
77
|
+
warped = cv2.remap(
|
|
78
|
+
image,
|
|
79
|
+
map_x.astype(np.float32),
|
|
80
|
+
map_y.astype(np.float32),
|
|
81
|
+
interpolation=cv2.INTER_LINEAR,
|
|
82
|
+
borderMode=cv2.BORDER_REFLECT_101,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
if rigid_mask is not None:
|
|
86
|
+
# For rigid regions, compute mean translation and apply rigidly
|
|
87
|
+
rigid_translation = _compute_rigid_translation(src_pts, dst_pts, rigid_mask, w, h)
|
|
88
|
+
rigid_warped = _apply_rigid_translation(image, rigid_translation)
|
|
89
|
+
|
|
90
|
+
# Translate the mask to match the rigidly-shifted content
|
|
91
|
+
translated_mask = _apply_rigid_translation(rigid_mask, rigid_translation)
|
|
92
|
+
# Composite: use rigid warp in rigid regions, TPS elsewhere
|
|
93
|
+
mask_f = translated_mask.astype(np.float32)
|
|
94
|
+
if len(mask_f.shape) == 2:
|
|
95
|
+
mask_f = np.stack([mask_f] * 3, axis=-1)
|
|
96
|
+
mask_f = mask_f / 255.0 if mask_f.max() > 1 else mask_f
|
|
97
|
+
warped = (rigid_warped * mask_f + warped * (1 - mask_f)).astype(np.uint8)
|
|
98
|
+
|
|
99
|
+
return warped
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _compute_tps_map(
|
|
103
|
+
src: np.ndarray,
|
|
104
|
+
dst: np.ndarray,
|
|
105
|
+
width: int,
|
|
106
|
+
height: int,
|
|
107
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
108
|
+
"""Build remap arrays from TPS control points via RBF interpolation."""
|
|
109
|
+
# Displacement at control points
|
|
110
|
+
dx = dst[:, 0] - src[:, 0]
|
|
111
|
+
dy = dst[:, 1] - src[:, 1]
|
|
112
|
+
|
|
113
|
+
# Create grid
|
|
114
|
+
grid_x, grid_y = np.meshgrid(np.arange(width), np.arange(height))
|
|
115
|
+
grid_x = grid_x.astype(np.float64)
|
|
116
|
+
grid_y = grid_y.astype(np.float64)
|
|
117
|
+
|
|
118
|
+
# RBF interpolation using TPS kernel: r^2 * log(r)
|
|
119
|
+
map_x = grid_x.copy()
|
|
120
|
+
map_y = grid_y.copy()
|
|
121
|
+
|
|
122
|
+
n = len(src)
|
|
123
|
+
if n == 0:
|
|
124
|
+
return map_x, map_y
|
|
125
|
+
|
|
126
|
+
# Solve TPS system for x and y displacements
|
|
127
|
+
weights_x = _solve_tps_weights(src, dx)
|
|
128
|
+
weights_y = _solve_tps_weights(src, dy)
|
|
129
|
+
|
|
130
|
+
# Evaluate on grid (vectorized for speed)
|
|
131
|
+
flat_x = grid_x.ravel()
|
|
132
|
+
flat_y = grid_y.ravel()
|
|
133
|
+
pts = np.stack([flat_x, flat_y], axis=1)
|
|
134
|
+
|
|
135
|
+
disp_x = _evaluate_tps(pts, src, weights_x)
|
|
136
|
+
disp_y = _evaluate_tps(pts, src, weights_y)
|
|
137
|
+
|
|
138
|
+
map_x = (flat_x - disp_x).reshape(height, width)
|
|
139
|
+
map_y = (flat_y - disp_y).reshape(height, width)
|
|
140
|
+
|
|
141
|
+
return map_x, map_y
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def _tps_kernel(r: np.ndarray) -> np.ndarray:
|
|
145
|
+
"""TPS radial basis function: r^2 * log(r), with r=0 -> 0."""
|
|
146
|
+
result = np.zeros_like(r)
|
|
147
|
+
mask = r > 0
|
|
148
|
+
result[mask] = r[mask] ** 2 * np.log(r[mask])
|
|
149
|
+
return result
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _solve_tps_weights(
|
|
153
|
+
control_pts: np.ndarray,
|
|
154
|
+
values: np.ndarray,
|
|
155
|
+
) -> np.ndarray:
|
|
156
|
+
"""Solve TPS system -> weight vector [w1..wn, a0, a1, a2]."""
|
|
157
|
+
n = len(control_pts)
|
|
158
|
+
|
|
159
|
+
# Build kernel matrix K (vectorized)
|
|
160
|
+
diff = control_pts[:, np.newaxis, :] - control_pts[np.newaxis, :, :] # (n, n, 2)
|
|
161
|
+
r_mat = np.sqrt((diff**2).sum(axis=2)) # (n, n)
|
|
162
|
+
K = np.zeros((n, n))
|
|
163
|
+
nz = r_mat > 0
|
|
164
|
+
K[nz] = r_mat[nz] ** 2 * np.log(r_mat[nz])
|
|
165
|
+
|
|
166
|
+
# Build system matrix [K P; P^T 0]
|
|
167
|
+
P = np.hstack([np.ones((n, 1)), control_pts]) # (n, 3)
|
|
168
|
+
|
|
169
|
+
L = np.zeros((n + 3, n + 3))
|
|
170
|
+
L[:n, :n] = K
|
|
171
|
+
L[:n, n:] = P
|
|
172
|
+
L[n:, :n] = P.T
|
|
173
|
+
|
|
174
|
+
# Regularization for numerical stability
|
|
175
|
+
L[:n, :n] += np.eye(n) * 1e-6
|
|
176
|
+
|
|
177
|
+
rhs = np.zeros(n + 3)
|
|
178
|
+
rhs[:n] = values
|
|
179
|
+
|
|
180
|
+
try:
|
|
181
|
+
weights = np.linalg.solve(L, rhs)
|
|
182
|
+
except np.linalg.LinAlgError:
|
|
183
|
+
weights = np.linalg.lstsq(L, rhs, rcond=None)[0]
|
|
184
|
+
|
|
185
|
+
return weights
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _evaluate_tps(
|
|
189
|
+
points: np.ndarray,
|
|
190
|
+
control_pts: np.ndarray,
|
|
191
|
+
weights: np.ndarray,
|
|
192
|
+
) -> np.ndarray:
|
|
193
|
+
"""Evaluate TPS at arbitrary points (vectorized)."""
|
|
194
|
+
n = len(control_pts)
|
|
195
|
+
w = weights[:n]
|
|
196
|
+
a = weights[n:] # affine: a0 + a1*x + a2*y
|
|
197
|
+
|
|
198
|
+
# Affine component
|
|
199
|
+
result = a[0] + a[1] * points[:, 0] + a[2] * points[:, 1]
|
|
200
|
+
|
|
201
|
+
# Vectorized RBF evaluation in batches to limit memory
|
|
202
|
+
batch_size = 50000
|
|
203
|
+
for start in range(0, len(points), batch_size):
|
|
204
|
+
end = min(start + batch_size, len(points))
|
|
205
|
+
batch = points[start:end] # (M, 2)
|
|
206
|
+
|
|
207
|
+
# Compute all distances at once: (M, n)
|
|
208
|
+
dx = batch[:, 0:1] - control_pts[:, 0] # (M, n) via broadcasting
|
|
209
|
+
dy = batch[:, 1:2] - control_pts[:, 1] # (M, n)
|
|
210
|
+
r = np.sqrt(dx**2 + dy**2)
|
|
211
|
+
|
|
212
|
+
# TPS kernel: r^2 * log(r), with r=0 -> 0
|
|
213
|
+
kernel = np.zeros_like(r)
|
|
214
|
+
mask = r > 0
|
|
215
|
+
kernel[mask] = r[mask] ** 2 * np.log(r[mask])
|
|
216
|
+
|
|
217
|
+
# Weighted sum across all control points
|
|
218
|
+
result[start:end] += kernel @ w
|
|
219
|
+
|
|
220
|
+
return result
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def _compute_rigid_translation(
|
|
224
|
+
src: np.ndarray,
|
|
225
|
+
dst: np.ndarray,
|
|
226
|
+
mask: np.ndarray,
|
|
227
|
+
width: int,
|
|
228
|
+
height: int,
|
|
229
|
+
) -> np.ndarray:
|
|
230
|
+
"""Compute mean translation for rigid regions."""
|
|
231
|
+
# Find control points inside rigid mask
|
|
232
|
+
inside = []
|
|
233
|
+
for i, (x, y) in enumerate(src):
|
|
234
|
+
ix, iy = int(x), int(y)
|
|
235
|
+
if 0 <= ix < width and 0 <= iy < height and mask[iy, ix] > 0:
|
|
236
|
+
inside.append(i)
|
|
237
|
+
|
|
238
|
+
if not inside:
|
|
239
|
+
return np.array([0.0, 0.0])
|
|
240
|
+
|
|
241
|
+
dx = np.mean(dst[inside, 0] - src[inside, 0])
|
|
242
|
+
dy = np.mean(dst[inside, 1] - src[inside, 1])
|
|
243
|
+
return np.array([dx, dy])
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def _apply_rigid_translation(
|
|
247
|
+
image: np.ndarray,
|
|
248
|
+
translation: np.ndarray,
|
|
249
|
+
) -> np.ndarray:
|
|
250
|
+
"""Apply rigid translation to an image."""
|
|
251
|
+
h, w = image.shape[:2]
|
|
252
|
+
M = np.float32([[1, 0, translation[0]], [0, 1, translation[1]]])
|
|
253
|
+
return cv2.warpAffine(image, M, (w, h), borderMode=cv2.BORDER_REFLECT_101)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def generate_random_warp(
|
|
257
|
+
landmarks: np.ndarray,
|
|
258
|
+
procedure_indices: list[int],
|
|
259
|
+
max_displacement: float = 15.0,
|
|
260
|
+
rng: np.random.Generator | None = None,
|
|
261
|
+
) -> np.ndarray:
|
|
262
|
+
"""Generate randomly warped landmarks for synthetic data."""
|
|
263
|
+
rng = rng or np.random.default_rng()
|
|
264
|
+
result = landmarks.copy()
|
|
265
|
+
|
|
266
|
+
for idx in procedure_indices:
|
|
267
|
+
if idx < len(landmarks):
|
|
268
|
+
dx = rng.uniform(-max_displacement, max_displacement)
|
|
269
|
+
dy = rng.uniform(-max_displacement, max_displacement)
|
|
270
|
+
result[idx, 0] += dx
|
|
271
|
+
result[idx, 1] += dy
|
|
272
|
+
|
|
273
|
+
return result
|