Rhapso 0.1.92__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. Rhapso/__init__.py +1 -0
  2. Rhapso/data_prep/__init__.py +2 -0
  3. Rhapso/data_prep/n5_reader.py +188 -0
  4. Rhapso/data_prep/s3_big_stitcher_reader.py +55 -0
  5. Rhapso/data_prep/xml_to_dataframe.py +215 -0
  6. Rhapso/detection/__init__.py +5 -0
  7. Rhapso/detection/advanced_refinement.py +203 -0
  8. Rhapso/detection/difference_of_gaussian.py +324 -0
  9. Rhapso/detection/image_reader.py +117 -0
  10. Rhapso/detection/metadata_builder.py +130 -0
  11. Rhapso/detection/overlap_detection.py +327 -0
  12. Rhapso/detection/points_validation.py +49 -0
  13. Rhapso/detection/save_interest_points.py +265 -0
  14. Rhapso/detection/view_transform_models.py +67 -0
  15. Rhapso/fusion/__init__.py +0 -0
  16. Rhapso/fusion/affine_fusion/__init__.py +2 -0
  17. Rhapso/fusion/affine_fusion/blend.py +289 -0
  18. Rhapso/fusion/affine_fusion/fusion.py +601 -0
  19. Rhapso/fusion/affine_fusion/geometry.py +159 -0
  20. Rhapso/fusion/affine_fusion/io.py +546 -0
  21. Rhapso/fusion/affine_fusion/script_utils.py +111 -0
  22. Rhapso/fusion/affine_fusion/setup.py +4 -0
  23. Rhapso/fusion/affine_fusion_worker.py +234 -0
  24. Rhapso/fusion/multiscale/__init__.py +0 -0
  25. Rhapso/fusion/multiscale/aind_hcr_data_transformation/__init__.py +19 -0
  26. Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/__init__.py +3 -0
  27. Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/czi_to_zarr.py +698 -0
  28. Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/zarr_writer.py +265 -0
  29. Rhapso/fusion/multiscale/aind_hcr_data_transformation/models.py +81 -0
  30. Rhapso/fusion/multiscale/aind_hcr_data_transformation/utils/__init__.py +3 -0
  31. Rhapso/fusion/multiscale/aind_hcr_data_transformation/utils/utils.py +526 -0
  32. Rhapso/fusion/multiscale/aind_hcr_data_transformation/zeiss_job.py +249 -0
  33. Rhapso/fusion/multiscale/aind_z1_radial_correction/__init__.py +21 -0
  34. Rhapso/fusion/multiscale/aind_z1_radial_correction/array_to_zarr.py +257 -0
  35. Rhapso/fusion/multiscale/aind_z1_radial_correction/radial_correction.py +557 -0
  36. Rhapso/fusion/multiscale/aind_z1_radial_correction/run_capsule.py +98 -0
  37. Rhapso/fusion/multiscale/aind_z1_radial_correction/utils/__init__.py +3 -0
  38. Rhapso/fusion/multiscale/aind_z1_radial_correction/utils/utils.py +266 -0
  39. Rhapso/fusion/multiscale/aind_z1_radial_correction/worker.py +89 -0
  40. Rhapso/fusion/multiscale_worker.py +113 -0
  41. Rhapso/fusion/neuroglancer_link_gen/__init__.py +8 -0
  42. Rhapso/fusion/neuroglancer_link_gen/dispim_link.py +235 -0
  43. Rhapso/fusion/neuroglancer_link_gen/exaspim_link.py +127 -0
  44. Rhapso/fusion/neuroglancer_link_gen/hcr_link.py +368 -0
  45. Rhapso/fusion/neuroglancer_link_gen/iSPIM_top.py +47 -0
  46. Rhapso/fusion/neuroglancer_link_gen/link_utils.py +239 -0
  47. Rhapso/fusion/neuroglancer_link_gen/main.py +299 -0
  48. Rhapso/fusion/neuroglancer_link_gen/ng_layer.py +1434 -0
  49. Rhapso/fusion/neuroglancer_link_gen/ng_state.py +1123 -0
  50. Rhapso/fusion/neuroglancer_link_gen/parsers.py +336 -0
  51. Rhapso/fusion/neuroglancer_link_gen/raw_link.py +116 -0
  52. Rhapso/fusion/neuroglancer_link_gen/utils/__init__.py +4 -0
  53. Rhapso/fusion/neuroglancer_link_gen/utils/shader_utils.py +85 -0
  54. Rhapso/fusion/neuroglancer_link_gen/utils/transfer.py +43 -0
  55. Rhapso/fusion/neuroglancer_link_gen/utils/utils.py +303 -0
  56. Rhapso/fusion/neuroglancer_link_gen_worker.py +30 -0
  57. Rhapso/matching/__init__.py +0 -0
  58. Rhapso/matching/load_and_transform_points.py +458 -0
  59. Rhapso/matching/ransac_matching.py +544 -0
  60. Rhapso/matching/save_matches.py +120 -0
  61. Rhapso/matching/xml_parser.py +302 -0
  62. Rhapso/pipelines/__init__.py +0 -0
  63. Rhapso/pipelines/ray/__init__.py +0 -0
  64. Rhapso/pipelines/ray/aws/__init__.py +0 -0
  65. Rhapso/pipelines/ray/aws/alignment_pipeline.py +227 -0
  66. Rhapso/pipelines/ray/aws/config/__init__.py +0 -0
  67. Rhapso/pipelines/ray/evaluation.py +71 -0
  68. Rhapso/pipelines/ray/interest_point_detection.py +137 -0
  69. Rhapso/pipelines/ray/interest_point_matching.py +110 -0
  70. Rhapso/pipelines/ray/local/__init__.py +0 -0
  71. Rhapso/pipelines/ray/local/alignment_pipeline.py +167 -0
  72. Rhapso/pipelines/ray/matching_stats.py +104 -0
  73. Rhapso/pipelines/ray/param/__init__.py +0 -0
  74. Rhapso/pipelines/ray/solver.py +120 -0
  75. Rhapso/pipelines/ray/split_dataset.py +78 -0
  76. Rhapso/solver/__init__.py +0 -0
  77. Rhapso/solver/compute_tiles.py +562 -0
  78. Rhapso/solver/concatenate_models.py +116 -0
  79. Rhapso/solver/connected_graphs.py +111 -0
  80. Rhapso/solver/data_prep.py +181 -0
  81. Rhapso/solver/global_optimization.py +410 -0
  82. Rhapso/solver/model_and_tile_setup.py +109 -0
  83. Rhapso/solver/pre_align_tiles.py +323 -0
  84. Rhapso/solver/save_results.py +97 -0
  85. Rhapso/solver/view_transforms.py +75 -0
  86. Rhapso/solver/xml_to_dataframe_solver.py +213 -0
  87. Rhapso/split_dataset/__init__.py +0 -0
  88. Rhapso/split_dataset/compute_grid_rules.py +78 -0
  89. Rhapso/split_dataset/save_points.py +101 -0
  90. Rhapso/split_dataset/save_xml.py +377 -0
  91. Rhapso/split_dataset/split_images.py +537 -0
  92. Rhapso/split_dataset/xml_to_dataframe_split.py +219 -0
  93. rhapso-0.1.92.dist-info/METADATA +39 -0
  94. rhapso-0.1.92.dist-info/RECORD +101 -0
  95. rhapso-0.1.92.dist-info/WHEEL +5 -0
  96. rhapso-0.1.92.dist-info/licenses/LICENSE +21 -0
  97. rhapso-0.1.92.dist-info/top_level.txt +2 -0
  98. tests/__init__.py +1 -0
  99. tests/test_detection.py +17 -0
  100. tests/test_matching.py +21 -0
  101. tests/test_solving.py +21 -0
@@ -0,0 +1,544 @@
1
+ import numpy as np
2
+ from sklearn.neighbors import KDTree
3
+ import itertools
4
+ import random
5
+ from scipy.linalg import eigh
6
+ import zarr
7
+ from bioio import BioImage
8
+ import bioio_tifffile
9
+ import dask.array as da
10
+ import s3fs
11
+ import copy
12
+ import re
13
+
14
+ """
15
+ Utility class to find interest point match candidates and filter with ransac
16
+ """
17
+
18
+ class CustomBioImage(BioImage):
19
+ def standard_metadata(self):
20
+ pass
21
+
22
+ def scale(self):
23
+ pass
24
+
25
+ def time_interval(self):
26
+ pass
27
+
28
+ class RansacMatching:
29
+ def __init__(self, data_global, num_neighbors, redundancy, significance, num_required_neighbors, match_type,
30
+ max_epsilon, min_inlier_ratio, num_iterations, model_min_matches, regularization_weight,
31
+ search_radius, view_registrations, input_type, image_file_prefix):
32
+ self.data_global = data_global
33
+ self.num_neighbors = num_neighbors
34
+ self.redundancy = redundancy
35
+ self.significance = significance
36
+ self.num_required_neighbors = num_required_neighbors
37
+ self.match_type = match_type
38
+ self.max_epsilon = max_epsilon
39
+ self.min_inlier_ratio = min_inlier_ratio
40
+ self.num_iterations = num_iterations
41
+ self.model_min_matches = model_min_matches
42
+ self.regularization_weight = regularization_weight
43
+ self.search_radius = search_radius
44
+ self.view_registrations = view_registrations
45
+ self.input_type = input_type
46
+ self.image_file_prefix = image_file_prefix
47
+
48
+ def filter_inliers(self, candidates, initial_model):
49
+ max_trust = 4.0
50
+
51
+ if len(candidates) < self.model_min_matches:
52
+ return []
53
+
54
+ model_copy = copy.deepcopy(initial_model)
55
+ inliers = candidates[:]
56
+ temp = []
57
+
58
+ while True:
59
+ temp = copy.deepcopy(inliers)
60
+ num_inliers = len(inliers)
61
+
62
+ point_pairs = [(m[1], m[5]) for m in inliers]
63
+ model_copy = self.model_regularization(point_pairs)
64
+
65
+ # Apply model and collect errors
66
+ errors = []
67
+ for match in temp:
68
+ p1 = np.array(match[1])
69
+ p2 = np.array(match[4])
70
+ p1_h = np.append(p1, 1.0)
71
+ p1_trans = model_copy @ p1_h
72
+ error = np.linalg.norm(p1_trans[:3] - p2)
73
+ errors.append(error)
74
+
75
+ median_error = np.median(errors)
76
+ threshold = median_error * max_trust
77
+
78
+ # Filter based on threshold
79
+ inliers = [m for m, err in zip(temp, errors) if err <= threshold]
80
+
81
+ if num_inliers <= len(inliers):
82
+ break
83
+
84
+ if num_inliers < self.model_min_matches:
85
+ return []
86
+
87
+ return inliers
88
+
89
+ def fit_rigid_model(self, matches):
90
+ matches = np.array(matches) # shape (N, 2, 3)
91
+ P = matches[:, 0] # source points
92
+ Q = matches[:, 1] # target points
93
+ weights = np.ones(P.shape[0]) # uniform weights for now
94
+
95
+ sum_w = np.sum(weights)
96
+
97
+ # Weighted centroids
98
+ pc = np.average(P, axis=0, weights=weights)
99
+ qc = np.average(Q, axis=0, weights=weights)
100
+
101
+ # Centered and weighted coordinates
102
+ P_centered = (P - pc) * weights[:, None]
103
+ Q_centered = Q - qc
104
+
105
+ # Cross-covariance matrix S
106
+ S = P_centered.T @ Q_centered # shape: (3, 3)
107
+ Sxx, Sxy, Sxz = S[0]
108
+ Syx, Syy, Syz = S[1]
109
+ Szx, Szy, Szz = S[2]
110
+
111
+ # Build 4x4 N matrix for quaternion extraction
112
+ N = np.array([
113
+ [Sxx + Syy + Szz, Syz - Szy, Szx - Sxz, Sxy - Syx],
114
+ [Syz - Szy, Sxx - Syy - Szz, Sxy + Syx, Szx + Sxz],
115
+ [Szx - Sxz, Sxy + Syx, -Sxx + Syy - Szz, Syz + Szy],
116
+ [Sxy - Syx, Szx + Sxz, Syz + Szy, -Sxx - Syy + Szz]
117
+ ])
118
+
119
+ # Find eigenvector with largest eigenvalue
120
+ eigenvalues, eigenvectors = eigh(N)
121
+ q = eigenvectors[:, np.argmax(eigenvalues)] # q = [q0, qx, qy, qz]
122
+ q0, qx, qy, qz = q
123
+
124
+ # Convert quaternion to rotation matrix
125
+ R = np.array([
126
+ [q0*q0 + qx*qx - qy*qy - qz*qz, 2*(qx*qy - q0*qz), 2*(qx*qz + q0*qy)],
127
+ [2*(qy*qx + q0*qz), q0*q0 - qx*qx + qy*qy - qz*qz, 2*(qy*qz - q0*qx)],
128
+ [2*(qz*qx - q0*qy), 2*(qz*qy + q0*qx), q0*q0 - qx*qx - qy*qy + qz*qz]
129
+ ])
130
+
131
+ # Compute translation
132
+ t = qc - R @ pc
133
+
134
+ # Combine into 4x4 rigid transformation matrix
135
+ rigid_matrix = np.eye(4)
136
+ rigid_matrix[:3, :3] = R
137
+ rigid_matrix[:3, 3] = t
138
+
139
+ return rigid_matrix
140
+
141
+ def fit_affine_model(self, matches):
142
+ matches = np.array(matches) # shape (N, 2, 3)
143
+ P = matches[:, 0] # source points
144
+ Q = matches[:, 1] # target points
145
+ weights = np.ones(P.shape[0]) # uniform weights
146
+
147
+ ws = np.sum(weights)
148
+
149
+ pc = np.average(P, axis=0, weights=weights)
150
+ qc = np.average(Q, axis=0, weights=weights)
151
+
152
+ P_centered = P - pc
153
+ Q_centered = Q - qc
154
+
155
+ A = np.zeros((3, 3))
156
+ B = np.zeros((3, 3))
157
+
158
+ for i in range(P.shape[0]):
159
+ w = weights[i]
160
+ p = P_centered[i]
161
+ q = Q_centered[i]
162
+
163
+ A += w * np.outer(p, p)
164
+ B += w * np.outer(p, q)
165
+
166
+ det = np.linalg.det(A)
167
+ if det == 0:
168
+ raise ValueError("Ill-defined data points (det=0)")
169
+
170
+ try:
171
+ A_inv = np.linalg.inv(A)
172
+ except np.linalg.LinAlgError:
173
+ # If A is not invertible, use the pseudo-inverse
174
+ A_inv = np.linalg.pinv(A)
175
+
176
+ M = A_inv @ B # 3x3 transformation matrix
177
+
178
+ t = qc - M @ pc # translation
179
+
180
+ affine_matrix = np.eye(4)
181
+ affine_matrix[:3, :3] = M
182
+ affine_matrix[:3, 3] = t
183
+
184
+ return affine_matrix
185
+
186
+ def test(self, candidates, model, max_epsilon, min_inlier_ratio, min_num_inliers):
187
+ inliers = []
188
+ for idxA, pointA, view_a, label_a, idxB, pointB, view_b, label_b in candidates:
189
+ p1_hom = np.append(pointA, 1.0)
190
+ transformed = model @ p1_hom
191
+ distance = np.linalg.norm(transformed[:3] - pointB)
192
+
193
+ if distance < max_epsilon:
194
+ inliers.append((idxA, pointA, view_a, label_a, idxB, pointB, view_b, label_b))
195
+
196
+ ir = len(inliers) / len(candidates)
197
+ is_good = len(inliers) >= min_num_inliers and ir > min_inlier_ratio
198
+
199
+ return is_good, inliers
200
+
201
+ def regularize_models(self, affine, rigid):
202
+ alpha=0.1
203
+ l1 = 1.0 - alpha
204
+
205
+ def to_array(model):
206
+ return [
207
+ model['m00'], model['m01'], model['m02'], model['m03'],
208
+ model['m10'], model['m11'], model['m12'], model['m13'],
209
+ model['m20'], model['m21'], model['m22'], model['m23'],
210
+ ]
211
+
212
+ afs = to_array(affine)
213
+ bfs = to_array(rigid)
214
+ rfs = [l1 * a + alpha * b for a, b in zip(afs, bfs)]
215
+
216
+ keys = [
217
+ 'm00', 'm01', 'm02', 'm03',
218
+ 'm10', 'm11', 'm12', 'm13',
219
+ 'm20', 'm21', 'm22', 'm23',
220
+ ]
221
+ regularized = dict(zip(keys, rfs))
222
+
223
+ return regularized
224
+
225
+ def model_regularization(self, point_pairs):
226
+ if self.match_type == "rigid":
227
+ regularized_model = self.fit_rigid_model(point_pairs)
228
+ elif self.match_type == "affine" or self.match_type == "split-affine":
229
+ rigid_model = self.fit_rigid_model(point_pairs)
230
+ affine_model = self.fit_affine_model(point_pairs)
231
+ regularized_model = (1 - self.regularization_weight) * affine_model + self.regularization_weight * rigid_model
232
+ else:
233
+ raise SystemExit(f"Unsupported match type: {self.match_type}")
234
+
235
+ return regularized_model
236
+
237
+ def compute_ransac(self, candidates):
238
+ best_inliers = []
239
+ max_inliers = 0
240
+ best_model = None
241
+
242
+ if len(candidates) < self.model_min_matches:
243
+ return [], None
244
+
245
+ for _ in range(self.num_iterations):
246
+ indices = random.sample(range(len(candidates)), self.model_min_matches)
247
+ min_matches = [candidates[i] for i in indices]
248
+
249
+ try:
250
+ point_pairs = [(m[1], m[5]) for m in min_matches]
251
+ regularized_model = self.model_regularization(point_pairs)
252
+ except Exception as e:
253
+ print(e)
254
+
255
+ num_inliers = 0
256
+ is_good, tmp_inliers = self.test(candidates, regularized_model, self.max_epsilon, self.min_inlier_ratio, self.model_min_matches)
257
+
258
+ while is_good and num_inliers < len(tmp_inliers):
259
+ num_inliers = len(tmp_inliers)
260
+ point_pairs = [(i[1], i[5]) for i in tmp_inliers]
261
+ regularized_model = self.model_regularization(point_pairs)
262
+ is_good, tmp_inliers = self.test(candidates, regularized_model, self.max_epsilon, self.min_inlier_ratio, self.model_min_matches)
263
+
264
+ if len(tmp_inliers) > max_inliers:
265
+ best_inliers = tmp_inliers
266
+ max_inliers = len(tmp_inliers)
267
+ best_model = regularized_model
268
+
269
+ return best_inliers, best_model
270
+
271
+ def create_candidates(self, desc_a, desc_b):
272
+ match_list = []
273
+
274
+ for a in range(1):
275
+ for b in range(1):
276
+
277
+ matches = []
278
+ for i in range(3):
279
+ point_match = (desc_a['relative_descriptors'][i], desc_b['relative_descriptors'][i])
280
+ matches.append(point_match)
281
+
282
+ match_list.append(matches)
283
+
284
+ return match_list
285
+
286
+ def descriptor_distance(self, desc_a, desc_b):
287
+ matches_list = self.create_candidates(desc_a, desc_b)
288
+
289
+ best_similarity = float("inf")
290
+ best_match_set = None
291
+
292
+ for matches in matches_list:
293
+ try:
294
+ points_a = np.array([pa for pa, _ in matches])
295
+ points_b = np.array([pb for _, pb in matches])
296
+
297
+ squared_diff_sum = np.sum((points_a - points_b) ** 2)
298
+ similarity = squared_diff_sum / points_a.shape[1]
299
+
300
+ if similarity < best_similarity:
301
+ best_similarity = similarity
302
+ best_match_set = matches
303
+
304
+ except Exception:
305
+ continue
306
+
307
+ return best_similarity
308
+
309
+ def create_simple_point_descriptors(self, tree, points_array, idx, num_required_neighbors, matcher):
310
+ k = num_required_neighbors + 1
311
+ if len(points_array) < k:
312
+ return []
313
+
314
+ _, indices = tree.query(points_array, k=k)
315
+
316
+ descriptors = []
317
+ for i, basis_point in enumerate(points_array):
318
+ try:
319
+ neighbor_idxs = indices[i][1:]
320
+ neighbors = points_array[neighbor_idxs]
321
+
322
+ if len(neighbors) == num_required_neighbors:
323
+ idx_sets = [tuple(range(num_required_neighbors))]
324
+ elif len(neighbors) > num_required_neighbors:
325
+ idx_sets = matcher["neighbors"]
326
+
327
+ relative_vectors = neighbors - basis_point
328
+
329
+ # Final descriptor representation (as dict)
330
+ descriptor = {
331
+ "point_index": idx[i],
332
+ "point": basis_point,
333
+ "neighbors": neighbors,
334
+ "relative_descriptors": relative_vectors,
335
+ "matcher": matcher,
336
+ "subsets": np.stack([neighbors[list(c)] for c in idx_sets])
337
+ }
338
+
339
+ descriptors.append(descriptor)
340
+
341
+ except Exception as e:
342
+ raise
343
+
344
+ return descriptors
345
+
346
+ def get_candidates(self, points_a, points_b, view_a, view_b, label):
347
+ difference_threshold = 3.4028235e+38
348
+ max_value = float("inf")
349
+
350
+ # -- Get Points and Indexes
351
+ idx_a, coords_a = zip(*points_a)
352
+ idx_b, coords_b = zip(*points_b)
353
+ points_a_array = np.array(coords_a)
354
+ points_b_array = np.array(coords_b)
355
+
356
+ # --- KD Trees ---
357
+ tree_a = KDTree(points_a_array)
358
+ tree_b = KDTree(points_b_array)
359
+
360
+ # --- Subset Matcher ---
361
+ subset_size = self.num_neighbors
362
+ total_neighbors = self.num_neighbors + self.redundancy
363
+ neighbor_indices_combinations = list(itertools.combinations(range(total_neighbors), subset_size))
364
+ num_combinations = len(neighbor_indices_combinations)
365
+ num_matchings = num_combinations * num_combinations
366
+ matcher = {
367
+ "subset_size": subset_size,
368
+ "num_neighbors": total_neighbors,
369
+ "neighbors": neighbor_indices_combinations,
370
+ "num_combinations": num_combinations,
371
+ "num_matchings": num_matchings
372
+ }
373
+
374
+ # --- Descriptors ---
375
+ descriptors_a = self.create_simple_point_descriptors(tree_a, points_a_array, idx_a, self.num_required_neighbors, matcher)
376
+ descriptors_b = self.create_simple_point_descriptors(tree_b, points_b_array, idx_b, self.num_required_neighbors, matcher)
377
+
378
+ # --- Descriptor Matching ---
379
+ correspondence_candidates = []
380
+
381
+ out_of_radius = 0
382
+ passed_lowes = 0
383
+ first_if = 0
384
+ second_if = 0
385
+
386
+ for desc_a in descriptors_a:
387
+ best_difference = float("inf")
388
+ second_best_difference = float("inf")
389
+ best_match = None
390
+ second_best_match = None
391
+
392
+ for desc_b in descriptors_b:
393
+
394
+ if np.linalg.norm(desc_a['point'] - desc_b['point']) > self.search_radius:
395
+ out_of_radius += 1
396
+ continue
397
+
398
+ difference = self.descriptor_distance(desc_a, desc_b)
399
+
400
+ if difference < second_best_difference:
401
+ second_best_difference = difference
402
+ second_best_match = desc_b
403
+ first_if += 1
404
+
405
+ if second_best_difference < best_difference:
406
+ tmp_diff = second_best_difference
407
+ tmp_match = second_best_match
408
+ second_best_difference = best_difference
409
+ second_best_match = best_match
410
+ best_difference = tmp_diff
411
+ best_match = tmp_match
412
+ second_if += 1
413
+
414
+ # --- Lowe's Test ---
415
+ if best_difference < difference_threshold and best_difference * self.significance < second_best_difference and second_best_difference != max_value:
416
+ correspondence_candidates.append((
417
+ desc_a['point_index'],
418
+ desc_a['point'],
419
+ view_a,
420
+ label,
421
+ best_match['point_index'],
422
+ best_match['point'],
423
+ view_b,
424
+ label
425
+ ))
426
+ passed_lowes += 1
427
+
428
+ # print(f"out of range: {out_of_radius}, first if: {first_if}, second if: {second_if}, passed lowes: {passed_lowes}")
429
+ return correspondence_candidates
430
+
431
+ def get_tile_dims(self, view1):
432
+ stripped = view1.strip("()")
433
+ parts = stripped.split(", ")
434
+ tp_id = int(parts[0].split("=")[1])
435
+ setup_id = int(parts[1].split("=")[1])
436
+
437
+ image_loader = self.data_global.get('imageLoader', {})
438
+
439
+ # Loop through all view entries in the image loader
440
+ for entry in image_loader:
441
+ entry_setup = int(entry.get('view_setup', -1))
442
+ entry_tp = int(entry.get('timepoint', -1))
443
+
444
+ if entry_setup == setup_id and entry_tp == tp_id:
445
+ file_path = self.image_file_prefix + entry.get('file_path')
446
+ if self.input_type == "tiff":
447
+ img = CustomBioImage(file_path, reader=bioio_tifffile.Reader)
448
+ dask_array = img.get_dask_stack()[0, 0, 0, :, :, :]
449
+ shape = dask_array.shape
450
+
451
+ elif self.input_type == "zarr":
452
+ s3 = s3fs.S3FileSystem(anon=False)
453
+ full_path = f"{file_path}/0"
454
+ store = s3fs.S3Map(root=full_path, s3=s3)
455
+ zarr_array = zarr.open(store, mode='r')
456
+ dask_array = da.from_zarr(zarr_array)[0, 0, :, :, :]
457
+ shape = dask_array.shape
458
+
459
+ return shape[::-1]
460
+
461
+ def invert_transformation_matrix(self, view_2):
462
+ """
463
+ Compose and invert all ViewTransforms for the given view key (timepoint, setup).
464
+ """
465
+ stripped = view_2.strip("()")
466
+ parts = stripped.split(", ")
467
+ tp_id = int(parts[0].split("=")[1])
468
+ setup_id = int(parts[1].split("=")[1])
469
+ view_key = (tp_id, setup_id)
470
+
471
+ # Get all transforms for this view
472
+ transforms = self.view_registrations.get(view_key, [])
473
+ if not transforms:
474
+ raise ValueError(f"No transforms found for view {view_key}")
475
+
476
+ final_matrix = np.eye(4)
477
+
478
+ for i, transform in enumerate(transforms):
479
+ affine_str = transform.get("affine")
480
+ if not affine_str:
481
+ continue
482
+
483
+ values = [float(x) for x in affine_str.strip().split()]
484
+ if len(values) != 12:
485
+ raise ValueError(f"Transform {i+1} in view {view_key} has {len(values)} values, expected 12.")
486
+
487
+ matrix3x4 = np.array(values).reshape(3, 4)
488
+ matrix4x4 = np.eye(4)
489
+ matrix4x4[:3, :4] = matrix3x4
490
+
491
+ # Combine with running matrix
492
+ final_matrix = final_matrix @ matrix4x4
493
+
494
+ # Return the inverse
495
+ return np.linalg.inv(final_matrix)
496
+
497
+ def filter_for_overlapping_points(self, points_a, points_b, view_a, view_b):
498
+ points_a = list(enumerate(points_a))
499
+ points_b = list(enumerate(points_b))
500
+
501
+ if not points_a or not points_b:
502
+ return [], []
503
+
504
+ # Check points_a against view_b's interval
505
+ overlapping_a = []
506
+ tinv_b = self.invert_transformation_matrix(view_b)
507
+
508
+ view_b_key = tuple(map(int, re.findall(r'\d+', view_b)))
509
+ dim_b = self.data_global['viewSetup']['byId'][view_b_key[1]]
510
+ interval_b = {'min': (0, 0, 0), 'max': dim_b['size']}
511
+
512
+ for i in reversed(range(len(points_a))):
513
+ idx, point = points_a[i]
514
+ p_h = np.append(point, 1.0)
515
+ transformed = tinv_b @ p_h
516
+ x, y, z = transformed[:3]
517
+ x_min, y_min, z_min = interval_b['min']
518
+ x_max, y_max, z_max = interval_b['max']
519
+
520
+ if x_min <= x < x_max and y_min <= y < y_max and z_min <= z < z_max:
521
+ overlapping_a.append((idx, point))
522
+ del points_a[i]
523
+
524
+ # Check points_b against view_a's interval
525
+ overlapping_b = []
526
+ tinv_a = self.invert_transformation_matrix(view_a)
527
+
528
+ view_a_key = tuple(map(int, re.findall(r'\d+', view_a)))
529
+ dim_a = self.data_global['viewSetup']['byId'][view_a_key[1]]
530
+ interval_a = {'min': (0, 0, 0), 'max': dim_a['size']}
531
+
532
+ for i in reversed(range(len(points_b))):
533
+ idx, point = points_b[i]
534
+ p_h = np.append(point, 1.0)
535
+ transformed = tinv_a @ p_h
536
+ x, y, z = transformed[:3]
537
+ x_min, y_min, z_min = interval_a['min']
538
+ x_max, y_max, z_max = interval_a['max']
539
+
540
+ if x_min <= x < x_max and y_min <= y < y_max and z_min <= z < z_max:
541
+ overlapping_b.append((idx, point))
542
+ del points_b[i]
543
+
544
+ return overlapping_a, overlapping_b
@@ -0,0 +1,120 @@
1
+ import zarr
2
+ from collections import defaultdict
3
+ import s3fs
4
+
5
+ """
6
+ Save Matches saves (matched) corresponding interest points to N5 format
7
+ """
8
+
9
+ class SaveMatches:
10
+ def __init__(self, all_results, n5_output_path, data_global, match_type):
11
+ self.all_results = all_results
12
+ self.n5_output_path = n5_output_path
13
+ self.data_global = data_global
14
+ self.match_type = match_type
15
+
16
+ def save_correspondences(self):
17
+ """
18
+ Save correspondences for each view/label, aggregating all matches involving that view/label.
19
+ Print a detailed summary with breakdowns.
20
+ """
21
+ def parse_view(v: str):
22
+ tp = int(v.split("tpId=")[1].split(",")[0])
23
+ vs = int(v.split("setupId=")[1].split(")")[0])
24
+ return tp, vs
25
+
26
+ # Group results back per view
27
+ grouped_by_viewA = defaultdict(list)
28
+ for idxA, _, viewA, label_a, idxB, _, viewB, label_b in self.all_results:
29
+ grouped_by_viewA[viewA].append((idxA, idxB, viewB, label_b))
30
+ grouped_by_viewA[viewB].append((idxB, idxA, viewA, label_a))
31
+
32
+ # Create idmap per view of all corresponding groups
33
+ idMaps = {}
34
+ for viewA, matches in grouped_by_viewA.items():
35
+ target_keys = sorted({
36
+ f"{tpB},{vsB},{labB}"
37
+ for (_iA, _iB, viewB, labB) in matches
38
+ for (tpB, vsB) in [parse_view(viewB)]
39
+ })
40
+ idMaps[viewA] = {k: i for i, k in enumerate(target_keys)}
41
+
42
+ # Format data for injection
43
+ grouped_with_ids = defaultdict(list)
44
+ for viewA, matches in grouped_by_viewA.items():
45
+ idMap = idMaps[viewA]
46
+ for idxA, idxB, viewB, label in matches:
47
+ tp = int(viewB.split("tpId=")[1].split(",")[0])
48
+ vs = int(viewB.split("setupId=")[1].split(")")[0])
49
+ key = f"{tp},{vs},{label}"
50
+ view_id = idMap[key]
51
+ grouped_with_ids[viewA, label].append((idxA, idxB, view_id))
52
+
53
+ # Save idmap and corr points per view
54
+ for (viewA, labelA), corr_list in grouped_with_ids.items():
55
+ tpA = int(viewA.split("tpId=")[1].split(",")[0])
56
+ vsA = int(viewA.split("setupId=")[1].split(")")[0])
57
+ idMap = idMaps[viewA]
58
+
59
+ if len(corr_list) == 0:
60
+ continue
61
+
62
+ # Output path
63
+ full_path = f"{self.n5_output_path}interestpoints.n5/tpId_{tpA}_viewSetupId_{vsA}/{labelA}/correspondences/"
64
+
65
+ if full_path.startswith("s3://"):
66
+ path = full_path.replace("s3://", "")
67
+ self.s3_filesystem = s3fs.S3FileSystem()
68
+ store = s3fs.S3Map(root=path, s3=self.s3_filesystem, check=False)
69
+ root = zarr.open_group(store=store, mode='a')
70
+ else:
71
+ # Write to Zarr N5
72
+ store = zarr.N5Store(full_path)
73
+ root = zarr.group(store=store, overwrite="true")
74
+
75
+ # Delete existing 'data' array
76
+ if "data" in root:
77
+ del root["data"]
78
+
79
+ # Set group-level attributes
80
+ root.attrs.update({
81
+ "correspondences": "1.0.0",
82
+ "idMap": idMap
83
+ })
84
+
85
+ # Create dataset inside the group
86
+ root.create_dataset(
87
+ name="data",
88
+ data=corr_list,
89
+ dtype='u8',
90
+ chunks=(min(300000, len(corr_list)), 1),
91
+ compressor=zarr.GZip()
92
+ )
93
+
94
+ def clear_correspondence(self):
95
+ if self.n5_output_path.startswith("s3://"):
96
+ root_path = self.n5_output_path.replace("s3://", "") + "interestpoints.n5"
97
+ s3 = s3fs.S3FileSystem()
98
+ store = s3fs.S3Map(root=root_path, s3=s3, check=False)
99
+ else:
100
+ root_path = self.n5_output_path + "interestpoints.n5"
101
+ store = zarr.N5Store(root_path)
102
+
103
+ root = zarr.open_group(store=store, mode="a")
104
+
105
+ views = list(self.data_global['viewsInterestPoints'].keys())
106
+ for tp, vs in views:
107
+ labels = self.data_global['viewsInterestPoints'][(tp, vs)]['label']
108
+ for label in labels:
109
+ corr_path = f"tpId_{tp}_viewSetupId_{vs}/{label}/correspondences"
110
+ try:
111
+ if corr_path in root:
112
+ del root[corr_path]
113
+ elif f"{corr_path}/data" in root:
114
+ del root[f"{corr_path}/data"]
115
+ except Exception as e:
116
+ print(f"⚠️ Could not delete {corr_path}: {e}")
117
+
118
+ def run(self):
119
+ self.clear_correspondence()
120
+ self.save_correspondences()