volsegtools 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- volsegtools/__init__.py +0 -0
- volsegtools/_cli/__init__.py +1 -0
- volsegtools/_cli/molstar_preprocessor.py +79 -0
- volsegtools/abc/__init__.py +6 -0
- volsegtools/abc/converter.py +69 -0
- volsegtools/abc/data_handle.py +24 -0
- volsegtools/abc/downsampler.py +26 -0
- volsegtools/abc/kernel.py +8 -0
- volsegtools/abc/preprocessor.py +38 -0
- volsegtools/abc/serializer.py +12 -0
- volsegtools/converter/__init__.py +1 -0
- volsegtools/converter/map_converter.py +148 -0
- volsegtools/core/__init__.py +5 -0
- volsegtools/core/bounds.py +12 -0
- volsegtools/core/downsampling_parameters.py +28 -0
- volsegtools/core/gaussian_kernel_3D.py +16 -0
- volsegtools/core/lattice_kind.py +8 -0
- volsegtools/core/vector.py +9 -0
- volsegtools/downsampler/__init__.py +2 -0
- volsegtools/downsampler/base_downsampler.py +20 -0
- volsegtools/downsampler/hierarchy_downsampler.py +253 -0
- volsegtools/model/__init__.py +13 -0
- volsegtools/model/chunking_mode.py +16 -0
- volsegtools/model/metadata.py +50 -0
- volsegtools/model/opaque_data_handle.py +112 -0
- volsegtools/model/storing_parameters.py +52 -0
- volsegtools/model/working_store.py +142 -0
- volsegtools/preprocessor/__init__.py +2 -0
- volsegtools/preprocessor/preprocessor.py +75 -0
- volsegtools/preprocessor/preprocessor_builder.py +110 -0
- volsegtools/serialization/__init__.py +1 -0
- volsegtools/serialization/bcif_serializer.py +318 -0
- volsegtools/typing.py +12 -0
- volsegtools-0.0.0.dist-info/METADATA +22 -0
- volsegtools-0.0.0.dist-info/RECORD +38 -0
- volsegtools-0.0.0.dist-info/WHEEL +5 -0
- volsegtools-0.0.0.dist-info/entry_points.txt +2 -0
- volsegtools-0.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import collections
|
|
3
|
+
import logging
|
|
4
|
+
import math
|
|
5
|
+
from typing import List, Tuple
|
|
6
|
+
|
|
7
|
+
import dask
|
|
8
|
+
import dask.array as da
|
|
9
|
+
import dask_image.ndfilters as dask_filter
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
from volsegtools.core import (
|
|
13
|
+
Bounds,
|
|
14
|
+
DownsamplingParameters,
|
|
15
|
+
LatticeKind,
|
|
16
|
+
Vector3,
|
|
17
|
+
to_bytes,
|
|
18
|
+
)
|
|
19
|
+
from volsegtools.downsampler import BaseDownsampler
|
|
20
|
+
from volsegtools.model import (
|
|
21
|
+
ChannelMetadata,
|
|
22
|
+
DescriptiveStatistics,
|
|
23
|
+
FlatChannelIterator,
|
|
24
|
+
OpaqueDataHandle,
|
|
25
|
+
StoringParameters,
|
|
26
|
+
TimeFrameMetadata,
|
|
27
|
+
)
|
|
28
|
+
from volsegtools.model.working_store import WorkingStore
|
|
29
|
+
|
|
30
|
+
MIN_GRID_SIZE = 100**1
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class HierarchyDownsampler(BaseDownsampler):
|
|
34
|
+
""" """
|
|
35
|
+
|
|
36
|
+
# This value was used in the previous version of the preprocessor.
|
|
37
|
+
KERNEL_PARAMETERS: Tuple[int, int, int] = (1, 4, 6)
|
|
38
|
+
|
|
39
|
+
def __init__(self):
|
|
40
|
+
_parameters = DownsamplingParameters()
|
|
41
|
+
super().__init__(_parameters)
|
|
42
|
+
|
|
43
|
+
"""
|
|
44
|
+
def __init__(self, params = DownsamplingParameters()):
|
|
45
|
+
super().__init__(params)
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
async def downsample_lattice(
|
|
49
|
+
self, data: OpaqueDataHandle
|
|
50
|
+
) -> List[OpaqueDataHandle]:
|
|
51
|
+
# We have to get the actual data from the zarr store.
|
|
52
|
+
store = WorkingStore.instance
|
|
53
|
+
channel_iter = FlatChannelIterator(
|
|
54
|
+
store.get_data_group(data.metadata.kind).require_group(
|
|
55
|
+
data.metadata.lattice_id
|
|
56
|
+
)
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
ret_value = []
|
|
60
|
+
|
|
61
|
+
# Each time frame can have multiple channels, this flat channel
|
|
62
|
+
# iterator iterates over each channel in each resolution for this
|
|
63
|
+
# particular data lattice.
|
|
64
|
+
for channel_info in channel_iter:
|
|
65
|
+
if 1 in channel_info.data.shape:
|
|
66
|
+
# TODO: Add some kind of message, that it does not make
|
|
67
|
+
# sense, fi there is only a single dimension.
|
|
68
|
+
continue
|
|
69
|
+
|
|
70
|
+
# This is always true, becase we are reading from the zarr store.
|
|
71
|
+
dask_arr = da.from_zarr(
|
|
72
|
+
url=channel_info.data,
|
|
73
|
+
chunks=channel_info.data.chunks,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
current_level_data = dask_arr
|
|
77
|
+
|
|
78
|
+
downsampling_steps = self._calculate_downsampling_steps_count(
|
|
79
|
+
dask_arr,
|
|
80
|
+
)
|
|
81
|
+
logging.info(f"Downsampling steps {downsampling_steps}")
|
|
82
|
+
|
|
83
|
+
downsampling_levels = self._calculate_downsampling_levels(
|
|
84
|
+
dask_arr,
|
|
85
|
+
downsampling_steps=downsampling_steps,
|
|
86
|
+
# TODO: this should be changeable with some "additional_downsampler_config" which
|
|
87
|
+
# would be just an arbitrary dictionary.
|
|
88
|
+
factor=8,
|
|
89
|
+
)
|
|
90
|
+
logging.info(f"Downsampling levels {downsampling_levels}")
|
|
91
|
+
|
|
92
|
+
for step in range(downsampling_steps):
|
|
93
|
+
# TODO use step to compute the voxel size
|
|
94
|
+
current_ratio = 2 ** (step + 1)
|
|
95
|
+
logging.info(
|
|
96
|
+
f"Currently downsampling r{current_ratio}, {channel_info.time} and ch{channel_info.channel}"
|
|
97
|
+
)
|
|
98
|
+
downsampled_data = dask_filter.convolve(
|
|
99
|
+
current_level_data,
|
|
100
|
+
self.parameters.kernel.as_ndarray(),
|
|
101
|
+
mode="mirror",
|
|
102
|
+
cval=0.0,
|
|
103
|
+
)
|
|
104
|
+
# TODO: find out what this does, removes the neighbor?
|
|
105
|
+
downsampled_data = downsampled_data[::2, ::2, ::2]
|
|
106
|
+
|
|
107
|
+
if current_ratio not in downsampling_levels:
|
|
108
|
+
continue
|
|
109
|
+
|
|
110
|
+
if self.parameters.acceptance_threshold != None:
|
|
111
|
+
logging.info("Using the acceptance threshold")
|
|
112
|
+
downsampled_data[
|
|
113
|
+
downsampled_data >= self.parameters.acceptance_threshold
|
|
114
|
+
]
|
|
115
|
+
|
|
116
|
+
if self.parameters.is_mask:
|
|
117
|
+
# TODO: this is inefficient point of usage, merge this
|
|
118
|
+
# acceptance threshold check
|
|
119
|
+
logging.info("Converting to Mask")
|
|
120
|
+
downsampled_data = da.where(
|
|
121
|
+
downsampled_data > self.parameters.acceptance_threshold, 1, 0
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
stats = dask.compute(
|
|
125
|
+
da.mean(downsampled_data),
|
|
126
|
+
da.std(downsampled_data),
|
|
127
|
+
downsampled_data.max(),
|
|
128
|
+
downsampled_data.min(),
|
|
129
|
+
)
|
|
130
|
+
stats = DescriptiveStatistics(*stats)
|
|
131
|
+
|
|
132
|
+
data_ref = OpaqueDataHandle(downsampled_data)
|
|
133
|
+
data_ref.metadata = data.metadata
|
|
134
|
+
|
|
135
|
+
# Only change things that are really different.
|
|
136
|
+
data_ref.metadata.id = int(channel_info.time.split("_")[-1])
|
|
137
|
+
data_ref.metadata.resolution = current_ratio
|
|
138
|
+
data_ref.metadata.lattice_dimensions = Vector3(
|
|
139
|
+
downsampled_data.shape[0],
|
|
140
|
+
downsampled_data.shape[1],
|
|
141
|
+
downsampled_data.shape[2],
|
|
142
|
+
)
|
|
143
|
+
data_ref.metadata.channels.append(
|
|
144
|
+
ChannelMetadata(int(channel_info.channel), stats)
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
ret_value.append(data_ref)
|
|
148
|
+
|
|
149
|
+
params = StoringParameters(
|
|
150
|
+
resolution_level=current_ratio,
|
|
151
|
+
time_frame=int(channel_info.time.split("_")[-1]),
|
|
152
|
+
channel=int(channel_info.channel),
|
|
153
|
+
storage_dtype=np.byte
|
|
154
|
+
if self.parameters.is_mask
|
|
155
|
+
else downsampled_data.dtype,
|
|
156
|
+
lattice_kind=data.metadata.kind,
|
|
157
|
+
)
|
|
158
|
+
WorkingStore.instance.store_lattice_time_frame(
|
|
159
|
+
params,
|
|
160
|
+
downsampled_data,
|
|
161
|
+
data.metadata.lattice_id,
|
|
162
|
+
)
|
|
163
|
+
current_level_data = downsampled_data
|
|
164
|
+
return ret_value
|
|
165
|
+
|
|
166
|
+
def _calculate_downsampling_steps_count(
|
|
167
|
+
self,
|
|
168
|
+
data: da.Array,
|
|
169
|
+
downsampling_factor: int = 8,
|
|
170
|
+
) -> int:
|
|
171
|
+
"""Calculates the number of steps that are going to be taken during
|
|
172
|
+
the downsampling of the input data.
|
|
173
|
+
|
|
174
|
+
Parameters
|
|
175
|
+
----------
|
|
176
|
+
data: da.Array
|
|
177
|
+
The input data that shall be downsampled.
|
|
178
|
+
downsampling_factor: int
|
|
179
|
+
The factor of downsampling.
|
|
180
|
+
|
|
181
|
+
Returns
|
|
182
|
+
-------
|
|
183
|
+
int:
|
|
184
|
+
the number of downsampling steps.
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
steps_count: int = 0
|
|
188
|
+
|
|
189
|
+
# Steps are calculated either from bounds provided as downsampling
|
|
190
|
+
# parameters, if any. In that case the maximal bound has priority over
|
|
191
|
+
# the minimal bound as it is the user's decision.
|
|
192
|
+
# Otherwise we have to compute them manually.
|
|
193
|
+
if self.parameters.downsampling_level_bounds:
|
|
194
|
+
level_bounds: Bounds = self.parameters.downsampling_level_bounds
|
|
195
|
+
if level_bounds.min:
|
|
196
|
+
steps_count = int(math.log2(level_bounds.min))
|
|
197
|
+
if level_bounds.max:
|
|
198
|
+
steps_count = int(math.log2(level_bounds.max))
|
|
199
|
+
else:
|
|
200
|
+
input_grid_size: float = math.prod(data.shape)
|
|
201
|
+
|
|
202
|
+
if input_grid_size <= MIN_GRID_SIZE:
|
|
203
|
+
return 1
|
|
204
|
+
|
|
205
|
+
file_size_in_bytes = data.dtype.itemsize * input_grid_size
|
|
206
|
+
size_per_downsampling = file_size_in_bytes / to_bytes(
|
|
207
|
+
self.parameters.size_per_level_bounds_in_mb.min
|
|
208
|
+
)
|
|
209
|
+
steps_count = int(math.log(size_per_downsampling, downsampling_factor))
|
|
210
|
+
|
|
211
|
+
return steps_count
|
|
212
|
+
|
|
213
|
+
def _calculate_downsampling_levels(
|
|
214
|
+
self,
|
|
215
|
+
data: da.Array,
|
|
216
|
+
factor: int,
|
|
217
|
+
downsampling_steps: int,
|
|
218
|
+
) -> List[int]:
|
|
219
|
+
levels: List[int] = [2**x for x in range(1, downsampling_steps + 1)]
|
|
220
|
+
|
|
221
|
+
if self.parameters.downsampling_level_bounds:
|
|
222
|
+
level_bounds: Bounds = self.parameters.downsampling_level_bounds
|
|
223
|
+
if level_bounds.max:
|
|
224
|
+
predicate = lambda x: x <= level_bounds.max
|
|
225
|
+
levels = [x for x in levels if predicate(x)]
|
|
226
|
+
if level_bounds.min:
|
|
227
|
+
predicate = lambda x: x >= level_bounds.min
|
|
228
|
+
levels = [x for x in levels if predicate(x)]
|
|
229
|
+
|
|
230
|
+
size_per_level: int = self.parameters.size_per_level_bounds_in_mb.max
|
|
231
|
+
if size_per_level:
|
|
232
|
+
# TODO: make this a parameters
|
|
233
|
+
input_grid_size: float = math.prod(data.shape)
|
|
234
|
+
file_size_in_bytes = data.dtype.itemsize * input_grid_size
|
|
235
|
+
# TODO: this needs a better name
|
|
236
|
+
n = math.ceil(
|
|
237
|
+
math.log(
|
|
238
|
+
file_size_in_bytes
|
|
239
|
+
/ (
|
|
240
|
+
# TODO: Make this a function or something
|
|
241
|
+
size_per_level * 1024**2
|
|
242
|
+
),
|
|
243
|
+
factor,
|
|
244
|
+
)
|
|
245
|
+
)
|
|
246
|
+
levels = [x for x in levels if x >= 2**n]
|
|
247
|
+
|
|
248
|
+
if len(levels) == 0:
|
|
249
|
+
raise RuntimeError(
|
|
250
|
+
"No downsamplings could be saved because the max size per"
|
|
251
|
+
f"channel ({size_per_level}) is too low"
|
|
252
|
+
)
|
|
253
|
+
return levels
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from .chunking_mode import ChunkingMode
|
|
2
|
+
|
|
3
|
+
# from .downsampling_data import Data, ChannelInfo, FlatChannelIterator
|
|
4
|
+
from .metadata import (
|
|
5
|
+
ChannelMetadata,
|
|
6
|
+
DescriptiveStatistics,
|
|
7
|
+
Metadata,
|
|
8
|
+
OriginalTimeFrameMetadata,
|
|
9
|
+
TimeFrameMetadata,
|
|
10
|
+
)
|
|
11
|
+
from .opaque_data_handle import ChannelInfo, FlatChannelIterator, OpaqueDataHandle
|
|
12
|
+
from .storing_parameters import StoringParameters
|
|
13
|
+
from .working_store import WorkingStore
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
from volsegtools.core import LatticeKind, Vector3
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclasses.dataclass
|
|
8
|
+
class DescriptiveStatistics:
|
|
9
|
+
"""Represents statistics that should be collected for some data set for
|
|
10
|
+
it to be representable in CIF.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
mean: float
|
|
14
|
+
std: float
|
|
15
|
+
max: float
|
|
16
|
+
min: float
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclasses.dataclass
|
|
20
|
+
class ChannelMetadata:
|
|
21
|
+
id: int
|
|
22
|
+
statistics: DescriptiveStatistics
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclasses.dataclass
|
|
26
|
+
class TimeFrameMetadata:
|
|
27
|
+
# TODO: rename to 'name'
|
|
28
|
+
axis_order: Vector3 = dataclasses.field(default_factory=Vector3)
|
|
29
|
+
lattice_id: str = "unknown"
|
|
30
|
+
kind: LatticeKind = dataclasses.field(default=LatticeKind.VOLUME)
|
|
31
|
+
id: int = -1
|
|
32
|
+
axis_order: Vector3 = dataclasses.field(default_factory=Vector3)
|
|
33
|
+
resolution: int = -1
|
|
34
|
+
origin: Vector3 = dataclasses.field(default_factory=Vector3)
|
|
35
|
+
lattice_dimensions: Vector3 = dataclasses.field(default_factory=Vector3)
|
|
36
|
+
voxel_size: Vector3 = dataclasses.field(default_factory=Vector3)
|
|
37
|
+
channels: List[ChannelMetadata] = dataclasses.field(default_factory=list)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclasses.dataclass
|
|
41
|
+
class OriginalTimeFrameMetadata(TimeFrameMetadata):
|
|
42
|
+
axis_order: Vector3 = dataclasses.field(default_factory=Vector3)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclasses.dataclass
|
|
46
|
+
class Metadata:
|
|
47
|
+
original_time_frame: OriginalTimeFrameMetadata = dataclasses.field(
|
|
48
|
+
default_factory=OriginalTimeFrameMetadata
|
|
49
|
+
)
|
|
50
|
+
time_frames: List[TimeFrameMetadata] = dataclasses.field(default_factory=list)
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
|
|
3
|
+
import dask.array as da
|
|
4
|
+
import numpy as np
|
|
5
|
+
from numpy.typing import ArrayLike
|
|
6
|
+
from zarr.core.array import Array as ZarrArray
|
|
7
|
+
|
|
8
|
+
from volsegtools.abc.data_handle import DataHandle
|
|
9
|
+
from volsegtools.model.metadata import TimeFrameMetadata
|
|
10
|
+
|
|
11
|
+
# cuPy is an optional import to the volseg-tools (as CUDA may not be available
|
|
12
|
+
# everywhere)
|
|
13
|
+
try:
|
|
14
|
+
import cupy as cp
|
|
15
|
+
except ImportError:
|
|
16
|
+
# Define the array type as something inacessible
|
|
17
|
+
...
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TimeFrameIterator: ...
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ResolutionIterator: ...
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ChannelIterator: ...
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclasses.dataclass
|
|
30
|
+
class ChannelInfo:
|
|
31
|
+
resolution: str
|
|
32
|
+
time: str
|
|
33
|
+
channel: str
|
|
34
|
+
data: zarr.Array
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class FlatChannelIterator:
|
|
38
|
+
def __init__(self, group):
|
|
39
|
+
self.group = group
|
|
40
|
+
self._iter = self._group_iter()
|
|
41
|
+
|
|
42
|
+
def _group_iter(self):
|
|
43
|
+
for resolution, resolution_group in self.group.groups():
|
|
44
|
+
for time, time_group in resolution_group.groups():
|
|
45
|
+
for channel, channel_arr in time_group.arrays():
|
|
46
|
+
yield ChannelInfo(resolution, time, channel, channel_arr)
|
|
47
|
+
|
|
48
|
+
def __iter__(self):
|
|
49
|
+
return self
|
|
50
|
+
|
|
51
|
+
def __next__(self) -> ChannelInfo:
|
|
52
|
+
return next(self._iter)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class OpaqueDataHandle(DataHandle):
|
|
56
|
+
"""Wrapper around basic volseg-tools data model.
|
|
57
|
+
|
|
58
|
+
The data held by the data model can potentially be very large and have to
|
|
59
|
+
be storred in files in the files system (usually by leveraging Zarr store).
|
|
60
|
+
|
|
61
|
+
This is a wrapper around this file that is passed between the different
|
|
62
|
+
stages and holds the reference to the data stored in the file system. If
|
|
63
|
+
unwrapped then it is directly loaded to the memory as an array.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def __init__(self, reference):
|
|
67
|
+
if type(reference) is np.ndarray:
|
|
68
|
+
if reference.base is None:
|
|
69
|
+
raise RuntimeError("Data reference can be made only from a view")
|
|
70
|
+
self._metadata = TimeFrameMetadata()
|
|
71
|
+
self._internal_repr = reference
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def metadata(self) -> TimeFrameMetadata:
|
|
75
|
+
return self._metadata
|
|
76
|
+
|
|
77
|
+
@metadata.setter
|
|
78
|
+
def metadata(self, new_metadata) -> None:
|
|
79
|
+
self._metadata = new_metadata
|
|
80
|
+
|
|
81
|
+
def access(self) -> ArrayLike:
|
|
82
|
+
return self._internal_repr
|
|
83
|
+
|
|
84
|
+
def unwrap(self, target="numpy") -> ArrayLike:
|
|
85
|
+
"""
|
|
86
|
+
Supported unwrapping targets are: numpy, dask, and optionally cupy
|
|
87
|
+
"""
|
|
88
|
+
match target:
|
|
89
|
+
case "numpy":
|
|
90
|
+
return self._repr_to_numpy_arr()
|
|
91
|
+
case "dask":
|
|
92
|
+
return self._repr_to_dask_arr()
|
|
93
|
+
case "cupy":
|
|
94
|
+
return self._repr_to_cupy_arr()
|
|
95
|
+
case _:
|
|
96
|
+
raise RuntimeError("Unknown unwrapping kind")
|
|
97
|
+
|
|
98
|
+
def _repr_to_numpy_arr(self):
|
|
99
|
+
if type(self._internal_repr) is np.ndarray:
|
|
100
|
+
return self._internal_repr.copy()
|
|
101
|
+
elif type(self._internal_repr) is ZarrArray:
|
|
102
|
+
return self._internal_repr[:]
|
|
103
|
+
|
|
104
|
+
raise RuntimeError(
|
|
105
|
+
f"Unknown internal representation {type(self._internal_repr)}"
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def _repr_to_dask_arr(self):
|
|
109
|
+
return None
|
|
110
|
+
|
|
111
|
+
def _repr_to_cupy_arr(self):
|
|
112
|
+
return None
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import numpy.typing
|
|
3
|
+
import pydantic
|
|
4
|
+
from zarr.abc.codec import BytesBytesCodec
|
|
5
|
+
from zarr.codecs import BloscCodec
|
|
6
|
+
|
|
7
|
+
from volsegtools.core import LatticeKind
|
|
8
|
+
from volsegtools.model.chunking_mode import ChunkingMode
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class StoringParameters(pydantic.BaseModel):
|
|
12
|
+
"""Parameters used for storing a volume or a segmentation.
|
|
13
|
+
|
|
14
|
+
Attributes
|
|
15
|
+
----------
|
|
16
|
+
is_compression_enabled: bool, default: False
|
|
17
|
+
Whether the compression is enabled.
|
|
18
|
+
chunking_mode: ChunkingMode, default: ChunkingMode.AUTO
|
|
19
|
+
Which chunking mode is used for this particular entry.
|
|
20
|
+
storage_dtype: numpy.typing.DTypeLike, default: np.float64
|
|
21
|
+
What is the type of data stored in this entry.
|
|
22
|
+
resolution_level: pydantic.NonNegativeInt, default: 0
|
|
23
|
+
Of which resolution level is this entry.
|
|
24
|
+
time_frame: pydantic.NonNegativeInt, default: 0
|
|
25
|
+
Which time frame is this entry.
|
|
26
|
+
channel: pydantic.NonNegativeInt, default: 0
|
|
27
|
+
Which channel is this entry.
|
|
28
|
+
compressor: Codec, default: Blosc()
|
|
29
|
+
Which compression codec is going to be used.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
is_compression_enabled: bool = False
|
|
33
|
+
chunking_mode: ChunkingMode = ChunkingMode.AUTO
|
|
34
|
+
storage_dtype: numpy.typing.DTypeLike = np.float64
|
|
35
|
+
resolution_level: pydantic.NonNegativeInt = 0
|
|
36
|
+
time_frame: pydantic.NonNegativeInt = 0
|
|
37
|
+
channel: pydantic.NonNegativeInt = 0
|
|
38
|
+
compressor: BytesBytesCodec = BloscCodec()
|
|
39
|
+
lattice_kind: LatticeKind = LatticeKind.VOLUME
|
|
40
|
+
|
|
41
|
+
def __str__(self) -> str:
|
|
42
|
+
return f"""Storing Paramaters:
|
|
43
|
+
is_compression_enabled {self.is_compression_enabled}
|
|
44
|
+
chunking_mode {self.chunking_mode}
|
|
45
|
+
storage_dtype {self.storage_dtype}
|
|
46
|
+
resolution_level {self.resolution_level}
|
|
47
|
+
time_frame {self.time_frame}
|
|
48
|
+
channel {self.channel}
|
|
49
|
+
compressor {self.compressor}"""
|
|
50
|
+
|
|
51
|
+
class Config:
|
|
52
|
+
arbitrary_types_allowed = True
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
|
|
5
|
+
import dask.array as da
|
|
6
|
+
import numpy as np
|
|
7
|
+
import zarr
|
|
8
|
+
import zarr.storage
|
|
9
|
+
|
|
10
|
+
from volsegtools.core import LatticeKind
|
|
11
|
+
from volsegtools.model.chunking_mode import ChunkingMode
|
|
12
|
+
from volsegtools.model.metadata import Metadata
|
|
13
|
+
from volsegtools.model.opaque_data_handle import OpaqueDataHandle
|
|
14
|
+
from volsegtools.model.storing_parameters import StoringParameters
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Singleton(type):
|
|
18
|
+
_instances = {}
|
|
19
|
+
|
|
20
|
+
def __call__(cls, *args, **kwargs):
|
|
21
|
+
if cls not in cls._instances:
|
|
22
|
+
instance = super().__call__(*args, **kwargs)
|
|
23
|
+
cls._instances[cls] = instance
|
|
24
|
+
return cls._instances[cls]
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def instance(cls):
|
|
28
|
+
return cls._instances[cls]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class WorkingStore(metaclass=Singleton):
|
|
32
|
+
def __init__(self, store_path: Path):
|
|
33
|
+
self.data_store = zarr.storage.LocalStore(root=store_path)
|
|
34
|
+
self.root_group = zarr.create_group(store=self.data_store)
|
|
35
|
+
|
|
36
|
+
self._metadata = Metadata()
|
|
37
|
+
|
|
38
|
+
self.volume_dtype = np.float64
|
|
39
|
+
self.is_volume_dtype_set = False
|
|
40
|
+
|
|
41
|
+
self.segmentation_dtype = np.float64
|
|
42
|
+
self.is_segmentation_dtype_set = False
|
|
43
|
+
|
|
44
|
+
self._volume_data_group = self.root_group.require_group("volume_data")
|
|
45
|
+
self._segmentation_data_group = self.root_group.require_group(
|
|
46
|
+
"segmentation_data"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def metadata(self):
|
|
51
|
+
return self._metadata
|
|
52
|
+
|
|
53
|
+
@metadata.setter
|
|
54
|
+
def metadata(self, value):
|
|
55
|
+
self._metadata = value
|
|
56
|
+
# TODO: it should return a dictionary
|
|
57
|
+
# self.root_group.attrs.put(dataclasses.asdict(self._metadata))
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def volume_data_group(self):
|
|
61
|
+
return self._volume_data_group
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def segmentation_data_group(self):
|
|
65
|
+
return self._segmentation_data_group
|
|
66
|
+
|
|
67
|
+
def get_data_array(
|
|
68
|
+
self, lattice_id, resolution, time_frame, channel, kind=LatticeKind.VOLUME
|
|
69
|
+
):
|
|
70
|
+
kind_group = self.get_data_group(kind)
|
|
71
|
+
lattice_group = kind_group.require_group(lattice_id)
|
|
72
|
+
resolution_group: zarr.Group = lattice_group.require_group(
|
|
73
|
+
f"resolution_{resolution}"
|
|
74
|
+
)
|
|
75
|
+
time_frame_group: zarr.Group = resolution_group.require_group(
|
|
76
|
+
f"time_frame_{time_frame}"
|
|
77
|
+
)
|
|
78
|
+
# FIX: this is unsafe, there should be some check!
|
|
79
|
+
return list(time_frame_group.arrays())[channel][1][:]
|
|
80
|
+
|
|
81
|
+
@staticmethod
|
|
82
|
+
def _compute_chunk_size_based_on_data(
|
|
83
|
+
data_shape: Tuple[int, ...],
|
|
84
|
+
) -> Tuple[int, ...]:
|
|
85
|
+
chunks = tuple([int(i / 4) if i > 4 else i for i in data_shape])
|
|
86
|
+
return chunks
|
|
87
|
+
|
|
88
|
+
@staticmethod
|
|
89
|
+
def _resolve_chunking_method(mode: ChunkingMode, data_shape: Tuple[int, ...]):
|
|
90
|
+
match mode:
|
|
91
|
+
case ChunkingMode.AUTO:
|
|
92
|
+
return "auto"
|
|
93
|
+
case ChunkingMode.NONE:
|
|
94
|
+
return (0, 0)
|
|
95
|
+
case ChunkingMode.CUSTOM:
|
|
96
|
+
return Data._compute_chunk_size_based_on_data(data_shape)
|
|
97
|
+
case _:
|
|
98
|
+
raise RuntimeError("Unsupported chunking method!")
|
|
99
|
+
|
|
100
|
+
def get_data_group(self, lattice_kind: LatticeKind):
|
|
101
|
+
match lattice_kind:
|
|
102
|
+
case LatticeKind.VOLUME:
|
|
103
|
+
return self.volume_data_group
|
|
104
|
+
case LatticeKind.SEGMENTATION:
|
|
105
|
+
return self.segmentation_data_group
|
|
106
|
+
case _:
|
|
107
|
+
raise RuntimeError("Unknown lattice kind encountered.")
|
|
108
|
+
|
|
109
|
+
def store_lattice_time_frame(
|
|
110
|
+
self,
|
|
111
|
+
params: StoringParameters,
|
|
112
|
+
data: da.Array,
|
|
113
|
+
lattice_id: str,
|
|
114
|
+
) -> OpaqueDataHandle:
|
|
115
|
+
kind_group = self.get_data_group(params.lattice_kind)
|
|
116
|
+
lattice_group = kind_group.require_group(lattice_id)
|
|
117
|
+
resolution_group: zarr.Group = lattice_group.require_group(
|
|
118
|
+
f"resolution_{params.resolution_level}"
|
|
119
|
+
)
|
|
120
|
+
time_frame_group: zarr.Group = resolution_group.require_group(
|
|
121
|
+
f"time_frame_{params.time_frame}"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
used_compressor = None
|
|
125
|
+
if params.is_compression_enabled:
|
|
126
|
+
used_compressor = params.compressor
|
|
127
|
+
|
|
128
|
+
zarr_repr: zarr.Array = time_frame_group.create_array(
|
|
129
|
+
name=str(params.channel),
|
|
130
|
+
chunks=WorkingStore._resolve_chunking_method(
|
|
131
|
+
params.chunking_mode, data.shape
|
|
132
|
+
),
|
|
133
|
+
dtype=params.storage_dtype,
|
|
134
|
+
compressors=[used_compressor] if used_compressor is not None else None,
|
|
135
|
+
shape=data.shape,
|
|
136
|
+
overwrite=True,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
da.to_zarr(arr=data, url=zarr_repr, overwrite=True, compute=True)
|
|
140
|
+
ref = OpaqueDataHandle(zarr_repr)
|
|
141
|
+
ref.metadata.lattice_id = lattice_id
|
|
142
|
+
return ref
|