Rhapso 0.1.92__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. Rhapso/__init__.py +1 -0
  2. Rhapso/data_prep/__init__.py +2 -0
  3. Rhapso/data_prep/n5_reader.py +188 -0
  4. Rhapso/data_prep/s3_big_stitcher_reader.py +55 -0
  5. Rhapso/data_prep/xml_to_dataframe.py +215 -0
  6. Rhapso/detection/__init__.py +5 -0
  7. Rhapso/detection/advanced_refinement.py +203 -0
  8. Rhapso/detection/difference_of_gaussian.py +324 -0
  9. Rhapso/detection/image_reader.py +117 -0
  10. Rhapso/detection/metadata_builder.py +130 -0
  11. Rhapso/detection/overlap_detection.py +327 -0
  12. Rhapso/detection/points_validation.py +49 -0
  13. Rhapso/detection/save_interest_points.py +265 -0
  14. Rhapso/detection/view_transform_models.py +67 -0
  15. Rhapso/fusion/__init__.py +0 -0
  16. Rhapso/fusion/affine_fusion/__init__.py +2 -0
  17. Rhapso/fusion/affine_fusion/blend.py +289 -0
  18. Rhapso/fusion/affine_fusion/fusion.py +601 -0
  19. Rhapso/fusion/affine_fusion/geometry.py +159 -0
  20. Rhapso/fusion/affine_fusion/io.py +546 -0
  21. Rhapso/fusion/affine_fusion/script_utils.py +111 -0
  22. Rhapso/fusion/affine_fusion/setup.py +4 -0
  23. Rhapso/fusion/affine_fusion_worker.py +234 -0
  24. Rhapso/fusion/multiscale/__init__.py +0 -0
  25. Rhapso/fusion/multiscale/aind_hcr_data_transformation/__init__.py +19 -0
  26. Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/__init__.py +3 -0
  27. Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/czi_to_zarr.py +698 -0
  28. Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/zarr_writer.py +265 -0
  29. Rhapso/fusion/multiscale/aind_hcr_data_transformation/models.py +81 -0
  30. Rhapso/fusion/multiscale/aind_hcr_data_transformation/utils/__init__.py +3 -0
  31. Rhapso/fusion/multiscale/aind_hcr_data_transformation/utils/utils.py +526 -0
  32. Rhapso/fusion/multiscale/aind_hcr_data_transformation/zeiss_job.py +249 -0
  33. Rhapso/fusion/multiscale/aind_z1_radial_correction/__init__.py +21 -0
  34. Rhapso/fusion/multiscale/aind_z1_radial_correction/array_to_zarr.py +257 -0
  35. Rhapso/fusion/multiscale/aind_z1_radial_correction/radial_correction.py +557 -0
  36. Rhapso/fusion/multiscale/aind_z1_radial_correction/run_capsule.py +98 -0
  37. Rhapso/fusion/multiscale/aind_z1_radial_correction/utils/__init__.py +3 -0
  38. Rhapso/fusion/multiscale/aind_z1_radial_correction/utils/utils.py +266 -0
  39. Rhapso/fusion/multiscale/aind_z1_radial_correction/worker.py +89 -0
  40. Rhapso/fusion/multiscale_worker.py +113 -0
  41. Rhapso/fusion/neuroglancer_link_gen/__init__.py +8 -0
  42. Rhapso/fusion/neuroglancer_link_gen/dispim_link.py +235 -0
  43. Rhapso/fusion/neuroglancer_link_gen/exaspim_link.py +127 -0
  44. Rhapso/fusion/neuroglancer_link_gen/hcr_link.py +368 -0
  45. Rhapso/fusion/neuroglancer_link_gen/iSPIM_top.py +47 -0
  46. Rhapso/fusion/neuroglancer_link_gen/link_utils.py +239 -0
  47. Rhapso/fusion/neuroglancer_link_gen/main.py +299 -0
  48. Rhapso/fusion/neuroglancer_link_gen/ng_layer.py +1434 -0
  49. Rhapso/fusion/neuroglancer_link_gen/ng_state.py +1123 -0
  50. Rhapso/fusion/neuroglancer_link_gen/parsers.py +336 -0
  51. Rhapso/fusion/neuroglancer_link_gen/raw_link.py +116 -0
  52. Rhapso/fusion/neuroglancer_link_gen/utils/__init__.py +4 -0
  53. Rhapso/fusion/neuroglancer_link_gen/utils/shader_utils.py +85 -0
  54. Rhapso/fusion/neuroglancer_link_gen/utils/transfer.py +43 -0
  55. Rhapso/fusion/neuroglancer_link_gen/utils/utils.py +303 -0
  56. Rhapso/fusion/neuroglancer_link_gen_worker.py +30 -0
  57. Rhapso/matching/__init__.py +0 -0
  58. Rhapso/matching/load_and_transform_points.py +458 -0
  59. Rhapso/matching/ransac_matching.py +544 -0
  60. Rhapso/matching/save_matches.py +120 -0
  61. Rhapso/matching/xml_parser.py +302 -0
  62. Rhapso/pipelines/__init__.py +0 -0
  63. Rhapso/pipelines/ray/__init__.py +0 -0
  64. Rhapso/pipelines/ray/aws/__init__.py +0 -0
  65. Rhapso/pipelines/ray/aws/alignment_pipeline.py +227 -0
  66. Rhapso/pipelines/ray/aws/config/__init__.py +0 -0
  67. Rhapso/pipelines/ray/evaluation.py +71 -0
  68. Rhapso/pipelines/ray/interest_point_detection.py +137 -0
  69. Rhapso/pipelines/ray/interest_point_matching.py +110 -0
  70. Rhapso/pipelines/ray/local/__init__.py +0 -0
  71. Rhapso/pipelines/ray/local/alignment_pipeline.py +167 -0
  72. Rhapso/pipelines/ray/matching_stats.py +104 -0
  73. Rhapso/pipelines/ray/param/__init__.py +0 -0
  74. Rhapso/pipelines/ray/solver.py +120 -0
  75. Rhapso/pipelines/ray/split_dataset.py +78 -0
  76. Rhapso/solver/__init__.py +0 -0
  77. Rhapso/solver/compute_tiles.py +562 -0
  78. Rhapso/solver/concatenate_models.py +116 -0
  79. Rhapso/solver/connected_graphs.py +111 -0
  80. Rhapso/solver/data_prep.py +181 -0
  81. Rhapso/solver/global_optimization.py +410 -0
  82. Rhapso/solver/model_and_tile_setup.py +109 -0
  83. Rhapso/solver/pre_align_tiles.py +323 -0
  84. Rhapso/solver/save_results.py +97 -0
  85. Rhapso/solver/view_transforms.py +75 -0
  86. Rhapso/solver/xml_to_dataframe_solver.py +213 -0
  87. Rhapso/split_dataset/__init__.py +0 -0
  88. Rhapso/split_dataset/compute_grid_rules.py +78 -0
  89. Rhapso/split_dataset/save_points.py +101 -0
  90. Rhapso/split_dataset/save_xml.py +377 -0
  91. Rhapso/split_dataset/split_images.py +537 -0
  92. Rhapso/split_dataset/xml_to_dataframe_split.py +219 -0
  93. rhapso-0.1.92.dist-info/METADATA +39 -0
  94. rhapso-0.1.92.dist-info/RECORD +101 -0
  95. rhapso-0.1.92.dist-info/WHEEL +5 -0
  96. rhapso-0.1.92.dist-info/licenses/LICENSE +21 -0
  97. rhapso-0.1.92.dist-info/top_level.txt +2 -0
  98. tests/__init__.py +1 -0
  99. tests/test_detection.py +17 -0
  100. tests/test_matching.py +21 -0
  101. tests/test_solving.py +21 -0
@@ -0,0 +1,601 @@
1
+ """Core fusion algorithm."""
2
+ import logging
3
+
4
+ import dask.array as da
5
+ import numpy as np
6
+ import torch
7
+ import s3fs
8
+ import zarr
9
+ import tensorstore as ts
10
+ import ray
11
+
12
+ from . import blend, geometry, io
13
+
14
+ def initialize_fusion(
15
+ dataset: io.Dataset,
16
+ post_reg_tfms: list[geometry.Transform],
17
+ output_params: io.OutputParameters,
18
+ ) -> tuple[dict, dict, dict, dict, tuple, tuple, torch.Tensor]:
19
+ """
20
+ Creates all core fusion data structures and key algorithm inputs.
21
+
22
+ Inputs
23
+ ------
24
+ Dataset, OutputParameters application primitives.
25
+
26
+ Returns
27
+ -------
28
+ tile_arrays: Dictionary of input tile arrays
29
+ tile_transforms: Dictionary of (list of) registrations associated with each tile
30
+ tile_sizes: Dictionary of tile sizes
31
+ tile_aabbs: Dictionary of AABB of each transformed tile
32
+ output_volume_size: Size of output volume
33
+ output_volume_origin: Location of output volume
34
+ """
35
+
36
+ tile_arrays, tile_paths = dataset.tile_volumes_tczyx
37
+
38
+ tile_transforms: dict[
39
+ int, list[geometry.Transform]
40
+ ] = dataset.tile_transforms_zyx
41
+ input_resolution_zyx: tuple[
42
+ float, float, float
43
+ ] = dataset.tile_resolution_zyx
44
+ iz, iy, ix = input_resolution_zyx
45
+ scale_input_zyx = geometry.Affine(
46
+ np.array([[iz, 0, 0, 0], [0, iy, 0, 0], [0, 0, ix, 0]])
47
+ )
48
+
49
+ output_resolution_zyx: tuple[
50
+ float, float, float
51
+ ] = output_params.resolution_zyx
52
+ oz, oy, ox = output_resolution_zyx
53
+ sample_output_zyx = geometry.Affine(
54
+ np.array([[1 / oz, 0, 0, 0], [0, 1 / oy, 0, 0], [0, 0, 1 / ox, 0]])
55
+ )
56
+ for tile_id, tfm_list in tile_transforms.items():
57
+ tile_transforms[tile_id] = [
58
+ *tfm_list,
59
+ scale_input_zyx,
60
+ *post_reg_tfms,
61
+ sample_output_zyx,
62
+ ]
63
+
64
+ tile_sizes_zyx: dict[int, tuple[int, int, int]] = {}
65
+ tile_aabbs: dict[int, geometry.AABB] = {}
66
+ tile_boundary_point_cloud_zyx = []
67
+
68
+ for tile_id, tile_arr in tile_arrays.items():
69
+ tile_sizes_zyx[tile_id] = zyx = tile_arr.shape[2:]
70
+ tile_sizes_zyx[tile_id] = zyx = tile_arr.shape[2:]
71
+ tile_boundaries = torch.Tensor(
72
+ [
73
+ [0.0, 0.0, 0.0],
74
+ [zyx[0], 0.0, 0.0],
75
+ [0.0, zyx[1], 0.0],
76
+ [0.0, 0.0, zyx[2]],
77
+ [zyx[0], zyx[1], 0.0],
78
+ [zyx[0], 0.0, zyx[2]],
79
+ [0.0, zyx[1], zyx[2]],
80
+ [zyx[0], zyx[1], zyx[2]],
81
+ ]
82
+ )
83
+
84
+ tfm_list = tile_transforms[tile_id]
85
+ for i, tfm in enumerate(tfm_list):
86
+ tile_boundaries = tfm.forward(
87
+ tile_boundaries, device=torch.device("cpu")
88
+ )
89
+
90
+ tile_aabbs[tile_id] = geometry.aabb_3d(tile_boundaries)
91
+ tile_boundary_point_cloud_zyx.extend(tile_boundaries)
92
+ tile_boundary_point_cloud_zyx = torch.stack(
93
+ tile_boundary_point_cloud_zyx, dim=0
94
+ )
95
+
96
+ # Resolve Output Volume Dimensions and Absolute Position
97
+ global_tile_boundaries = geometry.aabb_3d(tile_boundary_point_cloud_zyx)
98
+ OUTPUT_VOLUME_SIZE = [
99
+ int(global_tile_boundaries[1] - global_tile_boundaries[0]),
100
+ int(global_tile_boundaries[3] - global_tile_boundaries[2]),
101
+ int(global_tile_boundaries[5] - global_tile_boundaries[4]),
102
+ ]
103
+
104
+ # Rounding up the OUTPUT_VOLUME_SIZE to the nearest chunk
105
+ # b/c zarr-python has occasional errors writing at the boundaries.
106
+ # This ensures a multiple of chunksize without losing data.
107
+ remainder_0 = OUTPUT_VOLUME_SIZE[0] % output_params.chunksize[2]
108
+ remainder_1 = OUTPUT_VOLUME_SIZE[1] % output_params.chunksize[3]
109
+ remainder_2 = OUTPUT_VOLUME_SIZE[2] % output_params.chunksize[4]
110
+ if remainder_0 > 0:
111
+ OUTPUT_VOLUME_SIZE[0] -= remainder_0
112
+ OUTPUT_VOLUME_SIZE[0] += output_params.chunksize[2]
113
+ if remainder_1 > 0:
114
+ OUTPUT_VOLUME_SIZE[1] -= remainder_1
115
+ OUTPUT_VOLUME_SIZE[1] += output_params.chunksize[3]
116
+ if remainder_2 > 0:
117
+ OUTPUT_VOLUME_SIZE[2] -= remainder_2
118
+ OUTPUT_VOLUME_SIZE[2] += output_params.chunksize[4]
119
+ OUTPUT_VOLUME_SIZE = tuple(OUTPUT_VOLUME_SIZE)
120
+
121
+ OUTPUT_VOLUME_ORIGIN = (
122
+ torch.min(tile_boundary_point_cloud_zyx[:, 0]).item(),
123
+ torch.min(tile_boundary_point_cloud_zyx[:, 1]).item(),
124
+ torch.min(tile_boundary_point_cloud_zyx[:, 2]).item(),
125
+ )
126
+
127
+ # Shift AABB's into Output Volume where
128
+ # absolute position of output volume is moved to (0, 0, 0)
129
+ for tile_id, t_aabb in tile_aabbs.items():
130
+ updated_aabb = (
131
+ t_aabb[0] - OUTPUT_VOLUME_ORIGIN[0],
132
+ t_aabb[1] - OUTPUT_VOLUME_ORIGIN[0],
133
+ t_aabb[2] - OUTPUT_VOLUME_ORIGIN[1],
134
+ t_aabb[3] - OUTPUT_VOLUME_ORIGIN[1],
135
+ t_aabb[4] - OUTPUT_VOLUME_ORIGIN[2],
136
+ t_aabb[5] - OUTPUT_VOLUME_ORIGIN[2],
137
+ )
138
+ tile_aabbs[tile_id] = updated_aabb
139
+
140
+ return (
141
+ tile_arrays,
142
+ tile_paths,
143
+ tile_transforms,
144
+ tile_sizes_zyx,
145
+ tile_aabbs,
146
+ OUTPUT_VOLUME_SIZE,
147
+ OUTPUT_VOLUME_ORIGIN,
148
+ )
149
+
150
+
151
+ def initialize_output_volume_dask(
152
+ output_params: io.OutputParameters,
153
+ output_volume_size: tuple[int, int, int],
154
+ ) -> zarr.core.Array:
155
+ """
156
+ Self-documentation of output store initialization.
157
+
158
+ Inputs
159
+ ------
160
+ output_params: OutputParameters application instance.
161
+ output_volume_size: output of initalize_data_structures(...)
162
+
163
+ Returns
164
+ -------
165
+ Zarr thread-safe datastore initialized on OutputParameters.
166
+ """
167
+
168
+ # Local execution
169
+ out_group = zarr.open_group(output_params.path, mode="w")
170
+
171
+ # Cloud execuion
172
+ if output_params.path.startswith('s3'):
173
+ s3 = s3fs.S3FileSystem(
174
+ config_kwargs={
175
+ 'max_pool_connections': 50,
176
+ 's3': {
177
+ 'multipart_threshold': 64 * 1024 * 1024, # 64 MB, avoid multipart upload for small chunks
178
+ 'max_concurrent_requests': 20 # Increased from 10 -> 20.
179
+ },
180
+ 'retries': {
181
+ 'total_max_attempts': 100,
182
+ 'mode': 'adaptive',
183
+ }
184
+ }
185
+ )
186
+ store = s3fs.S3Map(root=output_params.path, s3=s3)
187
+ out_group = zarr.open(store=store, mode='a')
188
+
189
+ path = "0"
190
+ chunksize = output_params.chunksize
191
+ datatype = output_params.dtype
192
+ dimension_separator = "/"
193
+ compressor = output_params.compressor
194
+ output_volume = out_group.create_dataset(
195
+ path,
196
+ shape=(
197
+ 1,
198
+ 1,
199
+ output_volume_size[0],
200
+ output_volume_size[1],
201
+ output_volume_size[2],
202
+ ),
203
+ chunks=chunksize,
204
+ dtype=datatype,
205
+ compressor=compressor,
206
+ dimension_separator=dimension_separator,
207
+ overwrite=True,
208
+ fill_value=0,
209
+ )
210
+
211
+ return output_volume
212
+
213
+
214
+ def initialize_output_volume_tensorstore(
215
+ output_params: io.OutputParameters,
216
+ output_volume_size: tuple[int, int, int],
217
+ ):
218
+ """
219
+ The output is an async Tensorstore obj that you need
220
+ to call .result() to perform a write.
221
+ """
222
+ parts = output_params.path.split('/')
223
+ bucket = parts[2]
224
+ path = '/'.join(parts[3:])
225
+ chunksize = list(output_params.chunksize)
226
+ output_shape = [1,
227
+ 1,
228
+ output_volume_size[0],
229
+ output_volume_size[1],
230
+ output_volume_size[2]]
231
+
232
+ return ts.open({
233
+ 'driver': 'zarr',
234
+ 'dtype': 'uint16',
235
+ 'kvstore' : {
236
+ 'driver': 's3',
237
+ 'bucket': bucket,
238
+ 'path': path,
239
+ },
240
+ 'create': True,
241
+ 'open': True,
242
+ 'metadata': {
243
+ 'chunks': chunksize,
244
+ 'compressor': {
245
+ 'blocksize': 0,
246
+ 'clevel': 1,
247
+ 'cname': 'zstd',
248
+ 'id': 'blosc',
249
+ 'shuffle': 1,
250
+ },
251
+ 'dimension_separator': '/',
252
+ 'dtype': '<u2',
253
+ 'fill_value': 0,
254
+ 'filters': None,
255
+ 'order': 'C',
256
+ 'shape': output_shape,
257
+ 'zarr_format': 2
258
+ }
259
+ }).result()
260
+
261
+
262
+ def initialize_output_volume(
263
+ output_params: io.OutputParameters,
264
+ output_volume_size: tuple[int, int, int],
265
+ ) -> io.OutputArray:
266
+ output = None
267
+
268
+ assert output_params.datastore in [0, 1], \
269
+ f"Only 0 = Dask and 1 = Tensorstore supported."
270
+ if output_params.datastore == 0:
271
+ output = initialize_output_volume_dask(output_params, output_volume_size)
272
+ elif output_params.datastore == 1:
273
+ output = initialize_output_volume_tensorstore(output_params, output_volume_size)
274
+ return output
275
+
276
+
277
+ def get_cell_count_zyx(
278
+ output_volume_size: tuple[int, int, int], cell_size: tuple[int, int, int]
279
+ ) -> tuple[int, int, int]:
280
+ """
281
+ Total amount of z,y, and x cells returned in that order.
282
+ Input sizes are in canonical zyx order.
283
+ """
284
+ z_cnt = int(np.ceil(output_volume_size[0] / cell_size[0]))
285
+ y_cnt = int(np.ceil(output_volume_size[1] / cell_size[1]))
286
+ x_cnt = int(np.ceil(output_volume_size[2] / cell_size[2]))
287
+
288
+ return z_cnt, y_cnt, x_cnt
289
+
290
+
291
+ def run_fusion(
292
+ # client, # Uncomment for testing in jupyterlab
293
+ dataset: io.Dataset,
294
+ output_params: io.OutputParameters,
295
+ cell_size: tuple[int, int, int],
296
+ post_reg_tfms: list[geometry.Affine],
297
+ blend_module: blend.BlendingModule,
298
+ ):
299
+ """
300
+ Fusion algorithm.
301
+ Inputs: Application objs initalized from input configurations.
302
+ Output: Writes to location in output params.
303
+ """
304
+
305
+ logging.basicConfig(
306
+ format="%(asctime)s %(message)s", datefmt="%Y-%m-%d %H:%M"
307
+ )
308
+ LOGGER = logging.getLogger(__name__)
309
+ LOGGER.setLevel(logging.INFO)
310
+
311
+ a, p, b, c, d, e, f = initialize_fusion(dataset, post_reg_tfms, output_params)
312
+ tile_arrays = a
313
+ tile_paths = p
314
+ tile_transforms = b
315
+ tile_sizes_zyx = c
316
+ tile_aabbs = d
317
+ output_volume_size = e
318
+ output_volume_origin = f # Temp variables to meet line character maximum.
319
+
320
+ output_volume = initialize_output_volume(output_params, output_volume_size)
321
+
322
+ LOGGER.info(f"Number of Tiles: {len(tile_arrays)}")
323
+ LOGGER.info(f"{output_volume_size=}")
324
+
325
+ store = output_volume.store
326
+ write_root = getattr(store, "root", None) or getattr(store, "path", None)
327
+ write_ds = output_volume.path
328
+
329
+ z_cnt, y_cnt, x_cnt = get_cell_count_zyx(output_volume_size, cell_size)
330
+ cells = [(z, y, x) for z in range(z_cnt) for y in range(y_cnt) for x in range(x_cnt)]
331
+ num_cells = len(cells)
332
+ LOGGER.info(f'Coloring {num_cells} cells')
333
+
334
+ @ray.remote
335
+ def process_color_cell(curr_cell, tile_paths, write_root, write_ds, tile_transforms,
336
+ tile_sizes_zyx, tile_aabbs, output_volume, output_volume_origin, cell_size,
337
+ blend_module
338
+ ):
339
+ z, y, x = curr_cell
340
+ color_cell(tile_paths, write_root, write_ds, tile_transforms, tile_sizes_zyx, tile_aabbs,
341
+ output_volume, output_volume_origin, cell_size, blend_module, z, y, x,
342
+ torch.device("cpu")
343
+ )
344
+
345
+ return {"cell": curr_cell}
346
+
347
+ # submit exactly like your loop, one task per cell
348
+ futures = [
349
+ process_color_cell.remote((z, y, x), tile_paths, write_root, write_ds,
350
+ tile_transforms, tile_sizes_zyx, tile_aabbs, output_volume,
351
+ output_volume_origin, cell_size, blend_module
352
+ )
353
+ for (z, y, x) in cells
354
+ ]
355
+
356
+ ray.get(futures)
357
+
358
+ # DEBUG - iterative approach
359
+ # for (z, y, x) in cells:
360
+ # color_cell(
361
+ # tile_paths, write_root, write_ds,
362
+ # tile_transforms, tile_sizes_zyx, tile_aabbs,
363
+ # output_volume, output_volume_origin,
364
+ # cell_size, blend_module,
365
+ # z, y, x, torch.device("cpu")
366
+ # )
367
+
368
+ def color_cell(
369
+ tile_paths,
370
+ write_root,
371
+ write_ds,
372
+ tile_transforms: dict[int, list[geometry.Transform]],
373
+ tile_sizes_zyx: dict[int, tuple[int, int, int]],
374
+ tile_aabbs: dict[int, geometry.AABB],
375
+ output_volume: io.OutputArray,
376
+ output_volume_origin: tuple[float, float, float],
377
+ cell_size: tuple[int, int, int],
378
+ blend_module: blend.BlendingModule,
379
+ z: int,
380
+ y: int,
381
+ x: int,
382
+ device: torch.device,
383
+ ):
384
+ """
385
+ Parallelized function called in fusion.
386
+
387
+ Inputs
388
+ -------
389
+ tile_arrays: Dictionary of input tile arrays
390
+ tile_transforms: Dictionary of (list of) registrations associated with each tile
391
+ tile_sizes_zyx: Dictionary of tile sizes
392
+ tile_aabbs_zyx: Dictionary of AABB of each transformed tile
393
+ output_volume: Zarr store parallel functions write to
394
+ output_volume_origin: Location of output volume
395
+ cell_size: operating volume of this function
396
+ blend_module: application blending obj
397
+ z, y, x: location of cell in terms of output volume indices
398
+ """
399
+
400
+ # Cell Boundaries, exclusive stop index
401
+ output_volume_size = output_volume.shape
402
+ cell_box = np.array(
403
+ [
404
+ [z * cell_size[0], z * cell_size[0] + cell_size[0]],
405
+ [y * cell_size[1], y * cell_size[1] + cell_size[1]],
406
+ [x * cell_size[2], x * cell_size[2] + cell_size[2]],
407
+ ]
408
+ )
409
+ cell_box[:, 1] = np.minimum(
410
+ cell_box[:, 1], np.array(output_volume_size[2:])
411
+ )
412
+
413
+ cell_box = cell_box.flatten()
414
+
415
+ # Collision Detection
416
+ # Collision defined by overlapping intervals in all 3 dimensions.
417
+ # Two intervals (A, B) collide if A_max is not <= B_min and A_min is not >= B_max.
418
+ overlapping_tiles: list[int] = []
419
+ for tile_id, t_aabb in tile_aabbs.items():
420
+ if (
421
+ (cell_box[1] > t_aabb[0] and cell_box[0] < t_aabb[1])
422
+ and (cell_box[3] > t_aabb[2] and cell_box[2] < t_aabb[3])
423
+ and (cell_box[5] > t_aabb[4] and cell_box[4] < t_aabb[5])
424
+ ):
425
+ overlapping_tiles.append(tile_id)
426
+
427
+ # Interpolation for cell_contributions
428
+ cell_contributions: list[torch.Tensor] = []
429
+ cell_contribution_tile_ids: list[int] = []
430
+ for tile_id in overlapping_tiles:
431
+ # Init tile coords, arange end-exclusive, +0.5 to represent voxel center
432
+ z_indices = torch.arange(cell_box[0], cell_box[1], step=1) + 0.5
433
+ y_indices = torch.arange(cell_box[2], cell_box[3], step=1) + 0.5
434
+ x_indices = torch.arange(cell_box[4], cell_box[5], step=1) + 0.5
435
+ z_indices = z_indices.to(device)
436
+ y_indices = y_indices.to(device)
437
+ x_indices = x_indices.to(device)
438
+
439
+ z_grid, y_grid, x_grid = torch.meshgrid(
440
+ z_indices, y_indices, x_indices, indexing="ij"
441
+ )
442
+ z_grid = torch.unsqueeze(z_grid, 0)
443
+ y_grid = torch.unsqueeze(y_grid, 0)
444
+ x_grid = torch.unsqueeze(x_grid, 0)
445
+
446
+ tile_coords = torch.concatenate((z_grid, y_grid, x_grid), axis=0)
447
+ # (3, z, y, x) -> (z, y, x, 3)
448
+ tile_coords = torch.movedim(tile_coords, source=0, destination=3)
449
+
450
+ # Define tile coords wrt output vol origin
451
+ tile_coords = tile_coords + torch.Tensor(output_volume_origin).to(
452
+ device
453
+ )
454
+
455
+ # Send tile_coords through inverse transforms
456
+ # NOTE: tile_transforms list must be iterated thru in reverse
457
+ # (z, y, x, 3) -> (z, y, x, 3)
458
+ for tfm in reversed(tile_transforms[tile_id]):
459
+ tile_coords = tfm.backward(tile_coords, device=device)
460
+
461
+ # Calculate AABB of transformed coords
462
+ z_min, z_max, y_min, y_max, x_min, x_max = geometry.aabb_3d(
463
+ tile_coords
464
+ )
465
+
466
+ # Mini Optimization: Check true collision before executing interpolation/fusion
467
+ # That is, aabb of transformed coordinates into imagespace actually overlap the image.
468
+ t_size_zyx = tile_sizes_zyx[tile_id]
469
+ if not (
470
+ (z_max > 0 and z_min < t_size_zyx[0])
471
+ and (y_max > 0 and y_min < t_size_zyx[1])
472
+ and (x_max > 0 and x_min < t_size_zyx[2])
473
+ ):
474
+ continue
475
+
476
+ # Calculate overlapping region between transformed coords and image boundary
477
+ # For intervals (A, B):
478
+ # The lower bound of overlapping region = max(A_min, B_min)
479
+ # The upper bound of overlapping region = min(A_max, B_max)
480
+ crop_min_z = torch.max(torch.Tensor([0, z_min]))
481
+ crop_max_z = torch.min(torch.Tensor([t_size_zyx[0], z_max]))
482
+
483
+ crop_min_y = torch.max(torch.Tensor([0, y_min]))
484
+ crop_max_y = torch.min(torch.Tensor([t_size_zyx[1], y_max]))
485
+
486
+ crop_min_x = torch.max(torch.Tensor([0, x_min]))
487
+ crop_max_x = torch.min(torch.Tensor([t_size_zyx[2], x_max]))
488
+
489
+ # Make sure crop_{min, max}_{z, y, x} are integers to be used as indices.
490
+ # Minimum values are rounded down to nearest integer.
491
+ # Maximum values are rounded up to nearest integer.
492
+ crop_min_z = int(torch.floor(crop_min_z))
493
+ crop_min_y = int(torch.floor(crop_min_y))
494
+ crop_min_x = int(torch.floor(crop_min_x))
495
+
496
+ crop_max_z = int(torch.ceil(crop_max_z))
497
+ crop_max_y = int(torch.ceil(crop_max_y))
498
+ crop_max_x = int(torch.ceil(crop_max_x))
499
+
500
+ # Define tile coords wrt base image crop coordinates
501
+ image_crop_offset = torch.Tensor(
502
+ [crop_min_z, crop_min_y, crop_min_x]
503
+ ).to(device)
504
+ tile_coords = tile_coords - image_crop_offset
505
+
506
+ # Prep inputs to interpolation
507
+ image_crop_slice = (
508
+ 0,
509
+ 0,
510
+ slice(crop_min_z, crop_max_z),
511
+ slice(crop_min_y, crop_max_y),
512
+ slice(crop_min_x, crop_max_x),
513
+ )
514
+ s3_read = s3fs.S3FileSystem(anon=True)
515
+ src_path = tile_paths[tile_id]
516
+ store = s3fs.S3Map(root=src_path, s3=s3_read)
517
+ zarr_arr = zarr.open(store=store, mode="r")
518
+ image_crop = zarr_arr[image_crop_slice]
519
+
520
+ if isinstance(image_crop, da.Array):
521
+ image_crop = image_crop.compute()
522
+
523
+ image_crop = image_crop.astype(
524
+ np.int32
525
+ ) # Promote uint16 -> Pytorch compatible int32
526
+ image_crop = torch.Tensor(image_crop).to(device)
527
+
528
+ # Pytorch flow field follows a different basis than the image numpy basis.
529
+ # Change of basis to interpolation basis, which preserves relative distances/angles/positions.
530
+ # (z, y, x, 3) -> (z, y, x, 3)
531
+ interp_cob_matrix = torch.Tensor(
532
+ [[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]]
533
+ )
534
+ interp_cob = geometry.Affine(interp_cob_matrix)
535
+ tile_coords = interp_cob.forward(tile_coords, device=device)
536
+
537
+ # Interpolation expects 'grid' parameter/sample locations to be normalized [-1, 1].
538
+ # Very specific per-dimension normalization according to CoB
539
+ crop_z_length = crop_max_z - crop_min_z
540
+ crop_y_length = crop_max_y - crop_min_y
541
+ crop_x_length = crop_max_x - crop_min_x
542
+ tile_coords[:, :, :, 0] = (
543
+ tile_coords[:, :, :, 0] - (crop_x_length / 2)
544
+ ) / (crop_x_length / 2)
545
+ tile_coords[:, :, :, 1] = (
546
+ tile_coords[:, :, :, 1] - (crop_y_length / 2)
547
+ ) / (crop_y_length / 2)
548
+ tile_coords[:, :, :, 2] = (
549
+ tile_coords[:, :, :, 2] - (crop_z_length / 2)
550
+ ) / (crop_z_length / 2)
551
+
552
+ # Final reshaping
553
+ # image_crop: (z_in, y_in, x_in) -> (1, 1, z_in, y_in, x_in)
554
+ # tile_coords: (z_out, y_out, x_out, 3) -> (1, z_out, y_out, x_out, 3)
555
+ # => tile_contribution: (1, 1, z_out, y_out, x_out)
556
+ image_crop = image_crop[(None,) * 2]
557
+ tile_coords = torch.unsqueeze(tile_coords, 0)
558
+
559
+ # Interpolate and Store
560
+ tile_contribution = torch.nn.functional.grid_sample(
561
+ image_crop, tile_coords, padding_mode="zeros", mode="nearest", align_corners=False
562
+ )
563
+
564
+ cell_contributions.append(tile_contribution)
565
+ cell_contribution_tile_ids.append(tile_id)
566
+
567
+ del tile_coords
568
+
569
+ # Fuse all cell contributions together with specified blend module
570
+ fused_cell = torch.zeros((1,
571
+ 1,
572
+ cell_box[1] - cell_box[0],
573
+ cell_box[3] - cell_box[2],
574
+ cell_box[5] - cell_box[4]))
575
+ if len(cell_contributions) != 0:
576
+ fused_cell = blend_module.blend(
577
+ cell_contributions,
578
+ device,
579
+ kwargs={'chunk_tile_ids': cell_contribution_tile_ids,
580
+ 'cell_box': cell_box}
581
+ )
582
+ cell_contributions = []
583
+
584
+ # Write
585
+ output_slice = (
586
+ slice(0, 1),
587
+ slice(0, 1),
588
+ slice(cell_box[0], cell_box[1]),
589
+ slice(cell_box[2], cell_box[3]),
590
+ slice(cell_box[4], cell_box[5]),
591
+ )
592
+ # Convert from float32 -> canonical uint16
593
+ output_chunk = np.array(fused_cell.cpu()).astype(np.uint16)
594
+
595
+ s3_write = s3fs.S3FileSystem(anon=False)
596
+ out_store = s3fs.S3Map(root=write_root, s3=s3_write)
597
+ arr = zarr.open(store=out_store, mode="a")[write_ds]
598
+ arr[output_slice] = np.ascontiguousarray(output_chunk)
599
+
600
+ del fused_cell
601
+ del output_chunk