Rhapso 0.1.92__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- Rhapso/__init__.py +1 -0
- Rhapso/data_prep/__init__.py +2 -0
- Rhapso/data_prep/n5_reader.py +188 -0
- Rhapso/data_prep/s3_big_stitcher_reader.py +55 -0
- Rhapso/data_prep/xml_to_dataframe.py +215 -0
- Rhapso/detection/__init__.py +5 -0
- Rhapso/detection/advanced_refinement.py +203 -0
- Rhapso/detection/difference_of_gaussian.py +324 -0
- Rhapso/detection/image_reader.py +117 -0
- Rhapso/detection/metadata_builder.py +130 -0
- Rhapso/detection/overlap_detection.py +327 -0
- Rhapso/detection/points_validation.py +49 -0
- Rhapso/detection/save_interest_points.py +265 -0
- Rhapso/detection/view_transform_models.py +67 -0
- Rhapso/fusion/__init__.py +0 -0
- Rhapso/fusion/affine_fusion/__init__.py +2 -0
- Rhapso/fusion/affine_fusion/blend.py +289 -0
- Rhapso/fusion/affine_fusion/fusion.py +601 -0
- Rhapso/fusion/affine_fusion/geometry.py +159 -0
- Rhapso/fusion/affine_fusion/io.py +546 -0
- Rhapso/fusion/affine_fusion/script_utils.py +111 -0
- Rhapso/fusion/affine_fusion/setup.py +4 -0
- Rhapso/fusion/affine_fusion_worker.py +234 -0
- Rhapso/fusion/multiscale/__init__.py +0 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/__init__.py +19 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/__init__.py +3 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/czi_to_zarr.py +698 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/zarr_writer.py +265 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/models.py +81 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/utils/__init__.py +3 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/utils/utils.py +526 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/zeiss_job.py +249 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/__init__.py +21 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/array_to_zarr.py +257 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/radial_correction.py +557 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/run_capsule.py +98 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/utils/__init__.py +3 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/utils/utils.py +266 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/worker.py +89 -0
- Rhapso/fusion/multiscale_worker.py +113 -0
- Rhapso/fusion/neuroglancer_link_gen/__init__.py +8 -0
- Rhapso/fusion/neuroglancer_link_gen/dispim_link.py +235 -0
- Rhapso/fusion/neuroglancer_link_gen/exaspim_link.py +127 -0
- Rhapso/fusion/neuroglancer_link_gen/hcr_link.py +368 -0
- Rhapso/fusion/neuroglancer_link_gen/iSPIM_top.py +47 -0
- Rhapso/fusion/neuroglancer_link_gen/link_utils.py +239 -0
- Rhapso/fusion/neuroglancer_link_gen/main.py +299 -0
- Rhapso/fusion/neuroglancer_link_gen/ng_layer.py +1434 -0
- Rhapso/fusion/neuroglancer_link_gen/ng_state.py +1123 -0
- Rhapso/fusion/neuroglancer_link_gen/parsers.py +336 -0
- Rhapso/fusion/neuroglancer_link_gen/raw_link.py +116 -0
- Rhapso/fusion/neuroglancer_link_gen/utils/__init__.py +4 -0
- Rhapso/fusion/neuroglancer_link_gen/utils/shader_utils.py +85 -0
- Rhapso/fusion/neuroglancer_link_gen/utils/transfer.py +43 -0
- Rhapso/fusion/neuroglancer_link_gen/utils/utils.py +303 -0
- Rhapso/fusion/neuroglancer_link_gen_worker.py +30 -0
- Rhapso/matching/__init__.py +0 -0
- Rhapso/matching/load_and_transform_points.py +458 -0
- Rhapso/matching/ransac_matching.py +544 -0
- Rhapso/matching/save_matches.py +120 -0
- Rhapso/matching/xml_parser.py +302 -0
- Rhapso/pipelines/__init__.py +0 -0
- Rhapso/pipelines/ray/__init__.py +0 -0
- Rhapso/pipelines/ray/aws/__init__.py +0 -0
- Rhapso/pipelines/ray/aws/alignment_pipeline.py +227 -0
- Rhapso/pipelines/ray/aws/config/__init__.py +0 -0
- Rhapso/pipelines/ray/evaluation.py +71 -0
- Rhapso/pipelines/ray/interest_point_detection.py +137 -0
- Rhapso/pipelines/ray/interest_point_matching.py +110 -0
- Rhapso/pipelines/ray/local/__init__.py +0 -0
- Rhapso/pipelines/ray/local/alignment_pipeline.py +167 -0
- Rhapso/pipelines/ray/matching_stats.py +104 -0
- Rhapso/pipelines/ray/param/__init__.py +0 -0
- Rhapso/pipelines/ray/solver.py +120 -0
- Rhapso/pipelines/ray/split_dataset.py +78 -0
- Rhapso/solver/__init__.py +0 -0
- Rhapso/solver/compute_tiles.py +562 -0
- Rhapso/solver/concatenate_models.py +116 -0
- Rhapso/solver/connected_graphs.py +111 -0
- Rhapso/solver/data_prep.py +181 -0
- Rhapso/solver/global_optimization.py +410 -0
- Rhapso/solver/model_and_tile_setup.py +109 -0
- Rhapso/solver/pre_align_tiles.py +323 -0
- Rhapso/solver/save_results.py +97 -0
- Rhapso/solver/view_transforms.py +75 -0
- Rhapso/solver/xml_to_dataframe_solver.py +213 -0
- Rhapso/split_dataset/__init__.py +0 -0
- Rhapso/split_dataset/compute_grid_rules.py +78 -0
- Rhapso/split_dataset/save_points.py +101 -0
- Rhapso/split_dataset/save_xml.py +377 -0
- Rhapso/split_dataset/split_images.py +537 -0
- Rhapso/split_dataset/xml_to_dataframe_split.py +219 -0
- rhapso-0.1.92.dist-info/METADATA +39 -0
- rhapso-0.1.92.dist-info/RECORD +101 -0
- rhapso-0.1.92.dist-info/WHEEL +5 -0
- rhapso-0.1.92.dist-info/licenses/LICENSE +21 -0
- rhapso-0.1.92.dist-info/top_level.txt +2 -0
- tests/__init__.py +1 -0
- tests/test_detection.py +17 -0
- tests/test_matching.py +21 -0
- tests/test_solving.py +21 -0
|
@@ -0,0 +1,557 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Computes radial correction in microscopic data
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
import multiprocessing as mp
|
|
8
|
+
import time
|
|
9
|
+
import urllib.parse
|
|
10
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
11
|
+
from math import ceil
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import List, Literal, Optional, Tuple, Union
|
|
14
|
+
|
|
15
|
+
import dask.array as da
|
|
16
|
+
import numba as nb
|
|
17
|
+
import numpy as np
|
|
18
|
+
import tensorstore as ts
|
|
19
|
+
from aind_data_schema.core.processing import DataProcess, ProcessName
|
|
20
|
+
from dask.diagnostics import ProgressBar
|
|
21
|
+
from dask.distributed import Client, LocalCluster
|
|
22
|
+
from natsort import natsorted
|
|
23
|
+
from scipy.ndimage import map_coordinates
|
|
24
|
+
|
|
25
|
+
from . import __maintainers__, __pipeline_version__, __url__, __version__
|
|
26
|
+
from .array_to_zarr import convert_array_to_zarr
|
|
27
|
+
from .utils import utils
|
|
28
|
+
|
|
29
|
+
logging.basicConfig(format="%(asctime)s %(message)s", datefmt="%Y-%m-%d %H:%M")
|
|
30
|
+
LOGGER = logging.getLogger(__name__)
|
|
31
|
+
LOGGER.setLevel(logging.INFO)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def calculate_corner_shift_from_pixel_size(XY_pixel_size: float) -> float:
|
|
35
|
+
"""
|
|
36
|
+
Compute the corner shift value based on pixel size.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
XY_pixel_size : float
|
|
41
|
+
The size of a pixel in microns (not currently used).
|
|
42
|
+
|
|
43
|
+
Returns
|
|
44
|
+
-------
|
|
45
|
+
float
|
|
46
|
+
A constant value of 4.3.
|
|
47
|
+
"""
|
|
48
|
+
return 4.3
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def calculate_frac_cutoff_from_pixel_size(XY_pixel_size: float) -> float:
|
|
52
|
+
"""
|
|
53
|
+
Compute the fractional cutoff for radial correction.
|
|
54
|
+
|
|
55
|
+
Parameters
|
|
56
|
+
----------
|
|
57
|
+
XY_pixel_size : float
|
|
58
|
+
The size of a pixel in microns (not currently used).
|
|
59
|
+
|
|
60
|
+
Returns
|
|
61
|
+
-------
|
|
62
|
+
float
|
|
63
|
+
A constant value of 0.5.
|
|
64
|
+
"""
|
|
65
|
+
return 0.5
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# enables multithreading with prange
|
|
69
|
+
@nb.njit(parallel=True)
|
|
70
|
+
def _compute_coordinates(
|
|
71
|
+
pixels: int, cutoff: float, corner_shift: float, edge: int
|
|
72
|
+
) -> tuple:
|
|
73
|
+
"""
|
|
74
|
+
Computes the coordinates where the
|
|
75
|
+
pixels will be moved.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
pixels: int
|
|
80
|
+
Width or height of the image. It is assumed that the image
|
|
81
|
+
has the same resolution in XY.
|
|
82
|
+
|
|
83
|
+
cutoff: float
|
|
84
|
+
Radius beyond which distortion is applied.
|
|
85
|
+
|
|
86
|
+
corner_shift: float
|
|
87
|
+
How much to "pull in" corners beyond the cutoff.
|
|
88
|
+
|
|
89
|
+
edge: int
|
|
90
|
+
Number of pixels to crop from each side (e.g., due to interpolation instability).
|
|
91
|
+
"""
|
|
92
|
+
# coords[0] -> y-coordinates relative to center
|
|
93
|
+
# coords[1] -> x-coordinates relative to center
|
|
94
|
+
coords = np.zeros((2, pixels, pixels), dtype=np.float32)
|
|
95
|
+
|
|
96
|
+
# stores the radius of each pixel from the center.
|
|
97
|
+
r = np.zeros((pixels, pixels), dtype=np.float32)
|
|
98
|
+
|
|
99
|
+
# First pass: calculate centered coordinates and radius for
|
|
100
|
+
# every pixel in the image, parallelized with prange
|
|
101
|
+
for i in nb.prange(pixels):
|
|
102
|
+
for j in range(pixels):
|
|
103
|
+
# Shifts (i, j) so the origin is at the center.
|
|
104
|
+
y = i - pixels // 2
|
|
105
|
+
x = j - pixels // 2
|
|
106
|
+
coords[0, i, j] = y
|
|
107
|
+
coords[1, i, j] = x
|
|
108
|
+
|
|
109
|
+
# Calculates the radius from the center.
|
|
110
|
+
r[i, j] = np.sqrt(x * x + y * y)
|
|
111
|
+
|
|
112
|
+
# Finds the maximum radius,
|
|
113
|
+
# rmax, used to normalize the distortion.
|
|
114
|
+
rmax = r.max()
|
|
115
|
+
|
|
116
|
+
# Second pass: apply radial distortion
|
|
117
|
+
r_piece = np.zeros_like(r)
|
|
118
|
+
angles = np.zeros_like(r)
|
|
119
|
+
|
|
120
|
+
for i in nb.prange(pixels):
|
|
121
|
+
for j in range(pixels):
|
|
122
|
+
r_val = r[i, j]
|
|
123
|
+
# Y X angle, careful with x y
|
|
124
|
+
# coords 0 is y, coords 1 is x
|
|
125
|
+
# Uses arctan2(y, x) to get the angle of the pixel from the center
|
|
126
|
+
angle = np.arctan2(coords[0, i, j], coords[1, i, j])
|
|
127
|
+
|
|
128
|
+
# pixels farther from center than cutoff are pulled outward/inward
|
|
129
|
+
if r_val > cutoff:
|
|
130
|
+
r_val += (r_val - cutoff) * corner_shift / (rmax - cutoff)
|
|
131
|
+
|
|
132
|
+
r_piece[i, j] = r_val
|
|
133
|
+
angles[i, j] = angle
|
|
134
|
+
coords[0, i, j] = r_val * np.sin(angle)
|
|
135
|
+
coords[1, i, j] = r_val * np.cos(angle)
|
|
136
|
+
|
|
137
|
+
# Crop edges and shift to image space
|
|
138
|
+
cropped = coords[:, edge:-edge, edge:-edge]
|
|
139
|
+
|
|
140
|
+
cropped += pixels // 2
|
|
141
|
+
|
|
142
|
+
return cropped
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _process_plane(args):
|
|
146
|
+
"""Helper function to process a single z-plane for parallel execution"""
|
|
147
|
+
z, plane, coords, order = args
|
|
148
|
+
warp_coords = np.zeros((2, *coords[0].shape), dtype=np.float32)
|
|
149
|
+
warp_coords[0] = coords[0]
|
|
150
|
+
warp_coords[1] = coords[1]
|
|
151
|
+
return z, map_coordinates(plane, warp_coords, order=order, mode="constant")
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def radial_correction(
|
|
155
|
+
tile_data: np.ndarray,
|
|
156
|
+
corner_shift: Optional[float] = 5.5,
|
|
157
|
+
frac_cutoff: Optional[float] = 0.5,
|
|
158
|
+
mode: Union[Literal["2d"], Literal["3d"]] = "3d",
|
|
159
|
+
order: int = 1,
|
|
160
|
+
max_workers: Optional[int] = None,
|
|
161
|
+
) -> np.ndarray:
|
|
162
|
+
"""
|
|
163
|
+
Apply radial correction to a tile with optimized performance.
|
|
164
|
+
|
|
165
|
+
Parameters
|
|
166
|
+
----------
|
|
167
|
+
tile_data : np.ndarray
|
|
168
|
+
The 3D tile data (Z, Y, X) to be corrected.
|
|
169
|
+
corner_shift : Optional[float]
|
|
170
|
+
The amount of radial shift to apply (default is 5.5).
|
|
171
|
+
frac_cutoff : Optional[float]
|
|
172
|
+
Fraction of the radius to begin applying correction (default is 0.5).
|
|
173
|
+
mode : Union[Literal["2d"], Literal["3d"]]
|
|
174
|
+
Processing mode - "2d" for plane-wise processing or "3d" for full volume (default is "3d").
|
|
175
|
+
order : int
|
|
176
|
+
Interpolation order for map_coordinates (default is 1).
|
|
177
|
+
max_workers : Optional[int]
|
|
178
|
+
Maximum number of worker threads for parallel processing (default is None, which uses CPU count).
|
|
179
|
+
|
|
180
|
+
Returns
|
|
181
|
+
-------
|
|
182
|
+
np.ndarray
|
|
183
|
+
The corrected tile.
|
|
184
|
+
"""
|
|
185
|
+
edge = ceil(corner_shift / np.sqrt(2)) + 1
|
|
186
|
+
shape = tile_data.shape
|
|
187
|
+
pixels = shape[1] # Assume square XY plane
|
|
188
|
+
cutoff = pixels * frac_cutoff
|
|
189
|
+
|
|
190
|
+
# Compute the warp to transform coordinates using numba
|
|
191
|
+
coords = _compute_coordinates(pixels, cutoff, corner_shift, edge)
|
|
192
|
+
|
|
193
|
+
# Calculate new shape after edge cropping
|
|
194
|
+
new_shape = np.array(shape) - [0, edge * 2, edge * 2]
|
|
195
|
+
LOGGER.info(f"New shape: {new_shape} - Mode {mode} - Cutoff: {cutoff}")
|
|
196
|
+
|
|
197
|
+
# Different processing methods based on mode
|
|
198
|
+
if mode == "2d":
|
|
199
|
+
# Process each z-plane separately in parallel
|
|
200
|
+
result = np.zeros(new_shape, dtype=tile_data.dtype) # dtype=np.uint16)
|
|
201
|
+
|
|
202
|
+
# Use ThreadPoolExecutor for parallel processing of z-planes
|
|
203
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
204
|
+
tasks = [(z, tile_data[z], coords, order) for z in range(shape[0])]
|
|
205
|
+
for z, processed_plane in executor.map(
|
|
206
|
+
lambda args: _process_plane(args), tasks
|
|
207
|
+
):
|
|
208
|
+
# print(f"Tile data dtype: {tile_data[z].dtype} - Processed: {processed_plane.dtype}")
|
|
209
|
+
result[z] = processed_plane
|
|
210
|
+
|
|
211
|
+
return result
|
|
212
|
+
|
|
213
|
+
else: # 3D mode
|
|
214
|
+
# Create full 3D warping coordinates array
|
|
215
|
+
warp_coords = np.zeros((3, *new_shape), dtype=np.float32)
|
|
216
|
+
|
|
217
|
+
# Z coordinates remain unchanged
|
|
218
|
+
for z in range(new_shape[0]):
|
|
219
|
+
warp_coords[0, z] = z
|
|
220
|
+
|
|
221
|
+
# Apply pre-computed X-Y coordinate transformations to each z-plane
|
|
222
|
+
warp_coords[1] = np.repeat(
|
|
223
|
+
coords[0][np.newaxis, :, :], new_shape[0], axis=0
|
|
224
|
+
)
|
|
225
|
+
warp_coords[2] = np.repeat(
|
|
226
|
+
coords[1][np.newaxis, :, :], new_shape[0], axis=0
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# Process the entire volume at once
|
|
230
|
+
return map_coordinates(
|
|
231
|
+
tile_data, warp_coords, order=order, mode="constant"
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def read_zarr(
|
|
236
|
+
dataset_path: str,
|
|
237
|
+
compute: Optional[bool] = True,
|
|
238
|
+
) -> Tuple:
|
|
239
|
+
"""
|
|
240
|
+
Reads a zarr dataset
|
|
241
|
+
|
|
242
|
+
Parameters
|
|
243
|
+
----------
|
|
244
|
+
dataset_path: str
|
|
245
|
+
Path where the dataset is stored.
|
|
246
|
+
|
|
247
|
+
compute: Optional[bool]
|
|
248
|
+
Computes the lazy dask graph.
|
|
249
|
+
Default: True
|
|
250
|
+
|
|
251
|
+
Returns
|
|
252
|
+
-------
|
|
253
|
+
Tuple[ArrayLike, da.Array]
|
|
254
|
+
ArrayLike or None if compute is false
|
|
255
|
+
Lazy dask array
|
|
256
|
+
"""
|
|
257
|
+
tile = None
|
|
258
|
+
|
|
259
|
+
cluster = LocalCluster(
|
|
260
|
+
n_workers=mp.cpu_count(), threads_per_worker=1, memory_limit="auto"
|
|
261
|
+
)
|
|
262
|
+
client = Client(cluster)
|
|
263
|
+
|
|
264
|
+
# Explicitly setting threads to do reading (way faster)
|
|
265
|
+
try:
|
|
266
|
+
tile_lazy = da.from_zarr(dataset_path).squeeze()
|
|
267
|
+
|
|
268
|
+
if compute:
|
|
269
|
+
with ProgressBar():
|
|
270
|
+
tile = tile_lazy.compute(scheduler="threads")
|
|
271
|
+
finally:
|
|
272
|
+
client.close()
|
|
273
|
+
cluster.close()
|
|
274
|
+
|
|
275
|
+
return tile, tile_lazy
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
async def read_zarr_tensorstore(
|
|
279
|
+
dataset_path: str, scale: str, driver: Optional[str] = "zarr"
|
|
280
|
+
) -> Tuple:
|
|
281
|
+
"""
|
|
282
|
+
Reads a zarr dataset from local filesystem or S3 bucket
|
|
283
|
+
Parameters
|
|
284
|
+
----------
|
|
285
|
+
dataset_path: str
|
|
286
|
+
Path where the dataset is stored. Can be a local path or an S3 path (s3://...)
|
|
287
|
+
scale: str
|
|
288
|
+
Multiscale to load
|
|
289
|
+
driver: Optional[str]
|
|
290
|
+
Tensorstore driver
|
|
291
|
+
Default: zarr
|
|
292
|
+
Returns
|
|
293
|
+
-------
|
|
294
|
+
Tuple[ArrayLike, da.Array]
|
|
295
|
+
ArrayLike or None if compute is false
|
|
296
|
+
Lazy dask array
|
|
297
|
+
"""
|
|
298
|
+
# Parse the URL properly using urllib
|
|
299
|
+
parsed_url = urllib.parse.urlparse(dataset_path)
|
|
300
|
+
|
|
301
|
+
if parsed_url.scheme == "s3":
|
|
302
|
+
# Handle S3 path
|
|
303
|
+
bucket = parsed_url.netloc
|
|
304
|
+
# Remove leading slash if present
|
|
305
|
+
key = parsed_url.path.lstrip("/")
|
|
306
|
+
print(parsed_url, bucket, key)
|
|
307
|
+
|
|
308
|
+
ts_spec = {
|
|
309
|
+
"driver": str(driver),
|
|
310
|
+
"kvstore": {
|
|
311
|
+
"driver": "s3",
|
|
312
|
+
"bucket": bucket,
|
|
313
|
+
"path": key,
|
|
314
|
+
},
|
|
315
|
+
"path": str(scale),
|
|
316
|
+
}
|
|
317
|
+
else:
|
|
318
|
+
# Original local file handling
|
|
319
|
+
ts_spec = {
|
|
320
|
+
"driver": str(driver),
|
|
321
|
+
"kvstore": {
|
|
322
|
+
"driver": "file",
|
|
323
|
+
"path": str(dataset_path),
|
|
324
|
+
},
|
|
325
|
+
"path": str(scale),
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
tile_lazy = await ts.open(ts_spec)
|
|
329
|
+
tile = await tile_lazy.read()
|
|
330
|
+
return tile, tile_lazy
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def apply_corr_to_zarr_tile(
|
|
334
|
+
dataset_path: str,
|
|
335
|
+
scale: str,
|
|
336
|
+
corner_shift: Optional[float] = 5.5,
|
|
337
|
+
frac_cutoff: Optional[float] = 0.5,
|
|
338
|
+
z_size_threshold: Optional[int] = 400,
|
|
339
|
+
order: Optional[int] = 1,
|
|
340
|
+
max_workers: Optional[int] = None,
|
|
341
|
+
driver: Optional[str] = "zarr",
|
|
342
|
+
) -> np.ndarray:
|
|
343
|
+
"""
|
|
344
|
+
Load a Zarr tile, apply radial correction, and return corrected tile.
|
|
345
|
+
|
|
346
|
+
Parameters
|
|
347
|
+
----------
|
|
348
|
+
dataset_path : str
|
|
349
|
+
Path to the Zarr file containing the tile.
|
|
350
|
+
|
|
351
|
+
scale: str
|
|
352
|
+
Multiscale to load the data
|
|
353
|
+
|
|
354
|
+
corner_shift : Optional[float]
|
|
355
|
+
The amount of shift to apply to corners (default is 5.5).
|
|
356
|
+
|
|
357
|
+
frac_cutoff : Optional[float]
|
|
358
|
+
The fractional radius where correction starts (default is 0.5).
|
|
359
|
+
|
|
360
|
+
z_size_threshold: Optional[int]
|
|
361
|
+
Threshold in which 3D radial correction is applied.
|
|
362
|
+
|
|
363
|
+
order: Optional[int]
|
|
364
|
+
Interpolation order.
|
|
365
|
+
Default: 1
|
|
366
|
+
|
|
367
|
+
max_workers: Optional[int]
|
|
368
|
+
Max number of workers.
|
|
369
|
+
Default: None
|
|
370
|
+
|
|
371
|
+
driver: Optional[str]
|
|
372
|
+
Zarr driver to read the data.
|
|
373
|
+
Default: zarr
|
|
374
|
+
|
|
375
|
+
Returns
|
|
376
|
+
-------
|
|
377
|
+
np.ndarray
|
|
378
|
+
The corrected tile.
|
|
379
|
+
"""
|
|
380
|
+
if z_size_threshold < 0:
|
|
381
|
+
raise ValueError(
|
|
382
|
+
f"Please, provide a correct threshold: {z_size_threshold}"
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
# Reading zarr dataset
|
|
386
|
+
data_in_memory, lazy_array = asyncio.run(
|
|
387
|
+
read_zarr_tensorstore(dataset_path, scale=scale, driver=driver)
|
|
388
|
+
)
|
|
389
|
+
# data_in_memory, lazy_array = read_zarr(f"{dataset_path}/{scale}", compute=True)
|
|
390
|
+
data_in_memory = data_in_memory.squeeze()
|
|
391
|
+
z_size = data_in_memory.shape[-3]
|
|
392
|
+
|
|
393
|
+
output_radial = None
|
|
394
|
+
|
|
395
|
+
LOGGER.info(f"Dataset shape {data_in_memory.shape}")
|
|
396
|
+
|
|
397
|
+
mode = "2d"
|
|
398
|
+
|
|
399
|
+
if z_size < z_size_threshold:
|
|
400
|
+
mode = "3d"
|
|
401
|
+
|
|
402
|
+
output_radial = radial_correction(
|
|
403
|
+
tile_data=data_in_memory,
|
|
404
|
+
corner_shift=corner_shift,
|
|
405
|
+
frac_cutoff=frac_cutoff,
|
|
406
|
+
mode=mode,
|
|
407
|
+
order=order,
|
|
408
|
+
max_workers=max_workers,
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
# print(f"input radial correction: {data_in_memory.shape} - {data_in_memory.dtype}")
|
|
412
|
+
# print(f"Output radial correction: {output_radial.shape} - {output_radial.dtype}")
|
|
413
|
+
|
|
414
|
+
return output_radial
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def correct_and_save_tile(
|
|
418
|
+
dataset_loc: str,
|
|
419
|
+
output_path: str,
|
|
420
|
+
resolution_zyx: List[float],
|
|
421
|
+
scale: str = "0",
|
|
422
|
+
n_lvls: Optional[int] = 4,
|
|
423
|
+
driver: Optional[str] = "zarr",
|
|
424
|
+
):
|
|
425
|
+
"""
|
|
426
|
+
Corrects and saves a single tile.
|
|
427
|
+
|
|
428
|
+
Parameters
|
|
429
|
+
----------
|
|
430
|
+
dataset_loc: str
|
|
431
|
+
Path to the dataset to be corrected.
|
|
432
|
+
output_path: str
|
|
433
|
+
Path to save the corrected dataset.
|
|
434
|
+
resolution_zyx: List[float]
|
|
435
|
+
Voxel size in the format [z, y, x].
|
|
436
|
+
scale: str
|
|
437
|
+
Multiscale to load the data.
|
|
438
|
+
Default: 0
|
|
439
|
+
n_lvls: Optional[int]
|
|
440
|
+
Number of downsampled levels to write.
|
|
441
|
+
Default: 4
|
|
442
|
+
s3_output_path: str
|
|
443
|
+
Dataset name in S3.
|
|
444
|
+
Default: None
|
|
445
|
+
cloud_write: bool
|
|
446
|
+
If True, write to S3.
|
|
447
|
+
Default: True
|
|
448
|
+
driver: Optional[str]
|
|
449
|
+
Driver to read the data with tensorstore.
|
|
450
|
+
Default: "zarr"
|
|
451
|
+
"""
|
|
452
|
+
|
|
453
|
+
corner_shift = calculate_corner_shift_from_pixel_size(resolution_zyx[1])
|
|
454
|
+
frac_cutoff = calculate_frac_cutoff_from_pixel_size(resolution_zyx[1])
|
|
455
|
+
|
|
456
|
+
LOGGER.info(f"Input: {dataset_loc} - Output: {output_path}")
|
|
457
|
+
LOGGER.info(f"Corner Shift: {corner_shift} pixels")
|
|
458
|
+
LOGGER.info(f"Fraction Cutoff: {frac_cutoff}")
|
|
459
|
+
|
|
460
|
+
start_time = time.time()
|
|
461
|
+
corrected_tile = apply_corr_to_zarr_tile(
|
|
462
|
+
dataset_loc, scale, corner_shift, frac_cutoff, driver=driver
|
|
463
|
+
)
|
|
464
|
+
end_time = time.time()
|
|
465
|
+
LOGGER.info(
|
|
466
|
+
f"Time to correct: {end_time - start_time} seconds -> New shape {corrected_tile.shape}"
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
convert_array_to_zarr(
|
|
470
|
+
array=corrected_tile,
|
|
471
|
+
voxel_size=resolution_zyx,
|
|
472
|
+
chunk_size=[128] * 3,
|
|
473
|
+
output_path=str(output_path),
|
|
474
|
+
n_lvls=n_lvls,
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
data_process = None
|
|
478
|
+
# TODO: activate this when aind-data-schema 2.0 is out
|
|
479
|
+
# DataProcess(
|
|
480
|
+
# name=ProcessName.IMAGE_RADIAL_CORRECTION,
|
|
481
|
+
# software_version=__version__,
|
|
482
|
+
# start_date_time=start_time,
|
|
483
|
+
# end_date_time=end_time,
|
|
484
|
+
# input_location=dataset_loc,
|
|
485
|
+
# output_location=output_path,
|
|
486
|
+
# code_version=__version__,
|
|
487
|
+
# code_url=__url__,
|
|
488
|
+
# parameters={
|
|
489
|
+
# 'corner_shift': corner_shift,
|
|
490
|
+
# 'frac_cutoff': frac_cutoff
|
|
491
|
+
# },
|
|
492
|
+
# )
|
|
493
|
+
|
|
494
|
+
return data_process
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
def main(
|
|
498
|
+
data_folder: str,
|
|
499
|
+
results_folder: str,
|
|
500
|
+
acquisition_path: str,
|
|
501
|
+
tilenames: List[str],
|
|
502
|
+
driver: Optional[str] = "zarr",
|
|
503
|
+
):
|
|
504
|
+
"""
|
|
505
|
+
Radial correction to multiple tiles
|
|
506
|
+
based on provided YMLs.
|
|
507
|
+
|
|
508
|
+
Parameters
|
|
509
|
+
----------
|
|
510
|
+
data_folder: str
|
|
511
|
+
Folder where the data is stored.
|
|
512
|
+
|
|
513
|
+
results_folder: str
|
|
514
|
+
Results folder. It could be a local path or
|
|
515
|
+
a S3 bucket.
|
|
516
|
+
|
|
517
|
+
acquisition_path: str
|
|
518
|
+
Path where the acquisition.json is.
|
|
519
|
+
|
|
520
|
+
tilenames: List[str]
|
|
521
|
+
Tiles to process. E.g.,
|
|
522
|
+
[Tile_X_000...ome.zarr, ..., ]
|
|
523
|
+
|
|
524
|
+
driver: Optional[str]
|
|
525
|
+
Driver to read the data with tensorstore
|
|
526
|
+
Default: "zarr"
|
|
527
|
+
|
|
528
|
+
"""
|
|
529
|
+
zyx_voxel_size = utils.get_voxel_resolution(
|
|
530
|
+
acquisition_path=acquisition_path
|
|
531
|
+
)
|
|
532
|
+
LOGGER.info(f"Voxel ZYX resolution: {zyx_voxel_size}")
|
|
533
|
+
|
|
534
|
+
data_processes = []
|
|
535
|
+
for tilename in tilenames:
|
|
536
|
+
curr_tilename = tilename
|
|
537
|
+
zarr_path = f"{data_folder}/{tilename}"
|
|
538
|
+
output_path = f"{results_folder}/{curr_tilename}"
|
|
539
|
+
data_process = correct_and_save_tile(
|
|
540
|
+
dataset_loc=zarr_path,
|
|
541
|
+
output_path=output_path,
|
|
542
|
+
resolution_zyx=zyx_voxel_size,
|
|
543
|
+
n_lvls = 6,
|
|
544
|
+
driver=driver,
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
# utils.generate_processing(
|
|
548
|
+
# data_processes=data_processes,
|
|
549
|
+
# dest_processing=results_folder,
|
|
550
|
+
# processor_full_name=__maintainers__[0],
|
|
551
|
+
# pipeline_version=__pipeline_version__,
|
|
552
|
+
# prefix='radial_correction'
|
|
553
|
+
# )
|
|
554
|
+
|
|
555
|
+
|
|
556
|
+
if __name__ == "__main__":
|
|
557
|
+
main()
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Runs radial correction in a set of tiles provided
|
|
3
|
+
to the data directory
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
import sys
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
from aind_z1_radial_correction import radial_correction
|
|
11
|
+
from aind_z1_radial_correction.utils import utils
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def run():
|
|
15
|
+
"""
|
|
16
|
+
Main run file in Code Ocean
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
data_folder = os.path.abspath("../data")
|
|
20
|
+
results_folder = os.path.abspath("../results")
|
|
21
|
+
|
|
22
|
+
# Worker scheduler path has radial parameters and acquisition
|
|
23
|
+
worker_scheduler_path = list(Path(data_folder).glob("worker*"))[0]
|
|
24
|
+
|
|
25
|
+
acquisition_path = f"{worker_scheduler_path}/acquisition.json"
|
|
26
|
+
data_description_path = f"{worker_scheduler_path}/data_description.json"
|
|
27
|
+
|
|
28
|
+
radial_correction_parameters_path = (
|
|
29
|
+
f"{worker_scheduler_path}/radial_correction_parameters.json"
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
required_input_elements = [
|
|
33
|
+
acquisition_path,
|
|
34
|
+
radial_correction_parameters_path,
|
|
35
|
+
data_description_path,
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
missing_files = utils.validate_capsule_inputs(required_input_elements)
|
|
39
|
+
|
|
40
|
+
if len(missing_files):
|
|
41
|
+
raise ValueError(
|
|
42
|
+
f"We miss the following files in the capsule input: {missing_files}"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
radial_correction_parameters = utils.read_json_as_dict(
|
|
46
|
+
radial_correction_parameters_path
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
tilenames = radial_correction_parameters.get("tilenames", [])
|
|
50
|
+
worker_id = radial_correction_parameters.get("worker_id", None)
|
|
51
|
+
bucket_name = radial_correction_parameters.get("bucket_name", None)
|
|
52
|
+
input_s3_dataset_path = radial_correction_parameters.get(
|
|
53
|
+
"input_s3_dataset_path", None
|
|
54
|
+
)
|
|
55
|
+
tensorstore_driver = radial_correction_parameters.get(
|
|
56
|
+
"tensorstore_driver", "zarr3"
|
|
57
|
+
)
|
|
58
|
+
write_to_s3 = radial_correction_parameters.get("write_to_s3", True)
|
|
59
|
+
|
|
60
|
+
print(f"Worker ID: {worker_id} processing {len(tilenames)} tiles!")
|
|
61
|
+
|
|
62
|
+
write_folder = results_folder
|
|
63
|
+
if bucket_name is not None and write_to_s3:
|
|
64
|
+
data_description = utils.read_json_as_dict(data_description_path)
|
|
65
|
+
dataset_name = data_description.get("name", None)
|
|
66
|
+
if not dataset_name:
|
|
67
|
+
raise ValueError(
|
|
68
|
+
f"Dataset name not found in data_description.json: {data_description_path}"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
write_folder = (
|
|
72
|
+
f"s3://{bucket_name}/{dataset_name}/image_radial_correction"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
if input_s3_dataset_path is not None:
|
|
76
|
+
data_folder = f"s3://{bucket_name}/{input_s3_dataset_path}"
|
|
77
|
+
|
|
78
|
+
if len(tilenames):
|
|
79
|
+
radial_correction.main(
|
|
80
|
+
data_folder=data_folder,
|
|
81
|
+
results_folder=write_folder,
|
|
82
|
+
acquisition_path=acquisition_path,
|
|
83
|
+
tilenames=tilenames,
|
|
84
|
+
driver=tensorstore_driver,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Write the output path to a file
|
|
88
|
+
with open(
|
|
89
|
+
f"{results_folder}/output_path_worker_{worker_id}.txt", "w"
|
|
90
|
+
) as f:
|
|
91
|
+
f.write(write_folder)
|
|
92
|
+
|
|
93
|
+
else:
|
|
94
|
+
print(f"Nothing to do! Tilenames: {tilenames}")
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
if __name__ == "__main__":
|
|
98
|
+
run()
|