Rhapso 0.1.92__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- Rhapso/__init__.py +1 -0
- Rhapso/data_prep/__init__.py +2 -0
- Rhapso/data_prep/n5_reader.py +188 -0
- Rhapso/data_prep/s3_big_stitcher_reader.py +55 -0
- Rhapso/data_prep/xml_to_dataframe.py +215 -0
- Rhapso/detection/__init__.py +5 -0
- Rhapso/detection/advanced_refinement.py +203 -0
- Rhapso/detection/difference_of_gaussian.py +324 -0
- Rhapso/detection/image_reader.py +117 -0
- Rhapso/detection/metadata_builder.py +130 -0
- Rhapso/detection/overlap_detection.py +327 -0
- Rhapso/detection/points_validation.py +49 -0
- Rhapso/detection/save_interest_points.py +265 -0
- Rhapso/detection/view_transform_models.py +67 -0
- Rhapso/fusion/__init__.py +0 -0
- Rhapso/fusion/affine_fusion/__init__.py +2 -0
- Rhapso/fusion/affine_fusion/blend.py +289 -0
- Rhapso/fusion/affine_fusion/fusion.py +601 -0
- Rhapso/fusion/affine_fusion/geometry.py +159 -0
- Rhapso/fusion/affine_fusion/io.py +546 -0
- Rhapso/fusion/affine_fusion/script_utils.py +111 -0
- Rhapso/fusion/affine_fusion/setup.py +4 -0
- Rhapso/fusion/affine_fusion_worker.py +234 -0
- Rhapso/fusion/multiscale/__init__.py +0 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/__init__.py +19 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/__init__.py +3 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/czi_to_zarr.py +698 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/zarr_writer.py +265 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/models.py +81 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/utils/__init__.py +3 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/utils/utils.py +526 -0
- Rhapso/fusion/multiscale/aind_hcr_data_transformation/zeiss_job.py +249 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/__init__.py +21 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/array_to_zarr.py +257 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/radial_correction.py +557 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/run_capsule.py +98 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/utils/__init__.py +3 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/utils/utils.py +266 -0
- Rhapso/fusion/multiscale/aind_z1_radial_correction/worker.py +89 -0
- Rhapso/fusion/multiscale_worker.py +113 -0
- Rhapso/fusion/neuroglancer_link_gen/__init__.py +8 -0
- Rhapso/fusion/neuroglancer_link_gen/dispim_link.py +235 -0
- Rhapso/fusion/neuroglancer_link_gen/exaspim_link.py +127 -0
- Rhapso/fusion/neuroglancer_link_gen/hcr_link.py +368 -0
- Rhapso/fusion/neuroglancer_link_gen/iSPIM_top.py +47 -0
- Rhapso/fusion/neuroglancer_link_gen/link_utils.py +239 -0
- Rhapso/fusion/neuroglancer_link_gen/main.py +299 -0
- Rhapso/fusion/neuroglancer_link_gen/ng_layer.py +1434 -0
- Rhapso/fusion/neuroglancer_link_gen/ng_state.py +1123 -0
- Rhapso/fusion/neuroglancer_link_gen/parsers.py +336 -0
- Rhapso/fusion/neuroglancer_link_gen/raw_link.py +116 -0
- Rhapso/fusion/neuroglancer_link_gen/utils/__init__.py +4 -0
- Rhapso/fusion/neuroglancer_link_gen/utils/shader_utils.py +85 -0
- Rhapso/fusion/neuroglancer_link_gen/utils/transfer.py +43 -0
- Rhapso/fusion/neuroglancer_link_gen/utils/utils.py +303 -0
- Rhapso/fusion/neuroglancer_link_gen_worker.py +30 -0
- Rhapso/matching/__init__.py +0 -0
- Rhapso/matching/load_and_transform_points.py +458 -0
- Rhapso/matching/ransac_matching.py +544 -0
- Rhapso/matching/save_matches.py +120 -0
- Rhapso/matching/xml_parser.py +302 -0
- Rhapso/pipelines/__init__.py +0 -0
- Rhapso/pipelines/ray/__init__.py +0 -0
- Rhapso/pipelines/ray/aws/__init__.py +0 -0
- Rhapso/pipelines/ray/aws/alignment_pipeline.py +227 -0
- Rhapso/pipelines/ray/aws/config/__init__.py +0 -0
- Rhapso/pipelines/ray/evaluation.py +71 -0
- Rhapso/pipelines/ray/interest_point_detection.py +137 -0
- Rhapso/pipelines/ray/interest_point_matching.py +110 -0
- Rhapso/pipelines/ray/local/__init__.py +0 -0
- Rhapso/pipelines/ray/local/alignment_pipeline.py +167 -0
- Rhapso/pipelines/ray/matching_stats.py +104 -0
- Rhapso/pipelines/ray/param/__init__.py +0 -0
- Rhapso/pipelines/ray/solver.py +120 -0
- Rhapso/pipelines/ray/split_dataset.py +78 -0
- Rhapso/solver/__init__.py +0 -0
- Rhapso/solver/compute_tiles.py +562 -0
- Rhapso/solver/concatenate_models.py +116 -0
- Rhapso/solver/connected_graphs.py +111 -0
- Rhapso/solver/data_prep.py +181 -0
- Rhapso/solver/global_optimization.py +410 -0
- Rhapso/solver/model_and_tile_setup.py +109 -0
- Rhapso/solver/pre_align_tiles.py +323 -0
- Rhapso/solver/save_results.py +97 -0
- Rhapso/solver/view_transforms.py +75 -0
- Rhapso/solver/xml_to_dataframe_solver.py +213 -0
- Rhapso/split_dataset/__init__.py +0 -0
- Rhapso/split_dataset/compute_grid_rules.py +78 -0
- Rhapso/split_dataset/save_points.py +101 -0
- Rhapso/split_dataset/save_xml.py +377 -0
- Rhapso/split_dataset/split_images.py +537 -0
- Rhapso/split_dataset/xml_to_dataframe_split.py +219 -0
- rhapso-0.1.92.dist-info/METADATA +39 -0
- rhapso-0.1.92.dist-info/RECORD +101 -0
- rhapso-0.1.92.dist-info/WHEEL +5 -0
- rhapso-0.1.92.dist-info/licenses/LICENSE +21 -0
- rhapso-0.1.92.dist-info/top_level.txt +2 -0
- tests/__init__.py +1 -0
- tests/test_detection.py +17 -0
- tests/test_matching.py +21 -0
- tests/test_solving.py +21 -0
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Algorithm geometry primitives and utilities.
|
|
3
|
+
"""
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from nptyping import NDArray, Shape
|
|
7
|
+
|
|
8
|
+
Matrix = NDArray[Shape["3, 4"], np.float64]
|
|
9
|
+
AABB = tuple[int, int, int, int, int, int]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Transform:
|
|
13
|
+
"""
|
|
14
|
+
Registration Transform implemented in PyTorch.
|
|
15
|
+
forward/backward transforms preserve the shape of the data.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def forward(
|
|
19
|
+
self, data: torch.Tensor, device: torch.device
|
|
20
|
+
) -> torch.Tensor:
|
|
21
|
+
raise NotImplementedError("Please implement in Transform subclass.")
|
|
22
|
+
|
|
23
|
+
def backward(
|
|
24
|
+
self, data: torch.Tensor, device: torch.device
|
|
25
|
+
) -> torch.Tensor:
|
|
26
|
+
raise NotImplementedError("Please implement in Transform subclass.")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Affine(Transform):
|
|
30
|
+
"""
|
|
31
|
+
Rotation + Translation Registration.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, matrix: Matrix):
|
|
35
|
+
super().__init__()
|
|
36
|
+
assert matrix.shape == (
|
|
37
|
+
3,
|
|
38
|
+
4,
|
|
39
|
+
), "Matrix shape is {matrix.shape}, must be (3, 4)"
|
|
40
|
+
|
|
41
|
+
self.matrix = torch.Tensor(matrix)
|
|
42
|
+
self.matrix_3x3 = self.matrix[:, :3]
|
|
43
|
+
self.translation = self.matrix[:, 3]
|
|
44
|
+
|
|
45
|
+
self.backward_matrix_3x3 = torch.linalg.inv(self.matrix_3x3)
|
|
46
|
+
self.backward_translation = -self.translation
|
|
47
|
+
|
|
48
|
+
def forward(
|
|
49
|
+
self, data: torch.Tensor, device: torch.device
|
|
50
|
+
) -> torch.Tensor:
|
|
51
|
+
"""
|
|
52
|
+
Parameters:
|
|
53
|
+
-----------
|
|
54
|
+
data: (dims) + (3,)
|
|
55
|
+
data is a list/tensor of zyx vectors.
|
|
56
|
+
|
|
57
|
+
device: {cuda:n, 'cpu'}
|
|
58
|
+
device to perform computation on.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
--------
|
|
62
|
+
transformed_data: (dims) + (3,)
|
|
63
|
+
transformed_data is identical shape to the input.
|
|
64
|
+
transformed_data lives on the device specified
|
|
65
|
+
|
|
66
|
+
"""
|
|
67
|
+
assert (
|
|
68
|
+
data.shape[-1] == 3
|
|
69
|
+
), "Data shape is {data.shape}, last dimension of input data must be 3d."
|
|
70
|
+
|
|
71
|
+
# matrix: (3, 3) -> (1,)*(dims - 1) + (3, 3)
|
|
72
|
+
# data: (dims, 3) -> (dims, 3, 1)
|
|
73
|
+
# Ex:
|
|
74
|
+
# (3, 3) -> (1, 1, 1, 3, 3)
|
|
75
|
+
# (z, y, x, 3) -> (z, y, x, 3, 1)
|
|
76
|
+
dims = len(data.shape)
|
|
77
|
+
expanded_matrix = self.matrix_3x3[(None,) * (dims - 1)].to(device)
|
|
78
|
+
expanded_data = torch.unsqueeze(data, dims).to(device)
|
|
79
|
+
|
|
80
|
+
transformed_data = expanded_matrix @ expanded_data
|
|
81
|
+
transformed_data = torch.squeeze(transformed_data, -1)
|
|
82
|
+
transformed_data = transformed_data + self.translation.to(device)
|
|
83
|
+
|
|
84
|
+
return transformed_data
|
|
85
|
+
|
|
86
|
+
def backward(
|
|
87
|
+
self, data: torch.Tensor, device: torch.device
|
|
88
|
+
) -> torch.Tensor:
|
|
89
|
+
"""
|
|
90
|
+
Parameters:
|
|
91
|
+
-----------
|
|
92
|
+
data: (dims) + (3,)
|
|
93
|
+
data is a list/tensor of zyx vectors.
|
|
94
|
+
|
|
95
|
+
device: {cuda:n, 'cpu'}
|
|
96
|
+
device to perform computation on.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
--------
|
|
100
|
+
transformed_data: (dims) + (3,)
|
|
101
|
+
transformed_data is identical shape to the input.
|
|
102
|
+
transformed_data lives on the device specified
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
assert (
|
|
106
|
+
data.shape[-1] == 3
|
|
107
|
+
), "Data shape is {data.shape}, last dimension of input data must be 3d."
|
|
108
|
+
|
|
109
|
+
# matrix: (3, 3) -> (1,)*(dims - 1) + (3, 3)
|
|
110
|
+
# data: (dims, 3) -> (dims, 3, 1)
|
|
111
|
+
# Ex:
|
|
112
|
+
# (3, 3) -> (1, 1, 1, 3, 3)
|
|
113
|
+
# (z, y, x, 3) -> (z, y, x, 3, 1)
|
|
114
|
+
dims = len(data.shape)
|
|
115
|
+
expanded_matrix = self.backward_matrix_3x3[(None,) * (dims - 1)].to(
|
|
116
|
+
device
|
|
117
|
+
)
|
|
118
|
+
expanded_data = torch.unsqueeze(data, dims).to(device)
|
|
119
|
+
|
|
120
|
+
transformed_data = expanded_matrix @ expanded_data
|
|
121
|
+
transformed_data = torch.squeeze(transformed_data, -1)
|
|
122
|
+
transformed_data = transformed_data + self.backward_translation.to(
|
|
123
|
+
device
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
return transformed_data
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def aabb_3d(data) -> AABB:
|
|
130
|
+
"""
|
|
131
|
+
Parameters:
|
|
132
|
+
-----------
|
|
133
|
+
data: (dims) + (3,)
|
|
134
|
+
data is a list/tensor of zyx vectors.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
--------
|
|
138
|
+
aabb: Ranges ordered in same order as components in input buffer.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
assert (
|
|
142
|
+
data.shape[-1] == 3
|
|
143
|
+
), "Data shape is {data.shape}, last dimension of input data must be 3d."
|
|
144
|
+
dims = len(data.shape)
|
|
145
|
+
|
|
146
|
+
output = []
|
|
147
|
+
for i in range(3):
|
|
148
|
+
# Slice syntax:
|
|
149
|
+
# (slice(None, None, None)) => arr[:]
|
|
150
|
+
# (i) => arr[i]
|
|
151
|
+
dim_slice = [slice(None, None, None)] * (dims - 1)
|
|
152
|
+
dim_slice = tuple(dim_slice + [i])
|
|
153
|
+
|
|
154
|
+
dim_min = torch.min(data[dim_slice]).item()
|
|
155
|
+
dim_max = torch.max(data[dim_slice]).item()
|
|
156
|
+
output.append(dim_min)
|
|
157
|
+
output.append(dim_max)
|
|
158
|
+
|
|
159
|
+
return tuple(output)
|
|
@@ -0,0 +1,546 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Defines all standard input to fusion algorithm.
|
|
3
|
+
"""
|
|
4
|
+
from collections import OrderedDict
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import boto3
|
|
9
|
+
import dask.array as da
|
|
10
|
+
import numpy as np
|
|
11
|
+
from numcodecs import Blosc
|
|
12
|
+
import re
|
|
13
|
+
import s3fs
|
|
14
|
+
import tensorstore as ts
|
|
15
|
+
import xmltodict
|
|
16
|
+
import yaml
|
|
17
|
+
import zarr
|
|
18
|
+
import fsspec
|
|
19
|
+
|
|
20
|
+
from . import geometry
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def read_config_yaml(yaml_path: str) -> dict:
|
|
24
|
+
with open(yaml_path, "r") as f:
|
|
25
|
+
yaml_dict = yaml.safe_load(f)
|
|
26
|
+
return yaml_dict
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def write_config_yaml(yaml_path: str, yaml_data: dict) -> None:
|
|
30
|
+
with open(yaml_path, "w") as file:
|
|
31
|
+
yaml.dump(yaml_data, file)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def open_zarr_s3(bucket: str, path: str) -> ts.TensorStore:
|
|
35
|
+
return ts.open({
|
|
36
|
+
'driver': 'zarr',
|
|
37
|
+
'kvstore': {
|
|
38
|
+
'driver': 'http',
|
|
39
|
+
'base_url': f'https://{bucket}.s3.us-west-2.amazonaws.com/{path}',
|
|
40
|
+
},
|
|
41
|
+
}).result()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class InputArray:
|
|
45
|
+
def __getitem__(self, value):
|
|
46
|
+
"""
|
|
47
|
+
Member function for slice syntax, ex: arr[0:10, 0:10]
|
|
48
|
+
Value is a Python slice object.
|
|
49
|
+
"""
|
|
50
|
+
raise NotImplementedError("Please implement in InputArray subclass.")
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def shape(self):
|
|
54
|
+
raise NotImplementedError("Please implement in InputArray subclass.")
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class InputDask(InputArray):
|
|
58
|
+
def __init__(self, arr: da.Array):
|
|
59
|
+
self.arr = arr
|
|
60
|
+
|
|
61
|
+
def __getitem__(self, slice):
|
|
62
|
+
return np.array(self.arr[slice].compute())
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def shape(self):
|
|
66
|
+
return self.arr.shape
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class InputTensorstore(InputArray):
|
|
70
|
+
def __init__(self, arr: ts.TensorStore):
|
|
71
|
+
self.arr = arr
|
|
72
|
+
|
|
73
|
+
def __getitem__(self, slice):
|
|
74
|
+
return np.array(self.arr[slice])
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def shape(self):
|
|
78
|
+
return self.arr.shape
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class Dataset:
|
|
82
|
+
"""
|
|
83
|
+
Data are 5d tczyx objects.
|
|
84
|
+
Transforms are 3d zyx objects.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
class WriteError(Exception):
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def tile_volumes_tczyx(self) -> dict[int, InputArray]:
|
|
92
|
+
"""
|
|
93
|
+
Dict of tile_id -> tile references.
|
|
94
|
+
"""
|
|
95
|
+
raise NotImplementedError("Please implement in Dataset subclass.")
|
|
96
|
+
|
|
97
|
+
@tile_volumes_tczyx.setter
|
|
98
|
+
def tile_volumes_tczyx(self, value):
|
|
99
|
+
raise Dataset.WriteError("tile_volumes_tczyx is read-only.")
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def tile_transforms_zyx(self) -> dict[int, list[geometry.Transform]]:
|
|
103
|
+
"""
|
|
104
|
+
Dict of tile_id -> tile transforms.
|
|
105
|
+
"""
|
|
106
|
+
raise NotImplementedError("Please implement in Dataset subclass.")
|
|
107
|
+
|
|
108
|
+
@tile_transforms_zyx.setter
|
|
109
|
+
def tile_transforms_zyx(self, value):
|
|
110
|
+
raise Dataset.WriteError("tile_transforms_zyx is read-only.")
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def tile_resolution_zyx(self) -> tuple[float, float, float]:
|
|
114
|
+
"""
|
|
115
|
+
Specifies absolute size of each voxel in tile volume.
|
|
116
|
+
Tile resolution is used to scale tile volume into absolute space prior to registration.
|
|
117
|
+
"""
|
|
118
|
+
raise NotImplementedError("Please implement in Dataset subclass.")
|
|
119
|
+
|
|
120
|
+
@tile_resolution_zyx.setter
|
|
121
|
+
def tile_resolution_zyx(self, value):
|
|
122
|
+
raise Dataset.WriteError("tile_resolution_zyx is read-only.")
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class BigStitcherDataset(Dataset):
|
|
126
|
+
"""
|
|
127
|
+
Dataset class for loading in BigStitcher Dataset.
|
|
128
|
+
Intended for the base registration channel.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
def __init__(self, xml_path: str, s3_path: str, datastore: int, level: int = 0):
|
|
132
|
+
self.xml_path = xml_path
|
|
133
|
+
self.s3_path = s3_path
|
|
134
|
+
|
|
135
|
+
assert datastore in [0, 1], \
|
|
136
|
+
f"Only 0 = Dask and 1 = Tensorstore supported."
|
|
137
|
+
self.datastore = datastore # {0 = Dask, 1 = Tensorstore}
|
|
138
|
+
|
|
139
|
+
allowed_levels = [0, 1, 2, 3, 4, 5]
|
|
140
|
+
assert level in allowed_levels, \
|
|
141
|
+
f"Level {level} is not in {allowed_levels}"
|
|
142
|
+
self.level = level
|
|
143
|
+
|
|
144
|
+
self.tile_cache: dict[int, InputArray] = {}
|
|
145
|
+
self.transform_cache: dict[int, list[geometry.Transform]] = {}
|
|
146
|
+
|
|
147
|
+
@property
|
|
148
|
+
def tile_volumes_tczyx(self) -> dict[int, InputArray]:
|
|
149
|
+
if len(self.tile_cache) != 0:
|
|
150
|
+
tile_paths = self._extract_tile_paths(self.xml_path)
|
|
151
|
+
tile_paths = self._extract_tile_paths(self.xml_path)
|
|
152
|
+
for t_id, t_path in tile_paths.items():
|
|
153
|
+
if not self.s3_path.endswith('/'):
|
|
154
|
+
self.s3_path = self.s3_path + '/'
|
|
155
|
+
|
|
156
|
+
level_str = '/' + str(self.level) # Ex: '/0'
|
|
157
|
+
tile_paths[t_id] = self.s3_path + Path(t_path).name + level_str
|
|
158
|
+
return self.tile_cache, tile_paths
|
|
159
|
+
|
|
160
|
+
# Otherwise, fetch for first time
|
|
161
|
+
tile_paths = self._extract_tile_paths(self.xml_path)
|
|
162
|
+
for t_id, t_path in tile_paths.items():
|
|
163
|
+
if not self.s3_path.endswith('/'):
|
|
164
|
+
self.s3_path = self.s3_path + '/'
|
|
165
|
+
|
|
166
|
+
level_str = '/' + str(self.level) # Ex: '/0'
|
|
167
|
+
tile_paths[t_id] = self.s3_path + Path(t_path).name + level_str
|
|
168
|
+
|
|
169
|
+
tile_arrays: dict[int, InputArray] = {}
|
|
170
|
+
for tile_id, t_path in tile_paths.items():
|
|
171
|
+
|
|
172
|
+
arr = None
|
|
173
|
+
if self.datastore == 0: # Dask
|
|
174
|
+
tile_zarr = da.from_zarr(t_path)
|
|
175
|
+
arr = InputDask(tile_zarr)
|
|
176
|
+
elif self.datastore == 1: # Tensorstore
|
|
177
|
+
# Referencing the following naming convention:
|
|
178
|
+
# s3://BUCKET_NAME/DATASET_NAME/TILE/NAME/CHANNEL
|
|
179
|
+
parts = t_path.split('/')
|
|
180
|
+
bucket = parts[2]
|
|
181
|
+
third_slash_index = len(parts[0]) + len(parts[1]) + len(parts[2]) + 3
|
|
182
|
+
obj = t_path[third_slash_index:]
|
|
183
|
+
|
|
184
|
+
tile_zarr = open_zarr_s3(bucket, obj)
|
|
185
|
+
arr = InputTensorstore(tile_zarr)
|
|
186
|
+
|
|
187
|
+
print(f'Loading Tile {tile_id} / {len(tile_paths)}')
|
|
188
|
+
tile_arrays[int(tile_id)] = arr
|
|
189
|
+
|
|
190
|
+
self.tile_cache = tile_arrays
|
|
191
|
+
|
|
192
|
+
return tile_arrays, tile_paths
|
|
193
|
+
|
|
194
|
+
@property
|
|
195
|
+
def tile_transforms_zyx(self) -> dict[int, list[geometry.Transform]]:
|
|
196
|
+
if len(self.transform_cache) != 0:
|
|
197
|
+
return self.transform_cache
|
|
198
|
+
|
|
199
|
+
# Otherwise, fetch for first time
|
|
200
|
+
tile_tfms = self._extract_tile_transforms(self.xml_path)
|
|
201
|
+
tile_net_tfms = self._calculate_net_transforms(tile_tfms)
|
|
202
|
+
|
|
203
|
+
for tile_id, tfm in tile_net_tfms.items():
|
|
204
|
+
# BigStitcher XYZ -> ZYX
|
|
205
|
+
# Given Matrix 3x4:
|
|
206
|
+
# Swap Rows 0 and 2; Swap Colums 0 and 2
|
|
207
|
+
tmp = np.copy(tfm)
|
|
208
|
+
tmp[[0, 2], :] = tmp[[2, 0], :]
|
|
209
|
+
tmp[:, [0, 2]] = tmp[:, [2, 0]]
|
|
210
|
+
tfm = tmp
|
|
211
|
+
|
|
212
|
+
# Assemble matrix stack:
|
|
213
|
+
# 1) Add base registration
|
|
214
|
+
matrix_stack = [geometry.Affine(tfm)]
|
|
215
|
+
|
|
216
|
+
# 2) Append up/down-sampling transforms
|
|
217
|
+
sf = 2. ** self.level
|
|
218
|
+
up = geometry.Affine(np.array([[sf, 0., 0., 0.],
|
|
219
|
+
[0., sf, 0., 0.],
|
|
220
|
+
[0., 0., sf, 0.]]))
|
|
221
|
+
down = geometry.Affine(np.array([[1./sf, 0., 0., 0.],
|
|
222
|
+
[0., 1./sf, 0., 0.],
|
|
223
|
+
[0., 0., 1./sf, 0.]]))
|
|
224
|
+
matrix_stack.insert(0, up)
|
|
225
|
+
matrix_stack.append(down)
|
|
226
|
+
tile_net_tfms[int(tile_id)] = matrix_stack
|
|
227
|
+
|
|
228
|
+
self.transform_cache = tile_net_tfms
|
|
229
|
+
|
|
230
|
+
return tile_net_tfms
|
|
231
|
+
|
|
232
|
+
@property
|
|
233
|
+
def tile_resolution_zyx(self) -> tuple[float, float, float]:
|
|
234
|
+
if self.xml_path.startswith("s3://"):
|
|
235
|
+
with fsspec.open(self.xml_path, mode="rt") as f:
|
|
236
|
+
data: OrderedDict = xmltodict.parse(f.read())
|
|
237
|
+
else:
|
|
238
|
+
with open(self.xml_path, "r") as file:
|
|
239
|
+
data: OrderedDict = xmltodict.parse(file.read())
|
|
240
|
+
|
|
241
|
+
resolution_str = data["SpimData"]["SequenceDescription"]["ViewSetups"][
|
|
242
|
+
"ViewSetup"
|
|
243
|
+
][0]["voxelSize"]["size"]
|
|
244
|
+
resolution_xyz = [float(num) for num in resolution_str.split(" ")]
|
|
245
|
+
return tuple(resolution_xyz[::-1])
|
|
246
|
+
|
|
247
|
+
def _extract_tile_paths(self, xml_path: str) -> dict[int, str]:
|
|
248
|
+
"""
|
|
249
|
+
Utility called in property.
|
|
250
|
+
Parses BDV xml and outputs map of setup_id -> tile path.
|
|
251
|
+
|
|
252
|
+
Parameters
|
|
253
|
+
------------------------
|
|
254
|
+
xml_path: str
|
|
255
|
+
Path of xml outputted from BigStitcher.
|
|
256
|
+
|
|
257
|
+
Returns
|
|
258
|
+
------------------------
|
|
259
|
+
dict[int, str]:
|
|
260
|
+
Dictionary of tile ids to tile paths.
|
|
261
|
+
"""
|
|
262
|
+
view_paths: dict[int, str] = {}
|
|
263
|
+
|
|
264
|
+
if xml_path.startswith("s3://"):
|
|
265
|
+
with fsspec.open(xml_path, mode="rt") as f:
|
|
266
|
+
data: OrderedDict = xmltodict.parse(f.read())
|
|
267
|
+
else:
|
|
268
|
+
with open(xml_path, "r") as file:
|
|
269
|
+
data: OrderedDict = xmltodict.parse(file.read())
|
|
270
|
+
|
|
271
|
+
parent = data["SpimData"]["SequenceDescription"]["ImageLoader"][
|
|
272
|
+
"zarr"
|
|
273
|
+
]["#text"]
|
|
274
|
+
|
|
275
|
+
for i, zgroup in enumerate(
|
|
276
|
+
data["SpimData"]["SequenceDescription"]["ImageLoader"]["zgroups"][
|
|
277
|
+
"zgroup"
|
|
278
|
+
]
|
|
279
|
+
):
|
|
280
|
+
view_paths[i] = parent + "/" + zgroup["@path"]
|
|
281
|
+
|
|
282
|
+
return view_paths
|
|
283
|
+
|
|
284
|
+
def _extract_tile_transforms(self, xml_path: str) -> dict[int, list[dict]]:
|
|
285
|
+
"""
|
|
286
|
+
Utility called in property.
|
|
287
|
+
Parses BDV xml and outputs map of setup_id -> list of transformations
|
|
288
|
+
Output dictionary maps view number to list of {'@type', 'Name', 'affine'}
|
|
289
|
+
where 'affine' contains the transform as string of 12 floats.
|
|
290
|
+
|
|
291
|
+
Matrices are listed in the order of forward execution.
|
|
292
|
+
|
|
293
|
+
Parameters
|
|
294
|
+
------------------------
|
|
295
|
+
xml_path: str
|
|
296
|
+
Path of xml outputted by BigStitcher.
|
|
297
|
+
|
|
298
|
+
Returns
|
|
299
|
+
------------------------
|
|
300
|
+
dict[int, list[dict]]
|
|
301
|
+
Dictionary of tile ids to transform list. List entries described above.
|
|
302
|
+
"""
|
|
303
|
+
|
|
304
|
+
view_transforms: dict[int, list[dict]] = {}
|
|
305
|
+
|
|
306
|
+
if xml_path.startswith("s3://"):
|
|
307
|
+
with fsspec.open(xml_path, mode="rt") as f:
|
|
308
|
+
data: OrderedDict = xmltodict.parse(f.read())
|
|
309
|
+
else:
|
|
310
|
+
with open(xml_path, "r") as file:
|
|
311
|
+
data: OrderedDict = xmltodict.parse(file.read())
|
|
312
|
+
|
|
313
|
+
for view_reg in data["SpimData"]["ViewRegistrations"][
|
|
314
|
+
"ViewRegistration"
|
|
315
|
+
]:
|
|
316
|
+
tfm_stack = view_reg["ViewTransform"]
|
|
317
|
+
if type(tfm_stack) is not list:
|
|
318
|
+
tfm_stack = [tfm_stack]
|
|
319
|
+
view_transforms[int(view_reg["@setup"])] = tfm_stack
|
|
320
|
+
|
|
321
|
+
view_transforms = {
|
|
322
|
+
view: tfs[::-1] for view, tfs in view_transforms.items()
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
return view_transforms
|
|
326
|
+
|
|
327
|
+
def _calculate_net_transforms(
|
|
328
|
+
self, view_transforms: dict[int, list[dict]]
|
|
329
|
+
) -> dict[int, geometry.Matrix]:
|
|
330
|
+
"""
|
|
331
|
+
Utility called in property.
|
|
332
|
+
Accumulate net transform and net translation for each matrix stack.
|
|
333
|
+
Net translation =
|
|
334
|
+
Sum of translation vectors converted into original nominal basis
|
|
335
|
+
Net transform =
|
|
336
|
+
Product of 3x3 matrices
|
|
337
|
+
NOTE: Translational component (last column) is defined
|
|
338
|
+
wrt to the DOMAIN, not codomain.
|
|
339
|
+
Implementation is informed by this given.
|
|
340
|
+
|
|
341
|
+
Parameters
|
|
342
|
+
------------------------
|
|
343
|
+
view_transforms: dict[int, list[dict]]
|
|
344
|
+
Dictionary of tile ids to transforms associated with each tile.
|
|
345
|
+
|
|
346
|
+
Returns
|
|
347
|
+
------------------------
|
|
348
|
+
dict[int, np.ndarray]:
|
|
349
|
+
Dictionary of tile ids to net_transform.
|
|
350
|
+
"""
|
|
351
|
+
|
|
352
|
+
identity_transform = np.array(
|
|
353
|
+
[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]]
|
|
354
|
+
)
|
|
355
|
+
net_transforms: dict[int, np.ndarray] = {}
|
|
356
|
+
for tile_id in view_transforms:
|
|
357
|
+
net_transforms[tile_id] = np.copy(identity_transform)
|
|
358
|
+
|
|
359
|
+
for view, tfs in view_transforms.items():
|
|
360
|
+
net_translation = np.zeros(3)
|
|
361
|
+
net_matrix_3x3 = np.eye(3)
|
|
362
|
+
curr_inverse = np.eye(3)
|
|
363
|
+
|
|
364
|
+
for (
|
|
365
|
+
tf
|
|
366
|
+
) in (
|
|
367
|
+
tfs
|
|
368
|
+
): # Tfs is a list of dicts containing transform under 'affine' key
|
|
369
|
+
nums = [float(val) for val in tf["affine"].split(" ")]
|
|
370
|
+
matrix_3x3 = np.array([nums[0::4], nums[1::4], nums[2::4]])
|
|
371
|
+
translation = np.array(nums[3::4])
|
|
372
|
+
|
|
373
|
+
net_translation = net_translation + (
|
|
374
|
+
curr_inverse @ translation
|
|
375
|
+
)
|
|
376
|
+
net_matrix_3x3 = matrix_3x3 @ net_matrix_3x3
|
|
377
|
+
curr_inverse = np.linalg.inv(
|
|
378
|
+
net_matrix_3x3
|
|
379
|
+
) # Update curr_inverse
|
|
380
|
+
|
|
381
|
+
net_transforms[view] = np.hstack(
|
|
382
|
+
(net_matrix_3x3, net_translation.reshape(3, 1))
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
return net_transforms
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
class BigStitcherDatasetChannel(BigStitcherDataset):
|
|
389
|
+
"""
|
|
390
|
+
Convenience Dataset class that reuses tile registrations,
|
|
391
|
+
tile shapes, and tile resolution across channels.
|
|
392
|
+
Tile volumes is overloaded with channel-specific data.
|
|
393
|
+
|
|
394
|
+
NOTE: Only loads full resolution images/registrations.
|
|
395
|
+
"""
|
|
396
|
+
|
|
397
|
+
def __init__(self, xml_path: str, s3_path: str, channel_num: int, datastore: int):
|
|
398
|
+
"""
|
|
399
|
+
Only new information required is channel number.
|
|
400
|
+
"""
|
|
401
|
+
super().__init__(xml_path, s3_path, datastore)
|
|
402
|
+
self.channel_num = channel_num
|
|
403
|
+
|
|
404
|
+
self.tile_cache: dict[int, InputArray] = {}
|
|
405
|
+
|
|
406
|
+
@property
|
|
407
|
+
def tile_volumes_tczyx(self) -> dict[int, InputArray]:
|
|
408
|
+
"""
|
|
409
|
+
Load in channel-specific tiles.
|
|
410
|
+
"""
|
|
411
|
+
|
|
412
|
+
if len(self.tile_cache) != 0:
|
|
413
|
+
return self.tile_cache
|
|
414
|
+
|
|
415
|
+
# Otherwise fetch for first time
|
|
416
|
+
tile_arrays: dict[int, InputArray] = {}
|
|
417
|
+
|
|
418
|
+
with open(self.xml_path, "r") as file:
|
|
419
|
+
data: OrderedDict = xmltodict.parse(file.read())
|
|
420
|
+
tile_id_lut = {}
|
|
421
|
+
for zgroup in data['SpimData']['SequenceDescription']['ImageLoader']['zgroups']['zgroup']:
|
|
422
|
+
tile_id = zgroup['@setup']
|
|
423
|
+
tile_name = zgroup['path']
|
|
424
|
+
s_parts = tile_name.split('_')
|
|
425
|
+
location = (int(s_parts[2]),
|
|
426
|
+
int(s_parts[4]),
|
|
427
|
+
int(s_parts[6]))
|
|
428
|
+
tile_id_lut[location] = int(tile_id)
|
|
429
|
+
|
|
430
|
+
# Reference path: s3://aind-open-data/HCR_677594_2023-10-20_15-10-36/SPIM.ome.zarr/
|
|
431
|
+
# Reference tilename: <tile_name, no underscores>_X_####_Y_####_Z_####_ch_###.zarr
|
|
432
|
+
slash_2 = self.s3_path.find('/', self.s3_path.find('/') + 1)
|
|
433
|
+
slash_3 = self.s3_path.find('/', self.s3_path.find('/', self.s3_path.find('/') + 1) + 1)
|
|
434
|
+
bucket_name = self.s3_path[slash_2 + 1:slash_3]
|
|
435
|
+
directory_path = self.s3_path[slash_3 + 1:]
|
|
436
|
+
|
|
437
|
+
for p in self._list_bucket_directory(bucket_name, directory_path):
|
|
438
|
+
if p.endswith('.zgroup'):
|
|
439
|
+
continue
|
|
440
|
+
|
|
441
|
+
# Data loading
|
|
442
|
+
channel_num = -1
|
|
443
|
+
search_result = re.search(r'(\d*)\.zarr.?$', p)
|
|
444
|
+
if search_result:
|
|
445
|
+
channel_num = int(search_result.group(1))
|
|
446
|
+
if channel_num == self.channel_num:
|
|
447
|
+
|
|
448
|
+
full_resolution_p = self.s3_path + p + '/0'
|
|
449
|
+
s_parts = p.split('_')
|
|
450
|
+
location = (int(s_parts[2]),
|
|
451
|
+
int(s_parts[4]),
|
|
452
|
+
int(s_parts[6]))
|
|
453
|
+
tile_id = tile_id_lut[location]
|
|
454
|
+
|
|
455
|
+
arr = None
|
|
456
|
+
if self.datastore == 0: # Dask
|
|
457
|
+
tile_zarr = da.from_zarr(full_resolution_p)
|
|
458
|
+
arr = InputDask(tile_zarr)
|
|
459
|
+
|
|
460
|
+
elif self.datastore == 1: # Tensorstore
|
|
461
|
+
# Referencing the following naming convention:
|
|
462
|
+
# s3://BUCKET_NAME/DATASET_NAME/TILE/NAME/CHANNEL
|
|
463
|
+
parts = full_resolution_p.split('/')
|
|
464
|
+
bucket = parts[2]
|
|
465
|
+
third_slash_index = len(parts[0]) + len(parts[1]) + len(parts[2]) + 3
|
|
466
|
+
obj = full_resolution_p[third_slash_index:]
|
|
467
|
+
|
|
468
|
+
tile_zarr = open_zarr_s3(bucket, obj)
|
|
469
|
+
arr = InputTensorstore(tile_zarr)
|
|
470
|
+
|
|
471
|
+
print(f'Loading Tile {tile_id} / {len(tile_id_lut)}')
|
|
472
|
+
tile_arrays[int(tile_id)] = arr
|
|
473
|
+
|
|
474
|
+
self.tile_cache = tile_arrays
|
|
475
|
+
|
|
476
|
+
return tile_arrays
|
|
477
|
+
|
|
478
|
+
def _list_bucket_directory(self, bucket_name: str, directory_path: str):
|
|
479
|
+
client = boto3.client("s3")
|
|
480
|
+
result = client.list_objects(
|
|
481
|
+
Bucket=bucket_name, Prefix=directory_path, Delimiter="/"
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
paths = [] # These are paths
|
|
485
|
+
for o in result.get("CommonPrefixes"):
|
|
486
|
+
paths.append(o.get("Prefix"))
|
|
487
|
+
|
|
488
|
+
# Parse the ending files from the paths
|
|
489
|
+
files = []
|
|
490
|
+
for p in paths:
|
|
491
|
+
if p.endswith('/'):
|
|
492
|
+
p = p.rstrip("/") # Remove trailing slash from directories
|
|
493
|
+
|
|
494
|
+
parts = p.split('/')
|
|
495
|
+
files.append(parts[-1])
|
|
496
|
+
|
|
497
|
+
return files
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
class OutputArray:
|
|
501
|
+
def __setitem__(self, index, value):
|
|
502
|
+
raise NotImplementedError("Please implement in InputArray subclass.")
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
class OutputDask(OutputArray):
|
|
506
|
+
def __init__(self, arr: da.Array):
|
|
507
|
+
self.arr = arr
|
|
508
|
+
|
|
509
|
+
def __setitem__(self, index, value):
|
|
510
|
+
self.arr[index] = value
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
class OutputTensorstore(OutputArray):
|
|
514
|
+
def __init__(self, arr: ts.TensorStore):
|
|
515
|
+
self.arr = arr
|
|
516
|
+
|
|
517
|
+
def __setitem__(self, index, value):
|
|
518
|
+
self.arr[index].write(value).result()
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
@dataclass
|
|
522
|
+
class OutputParameters:
|
|
523
|
+
path: str
|
|
524
|
+
chunksize: tuple[int, int, int, int, int]
|
|
525
|
+
resolution_zyx: tuple[float, float, float]
|
|
526
|
+
datastore: int # {0 == Dask, 1 == Tensorstore}
|
|
527
|
+
dtype: np.dtype = np.uint16
|
|
528
|
+
dimension_separator: str = "/"
|
|
529
|
+
compressor = Blosc(cname='zstd', clevel=1, shuffle=Blosc.SHUFFLE)
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
@dataclass
|
|
533
|
+
class RuntimeParameters:
|
|
534
|
+
"""
|
|
535
|
+
Simplified Runtime Parameters
|
|
536
|
+
option:
|
|
537
|
+
0: single process exectution
|
|
538
|
+
1: multiprocessing execution
|
|
539
|
+
2: dask execution
|
|
540
|
+
pool_size: number of processes/vCPUs for options {1, 2}
|
|
541
|
+
worker_cells:
|
|
542
|
+
list of cells/chunks this execution operates on
|
|
543
|
+
"""
|
|
544
|
+
option: int
|
|
545
|
+
pool_size: int
|
|
546
|
+
worker_cells: list[tuple[int, int, int]]
|