Rhapso 0.1.92__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. Rhapso/__init__.py +1 -0
  2. Rhapso/data_prep/__init__.py +2 -0
  3. Rhapso/data_prep/n5_reader.py +188 -0
  4. Rhapso/data_prep/s3_big_stitcher_reader.py +55 -0
  5. Rhapso/data_prep/xml_to_dataframe.py +215 -0
  6. Rhapso/detection/__init__.py +5 -0
  7. Rhapso/detection/advanced_refinement.py +203 -0
  8. Rhapso/detection/difference_of_gaussian.py +324 -0
  9. Rhapso/detection/image_reader.py +117 -0
  10. Rhapso/detection/metadata_builder.py +130 -0
  11. Rhapso/detection/overlap_detection.py +327 -0
  12. Rhapso/detection/points_validation.py +49 -0
  13. Rhapso/detection/save_interest_points.py +265 -0
  14. Rhapso/detection/view_transform_models.py +67 -0
  15. Rhapso/fusion/__init__.py +0 -0
  16. Rhapso/fusion/affine_fusion/__init__.py +2 -0
  17. Rhapso/fusion/affine_fusion/blend.py +289 -0
  18. Rhapso/fusion/affine_fusion/fusion.py +601 -0
  19. Rhapso/fusion/affine_fusion/geometry.py +159 -0
  20. Rhapso/fusion/affine_fusion/io.py +546 -0
  21. Rhapso/fusion/affine_fusion/script_utils.py +111 -0
  22. Rhapso/fusion/affine_fusion/setup.py +4 -0
  23. Rhapso/fusion/affine_fusion_worker.py +234 -0
  24. Rhapso/fusion/multiscale/__init__.py +0 -0
  25. Rhapso/fusion/multiscale/aind_hcr_data_transformation/__init__.py +19 -0
  26. Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/__init__.py +3 -0
  27. Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/czi_to_zarr.py +698 -0
  28. Rhapso/fusion/multiscale/aind_hcr_data_transformation/compress/zarr_writer.py +265 -0
  29. Rhapso/fusion/multiscale/aind_hcr_data_transformation/models.py +81 -0
  30. Rhapso/fusion/multiscale/aind_hcr_data_transformation/utils/__init__.py +3 -0
  31. Rhapso/fusion/multiscale/aind_hcr_data_transformation/utils/utils.py +526 -0
  32. Rhapso/fusion/multiscale/aind_hcr_data_transformation/zeiss_job.py +249 -0
  33. Rhapso/fusion/multiscale/aind_z1_radial_correction/__init__.py +21 -0
  34. Rhapso/fusion/multiscale/aind_z1_radial_correction/array_to_zarr.py +257 -0
  35. Rhapso/fusion/multiscale/aind_z1_radial_correction/radial_correction.py +557 -0
  36. Rhapso/fusion/multiscale/aind_z1_radial_correction/run_capsule.py +98 -0
  37. Rhapso/fusion/multiscale/aind_z1_radial_correction/utils/__init__.py +3 -0
  38. Rhapso/fusion/multiscale/aind_z1_radial_correction/utils/utils.py +266 -0
  39. Rhapso/fusion/multiscale/aind_z1_radial_correction/worker.py +89 -0
  40. Rhapso/fusion/multiscale_worker.py +113 -0
  41. Rhapso/fusion/neuroglancer_link_gen/__init__.py +8 -0
  42. Rhapso/fusion/neuroglancer_link_gen/dispim_link.py +235 -0
  43. Rhapso/fusion/neuroglancer_link_gen/exaspim_link.py +127 -0
  44. Rhapso/fusion/neuroglancer_link_gen/hcr_link.py +368 -0
  45. Rhapso/fusion/neuroglancer_link_gen/iSPIM_top.py +47 -0
  46. Rhapso/fusion/neuroglancer_link_gen/link_utils.py +239 -0
  47. Rhapso/fusion/neuroglancer_link_gen/main.py +299 -0
  48. Rhapso/fusion/neuroglancer_link_gen/ng_layer.py +1434 -0
  49. Rhapso/fusion/neuroglancer_link_gen/ng_state.py +1123 -0
  50. Rhapso/fusion/neuroglancer_link_gen/parsers.py +336 -0
  51. Rhapso/fusion/neuroglancer_link_gen/raw_link.py +116 -0
  52. Rhapso/fusion/neuroglancer_link_gen/utils/__init__.py +4 -0
  53. Rhapso/fusion/neuroglancer_link_gen/utils/shader_utils.py +85 -0
  54. Rhapso/fusion/neuroglancer_link_gen/utils/transfer.py +43 -0
  55. Rhapso/fusion/neuroglancer_link_gen/utils/utils.py +303 -0
  56. Rhapso/fusion/neuroglancer_link_gen_worker.py +30 -0
  57. Rhapso/matching/__init__.py +0 -0
  58. Rhapso/matching/load_and_transform_points.py +458 -0
  59. Rhapso/matching/ransac_matching.py +544 -0
  60. Rhapso/matching/save_matches.py +120 -0
  61. Rhapso/matching/xml_parser.py +302 -0
  62. Rhapso/pipelines/__init__.py +0 -0
  63. Rhapso/pipelines/ray/__init__.py +0 -0
  64. Rhapso/pipelines/ray/aws/__init__.py +0 -0
  65. Rhapso/pipelines/ray/aws/alignment_pipeline.py +227 -0
  66. Rhapso/pipelines/ray/aws/config/__init__.py +0 -0
  67. Rhapso/pipelines/ray/evaluation.py +71 -0
  68. Rhapso/pipelines/ray/interest_point_detection.py +137 -0
  69. Rhapso/pipelines/ray/interest_point_matching.py +110 -0
  70. Rhapso/pipelines/ray/local/__init__.py +0 -0
  71. Rhapso/pipelines/ray/local/alignment_pipeline.py +167 -0
  72. Rhapso/pipelines/ray/matching_stats.py +104 -0
  73. Rhapso/pipelines/ray/param/__init__.py +0 -0
  74. Rhapso/pipelines/ray/solver.py +120 -0
  75. Rhapso/pipelines/ray/split_dataset.py +78 -0
  76. Rhapso/solver/__init__.py +0 -0
  77. Rhapso/solver/compute_tiles.py +562 -0
  78. Rhapso/solver/concatenate_models.py +116 -0
  79. Rhapso/solver/connected_graphs.py +111 -0
  80. Rhapso/solver/data_prep.py +181 -0
  81. Rhapso/solver/global_optimization.py +410 -0
  82. Rhapso/solver/model_and_tile_setup.py +109 -0
  83. Rhapso/solver/pre_align_tiles.py +323 -0
  84. Rhapso/solver/save_results.py +97 -0
  85. Rhapso/solver/view_transforms.py +75 -0
  86. Rhapso/solver/xml_to_dataframe_solver.py +213 -0
  87. Rhapso/split_dataset/__init__.py +0 -0
  88. Rhapso/split_dataset/compute_grid_rules.py +78 -0
  89. Rhapso/split_dataset/save_points.py +101 -0
  90. Rhapso/split_dataset/save_xml.py +377 -0
  91. Rhapso/split_dataset/split_images.py +537 -0
  92. Rhapso/split_dataset/xml_to_dataframe_split.py +219 -0
  93. rhapso-0.1.92.dist-info/METADATA +39 -0
  94. rhapso-0.1.92.dist-info/RECORD +101 -0
  95. rhapso-0.1.92.dist-info/WHEEL +5 -0
  96. rhapso-0.1.92.dist-info/licenses/LICENSE +21 -0
  97. rhapso-0.1.92.dist-info/top_level.txt +2 -0
  98. tests/__init__.py +1 -0
  99. tests/test_detection.py +17 -0
  100. tests/test_matching.py +21 -0
  101. tests/test_solving.py +21 -0
@@ -0,0 +1,159 @@
1
+ """
2
+ Algorithm geometry primitives and utilities.
3
+ """
4
+ import numpy as np
5
+ import torch
6
+ from nptyping import NDArray, Shape
7
+
8
+ Matrix = NDArray[Shape["3, 4"], np.float64]
9
+ AABB = tuple[int, int, int, int, int, int]
10
+
11
+
12
+ class Transform:
13
+ """
14
+ Registration Transform implemented in PyTorch.
15
+ forward/backward transforms preserve the shape of the data.
16
+ """
17
+
18
+ def forward(
19
+ self, data: torch.Tensor, device: torch.device
20
+ ) -> torch.Tensor:
21
+ raise NotImplementedError("Please implement in Transform subclass.")
22
+
23
+ def backward(
24
+ self, data: torch.Tensor, device: torch.device
25
+ ) -> torch.Tensor:
26
+ raise NotImplementedError("Please implement in Transform subclass.")
27
+
28
+
29
+ class Affine(Transform):
30
+ """
31
+ Rotation + Translation Registration.
32
+ """
33
+
34
+ def __init__(self, matrix: Matrix):
35
+ super().__init__()
36
+ assert matrix.shape == (
37
+ 3,
38
+ 4,
39
+ ), "Matrix shape is {matrix.shape}, must be (3, 4)"
40
+
41
+ self.matrix = torch.Tensor(matrix)
42
+ self.matrix_3x3 = self.matrix[:, :3]
43
+ self.translation = self.matrix[:, 3]
44
+
45
+ self.backward_matrix_3x3 = torch.linalg.inv(self.matrix_3x3)
46
+ self.backward_translation = -self.translation
47
+
48
+ def forward(
49
+ self, data: torch.Tensor, device: torch.device
50
+ ) -> torch.Tensor:
51
+ """
52
+ Parameters:
53
+ -----------
54
+ data: (dims) + (3,)
55
+ data is a list/tensor of zyx vectors.
56
+
57
+ device: {cuda:n, 'cpu'}
58
+ device to perform computation on.
59
+
60
+ Returns:
61
+ --------
62
+ transformed_data: (dims) + (3,)
63
+ transformed_data is identical shape to the input.
64
+ transformed_data lives on the device specified
65
+
66
+ """
67
+ assert (
68
+ data.shape[-1] == 3
69
+ ), "Data shape is {data.shape}, last dimension of input data must be 3d."
70
+
71
+ # matrix: (3, 3) -> (1,)*(dims - 1) + (3, 3)
72
+ # data: (dims, 3) -> (dims, 3, 1)
73
+ # Ex:
74
+ # (3, 3) -> (1, 1, 1, 3, 3)
75
+ # (z, y, x, 3) -> (z, y, x, 3, 1)
76
+ dims = len(data.shape)
77
+ expanded_matrix = self.matrix_3x3[(None,) * (dims - 1)].to(device)
78
+ expanded_data = torch.unsqueeze(data, dims).to(device)
79
+
80
+ transformed_data = expanded_matrix @ expanded_data
81
+ transformed_data = torch.squeeze(transformed_data, -1)
82
+ transformed_data = transformed_data + self.translation.to(device)
83
+
84
+ return transformed_data
85
+
86
+ def backward(
87
+ self, data: torch.Tensor, device: torch.device
88
+ ) -> torch.Tensor:
89
+ """
90
+ Parameters:
91
+ -----------
92
+ data: (dims) + (3,)
93
+ data is a list/tensor of zyx vectors.
94
+
95
+ device: {cuda:n, 'cpu'}
96
+ device to perform computation on.
97
+
98
+ Returns:
99
+ --------
100
+ transformed_data: (dims) + (3,)
101
+ transformed_data is identical shape to the input.
102
+ transformed_data lives on the device specified
103
+ """
104
+
105
+ assert (
106
+ data.shape[-1] == 3
107
+ ), "Data shape is {data.shape}, last dimension of input data must be 3d."
108
+
109
+ # matrix: (3, 3) -> (1,)*(dims - 1) + (3, 3)
110
+ # data: (dims, 3) -> (dims, 3, 1)
111
+ # Ex:
112
+ # (3, 3) -> (1, 1, 1, 3, 3)
113
+ # (z, y, x, 3) -> (z, y, x, 3, 1)
114
+ dims = len(data.shape)
115
+ expanded_matrix = self.backward_matrix_3x3[(None,) * (dims - 1)].to(
116
+ device
117
+ )
118
+ expanded_data = torch.unsqueeze(data, dims).to(device)
119
+
120
+ transformed_data = expanded_matrix @ expanded_data
121
+ transformed_data = torch.squeeze(transformed_data, -1)
122
+ transformed_data = transformed_data + self.backward_translation.to(
123
+ device
124
+ )
125
+
126
+ return transformed_data
127
+
128
+
129
+ def aabb_3d(data) -> AABB:
130
+ """
131
+ Parameters:
132
+ -----------
133
+ data: (dims) + (3,)
134
+ data is a list/tensor of zyx vectors.
135
+
136
+ Returns:
137
+ --------
138
+ aabb: Ranges ordered in same order as components in input buffer.
139
+ """
140
+
141
+ assert (
142
+ data.shape[-1] == 3
143
+ ), "Data shape is {data.shape}, last dimension of input data must be 3d."
144
+ dims = len(data.shape)
145
+
146
+ output = []
147
+ for i in range(3):
148
+ # Slice syntax:
149
+ # (slice(None, None, None)) => arr[:]
150
+ # (i) => arr[i]
151
+ dim_slice = [slice(None, None, None)] * (dims - 1)
152
+ dim_slice = tuple(dim_slice + [i])
153
+
154
+ dim_min = torch.min(data[dim_slice]).item()
155
+ dim_max = torch.max(data[dim_slice]).item()
156
+ output.append(dim_min)
157
+ output.append(dim_max)
158
+
159
+ return tuple(output)
@@ -0,0 +1,546 @@
1
+ """
2
+ Defines all standard input to fusion algorithm.
3
+ """
4
+ from collections import OrderedDict
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+
8
+ import boto3
9
+ import dask.array as da
10
+ import numpy as np
11
+ from numcodecs import Blosc
12
+ import re
13
+ import s3fs
14
+ import tensorstore as ts
15
+ import xmltodict
16
+ import yaml
17
+ import zarr
18
+ import fsspec
19
+
20
+ from . import geometry
21
+
22
+
23
+ def read_config_yaml(yaml_path: str) -> dict:
24
+ with open(yaml_path, "r") as f:
25
+ yaml_dict = yaml.safe_load(f)
26
+ return yaml_dict
27
+
28
+
29
+ def write_config_yaml(yaml_path: str, yaml_data: dict) -> None:
30
+ with open(yaml_path, "w") as file:
31
+ yaml.dump(yaml_data, file)
32
+
33
+
34
+ def open_zarr_s3(bucket: str, path: str) -> ts.TensorStore:
35
+ return ts.open({
36
+ 'driver': 'zarr',
37
+ 'kvstore': {
38
+ 'driver': 'http',
39
+ 'base_url': f'https://{bucket}.s3.us-west-2.amazonaws.com/{path}',
40
+ },
41
+ }).result()
42
+
43
+
44
+ class InputArray:
45
+ def __getitem__(self, value):
46
+ """
47
+ Member function for slice syntax, ex: arr[0:10, 0:10]
48
+ Value is a Python slice object.
49
+ """
50
+ raise NotImplementedError("Please implement in InputArray subclass.")
51
+
52
+ @property
53
+ def shape(self):
54
+ raise NotImplementedError("Please implement in InputArray subclass.")
55
+
56
+
57
+ class InputDask(InputArray):
58
+ def __init__(self, arr: da.Array):
59
+ self.arr = arr
60
+
61
+ def __getitem__(self, slice):
62
+ return np.array(self.arr[slice].compute())
63
+
64
+ @property
65
+ def shape(self):
66
+ return self.arr.shape
67
+
68
+
69
+ class InputTensorstore(InputArray):
70
+ def __init__(self, arr: ts.TensorStore):
71
+ self.arr = arr
72
+
73
+ def __getitem__(self, slice):
74
+ return np.array(self.arr[slice])
75
+
76
+ @property
77
+ def shape(self):
78
+ return self.arr.shape
79
+
80
+
81
+ class Dataset:
82
+ """
83
+ Data are 5d tczyx objects.
84
+ Transforms are 3d zyx objects.
85
+ """
86
+
87
+ class WriteError(Exception):
88
+ pass
89
+
90
+ @property
91
+ def tile_volumes_tczyx(self) -> dict[int, InputArray]:
92
+ """
93
+ Dict of tile_id -> tile references.
94
+ """
95
+ raise NotImplementedError("Please implement in Dataset subclass.")
96
+
97
+ @tile_volumes_tczyx.setter
98
+ def tile_volumes_tczyx(self, value):
99
+ raise Dataset.WriteError("tile_volumes_tczyx is read-only.")
100
+
101
+ @property
102
+ def tile_transforms_zyx(self) -> dict[int, list[geometry.Transform]]:
103
+ """
104
+ Dict of tile_id -> tile transforms.
105
+ """
106
+ raise NotImplementedError("Please implement in Dataset subclass.")
107
+
108
+ @tile_transforms_zyx.setter
109
+ def tile_transforms_zyx(self, value):
110
+ raise Dataset.WriteError("tile_transforms_zyx is read-only.")
111
+
112
+ @property
113
+ def tile_resolution_zyx(self) -> tuple[float, float, float]:
114
+ """
115
+ Specifies absolute size of each voxel in tile volume.
116
+ Tile resolution is used to scale tile volume into absolute space prior to registration.
117
+ """
118
+ raise NotImplementedError("Please implement in Dataset subclass.")
119
+
120
+ @tile_resolution_zyx.setter
121
+ def tile_resolution_zyx(self, value):
122
+ raise Dataset.WriteError("tile_resolution_zyx is read-only.")
123
+
124
+
125
+ class BigStitcherDataset(Dataset):
126
+ """
127
+ Dataset class for loading in BigStitcher Dataset.
128
+ Intended for the base registration channel.
129
+ """
130
+
131
+ def __init__(self, xml_path: str, s3_path: str, datastore: int, level: int = 0):
132
+ self.xml_path = xml_path
133
+ self.s3_path = s3_path
134
+
135
+ assert datastore in [0, 1], \
136
+ f"Only 0 = Dask and 1 = Tensorstore supported."
137
+ self.datastore = datastore # {0 = Dask, 1 = Tensorstore}
138
+
139
+ allowed_levels = [0, 1, 2, 3, 4, 5]
140
+ assert level in allowed_levels, \
141
+ f"Level {level} is not in {allowed_levels}"
142
+ self.level = level
143
+
144
+ self.tile_cache: dict[int, InputArray] = {}
145
+ self.transform_cache: dict[int, list[geometry.Transform]] = {}
146
+
147
+ @property
148
+ def tile_volumes_tczyx(self) -> dict[int, InputArray]:
149
+ if len(self.tile_cache) != 0:
150
+ tile_paths = self._extract_tile_paths(self.xml_path)
151
+ tile_paths = self._extract_tile_paths(self.xml_path)
152
+ for t_id, t_path in tile_paths.items():
153
+ if not self.s3_path.endswith('/'):
154
+ self.s3_path = self.s3_path + '/'
155
+
156
+ level_str = '/' + str(self.level) # Ex: '/0'
157
+ tile_paths[t_id] = self.s3_path + Path(t_path).name + level_str
158
+ return self.tile_cache, tile_paths
159
+
160
+ # Otherwise, fetch for first time
161
+ tile_paths = self._extract_tile_paths(self.xml_path)
162
+ for t_id, t_path in tile_paths.items():
163
+ if not self.s3_path.endswith('/'):
164
+ self.s3_path = self.s3_path + '/'
165
+
166
+ level_str = '/' + str(self.level) # Ex: '/0'
167
+ tile_paths[t_id] = self.s3_path + Path(t_path).name + level_str
168
+
169
+ tile_arrays: dict[int, InputArray] = {}
170
+ for tile_id, t_path in tile_paths.items():
171
+
172
+ arr = None
173
+ if self.datastore == 0: # Dask
174
+ tile_zarr = da.from_zarr(t_path)
175
+ arr = InputDask(tile_zarr)
176
+ elif self.datastore == 1: # Tensorstore
177
+ # Referencing the following naming convention:
178
+ # s3://BUCKET_NAME/DATASET_NAME/TILE/NAME/CHANNEL
179
+ parts = t_path.split('/')
180
+ bucket = parts[2]
181
+ third_slash_index = len(parts[0]) + len(parts[1]) + len(parts[2]) + 3
182
+ obj = t_path[third_slash_index:]
183
+
184
+ tile_zarr = open_zarr_s3(bucket, obj)
185
+ arr = InputTensorstore(tile_zarr)
186
+
187
+ print(f'Loading Tile {tile_id} / {len(tile_paths)}')
188
+ tile_arrays[int(tile_id)] = arr
189
+
190
+ self.tile_cache = tile_arrays
191
+
192
+ return tile_arrays, tile_paths
193
+
194
+ @property
195
+ def tile_transforms_zyx(self) -> dict[int, list[geometry.Transform]]:
196
+ if len(self.transform_cache) != 0:
197
+ return self.transform_cache
198
+
199
+ # Otherwise, fetch for first time
200
+ tile_tfms = self._extract_tile_transforms(self.xml_path)
201
+ tile_net_tfms = self._calculate_net_transforms(tile_tfms)
202
+
203
+ for tile_id, tfm in tile_net_tfms.items():
204
+ # BigStitcher XYZ -> ZYX
205
+ # Given Matrix 3x4:
206
+ # Swap Rows 0 and 2; Swap Colums 0 and 2
207
+ tmp = np.copy(tfm)
208
+ tmp[[0, 2], :] = tmp[[2, 0], :]
209
+ tmp[:, [0, 2]] = tmp[:, [2, 0]]
210
+ tfm = tmp
211
+
212
+ # Assemble matrix stack:
213
+ # 1) Add base registration
214
+ matrix_stack = [geometry.Affine(tfm)]
215
+
216
+ # 2) Append up/down-sampling transforms
217
+ sf = 2. ** self.level
218
+ up = geometry.Affine(np.array([[sf, 0., 0., 0.],
219
+ [0., sf, 0., 0.],
220
+ [0., 0., sf, 0.]]))
221
+ down = geometry.Affine(np.array([[1./sf, 0., 0., 0.],
222
+ [0., 1./sf, 0., 0.],
223
+ [0., 0., 1./sf, 0.]]))
224
+ matrix_stack.insert(0, up)
225
+ matrix_stack.append(down)
226
+ tile_net_tfms[int(tile_id)] = matrix_stack
227
+
228
+ self.transform_cache = tile_net_tfms
229
+
230
+ return tile_net_tfms
231
+
232
+ @property
233
+ def tile_resolution_zyx(self) -> tuple[float, float, float]:
234
+ if self.xml_path.startswith("s3://"):
235
+ with fsspec.open(self.xml_path, mode="rt") as f:
236
+ data: OrderedDict = xmltodict.parse(f.read())
237
+ else:
238
+ with open(self.xml_path, "r") as file:
239
+ data: OrderedDict = xmltodict.parse(file.read())
240
+
241
+ resolution_str = data["SpimData"]["SequenceDescription"]["ViewSetups"][
242
+ "ViewSetup"
243
+ ][0]["voxelSize"]["size"]
244
+ resolution_xyz = [float(num) for num in resolution_str.split(" ")]
245
+ return tuple(resolution_xyz[::-1])
246
+
247
+ def _extract_tile_paths(self, xml_path: str) -> dict[int, str]:
248
+ """
249
+ Utility called in property.
250
+ Parses BDV xml and outputs map of setup_id -> tile path.
251
+
252
+ Parameters
253
+ ------------------------
254
+ xml_path: str
255
+ Path of xml outputted from BigStitcher.
256
+
257
+ Returns
258
+ ------------------------
259
+ dict[int, str]:
260
+ Dictionary of tile ids to tile paths.
261
+ """
262
+ view_paths: dict[int, str] = {}
263
+
264
+ if xml_path.startswith("s3://"):
265
+ with fsspec.open(xml_path, mode="rt") as f:
266
+ data: OrderedDict = xmltodict.parse(f.read())
267
+ else:
268
+ with open(xml_path, "r") as file:
269
+ data: OrderedDict = xmltodict.parse(file.read())
270
+
271
+ parent = data["SpimData"]["SequenceDescription"]["ImageLoader"][
272
+ "zarr"
273
+ ]["#text"]
274
+
275
+ for i, zgroup in enumerate(
276
+ data["SpimData"]["SequenceDescription"]["ImageLoader"]["zgroups"][
277
+ "zgroup"
278
+ ]
279
+ ):
280
+ view_paths[i] = parent + "/" + zgroup["@path"]
281
+
282
+ return view_paths
283
+
284
+ def _extract_tile_transforms(self, xml_path: str) -> dict[int, list[dict]]:
285
+ """
286
+ Utility called in property.
287
+ Parses BDV xml and outputs map of setup_id -> list of transformations
288
+ Output dictionary maps view number to list of {'@type', 'Name', 'affine'}
289
+ where 'affine' contains the transform as string of 12 floats.
290
+
291
+ Matrices are listed in the order of forward execution.
292
+
293
+ Parameters
294
+ ------------------------
295
+ xml_path: str
296
+ Path of xml outputted by BigStitcher.
297
+
298
+ Returns
299
+ ------------------------
300
+ dict[int, list[dict]]
301
+ Dictionary of tile ids to transform list. List entries described above.
302
+ """
303
+
304
+ view_transforms: dict[int, list[dict]] = {}
305
+
306
+ if xml_path.startswith("s3://"):
307
+ with fsspec.open(xml_path, mode="rt") as f:
308
+ data: OrderedDict = xmltodict.parse(f.read())
309
+ else:
310
+ with open(xml_path, "r") as file:
311
+ data: OrderedDict = xmltodict.parse(file.read())
312
+
313
+ for view_reg in data["SpimData"]["ViewRegistrations"][
314
+ "ViewRegistration"
315
+ ]:
316
+ tfm_stack = view_reg["ViewTransform"]
317
+ if type(tfm_stack) is not list:
318
+ tfm_stack = [tfm_stack]
319
+ view_transforms[int(view_reg["@setup"])] = tfm_stack
320
+
321
+ view_transforms = {
322
+ view: tfs[::-1] for view, tfs in view_transforms.items()
323
+ }
324
+
325
+ return view_transforms
326
+
327
+ def _calculate_net_transforms(
328
+ self, view_transforms: dict[int, list[dict]]
329
+ ) -> dict[int, geometry.Matrix]:
330
+ """
331
+ Utility called in property.
332
+ Accumulate net transform and net translation for each matrix stack.
333
+ Net translation =
334
+ Sum of translation vectors converted into original nominal basis
335
+ Net transform =
336
+ Product of 3x3 matrices
337
+ NOTE: Translational component (last column) is defined
338
+ wrt to the DOMAIN, not codomain.
339
+ Implementation is informed by this given.
340
+
341
+ Parameters
342
+ ------------------------
343
+ view_transforms: dict[int, list[dict]]
344
+ Dictionary of tile ids to transforms associated with each tile.
345
+
346
+ Returns
347
+ ------------------------
348
+ dict[int, np.ndarray]:
349
+ Dictionary of tile ids to net_transform.
350
+ """
351
+
352
+ identity_transform = np.array(
353
+ [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]]
354
+ )
355
+ net_transforms: dict[int, np.ndarray] = {}
356
+ for tile_id in view_transforms:
357
+ net_transforms[tile_id] = np.copy(identity_transform)
358
+
359
+ for view, tfs in view_transforms.items():
360
+ net_translation = np.zeros(3)
361
+ net_matrix_3x3 = np.eye(3)
362
+ curr_inverse = np.eye(3)
363
+
364
+ for (
365
+ tf
366
+ ) in (
367
+ tfs
368
+ ): # Tfs is a list of dicts containing transform under 'affine' key
369
+ nums = [float(val) for val in tf["affine"].split(" ")]
370
+ matrix_3x3 = np.array([nums[0::4], nums[1::4], nums[2::4]])
371
+ translation = np.array(nums[3::4])
372
+
373
+ net_translation = net_translation + (
374
+ curr_inverse @ translation
375
+ )
376
+ net_matrix_3x3 = matrix_3x3 @ net_matrix_3x3
377
+ curr_inverse = np.linalg.inv(
378
+ net_matrix_3x3
379
+ ) # Update curr_inverse
380
+
381
+ net_transforms[view] = np.hstack(
382
+ (net_matrix_3x3, net_translation.reshape(3, 1))
383
+ )
384
+
385
+ return net_transforms
386
+
387
+
388
+ class BigStitcherDatasetChannel(BigStitcherDataset):
389
+ """
390
+ Convenience Dataset class that reuses tile registrations,
391
+ tile shapes, and tile resolution across channels.
392
+ Tile volumes is overloaded with channel-specific data.
393
+
394
+ NOTE: Only loads full resolution images/registrations.
395
+ """
396
+
397
+ def __init__(self, xml_path: str, s3_path: str, channel_num: int, datastore: int):
398
+ """
399
+ Only new information required is channel number.
400
+ """
401
+ super().__init__(xml_path, s3_path, datastore)
402
+ self.channel_num = channel_num
403
+
404
+ self.tile_cache: dict[int, InputArray] = {}
405
+
406
+ @property
407
+ def tile_volumes_tczyx(self) -> dict[int, InputArray]:
408
+ """
409
+ Load in channel-specific tiles.
410
+ """
411
+
412
+ if len(self.tile_cache) != 0:
413
+ return self.tile_cache
414
+
415
+ # Otherwise fetch for first time
416
+ tile_arrays: dict[int, InputArray] = {}
417
+
418
+ with open(self.xml_path, "r") as file:
419
+ data: OrderedDict = xmltodict.parse(file.read())
420
+ tile_id_lut = {}
421
+ for zgroup in data['SpimData']['SequenceDescription']['ImageLoader']['zgroups']['zgroup']:
422
+ tile_id = zgroup['@setup']
423
+ tile_name = zgroup['path']
424
+ s_parts = tile_name.split('_')
425
+ location = (int(s_parts[2]),
426
+ int(s_parts[4]),
427
+ int(s_parts[6]))
428
+ tile_id_lut[location] = int(tile_id)
429
+
430
+ # Reference path: s3://aind-open-data/HCR_677594_2023-10-20_15-10-36/SPIM.ome.zarr/
431
+ # Reference tilename: <tile_name, no underscores>_X_####_Y_####_Z_####_ch_###.zarr
432
+ slash_2 = self.s3_path.find('/', self.s3_path.find('/') + 1)
433
+ slash_3 = self.s3_path.find('/', self.s3_path.find('/', self.s3_path.find('/') + 1) + 1)
434
+ bucket_name = self.s3_path[slash_2 + 1:slash_3]
435
+ directory_path = self.s3_path[slash_3 + 1:]
436
+
437
+ for p in self._list_bucket_directory(bucket_name, directory_path):
438
+ if p.endswith('.zgroup'):
439
+ continue
440
+
441
+ # Data loading
442
+ channel_num = -1
443
+ search_result = re.search(r'(\d*)\.zarr.?$', p)
444
+ if search_result:
445
+ channel_num = int(search_result.group(1))
446
+ if channel_num == self.channel_num:
447
+
448
+ full_resolution_p = self.s3_path + p + '/0'
449
+ s_parts = p.split('_')
450
+ location = (int(s_parts[2]),
451
+ int(s_parts[4]),
452
+ int(s_parts[6]))
453
+ tile_id = tile_id_lut[location]
454
+
455
+ arr = None
456
+ if self.datastore == 0: # Dask
457
+ tile_zarr = da.from_zarr(full_resolution_p)
458
+ arr = InputDask(tile_zarr)
459
+
460
+ elif self.datastore == 1: # Tensorstore
461
+ # Referencing the following naming convention:
462
+ # s3://BUCKET_NAME/DATASET_NAME/TILE/NAME/CHANNEL
463
+ parts = full_resolution_p.split('/')
464
+ bucket = parts[2]
465
+ third_slash_index = len(parts[0]) + len(parts[1]) + len(parts[2]) + 3
466
+ obj = full_resolution_p[third_slash_index:]
467
+
468
+ tile_zarr = open_zarr_s3(bucket, obj)
469
+ arr = InputTensorstore(tile_zarr)
470
+
471
+ print(f'Loading Tile {tile_id} / {len(tile_id_lut)}')
472
+ tile_arrays[int(tile_id)] = arr
473
+
474
+ self.tile_cache = tile_arrays
475
+
476
+ return tile_arrays
477
+
478
+ def _list_bucket_directory(self, bucket_name: str, directory_path: str):
479
+ client = boto3.client("s3")
480
+ result = client.list_objects(
481
+ Bucket=bucket_name, Prefix=directory_path, Delimiter="/"
482
+ )
483
+
484
+ paths = [] # These are paths
485
+ for o in result.get("CommonPrefixes"):
486
+ paths.append(o.get("Prefix"))
487
+
488
+ # Parse the ending files from the paths
489
+ files = []
490
+ for p in paths:
491
+ if p.endswith('/'):
492
+ p = p.rstrip("/") # Remove trailing slash from directories
493
+
494
+ parts = p.split('/')
495
+ files.append(parts[-1])
496
+
497
+ return files
498
+
499
+
500
+ class OutputArray:
501
+ def __setitem__(self, index, value):
502
+ raise NotImplementedError("Please implement in InputArray subclass.")
503
+
504
+
505
+ class OutputDask(OutputArray):
506
+ def __init__(self, arr: da.Array):
507
+ self.arr = arr
508
+
509
+ def __setitem__(self, index, value):
510
+ self.arr[index] = value
511
+
512
+
513
+ class OutputTensorstore(OutputArray):
514
+ def __init__(self, arr: ts.TensorStore):
515
+ self.arr = arr
516
+
517
+ def __setitem__(self, index, value):
518
+ self.arr[index].write(value).result()
519
+
520
+
521
+ @dataclass
522
+ class OutputParameters:
523
+ path: str
524
+ chunksize: tuple[int, int, int, int, int]
525
+ resolution_zyx: tuple[float, float, float]
526
+ datastore: int # {0 == Dask, 1 == Tensorstore}
527
+ dtype: np.dtype = np.uint16
528
+ dimension_separator: str = "/"
529
+ compressor = Blosc(cname='zstd', clevel=1, shuffle=Blosc.SHUFFLE)
530
+
531
+
532
+ @dataclass
533
+ class RuntimeParameters:
534
+ """
535
+ Simplified Runtime Parameters
536
+ option:
537
+ 0: single process exectution
538
+ 1: multiprocessing execution
539
+ 2: dask execution
540
+ pool_size: number of processes/vCPUs for options {1, 2}
541
+ worker_cells:
542
+ list of cells/chunks this execution operates on
543
+ """
544
+ option: int
545
+ pool_size: int
546
+ worker_cells: list[tuple[int, int, int]]