Rhapso 0.1.92__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. Rhapso/__init__.py +1 -0
  2. Rhapso/data_prep/__init__.py +2 -0
  3. Rhapso/data_prep/n5_reader.py +188 -0
  4. Rhapso/data_prep/s3_big_stitcher_reader.py +55 -0
  5. Rhapso/data_prep/xml_to_dataframe.py +215 -0
  6. Rhapso/detection/__init__.py +5 -0
  7. Rhapso/detection/advanced_refinement.py +203 -0
  8. Rhapso/detection/difference_of_gaussian.py +324 -0
  9. Rhapso/detection/image_reader.py +117 -0
  10. Rhapso/detection/metadata_builder.py +130 -0
  11. Rhapso/detection/overlap_detection.py +327 -0
  12. Rhapso/detection/points_validation.py +49 -0
  13. Rhapso/detection/save_interest_points.py +265 -0
  14. Rhapso/detection/view_transform_models.py +67 -0
  15. Rhapso/fusion/__init__.py +0 -0
  16. Rhapso/fusion/affine_fusion/__init__.py +2 -0
  17. Rhapso/fusion/affine_fusion/blend.py +289 -0
  18. Rhapso/fusion/affine_fusion/fusion.py +601 -0
  19. Rhapso/fusion/affine_fusion/geometry.py +159 -0
  20. Rhapso/fusion/affine_fusion/io.py +546 -0
  21. Rhapso/fusion/affine_fusion/script_utils.py +111 -0
  22. Rhapso/fusion/affine_fusion/setup.py +4 -0
  23. Rhapso/fusion/affine_fusion_worker.py +234 -0
  24. Rhapso/fusion/multiscale/__init__.py +0 -0
  25. Rhapso/fusion/multiscale/aind_hcr_data_transformation/__init__.py +19 -0
  26. Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/__init__.py +3 -0
  27. Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/czi_to_zarr.py +698 -0
  28. Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/zarr_writer.py +265 -0
  29. Rhapso/fusion/multiscale/aind_hcr_data_transformation/models.py +81 -0
  30. Rhapso/fusion/multiscale/aind_hcr_data_transformation/utils/__init__.py +3 -0
  31. Rhapso/fusion/multiscale/aind_hcr_data_transformation/utils/utils.py +526 -0
  32. Rhapso/fusion/multiscale/aind_hcr_data_transformation/zeiss_job.py +249 -0
  33. Rhapso/fusion/multiscale/aind_z1_radial_correction/__init__.py +21 -0
  34. Rhapso/fusion/multiscale/aind_z1_radial_correction/array_to_zarr.py +257 -0
  35. Rhapso/fusion/multiscale/aind_z1_radial_correction/radial_correction.py +557 -0
  36. Rhapso/fusion/multiscale/aind_z1_radial_correction/run_capsule.py +98 -0
  37. Rhapso/fusion/multiscale/aind_z1_radial_correction/utils/__init__.py +3 -0
  38. Rhapso/fusion/multiscale/aind_z1_radial_correction/utils/utils.py +266 -0
  39. Rhapso/fusion/multiscale/aind_z1_radial_correction/worker.py +89 -0
  40. Rhapso/fusion/multiscale_worker.py +113 -0
  41. Rhapso/fusion/neuroglancer_link_gen/__init__.py +8 -0
  42. Rhapso/fusion/neuroglancer_link_gen/dispim_link.py +235 -0
  43. Rhapso/fusion/neuroglancer_link_gen/exaspim_link.py +127 -0
  44. Rhapso/fusion/neuroglancer_link_gen/hcr_link.py +368 -0
  45. Rhapso/fusion/neuroglancer_link_gen/iSPIM_top.py +47 -0
  46. Rhapso/fusion/neuroglancer_link_gen/link_utils.py +239 -0
  47. Rhapso/fusion/neuroglancer_link_gen/main.py +299 -0
  48. Rhapso/fusion/neuroglancer_link_gen/ng_layer.py +1434 -0
  49. Rhapso/fusion/neuroglancer_link_gen/ng_state.py +1123 -0
  50. Rhapso/fusion/neuroglancer_link_gen/parsers.py +336 -0
  51. Rhapso/fusion/neuroglancer_link_gen/raw_link.py +116 -0
  52. Rhapso/fusion/neuroglancer_link_gen/utils/__init__.py +4 -0
  53. Rhapso/fusion/neuroglancer_link_gen/utils/shader_utils.py +85 -0
  54. Rhapso/fusion/neuroglancer_link_gen/utils/transfer.py +43 -0
  55. Rhapso/fusion/neuroglancer_link_gen/utils/utils.py +303 -0
  56. Rhapso/fusion/neuroglancer_link_gen_worker.py +30 -0
  57. Rhapso/matching/__init__.py +0 -0
  58. Rhapso/matching/load_and_transform_points.py +458 -0
  59. Rhapso/matching/ransac_matching.py +544 -0
  60. Rhapso/matching/save_matches.py +120 -0
  61. Rhapso/matching/xml_parser.py +302 -0
  62. Rhapso/pipelines/__init__.py +0 -0
  63. Rhapso/pipelines/ray/__init__.py +0 -0
  64. Rhapso/pipelines/ray/aws/__init__.py +0 -0
  65. Rhapso/pipelines/ray/aws/alignment_pipeline.py +227 -0
  66. Rhapso/pipelines/ray/aws/config/__init__.py +0 -0
  67. Rhapso/pipelines/ray/evaluation.py +71 -0
  68. Rhapso/pipelines/ray/interest_point_detection.py +137 -0
  69. Rhapso/pipelines/ray/interest_point_matching.py +110 -0
  70. Rhapso/pipelines/ray/local/__init__.py +0 -0
  71. Rhapso/pipelines/ray/local/alignment_pipeline.py +167 -0
  72. Rhapso/pipelines/ray/matching_stats.py +104 -0
  73. Rhapso/pipelines/ray/param/__init__.py +0 -0
  74. Rhapso/pipelines/ray/solver.py +120 -0
  75. Rhapso/pipelines/ray/split_dataset.py +78 -0
  76. Rhapso/solver/__init__.py +0 -0
  77. Rhapso/solver/compute_tiles.py +562 -0
  78. Rhapso/solver/concatenate_models.py +116 -0
  79. Rhapso/solver/connected_graphs.py +111 -0
  80. Rhapso/solver/data_prep.py +181 -0
  81. Rhapso/solver/global_optimization.py +410 -0
  82. Rhapso/solver/model_and_tile_setup.py +109 -0
  83. Rhapso/solver/pre_align_tiles.py +323 -0
  84. Rhapso/solver/save_results.py +97 -0
  85. Rhapso/solver/view_transforms.py +75 -0
  86. Rhapso/solver/xml_to_dataframe_solver.py +213 -0
  87. Rhapso/split_dataset/__init__.py +0 -0
  88. Rhapso/split_dataset/compute_grid_rules.py +78 -0
  89. Rhapso/split_dataset/save_points.py +101 -0
  90. Rhapso/split_dataset/save_xml.py +377 -0
  91. Rhapso/split_dataset/split_images.py +537 -0
  92. Rhapso/split_dataset/xml_to_dataframe_split.py +219 -0
  93. rhapso-0.1.92.dist-info/METADATA +39 -0
  94. rhapso-0.1.92.dist-info/RECORD +101 -0
  95. rhapso-0.1.92.dist-info/WHEEL +5 -0
  96. rhapso-0.1.92.dist-info/licenses/LICENSE +21 -0
  97. rhapso-0.1.92.dist-info/top_level.txt +2 -0
  98. tests/__init__.py +1 -0
  99. tests/test_detection.py +17 -0
  100. tests/test_matching.py +21 -0
  101. tests/test_solving.py +21 -0
@@ -0,0 +1,557 @@
1
+ """
2
+ Computes radial correction in microscopic data
3
+ """
4
+
5
+ import asyncio
6
+ import logging
7
+ import multiprocessing as mp
8
+ import time
9
+ import urllib.parse
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ from math import ceil
12
+ from pathlib import Path
13
+ from typing import List, Literal, Optional, Tuple, Union
14
+
15
+ import dask.array as da
16
+ import numba as nb
17
+ import numpy as np
18
+ import tensorstore as ts
19
+ from aind_data_schema.core.processing import DataProcess, ProcessName
20
+ from dask.diagnostics import ProgressBar
21
+ from dask.distributed import Client, LocalCluster
22
+ from natsort import natsorted
23
+ from scipy.ndimage import map_coordinates
24
+
25
+ from . import __maintainers__, __pipeline_version__, __url__, __version__
26
+ from .array_to_zarr import convert_array_to_zarr
27
+ from .utils import utils
28
+
29
+ logging.basicConfig(format="%(asctime)s %(message)s", datefmt="%Y-%m-%d %H:%M")
30
+ LOGGER = logging.getLogger(__name__)
31
+ LOGGER.setLevel(logging.INFO)
32
+
33
+
34
+ def calculate_corner_shift_from_pixel_size(XY_pixel_size: float) -> float:
35
+ """
36
+ Compute the corner shift value based on pixel size.
37
+
38
+ Parameters
39
+ ----------
40
+ XY_pixel_size : float
41
+ The size of a pixel in microns (not currently used).
42
+
43
+ Returns
44
+ -------
45
+ float
46
+ A constant value of 4.3.
47
+ """
48
+ return 4.3
49
+
50
+
51
+ def calculate_frac_cutoff_from_pixel_size(XY_pixel_size: float) -> float:
52
+ """
53
+ Compute the fractional cutoff for radial correction.
54
+
55
+ Parameters
56
+ ----------
57
+ XY_pixel_size : float
58
+ The size of a pixel in microns (not currently used).
59
+
60
+ Returns
61
+ -------
62
+ float
63
+ A constant value of 0.5.
64
+ """
65
+ return 0.5
66
+
67
+
68
+ # enables multithreading with prange
69
+ @nb.njit(parallel=True)
70
+ def _compute_coordinates(
71
+ pixels: int, cutoff: float, corner_shift: float, edge: int
72
+ ) -> tuple:
73
+ """
74
+ Computes the coordinates where the
75
+ pixels will be moved.
76
+
77
+ Parameters
78
+ ----------
79
+ pixels: int
80
+ Width or height of the image. It is assumed that the image
81
+ has the same resolution in XY.
82
+
83
+ cutoff: float
84
+ Radius beyond which distortion is applied.
85
+
86
+ corner_shift: float
87
+ How much to "pull in" corners beyond the cutoff.
88
+
89
+ edge: int
90
+ Number of pixels to crop from each side (e.g., due to interpolation instability).
91
+ """
92
+ # coords[0] -> y-coordinates relative to center
93
+ # coords[1] -> x-coordinates relative to center
94
+ coords = np.zeros((2, pixels, pixels), dtype=np.float32)
95
+
96
+ # stores the radius of each pixel from the center.
97
+ r = np.zeros((pixels, pixels), dtype=np.float32)
98
+
99
+ # First pass: calculate centered coordinates and radius for
100
+ # every pixel in the image, parallelized with prange
101
+ for i in nb.prange(pixels):
102
+ for j in range(pixels):
103
+ # Shifts (i, j) so the origin is at the center.
104
+ y = i - pixels // 2
105
+ x = j - pixels // 2
106
+ coords[0, i, j] = y
107
+ coords[1, i, j] = x
108
+
109
+ # Calculates the radius from the center.
110
+ r[i, j] = np.sqrt(x * x + y * y)
111
+
112
+ # Finds the maximum radius,
113
+ # rmax, used to normalize the distortion.
114
+ rmax = r.max()
115
+
116
+ # Second pass: apply radial distortion
117
+ r_piece = np.zeros_like(r)
118
+ angles = np.zeros_like(r)
119
+
120
+ for i in nb.prange(pixels):
121
+ for j in range(pixels):
122
+ r_val = r[i, j]
123
+ # Y X angle, careful with x y
124
+ # coords 0 is y, coords 1 is x
125
+ # Uses arctan2(y, x) to get the angle of the pixel from the center
126
+ angle = np.arctan2(coords[0, i, j], coords[1, i, j])
127
+
128
+ # pixels farther from center than cutoff are pulled outward/inward
129
+ if r_val > cutoff:
130
+ r_val += (r_val - cutoff) * corner_shift / (rmax - cutoff)
131
+
132
+ r_piece[i, j] = r_val
133
+ angles[i, j] = angle
134
+ coords[0, i, j] = r_val * np.sin(angle)
135
+ coords[1, i, j] = r_val * np.cos(angle)
136
+
137
+ # Crop edges and shift to image space
138
+ cropped = coords[:, edge:-edge, edge:-edge]
139
+
140
+ cropped += pixels // 2
141
+
142
+ return cropped
143
+
144
+
145
+ def _process_plane(args):
146
+ """Helper function to process a single z-plane for parallel execution"""
147
+ z, plane, coords, order = args
148
+ warp_coords = np.zeros((2, *coords[0].shape), dtype=np.float32)
149
+ warp_coords[0] = coords[0]
150
+ warp_coords[1] = coords[1]
151
+ return z, map_coordinates(plane, warp_coords, order=order, mode="constant")
152
+
153
+
154
+ def radial_correction(
155
+ tile_data: np.ndarray,
156
+ corner_shift: Optional[float] = 5.5,
157
+ frac_cutoff: Optional[float] = 0.5,
158
+ mode: Union[Literal["2d"], Literal["3d"]] = "3d",
159
+ order: int = 1,
160
+ max_workers: Optional[int] = None,
161
+ ) -> np.ndarray:
162
+ """
163
+ Apply radial correction to a tile with optimized performance.
164
+
165
+ Parameters
166
+ ----------
167
+ tile_data : np.ndarray
168
+ The 3D tile data (Z, Y, X) to be corrected.
169
+ corner_shift : Optional[float]
170
+ The amount of radial shift to apply (default is 5.5).
171
+ frac_cutoff : Optional[float]
172
+ Fraction of the radius to begin applying correction (default is 0.5).
173
+ mode : Union[Literal["2d"], Literal["3d"]]
174
+ Processing mode - "2d" for plane-wise processing or "3d" for full volume (default is "3d").
175
+ order : int
176
+ Interpolation order for map_coordinates (default is 1).
177
+ max_workers : Optional[int]
178
+ Maximum number of worker threads for parallel processing (default is None, which uses CPU count).
179
+
180
+ Returns
181
+ -------
182
+ np.ndarray
183
+ The corrected tile.
184
+ """
185
+ edge = ceil(corner_shift / np.sqrt(2)) + 1
186
+ shape = tile_data.shape
187
+ pixels = shape[1] # Assume square XY plane
188
+ cutoff = pixels * frac_cutoff
189
+
190
+ # Compute the warp to transform coordinates using numba
191
+ coords = _compute_coordinates(pixels, cutoff, corner_shift, edge)
192
+
193
+ # Calculate new shape after edge cropping
194
+ new_shape = np.array(shape) - [0, edge * 2, edge * 2]
195
+ LOGGER.info(f"New shape: {new_shape} - Mode {mode} - Cutoff: {cutoff}")
196
+
197
+ # Different processing methods based on mode
198
+ if mode == "2d":
199
+ # Process each z-plane separately in parallel
200
+ result = np.zeros(new_shape, dtype=tile_data.dtype) # dtype=np.uint16)
201
+
202
+ # Use ThreadPoolExecutor for parallel processing of z-planes
203
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
204
+ tasks = [(z, tile_data[z], coords, order) for z in range(shape[0])]
205
+ for z, processed_plane in executor.map(
206
+ lambda args: _process_plane(args), tasks
207
+ ):
208
+ # print(f"Tile data dtype: {tile_data[z].dtype} - Processed: {processed_plane.dtype}")
209
+ result[z] = processed_plane
210
+
211
+ return result
212
+
213
+ else: # 3D mode
214
+ # Create full 3D warping coordinates array
215
+ warp_coords = np.zeros((3, *new_shape), dtype=np.float32)
216
+
217
+ # Z coordinates remain unchanged
218
+ for z in range(new_shape[0]):
219
+ warp_coords[0, z] = z
220
+
221
+ # Apply pre-computed X-Y coordinate transformations to each z-plane
222
+ warp_coords[1] = np.repeat(
223
+ coords[0][np.newaxis, :, :], new_shape[0], axis=0
224
+ )
225
+ warp_coords[2] = np.repeat(
226
+ coords[1][np.newaxis, :, :], new_shape[0], axis=0
227
+ )
228
+
229
+ # Process the entire volume at once
230
+ return map_coordinates(
231
+ tile_data, warp_coords, order=order, mode="constant"
232
+ )
233
+
234
+
235
+ def read_zarr(
236
+ dataset_path: str,
237
+ compute: Optional[bool] = True,
238
+ ) -> Tuple:
239
+ """
240
+ Reads a zarr dataset
241
+
242
+ Parameters
243
+ ----------
244
+ dataset_path: str
245
+ Path where the dataset is stored.
246
+
247
+ compute: Optional[bool]
248
+ Computes the lazy dask graph.
249
+ Default: True
250
+
251
+ Returns
252
+ -------
253
+ Tuple[ArrayLike, da.Array]
254
+ ArrayLike or None if compute is false
255
+ Lazy dask array
256
+ """
257
+ tile = None
258
+
259
+ cluster = LocalCluster(
260
+ n_workers=mp.cpu_count(), threads_per_worker=1, memory_limit="auto"
261
+ )
262
+ client = Client(cluster)
263
+
264
+ # Explicitly setting threads to do reading (way faster)
265
+ try:
266
+ tile_lazy = da.from_zarr(dataset_path).squeeze()
267
+
268
+ if compute:
269
+ with ProgressBar():
270
+ tile = tile_lazy.compute(scheduler="threads")
271
+ finally:
272
+ client.close()
273
+ cluster.close()
274
+
275
+ return tile, tile_lazy
276
+
277
+
278
+ async def read_zarr_tensorstore(
279
+ dataset_path: str, scale: str, driver: Optional[str] = "zarr"
280
+ ) -> Tuple:
281
+ """
282
+ Reads a zarr dataset from local filesystem or S3 bucket
283
+ Parameters
284
+ ----------
285
+ dataset_path: str
286
+ Path where the dataset is stored. Can be a local path or an S3 path (s3://...)
287
+ scale: str
288
+ Multiscale to load
289
+ driver: Optional[str]
290
+ Tensorstore driver
291
+ Default: zarr
292
+ Returns
293
+ -------
294
+ Tuple[ArrayLike, da.Array]
295
+ ArrayLike or None if compute is false
296
+ Lazy dask array
297
+ """
298
+ # Parse the URL properly using urllib
299
+ parsed_url = urllib.parse.urlparse(dataset_path)
300
+
301
+ if parsed_url.scheme == "s3":
302
+ # Handle S3 path
303
+ bucket = parsed_url.netloc
304
+ # Remove leading slash if present
305
+ key = parsed_url.path.lstrip("/")
306
+ print(parsed_url, bucket, key)
307
+
308
+ ts_spec = {
309
+ "driver": str(driver),
310
+ "kvstore": {
311
+ "driver": "s3",
312
+ "bucket": bucket,
313
+ "path": key,
314
+ },
315
+ "path": str(scale),
316
+ }
317
+ else:
318
+ # Original local file handling
319
+ ts_spec = {
320
+ "driver": str(driver),
321
+ "kvstore": {
322
+ "driver": "file",
323
+ "path": str(dataset_path),
324
+ },
325
+ "path": str(scale),
326
+ }
327
+
328
+ tile_lazy = await ts.open(ts_spec)
329
+ tile = await tile_lazy.read()
330
+ return tile, tile_lazy
331
+
332
+
333
+ def apply_corr_to_zarr_tile(
334
+ dataset_path: str,
335
+ scale: str,
336
+ corner_shift: Optional[float] = 5.5,
337
+ frac_cutoff: Optional[float] = 0.5,
338
+ z_size_threshold: Optional[int] = 400,
339
+ order: Optional[int] = 1,
340
+ max_workers: Optional[int] = None,
341
+ driver: Optional[str] = "zarr",
342
+ ) -> np.ndarray:
343
+ """
344
+ Load a Zarr tile, apply radial correction, and return corrected tile.
345
+
346
+ Parameters
347
+ ----------
348
+ dataset_path : str
349
+ Path to the Zarr file containing the tile.
350
+
351
+ scale: str
352
+ Multiscale to load the data
353
+
354
+ corner_shift : Optional[float]
355
+ The amount of shift to apply to corners (default is 5.5).
356
+
357
+ frac_cutoff : Optional[float]
358
+ The fractional radius where correction starts (default is 0.5).
359
+
360
+ z_size_threshold: Optional[int]
361
+ Threshold in which 3D radial correction is applied.
362
+
363
+ order: Optional[int]
364
+ Interpolation order.
365
+ Default: 1
366
+
367
+ max_workers: Optional[int]
368
+ Max number of workers.
369
+ Default: None
370
+
371
+ driver: Optional[str]
372
+ Zarr driver to read the data.
373
+ Default: zarr
374
+
375
+ Returns
376
+ -------
377
+ np.ndarray
378
+ The corrected tile.
379
+ """
380
+ if z_size_threshold < 0:
381
+ raise ValueError(
382
+ f"Please, provide a correct threshold: {z_size_threshold}"
383
+ )
384
+
385
+ # Reading zarr dataset
386
+ data_in_memory, lazy_array = asyncio.run(
387
+ read_zarr_tensorstore(dataset_path, scale=scale, driver=driver)
388
+ )
389
+ # data_in_memory, lazy_array = read_zarr(f"{dataset_path}/{scale}", compute=True)
390
+ data_in_memory = data_in_memory.squeeze()
391
+ z_size = data_in_memory.shape[-3]
392
+
393
+ output_radial = None
394
+
395
+ LOGGER.info(f"Dataset shape {data_in_memory.shape}")
396
+
397
+ mode = "2d"
398
+
399
+ if z_size < z_size_threshold:
400
+ mode = "3d"
401
+
402
+ output_radial = radial_correction(
403
+ tile_data=data_in_memory,
404
+ corner_shift=corner_shift,
405
+ frac_cutoff=frac_cutoff,
406
+ mode=mode,
407
+ order=order,
408
+ max_workers=max_workers,
409
+ )
410
+
411
+ # print(f"input radial correction: {data_in_memory.shape} - {data_in_memory.dtype}")
412
+ # print(f"Output radial correction: {output_radial.shape} - {output_radial.dtype}")
413
+
414
+ return output_radial
415
+
416
+
417
+ def correct_and_save_tile(
418
+ dataset_loc: str,
419
+ output_path: str,
420
+ resolution_zyx: List[float],
421
+ scale: str = "0",
422
+ n_lvls: Optional[int] = 4,
423
+ driver: Optional[str] = "zarr",
424
+ ):
425
+ """
426
+ Corrects and saves a single tile.
427
+
428
+ Parameters
429
+ ----------
430
+ dataset_loc: str
431
+ Path to the dataset to be corrected.
432
+ output_path: str
433
+ Path to save the corrected dataset.
434
+ resolution_zyx: List[float]
435
+ Voxel size in the format [z, y, x].
436
+ scale: str
437
+ Multiscale to load the data.
438
+ Default: 0
439
+ n_lvls: Optional[int]
440
+ Number of downsampled levels to write.
441
+ Default: 4
442
+ s3_output_path: str
443
+ Dataset name in S3.
444
+ Default: None
445
+ cloud_write: bool
446
+ If True, write to S3.
447
+ Default: True
448
+ driver: Optional[str]
449
+ Driver to read the data with tensorstore.
450
+ Default: "zarr"
451
+ """
452
+
453
+ corner_shift = calculate_corner_shift_from_pixel_size(resolution_zyx[1])
454
+ frac_cutoff = calculate_frac_cutoff_from_pixel_size(resolution_zyx[1])
455
+
456
+ LOGGER.info(f"Input: {dataset_loc} - Output: {output_path}")
457
+ LOGGER.info(f"Corner Shift: {corner_shift} pixels")
458
+ LOGGER.info(f"Fraction Cutoff: {frac_cutoff}")
459
+
460
+ start_time = time.time()
461
+ corrected_tile = apply_corr_to_zarr_tile(
462
+ dataset_loc, scale, corner_shift, frac_cutoff, driver=driver
463
+ )
464
+ end_time = time.time()
465
+ LOGGER.info(
466
+ f"Time to correct: {end_time - start_time} seconds -> New shape {corrected_tile.shape}"
467
+ )
468
+
469
+ convert_array_to_zarr(
470
+ array=corrected_tile,
471
+ voxel_size=resolution_zyx,
472
+ chunk_size=[128] * 3,
473
+ output_path=str(output_path),
474
+ n_lvls=n_lvls,
475
+ )
476
+
477
+ data_process = None
478
+ # TODO: activate this when aind-data-schema 2.0 is out
479
+ # DataProcess(
480
+ # name=ProcessName.IMAGE_RADIAL_CORRECTION,
481
+ # software_version=__version__,
482
+ # start_date_time=start_time,
483
+ # end_date_time=end_time,
484
+ # input_location=dataset_loc,
485
+ # output_location=output_path,
486
+ # code_version=__version__,
487
+ # code_url=__url__,
488
+ # parameters={
489
+ # 'corner_shift': corner_shift,
490
+ # 'frac_cutoff': frac_cutoff
491
+ # },
492
+ # )
493
+
494
+ return data_process
495
+
496
+
497
+ def main(
498
+ data_folder: str,
499
+ results_folder: str,
500
+ acquisition_path: str,
501
+ tilenames: List[str],
502
+ driver: Optional[str] = "zarr",
503
+ ):
504
+ """
505
+ Radial correction to multiple tiles
506
+ based on provided YMLs.
507
+
508
+ Parameters
509
+ ----------
510
+ data_folder: str
511
+ Folder where the data is stored.
512
+
513
+ results_folder: str
514
+ Results folder. It could be a local path or
515
+ a S3 bucket.
516
+
517
+ acquisition_path: str
518
+ Path where the acquisition.json is.
519
+
520
+ tilenames: List[str]
521
+ Tiles to process. E.g.,
522
+ [Tile_X_000...ome.zarr, ..., ]
523
+
524
+ driver: Optional[str]
525
+ Driver to read the data with tensorstore
526
+ Default: "zarr"
527
+
528
+ """
529
+ zyx_voxel_size = utils.get_voxel_resolution(
530
+ acquisition_path=acquisition_path
531
+ )
532
+ LOGGER.info(f"Voxel ZYX resolution: {zyx_voxel_size}")
533
+
534
+ data_processes = []
535
+ for tilename in tilenames:
536
+ curr_tilename = tilename
537
+ zarr_path = f"{data_folder}/{tilename}"
538
+ output_path = f"{results_folder}/{curr_tilename}"
539
+ data_process = correct_and_save_tile(
540
+ dataset_loc=zarr_path,
541
+ output_path=output_path,
542
+ resolution_zyx=zyx_voxel_size,
543
+ n_lvls = 6,
544
+ driver=driver,
545
+ )
546
+
547
+ # utils.generate_processing(
548
+ # data_processes=data_processes,
549
+ # dest_processing=results_folder,
550
+ # processor_full_name=__maintainers__[0],
551
+ # pipeline_version=__pipeline_version__,
552
+ # prefix='radial_correction'
553
+ # )
554
+
555
+
556
+ if __name__ == "__main__":
557
+ main()
@@ -0,0 +1,98 @@
1
+ """
2
+ Runs radial correction in a set of tiles provided
3
+ to the data directory
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ from aind_z1_radial_correction import radial_correction
11
+ from aind_z1_radial_correction.utils import utils
12
+
13
+
14
+ def run():
15
+ """
16
+ Main run file in Code Ocean
17
+ """
18
+
19
+ data_folder = os.path.abspath("../data")
20
+ results_folder = os.path.abspath("../results")
21
+
22
+ # Worker scheduler path has radial parameters and acquisition
23
+ worker_scheduler_path = list(Path(data_folder).glob("worker*"))[0]
24
+
25
+ acquisition_path = f"{worker_scheduler_path}/acquisition.json"
26
+ data_description_path = f"{worker_scheduler_path}/data_description.json"
27
+
28
+ radial_correction_parameters_path = (
29
+ f"{worker_scheduler_path}/radial_correction_parameters.json"
30
+ )
31
+
32
+ required_input_elements = [
33
+ acquisition_path,
34
+ radial_correction_parameters_path,
35
+ data_description_path,
36
+ ]
37
+
38
+ missing_files = utils.validate_capsule_inputs(required_input_elements)
39
+
40
+ if len(missing_files):
41
+ raise ValueError(
42
+ f"We miss the following files in the capsule input: {missing_files}"
43
+ )
44
+
45
+ radial_correction_parameters = utils.read_json_as_dict(
46
+ radial_correction_parameters_path
47
+ )
48
+
49
+ tilenames = radial_correction_parameters.get("tilenames", [])
50
+ worker_id = radial_correction_parameters.get("worker_id", None)
51
+ bucket_name = radial_correction_parameters.get("bucket_name", None)
52
+ input_s3_dataset_path = radial_correction_parameters.get(
53
+ "input_s3_dataset_path", None
54
+ )
55
+ tensorstore_driver = radial_correction_parameters.get(
56
+ "tensorstore_driver", "zarr3"
57
+ )
58
+ write_to_s3 = radial_correction_parameters.get("write_to_s3", True)
59
+
60
+ print(f"Worker ID: {worker_id} processing {len(tilenames)} tiles!")
61
+
62
+ write_folder = results_folder
63
+ if bucket_name is not None and write_to_s3:
64
+ data_description = utils.read_json_as_dict(data_description_path)
65
+ dataset_name = data_description.get("name", None)
66
+ if not dataset_name:
67
+ raise ValueError(
68
+ f"Dataset name not found in data_description.json: {data_description_path}"
69
+ )
70
+
71
+ write_folder = (
72
+ f"s3://{bucket_name}/{dataset_name}/image_radial_correction"
73
+ )
74
+
75
+ if input_s3_dataset_path is not None:
76
+ data_folder = f"s3://{bucket_name}/{input_s3_dataset_path}"
77
+
78
+ if len(tilenames):
79
+ radial_correction.main(
80
+ data_folder=data_folder,
81
+ results_folder=write_folder,
82
+ acquisition_path=acquisition_path,
83
+ tilenames=tilenames,
84
+ driver=tensorstore_driver,
85
+ )
86
+
87
+ # Write the output path to a file
88
+ with open(
89
+ f"{results_folder}/output_path_worker_{worker_id}.txt", "w"
90
+ ) as f:
91
+ f.write(write_folder)
92
+
93
+ else:
94
+ print(f"Nothing to do! Tilenames: {tilenames}")
95
+
96
+
97
+ if __name__ == "__main__":
98
+ run()
@@ -0,0 +1,3 @@
1
+ """
2
+ Utility functions
3
+ """