Rhapso 0.1.92__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- Rhapso/__init__.py +1 -0
- Rhapso/data_prep/__init__.py +2 -0
- Rhapso/data_prep/n5_reader.py +188 -0
- Rhapso/data_prep/s3_big_stitcher_reader.py +55 -0
- Rhapso/data_prep/xml_to_dataframe.py +215 -0
- Rhapso/detection/__init__.py +5 -0
- Rhapso/detection/advanced_refinement.py +203 -0
- Rhapso/detection/difference_of_gaussian.py +324 -0
- Rhapso/detection/image_reader.py +117 -0
- Rhapso/detection/metadata_builder.py +130 -0
- Rhapso/detection/overlap_detection.py +327 -0
- Rhapso/detection/points_validation.py +49 -0
- Rhapso/detection/save_interest_points.py +265 -0
- Rhapso/detection/view_transform_models.py +67 -0
- Rhapso/fusion/__init__.py +0 -0
- Rhapso/fusion/affine_fusion/__init__.py +2 -0
- Rhapso/fusion/affine_fusion/blend.py +289 -0
- Rhapso/fusion/affine_fusion/fusion.py +601 -0
- Rhapso/fusion/affine_fusion/geometry.py +159 -0
- Rhapso/fusion/affine_fusion/io.py +546 -0
- Rhapso/fusion/affine_fusion/script_utils.py +111 -0
- Rhapso/fusion/affine_fusion/setup.py +4 -0
- Rhapso/fusion/affine_fusion_worker.py +234 -0
- Rhapso/fusion/multiscale/__init__.py +0 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/__init__.py +19 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/__init__.py +3 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/czi_to_zarr.py +698 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/zarr_writer.py +265 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/models.py +81 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/utils/__init__.py +3 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/utils/utils.py +526 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/zeiss_job.py +249 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/__init__.py +21 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/array_to_zarr.py +257 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/radial_correction.py +557 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/run_capsule.py +98 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/utils/__init__.py +3 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/utils/utils.py +266 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/worker.py +89 -0
- Rhapso/fusion/multiscale_worker.py +113 -0
- Rhapso/fusion/neuroglancer_link_gen/__init__.py +8 -0
- Rhapso/fusion/neuroglancer_link_gen/dispim_link.py +235 -0
- Rhapso/fusion/neuroglancer_link_gen/exaspim_link.py +127 -0
- Rhapso/fusion/neuroglancer_link_gen/hcr_link.py +368 -0
- Rhapso/fusion/neuroglancer_link_gen/iSPIM_top.py +47 -0
- Rhapso/fusion/neuroglancer_link_gen/link_utils.py +239 -0
- Rhapso/fusion/neuroglancer_link_gen/main.py +299 -0
- Rhapso/fusion/neuroglancer_link_gen/ng_layer.py +1434 -0
- Rhapso/fusion/neuroglancer_link_gen/ng_state.py +1123 -0
- Rhapso/fusion/neuroglancer_link_gen/parsers.py +336 -0
- Rhapso/fusion/neuroglancer_link_gen/raw_link.py +116 -0
- Rhapso/fusion/neuroglancer_link_gen/utils/__init__.py +4 -0
- Rhapso/fusion/neuroglancer_link_gen/utils/shader_utils.py +85 -0
- Rhapso/fusion/neuroglancer_link_gen/utils/transfer.py +43 -0
- Rhapso/fusion/neuroglancer_link_gen/utils/utils.py +303 -0
- Rhapso/fusion/neuroglancer_link_gen_worker.py +30 -0
- Rhapso/matching/__init__.py +0 -0
- Rhapso/matching/load_and_transform_points.py +458 -0
- Rhapso/matching/ransac_matching.py +544 -0
- Rhapso/matching/save_matches.py +120 -0
- Rhapso/matching/xml_parser.py +302 -0
- Rhapso/pipelines/__init__.py +0 -0
- Rhapso/pipelines/ray/__init__.py +0 -0
- Rhapso/pipelines/ray/aws/__init__.py +0 -0
- Rhapso/pipelines/ray/aws/alignment_pipeline.py +227 -0
- Rhapso/pipelines/ray/aws/config/__init__.py +0 -0
- Rhapso/pipelines/ray/evaluation.py +71 -0
- Rhapso/pipelines/ray/interest_point_detection.py +137 -0
- Rhapso/pipelines/ray/interest_point_matching.py +110 -0
- Rhapso/pipelines/ray/local/__init__.py +0 -0
- Rhapso/pipelines/ray/local/alignment_pipeline.py +167 -0
- Rhapso/pipelines/ray/matching_stats.py +104 -0
- Rhapso/pipelines/ray/param/__init__.py +0 -0
- Rhapso/pipelines/ray/solver.py +120 -0
- Rhapso/pipelines/ray/split_dataset.py +78 -0
- Rhapso/solver/__init__.py +0 -0
- Rhapso/solver/compute_tiles.py +562 -0
- Rhapso/solver/concatenate_models.py +116 -0
- Rhapso/solver/connected_graphs.py +111 -0
- Rhapso/solver/data_prep.py +181 -0
- Rhapso/solver/global_optimization.py +410 -0
- Rhapso/solver/model_and_tile_setup.py +109 -0
- Rhapso/solver/pre_align_tiles.py +323 -0
- Rhapso/solver/save_results.py +97 -0
- Rhapso/solver/view_transforms.py +75 -0
- Rhapso/solver/xml_to_dataframe_solver.py +213 -0
- Rhapso/split_dataset/__init__.py +0 -0
- Rhapso/split_dataset/compute_grid_rules.py +78 -0
- Rhapso/split_dataset/save_points.py +101 -0
- Rhapso/split_dataset/save_xml.py +377 -0
- Rhapso/split_dataset/split_images.py +537 -0
- Rhapso/split_dataset/xml_to_dataframe_split.py +219 -0
- rhapso-0.1.92.dist-info/METADATA +39 -0
- rhapso-0.1.92.dist-info/RECORD +101 -0
- rhapso-0.1.92.dist-info/WHEEL +5 -0
- rhapso-0.1.92.dist-info/licenses/LICENSE +21 -0
- rhapso-0.1.92.dist-info/top_level.txt +2 -0
- tests/__init__.py +1 -0
- tests/test_detection.py +17 -0
- tests/test_matching.py +21 -0
- tests/test_solving.py +21 -0
|
@@ -0,0 +1,601 @@
|
|
|
1
|
+
"""Core fusion algorithm."""
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
import dask.array as da
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
import s3fs
|
|
8
|
+
import zarr
|
|
9
|
+
import tensorstore as ts
|
|
10
|
+
import ray
|
|
11
|
+
|
|
12
|
+
from . import blend, geometry, io
|
|
13
|
+
|
|
14
|
+
def initialize_fusion(
|
|
15
|
+
dataset: io.Dataset,
|
|
16
|
+
post_reg_tfms: list[geometry.Transform],
|
|
17
|
+
output_params: io.OutputParameters,
|
|
18
|
+
) -> tuple[dict, dict, dict, dict, tuple, tuple, torch.Tensor]:
|
|
19
|
+
"""
|
|
20
|
+
Creates all core fusion data structures and key algorithm inputs.
|
|
21
|
+
|
|
22
|
+
Inputs
|
|
23
|
+
------
|
|
24
|
+
Dataset, OutputParameters application primitives.
|
|
25
|
+
|
|
26
|
+
Returns
|
|
27
|
+
-------
|
|
28
|
+
tile_arrays: Dictionary of input tile arrays
|
|
29
|
+
tile_transforms: Dictionary of (list of) registrations associated with each tile
|
|
30
|
+
tile_sizes: Dictionary of tile sizes
|
|
31
|
+
tile_aabbs: Dictionary of AABB of each transformed tile
|
|
32
|
+
output_volume_size: Size of output volume
|
|
33
|
+
output_volume_origin: Location of output volume
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
tile_arrays, tile_paths = dataset.tile_volumes_tczyx
|
|
37
|
+
|
|
38
|
+
tile_transforms: dict[
|
|
39
|
+
int, list[geometry.Transform]
|
|
40
|
+
] = dataset.tile_transforms_zyx
|
|
41
|
+
input_resolution_zyx: tuple[
|
|
42
|
+
float, float, float
|
|
43
|
+
] = dataset.tile_resolution_zyx
|
|
44
|
+
iz, iy, ix = input_resolution_zyx
|
|
45
|
+
scale_input_zyx = geometry.Affine(
|
|
46
|
+
np.array([[iz, 0, 0, 0], [0, iy, 0, 0], [0, 0, ix, 0]])
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
output_resolution_zyx: tuple[
|
|
50
|
+
float, float, float
|
|
51
|
+
] = output_params.resolution_zyx
|
|
52
|
+
oz, oy, ox = output_resolution_zyx
|
|
53
|
+
sample_output_zyx = geometry.Affine(
|
|
54
|
+
np.array([[1 / oz, 0, 0, 0], [0, 1 / oy, 0, 0], [0, 0, 1 / ox, 0]])
|
|
55
|
+
)
|
|
56
|
+
for tile_id, tfm_list in tile_transforms.items():
|
|
57
|
+
tile_transforms[tile_id] = [
|
|
58
|
+
*tfm_list,
|
|
59
|
+
scale_input_zyx,
|
|
60
|
+
*post_reg_tfms,
|
|
61
|
+
sample_output_zyx,
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
tile_sizes_zyx: dict[int, tuple[int, int, int]] = {}
|
|
65
|
+
tile_aabbs: dict[int, geometry.AABB] = {}
|
|
66
|
+
tile_boundary_point_cloud_zyx = []
|
|
67
|
+
|
|
68
|
+
for tile_id, tile_arr in tile_arrays.items():
|
|
69
|
+
tile_sizes_zyx[tile_id] = zyx = tile_arr.shape[2:]
|
|
70
|
+
tile_sizes_zyx[tile_id] = zyx = tile_arr.shape[2:]
|
|
71
|
+
tile_boundaries = torch.Tensor(
|
|
72
|
+
[
|
|
73
|
+
[0.0, 0.0, 0.0],
|
|
74
|
+
[zyx[0], 0.0, 0.0],
|
|
75
|
+
[0.0, zyx[1], 0.0],
|
|
76
|
+
[0.0, 0.0, zyx[2]],
|
|
77
|
+
[zyx[0], zyx[1], 0.0],
|
|
78
|
+
[zyx[0], 0.0, zyx[2]],
|
|
79
|
+
[0.0, zyx[1], zyx[2]],
|
|
80
|
+
[zyx[0], zyx[1], zyx[2]],
|
|
81
|
+
]
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
tfm_list = tile_transforms[tile_id]
|
|
85
|
+
for i, tfm in enumerate(tfm_list):
|
|
86
|
+
tile_boundaries = tfm.forward(
|
|
87
|
+
tile_boundaries, device=torch.device("cpu")
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
tile_aabbs[tile_id] = geometry.aabb_3d(tile_boundaries)
|
|
91
|
+
tile_boundary_point_cloud_zyx.extend(tile_boundaries)
|
|
92
|
+
tile_boundary_point_cloud_zyx = torch.stack(
|
|
93
|
+
tile_boundary_point_cloud_zyx, dim=0
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Resolve Output Volume Dimensions and Absolute Position
|
|
97
|
+
global_tile_boundaries = geometry.aabb_3d(tile_boundary_point_cloud_zyx)
|
|
98
|
+
OUTPUT_VOLUME_SIZE = [
|
|
99
|
+
int(global_tile_boundaries[1] - global_tile_boundaries[0]),
|
|
100
|
+
int(global_tile_boundaries[3] - global_tile_boundaries[2]),
|
|
101
|
+
int(global_tile_boundaries[5] - global_tile_boundaries[4]),
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
# Rounding up the OUTPUT_VOLUME_SIZE to the nearest chunk
|
|
105
|
+
# b/c zarr-python has occasional errors writing at the boundaries.
|
|
106
|
+
# This ensures a multiple of chunksize without losing data.
|
|
107
|
+
remainder_0 = OUTPUT_VOLUME_SIZE[0] % output_params.chunksize[2]
|
|
108
|
+
remainder_1 = OUTPUT_VOLUME_SIZE[1] % output_params.chunksize[3]
|
|
109
|
+
remainder_2 = OUTPUT_VOLUME_SIZE[2] % output_params.chunksize[4]
|
|
110
|
+
if remainder_0 > 0:
|
|
111
|
+
OUTPUT_VOLUME_SIZE[0] -= remainder_0
|
|
112
|
+
OUTPUT_VOLUME_SIZE[0] += output_params.chunksize[2]
|
|
113
|
+
if remainder_1 > 0:
|
|
114
|
+
OUTPUT_VOLUME_SIZE[1] -= remainder_1
|
|
115
|
+
OUTPUT_VOLUME_SIZE[1] += output_params.chunksize[3]
|
|
116
|
+
if remainder_2 > 0:
|
|
117
|
+
OUTPUT_VOLUME_SIZE[2] -= remainder_2
|
|
118
|
+
OUTPUT_VOLUME_SIZE[2] += output_params.chunksize[4]
|
|
119
|
+
OUTPUT_VOLUME_SIZE = tuple(OUTPUT_VOLUME_SIZE)
|
|
120
|
+
|
|
121
|
+
OUTPUT_VOLUME_ORIGIN = (
|
|
122
|
+
torch.min(tile_boundary_point_cloud_zyx[:, 0]).item(),
|
|
123
|
+
torch.min(tile_boundary_point_cloud_zyx[:, 1]).item(),
|
|
124
|
+
torch.min(tile_boundary_point_cloud_zyx[:, 2]).item(),
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Shift AABB's into Output Volume where
|
|
128
|
+
# absolute position of output volume is moved to (0, 0, 0)
|
|
129
|
+
for tile_id, t_aabb in tile_aabbs.items():
|
|
130
|
+
updated_aabb = (
|
|
131
|
+
t_aabb[0] - OUTPUT_VOLUME_ORIGIN[0],
|
|
132
|
+
t_aabb[1] - OUTPUT_VOLUME_ORIGIN[0],
|
|
133
|
+
t_aabb[2] - OUTPUT_VOLUME_ORIGIN[1],
|
|
134
|
+
t_aabb[3] - OUTPUT_VOLUME_ORIGIN[1],
|
|
135
|
+
t_aabb[4] - OUTPUT_VOLUME_ORIGIN[2],
|
|
136
|
+
t_aabb[5] - OUTPUT_VOLUME_ORIGIN[2],
|
|
137
|
+
)
|
|
138
|
+
tile_aabbs[tile_id] = updated_aabb
|
|
139
|
+
|
|
140
|
+
return (
|
|
141
|
+
tile_arrays,
|
|
142
|
+
tile_paths,
|
|
143
|
+
tile_transforms,
|
|
144
|
+
tile_sizes_zyx,
|
|
145
|
+
tile_aabbs,
|
|
146
|
+
OUTPUT_VOLUME_SIZE,
|
|
147
|
+
OUTPUT_VOLUME_ORIGIN,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def initialize_output_volume_dask(
|
|
152
|
+
output_params: io.OutputParameters,
|
|
153
|
+
output_volume_size: tuple[int, int, int],
|
|
154
|
+
) -> zarr.core.Array:
|
|
155
|
+
"""
|
|
156
|
+
Self-documentation of output store initialization.
|
|
157
|
+
|
|
158
|
+
Inputs
|
|
159
|
+
------
|
|
160
|
+
output_params: OutputParameters application instance.
|
|
161
|
+
output_volume_size: output of initalize_data_structures(...)
|
|
162
|
+
|
|
163
|
+
Returns
|
|
164
|
+
-------
|
|
165
|
+
Zarr thread-safe datastore initialized on OutputParameters.
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
# Local execution
|
|
169
|
+
out_group = zarr.open_group(output_params.path, mode="w")
|
|
170
|
+
|
|
171
|
+
# Cloud execuion
|
|
172
|
+
if output_params.path.startswith('s3'):
|
|
173
|
+
s3 = s3fs.S3FileSystem(
|
|
174
|
+
config_kwargs={
|
|
175
|
+
'max_pool_connections': 50,
|
|
176
|
+
's3': {
|
|
177
|
+
'multipart_threshold': 64 * 1024 * 1024, # 64 MB, avoid multipart upload for small chunks
|
|
178
|
+
'max_concurrent_requests': 20 # Increased from 10 -> 20.
|
|
179
|
+
},
|
|
180
|
+
'retries': {
|
|
181
|
+
'total_max_attempts': 100,
|
|
182
|
+
'mode': 'adaptive',
|
|
183
|
+
}
|
|
184
|
+
}
|
|
185
|
+
)
|
|
186
|
+
store = s3fs.S3Map(root=output_params.path, s3=s3)
|
|
187
|
+
out_group = zarr.open(store=store, mode='a')
|
|
188
|
+
|
|
189
|
+
path = "0"
|
|
190
|
+
chunksize = output_params.chunksize
|
|
191
|
+
datatype = output_params.dtype
|
|
192
|
+
dimension_separator = "/"
|
|
193
|
+
compressor = output_params.compressor
|
|
194
|
+
output_volume = out_group.create_dataset(
|
|
195
|
+
path,
|
|
196
|
+
shape=(
|
|
197
|
+
1,
|
|
198
|
+
1,
|
|
199
|
+
output_volume_size[0],
|
|
200
|
+
output_volume_size[1],
|
|
201
|
+
output_volume_size[2],
|
|
202
|
+
),
|
|
203
|
+
chunks=chunksize,
|
|
204
|
+
dtype=datatype,
|
|
205
|
+
compressor=compressor,
|
|
206
|
+
dimension_separator=dimension_separator,
|
|
207
|
+
overwrite=True,
|
|
208
|
+
fill_value=0,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
return output_volume
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def initialize_output_volume_tensorstore(
|
|
215
|
+
output_params: io.OutputParameters,
|
|
216
|
+
output_volume_size: tuple[int, int, int],
|
|
217
|
+
):
|
|
218
|
+
"""
|
|
219
|
+
The output is an async Tensorstore obj that you need
|
|
220
|
+
to call .result() to perform a write.
|
|
221
|
+
"""
|
|
222
|
+
parts = output_params.path.split('/')
|
|
223
|
+
bucket = parts[2]
|
|
224
|
+
path = '/'.join(parts[3:])
|
|
225
|
+
chunksize = list(output_params.chunksize)
|
|
226
|
+
output_shape = [1,
|
|
227
|
+
1,
|
|
228
|
+
output_volume_size[0],
|
|
229
|
+
output_volume_size[1],
|
|
230
|
+
output_volume_size[2]]
|
|
231
|
+
|
|
232
|
+
return ts.open({
|
|
233
|
+
'driver': 'zarr',
|
|
234
|
+
'dtype': 'uint16',
|
|
235
|
+
'kvstore' : {
|
|
236
|
+
'driver': 's3',
|
|
237
|
+
'bucket': bucket,
|
|
238
|
+
'path': path,
|
|
239
|
+
},
|
|
240
|
+
'create': True,
|
|
241
|
+
'open': True,
|
|
242
|
+
'metadata': {
|
|
243
|
+
'chunks': chunksize,
|
|
244
|
+
'compressor': {
|
|
245
|
+
'blocksize': 0,
|
|
246
|
+
'clevel': 1,
|
|
247
|
+
'cname': 'zstd',
|
|
248
|
+
'id': 'blosc',
|
|
249
|
+
'shuffle': 1,
|
|
250
|
+
},
|
|
251
|
+
'dimension_separator': '/',
|
|
252
|
+
'dtype': '<u2',
|
|
253
|
+
'fill_value': 0,
|
|
254
|
+
'filters': None,
|
|
255
|
+
'order': 'C',
|
|
256
|
+
'shape': output_shape,
|
|
257
|
+
'zarr_format': 2
|
|
258
|
+
}
|
|
259
|
+
}).result()
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def initialize_output_volume(
|
|
263
|
+
output_params: io.OutputParameters,
|
|
264
|
+
output_volume_size: tuple[int, int, int],
|
|
265
|
+
) -> io.OutputArray:
|
|
266
|
+
output = None
|
|
267
|
+
|
|
268
|
+
assert output_params.datastore in [0, 1], \
|
|
269
|
+
f"Only 0 = Dask and 1 = Tensorstore supported."
|
|
270
|
+
if output_params.datastore == 0:
|
|
271
|
+
output = initialize_output_volume_dask(output_params, output_volume_size)
|
|
272
|
+
elif output_params.datastore == 1:
|
|
273
|
+
output = initialize_output_volume_tensorstore(output_params, output_volume_size)
|
|
274
|
+
return output
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def get_cell_count_zyx(
|
|
278
|
+
output_volume_size: tuple[int, int, int], cell_size: tuple[int, int, int]
|
|
279
|
+
) -> tuple[int, int, int]:
|
|
280
|
+
"""
|
|
281
|
+
Total amount of z,y, and x cells returned in that order.
|
|
282
|
+
Input sizes are in canonical zyx order.
|
|
283
|
+
"""
|
|
284
|
+
z_cnt = int(np.ceil(output_volume_size[0] / cell_size[0]))
|
|
285
|
+
y_cnt = int(np.ceil(output_volume_size[1] / cell_size[1]))
|
|
286
|
+
x_cnt = int(np.ceil(output_volume_size[2] / cell_size[2]))
|
|
287
|
+
|
|
288
|
+
return z_cnt, y_cnt, x_cnt
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def run_fusion(
|
|
292
|
+
# client, # Uncomment for testing in jupyterlab
|
|
293
|
+
dataset: io.Dataset,
|
|
294
|
+
output_params: io.OutputParameters,
|
|
295
|
+
cell_size: tuple[int, int, int],
|
|
296
|
+
post_reg_tfms: list[geometry.Affine],
|
|
297
|
+
blend_module: blend.BlendingModule,
|
|
298
|
+
):
|
|
299
|
+
"""
|
|
300
|
+
Fusion algorithm.
|
|
301
|
+
Inputs: Application objs initalized from input configurations.
|
|
302
|
+
Output: Writes to location in output params.
|
|
303
|
+
"""
|
|
304
|
+
|
|
305
|
+
logging.basicConfig(
|
|
306
|
+
format="%(asctime)s %(message)s", datefmt="%Y-%m-%d %H:%M"
|
|
307
|
+
)
|
|
308
|
+
LOGGER = logging.getLogger(__name__)
|
|
309
|
+
LOGGER.setLevel(logging.INFO)
|
|
310
|
+
|
|
311
|
+
a, p, b, c, d, e, f = initialize_fusion(dataset, post_reg_tfms, output_params)
|
|
312
|
+
tile_arrays = a
|
|
313
|
+
tile_paths = p
|
|
314
|
+
tile_transforms = b
|
|
315
|
+
tile_sizes_zyx = c
|
|
316
|
+
tile_aabbs = d
|
|
317
|
+
output_volume_size = e
|
|
318
|
+
output_volume_origin = f # Temp variables to meet line character maximum.
|
|
319
|
+
|
|
320
|
+
output_volume = initialize_output_volume(output_params, output_volume_size)
|
|
321
|
+
|
|
322
|
+
LOGGER.info(f"Number of Tiles: {len(tile_arrays)}")
|
|
323
|
+
LOGGER.info(f"{output_volume_size=}")
|
|
324
|
+
|
|
325
|
+
store = output_volume.store
|
|
326
|
+
write_root = getattr(store, "root", None) or getattr(store, "path", None)
|
|
327
|
+
write_ds = output_volume.path
|
|
328
|
+
|
|
329
|
+
z_cnt, y_cnt, x_cnt = get_cell_count_zyx(output_volume_size, cell_size)
|
|
330
|
+
cells = [(z, y, x) for z in range(z_cnt) for y in range(y_cnt) for x in range(x_cnt)]
|
|
331
|
+
num_cells = len(cells)
|
|
332
|
+
LOGGER.info(f'Coloring {num_cells} cells')
|
|
333
|
+
|
|
334
|
+
@ray.remote
|
|
335
|
+
def process_color_cell(curr_cell, tile_paths, write_root, write_ds, tile_transforms,
|
|
336
|
+
tile_sizes_zyx, tile_aabbs, output_volume, output_volume_origin, cell_size,
|
|
337
|
+
blend_module
|
|
338
|
+
):
|
|
339
|
+
z, y, x = curr_cell
|
|
340
|
+
color_cell(tile_paths, write_root, write_ds, tile_transforms, tile_sizes_zyx, tile_aabbs,
|
|
341
|
+
output_volume, output_volume_origin, cell_size, blend_module, z, y, x,
|
|
342
|
+
torch.device("cpu")
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
return {"cell": curr_cell}
|
|
346
|
+
|
|
347
|
+
# submit exactly like your loop, one task per cell
|
|
348
|
+
futures = [
|
|
349
|
+
process_color_cell.remote((z, y, x), tile_paths, write_root, write_ds,
|
|
350
|
+
tile_transforms, tile_sizes_zyx, tile_aabbs, output_volume,
|
|
351
|
+
output_volume_origin, cell_size, blend_module
|
|
352
|
+
)
|
|
353
|
+
for (z, y, x) in cells
|
|
354
|
+
]
|
|
355
|
+
|
|
356
|
+
ray.get(futures)
|
|
357
|
+
|
|
358
|
+
# DEBUG - iterative approach
|
|
359
|
+
# for (z, y, x) in cells:
|
|
360
|
+
# color_cell(
|
|
361
|
+
# tile_paths, write_root, write_ds,
|
|
362
|
+
# tile_transforms, tile_sizes_zyx, tile_aabbs,
|
|
363
|
+
# output_volume, output_volume_origin,
|
|
364
|
+
# cell_size, blend_module,
|
|
365
|
+
# z, y, x, torch.device("cpu")
|
|
366
|
+
# )
|
|
367
|
+
|
|
368
|
+
def color_cell(
|
|
369
|
+
tile_paths,
|
|
370
|
+
write_root,
|
|
371
|
+
write_ds,
|
|
372
|
+
tile_transforms: dict[int, list[geometry.Transform]],
|
|
373
|
+
tile_sizes_zyx: dict[int, tuple[int, int, int]],
|
|
374
|
+
tile_aabbs: dict[int, geometry.AABB],
|
|
375
|
+
output_volume: io.OutputArray,
|
|
376
|
+
output_volume_origin: tuple[float, float, float],
|
|
377
|
+
cell_size: tuple[int, int, int],
|
|
378
|
+
blend_module: blend.BlendingModule,
|
|
379
|
+
z: int,
|
|
380
|
+
y: int,
|
|
381
|
+
x: int,
|
|
382
|
+
device: torch.device,
|
|
383
|
+
):
|
|
384
|
+
"""
|
|
385
|
+
Parallelized function called in fusion.
|
|
386
|
+
|
|
387
|
+
Inputs
|
|
388
|
+
-------
|
|
389
|
+
tile_arrays: Dictionary of input tile arrays
|
|
390
|
+
tile_transforms: Dictionary of (list of) registrations associated with each tile
|
|
391
|
+
tile_sizes_zyx: Dictionary of tile sizes
|
|
392
|
+
tile_aabbs_zyx: Dictionary of AABB of each transformed tile
|
|
393
|
+
output_volume: Zarr store parallel functions write to
|
|
394
|
+
output_volume_origin: Location of output volume
|
|
395
|
+
cell_size: operating volume of this function
|
|
396
|
+
blend_module: application blending obj
|
|
397
|
+
z, y, x: location of cell in terms of output volume indices
|
|
398
|
+
"""
|
|
399
|
+
|
|
400
|
+
# Cell Boundaries, exclusive stop index
|
|
401
|
+
output_volume_size = output_volume.shape
|
|
402
|
+
cell_box = np.array(
|
|
403
|
+
[
|
|
404
|
+
[z * cell_size[0], z * cell_size[0] + cell_size[0]],
|
|
405
|
+
[y * cell_size[1], y * cell_size[1] + cell_size[1]],
|
|
406
|
+
[x * cell_size[2], x * cell_size[2] + cell_size[2]],
|
|
407
|
+
]
|
|
408
|
+
)
|
|
409
|
+
cell_box[:, 1] = np.minimum(
|
|
410
|
+
cell_box[:, 1], np.array(output_volume_size[2:])
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
cell_box = cell_box.flatten()
|
|
414
|
+
|
|
415
|
+
# Collision Detection
|
|
416
|
+
# Collision defined by overlapping intervals in all 3 dimensions.
|
|
417
|
+
# Two intervals (A, B) collide if A_max is not <= B_min and A_min is not >= B_max.
|
|
418
|
+
overlapping_tiles: list[int] = []
|
|
419
|
+
for tile_id, t_aabb in tile_aabbs.items():
|
|
420
|
+
if (
|
|
421
|
+
(cell_box[1] > t_aabb[0] and cell_box[0] < t_aabb[1])
|
|
422
|
+
and (cell_box[3] > t_aabb[2] and cell_box[2] < t_aabb[3])
|
|
423
|
+
and (cell_box[5] > t_aabb[4] and cell_box[4] < t_aabb[5])
|
|
424
|
+
):
|
|
425
|
+
overlapping_tiles.append(tile_id)
|
|
426
|
+
|
|
427
|
+
# Interpolation for cell_contributions
|
|
428
|
+
cell_contributions: list[torch.Tensor] = []
|
|
429
|
+
cell_contribution_tile_ids: list[int] = []
|
|
430
|
+
for tile_id in overlapping_tiles:
|
|
431
|
+
# Init tile coords, arange end-exclusive, +0.5 to represent voxel center
|
|
432
|
+
z_indices = torch.arange(cell_box[0], cell_box[1], step=1) + 0.5
|
|
433
|
+
y_indices = torch.arange(cell_box[2], cell_box[3], step=1) + 0.5
|
|
434
|
+
x_indices = torch.arange(cell_box[4], cell_box[5], step=1) + 0.5
|
|
435
|
+
z_indices = z_indices.to(device)
|
|
436
|
+
y_indices = y_indices.to(device)
|
|
437
|
+
x_indices = x_indices.to(device)
|
|
438
|
+
|
|
439
|
+
z_grid, y_grid, x_grid = torch.meshgrid(
|
|
440
|
+
z_indices, y_indices, x_indices, indexing="ij"
|
|
441
|
+
)
|
|
442
|
+
z_grid = torch.unsqueeze(z_grid, 0)
|
|
443
|
+
y_grid = torch.unsqueeze(y_grid, 0)
|
|
444
|
+
x_grid = torch.unsqueeze(x_grid, 0)
|
|
445
|
+
|
|
446
|
+
tile_coords = torch.concatenate((z_grid, y_grid, x_grid), axis=0)
|
|
447
|
+
# (3, z, y, x) -> (z, y, x, 3)
|
|
448
|
+
tile_coords = torch.movedim(tile_coords, source=0, destination=3)
|
|
449
|
+
|
|
450
|
+
# Define tile coords wrt output vol origin
|
|
451
|
+
tile_coords = tile_coords + torch.Tensor(output_volume_origin).to(
|
|
452
|
+
device
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
# Send tile_coords through inverse transforms
|
|
456
|
+
# NOTE: tile_transforms list must be iterated thru in reverse
|
|
457
|
+
# (z, y, x, 3) -> (z, y, x, 3)
|
|
458
|
+
for tfm in reversed(tile_transforms[tile_id]):
|
|
459
|
+
tile_coords = tfm.backward(tile_coords, device=device)
|
|
460
|
+
|
|
461
|
+
# Calculate AABB of transformed coords
|
|
462
|
+
z_min, z_max, y_min, y_max, x_min, x_max = geometry.aabb_3d(
|
|
463
|
+
tile_coords
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
# Mini Optimization: Check true collision before executing interpolation/fusion
|
|
467
|
+
# That is, aabb of transformed coordinates into imagespace actually overlap the image.
|
|
468
|
+
t_size_zyx = tile_sizes_zyx[tile_id]
|
|
469
|
+
if not (
|
|
470
|
+
(z_max > 0 and z_min < t_size_zyx[0])
|
|
471
|
+
and (y_max > 0 and y_min < t_size_zyx[1])
|
|
472
|
+
and (x_max > 0 and x_min < t_size_zyx[2])
|
|
473
|
+
):
|
|
474
|
+
continue
|
|
475
|
+
|
|
476
|
+
# Calculate overlapping region between transformed coords and image boundary
|
|
477
|
+
# For intervals (A, B):
|
|
478
|
+
# The lower bound of overlapping region = max(A_min, B_min)
|
|
479
|
+
# The upper bound of overlapping region = min(A_max, B_max)
|
|
480
|
+
crop_min_z = torch.max(torch.Tensor([0, z_min]))
|
|
481
|
+
crop_max_z = torch.min(torch.Tensor([t_size_zyx[0], z_max]))
|
|
482
|
+
|
|
483
|
+
crop_min_y = torch.max(torch.Tensor([0, y_min]))
|
|
484
|
+
crop_max_y = torch.min(torch.Tensor([t_size_zyx[1], y_max]))
|
|
485
|
+
|
|
486
|
+
crop_min_x = torch.max(torch.Tensor([0, x_min]))
|
|
487
|
+
crop_max_x = torch.min(torch.Tensor([t_size_zyx[2], x_max]))
|
|
488
|
+
|
|
489
|
+
# Make sure crop_{min, max}_{z, y, x} are integers to be used as indices.
|
|
490
|
+
# Minimum values are rounded down to nearest integer.
|
|
491
|
+
# Maximum values are rounded up to nearest integer.
|
|
492
|
+
crop_min_z = int(torch.floor(crop_min_z))
|
|
493
|
+
crop_min_y = int(torch.floor(crop_min_y))
|
|
494
|
+
crop_min_x = int(torch.floor(crop_min_x))
|
|
495
|
+
|
|
496
|
+
crop_max_z = int(torch.ceil(crop_max_z))
|
|
497
|
+
crop_max_y = int(torch.ceil(crop_max_y))
|
|
498
|
+
crop_max_x = int(torch.ceil(crop_max_x))
|
|
499
|
+
|
|
500
|
+
# Define tile coords wrt base image crop coordinates
|
|
501
|
+
image_crop_offset = torch.Tensor(
|
|
502
|
+
[crop_min_z, crop_min_y, crop_min_x]
|
|
503
|
+
).to(device)
|
|
504
|
+
tile_coords = tile_coords - image_crop_offset
|
|
505
|
+
|
|
506
|
+
# Prep inputs to interpolation
|
|
507
|
+
image_crop_slice = (
|
|
508
|
+
0,
|
|
509
|
+
0,
|
|
510
|
+
slice(crop_min_z, crop_max_z),
|
|
511
|
+
slice(crop_min_y, crop_max_y),
|
|
512
|
+
slice(crop_min_x, crop_max_x),
|
|
513
|
+
)
|
|
514
|
+
s3_read = s3fs.S3FileSystem(anon=True)
|
|
515
|
+
src_path = tile_paths[tile_id]
|
|
516
|
+
store = s3fs.S3Map(root=src_path, s3=s3_read)
|
|
517
|
+
zarr_arr = zarr.open(store=store, mode="r")
|
|
518
|
+
image_crop = zarr_arr[image_crop_slice]
|
|
519
|
+
|
|
520
|
+
if isinstance(image_crop, da.Array):
|
|
521
|
+
image_crop = image_crop.compute()
|
|
522
|
+
|
|
523
|
+
image_crop = image_crop.astype(
|
|
524
|
+
np.int32
|
|
525
|
+
) # Promote uint16 -> Pytorch compatible int32
|
|
526
|
+
image_crop = torch.Tensor(image_crop).to(device)
|
|
527
|
+
|
|
528
|
+
# Pytorch flow field follows a different basis than the image numpy basis.
|
|
529
|
+
# Change of basis to interpolation basis, which preserves relative distances/angles/positions.
|
|
530
|
+
# (z, y, x, 3) -> (z, y, x, 3)
|
|
531
|
+
interp_cob_matrix = torch.Tensor(
|
|
532
|
+
[[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]]
|
|
533
|
+
)
|
|
534
|
+
interp_cob = geometry.Affine(interp_cob_matrix)
|
|
535
|
+
tile_coords = interp_cob.forward(tile_coords, device=device)
|
|
536
|
+
|
|
537
|
+
# Interpolation expects 'grid' parameter/sample locations to be normalized [-1, 1].
|
|
538
|
+
# Very specific per-dimension normalization according to CoB
|
|
539
|
+
crop_z_length = crop_max_z - crop_min_z
|
|
540
|
+
crop_y_length = crop_max_y - crop_min_y
|
|
541
|
+
crop_x_length = crop_max_x - crop_min_x
|
|
542
|
+
tile_coords[:, :, :, 0] = (
|
|
543
|
+
tile_coords[:, :, :, 0] - (crop_x_length / 2)
|
|
544
|
+
) / (crop_x_length / 2)
|
|
545
|
+
tile_coords[:, :, :, 1] = (
|
|
546
|
+
tile_coords[:, :, :, 1] - (crop_y_length / 2)
|
|
547
|
+
) / (crop_y_length / 2)
|
|
548
|
+
tile_coords[:, :, :, 2] = (
|
|
549
|
+
tile_coords[:, :, :, 2] - (crop_z_length / 2)
|
|
550
|
+
) / (crop_z_length / 2)
|
|
551
|
+
|
|
552
|
+
# Final reshaping
|
|
553
|
+
# image_crop: (z_in, y_in, x_in) -> (1, 1, z_in, y_in, x_in)
|
|
554
|
+
# tile_coords: (z_out, y_out, x_out, 3) -> (1, z_out, y_out, x_out, 3)
|
|
555
|
+
# => tile_contribution: (1, 1, z_out, y_out, x_out)
|
|
556
|
+
image_crop = image_crop[(None,) * 2]
|
|
557
|
+
tile_coords = torch.unsqueeze(tile_coords, 0)
|
|
558
|
+
|
|
559
|
+
# Interpolate and Store
|
|
560
|
+
tile_contribution = torch.nn.functional.grid_sample(
|
|
561
|
+
image_crop, tile_coords, padding_mode="zeros", mode="nearest", align_corners=False
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
cell_contributions.append(tile_contribution)
|
|
565
|
+
cell_contribution_tile_ids.append(tile_id)
|
|
566
|
+
|
|
567
|
+
del tile_coords
|
|
568
|
+
|
|
569
|
+
# Fuse all cell contributions together with specified blend module
|
|
570
|
+
fused_cell = torch.zeros((1,
|
|
571
|
+
1,
|
|
572
|
+
cell_box[1] - cell_box[0],
|
|
573
|
+
cell_box[3] - cell_box[2],
|
|
574
|
+
cell_box[5] - cell_box[4]))
|
|
575
|
+
if len(cell_contributions) != 0:
|
|
576
|
+
fused_cell = blend_module.blend(
|
|
577
|
+
cell_contributions,
|
|
578
|
+
device,
|
|
579
|
+
kwargs={'chunk_tile_ids': cell_contribution_tile_ids,
|
|
580
|
+
'cell_box': cell_box}
|
|
581
|
+
)
|
|
582
|
+
cell_contributions = []
|
|
583
|
+
|
|
584
|
+
# Write
|
|
585
|
+
output_slice = (
|
|
586
|
+
slice(0, 1),
|
|
587
|
+
slice(0, 1),
|
|
588
|
+
slice(cell_box[0], cell_box[1]),
|
|
589
|
+
slice(cell_box[2], cell_box[3]),
|
|
590
|
+
slice(cell_box[4], cell_box[5]),
|
|
591
|
+
)
|
|
592
|
+
# Convert from float32 -> canonical uint16
|
|
593
|
+
output_chunk = np.array(fused_cell.cpu()).astype(np.uint16)
|
|
594
|
+
|
|
595
|
+
s3_write = s3fs.S3FileSystem(anon=False)
|
|
596
|
+
out_store = s3fs.S3Map(root=write_root, s3=s3_write)
|
|
597
|
+
arr = zarr.open(store=out_store, mode="a")[write_ds]
|
|
598
|
+
arr[output_slice] = np.ascontiguousarray(output_chunk)
|
|
599
|
+
|
|
600
|
+
del fused_cell
|
|
601
|
+
del output_chunk
|