Rhapso 0.1.92__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. Rhapso/__init__.py +1 -0
  2. Rhapso/data_prep/__init__.py +2 -0
  3. Rhapso/data_prep/n5_reader.py +188 -0
  4. Rhapso/data_prep/s3_big_stitcher_reader.py +55 -0
  5. Rhapso/data_prep/xml_to_dataframe.py +215 -0
  6. Rhapso/detection/__init__.py +5 -0
  7. Rhapso/detection/advanced_refinement.py +203 -0
  8. Rhapso/detection/difference_of_gaussian.py +324 -0
  9. Rhapso/detection/image_reader.py +117 -0
  10. Rhapso/detection/metadata_builder.py +130 -0
  11. Rhapso/detection/overlap_detection.py +327 -0
  12. Rhapso/detection/points_validation.py +49 -0
  13. Rhapso/detection/save_interest_points.py +265 -0
  14. Rhapso/detection/view_transform_models.py +67 -0
  15. Rhapso/fusion/__init__.py +0 -0
  16. Rhapso/fusion/affine_fusion/__init__.py +2 -0
  17. Rhapso/fusion/affine_fusion/blend.py +289 -0
  18. Rhapso/fusion/affine_fusion/fusion.py +601 -0
  19. Rhapso/fusion/affine_fusion/geometry.py +159 -0
  20. Rhapso/fusion/affine_fusion/io.py +546 -0
  21. Rhapso/fusion/affine_fusion/script_utils.py +111 -0
  22. Rhapso/fusion/affine_fusion/setup.py +4 -0
  23. Rhapso/fusion/affine_fusion_worker.py +234 -0
  24. Rhapso/fusion/multiscale/__init__.py +0 -0
  25. Rhapso/fusion/multiscale/aind_hcr_data_transformation/__init__.py +19 -0
  26. Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/__init__.py +3 -0
  27. Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/czi_to_zarr.py +698 -0
  28. Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/zarr_writer.py +265 -0
  29. Rhapso/fusion/multiscale/aind_hcr_data_transformation/models.py +81 -0
  30. Rhapso/fusion/multiscale/aind_hcr_data_transformation/utils/__init__.py +3 -0
  31. Rhapso/fusion/multiscale/aind_hcr_data_transformation/utils/utils.py +526 -0
  32. Rhapso/fusion/multiscale/aind_hcr_data_transformation/zeiss_job.py +249 -0
  33. Rhapso/fusion/multiscale/aind_z1_radial_correction/__init__.py +21 -0
  34. Rhapso/fusion/multiscale/aind_z1_radial_correction/array_to_zarr.py +257 -0
  35. Rhapso/fusion/multiscale/aind_z1_radial_correction/radial_correction.py +557 -0
  36. Rhapso/fusion/multiscale/aind_z1_radial_correction/run_capsule.py +98 -0
  37. Rhapso/fusion/multiscale/aind_z1_radial_correction/utils/__init__.py +3 -0
  38. Rhapso/fusion/multiscale/aind_z1_radial_correction/utils/utils.py +266 -0
  39. Rhapso/fusion/multiscale/aind_z1_radial_correction/worker.py +89 -0
  40. Rhapso/fusion/multiscale_worker.py +113 -0
  41. Rhapso/fusion/neuroglancer_link_gen/__init__.py +8 -0
  42. Rhapso/fusion/neuroglancer_link_gen/dispim_link.py +235 -0
  43. Rhapso/fusion/neuroglancer_link_gen/exaspim_link.py +127 -0
  44. Rhapso/fusion/neuroglancer_link_gen/hcr_link.py +368 -0
  45. Rhapso/fusion/neuroglancer_link_gen/iSPIM_top.py +47 -0
  46. Rhapso/fusion/neuroglancer_link_gen/link_utils.py +239 -0
  47. Rhapso/fusion/neuroglancer_link_gen/main.py +299 -0
  48. Rhapso/fusion/neuroglancer_link_gen/ng_layer.py +1434 -0
  49. Rhapso/fusion/neuroglancer_link_gen/ng_state.py +1123 -0
  50. Rhapso/fusion/neuroglancer_link_gen/parsers.py +336 -0
  51. Rhapso/fusion/neuroglancer_link_gen/raw_link.py +116 -0
  52. Rhapso/fusion/neuroglancer_link_gen/utils/__init__.py +4 -0
  53. Rhapso/fusion/neuroglancer_link_gen/utils/shader_utils.py +85 -0
  54. Rhapso/fusion/neuroglancer_link_gen/utils/transfer.py +43 -0
  55. Rhapso/fusion/neuroglancer_link_gen/utils/utils.py +303 -0
  56. Rhapso/fusion/neuroglancer_link_gen_worker.py +30 -0
  57. Rhapso/matching/__init__.py +0 -0
  58. Rhapso/matching/load_and_transform_points.py +458 -0
  59. Rhapso/matching/ransac_matching.py +544 -0
  60. Rhapso/matching/save_matches.py +120 -0
  61. Rhapso/matching/xml_parser.py +302 -0
  62. Rhapso/pipelines/__init__.py +0 -0
  63. Rhapso/pipelines/ray/__init__.py +0 -0
  64. Rhapso/pipelines/ray/aws/__init__.py +0 -0
  65. Rhapso/pipelines/ray/aws/alignment_pipeline.py +227 -0
  66. Rhapso/pipelines/ray/aws/config/__init__.py +0 -0
  67. Rhapso/pipelines/ray/evaluation.py +71 -0
  68. Rhapso/pipelines/ray/interest_point_detection.py +137 -0
  69. Rhapso/pipelines/ray/interest_point_matching.py +110 -0
  70. Rhapso/pipelines/ray/local/__init__.py +0 -0
  71. Rhapso/pipelines/ray/local/alignment_pipeline.py +167 -0
  72. Rhapso/pipelines/ray/matching_stats.py +104 -0
  73. Rhapso/pipelines/ray/param/__init__.py +0 -0
  74. Rhapso/pipelines/ray/solver.py +120 -0
  75. Rhapso/pipelines/ray/split_dataset.py +78 -0
  76. Rhapso/solver/__init__.py +0 -0
  77. Rhapso/solver/compute_tiles.py +562 -0
  78. Rhapso/solver/concatenate_models.py +116 -0
  79. Rhapso/solver/connected_graphs.py +111 -0
  80. Rhapso/solver/data_prep.py +181 -0
  81. Rhapso/solver/global_optimization.py +410 -0
  82. Rhapso/solver/model_and_tile_setup.py +109 -0
  83. Rhapso/solver/pre_align_tiles.py +323 -0
  84. Rhapso/solver/save_results.py +97 -0
  85. Rhapso/solver/view_transforms.py +75 -0
  86. Rhapso/solver/xml_to_dataframe_solver.py +213 -0
  87. Rhapso/split_dataset/__init__.py +0 -0
  88. Rhapso/split_dataset/compute_grid_rules.py +78 -0
  89. Rhapso/split_dataset/save_points.py +101 -0
  90. Rhapso/split_dataset/save_xml.py +377 -0
  91. Rhapso/split_dataset/split_images.py +537 -0
  92. Rhapso/split_dataset/xml_to_dataframe_split.py +219 -0
  93. rhapso-0.1.92.dist-info/METADATA +39 -0
  94. rhapso-0.1.92.dist-info/RECORD +101 -0
  95. rhapso-0.1.92.dist-info/WHEEL +5 -0
  96. rhapso-0.1.92.dist-info/licenses/LICENSE +21 -0
  97. rhapso-0.1.92.dist-info/top_level.txt +2 -0
  98. tests/__init__.py +1 -0
  99. tests/test_detection.py +17 -0
  100. tests/test_matching.py +21 -0
  101. tests/test_solving.py +21 -0
@@ -0,0 +1,327 @@
1
+ import numpy as np
2
+ from bioio import BioImage
3
+ import bioio_tifffile
4
+ import zarr
5
+ import s3fs
6
+ import dask.array as da
7
+ import math
8
+
9
+ """
10
+ Overlap Detection figures out where image tile overlap.
11
+ """
12
+
13
+ # TIFF reader wants to be used as an abstract class
14
+ class CustomBioImage(BioImage):
15
+ def standard_metadata(self):
16
+ pass
17
+
18
+ def scale(self):
19
+ pass
20
+
21
+ def time_interval(self):
22
+ pass
23
+
24
+ class OverlapDetection():
25
+ def __init__(self, transform_models, dataframes, dsxy, dsz, prefix, file_type):
26
+ self.transform_models = transform_models
27
+ self.image_loader_df = dataframes['image_loader']
28
+ self.dsxy, self.dsz = dsxy, dsz
29
+ self.prefix = prefix
30
+ self.file_type = file_type
31
+ self.to_process = {}
32
+ self.image_shape_cache = {}
33
+ self.max_interval_size = 0
34
+
35
+ def create_mipmap_transform(self):
36
+ """
37
+ Build a 4×4 homogeneous scaling matrix for the mipmap level
38
+ """
39
+ scale_matrix = np.array([
40
+ [self.dsxy, 0, 0, 0],
41
+ [0, self.dsxy, 0, 0],
42
+ [0, 0, self.dsz, 0],
43
+ [0, 0, 0, 1]
44
+ ])
45
+
46
+ return scale_matrix
47
+
48
+ def load_image_metadata(self, file_path):
49
+ if file_path in self.image_shape_cache:
50
+ return self.image_shape_cache[file_path]
51
+
52
+ if self.file_type == 'zarr':
53
+ s3 = s3fs.S3FileSystem(anon=False)
54
+ store = s3fs.S3Map(root=file_path, s3=s3)
55
+ zarr_array = zarr.open(store, mode='r')
56
+ dask_array = da.from_zarr(zarr_array)
57
+ dask_array = da.expand_dims(dask_array, axis=2)
58
+ shape = dask_array.shape
59
+ self.image_shape_cache[file_path] = shape
60
+
61
+ elif self.file_type == 'tiff':
62
+ img = CustomBioImage(file_path, reader=bioio_tifffile.Reader)
63
+ data = img.get_dask_stack()
64
+ shape = data.shape
65
+ self.image_shape_cache[file_path] = shape
66
+
67
+ return shape
68
+
69
+ # def open_and_downsample(self, shape):
70
+ # X = int(shape[5])
71
+ # Y = int(shape[4])
72
+ # Z = int(shape[3])
73
+
74
+ # dsx = int(self.dsxy)
75
+ # dsy = int(self.dsxy)
76
+ # dsz = int(self.dsz)
77
+
78
+ # def ceil_half_chain(n, f):
79
+ # out = int(n)
80
+ # while f >= 2:
81
+ # out = (out + 1) // 2 # ceil(n/2)
82
+ # f //= 2
83
+ # return out
84
+
85
+ # x_new = ceil_half_chain(X, dsx)
86
+ # y_new = ceil_half_chain(Y, dsy)
87
+ # z_new = ceil_half_chain(Z, dsz)
88
+
89
+ # mipmap_transform = self.create_mipmap_transform()
90
+ # return ((0, 0, 0), (x_new, y_new, z_new)), mipmap_transform
91
+
92
+ def open_and_downsample(self, shape, dsxy, dsz):
93
+ """
94
+ Downsample a 3D volume by powers of two by repeatedly halving along each axis
95
+ """
96
+ dsx = dsxy
97
+ dsy = dsxy
98
+
99
+ # downsample x dimension
100
+ x_new = shape[5]
101
+ while dsx > 1:
102
+ x_new = x_new // 2 if x_new % 2 == 0 else (x_new // 2) + 1
103
+ dsx //= 2
104
+
105
+ # downsample y dimension
106
+ y_new = shape[4]
107
+ while dsy > 1:
108
+ y_new = y_new // 2 if y_new % 2 == 0 else (y_new // 2) + 1
109
+ dsy //= 2
110
+
111
+ # downsample z dimension
112
+ z_new = shape[3]
113
+ while dsz > 1:
114
+ z_new = z_new // 2 if z_new % 2 == 0 else (z_new // 2) + 1
115
+ dsz //= 2
116
+
117
+ return ((0, 0, 0), (x_new, y_new, z_new))
118
+
119
+ def get_inverse_mipmap_transform(self, mipmap_transform):
120
+ """
121
+ Compute the inverse of the given mipmap transform
122
+ """
123
+ try:
124
+ inverse_scale_matrix = np.linalg.inv(mipmap_transform)
125
+ except np.linalg.LinAlgError:
126
+ print("Matrix cannot be inverted.")
127
+ return None
128
+
129
+ return inverse_scale_matrix
130
+
131
+ def estimate_bounds(self, a, interval):
132
+ """
133
+ Transform an axis-aligned box through a 4x4 affine
134
+ """
135
+ # set lower bounds
136
+ t0, t1, t2 = 0, 0, 0
137
+
138
+ # set upper bounds
139
+ if self.file_type == 'zarr':
140
+ s0 = interval[5] - t0
141
+ s1 = interval[4] - t1
142
+ s2 = interval[3] - t2
143
+ elif self.file_type == 'tiff':
144
+ s0 = interval[5] - t0
145
+ s1 = interval[4] - t1
146
+ s2 = interval[3] - t2
147
+
148
+ # get dot product of uppper bounds and inverted downsampling matrix
149
+ matrix = np.array(a)
150
+ tt = np.dot(matrix[:, :3], [t0, t1, t2]) + matrix[:, 3]
151
+ r_min = np.copy(tt)
152
+ r_max = np.copy(tt)
153
+
154
+ # set upper and lower bounds using inverted downsampling matrix
155
+ for i in range(3):
156
+ if matrix[i, 0] < 0:
157
+ r_min[i] += s0 * matrix[i, 0]
158
+ else:
159
+ r_max[i] += s0 * matrix[i, 0]
160
+
161
+ if matrix[i, 1] < 0:
162
+ r_min[i] += s1 * matrix[i, 1]
163
+ else:
164
+ r_max[i] += s1 * matrix[i, 1]
165
+
166
+ if matrix[i, 2] < 0:
167
+ r_min[i] += s2 * matrix[i, 2]
168
+ else:
169
+ r_max[i] += s2 * matrix[i, 2]
170
+
171
+ return r_min[:3], r_max[:3]
172
+
173
+ def calculate_intersection(self, bbox1, bbox2):
174
+ """
175
+ Compute the axis-aligned intersection of two 3D boxes given as (min, max) coordinates
176
+ """
177
+ intersect_min = np.maximum(bbox1[0], bbox2[0])
178
+ intersect_max = np.minimum(bbox1[1], bbox2[1])
179
+
180
+ return (intersect_min, intersect_max)
181
+
182
+ def calculate_new_dims(self, lower_bound, upper_bound):
183
+ """
184
+ Compute per-axis lengths from bounds
185
+ """
186
+ new_dims = []
187
+ for lb, ub in zip(lower_bound, upper_bound):
188
+ if lb == 0:
189
+ new_dims.append(ub + 1)
190
+ else:
191
+ new_dims.append(ub - lb)
192
+
193
+ return new_dims
194
+
195
+ def floor_log2(self, n):
196
+ """
197
+ Return ⌊log2(n)⌋ - clamps n ≤ 1 to 1 so the result is 0 for n ≤ 1
198
+ """
199
+ return max(0, int(math.floor(math.log2(max(1, n)))))
200
+
201
+ def choose_zarr_level(self):
202
+ """
203
+ pick the highest power-of-two pyramid level ( ≤ 7) compatible with dsxy/dsz
204
+ """
205
+ max_level = 7
206
+ lvl_xy = self.floor_log2(self.dsxy)
207
+ lvl_z = self.floor_log2(self.dsz)
208
+ best = min(lvl_xy, lvl_z, max_level)
209
+ factor = 1 << best
210
+ leftovers = (max(1, self.dsxy // factor), max(1, self.dsxy // factor), max(1, self.dsz // factor))
211
+ return best, leftovers
212
+
213
+ def affine_with_half_pixel_shift(self, sx, sy, sz):
214
+ """
215
+ Build a 4x4 scaling affine that also shifts by 0.5·(scale-1) per axis so voxel centers stay aligned after
216
+ resampling (half-pixel compensation)
217
+ """
218
+ # translation = 0.5 * (scale - 1) per axis
219
+ tx = 0.5 * (sx - 1.0)
220
+ ty = 0.5 * (sy - 1.0)
221
+ tz = 0.5 * (sz - 1.0)
222
+
223
+ return np.array([
224
+ [sx, 0.0, 0.0, tx],
225
+ [0.0, sy, 0.0, ty],
226
+ [0.0, 0.0, sz, tz],
227
+ [0.0, 0.0, 0.0, 1.0],
228
+ ], dtype=float)
229
+
230
+ def size_interval(self, lb, ub):
231
+ """
232
+ Find the number of voxels in a 3D box with inclusive bounds
233
+ """
234
+ return int((int(ub[0]) - int(lb[0]) + 1) *
235
+ (int(ub[1]) - int(lb[1]) + 1) *
236
+ (int(ub[2]) - int(lb[2]) + 1))
237
+
238
+ def find_overlapping_area(self):
239
+ """
240
+ Compute XY Z overlap intervals against every other view, accounting for mipmap/downsampling and per-view affine transforms
241
+ """
242
+ for i, row_i in self.image_loader_df.iterrows():
243
+ view_id = f"timepoint: {row_i['timepoint']}, setup: {row_i['view_setup']}"
244
+
245
+ # get inverted matrice of downsampling
246
+ all_intervals = []
247
+ if self.file_type == 'zarr':
248
+ level, leftovers = self.choose_zarr_level()
249
+ dim_base = self.load_image_metadata(self.prefix + row_i['file_path'] + f'/{0}')
250
+
251
+ # isotropic pyramid
252
+ s = float(2 ** level)
253
+ mipmap_of_downsample = self.affine_with_half_pixel_shift(s, s, s)
254
+
255
+ # TODO - update mipmap with leftovers if other than 1
256
+ _, dsxy, dsz = leftovers
257
+
258
+ elif self.file_type == 'tiff':
259
+ dim_base = self.load_image_metadata(self.prefix + row_i['file_path'])
260
+ mipmap_of_downsample = self.create_mipmap_transform()
261
+ dsxy, dsz = self.dsxy, self.dsz
262
+ level = None
263
+
264
+ downsampled_dim_base = self.open_and_downsample(dim_base, dsxy, dsz)
265
+ t1 = self.get_inverse_mipmap_transform(mipmap_of_downsample)
266
+
267
+ # compare with all view_ids
268
+ for j, row_j in self.image_loader_df.iterrows():
269
+ if i == j: continue
270
+
271
+ view_id_other = f"timepoint: {row_j['timepoint']}, setup: {row_j['view_setup']}"
272
+
273
+ if self.file_type == 'zarr':
274
+ dim_other = self.load_image_metadata(self.prefix + row_j['file_path'] + f'/{0}')
275
+ elif self.file_type == 'tiff':
276
+ dim_other = self.load_image_metadata(self.prefix + row_j['file_path'])
277
+
278
+ # get transforms matrix from both view_ids and downsampling matrices
279
+ matrix = self.transform_models.get(view_id)
280
+ matrix_other = self.transform_models.get(view_id_other)
281
+
282
+ if self.file_type == 'zarr':
283
+ s = float(2 ** level)
284
+ mipmap_of_downsample_other = self.affine_with_half_pixel_shift(s, s, s)
285
+ elif self.file_type == 'tiff':
286
+ mipmap_of_downsample_other = self.create_mipmap_transform()
287
+
288
+ inverse_mipmap_of_downsample_other = self.get_inverse_mipmap_transform(mipmap_of_downsample_other)
289
+ inverse_matrix = self.get_inverse_mipmap_transform(matrix)
290
+
291
+ concatenated_matrix = np.dot(inverse_matrix, matrix_other)
292
+ t2 = np.dot(inverse_mipmap_of_downsample_other, concatenated_matrix)
293
+
294
+ intervals = self.estimate_bounds(t1, dim_base)
295
+ intervals_other = self.estimate_bounds(t2, dim_other)
296
+
297
+ bounding_boxes = tuple(map(lambda x: np.round(x).astype(int), intervals))
298
+ bounding_boxes_other = tuple(map(lambda x: np.round(x).astype(int), intervals_other))
299
+
300
+ # find upper and lower bounds of intersection
301
+ if np.all((bounding_boxes[1] >= bounding_boxes_other[0]) & (bounding_boxes_other[1] >= bounding_boxes[0])):
302
+ intersected_boxes = self.calculate_intersection(bounding_boxes, bounding_boxes_other)
303
+ intersect = self.calculate_intersection(downsampled_dim_base, intersected_boxes)
304
+ intersect_dict = {
305
+ 'lower_bound': intersect[0],
306
+ 'upper_bound': intersect[1],
307
+ 'span': self.calculate_new_dims(intersect[0], intersect[1])
308
+ }
309
+
310
+ lb, ub = intersect[0], intersect[1]
311
+ sz = self.size_interval(lb, ub)
312
+ if sz > self.max_interval_size:
313
+ self.max_interval_size = sz
314
+
315
+ # add max size
316
+ all_intervals.append(intersect_dict)
317
+
318
+ self.to_process[view_id] = all_intervals
319
+
320
+ return dsxy, dsz, level, mipmap_of_downsample
321
+
322
+ def run(self):
323
+ """
324
+ Executes the entry point of the script.
325
+ """
326
+ dsxy, dsz, level, mipmap_of_dowsample = self.find_overlapping_area()
327
+ return self.to_process, dsxy, dsz, level, self.max_interval_size, mipmap_of_dowsample
@@ -0,0 +1,49 @@
1
+ import numpy as np
2
+
3
+ """
4
+ Points Validation prints out the metrics for the results of interest point detection
5
+ """
6
+
7
+ class PointsValidation:
8
+ def __init__(self, consolidated_data):
9
+ self.consolidated_data = consolidated_data
10
+
11
+ def validation_suite(self):
12
+ total_points = sum(len(points) for points in self.consolidated_data.values())
13
+ print(f"\nTotal Interest Points Found: {total_points}")
14
+
15
+ print("\nInterest Points by View ID:")
16
+ for view_id, points in self.consolidated_data.items():
17
+
18
+ # Sort points by index
19
+ sorted_points = sorted(points, key=lambda x: x[0][2]) # x[1] is the (x,y,z) index to sort by
20
+
21
+ if len(sorted_points) == 0:
22
+ print(f"\nView ID: {view_id} | Num points: 0")
23
+ print("\n--- Detection Stats ---")
24
+ print("No points found for this view.\n")
25
+ continue
26
+
27
+ coords = np.array([p[0] for p in sorted_points])
28
+ intensities = np.array([p[1] for p in sorted_points])
29
+
30
+ # Print metrics on interest points
31
+ print("\n--- Detection Stats ---")
32
+ print(f"Total Points: {len(coords)}")
33
+ print(f"Intensity: min={intensities.min():.2f}, max={intensities.max():.2f}, mean={intensities.mean():.2f}, std={intensities.std():.2f}")
34
+
35
+ for dim, name in zip(range(3), ['X', 'Y', 'Z']):
36
+ values = coords[:, dim]
37
+ print(f"{name} Range: {values.min():.2f} – {values.max():.2f} | Spread (std): {values.std():.2f}")
38
+
39
+ # Density per 1000x1000x1000 space
40
+ volume = np.ptp(coords[:, 0]) * np.ptp(coords[:, 1]) * np.ptp(coords[:, 2])
41
+ density = len(coords) / (volume / 1e9) if volume > 0 else 0
42
+ print(f"Estimated Density: {density:.2f} points per 1000³ volume")
43
+ print("-----------------------\n")
44
+
45
+ def run(self):
46
+ """
47
+ Executes the entry point of the script.
48
+ """
49
+ self.validation_suite()
@@ -0,0 +1,265 @@
1
+ import zarr
2
+ import numpy as np
3
+ import xml.etree.ElementTree as ET
4
+ import s3fs
5
+ import boto3
6
+ from io import BytesIO
7
+ import io
8
+ import json
9
+
10
+ """
11
+ Save Interest Points saves interest points as N5 and updates the xml with pathways
12
+ """
13
+
14
+ class SaveInterestPoints:
15
+ def __init__(self, dataframes, consolidated_data, xml_file_path, xml_output_file_path, n5_output_file_prefix, downsample_xy, downsample_z, min_intensity,
16
+ max_intensity, sigma, threshold):
17
+ self.consolidated_data = consolidated_data
18
+ self.image_loader_df = dataframes['image_loader']
19
+ self.xml_file_path = xml_file_path
20
+ self.xml_output_file_path = xml_output_file_path
21
+ self.n5_output_file_prefix = n5_output_file_prefix
22
+ self.downsample_xy = downsample_xy
23
+ self.downsample_z = downsample_z
24
+ self.min_intensity = min_intensity
25
+ self.max_intensity = max_intensity
26
+ self.sigma = sigma
27
+ self.threshold = threshold
28
+ self.s3_filesystem = s3fs.S3FileSystem()
29
+ self.overlappingOnly = "true"
30
+ self.findMin = "true"
31
+ self.findMax = "true"
32
+ self.default_block_size = 300000
33
+
34
+ def load_xml_file(self, file_path):
35
+ tree = ET.parse(file_path)
36
+ root = tree.getroot()
37
+ return tree, root
38
+
39
+ def fetch_from_s3(self, s3, bucket_name, input_file):
40
+ response = s3.get_object(Bucket=bucket_name, Key=input_file)
41
+ return response['Body'].read().decode('utf-8')
42
+
43
+ def save_to_xml(self):
44
+ """
45
+ Rebuild the <ViewInterestPoints> section and write the updated XML back
46
+ """
47
+ if self.xml_file_path.startswith("s3://"):
48
+ bucket, key = self.xml_file_path.replace("s3://", "", 1).split("/", 1)
49
+ s3 = boto3.client('s3')
50
+ xml_string = self.fetch_from_s3(s3, bucket, key)
51
+ tree = ET.parse(io.BytesIO(xml_string.encode('utf-8')))
52
+ root = tree.getroot()
53
+ else:
54
+ tree, root = self.load_xml_file(self.xml_file_path)
55
+
56
+ interest_points_section = root.find('.//ViewInterestPoints')
57
+
58
+ if interest_points_section is None:
59
+ interest_points_section = ET.SubElement(root, 'ViewInterestPoints')
60
+ interest_points_section.text = '\n '
61
+
62
+ else:
63
+ interest_points_section.clear()
64
+ interest_points_section.text = '\n '
65
+
66
+ for view_id, _ in self.consolidated_data.items():
67
+ parts = view_id.split(',')
68
+ timepoint_part = parts[0].strip()
69
+ setup_part = parts[1].strip()
70
+
71
+ timepoint = int(timepoint_part.split(':')[1].strip())
72
+ setup = int(setup_part.split(':')[1].strip())
73
+ label = "beads"
74
+ params = "DOG (Spark) s={} t={} overlappingOnly={} min={} max={} downsampleXY={} downsampleZ={} minIntensity={} maxIntensity={}".format(
75
+ self.sigma, self.threshold, self.overlappingOnly, self.findMin, self.findMax,
76
+ self.downsample_xy, self.downsample_z, self.min_intensity, self.max_intensity)
77
+ value = f"tpId_{timepoint}_viewSetupId_{setup}/beads"
78
+
79
+ new_interest_point = ET.SubElement(interest_points_section, 'ViewInterestPointsFile', {
80
+ 'timepoint': str(timepoint),
81
+ 'setup': str(setup),
82
+ 'label': label,
83
+ 'params': params
84
+ })
85
+ new_interest_point.text = value
86
+ new_interest_point.tail = '\n '
87
+
88
+ interest_points_section.tail = '\n '
89
+
90
+ if self.xml_output_file_path.startswith("s3://"):
91
+ bucket, key = self.xml_output_file_path.replace("s3://", "", 1).split("/", 1)
92
+ xml_bytes = BytesIO()
93
+ tree.write(xml_bytes, encoding='utf-8', xml_declaration=True)
94
+ xml_bytes.seek(0)
95
+ s3 = boto3.client('s3')
96
+ s3.upload_fileobj(xml_bytes, bucket, key)
97
+
98
+ else:
99
+ tree.write(self.xml_output_file_path, encoding='utf-8', xml_declaration=True)
100
+
101
+ def write_json_to_s3(self, id_dataset_path, loc_dataset_path, attributes):
102
+ """
103
+ Write attributes file into both the ID and LOC dataset directories on S3
104
+ """
105
+ bucket, key = id_dataset_path.replace("s3://", "", 1).split("/", 1)
106
+ json_path = key + '/attributes.json'
107
+ json_bytes = json.dumps(attributes).encode('utf-8')
108
+ s3 = boto3.client('s3')
109
+ s3.put_object(Bucket=bucket, Key=json_path, Body=json_bytes)
110
+
111
+ bucket, key = loc_dataset_path.replace("s3://", "", 1).split("/", 1)
112
+ json_path = key + '/attributes.json'
113
+ json_bytes = json.dumps(attributes).encode('utf-8')
114
+ s3 = boto3.client('s3')
115
+ s3.put_object(Bucket=bucket, Key=json_path, Body=json_bytes)
116
+
117
+ def save_intensities_to_n5(self, view_id, n5_path):
118
+ """
119
+ Write intensities into an N5 group
120
+ """
121
+ if self.n5_output_file_prefix.startswith("s3://"):
122
+ output_path = self.n5_output_file_prefix + n5_path + "/interestpoints"
123
+ store = s3fs.S3Map(root=output_path, s3=self.s3_filesystem, check=False)
124
+ root = zarr.group(store=store, overwrite=False)
125
+ root.attrs['n5'] = '4.0.0'
126
+
127
+ else:
128
+ store = zarr.N5Store(self.n5_output_file_prefix + n5_path + "/interestpoints")
129
+ root = zarr.group(store, overwrite=False)
130
+ root.attrs['n5'] = '4.0.0'
131
+
132
+ intensities_path = 'intensities'
133
+
134
+ if intensities_path in root:
135
+ try:
136
+ del root[intensities_path]
137
+ except Exception as e:
138
+ print(f"Warning: failed to delete existing dataset at {intensities_path}: {e}")
139
+
140
+ try:
141
+ if view_id in self.consolidated_data:
142
+ intensities = [point[1] for point in self.consolidated_data[view_id]]
143
+ dataset = root.create_dataset(
144
+ intensities_path,
145
+ data=intensities,
146
+ dtype='f4',
147
+ chunks=(self.default_block_size,),
148
+ compressor=zarr.GZip()
149
+ )
150
+ dataset.attrs["dimensions"] = [1, len(intensities)]
151
+ dataset.attrs["blockSize"] = [1, self.default_block_size]
152
+ else:
153
+ root.create_dataset(
154
+ intensities_path,
155
+ shape=(0,),
156
+ dtype='f4',
157
+ chunks=(1,),
158
+ compressor=zarr.GZip()
159
+ )
160
+ except Exception as e:
161
+ print(f"Error creating intensities dataset at {intensities_path}: {e}")
162
+
163
+ def save_interest_points_to_n5(self, view_id, n5_path):
164
+ """
165
+ Write interest point IDs and 3D locations into an N5 group
166
+ """
167
+ if self.n5_output_file_prefix.startswith("s3://"):
168
+ output_path = self.n5_output_file_prefix + n5_path + "/interestpoints"
169
+ store = s3fs.S3Map(root=output_path, s3=self.s3_filesystem, check=False)
170
+ root = zarr.group(store=store, overwrite=False)
171
+ root.attrs["pointcloud"] = "1.0.0"
172
+ root.attrs["type"] = "list"
173
+ root.attrs["list version"] = "1.0.0"
174
+
175
+ else:
176
+ store = zarr.N5Store(self.n5_output_file_prefix + n5_path + "/interestpoints")
177
+ root = zarr.group(store, overwrite=False)
178
+ root.attrs["pointcloud"] = "1.0.0"
179
+ root.attrs["type"] = "list"
180
+ root.attrs["list version"] = "1.0.0"
181
+
182
+ id_dataset = "id"
183
+ loc_dataset = "loc"
184
+
185
+ if self.n5_output_file_prefix.startswith("s3://"):
186
+ id_path = f"{output_path}/id"
187
+ loc_path = f"{output_path}/loc"
188
+ attrs_dict = dict(root.attrs)
189
+ self.write_json_to_s3(id_path, loc_path, attrs_dict)
190
+
191
+ if (view_id in self.consolidated_data) and (len(self.consolidated_data[view_id]) > 0):
192
+ interest_points = [point[0] for point in self.consolidated_data[view_id]]
193
+ interest_point_ids = np.arange(len(interest_points), dtype=np.uint64).reshape(-1, 1)
194
+ n = 3
195
+
196
+ if id_dataset in root:
197
+ del root[id_dataset]
198
+ root.create_dataset(
199
+ id_dataset,
200
+ data=interest_point_ids,
201
+ dtype='u8',
202
+ chunks=(self.default_block_size,),
203
+ compressor=zarr.GZip()
204
+ )
205
+
206
+ if loc_dataset in root:
207
+ del root[loc_dataset]
208
+ root.create_dataset(
209
+ loc_dataset,
210
+ data=interest_points,
211
+ dtype='f8',
212
+ chunks=(self.default_block_size, n),
213
+ compressor=zarr.GZip()
214
+ )
215
+
216
+ # save as empty lists
217
+ else:
218
+ if id_dataset in root:
219
+ del root[id_dataset]
220
+ root.create_dataset(
221
+ id_dataset,
222
+ shape=(0,),
223
+ dtype='u8',
224
+ chunks=(1,),
225
+ compressor=zarr.GZip()
226
+ )
227
+
228
+ if loc_dataset in root:
229
+ del root[loc_dataset]
230
+ root.create_dataset(
231
+ loc_dataset,
232
+ shape=(0,),
233
+ dtype='f8',
234
+ chunks=(1,),
235
+ compressor=zarr.GZip()
236
+ )
237
+
238
+ def save_points(self):
239
+ """
240
+ Orchestrate interest points and intensities into an N5 layout - inject attributes file
241
+ """
242
+ for _, row in self.image_loader_df.iterrows():
243
+ view_id = f"timepoint: {row['timepoint']}, setup: {row['view_setup']}"
244
+ n5_path = f"interestpoints.n5/tpId_{row['timepoint']}_viewSetupId_{row['view_setup']}/beads"
245
+ self.save_interest_points_to_n5(view_id, n5_path)
246
+ self.save_intensities_to_n5(view_id, n5_path)
247
+
248
+ path = self.n5_output_file_prefix + "interestpoints.n5"
249
+
250
+ if path.startswith("s3://"):
251
+ bucket_key = path.replace("s3://", "", 1)
252
+ store = s3fs.S3Map(root=bucket_key, s3=self.s3_filesystem, check=False)
253
+ root = zarr.group(store=store, overwrite=False)
254
+ root.attrs['n5'] = '4.0.0'
255
+ else:
256
+ store = zarr.N5Store(path)
257
+ root = zarr.group(store, overwrite=False)
258
+ root.attrs['n5'] = '4.0.0'
259
+
260
+ def run(self):
261
+ """
262
+ Executes the entry point of the script.
263
+ """
264
+ self.save_points()
265
+ self.save_to_xml()