Rhapso 0.1.92__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- Rhapso/__init__.py +1 -0
- Rhapso/data_prep/__init__.py +2 -0
- Rhapso/data_prep/n5_reader.py +188 -0
- Rhapso/data_prep/s3_big_stitcher_reader.py +55 -0
- Rhapso/data_prep/xml_to_dataframe.py +215 -0
- Rhapso/detection/__init__.py +5 -0
- Rhapso/detection/advanced_refinement.py +203 -0
- Rhapso/detection/difference_of_gaussian.py +324 -0
- Rhapso/detection/image_reader.py +117 -0
- Rhapso/detection/metadata_builder.py +130 -0
- Rhapso/detection/overlap_detection.py +327 -0
- Rhapso/detection/points_validation.py +49 -0
- Rhapso/detection/save_interest_points.py +265 -0
- Rhapso/detection/view_transform_models.py +67 -0
- Rhapso/fusion/__init__.py +0 -0
- Rhapso/fusion/affine_fusion/__init__.py +2 -0
- Rhapso/fusion/affine_fusion/blend.py +289 -0
- Rhapso/fusion/affine_fusion/fusion.py +601 -0
- Rhapso/fusion/affine_fusion/geometry.py +159 -0
- Rhapso/fusion/affine_fusion/io.py +546 -0
- Rhapso/fusion/affine_fusion/script_utils.py +111 -0
- Rhapso/fusion/affine_fusion/setup.py +4 -0
- Rhapso/fusion/affine_fusion_worker.py +234 -0
- Rhapso/fusion/multiscale/__init__.py +0 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/__init__.py +19 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/__init__.py +3 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/czi_to_zarr.py +698 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/zarr_writer.py +265 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/models.py +81 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/utils/__init__.py +3 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/utils/utils.py +526 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/zeiss_job.py +249 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/__init__.py +21 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/array_to_zarr.py +257 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/radial_correction.py +557 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/run_capsule.py +98 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/utils/__init__.py +3 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/utils/utils.py +266 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/worker.py +89 -0
- Rhapso/fusion/multiscale_worker.py +113 -0
- Rhapso/fusion/neuroglancer_link_gen/__init__.py +8 -0
- Rhapso/fusion/neuroglancer_link_gen/dispim_link.py +235 -0
- Rhapso/fusion/neuroglancer_link_gen/exaspim_link.py +127 -0
- Rhapso/fusion/neuroglancer_link_gen/hcr_link.py +368 -0
- Rhapso/fusion/neuroglancer_link_gen/iSPIM_top.py +47 -0
- Rhapso/fusion/neuroglancer_link_gen/link_utils.py +239 -0
- Rhapso/fusion/neuroglancer_link_gen/main.py +299 -0
- Rhapso/fusion/neuroglancer_link_gen/ng_layer.py +1434 -0
- Rhapso/fusion/neuroglancer_link_gen/ng_state.py +1123 -0
- Rhapso/fusion/neuroglancer_link_gen/parsers.py +336 -0
- Rhapso/fusion/neuroglancer_link_gen/raw_link.py +116 -0
- Rhapso/fusion/neuroglancer_link_gen/utils/__init__.py +4 -0
- Rhapso/fusion/neuroglancer_link_gen/utils/shader_utils.py +85 -0
- Rhapso/fusion/neuroglancer_link_gen/utils/transfer.py +43 -0
- Rhapso/fusion/neuroglancer_link_gen/utils/utils.py +303 -0
- Rhapso/fusion/neuroglancer_link_gen_worker.py +30 -0
- Rhapso/matching/__init__.py +0 -0
- Rhapso/matching/load_and_transform_points.py +458 -0
- Rhapso/matching/ransac_matching.py +544 -0
- Rhapso/matching/save_matches.py +120 -0
- Rhapso/matching/xml_parser.py +302 -0
- Rhapso/pipelines/__init__.py +0 -0
- Rhapso/pipelines/ray/__init__.py +0 -0
- Rhapso/pipelines/ray/aws/__init__.py +0 -0
- Rhapso/pipelines/ray/aws/alignment_pipeline.py +227 -0
- Rhapso/pipelines/ray/aws/config/__init__.py +0 -0
- Rhapso/pipelines/ray/evaluation.py +71 -0
- Rhapso/pipelines/ray/interest_point_detection.py +137 -0
- Rhapso/pipelines/ray/interest_point_matching.py +110 -0
- Rhapso/pipelines/ray/local/__init__.py +0 -0
- Rhapso/pipelines/ray/local/alignment_pipeline.py +167 -0
- Rhapso/pipelines/ray/matching_stats.py +104 -0
- Rhapso/pipelines/ray/param/__init__.py +0 -0
- Rhapso/pipelines/ray/solver.py +120 -0
- Rhapso/pipelines/ray/split_dataset.py +78 -0
- Rhapso/solver/__init__.py +0 -0
- Rhapso/solver/compute_tiles.py +562 -0
- Rhapso/solver/concatenate_models.py +116 -0
- Rhapso/solver/connected_graphs.py +111 -0
- Rhapso/solver/data_prep.py +181 -0
- Rhapso/solver/global_optimization.py +410 -0
- Rhapso/solver/model_and_tile_setup.py +109 -0
- Rhapso/solver/pre_align_tiles.py +323 -0
- Rhapso/solver/save_results.py +97 -0
- Rhapso/solver/view_transforms.py +75 -0
- Rhapso/solver/xml_to_dataframe_solver.py +213 -0
- Rhapso/split_dataset/__init__.py +0 -0
- Rhapso/split_dataset/compute_grid_rules.py +78 -0
- Rhapso/split_dataset/save_points.py +101 -0
- Rhapso/split_dataset/save_xml.py +377 -0
- Rhapso/split_dataset/split_images.py +537 -0
- Rhapso/split_dataset/xml_to_dataframe_split.py +219 -0
- rhapso-0.1.92.dist-info/METADATA +39 -0
- rhapso-0.1.92.dist-info/RECORD +101 -0
- rhapso-0.1.92.dist-info/WHEEL +5 -0
- rhapso-0.1.92.dist-info/licenses/LICENSE +21 -0
- rhapso-0.1.92.dist-info/top_level.txt +2 -0
- tests/__init__.py +1 -0
- tests/test_detection.py +17 -0
- tests/test_matching.py +21 -0
- tests/test_solving.py +21 -0
|
@@ -0,0 +1,544 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from sklearn.neighbors import KDTree
|
|
3
|
+
import itertools
|
|
4
|
+
import random
|
|
5
|
+
from scipy.linalg import eigh
|
|
6
|
+
import zarr
|
|
7
|
+
from bioio import BioImage
|
|
8
|
+
import bioio_tifffile
|
|
9
|
+
import dask.array as da
|
|
10
|
+
import s3fs
|
|
11
|
+
import copy
|
|
12
|
+
import re
|
|
13
|
+
|
|
14
|
+
"""
|
|
15
|
+
Utility class to find interest point match candidates and filter with ransac
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
class CustomBioImage(BioImage):
|
|
19
|
+
def standard_metadata(self):
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
def scale(self):
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
def time_interval(self):
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
class RansacMatching:
|
|
29
|
+
def __init__(self, data_global, num_neighbors, redundancy, significance, num_required_neighbors, match_type,
|
|
30
|
+
max_epsilon, min_inlier_ratio, num_iterations, model_min_matches, regularization_weight,
|
|
31
|
+
search_radius, view_registrations, input_type, image_file_prefix):
|
|
32
|
+
self.data_global = data_global
|
|
33
|
+
self.num_neighbors = num_neighbors
|
|
34
|
+
self.redundancy = redundancy
|
|
35
|
+
self.significance = significance
|
|
36
|
+
self.num_required_neighbors = num_required_neighbors
|
|
37
|
+
self.match_type = match_type
|
|
38
|
+
self.max_epsilon = max_epsilon
|
|
39
|
+
self.min_inlier_ratio = min_inlier_ratio
|
|
40
|
+
self.num_iterations = num_iterations
|
|
41
|
+
self.model_min_matches = model_min_matches
|
|
42
|
+
self.regularization_weight = regularization_weight
|
|
43
|
+
self.search_radius = search_radius
|
|
44
|
+
self.view_registrations = view_registrations
|
|
45
|
+
self.input_type = input_type
|
|
46
|
+
self.image_file_prefix = image_file_prefix
|
|
47
|
+
|
|
48
|
+
def filter_inliers(self, candidates, initial_model):
|
|
49
|
+
max_trust = 4.0
|
|
50
|
+
|
|
51
|
+
if len(candidates) < self.model_min_matches:
|
|
52
|
+
return []
|
|
53
|
+
|
|
54
|
+
model_copy = copy.deepcopy(initial_model)
|
|
55
|
+
inliers = candidates[:]
|
|
56
|
+
temp = []
|
|
57
|
+
|
|
58
|
+
while True:
|
|
59
|
+
temp = copy.deepcopy(inliers)
|
|
60
|
+
num_inliers = len(inliers)
|
|
61
|
+
|
|
62
|
+
point_pairs = [(m[1], m[5]) for m in inliers]
|
|
63
|
+
model_copy = self.model_regularization(point_pairs)
|
|
64
|
+
|
|
65
|
+
# Apply model and collect errors
|
|
66
|
+
errors = []
|
|
67
|
+
for match in temp:
|
|
68
|
+
p1 = np.array(match[1])
|
|
69
|
+
p2 = np.array(match[4])
|
|
70
|
+
p1_h = np.append(p1, 1.0)
|
|
71
|
+
p1_trans = model_copy @ p1_h
|
|
72
|
+
error = np.linalg.norm(p1_trans[:3] - p2)
|
|
73
|
+
errors.append(error)
|
|
74
|
+
|
|
75
|
+
median_error = np.median(errors)
|
|
76
|
+
threshold = median_error * max_trust
|
|
77
|
+
|
|
78
|
+
# Filter based on threshold
|
|
79
|
+
inliers = [m for m, err in zip(temp, errors) if err <= threshold]
|
|
80
|
+
|
|
81
|
+
if num_inliers <= len(inliers):
|
|
82
|
+
break
|
|
83
|
+
|
|
84
|
+
if num_inliers < self.model_min_matches:
|
|
85
|
+
return []
|
|
86
|
+
|
|
87
|
+
return inliers
|
|
88
|
+
|
|
89
|
+
def fit_rigid_model(self, matches):
|
|
90
|
+
matches = np.array(matches) # shape (N, 2, 3)
|
|
91
|
+
P = matches[:, 0] # source points
|
|
92
|
+
Q = matches[:, 1] # target points
|
|
93
|
+
weights = np.ones(P.shape[0]) # uniform weights for now
|
|
94
|
+
|
|
95
|
+
sum_w = np.sum(weights)
|
|
96
|
+
|
|
97
|
+
# Weighted centroids
|
|
98
|
+
pc = np.average(P, axis=0, weights=weights)
|
|
99
|
+
qc = np.average(Q, axis=0, weights=weights)
|
|
100
|
+
|
|
101
|
+
# Centered and weighted coordinates
|
|
102
|
+
P_centered = (P - pc) * weights[:, None]
|
|
103
|
+
Q_centered = Q - qc
|
|
104
|
+
|
|
105
|
+
# Cross-covariance matrix S
|
|
106
|
+
S = P_centered.T @ Q_centered # shape: (3, 3)
|
|
107
|
+
Sxx, Sxy, Sxz = S[0]
|
|
108
|
+
Syx, Syy, Syz = S[1]
|
|
109
|
+
Szx, Szy, Szz = S[2]
|
|
110
|
+
|
|
111
|
+
# Build 4x4 N matrix for quaternion extraction
|
|
112
|
+
N = np.array([
|
|
113
|
+
[Sxx + Syy + Szz, Syz - Szy, Szx - Sxz, Sxy - Syx],
|
|
114
|
+
[Syz - Szy, Sxx - Syy - Szz, Sxy + Syx, Szx + Sxz],
|
|
115
|
+
[Szx - Sxz, Sxy + Syx, -Sxx + Syy - Szz, Syz + Szy],
|
|
116
|
+
[Sxy - Syx, Szx + Sxz, Syz + Szy, -Sxx - Syy + Szz]
|
|
117
|
+
])
|
|
118
|
+
|
|
119
|
+
# Find eigenvector with largest eigenvalue
|
|
120
|
+
eigenvalues, eigenvectors = eigh(N)
|
|
121
|
+
q = eigenvectors[:, np.argmax(eigenvalues)] # q = [q0, qx, qy, qz]
|
|
122
|
+
q0, qx, qy, qz = q
|
|
123
|
+
|
|
124
|
+
# Convert quaternion to rotation matrix
|
|
125
|
+
R = np.array([
|
|
126
|
+
[q0*q0 + qx*qx - qy*qy - qz*qz, 2*(qx*qy - q0*qz), 2*(qx*qz + q0*qy)],
|
|
127
|
+
[2*(qy*qx + q0*qz), q0*q0 - qx*qx + qy*qy - qz*qz, 2*(qy*qz - q0*qx)],
|
|
128
|
+
[2*(qz*qx - q0*qy), 2*(qz*qy + q0*qx), q0*q0 - qx*qx - qy*qy + qz*qz]
|
|
129
|
+
])
|
|
130
|
+
|
|
131
|
+
# Compute translation
|
|
132
|
+
t = qc - R @ pc
|
|
133
|
+
|
|
134
|
+
# Combine into 4x4 rigid transformation matrix
|
|
135
|
+
rigid_matrix = np.eye(4)
|
|
136
|
+
rigid_matrix[:3, :3] = R
|
|
137
|
+
rigid_matrix[:3, 3] = t
|
|
138
|
+
|
|
139
|
+
return rigid_matrix
|
|
140
|
+
|
|
141
|
+
def fit_affine_model(self, matches):
|
|
142
|
+
matches = np.array(matches) # shape (N, 2, 3)
|
|
143
|
+
P = matches[:, 0] # source points
|
|
144
|
+
Q = matches[:, 1] # target points
|
|
145
|
+
weights = np.ones(P.shape[0]) # uniform weights
|
|
146
|
+
|
|
147
|
+
ws = np.sum(weights)
|
|
148
|
+
|
|
149
|
+
pc = np.average(P, axis=0, weights=weights)
|
|
150
|
+
qc = np.average(Q, axis=0, weights=weights)
|
|
151
|
+
|
|
152
|
+
P_centered = P - pc
|
|
153
|
+
Q_centered = Q - qc
|
|
154
|
+
|
|
155
|
+
A = np.zeros((3, 3))
|
|
156
|
+
B = np.zeros((3, 3))
|
|
157
|
+
|
|
158
|
+
for i in range(P.shape[0]):
|
|
159
|
+
w = weights[i]
|
|
160
|
+
p = P_centered[i]
|
|
161
|
+
q = Q_centered[i]
|
|
162
|
+
|
|
163
|
+
A += w * np.outer(p, p)
|
|
164
|
+
B += w * np.outer(p, q)
|
|
165
|
+
|
|
166
|
+
det = np.linalg.det(A)
|
|
167
|
+
if det == 0:
|
|
168
|
+
raise ValueError("Ill-defined data points (det=0)")
|
|
169
|
+
|
|
170
|
+
try:
|
|
171
|
+
A_inv = np.linalg.inv(A)
|
|
172
|
+
except np.linalg.LinAlgError:
|
|
173
|
+
# If A is not invertible, use the pseudo-inverse
|
|
174
|
+
A_inv = np.linalg.pinv(A)
|
|
175
|
+
|
|
176
|
+
M = A_inv @ B # 3x3 transformation matrix
|
|
177
|
+
|
|
178
|
+
t = qc - M @ pc # translation
|
|
179
|
+
|
|
180
|
+
affine_matrix = np.eye(4)
|
|
181
|
+
affine_matrix[:3, :3] = M
|
|
182
|
+
affine_matrix[:3, 3] = t
|
|
183
|
+
|
|
184
|
+
return affine_matrix
|
|
185
|
+
|
|
186
|
+
def test(self, candidates, model, max_epsilon, min_inlier_ratio, min_num_inliers):
|
|
187
|
+
inliers = []
|
|
188
|
+
for idxA, pointA, view_a, label_a, idxB, pointB, view_b, label_b in candidates:
|
|
189
|
+
p1_hom = np.append(pointA, 1.0)
|
|
190
|
+
transformed = model @ p1_hom
|
|
191
|
+
distance = np.linalg.norm(transformed[:3] - pointB)
|
|
192
|
+
|
|
193
|
+
if distance < max_epsilon:
|
|
194
|
+
inliers.append((idxA, pointA, view_a, label_a, idxB, pointB, view_b, label_b))
|
|
195
|
+
|
|
196
|
+
ir = len(inliers) / len(candidates)
|
|
197
|
+
is_good = len(inliers) >= min_num_inliers and ir > min_inlier_ratio
|
|
198
|
+
|
|
199
|
+
return is_good, inliers
|
|
200
|
+
|
|
201
|
+
def regularize_models(self, affine, rigid):
|
|
202
|
+
alpha=0.1
|
|
203
|
+
l1 = 1.0 - alpha
|
|
204
|
+
|
|
205
|
+
def to_array(model):
|
|
206
|
+
return [
|
|
207
|
+
model['m00'], model['m01'], model['m02'], model['m03'],
|
|
208
|
+
model['m10'], model['m11'], model['m12'], model['m13'],
|
|
209
|
+
model['m20'], model['m21'], model['m22'], model['m23'],
|
|
210
|
+
]
|
|
211
|
+
|
|
212
|
+
afs = to_array(affine)
|
|
213
|
+
bfs = to_array(rigid)
|
|
214
|
+
rfs = [l1 * a + alpha * b for a, b in zip(afs, bfs)]
|
|
215
|
+
|
|
216
|
+
keys = [
|
|
217
|
+
'm00', 'm01', 'm02', 'm03',
|
|
218
|
+
'm10', 'm11', 'm12', 'm13',
|
|
219
|
+
'm20', 'm21', 'm22', 'm23',
|
|
220
|
+
]
|
|
221
|
+
regularized = dict(zip(keys, rfs))
|
|
222
|
+
|
|
223
|
+
return regularized
|
|
224
|
+
|
|
225
|
+
def model_regularization(self, point_pairs):
|
|
226
|
+
if self.match_type == "rigid":
|
|
227
|
+
regularized_model = self.fit_rigid_model(point_pairs)
|
|
228
|
+
elif self.match_type == "affine" or self.match_type == "split-affine":
|
|
229
|
+
rigid_model = self.fit_rigid_model(point_pairs)
|
|
230
|
+
affine_model = self.fit_affine_model(point_pairs)
|
|
231
|
+
regularized_model = (1 - self.regularization_weight) * affine_model + self.regularization_weight * rigid_model
|
|
232
|
+
else:
|
|
233
|
+
raise SystemExit(f"Unsupported match type: {self.match_type}")
|
|
234
|
+
|
|
235
|
+
return regularized_model
|
|
236
|
+
|
|
237
|
+
def compute_ransac(self, candidates):
|
|
238
|
+
best_inliers = []
|
|
239
|
+
max_inliers = 0
|
|
240
|
+
best_model = None
|
|
241
|
+
|
|
242
|
+
if len(candidates) < self.model_min_matches:
|
|
243
|
+
return [], None
|
|
244
|
+
|
|
245
|
+
for _ in range(self.num_iterations):
|
|
246
|
+
indices = random.sample(range(len(candidates)), self.model_min_matches)
|
|
247
|
+
min_matches = [candidates[i] for i in indices]
|
|
248
|
+
|
|
249
|
+
try:
|
|
250
|
+
point_pairs = [(m[1], m[5]) for m in min_matches]
|
|
251
|
+
regularized_model = self.model_regularization(point_pairs)
|
|
252
|
+
except Exception as e:
|
|
253
|
+
print(e)
|
|
254
|
+
|
|
255
|
+
num_inliers = 0
|
|
256
|
+
is_good, tmp_inliers = self.test(candidates, regularized_model, self.max_epsilon, self.min_inlier_ratio, self.model_min_matches)
|
|
257
|
+
|
|
258
|
+
while is_good and num_inliers < len(tmp_inliers):
|
|
259
|
+
num_inliers = len(tmp_inliers)
|
|
260
|
+
point_pairs = [(i[1], i[5]) for i in tmp_inliers]
|
|
261
|
+
regularized_model = self.model_regularization(point_pairs)
|
|
262
|
+
is_good, tmp_inliers = self.test(candidates, regularized_model, self.max_epsilon, self.min_inlier_ratio, self.model_min_matches)
|
|
263
|
+
|
|
264
|
+
if len(tmp_inliers) > max_inliers:
|
|
265
|
+
best_inliers = tmp_inliers
|
|
266
|
+
max_inliers = len(tmp_inliers)
|
|
267
|
+
best_model = regularized_model
|
|
268
|
+
|
|
269
|
+
return best_inliers, best_model
|
|
270
|
+
|
|
271
|
+
def create_candidates(self, desc_a, desc_b):
|
|
272
|
+
match_list = []
|
|
273
|
+
|
|
274
|
+
for a in range(1):
|
|
275
|
+
for b in range(1):
|
|
276
|
+
|
|
277
|
+
matches = []
|
|
278
|
+
for i in range(3):
|
|
279
|
+
point_match = (desc_a['relative_descriptors'][i], desc_b['relative_descriptors'][i])
|
|
280
|
+
matches.append(point_match)
|
|
281
|
+
|
|
282
|
+
match_list.append(matches)
|
|
283
|
+
|
|
284
|
+
return match_list
|
|
285
|
+
|
|
286
|
+
def descriptor_distance(self, desc_a, desc_b):
|
|
287
|
+
matches_list = self.create_candidates(desc_a, desc_b)
|
|
288
|
+
|
|
289
|
+
best_similarity = float("inf")
|
|
290
|
+
best_match_set = None
|
|
291
|
+
|
|
292
|
+
for matches in matches_list:
|
|
293
|
+
try:
|
|
294
|
+
points_a = np.array([pa for pa, _ in matches])
|
|
295
|
+
points_b = np.array([pb for _, pb in matches])
|
|
296
|
+
|
|
297
|
+
squared_diff_sum = np.sum((points_a - points_b) ** 2)
|
|
298
|
+
similarity = squared_diff_sum / points_a.shape[1]
|
|
299
|
+
|
|
300
|
+
if similarity < best_similarity:
|
|
301
|
+
best_similarity = similarity
|
|
302
|
+
best_match_set = matches
|
|
303
|
+
|
|
304
|
+
except Exception:
|
|
305
|
+
continue
|
|
306
|
+
|
|
307
|
+
return best_similarity
|
|
308
|
+
|
|
309
|
+
def create_simple_point_descriptors(self, tree, points_array, idx, num_required_neighbors, matcher):
|
|
310
|
+
k = num_required_neighbors + 1
|
|
311
|
+
if len(points_array) < k:
|
|
312
|
+
return []
|
|
313
|
+
|
|
314
|
+
_, indices = tree.query(points_array, k=k)
|
|
315
|
+
|
|
316
|
+
descriptors = []
|
|
317
|
+
for i, basis_point in enumerate(points_array):
|
|
318
|
+
try:
|
|
319
|
+
neighbor_idxs = indices[i][1:]
|
|
320
|
+
neighbors = points_array[neighbor_idxs]
|
|
321
|
+
|
|
322
|
+
if len(neighbors) == num_required_neighbors:
|
|
323
|
+
idx_sets = [tuple(range(num_required_neighbors))]
|
|
324
|
+
elif len(neighbors) > num_required_neighbors:
|
|
325
|
+
idx_sets = matcher["neighbors"]
|
|
326
|
+
|
|
327
|
+
relative_vectors = neighbors - basis_point
|
|
328
|
+
|
|
329
|
+
# Final descriptor representation (as dict)
|
|
330
|
+
descriptor = {
|
|
331
|
+
"point_index": idx[i],
|
|
332
|
+
"point": basis_point,
|
|
333
|
+
"neighbors": neighbors,
|
|
334
|
+
"relative_descriptors": relative_vectors,
|
|
335
|
+
"matcher": matcher,
|
|
336
|
+
"subsets": np.stack([neighbors[list(c)] for c in idx_sets])
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
descriptors.append(descriptor)
|
|
340
|
+
|
|
341
|
+
except Exception as e:
|
|
342
|
+
raise
|
|
343
|
+
|
|
344
|
+
return descriptors
|
|
345
|
+
|
|
346
|
+
def get_candidates(self, points_a, points_b, view_a, view_b, label):
|
|
347
|
+
difference_threshold = 3.4028235e+38
|
|
348
|
+
max_value = float("inf")
|
|
349
|
+
|
|
350
|
+
# -- Get Points and Indexes
|
|
351
|
+
idx_a, coords_a = zip(*points_a)
|
|
352
|
+
idx_b, coords_b = zip(*points_b)
|
|
353
|
+
points_a_array = np.array(coords_a)
|
|
354
|
+
points_b_array = np.array(coords_b)
|
|
355
|
+
|
|
356
|
+
# --- KD Trees ---
|
|
357
|
+
tree_a = KDTree(points_a_array)
|
|
358
|
+
tree_b = KDTree(points_b_array)
|
|
359
|
+
|
|
360
|
+
# --- Subset Matcher ---
|
|
361
|
+
subset_size = self.num_neighbors
|
|
362
|
+
total_neighbors = self.num_neighbors + self.redundancy
|
|
363
|
+
neighbor_indices_combinations = list(itertools.combinations(range(total_neighbors), subset_size))
|
|
364
|
+
num_combinations = len(neighbor_indices_combinations)
|
|
365
|
+
num_matchings = num_combinations * num_combinations
|
|
366
|
+
matcher = {
|
|
367
|
+
"subset_size": subset_size,
|
|
368
|
+
"num_neighbors": total_neighbors,
|
|
369
|
+
"neighbors": neighbor_indices_combinations,
|
|
370
|
+
"num_combinations": num_combinations,
|
|
371
|
+
"num_matchings": num_matchings
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
# --- Descriptors ---
|
|
375
|
+
descriptors_a = self.create_simple_point_descriptors(tree_a, points_a_array, idx_a, self.num_required_neighbors, matcher)
|
|
376
|
+
descriptors_b = self.create_simple_point_descriptors(tree_b, points_b_array, idx_b, self.num_required_neighbors, matcher)
|
|
377
|
+
|
|
378
|
+
# --- Descriptor Matching ---
|
|
379
|
+
correspondence_candidates = []
|
|
380
|
+
|
|
381
|
+
out_of_radius = 0
|
|
382
|
+
passed_lowes = 0
|
|
383
|
+
first_if = 0
|
|
384
|
+
second_if = 0
|
|
385
|
+
|
|
386
|
+
for desc_a in descriptors_a:
|
|
387
|
+
best_difference = float("inf")
|
|
388
|
+
second_best_difference = float("inf")
|
|
389
|
+
best_match = None
|
|
390
|
+
second_best_match = None
|
|
391
|
+
|
|
392
|
+
for desc_b in descriptors_b:
|
|
393
|
+
|
|
394
|
+
if np.linalg.norm(desc_a['point'] - desc_b['point']) > self.search_radius:
|
|
395
|
+
out_of_radius += 1
|
|
396
|
+
continue
|
|
397
|
+
|
|
398
|
+
difference = self.descriptor_distance(desc_a, desc_b)
|
|
399
|
+
|
|
400
|
+
if difference < second_best_difference:
|
|
401
|
+
second_best_difference = difference
|
|
402
|
+
second_best_match = desc_b
|
|
403
|
+
first_if += 1
|
|
404
|
+
|
|
405
|
+
if second_best_difference < best_difference:
|
|
406
|
+
tmp_diff = second_best_difference
|
|
407
|
+
tmp_match = second_best_match
|
|
408
|
+
second_best_difference = best_difference
|
|
409
|
+
second_best_match = best_match
|
|
410
|
+
best_difference = tmp_diff
|
|
411
|
+
best_match = tmp_match
|
|
412
|
+
second_if += 1
|
|
413
|
+
|
|
414
|
+
# --- Lowe's Test ---
|
|
415
|
+
if best_difference < difference_threshold and best_difference * self.significance < second_best_difference and second_best_difference != max_value:
|
|
416
|
+
correspondence_candidates.append((
|
|
417
|
+
desc_a['point_index'],
|
|
418
|
+
desc_a['point'],
|
|
419
|
+
view_a,
|
|
420
|
+
label,
|
|
421
|
+
best_match['point_index'],
|
|
422
|
+
best_match['point'],
|
|
423
|
+
view_b,
|
|
424
|
+
label
|
|
425
|
+
))
|
|
426
|
+
passed_lowes += 1
|
|
427
|
+
|
|
428
|
+
# print(f"out of range: {out_of_radius}, first if: {first_if}, second if: {second_if}, passed lowes: {passed_lowes}")
|
|
429
|
+
return correspondence_candidates
|
|
430
|
+
|
|
431
|
+
def get_tile_dims(self, view1):
|
|
432
|
+
stripped = view1.strip("()")
|
|
433
|
+
parts = stripped.split(", ")
|
|
434
|
+
tp_id = int(parts[0].split("=")[1])
|
|
435
|
+
setup_id = int(parts[1].split("=")[1])
|
|
436
|
+
|
|
437
|
+
image_loader = self.data_global.get('imageLoader', {})
|
|
438
|
+
|
|
439
|
+
# Loop through all view entries in the image loader
|
|
440
|
+
for entry in image_loader:
|
|
441
|
+
entry_setup = int(entry.get('view_setup', -1))
|
|
442
|
+
entry_tp = int(entry.get('timepoint', -1))
|
|
443
|
+
|
|
444
|
+
if entry_setup == setup_id and entry_tp == tp_id:
|
|
445
|
+
file_path = self.image_file_prefix + entry.get('file_path')
|
|
446
|
+
if self.input_type == "tiff":
|
|
447
|
+
img = CustomBioImage(file_path, reader=bioio_tifffile.Reader)
|
|
448
|
+
dask_array = img.get_dask_stack()[0, 0, 0, :, :, :]
|
|
449
|
+
shape = dask_array.shape
|
|
450
|
+
|
|
451
|
+
elif self.input_type == "zarr":
|
|
452
|
+
s3 = s3fs.S3FileSystem(anon=False)
|
|
453
|
+
full_path = f"{file_path}/0"
|
|
454
|
+
store = s3fs.S3Map(root=full_path, s3=s3)
|
|
455
|
+
zarr_array = zarr.open(store, mode='r')
|
|
456
|
+
dask_array = da.from_zarr(zarr_array)[0, 0, :, :, :]
|
|
457
|
+
shape = dask_array.shape
|
|
458
|
+
|
|
459
|
+
return shape[::-1]
|
|
460
|
+
|
|
461
|
+
def invert_transformation_matrix(self, view_2):
|
|
462
|
+
"""
|
|
463
|
+
Compose and invert all ViewTransforms for the given view key (timepoint, setup).
|
|
464
|
+
"""
|
|
465
|
+
stripped = view_2.strip("()")
|
|
466
|
+
parts = stripped.split(", ")
|
|
467
|
+
tp_id = int(parts[0].split("=")[1])
|
|
468
|
+
setup_id = int(parts[1].split("=")[1])
|
|
469
|
+
view_key = (tp_id, setup_id)
|
|
470
|
+
|
|
471
|
+
# Get all transforms for this view
|
|
472
|
+
transforms = self.view_registrations.get(view_key, [])
|
|
473
|
+
if not transforms:
|
|
474
|
+
raise ValueError(f"No transforms found for view {view_key}")
|
|
475
|
+
|
|
476
|
+
final_matrix = np.eye(4)
|
|
477
|
+
|
|
478
|
+
for i, transform in enumerate(transforms):
|
|
479
|
+
affine_str = transform.get("affine")
|
|
480
|
+
if not affine_str:
|
|
481
|
+
continue
|
|
482
|
+
|
|
483
|
+
values = [float(x) for x in affine_str.strip().split()]
|
|
484
|
+
if len(values) != 12:
|
|
485
|
+
raise ValueError(f"Transform {i+1} in view {view_key} has {len(values)} values, expected 12.")
|
|
486
|
+
|
|
487
|
+
matrix3x4 = np.array(values).reshape(3, 4)
|
|
488
|
+
matrix4x4 = np.eye(4)
|
|
489
|
+
matrix4x4[:3, :4] = matrix3x4
|
|
490
|
+
|
|
491
|
+
# Combine with running matrix
|
|
492
|
+
final_matrix = final_matrix @ matrix4x4
|
|
493
|
+
|
|
494
|
+
# Return the inverse
|
|
495
|
+
return np.linalg.inv(final_matrix)
|
|
496
|
+
|
|
497
|
+
def filter_for_overlapping_points(self, points_a, points_b, view_a, view_b):
|
|
498
|
+
points_a = list(enumerate(points_a))
|
|
499
|
+
points_b = list(enumerate(points_b))
|
|
500
|
+
|
|
501
|
+
if not points_a or not points_b:
|
|
502
|
+
return [], []
|
|
503
|
+
|
|
504
|
+
# Check points_a against view_b's interval
|
|
505
|
+
overlapping_a = []
|
|
506
|
+
tinv_b = self.invert_transformation_matrix(view_b)
|
|
507
|
+
|
|
508
|
+
view_b_key = tuple(map(int, re.findall(r'\d+', view_b)))
|
|
509
|
+
dim_b = self.data_global['viewSetup']['byId'][view_b_key[1]]
|
|
510
|
+
interval_b = {'min': (0, 0, 0), 'max': dim_b['size']}
|
|
511
|
+
|
|
512
|
+
for i in reversed(range(len(points_a))):
|
|
513
|
+
idx, point = points_a[i]
|
|
514
|
+
p_h = np.append(point, 1.0)
|
|
515
|
+
transformed = tinv_b @ p_h
|
|
516
|
+
x, y, z = transformed[:3]
|
|
517
|
+
x_min, y_min, z_min = interval_b['min']
|
|
518
|
+
x_max, y_max, z_max = interval_b['max']
|
|
519
|
+
|
|
520
|
+
if x_min <= x < x_max and y_min <= y < y_max and z_min <= z < z_max:
|
|
521
|
+
overlapping_a.append((idx, point))
|
|
522
|
+
del points_a[i]
|
|
523
|
+
|
|
524
|
+
# Check points_b against view_a's interval
|
|
525
|
+
overlapping_b = []
|
|
526
|
+
tinv_a = self.invert_transformation_matrix(view_a)
|
|
527
|
+
|
|
528
|
+
view_a_key = tuple(map(int, re.findall(r'\d+', view_a)))
|
|
529
|
+
dim_a = self.data_global['viewSetup']['byId'][view_a_key[1]]
|
|
530
|
+
interval_a = {'min': (0, 0, 0), 'max': dim_a['size']}
|
|
531
|
+
|
|
532
|
+
for i in reversed(range(len(points_b))):
|
|
533
|
+
idx, point = points_b[i]
|
|
534
|
+
p_h = np.append(point, 1.0)
|
|
535
|
+
transformed = tinv_a @ p_h
|
|
536
|
+
x, y, z = transformed[:3]
|
|
537
|
+
x_min, y_min, z_min = interval_a['min']
|
|
538
|
+
x_max, y_max, z_max = interval_a['max']
|
|
539
|
+
|
|
540
|
+
if x_min <= x < x_max and y_min <= y < y_max and z_min <= z < z_max:
|
|
541
|
+
overlapping_b.append((idx, point))
|
|
542
|
+
del points_b[i]
|
|
543
|
+
|
|
544
|
+
return overlapping_a, overlapping_b
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import zarr
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
import s3fs
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Save Matches saves (matched) corresponding interest points to N5 format
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
class SaveMatches:
|
|
10
|
+
def __init__(self, all_results, n5_output_path, data_global, match_type):
|
|
11
|
+
self.all_results = all_results
|
|
12
|
+
self.n5_output_path = n5_output_path
|
|
13
|
+
self.data_global = data_global
|
|
14
|
+
self.match_type = match_type
|
|
15
|
+
|
|
16
|
+
def save_correspondences(self):
|
|
17
|
+
"""
|
|
18
|
+
Save correspondences for each view/label, aggregating all matches involving that view/label.
|
|
19
|
+
Print a detailed summary with breakdowns.
|
|
20
|
+
"""
|
|
21
|
+
def parse_view(v: str):
|
|
22
|
+
tp = int(v.split("tpId=")[1].split(",")[0])
|
|
23
|
+
vs = int(v.split("setupId=")[1].split(")")[0])
|
|
24
|
+
return tp, vs
|
|
25
|
+
|
|
26
|
+
# Group results back per view
|
|
27
|
+
grouped_by_viewA = defaultdict(list)
|
|
28
|
+
for idxA, _, viewA, label_a, idxB, _, viewB, label_b in self.all_results:
|
|
29
|
+
grouped_by_viewA[viewA].append((idxA, idxB, viewB, label_b))
|
|
30
|
+
grouped_by_viewA[viewB].append((idxB, idxA, viewA, label_a))
|
|
31
|
+
|
|
32
|
+
# Create idmap per view of all corresponding groups
|
|
33
|
+
idMaps = {}
|
|
34
|
+
for viewA, matches in grouped_by_viewA.items():
|
|
35
|
+
target_keys = sorted({
|
|
36
|
+
f"{tpB},{vsB},{labB}"
|
|
37
|
+
for (_iA, _iB, viewB, labB) in matches
|
|
38
|
+
for (tpB, vsB) in [parse_view(viewB)]
|
|
39
|
+
})
|
|
40
|
+
idMaps[viewA] = {k: i for i, k in enumerate(target_keys)}
|
|
41
|
+
|
|
42
|
+
# Format data for injection
|
|
43
|
+
grouped_with_ids = defaultdict(list)
|
|
44
|
+
for viewA, matches in grouped_by_viewA.items():
|
|
45
|
+
idMap = idMaps[viewA]
|
|
46
|
+
for idxA, idxB, viewB, label in matches:
|
|
47
|
+
tp = int(viewB.split("tpId=")[1].split(",")[0])
|
|
48
|
+
vs = int(viewB.split("setupId=")[1].split(")")[0])
|
|
49
|
+
key = f"{tp},{vs},{label}"
|
|
50
|
+
view_id = idMap[key]
|
|
51
|
+
grouped_with_ids[viewA, label].append((idxA, idxB, view_id))
|
|
52
|
+
|
|
53
|
+
# Save idmap and corr points per view
|
|
54
|
+
for (viewA, labelA), corr_list in grouped_with_ids.items():
|
|
55
|
+
tpA = int(viewA.split("tpId=")[1].split(",")[0])
|
|
56
|
+
vsA = int(viewA.split("setupId=")[1].split(")")[0])
|
|
57
|
+
idMap = idMaps[viewA]
|
|
58
|
+
|
|
59
|
+
if len(corr_list) == 0:
|
|
60
|
+
continue
|
|
61
|
+
|
|
62
|
+
# Output path
|
|
63
|
+
full_path = f"{self.n5_output_path}interestpoints.n5/tpId_{tpA}_viewSetupId_{vsA}/{labelA}/correspondences/"
|
|
64
|
+
|
|
65
|
+
if full_path.startswith("s3://"):
|
|
66
|
+
path = full_path.replace("s3://", "")
|
|
67
|
+
self.s3_filesystem = s3fs.S3FileSystem()
|
|
68
|
+
store = s3fs.S3Map(root=path, s3=self.s3_filesystem, check=False)
|
|
69
|
+
root = zarr.open_group(store=store, mode='a')
|
|
70
|
+
else:
|
|
71
|
+
# Write to Zarr N5
|
|
72
|
+
store = zarr.N5Store(full_path)
|
|
73
|
+
root = zarr.group(store=store, overwrite="true")
|
|
74
|
+
|
|
75
|
+
# Delete existing 'data' array
|
|
76
|
+
if "data" in root:
|
|
77
|
+
del root["data"]
|
|
78
|
+
|
|
79
|
+
# Set group-level attributes
|
|
80
|
+
root.attrs.update({
|
|
81
|
+
"correspondences": "1.0.0",
|
|
82
|
+
"idMap": idMap
|
|
83
|
+
})
|
|
84
|
+
|
|
85
|
+
# Create dataset inside the group
|
|
86
|
+
root.create_dataset(
|
|
87
|
+
name="data",
|
|
88
|
+
data=corr_list,
|
|
89
|
+
dtype='u8',
|
|
90
|
+
chunks=(min(300000, len(corr_list)), 1),
|
|
91
|
+
compressor=zarr.GZip()
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def clear_correspondence(self):
|
|
95
|
+
if self.n5_output_path.startswith("s3://"):
|
|
96
|
+
root_path = self.n5_output_path.replace("s3://", "") + "interestpoints.n5"
|
|
97
|
+
s3 = s3fs.S3FileSystem()
|
|
98
|
+
store = s3fs.S3Map(root=root_path, s3=s3, check=False)
|
|
99
|
+
else:
|
|
100
|
+
root_path = self.n5_output_path + "interestpoints.n5"
|
|
101
|
+
store = zarr.N5Store(root_path)
|
|
102
|
+
|
|
103
|
+
root = zarr.open_group(store=store, mode="a")
|
|
104
|
+
|
|
105
|
+
views = list(self.data_global['viewsInterestPoints'].keys())
|
|
106
|
+
for tp, vs in views:
|
|
107
|
+
labels = self.data_global['viewsInterestPoints'][(tp, vs)]['label']
|
|
108
|
+
for label in labels:
|
|
109
|
+
corr_path = f"tpId_{tp}_viewSetupId_{vs}/{label}/correspondences"
|
|
110
|
+
try:
|
|
111
|
+
if corr_path in root:
|
|
112
|
+
del root[corr_path]
|
|
113
|
+
elif f"{corr_path}/data" in root:
|
|
114
|
+
del root[f"{corr_path}/data"]
|
|
115
|
+
except Exception as e:
|
|
116
|
+
print(f"⚠️ Could not delete {corr_path}: {e}")
|
|
117
|
+
|
|
118
|
+
def run(self):
|
|
119
|
+
self.clear_correspondence()
|
|
120
|
+
self.save_correspondences()
|